pax_global_header 0000666 0000000 0000000 00000000064 14420263624 0014515 g ustar 00root root 0000000 0000000 52 comment=4d3329f156bd00305db380a0250535af7e752799
microsoft-authentication-library-for-go-1.0.0/ 0000775 0000000 0000000 00000000000 14420263624 0021366 5 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/.github/ 0000775 0000000 0000000 00000000000 14420263624 0022726 5 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/.github/ISSUE_TEMPLATE/ 0000775 0000000 0000000 00000000000 14420263624 0025111 5 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/.github/ISSUE_TEMPLATE/bug_report.md 0000664 0000000 0000000 00000003263 14420263624 0027607 0 ustar 00root root 0000000 0000000 ---
name: Bug report
about: Please do NOT file bugs without filling in this form.
title: '[Bug] '
labels: ''
assignees: ''
---
**Which version of MSAL Go are you using?**
Note that to get help, you need to run the latest version.
**Where is the issue?**
* Public client
* [ ] Device code flow
* [ ] Username/Password (ROPC grant)
* [ ] Authorization code flow
* Confidential client
* [ ] Authorization code flow
* [ ] Client credentials:
* [ ] client secret
* [ ] client certificate
* Token cache serialization
* [ ] In-memory cache
* Other (please describe)
**Is this a new or an existing app?**
**What version of Go are you using (`go version`)?**
$ go version
**What operating system and processor architecture are you using (`go env`)?**
go env
Output
$ go env
**Repro**
var your = (code) => here;
**Expected behavior**
A clear and concise description of what you expected to happen (or code).
**Actual behavior**
A clear and concise description of what happens, e.g. an exception is thrown, UI freezes.
**Possible solution**
**Additional context / logs / screenshots**
Add any other context about the problem here, such as logs and screenshots.
microsoft-authentication-library-for-go-1.0.0/.github/ISSUE_TEMPLATE/documentation.md 0000664 0000000 0000000 00000000755 14420263624 0030313 0 ustar 00root root 0000000 0000000 ---
name: Documentation
about: Suggest a change to the documentation.
title: '[Documentation] '
labels: documentation
assignees: ''
---
### Documentation related to component
### Please check all that apply
- [ ] typo
- [ ] documentation doesn't exist
- [ ] documentation needs clarification
- [ ] error(s) in the example
- [ ] needs an example
### Description of the issue
microsoft-authentication-library-for-go-1.0.0/.github/ISSUE_TEMPLATE/feature_request.md 0000664 0000000 0000000 00000001201 14420263624 0030630 0 ustar 00root root 0000000 0000000 ---
name: Feature request
about: Suggest an idea for this project.
title: "[Feature Request] "
labels: enhancement, Feature Request
assignees: ''
---
**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...].
**Describe the solution you'd like**
A clear and concise description of what you want to happen.
**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.
**Additional context**
Add any other context or screenshots about the feature request here.
microsoft-authentication-library-for-go-1.0.0/.github/workflows/ 0000775 0000000 0000000 00000000000 14420263624 0024763 5 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/.github/workflows/go.yml 0000664 0000000 0000000 00000002621 14420263624 0026114 0 ustar 00root root 0000000 0000000 name: Go
on:
push:
branches: [dev]
pull_request:
# This guards against unknown PR until a community member vet it and label it.
types: [ labeled ]
jobs:
build:
name: Build
runs-on: ubuntu-latest
strategy:
matrix:
go: ["1.19", "1.20"]
steps:
- name: Set up Go 1.x
uses: actions/setup-go@v2
with:
go-version: ${{ matrix.go }}
id: go
- name: Check out code into the Go module directory
uses: actions/checkout@v2
- name: Get dependencies
run: go get -v -t -d ./...
# designed to only run on linux
- name: Format Check
run: if [ $(gofmt -l -s . | wc -l) -ne 0 ]; then echo "fmt failed"; exit 1; fi
- name: Build
run: go build ./apps/...
- name: Unit Tests
run: go test -race -short ./apps/cache/... ./apps/confidential/... ./apps/public/... ./apps/internal/...
- name: Integration Tests
run: go test -race ./apps/tests/integration/...
env :
clientId: ${{ secrets.LAB_APP_CLIENT_ID }}
clientSecret: ${{ secrets.LAB_APP_CLIENT_SECRET }}
oboConfidentialClientId: ${{ secrets.OBO_CONFIDENTIAL_APP_CLIENT_ID }}
oboConfidentialClientSecret: ${{ secrets.OBO_CONFIDENTIAL_APP_CLIENT_SECRET }}
oboPublicClientId: ${{ secrets.OBO_PUBLIC_APP_CLIENT_ID }}
CI: ${{secrets.ENABLECI}}
microsoft-authentication-library-for-go-1.0.0/.github/workflows/golangci-lint.yml 0000664 0000000 0000000 00000001410 14420263624 0030231 0 ustar 00root root 0000000 0000000 name: golangci-lint
on:
push:
tags:
- v*
branches:
- master
- main
- dev
pull_request:
# This guards against unknown PR until a community member vet it and label it.
types: [ labeled ]
jobs:
golangci:
name: lint
runs-on: ubuntu-latest
steps:
- uses: actions/setup-go@v3
with:
go-version: "1.20"
- uses: actions/checkout@v3
- name: golangci-lint
uses: golangci/golangci-lint-action@v3
with:
# Required: the version of golangci-lint is required and must be specified without patch version: we always use the latest patch version.
version: v1.51
# Optional: golangci-lint command line arguments.
# args: --issues-exit-code=0
microsoft-authentication-library-for-go-1.0.0/.gitignore 0000664 0000000 0000000 00000000457 14420263624 0023364 0 ustar 00root root 0000000 0000000 # Binaries for programs and plugins
*.exe
*.exe~
*.dll
*.so
*.dylib
*.golangci.yml
*.swp
*.pprof
# OSX specific os files
*.DS_Store
# Test binary, build with `go test -c`
*.test
# Output of the go coverage tool, specifically when used with LiteIDE
*.out
pkg/
github.com/
golang.org/
.vscode/
.idea/
microsoft-authentication-library-for-go-1.0.0/.golangci.yml 0000664 0000000 0000000 00000000423 14420263624 0023751 0 ustar 00root root 0000000 0000000 linters:
# enabled in addition to default
enable:
- gosec
issues:
# Excluding configuration per-path, per-linter, per-text and per-source
exclude-rules:
# Exclude some linters from running on tests files.
- path: _test\.go
linters:
- gosec
microsoft-authentication-library-for-go-1.0.0/CODE_OF_CONDUCT.md 0000664 0000000 0000000 00000000705 14420263624 0024167 0 ustar 00root root 0000000 0000000 # Microsoft Open Source Code of Conduct
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
Resources:
- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
microsoft-authentication-library-for-go-1.0.0/CONTRIBUTING.md 0000664 0000000 0000000 00000007715 14420263624 0023631 0 ustar 00root root 0000000 0000000 # Microsoft Authentication Library for Go welcomes new contributors
This document will guide you through the process.
## Contributor License agreement
Please visit [https://cla.microsoft.com/](https://cla.microsoft.com/) and sign the Contributor License
Agreement. You only need to do that once. We can not look at your code until you've submitted this request.
## FORK
Fork the project [on GitHub](https://github.com/AzureAD/microsoft-authentication-library-for-go) and check out
your copy.
Example for MSAL Go:
```
$ git clone git@github.com:username/microsoft-authentication-library-for-go.git
$ cd microsoft-authentication-library-for-go
$ git remote add upstream git@github.com:AzureAD/microsoft-authentication-library-for-go.git
```
## Setup, Building and Testing
Please see the [Build & Run](https://github.com/AzureAD/microsoft-authentication-library-for-go/wiki/build-and-test) wiki page.
## Decide on which branch to create
**Bug fixes for the current stable version need to go to 'master' branch.**
If you need to contribute to a different branch, please contact us first (open an issue).
All details after this point is standard - make sure your commits have nice messages, and prefer rebase to merge.
In case of doubt, please open an issue in the [issue tracker](https://github.com/AzureAD/microsoft-authentication-library-for-go/issues).
Especially do so if you plan to work on a major change in functionality. Nothing is more
frustrating than seeing your hard work go to waste because your vision
does not align with our goals for the SDK.
## Branch
Okay, so you have decided on the proper branch. Create a feature branch
and start hacking:
```
$ git checkout -b my-feature-branch
```
## Commit
Make sure git knows your name and email address:
```
$ git config --global user.name "J. Random User"
$ git config --global user.email "j.random.user@example.com"
```
Writing good commit logs is important. A commit log should describe what
changed and why. Follow these guidelines when writing one:
1. The first line should be 50 characters or less and contain a short
description of the change prefixed with the name of the changed
subsystem (e.g. "net: add localAddress and localPort to Socket").
2. Keep the second line blank.
3. Wrap all other lines at 72 columns.
A good commit log looks like this:
```
fix: explaining the commit in one line
Body of commit message is a few lines of text, explaining things
in more detail, possibly giving some background about the issue
being fixed, etc etc.
The body of the commit message can be several paragraphs, and
please do proper word-wrap and keep columns shorter than about
72 characters or so. That way `git log` will show things
nicely even when it is indented.
```
The header line should be meaningful; it is what other people see when they
run `git shortlog` or `git log --oneline`.
Check the output of `git log --oneline files_that_you_changed` to find out
what directories your changes touch.
### Rebase
Use `git rebase` (not `git merge`) to sync your work from time to time.
```
$ git fetch upstream
$ git rebase upstream/v0.1 # or upstream/master
```
### Tests
It's all standard stuff, but please note that you won't be able to run integration tests locally because they connect to a KeyVault to fetch some test users and passwords. The CI will run them for you.
### Push
```
$ git push origin my-feature-branch
```
Go to `https://github.com/username/microsoft-authentication-library-for-go` and select your feature branch. Click
the 'Pull Request' button and fill out the form.
Pull requests are usually reviewed within a few days. If there are comments
to address, apply your changes in a separate commit and push that to your
feature branch. Post a comment in the pull request afterwards; GitHub does
not send out notifications when you add commits.
[on GitHub]: https://github.com/AzureAD/microsoft-authentication-library-for-go
[issue tracker]: https://github.com/AzureAD/microsoft-authentication-library-for-go/issues
microsoft-authentication-library-for-go-1.0.0/LICENSE 0000664 0000000 0000000 00000002212 14420263624 0022370 0 ustar 00root root 0000000 0000000 MIT License
Copyright (c) Microsoft Corporation.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE
microsoft-authentication-library-for-go-1.0.0/README.md 0000664 0000000 0000000 00000021067 14420263624 0022653 0 ustar 00root root 0000000 0000000 # Microsoft Authentication Library (MSAL) for Go
**MSAL for Go is a new addition to the MSAL family of libraries, has been made available in production ready preview to gauge customer interest and to gather feedback from the community. We welcome all contributors (see [CONTRIBUTING.md](https://github.com/AzureAD/microsoft-authentication-library-for-go/blob/dev/CONTRIBUTING.md)) to help us grow our list of supported MSAL SDKs.**
The Microsoft Authentication Library (MSAL) for Go is part of the [Microsoft identity platform for developers](https://aka.ms/aaddevv2) (formerly named Azure AD) v2.0. It allows you to sign in users or apps with Microsoft identities ([Azure AD](https://azure.microsoft.com/services/active-directory/) and [Microsoft Accounts](https://account.microsoft.com)) and obtain tokens to call Microsoft APIs such as [Microsoft Graph](https://graph.microsoft.io/) or your own APIs registered with the Microsoft identity platform. It is built using industry standard OAuth2 and OpenID Connect protocols.
The latest code resides in the `dev` branch.
Quick links:
| [Getting Started](https://docs.microsoft.com/azure/active-directory/develop/#quickstarts) | [GoDoc](https://pkg.go.dev/github.com/AzureAD/microsoft-authentication-library-for-go/apps) | [Wiki](https://github.com/AzureAD/microsoft-authentication-library-for-go/wiki) | [Samples](https://github.com/AzureAD/microsoft-authentication-library-for-go/tree/dev/apps/tests/devapps) | [Support](README.md#community-help-and-support) | [Feedback](https://forms.office.com/r/s4waBAytFJ) |
| ------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------- |
## Build Status

## Installation
### Setting up Go
To install Go, visit [this link](https://golang.org/dl/).
### Installing MSAL Go
`go get -u github.com/AzureAD/microsoft-authentication-library-for-go/`
## Usage
Before using MSAL Go, you will need to [register your application with the Microsoft identity platform](https://docs.microsoft.com/azure/active-directory/develop/quickstart-v2-register-an-app).
### Public Surface
The Public API of the library can be found in the following directories under `apps`.
```
apps/ - Contains all our code
confidential/ - The confidential application API
public/ - The public application API
cache/ - The cache interface that can be implemented to provide persistence cache storage of credentials
```
Acquiring tokens with MSAL Go follows this general three step pattern. There might be some slight differences for other token acquisition flows. Here is a basic example:
1. MSAL separates [public and confidential client applications](https://tools.ietf.org/html/rfc6749#section-2.1). So, you would create an instance of a `PublicClientApplication` and `ConfidentialClientApplication` and use this throughout the lifetime of your application.
* Initializing a public client:
```go
publicClientApp, err := public.New("client_id", public.WithAuthority("https://login.microsoft.com/Enter_The_Tenant_Name_Here"))
```
* Initializing a confidential client:
```go
// Initializing the client credential
cred, err := confidential.NewCredFromSecret("client_secret")
if err != nil {
return nil, fmt.Errorf("could not create a cred from a secret: %w", err)
}
confidentialClientApp, err := confidential.New("client_id", cred, confidential.WithAuthority("https://login.microsoft.com/Enter_The_Tenant_Name_Here"))
```
1. MSAL comes packaged with an in-memory cache. Utilizing the cache is optional, but we would highly recommend it.
```go
var userAccount public.Account
accounts := publicClientApp.Accounts()
if len(accounts) > 0 {
// Assuming the user wanted the first account
userAccount = accounts[0]
// found a cached account, now see if an applicable token has been cached
result, err := publicClientApp.AcquireTokenSilent(context.Background(), []string{"your_scope"}, public.WithSilentAccount(userAccount))
accessToken := result.AccessToken
}
```
1. If there is no suitable token in the cache, or you choose to skip this step, now we can send a request to AAD to obtain a token.
```go
result, err := publicClientApp.AcquireToken"ByOneofTheActualMethods"([]string{"your_scope"}, ...(other parameters depending on the function))
if err != nil {
log.Fatal(err)
}
accessToken := result.AccessToken
```
You can view the [dev apps](https://github.com/AzureAD/microsoft-authentication-library-for-go/tree/dev/apps/tests/devapps) on how to use MSAL Go with various application types in various scenarios. For more detailed information, please refer to the [wiki](https://github.com/AzureAD/microsoft-authentication-library-for-go/wiki).
# Releases
The list of [releases](https://github.com/AzureAD/microsoft-authentication-library-for-go/releases)
## Roadmap
This is a preview library. Details of the roadmap will come soon in the [wiki pages](https://github.com/AzureAD/microsoft-authentication-library-for-go/wiki), along with release notes.
## Community Help and Support
We use [Stack Overflow](http://stackoverflow.com/questions/tagged/msal) to work with the community on supporting Azure Active Directory and its SDKs, including this one! We highly recommend you ask your questions on Stack Overflow (we're all on there!) Also browse existing issues to see if someone has had your question before. Please use the "msal" tag when asking your questions.
If you find and bug or have a feature request, please raise the issue on [GitHub Issues](https://github.com/AzureAD/microsoft-authentication-library-for-go/issues).
## Submit Feedback
We'd like your thoughts on this library. Please complete [this short survey.](https://forms.office.com/r/s4waBAytFJ)
## Contributing
This project welcomes contributions and suggestions. Most contributions require you to agree to a
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
When you submit a pull request, a CLA bot will automatically determine whether you need to provide
a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
provided by the bot. You will only need to do this once across all repos using our CLA.
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
## Security Library
This library controls how users sign-in and access services. We recommend you always take the latest version of our library in your app when possible. We use [semantic versioning](http://semver.org) so you can control the risk associated with updating your app. As an example, always downloading the latest minor version number (e.g. x.*y*.x) ensures you get the latest security and feature enhancements but our API surface remains the same. You can always see the latest version and release notes under the Releases tab of GitHub.
## Security Reporting
If you find a security issue with our libraries or services please report it to [secure@microsoft.com](mailto:secure@microsoft.com) with as much detail as possible. Your submission may be eligible for a bounty through the [Microsoft Bounty](http://aka.ms/bugbounty) program. Please do not post security issues to GitHub Issues or any other public site. We will contact you shortly upon receiving the information. We encourage you to get notifications of when security incidents occur by visiting [this page](https://technet.microsoft.com/en-us/security/dd252948) and subscribing to Security Advisory Alerts.
Copyright (c) Microsoft Corporation. All rights reserved. Licensed under the MIT License (the "License").
microsoft-authentication-library-for-go-1.0.0/RELEASES.md 0000664 0000000 0000000 00000014523 14420263624 0023120 0 ustar 00root root 0000000 0000000 # Microsoft Identity SDK Versioning and Servicing FAQ
We have adopted the semantic versioning flow that is industry standard for OSS projects. It gives the maximum amount of control on what risk you take with what versions. If you know how semantic versioning works with node.js, java, and ruby none of this will be new.
## Semantic Versioning and API stability promises
Microsoft Identity libraries are independent open source libraries that are used by partners both internal and external to Microsoft. As with the rest of Microsoft, we have moved to a rapid iteration model where bugs are fixed daily and new versions are produced as required. To communicate these frequent changes to external partners and customers, we use semantic versioning for all our public Microsoft Identity SDK libraries. This follows the practices of other open source libraries on the internet. This allows us to support our downstream partners which will lock on certain versions for stability purposes, as well as providing for the distribution over NuGet, CocoaPods, and Maven.
The semantics are: MAJOR.MINOR.PATCH (example 1.1.5)
We will update our code distributions to use the latest PATCH semantic version number in order to make sure our customers and partners get the latest bug fixes. Downstream partner needs to pull the latest PATCH version. Most partners should try lock on the latest MINOR version number in their builds and accept any updates in the PATCH number.
Examples:
Using Cocapods, the following in the podfile will take the latest ADALiOS build that is > 1.1 but not 1.2.
```
pod 'ADALiOS', '~> 1.1'
```
Using NuGet, this ensures all 1.1.0 to 1.1.x updates are included when building your code, but not 1.2.
```
```
| Version | Description | Example |
|:-------:|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:---------------------------------------------------------------------------------------------------------:|
| x.x.x | PATCH version number. Incrementing these numbers is for bug fixes and updates but do not introduce new features. This is used for close partners who build on our platform release (ex. Azure AD Fabric, Office, etc.),In addition, Cocoapods, NuGet, and Maven use this number to deliver the latest release to customers.,This will update frequently (sometimes within the same day),There is no new features, and no regressions or API surface changes. Code will continue to work unless affected by a particular code fix. | ADAL for iOS 1.0.10,(this was a fix for the Storyboard display that was fixed for a specific Office team) |
| x.x | MINOR version numbers. Incrementing these second numbers are for new feature additions that do not impact existing features or introduce regressions. They are purely additive, but may require testing to ensure nothing is impacted.,All x.x.x bug fixes will also roll up in to this number.,There is no regressions or API surface changes. Code will continue to work unless affected by a particular code fix or needs this new feature. | ADAL for iOS 1.1.0,(this added WPJ capability to ADAL, and rolled all the updates from 1.0.0 to 1.0.12) |
| x | MAJOR version numbers. This should be considered a new, supported version of Microsoft Identity SDK and begins the Azure two year support cycle anew. Major new features are introduced and API changes can occur.,This should only be used after a large amount of testing and used only if those features are needed.,We will continue to service MAJOR version numbers with bug fixes up to the two year support cycle. | ADAL for iOS 1.0,(our first official release of ADAL) |
## Serviceability
When we release a new MINOR version, the previous MINOR version is abandoned.
When we release a new MAJOR version, we will continue to apply bug fixes to the existing features in the previous MAJOR version for up to the 2 year support cycle for Azure.
Example: We release ADALiOS 2.0 in the future which supports unified Auth for AAD and MSA. Later, we then have a fix in Conditional Access for ADALiOS. Since that feature exists both in ADALiOS 1.1 and ADALiOS 2.0, we will fix both. It will roll up in a PATCH number for each. Customers that are still locked down on ADALiOS 1.1 will receive the benefit of this fix.
## Microsoft Identity SDKs and Azure Active Directory
Microsoft Identity SDKs major versions will maintain backwards compatibility with Azure Active Directory web services through the support period. This means that the API surface area defined in a MAJOR version will continue to work for 2 years after release.
We will respond to bugs quickly from our partners and customers submitted through GitHub and through our private alias (tellaad@microsoft.com) for security issues and update the PATCH version number. We will also submit a change summary for each PATCH number.
Occasionally, there will be security bugs or breaking bugs from our partners that will require an immediate fix and a publish of an update to all partners and customers. When this occurs, we will do an emergency roll up to a PATCH version number and update all our distribution methods to the latest. microsoft-authentication-library-for-go-1.0.0/SECURITY.md 0000664 0000000 0000000 00000005461 14420263624 0023165 0 ustar 00root root 0000000 0000000
## Security
Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
If you believe you have found a security vulnerability in any Microsoft-owned repository that meets Microsoft's [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)) of a security vulnerability, please report it to us as described below.
## Reporting Security Issues
**Please do not report security vulnerabilities through public GitHub issues.**
Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report).
If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc).
You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
* Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
* Full paths of source file(s) related to the manifestation of the issue
* The location of the affected source code (tag/branch/commit or direct URL)
* Any special configuration required to reproduce the issue
* Step-by-step instructions to reproduce the issue
* Proof-of-concept or exploit code (if possible)
* Impact of the issue, including how an attacker might exploit the issue
This information will help us triage your report more quickly.
If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs.
## Preferred Languages
We prefer all communications to be in English.
## Policy
Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd).
microsoft-authentication-library-for-go-1.0.0/apps/ 0000775 0000000 0000000 00000000000 14420263624 0022331 5 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/apps/cache/ 0000775 0000000 0000000 00000000000 14420263624 0023374 5 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/apps/cache/cache.go 0000664 0000000 0000000 00000004021 14420263624 0024763 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
/*
Package cache allows third parties to implement external storage for caching token data
for distributed systems or multiple local applications access.
The data stored and extracted will represent the entire cache. Therefore it is recommended
one msal instance per user. This data is considered opaque and there are no guarantees to
implementers on the format being passed.
*/
package cache
import "context"
// Marshaler marshals data from an internal cache to bytes that can be stored.
type Marshaler interface {
Marshal() ([]byte, error)
}
// Unmarshaler unmarshals data from a storage medium into the internal cache, overwriting it.
type Unmarshaler interface {
Unmarshal([]byte) error
}
// Serializer can serialize the cache to binary or from binary into the cache.
type Serializer interface {
Marshaler
Unmarshaler
}
// ExportHints are suggestions for storing data.
type ExportHints struct {
// PartitionKey is a suggested key for partitioning the cache
PartitionKey string
}
// ReplaceHints are suggestions for loading data.
type ReplaceHints struct {
// PartitionKey is a suggested key for partitioning the cache
PartitionKey string
}
// ExportReplace exports and replaces in-memory cache data. It doesn't support nil Context or
// define the outcome of passing one. A Context without a timeout must receive a default timeout
// specified by the implementor. Retries must be implemented inside the implementation.
type ExportReplace interface {
// Replace replaces the cache with what is in external storage. Implementors should honor
// Context cancellations and return context.Canceled or context.DeadlineExceeded in those cases.
Replace(ctx context.Context, cache Unmarshaler, hints ReplaceHints) error
// Export writes the binary representation of the cache (cache.Marshal()) to external storage.
// This is considered opaque. Context cancellations should be honored as in Replace.
Export(ctx context.Context, cache Marshaler, hints ExportHints) error
}
microsoft-authentication-library-for-go-1.0.0/apps/confidential/ 0000775 0000000 0000000 00000000000 14420263624 0024770 5 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/apps/confidential/confidential.go 0000664 0000000 0000000 00000056455 14420263624 0027775 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
/*
Package confidential provides a client for authentication of "confidential" applications.
A "confidential" application is defined as an app that run on servers. They are considered
difficult to access and for that reason capable of keeping an application secret.
Confidential clients can hold configuration-time secrets.
*/
package confidential
import (
"context"
"crypto"
"crypto/rsa"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"errors"
"fmt"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/exported"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/options"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/shared"
)
/*
Design note:
confidential.Client uses base.Client as an embedded type. base.Client statically assigns its attributes
during creation. As it doesn't have any pointers in it, anything borrowed from it, such as
Base.AuthParams is a copy that is free to be manipulated here.
Duplicate Calls shared between public.Client and this package:
There is some duplicate call options provided here that are the same as in public.Client . This
is a design choices. Go proverb(https://www.youtube.com/watch?v=PAAkCSZUG1c&t=9m28s):
"a little copying is better than a little dependency". Yes, we could have another package with
shared options (fail). That divides like 2 options from all others which makes the user look
through more docs. We can have all clients in one package, but I think separate packages
here makes for better naming (public.Client vs client.PublicClient). So I chose a little
duplication.
.Net People, Take note on X509:
This uses x509.Certificates and private keys. x509 does not store private keys. .Net
has some x509.Certificate2 thing that has private keys, but that is just some bullcrap that .Net
added, it doesn't exist in real life. As such I've put a PEM decoder into here.
*/
// TODO(msal): This should have example code for each method on client using Go's example doc framework.
// base usage details should be include in the package documentation.
// AuthResult contains the results of one token acquisition operation.
// For details see https://aka.ms/msal-net-authenticationresult
type AuthResult = base.AuthResult
type Account = shared.Account
// CertFromPEM converts a PEM file (.pem or .key) for use with [NewCredFromCert]. The file
// must contain the public certificate and the private key. If a PEM block is encrypted and
// password is not an empty string, it attempts to decrypt the PEM blocks using the password.
// Multiple certs are due to certificate chaining for use cases like TLS that sign from root to leaf.
func CertFromPEM(pemData []byte, password string) ([]*x509.Certificate, crypto.PrivateKey, error) {
var certs []*x509.Certificate
var priv crypto.PrivateKey
for {
block, rest := pem.Decode(pemData)
if block == nil {
break
}
//nolint:staticcheck // x509.IsEncryptedPEMBlock and x509.DecryptPEMBlock are deprecated. They are used here only to support a usecase.
if x509.IsEncryptedPEMBlock(block) {
b, err := x509.DecryptPEMBlock(block, []byte(password))
if err != nil {
return nil, nil, fmt.Errorf("could not decrypt encrypted PEM block: %v", err)
}
block, _ = pem.Decode(b)
if block == nil {
return nil, nil, fmt.Errorf("encounter encrypted PEM block that did not decode")
}
}
switch block.Type {
case "CERTIFICATE":
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return nil, nil, fmt.Errorf("block labelled 'CERTIFICATE' could not be parsed by x509: %v", err)
}
certs = append(certs, cert)
case "PRIVATE KEY":
if priv != nil {
return nil, nil, errors.New("found multiple private key blocks")
}
var err error
priv, err = x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return nil, nil, fmt.Errorf("could not decode private key: %v", err)
}
case "RSA PRIVATE KEY":
if priv != nil {
return nil, nil, errors.New("found multiple private key blocks")
}
var err error
priv, err = x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return nil, nil, fmt.Errorf("could not decode private key: %v", err)
}
}
pemData = rest
}
if len(certs) == 0 {
return nil, nil, fmt.Errorf("no certificates found")
}
if priv == nil {
return nil, nil, fmt.Errorf("no private key found")
}
return certs, priv, nil
}
// AssertionRequestOptions has required information for client assertion claims
type AssertionRequestOptions = exported.AssertionRequestOptions
// Credential represents the credential used in confidential client flows.
type Credential struct {
secret string
cert *x509.Certificate
key crypto.PrivateKey
x5c []string
assertionCallback func(context.Context, AssertionRequestOptions) (string, error)
tokenProvider func(context.Context, TokenProviderParameters) (TokenProviderResult, error)
}
// toInternal returns the accesstokens.Credential that is used internally. The current structure of the
// code requires that client.go, requests.go and confidential.go share a credential type without
// having import recursion. That requires the type used between is in a shared package. Therefore
// we have this.
func (c Credential) toInternal() (*accesstokens.Credential, error) {
if c.secret != "" {
return &accesstokens.Credential{Secret: c.secret}, nil
}
if c.cert != nil {
if c.key == nil {
return nil, errors.New("missing private key for certificate")
}
return &accesstokens.Credential{Cert: c.cert, Key: c.key, X5c: c.x5c}, nil
}
if c.key != nil {
return nil, errors.New("missing certificate for private key")
}
if c.assertionCallback != nil {
return &accesstokens.Credential{AssertionCallback: c.assertionCallback}, nil
}
if c.tokenProvider != nil {
return &accesstokens.Credential{TokenProvider: c.tokenProvider}, nil
}
return nil, errors.New("invalid credential")
}
// NewCredFromSecret creates a Credential from a secret.
func NewCredFromSecret(secret string) (Credential, error) {
if secret == "" {
return Credential{}, errors.New("secret can't be empty string")
}
return Credential{secret: secret}, nil
}
// NewCredFromAssertionCallback creates a Credential that invokes a callback to get assertions
// authenticating the application. The callback must be thread safe.
func NewCredFromAssertionCallback(callback func(context.Context, AssertionRequestOptions) (string, error)) Credential {
return Credential{assertionCallback: callback}
}
// NewCredFromCert creates a Credential from a certificate or chain of certificates and an RSA private key
// as returned by [CertFromPEM].
func NewCredFromCert(certs []*x509.Certificate, key crypto.PrivateKey) (Credential, error) {
cred := Credential{key: key}
k, ok := key.(*rsa.PrivateKey)
if !ok {
return cred, errors.New("key must be an RSA key")
}
for _, cert := range certs {
if cert == nil {
// not returning an error here because certs may still contain a sufficient cert/key pair
continue
}
certKey, ok := cert.PublicKey.(*rsa.PublicKey)
if ok && k.E == certKey.E && k.N.Cmp(certKey.N) == 0 {
// We know this is the signing cert because its public key matches the given private key.
// This cert must be first in x5c.
cred.cert = cert
cred.x5c = append([]string{base64.StdEncoding.EncodeToString(cert.Raw)}, cred.x5c...)
} else {
cred.x5c = append(cred.x5c, base64.StdEncoding.EncodeToString(cert.Raw))
}
}
if cred.cert == nil {
return cred, errors.New("key doesn't match any certificate")
}
return cred, nil
}
// TokenProviderParameters is the authentication parameters passed to token providers
type TokenProviderParameters = exported.TokenProviderParameters
// TokenProviderResult is the authentication result returned by custom token providers
type TokenProviderResult = exported.TokenProviderResult
// NewCredFromTokenProvider creates a Credential from a function that provides access tokens. The function
// must be concurrency safe. This is intended only to allow the Azure SDK to cache MSI tokens. It isn't
// useful to applications in general because the token provider must implement all authentication logic.
func NewCredFromTokenProvider(provider func(context.Context, TokenProviderParameters) (TokenProviderResult, error)) Credential {
return Credential{tokenProvider: provider}
}
// AutoDetectRegion instructs MSAL Go to auto detect region for Azure regional token service.
func AutoDetectRegion() string {
return "TryAutoDetect"
}
// Client is a representation of authentication client for confidential applications as defined in the
// package doc. A new Client should be created PER SERVICE USER.
// For more information, visit https://docs.microsoft.com/azure/active-directory/develop/msal-client-applications
type Client struct {
base base.Client
cred *accesstokens.Credential
}
// clientOptions are optional settings for New(). These options are set using various functions
// returning Option calls.
type clientOptions struct {
accessor cache.ExportReplace
authority, azureRegion string
capabilities []string
disableInstanceDiscovery, sendX5C bool
httpClient ops.HTTPClient
}
// Option is an optional argument to New().
type Option func(o *clientOptions)
// WithCache provides an accessor that will read and write authentication data to an externally managed cache.
func WithCache(accessor cache.ExportReplace) Option {
return func(o *clientOptions) {
o.accessor = accessor
}
}
// WithClientCapabilities allows configuring one or more client capabilities such as "CP1"
func WithClientCapabilities(capabilities []string) Option {
return func(o *clientOptions) {
// there's no danger of sharing the slice's underlying memory with the application because
// this slice is simply passed to base.WithClientCapabilities, which copies its data
o.capabilities = capabilities
}
}
// WithHTTPClient allows for a custom HTTP client to be set.
func WithHTTPClient(httpClient ops.HTTPClient) Option {
return func(o *clientOptions) {
o.httpClient = httpClient
}
}
// WithX5C specifies if x5c claim(public key of the certificate) should be sent to STS to enable Subject Name Issuer Authentication.
func WithX5C() Option {
return func(o *clientOptions) {
o.sendX5C = true
}
}
// WithInstanceDiscovery set to false to disable authority validation (to support private cloud scenarios)
func WithInstanceDiscovery(enabled bool) Option {
return func(o *clientOptions) {
o.disableInstanceDiscovery = !enabled
}
}
// WithAzureRegion sets the region(preferred) or Confidential.AutoDetectRegion() for auto detecting region.
// Region names as per https://azure.microsoft.com/en-ca/global-infrastructure/geographies/.
// See https://aka.ms/region-map for more details on region names.
// The region value should be short region name for the region where the service is deployed.
// For example "centralus" is short name for region Central US.
// Not all auth flows can use the regional token service.
// Service To Service (client credential flow) tokens can be obtained from the regional service.
// Requires configuration at the tenant level.
// Auto-detection works on a limited number of Azure artifacts (VMs, Azure functions).
// If auto-detection fails, the non-regional endpoint will be used.
// If an invalid region name is provided, the non-regional endpoint MIGHT be used or the token request MIGHT fail.
func WithAzureRegion(val string) Option {
return func(o *clientOptions) {
o.azureRegion = val
}
}
// New is the constructor for Client. authority is the URL of a token authority such as "https://login.microsoftonline.com/".
// If the Client will connect directly to AD FS, use "adfs" for the tenant. clientID is the application's client ID (also called its
// "application ID").
func New(authority, clientID string, cred Credential, options ...Option) (Client, error) {
internalCred, err := cred.toInternal()
if err != nil {
return Client{}, err
}
opts := clientOptions{
authority: authority,
// if the caller specified a token provider, it will handle all details of authentication, using Client only as a token cache
disableInstanceDiscovery: cred.tokenProvider != nil,
httpClient: shared.DefaultClient,
}
for _, o := range options {
o(&opts)
}
baseOpts := []base.Option{
base.WithCacheAccessor(opts.accessor),
base.WithClientCapabilities(opts.capabilities),
base.WithInstanceDiscovery(!opts.disableInstanceDiscovery),
base.WithRegionDetection(opts.azureRegion),
base.WithX5C(opts.sendX5C),
}
base, err := base.New(clientID, opts.authority, oauth.New(opts.httpClient), baseOpts...)
if err != nil {
return Client{}, err
}
base.AuthParams.IsConfidentialClient = true
return Client{base: base, cred: internalCred}, nil
}
// authCodeURLOptions contains options for AuthCodeURL
type authCodeURLOptions struct {
claims, loginHint, tenantID, domainHint string
}
// AuthCodeURLOption is implemented by options for AuthCodeURL
type AuthCodeURLOption interface {
authCodeURLOption()
}
// AuthCodeURL creates a URL used to acquire an authorization code. Users need to call CreateAuthorizationCodeURLParameters and pass it in.
//
// Options: [WithClaims], [WithDomainHint], [WithLoginHint], [WithTenantID]
func (cca Client) AuthCodeURL(ctx context.Context, clientID, redirectURI string, scopes []string, opts ...AuthCodeURLOption) (string, error) {
o := authCodeURLOptions{}
if err := options.ApplyOptions(&o, opts); err != nil {
return "", err
}
ap, err := cca.base.AuthParams.WithTenant(o.tenantID)
if err != nil {
return "", err
}
ap.Claims = o.claims
ap.LoginHint = o.loginHint
ap.DomainHint = o.domainHint
return cca.base.AuthCodeURL(ctx, clientID, redirectURI, scopes, ap)
}
// WithLoginHint pre-populates the login prompt with a username.
func WithLoginHint(username string) interface {
AuthCodeURLOption
options.CallOption
} {
return struct {
AuthCodeURLOption
options.CallOption
}{
CallOption: options.NewCallOption(
func(a any) error {
switch t := a.(type) {
case *authCodeURLOptions:
t.loginHint = username
default:
return fmt.Errorf("unexpected options type %T", a)
}
return nil
},
),
}
}
// WithDomainHint adds the IdP domain as domain_hint query parameter in the auth url.
func WithDomainHint(domain string) interface {
AuthCodeURLOption
options.CallOption
} {
return struct {
AuthCodeURLOption
options.CallOption
}{
CallOption: options.NewCallOption(
func(a any) error {
switch t := a.(type) {
case *authCodeURLOptions:
t.domainHint = domain
default:
return fmt.Errorf("unexpected options type %T", a)
}
return nil
},
),
}
}
// WithClaims sets additional claims to request for the token, such as those required by conditional access policies.
// Use this option when Azure AD returned a claims challenge for a prior request. The argument must be decoded.
// This option is valid for any token acquisition method.
func WithClaims(claims string) interface {
AcquireByAuthCodeOption
AcquireByCredentialOption
AcquireOnBehalfOfOption
AcquireSilentOption
AuthCodeURLOption
options.CallOption
} {
return struct {
AcquireByAuthCodeOption
AcquireByCredentialOption
AcquireOnBehalfOfOption
AcquireSilentOption
AuthCodeURLOption
options.CallOption
}{
CallOption: options.NewCallOption(
func(a any) error {
switch t := a.(type) {
case *acquireTokenByAuthCodeOptions:
t.claims = claims
case *acquireTokenByCredentialOptions:
t.claims = claims
case *acquireTokenOnBehalfOfOptions:
t.claims = claims
case *acquireTokenSilentOptions:
t.claims = claims
case *authCodeURLOptions:
t.claims = claims
default:
return fmt.Errorf("unexpected options type %T", a)
}
return nil
},
),
}
}
// WithTenantID specifies a tenant for a single authentication. It may be different than the tenant set in [New].
// This option is valid for any token acquisition method.
func WithTenantID(tenantID string) interface {
AcquireByAuthCodeOption
AcquireByCredentialOption
AcquireOnBehalfOfOption
AcquireSilentOption
AuthCodeURLOption
options.CallOption
} {
return struct {
AcquireByAuthCodeOption
AcquireByCredentialOption
AcquireOnBehalfOfOption
AcquireSilentOption
AuthCodeURLOption
options.CallOption
}{
CallOption: options.NewCallOption(
func(a any) error {
switch t := a.(type) {
case *acquireTokenByAuthCodeOptions:
t.tenantID = tenantID
case *acquireTokenByCredentialOptions:
t.tenantID = tenantID
case *acquireTokenOnBehalfOfOptions:
t.tenantID = tenantID
case *acquireTokenSilentOptions:
t.tenantID = tenantID
case *authCodeURLOptions:
t.tenantID = tenantID
default:
return fmt.Errorf("unexpected options type %T", a)
}
return nil
},
),
}
}
// acquireTokenSilentOptions are all the optional settings to an AcquireTokenSilent() call.
// These are set by using various AcquireTokenSilentOption functions.
type acquireTokenSilentOptions struct {
account Account
claims, tenantID string
}
// AcquireSilentOption is implemented by options for AcquireTokenSilent
type AcquireSilentOption interface {
acquireSilentOption()
}
// WithSilentAccount uses the passed account during an AcquireTokenSilent() call.
func WithSilentAccount(account Account) interface {
AcquireSilentOption
options.CallOption
} {
return struct {
AcquireSilentOption
options.CallOption
}{
CallOption: options.NewCallOption(
func(a any) error {
switch t := a.(type) {
case *acquireTokenSilentOptions:
t.account = account
default:
return fmt.Errorf("unexpected options type %T", a)
}
return nil
},
),
}
}
// AcquireTokenSilent acquires a token from either the cache or using a refresh token.
//
// Options: [WithClaims], [WithSilentAccount], [WithTenantID]
func (cca Client) AcquireTokenSilent(ctx context.Context, scopes []string, opts ...AcquireSilentOption) (AuthResult, error) {
o := acquireTokenSilentOptions{}
if err := options.ApplyOptions(&o, opts); err != nil {
return AuthResult{}, err
}
if o.claims != "" {
return AuthResult{}, errors.New("call another AcquireToken method to request a new token having these claims")
}
silentParameters := base.AcquireTokenSilentParameters{
Scopes: scopes,
Account: o.account,
RequestType: accesstokens.ATConfidential,
Credential: cca.cred,
IsAppCache: o.account.IsZero(),
TenantID: o.tenantID,
}
return cca.base.AcquireTokenSilent(ctx, silentParameters)
}
// acquireTokenByAuthCodeOptions contains the optional parameters used to acquire an access token using the authorization code flow.
type acquireTokenByAuthCodeOptions struct {
challenge, claims, tenantID string
}
// AcquireByAuthCodeOption is implemented by options for AcquireTokenByAuthCode
type AcquireByAuthCodeOption interface {
acquireByAuthCodeOption()
}
// WithChallenge allows you to provide a challenge for the .AcquireTokenByAuthCode() call.
func WithChallenge(challenge string) interface {
AcquireByAuthCodeOption
options.CallOption
} {
return struct {
AcquireByAuthCodeOption
options.CallOption
}{
CallOption: options.NewCallOption(
func(a any) error {
switch t := a.(type) {
case *acquireTokenByAuthCodeOptions:
t.challenge = challenge
default:
return fmt.Errorf("unexpected options type %T", a)
}
return nil
},
),
}
}
// AcquireTokenByAuthCode is a request to acquire a security token from the authority, using an authorization code.
// The specified redirect URI must be the same URI that was used when the authorization code was requested.
//
// Options: [WithChallenge], [WithClaims], [WithTenantID]
func (cca Client) AcquireTokenByAuthCode(ctx context.Context, code string, redirectURI string, scopes []string, opts ...AcquireByAuthCodeOption) (AuthResult, error) {
o := acquireTokenByAuthCodeOptions{}
if err := options.ApplyOptions(&o, opts); err != nil {
return AuthResult{}, err
}
params := base.AcquireTokenAuthCodeParameters{
Scopes: scopes,
Code: code,
Challenge: o.challenge,
Claims: o.claims,
AppType: accesstokens.ATConfidential,
Credential: cca.cred, // This setting differs from public.Client.AcquireTokenByAuthCode
RedirectURI: redirectURI,
TenantID: o.tenantID,
}
return cca.base.AcquireTokenByAuthCode(ctx, params)
}
// acquireTokenByCredentialOptions contains optional configuration for AcquireTokenByCredential
type acquireTokenByCredentialOptions struct {
claims, tenantID string
}
// AcquireByCredentialOption is implemented by options for AcquireTokenByCredential
type AcquireByCredentialOption interface {
acquireByCredOption()
}
// AcquireTokenByCredential acquires a security token from the authority, using the client credentials grant.
//
// Options: [WithClaims], [WithTenantID]
func (cca Client) AcquireTokenByCredential(ctx context.Context, scopes []string, opts ...AcquireByCredentialOption) (AuthResult, error) {
o := acquireTokenByCredentialOptions{}
err := options.ApplyOptions(&o, opts)
if err != nil {
return AuthResult{}, err
}
authParams, err := cca.base.AuthParams.WithTenant(o.tenantID)
if err != nil {
return AuthResult{}, err
}
authParams.Scopes = scopes
authParams.AuthorizationType = authority.ATClientCredentials
authParams.Claims = o.claims
token, err := cca.base.Token.Credential(ctx, authParams, cca.cred)
if err != nil {
return AuthResult{}, err
}
return cca.base.AuthResultFromToken(ctx, authParams, token, true)
}
// acquireTokenOnBehalfOfOptions contains optional configuration for AcquireTokenOnBehalfOf
type acquireTokenOnBehalfOfOptions struct {
claims, tenantID string
}
// AcquireOnBehalfOfOption is implemented by options for AcquireTokenOnBehalfOf
type AcquireOnBehalfOfOption interface {
acquireOBOOption()
}
// AcquireTokenOnBehalfOf acquires a security token for an app using middle tier apps access token.
// Refer https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-on-behalf-of-flow.
//
// Options: [WithClaims], [WithTenantID]
func (cca Client) AcquireTokenOnBehalfOf(ctx context.Context, userAssertion string, scopes []string, opts ...AcquireOnBehalfOfOption) (AuthResult, error) {
o := acquireTokenOnBehalfOfOptions{}
if err := options.ApplyOptions(&o, opts); err != nil {
return AuthResult{}, err
}
params := base.AcquireTokenOnBehalfOfParameters{
Scopes: scopes,
UserAssertion: userAssertion,
Claims: o.claims,
Credential: cca.cred,
TenantID: o.tenantID,
}
return cca.base.AcquireTokenOnBehalfOf(ctx, params)
}
// Account gets the account in the token cache with the specified homeAccountID.
func (cca Client) Account(ctx context.Context, accountID string) (Account, error) {
return cca.base.Account(ctx, accountID)
}
// RemoveAccount signs the account out and forgets account from token cache.
func (cca Client) RemoveAccount(ctx context.Context, account Account) error {
return cca.base.RemoveAccount(ctx, account)
}
microsoft-authentication-library-for-go-1.0.0/apps/confidential/confidential_test.go 0000664 0000000 0000000 00000121366 14420263624 0031026 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package confidential
import (
"context"
"crypto"
"crypto/rsa"
"crypto/x509"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/exported"
internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/mock"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/fake"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
"github.com/golang-jwt/jwt/v4"
"github.com/kylelemons/godebug/pretty"
)
// errorClient is an HTTP client for tests that should fail when confidential.Client sends a request
type errorClient struct{}
func (*errorClient) Do(req *http.Request) (*http.Response, error) {
return nil, fmt.Errorf("expected no requests but received one for %s", req.URL.String())
}
func (*errorClient) CloseIdleConnections() {}
func TestCertFromPEM(t *testing.T) {
f, err := os.Open(filepath.Clean("../testdata/test-cert.pem"))
if err != nil {
t.Fatal(err)
}
defer f.Close()
pemData, err := io.ReadAll(f)
if err != nil {
t.Fatal(err)
}
certs, key, err := CertFromPEM(pemData, "")
if err != nil {
t.Fatalf("TestCertFromPEM: got err == %s, want err == nil", err)
}
if len(certs) != 1 {
t.Fatalf("TestCertFromPEM: got %d certs, want 1 cert", len(certs))
}
if key == nil {
t.Fatalf("TestCertFromPEM: got nil key, want key != nil")
}
}
const (
authorityFmt = "https://%s/%s"
fakeAuthority = "https://fake_authority/fake"
fakeClientID = "fake_client_id"
fakeSecret = "fake_secret"
fakeTokenEndpoint = "https://fake_authority/fake/token"
localhost = "http://localhost"
refresh = "fake_refresh"
token = "fake_token"
)
var tokenScope = []string{"the_scope"}
func fakeClient(tk accesstokens.TokenResponse, credential Credential, options ...Option) (Client, error) {
client, err := New(fakeAuthority, fakeClientID, credential, options...)
if err != nil {
return Client{}, err
}
client.base.Token.AccessTokens = &fake.AccessTokens{
AccessToken: tk,
}
client.base.Token.Authority = &fake.Authority{
InstanceResp: authority.InstanceDiscoveryResponse{
TenantDiscoveryEndpoint: "https://fake_authority/fake/discovery/endpoint",
Metadata: []authority.InstanceDiscoveryMetadata{
{
PreferredNetwork: "fake_authority",
PreferredCache: "fake_cache",
Aliases: []string{
"fake_authority",
"fake_auth",
"fk_au",
},
},
},
AdditionalFields: map[string]interface{}{
"api-version": "2020-02-02",
},
},
}
client.base.Token.Resolver = &fake.ResolveEndpoints{
Endpoints: authority.NewEndpoints("https://fake_authority/fake/auth",
fakeTokenEndpoint, "https://fake_authority/fake/jwt", "fake_authority"),
}
client.base.Token.WSTrust = &fake.WSTrust{}
return client, nil
}
func TestAcquireTokenByCredential(t *testing.T) {
tests := []struct {
desc string
cred string
}{
{
desc: "Secret",
cred: "fake_secret",
},
{
desc: "Signed Assertion",
cred: "fake_assertion",
},
}
for _, test := range tests {
cred, err := NewCredFromSecret(test.cred)
if err != nil {
t.Fatal(err)
}
client, err := fakeClient(accesstokens.TokenResponse{
AccessToken: token,
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: tokenScope},
}, cred)
if err != nil {
t.Fatal(err)
}
_, err = client.AcquireTokenSilent(context.Background(), tokenScope)
// first attempt should fail
if err == nil {
t.Errorf("TestAcquireTokenByCredential(%s): unexpected nil error from AcquireTokenSilent", test.desc)
}
tk, err := client.AcquireTokenByCredential(context.Background(), tokenScope)
if err != nil {
t.Errorf("TestAcquireTokenByCredential(%s): got err == %s, want err == nil", test.desc, err)
}
if tk.AccessToken != token {
t.Errorf("TestAcquireTokenByCredential(%s): unexpected access token %s", test.desc, tk.AccessToken)
}
// second attempt should return the cached token
tk, err = client.AcquireTokenSilent(context.Background(), tokenScope)
if err != nil {
t.Errorf("TestAcquireTokenByCredential(%s): got err == %s, want err == nil", test.desc, err)
}
if tk.AccessToken != token {
t.Errorf("TestAcquireTokenByCredential(%s): unexpected access token %s", test.desc, tk.AccessToken)
}
}
}
func TestAcquireTokenOnBehalfOf(t *testing.T) {
// this test is an offline version of TestOnBehalfOf in integration_test.go
cred, err := NewCredFromSecret(fakeSecret)
if err != nil {
t.Fatal(err)
}
lmo := "login.microsoftonline.com"
tenant := "tenant"
assertion := "assertion"
mockClient := mock.Client{}
// TODO: OBO does instance discovery twice before first token request https://github.com/AzureAD/microsoft-authentication-library-for-go/issues/351
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenant)))
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenant)))
mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(token, "", "rt", "", 3600)))
client, err := New(fmt.Sprintf(authorityFmt, lmo, tenant), fakeClientID, cred, WithHTTPClient(&mockClient))
if err != nil {
t.Fatal(err)
}
tk, err := client.AcquireTokenOnBehalfOf(context.Background(), assertion, tokenScope)
if err != nil {
t.Fatal(err)
}
if tk.AccessToken != token {
t.Fatalf("wanted %q, got %q", token, tk.AccessToken)
}
// should return the cached access token
tk, err = client.AcquireTokenOnBehalfOf(context.Background(), assertion, tokenScope)
if err != nil {
t.Fatal(err)
}
if tk.AccessToken != token {
t.Fatalf("wanted %q, got %q", token, tk.AccessToken)
}
// new assertion should trigger new token request
token2 := token + "2"
mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(token2, "", "rt", "", 3600)))
tk, err = client.AcquireTokenOnBehalfOf(context.Background(), assertion+"2", tokenScope)
if err != nil {
t.Fatal(err)
}
if tk.AccessToken != token2 {
t.Fatal("expected a new token")
}
}
func TestAcquireTokenByAssertionCallback(t *testing.T) {
calls := 0
key := struct{}{}
ctx := context.WithValue(context.Background(), key, true)
getAssertion := func(c context.Context, o AssertionRequestOptions) (string, error) {
if v := c.Value(key); v == nil || !v.(bool) {
t.Fatal("callback received unexpected context")
}
if o.ClientID != fakeClientID {
t.Fatalf(`unexpected client ID "%s"`, o.ClientID)
}
if o.TokenEndpoint != fakeTokenEndpoint {
t.Fatalf(`unexpected token endpoint "%s"`, o.TokenEndpoint)
}
calls++
if calls < 4 {
return "assertion", nil
}
return "", errors.New("expected error")
}
cred := NewCredFromAssertionCallback(getAssertion)
client, err := fakeClient(accesstokens.TokenResponse{}, cred)
if err != nil {
t.Fatal(err)
}
for i := 0; i < 3; i++ {
if calls != i {
t.Fatalf("expected %d calls, got %d", i, calls)
}
_, err = client.AcquireTokenByCredential(ctx, tokenScope)
if err != nil {
t.Fatal(err)
}
}
_, err = client.AcquireTokenByCredential(ctx, tokenScope)
if err == nil || err.Error() != "expected error" {
t.Fatalf("expected an error from the callback, got %v", err)
}
}
func TestAcquireTokenByAuthCode(t *testing.T) {
cred, err := NewCredFromSecret(fakeSecret)
if err != nil {
t.Fatal(err)
}
for _, params := range []struct {
upn, preferredUsername, utid string
}{
{"", "fakeuser@fakeplace.fake", "fake"},
{"fakeuser@fakeplace.fake", "", ""},
} {
t.Run("", func(t *testing.T) {
tr := accesstokens.TokenResponse{
AccessToken: token,
RefreshToken: refresh,
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: tokenScope},
IDToken: accesstokens.IDToken{
PreferredUsername: params.preferredUsername,
UPN: params.upn,
Name: "fake person",
Oid: "123-456",
TenantID: "fake",
Subject: "nothing",
Issuer: "https://fake_authority/fake",
Audience: "abc-123",
ExpirationTime: time.Now().Add(time.Hour).Unix(),
IssuedAt: time.Now().Add(-5 * time.Minute).Unix(),
NotBefore: time.Now().Add(-5 * time.Minute).Unix(),
// NOTE: this is an invalid JWT however this doesn't cause a failure.
// it simply falls back to calling Token.Refresh() which will obviously succeed.
RawToken: "fake.raw.token",
},
ClientInfo: accesstokens.ClientInfo{
UID: "123-456",
UTID: params.utid,
},
}
client, err := fakeClient(tr, cred)
if err != nil {
t.Fatal(err)
}
_, err = client.AcquireTokenSilent(context.Background(), tokenScope)
// first attempt should fail
if err == nil {
t.Fatal("unexpected nil error from AcquireTokenSilent")
}
tk, err := client.AcquireTokenByAuthCode(context.Background(), "fake_auth_code", "fake_redirect_uri", tokenScope)
if err != nil {
t.Fatal(err)
}
if tk.AccessToken != token {
t.Fatalf("unexpected access token %s", tk.AccessToken)
}
account, err := client.Account(context.Background(), tk.Account.HomeAccountID)
if err != nil {
t.Fatal(err)
}
if params.utid == "" {
if actual := account.HomeAccountID; actual != "123-456.123-456" {
t.Fatalf("expected %q, got %q", "123-456.123-456", actual)
}
} else {
if actual := account.HomeAccountID; actual != "123-456.fake" {
t.Fatalf("expected %q, got %q", "123-456.fake", actual)
}
}
if account.PreferredUsername != "fakeuser@fakeplace.fake" {
t.Fatal("Unexpected Account.PreferredUsername")
}
// second attempt should return the cached token
tk, err = client.AcquireTokenSilent(context.Background(), tokenScope, WithSilentAccount(account))
if err != nil {
t.Fatal(err)
}
if tk.AccessToken != token {
t.Fatalf("unexpected access token %s", tk.AccessToken)
}
})
}
}
func TestAcquireTokenSilentTenants(t *testing.T) {
cred, err := NewCredFromSecret(fakeSecret)
if err != nil {
t.Fatal(err)
}
tenants := []string{"a", "b"}
lmo := "login.microsoftonline.com"
mockClient := mock.Client{}
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenants[0])))
client, err := New(fmt.Sprintf(authorityFmt, lmo, tenants[0]), fakeClientID, cred, WithHTTPClient(&mockClient))
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
// cache an access token for each tenant. To simplify determining their provenance below, the value of each token is the ID of the tenant that provided it.
for _, tenant := range tenants {
if _, err = client.AcquireTokenSilent(ctx, tokenScope, WithTenantID(tenant)); err == nil {
t.Fatal("silent auth should fail because the cache is empty")
}
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenant)))
mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(tenant, "", "", "", 3600)))
if _, err := client.AcquireTokenByCredential(ctx, tokenScope, WithTenantID(tenant)); err != nil {
t.Fatal(err)
}
}
// cache should return the correct access token for each tenant
for _, tenant := range tenants {
ar, err := client.AcquireTokenSilent(ctx, tokenScope, WithTenantID(tenant))
if err != nil {
t.Fatal(err)
}
if ar.AccessToken != tenant {
t.Fatalf(`expected "%s", got "%s"`, tenant, ar.AccessToken)
}
}
}
func TestAuthorityValidation(t *testing.T) {
cred, err := NewCredFromSecret(fakeSecret)
if err != nil {
t.Fatal(err)
}
for _, a := range []string{"", "https://login.microsoftonline.com", "http://login.microsoftonline.com/tenant"} {
t.Run(a, func(t *testing.T) {
_, err := New(a, fakeClientID, cred)
if err == nil || !strings.Contains(err.Error(), "authority") {
t.Fatalf("expected an error about the invalid authority, got %v", err)
}
})
}
}
func TestInvalidCredential(t *testing.T) {
for _, cred := range []Credential{
{},
NewCredFromAssertionCallback(nil),
} {
t.Run("", func(t *testing.T) {
_, err := New(fakeAuthority, fakeClientID, cred)
if err == nil {
t.Fatal("expected an error")
}
})
}
}
func TestNewCredFromCert(t *testing.T) {
for _, file := range []struct {
path string
numCerts int
}{
{"../testdata/test-cert.pem", 1},
{"../testdata/test-cert-chain.pem", 2},
{"../testdata/test-cert-chain-reverse.pem", 2},
} {
f, err := os.Open(filepath.Clean(file.path))
if err != nil {
t.Fatal(err)
}
defer f.Close()
pemData, err := io.ReadAll(f)
if err != nil {
t.Fatal(err)
}
certs, key, err := CertFromPEM(pemData, "")
if err != nil {
t.Fatal(err)
}
if len(certs) != file.numCerts {
t.Fatalf("expected %d certs, got %d", file.numCerts, len(certs))
}
expectedCerts := make(map[string]struct{}, len(certs))
for _, cert := range certs {
expectedCerts[base64.StdEncoding.EncodeToString(cert.Raw)] = struct{}{}
}
k, ok := key.(*rsa.PrivateKey)
if !ok {
t.Fatal("expected an RSA private key")
}
verifyingKey := &k.PublicKey
cred, err := NewCredFromCert(certs, key)
if err != nil {
t.Fatal(err)
}
for _, sendX5c := range []bool{false, true} {
opts := []Option{}
if sendX5c {
opts = append(opts, WithX5C())
}
t.Run(fmt.Sprintf("%s/%v", filepath.Base(file.path), sendX5c), func(t *testing.T) {
client, err := fakeClient(accesstokens.TokenResponse{
AccessToken: token,
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: tokenScope},
}, cred, opts...)
if err != nil {
t.Fatal(err)
}
// the test fake passes assertions generated by the credential to this function
validated := false
client.base.Token.AccessTokens.(*fake.AccessTokens).ValidateAssertion = func(s string) {
validated = true
tk, err := jwt.Parse(s, func(tk *jwt.Token) (interface{}, error) {
if signingMethod, ok := tk.Method.(*jwt.SigningMethodRSA); !ok {
t.Fatalf("unexpected signing method %T", signingMethod)
}
return verifyingKey, nil
})
if err != nil {
t.Fatal(err)
}
if err = tk.Claims.Valid(); err != nil {
t.Fatal(err)
}
// x5c header should be set iff the sendX5c is true
if x5c, ok := tk.Header["x5c"]; ok != sendX5c {
t.Fatal("x5c should be set only when application passed WithX5C option")
} else if ok {
if x := len(x5c.([]interface{})); x > file.numCerts {
t.Fatalf("x5c contains %d certs; expected %d", x, file.numCerts)
}
// x5c must contain all the file's certs, signing cert first
for i, cert := range x5c.([]interface{}) {
s := cert.(string)
if _, ok := expectedCerts[s]; ok {
delete(expectedCerts, s)
} else {
t.Fatal("x5c contains an unexpected cert")
}
if i == 0 {
decoded, err := base64.StdEncoding.DecodeString(s)
if err != nil {
t.Fatal(err)
}
parsed, err := x509.ParseCertificate(decoded)
if err != nil {
t.Fatal(err)
}
if !verifyingKey.Equal(parsed.PublicKey) {
t.Fatal("signing cert must appear first in x5c")
}
}
}
if len(expectedCerts) > 0 {
t.Fatal("x5c header is missing a cert")
}
}
}
tk, err := client.AcquireTokenByCredential(context.Background(), tokenScope)
if err != nil {
t.Fatal(err)
}
if tk.AccessToken != token {
t.Fatalf("unexpected access token %s", tk.AccessToken)
}
if !validated {
t.Fatal("assertion validation function wasn't called")
}
})
}
}
}
func TestNewCredFromCertError(t *testing.T) {
data, err := os.ReadFile("../testdata/test-cert.pem")
if err != nil {
t.Fatal(err)
}
certs, key, err := CertFromPEM(data, "")
if err != nil {
t.Fatal(err)
}
for _, test := range []struct {
certs []*x509.Certificate
key crypto.PrivateKey
}{
{nil, nil},
{certs, nil},
{nil, key},
{[]*x509.Certificate{}, nil},
{[]*x509.Certificate{}, key},
{[]*x509.Certificate{nil}, nil},
{[]*x509.Certificate{nil}, key},
} {
t.Run("", func(t *testing.T) {
_, err := NewCredFromCert(test.certs, test.key)
if err == nil {
t.Fatal("expected an error")
}
})
}
// the key in this file doesn't match the cert loaded above
if data, err = os.ReadFile("../testdata/test-cert-chain.pem"); err != nil {
t.Fatal(err)
}
if _, key, err = CertFromPEM(data, ""); err != nil {
t.Fatal(err)
}
if _, err = NewCredFromCert(certs, key); err == nil {
t.Fatal("expected an error because key doesn't match certs")
}
}
func TestNewCredFromTokenProvider(t *testing.T) {
expectedToken := "expected token"
called := false
expiresIn := 4200
key := struct{}{}
ctx := context.WithValue(context.Background(), key, true)
cred := NewCredFromTokenProvider(func(c context.Context, tp exported.TokenProviderParameters) (exported.TokenProviderResult, error) {
if called {
t.Fatal("expected exactly one token provider invocation")
}
called = true
if v := c.Value(key); v == nil || !v.(bool) {
t.Fatal("callback received unexpected context")
}
if tp.CorrelationID == "" {
t.Fatal("expected CorrelationID")
}
if v := fmt.Sprint(tp.Scopes); v != fmt.Sprint(tokenScope) {
t.Fatalf(`unexpected scopes "%v"`, v)
}
return exported.TokenProviderResult{
AccessToken: expectedToken,
ExpiresInSeconds: expiresIn,
}, nil
})
client, err := New(fakeAuthority, fakeClientID, cred, WithHTTPClient(&errorClient{}))
if err != nil {
t.Fatal(err)
}
ar, err := client.AcquireTokenByCredential(ctx, tokenScope)
if err != nil {
t.Fatal(err)
}
if !called {
t.Fatal("token provider wasn't invoked")
}
if v := int(time.Until(ar.ExpiresOn).Seconds()); v < expiresIn-2 || v > expiresIn {
t.Fatalf("expected ExpiresOn ~= %d seconds, got %d", expiresIn, v)
}
if ar.AccessToken != expectedToken {
t.Fatalf(`unexpected token "%s"`, ar.AccessToken)
}
ar, err = client.AcquireTokenSilent(context.Background(), tokenScope)
if err != nil {
t.Fatal(err)
}
if ar.AccessToken != expectedToken {
t.Fatalf(`unexpected token "%s"`, ar.AccessToken)
}
}
func TestNewCredFromTokenProviderError(t *testing.T) {
expectedError := "something went wrong"
cred := NewCredFromTokenProvider(func(ctx context.Context, tpp exported.TokenProviderParameters) (exported.TokenProviderResult, error) {
return exported.TokenProviderResult{}, errors.New(expectedError)
})
client, err := New(fakeAuthority, fakeClientID, cred)
if err != nil {
t.Fatal(err)
}
_, err = client.AcquireTokenByCredential(context.Background(), tokenScope)
if err == nil || !strings.Contains(err.Error(), expectedError) {
t.Fatalf(`unexpected error "%v"`, err)
}
}
func TestTokenProviderOptions(t *testing.T) {
accessToken, claims, tenant := "at", "claims", "tenant"
cred := NewCredFromTokenProvider(func(ctx context.Context, tpp TokenProviderParameters) (TokenProviderResult, error) {
if tpp.Claims != claims {
t.Fatalf(`unexpected claims "%s"`, tpp.Claims)
}
if tpp.TenantID != tenant {
t.Fatalf(`unexpected tenant "%s"`, tpp.TenantID)
}
return TokenProviderResult{AccessToken: accessToken, ExpiresInSeconds: 3600}, nil
})
client, err := New(fakeAuthority, fakeClientID, cred, WithHTTPClient(&errorClient{}))
if err != nil {
t.Fatal(err)
}
ar, err := client.AcquireTokenByCredential(context.Background(), tokenScope, WithClaims(claims), WithTenantID(tenant))
if err != nil {
t.Fatal(err)
}
if ar.AccessToken != accessToken {
t.Fatalf(`unexpected access token "%s"`, ar.AccessToken)
}
}
// testCache is a simple in-memory cache.ExportReplace implementation
type testCache map[string][]byte
func (c testCache) Export(ctx context.Context, m cache.Marshaler, h cache.ExportHints) error {
if v, err := m.Marshal(); err == nil {
c[h.PartitionKey] = v
}
return nil
}
func (c testCache) Replace(ctx context.Context, u cache.Unmarshaler, h cache.ReplaceHints) error {
if v, has := c[h.PartitionKey]; has {
_ = u.Unmarshal(v)
}
return nil
}
func TestWithCache(t *testing.T) {
cache := make(testCache)
accessToken := "*"
lmo := "login.microsoftonline.com"
tenantA, tenantB := "a", "b"
authorityA, authorityB := fmt.Sprintf(authorityFmt, lmo, tenantA), fmt.Sprintf(authorityFmt, lmo, tenantB)
mockClient := mock.Client{}
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenantA)))
mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(accessToken, mock.GetIDToken(tenantA, authorityA), "", "", 3600)))
cred, err := NewCredFromSecret(fakeSecret)
if err != nil {
t.Fatal(err)
}
client, err := New(authorityA, fakeClientID, cred, WithCache(&cache), WithHTTPClient(&mockClient))
if err != nil {
t.Fatal(err)
}
// The particular flow isn't important, we just need to populate the cache. Auth code is the simplest for this test
ar, err := client.AcquireTokenByAuthCode(context.Background(), "code", "https://localhost", tokenScope)
if err != nil {
t.Fatal(err)
}
if ar.AccessToken != accessToken {
t.Fatalf(`unexpected access token "%s"`, ar.AccessToken)
}
account := ar.Account
if actual := account.Realm; actual != tenantA {
t.Fatalf(`unexpected realm "%s"`, actual)
}
// a client configured for a different tenant should be able to authenticate silently with the shared cache's data
client, err = New(authorityB, fakeClientID, cred, WithCache(&cache), WithHTTPClient(&mockClient))
if err != nil {
t.Fatal(err)
}
// this should succeed because the cache contains an access token from tenantA
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenantA)))
ar, err = client.AcquireTokenSilent(context.Background(), tokenScope, WithSilentAccount(account), WithTenantID(tenantA))
if err != nil {
t.Fatal(err)
}
if ar.AccessToken != accessToken {
t.Fatalf(`unexpected access token "%s"`, ar.AccessToken)
}
// this should fail because the cache doesn't contain an access token from tenantB
ar, err = client.AcquireTokenSilent(context.Background(), tokenScope, WithSilentAccount(account))
if err == nil {
t.Fatal("expected an error because the cache doesn't have an appropriate access token")
}
}
func TestWithClaims(t *testing.T) {
cred, err := NewCredFromSecret(fakeSecret)
if err != nil {
t.Fatal(err)
}
accessToken := "at"
lmo, tenant := "login.microsoftonline.com", "tenant"
authority := fmt.Sprintf(authorityFmt, lmo, tenant)
for _, test := range []struct {
capabilities []string
claims, expected string
}{
{},
{
capabilities: []string{"cp1"},
expected: `{"access_token":{"xms_cc":{"values":["cp1"]}}}`,
},
{
claims: `{"id_token":{"auth_time":{"essential":true}}}`,
expected: `{"id_token":{"auth_time":{"essential":true}}}`,
},
{
capabilities: []string{"cp1", "cp2"},
claims: `{"access_token":{"nbf":{"essential":true, "value":"42"}}}`,
expected: `{"access_token":{"nbf":{"essential":true, "value":"42"}, "xms_cc":{"values":["cp1","cp2"]}}}`,
},
} {
var expected map[string]any
if err := json.Unmarshal([]byte(test.expected), &expected); err != nil && test.expected != "" {
t.Fatal("test bug: the expected result must be JSON or an empty string")
}
validate := func(t *testing.T, v url.Values) {
if test.expected == "" {
if v.Has("claims") {
t.Fatal("claims shouldn't be set")
}
return
}
claims, ok := v["claims"]
if !ok {
t.Fatal("claims should be set")
}
if len(claims) != 1 {
t.Fatalf("expected 1 value for claims, got %d", len(claims))
}
var actual map[string]any
if err := json.Unmarshal([]byte(claims[0]), &actual); err != nil {
t.Fatal(err)
}
if diff := pretty.Compare(expected, actual); diff != "" {
t.Fatal(diff)
}
}
for _, method := range []string{"authcode", "authcodeURL", "credential", "obo"} {
t.Run(method, func(t *testing.T) {
mockClient := mock.Client{}
clientInfo, idToken, refreshToken := "", "", ""
if method == "obo" {
clientInfo = base64.RawStdEncoding.EncodeToString([]byte(`{"uid":"uid","utid":"utid"}`))
idToken = mock.GetIDToken(tenant, authority)
refreshToken = "rt"
// TODO: OBO does instance discovery twice before first token request https://github.com/AzureAD/microsoft-authentication-library-for-go/issues/351
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenant)))
}
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenant)))
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenant)))
mockClient.AppendResponse(
mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, clientInfo, 3600)),
mock.WithCallback(func(r *http.Request) {
if err := r.ParseForm(); err != nil {
t.Fatal(err)
}
validate(t, r.Form)
}),
)
client, err := New(authority, fakeClientID, cred, WithClientCapabilities(test.capabilities), WithHTTPClient(&mockClient))
if err != nil {
t.Fatal(err)
}
if _, err = client.AcquireTokenSilent(context.Background(), tokenScope); err == nil {
t.Fatal("silent authentication should fail because the cache is empty")
}
ctx := context.Background()
var ar AuthResult
switch method {
case "authcode":
ar, err = client.AcquireTokenByAuthCode(ctx, "code", localhost, tokenScope, WithClaims(test.claims))
case "authcodeURL":
u := ""
if u, err = client.AuthCodeURL(ctx, fakeClientID, localhost, tokenScope, WithClaims(test.claims)); err == nil {
var parsed *url.URL
if parsed, err = url.Parse(u); err == nil {
validate(t, parsed.Query())
return // didn't acquire a token, no need for further validation
}
}
case "credential":
ar, err = client.AcquireTokenByCredential(ctx, tokenScope, WithClaims(test.claims))
case "obo":
ar, err = client.AcquireTokenOnBehalfOf(ctx, "assertion", tokenScope, WithClaims(test.claims))
default:
t.Fatalf("test bug: no test for " + method)
}
if err != nil {
t.Fatal(err)
}
if ar.AccessToken != accessToken {
t.Fatalf(`unexpected access token "%s"`, ar.AccessToken)
}
// silent auth should now succeed, provided no claims are requested, because the client has cached an access token
if method == "obo" {
ar, err = client.AcquireTokenOnBehalfOf(ctx, "assertion", tokenScope)
} else {
ar, err = client.AcquireTokenSilent(ctx, tokenScope)
}
if err != nil {
t.Fatal(err)
}
if ar.AccessToken != accessToken {
t.Fatalf(`unexpected access token "%s"`, ar.AccessToken)
}
if test.claims != "" {
if _, err = client.AcquireTokenSilent(ctx, tokenScope, WithClaims(test.claims)); err == nil {
t.Fatal("AcquireTokenSilent should fail when given claims")
}
if method == "obo" {
// client has cached access and refresh tokens. When given claims, it should redeem a refresh token for a new access token.
newToken := "new-access-token"
mockClient.AppendResponse(
mock.WithBody(mock.GetAccessTokenBody(newToken, idToken, "", clientInfo, 3600)),
mock.WithCallback(func(r *http.Request) {
if err := r.ParseForm(); err != nil {
t.Fatal(err)
}
// all token requests should include any specified claims
validate(t, r.Form)
if actual := r.Form.Get("refresh_token"); actual != refreshToken {
t.Fatalf(`unexpected refresh token "%s"`, actual)
}
}),
)
ar, err = client.AcquireTokenOnBehalfOf(ctx, "assertion", tokenScope, WithClaims(test.claims))
if err != nil {
t.Fatal(err)
}
if ar.AccessToken != newToken {
t.Fatalf(`unexpected access token "%s"`, ar.AccessToken)
}
}
}
})
}
}
}
func TestWithTenantID(t *testing.T) {
accessToken := "*"
uuid1 := "00000000-0000-0000-0000-000000000000"
uuid2 := strings.ReplaceAll(uuid1, "0", "1")
lmo := "login.microsoftonline.com"
host := fmt.Sprintf("https://%s/", lmo)
for _, test := range []struct {
authority, expectedAuthority, tenant string
expectError bool
}{
{authority: host + "common", tenant: uuid1, expectedAuthority: host + uuid1},
{authority: host + "organizations", tenant: uuid1, expectedAuthority: host + uuid1},
{authority: host + uuid1, tenant: uuid2, expectedAuthority: host + uuid2},
{authority: host + uuid1, tenant: "common", expectError: true},
{authority: host + uuid1, tenant: "organizations", expectError: true},
{authority: host + "consumers", tenant: uuid1, expectError: true},
} {
for _, method := range []string{"authcode", "authcodeURL", "credential", "obo"} {
t.Run(method, func(t *testing.T) {
cred, err := NewCredFromSecret(fakeSecret)
if err != nil {
t.Fatal(err)
}
idToken, refreshToken, URL := "", "", ""
mockClient := mock.Client{}
if method == "obo" {
idToken = mock.GetIDToken(test.tenant, test.authority)
refreshToken = "refresh-token"
// TODO: OBO does instance discovery twice before first token request https://github.com/AzureAD/microsoft-authentication-library-for-go/issues/351
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, test.tenant)))
}
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, test.tenant)))
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, test.tenant)))
mockClient.AppendResponse(
mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, "", 3600)),
mock.WithCallback(func(r *http.Request) { URL = r.URL.String() }),
)
client, err := New(test.authority, fakeClientID, cred, WithHTTPClient(&mockClient))
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
if _, err = client.AcquireTokenSilent(ctx, tokenScope, WithTenantID(test.tenant)); err == nil {
t.Fatal("silent auth should fail because the cache is empty")
}
var ar AuthResult
switch method {
case "authcode":
ar, err = client.AcquireTokenByAuthCode(ctx, "auth code", localhost, tokenScope, WithTenantID(test.tenant))
case "authcodeURL":
URL, err = client.AuthCodeURL(ctx, fakeClientID, localhost, tokenScope, WithTenantID(test.tenant))
case "credential":
ar, err = client.AcquireTokenByCredential(ctx, tokenScope, WithTenantID(test.tenant))
case "obo":
ar, err = client.AcquireTokenOnBehalfOf(ctx, "assertion", tokenScope, WithTenantID(test.tenant))
default:
t.Fatalf("test bug: no test for " + method)
}
if err != nil {
if test.expectError {
return
}
t.Fatal(err)
} else if test.expectError {
t.Fatal("expected an error")
}
if !strings.HasPrefix(URL, test.expectedAuthority) {
t.Fatalf(`expected "%s", got "%s"`, test.expectedAuthority, URL)
}
if method == "authcodeURL" {
// didn't acquire a token, no need to test silent auth
return
}
if ar.AccessToken != accessToken {
t.Fatalf(`unexpected access token "%s"`, ar.AccessToken)
}
// silent authentication should now succeed for the given tenant...
if method == "obo" {
if ar, err = client.AcquireTokenOnBehalfOf(ctx, "assertion", tokenScope, WithTenantID(test.tenant)); err != nil {
t.Fatal(err)
}
} else if ar, err = client.AcquireTokenSilent(ctx, tokenScope, WithTenantID(test.tenant)); err != nil {
t.Fatal(err)
}
if ar.AccessToken != accessToken {
t.Fatal("cached access token should match the one returned by AcquireToken...")
}
// ...but fail for another tenant unless we're authenticating OBO, in which case we have a refresh token
otherTenant := "not-" + test.tenant
if method == "obo" {
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, test.tenant)))
mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, "", 3600)))
if _, err = client.AcquireTokenOnBehalfOf(ctx, "assertion", tokenScope, WithTenantID(otherTenant)); err != nil {
t.Fatal(err)
}
} else if _, err = client.AcquireTokenSilent(ctx, tokenScope, WithTenantID(otherTenant)); err == nil {
t.Fatal("expected an error")
}
})
}
}
// if every auth call specifies a different tenant, Client shouldn't send requests to its configured authority
t.Run("enables fake authority", func(t *testing.T) {
host := "host"
defaultTenant := "default"
cred, err := NewCredFromSecret(fakeSecret)
if err != nil {
t.Fatal(err)
}
URL := ""
mockClient := mock.Client{}
client, err := New(fmt.Sprintf(authorityFmt, host, defaultTenant), fakeClientID, cred, WithHTTPClient(&mockClient))
if err != nil {
t.Fatal(err)
}
checkForWrongTenant := func(r *http.Request) {
if u := r.URL.String(); strings.Contains(u, defaultTenant) {
t.Fatalf("unexpected request to the default authority: %q", u)
}
}
ctx := context.Background()
for i := 0; i < 3; i++ {
tenant := fmt.Sprint(i)
expected := fmt.Sprintf(authorityFmt, host, tenant)
// TODO: prevent redundant discovery requests https://github.com/AzureAD/microsoft-authentication-library-for-go/issues/351
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(host, tenant)), mock.WithCallback(checkForWrongTenant))
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(host, tenant)), mock.WithCallback(checkForWrongTenant))
mockClient.AppendResponse(
mock.WithBody(mock.GetAccessTokenBody(accessToken, "", "", "", 3600)),
mock.WithCallback(func(r *http.Request) { URL = r.URL.String() }),
)
if i == 0 {
// TODO: see above (first silent auth rediscovers instance metadata)
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(host, tenant)), mock.WithCallback(checkForWrongTenant))
}
ar, err := client.AcquireTokenByAuthCode(ctx, "auth code", localhost, tokenScope, WithTenantID(tenant))
if err != nil {
t.Fatal(err)
}
if !strings.HasPrefix(URL, expected) {
t.Fatalf(`expected "%s", got "%s"`, expected, URL)
}
if ar.AccessToken != accessToken {
t.Fatalf("unexpected access token %q", ar.AccessToken)
}
// silent authentication should now succeed for the given tenant...
if ar, err = client.AcquireTokenSilent(ctx, tokenScope, WithTenantID(tenant)); err != nil {
t.Fatal(err)
}
if ar.AccessToken != accessToken {
t.Fatal("cached access token should match the one returned by AcquireToken...")
}
// ...but fail for another tenant
otherTenant := "not-" + tenant
if _, err = client.AcquireTokenSilent(ctx, tokenScope, WithTenantID(otherTenant)); err == nil {
t.Fatal("expected an error")
}
}
})
}
func TestWithInstanceDiscovery(t *testing.T) {
accessToken := "*"
host := "stack.local"
stackurl := fmt.Sprintf("https://%s/", host)
for _, tenant := range []string{
"adfs",
"98b8267d-e97f-426e-8b3f-7956511fd63f",
} {
for _, method := range []string{"authcode", "credential", "obo"} {
t.Run(method, func(t *testing.T) {
authority := stackurl + tenant
cred, err := NewCredFromSecret(fakeSecret)
if err != nil {
t.Fatal(err)
}
idToken, refreshToken := "", ""
mockClient := mock.Client{}
if method == "obo" {
idToken = mock.GetIDToken(tenant, authority)
refreshToken = "refresh-token"
}
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(stackurl, tenant)))
mockClient.AppendResponse(
mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, "", 3600)),
)
client, err := New(authority, fakeClientID, cred, WithHTTPClient(&mockClient), WithInstanceDiscovery(false))
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
if _, err = client.AcquireTokenSilent(ctx, tokenScope); err == nil {
t.Fatal("silent auth should fail because the cache is empty")
}
var ar AuthResult
switch method {
case "authcode":
ar, err = client.AcquireTokenByAuthCode(ctx, "auth code", localhost, tokenScope)
case "credential":
ar, err = client.AcquireTokenByCredential(ctx, tokenScope)
case "obo":
ar, err = client.AcquireTokenOnBehalfOf(ctx, "assertion", tokenScope)
default:
t.Fatal("test bug: no test for " + method)
}
if err != nil {
t.Fatal(err)
}
if ar.AccessToken != accessToken {
t.Fatalf(`unexpected access token "%s"`, ar.AccessToken)
}
if method == "obo" {
if ar, err = client.AcquireTokenOnBehalfOf(ctx, "assertion", tokenScope); err != nil {
t.Fatal(err)
}
} else if ar, err = client.AcquireTokenSilent(ctx, tokenScope); err != nil {
t.Fatal(err)
}
if ar.AccessToken != accessToken {
t.Fatal("cached access token should match the one returned by AcquireToken...")
}
})
}
}
}
func TestWithPortAuthority(t *testing.T) {
accessToken := "*"
sl := "stack.local"
port := ":3001"
host := sl + port
tenant := "00000000-0000-0000-0000-000000000000"
authority := fmt.Sprintf("https://%s%s/%s", sl, port, tenant)
cred, err := NewCredFromSecret(fakeSecret)
if err != nil {
t.Fatal(err)
}
idToken, refreshToken, URL := "", "", ""
mockClient := mock.Client{}
//2 calls to instance discovery are made because Host is not trusted
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(host, tenant)))
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(host, tenant)))
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(host, tenant)))
mockClient.AppendResponse(
mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, "", 3600)),
mock.WithCallback(func(r *http.Request) { URL = r.URL.String() }),
)
client, err := New(authority, fakeClientID, cred, WithHTTPClient(&mockClient))
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
if _, err = client.AcquireTokenSilent(ctx, tokenScope); err == nil {
t.Fatal("silent auth should fail because the cache is empty")
}
var ar AuthResult
ar, err = client.AcquireTokenByAuthCode(ctx, "auth code", localhost, tokenScope)
if err != nil {
t.Fatal(err)
}
if !strings.HasPrefix(URL, authority) {
t.Fatalf(`expected "%s", got "%s"`, authority, URL)
}
if ar.AccessToken != accessToken {
t.Fatalf(`unexpected access token "%s"`, ar.AccessToken)
}
if ar, err = client.AcquireTokenSilent(ctx, tokenScope); err != nil {
t.Fatal(err)
}
if ar.AccessToken != accessToken {
t.Fatal("cached access token should match the one returned by AcquireToken...")
}
}
func TestWithLoginHint(t *testing.T) {
upn := "user@localhost"
cred, err := NewCredFromSecret(fakeSecret)
if err != nil {
t.Fatal(err)
}
client, err := New(fakeAuthority, fakeClientID, cred, WithHTTPClient(&errorClient{}))
if err != nil {
t.Fatal(err)
}
client.base.Token.Resolver = &fake.ResolveEndpoints{}
for _, expectHint := range []bool{true, false} {
t.Run(fmt.Sprint(expectHint), func(t *testing.T) {
opts := []AuthCodeURLOption{}
if expectHint {
opts = append(opts, WithLoginHint(upn))
}
u, err := client.AuthCodeURL(context.Background(), "id", localhost, tokenScope, opts...)
if err != nil {
t.Fatal(err)
}
parsed, err := url.Parse(u)
if err != nil {
t.Fatal(err)
}
if !parsed.Query().Has("login_hint") {
if !expectHint {
return
}
t.Fatal("expected a login hint")
} else if !expectHint {
t.Fatal("expected no login hint")
}
if actual := parsed.Query()["login_hint"]; len(actual) != 1 || actual[0] != upn {
t.Fatalf(`unexpected login_hint "%v"`, actual)
}
})
}
}
func TestWithDomainHint(t *testing.T) {
domain := "contoso.com"
cred, err := NewCredFromSecret(fakeSecret)
if err != nil {
t.Fatal(err)
}
client, err := New(fakeAuthority, fakeClientID, cred, WithHTTPClient(&errorClient{}))
if err != nil {
t.Fatal(err)
}
client.base.Token.Resolver = &fake.ResolveEndpoints{}
for _, expectHint := range []bool{true, false} {
t.Run(fmt.Sprint(expectHint), func(t *testing.T) {
var opts []AuthCodeURLOption
if expectHint {
opts = append(opts, WithDomainHint(domain))
}
u, err := client.AuthCodeURL(context.Background(), "id", localhost, tokenScope, opts...)
if err != nil {
t.Fatal(err)
}
parsed, err := url.Parse(u)
if err != nil {
t.Fatal(err)
}
if !parsed.Query().Has("domain_hint") {
if !expectHint {
return
}
t.Fatal("expected a domain hint")
} else if !expectHint {
t.Fatal("expected no domain hint")
}
if actual := parsed.Query()["domain_hint"]; len(actual) != 1 || actual[0] != domain {
t.Fatalf(`unexpected domain_hint "%v"`, actual)
}
})
}
}
microsoft-authentication-library-for-go-1.0.0/apps/confidential/examples_test.go 0000664 0000000 0000000 00000001337 14420263624 0030200 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package confidential_test
import (
"fmt"
"log"
"os"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential"
)
func ExampleNewCredFromCert_pem() {
b, err := os.ReadFile("key.pem")
if err != nil {
log.Fatal(err)
}
// This extracts our public certificates and private key from the PEM file. If it is
// encrypted, the second argument must be password to decode.
certs, priv, err := confidential.CertFromPEM(b, "")
if err != nil {
log.Fatal(err)
}
cred, err := confidential.NewCredFromCert(certs, priv)
if err != nil {
log.Fatal(err)
}
fmt.Println(cred) // Simply here so cred is used, otherwise won't compile.
}
microsoft-authentication-library-for-go-1.0.0/apps/design/ 0000775 0000000 0000000 00000000000 14420263624 0023602 5 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/apps/design/design.md 0000664 0000000 0000000 00000015126 14420263624 0025402 0 ustar 00root root 0000000 0000000 # MSAL Go Design Guide
Author: John Doak(jdoak@microsoft.com)
Contributors:
- Keegan Caruso(Keegan.Caruso@microsoft.com)
- Joel Hendrix(jhendrix@microsoft.com)
- Santiago Gonzalez(Santiago.Gonzalez@microsoft.com)
- Bogdan Gavril (bogavril@microsoft.com)
## History
The original code submitted for Go MSAL was a translation of either Java or .Net code. This was done as a best effort by an intern who was attempting their first crack at Go. It had a
very interesting structure that didn't fit into Go style and made it difficult to understand or
change. It used global locks, global variables, base type classes (mimicing inheritence), ...
This probably should have be re-written from scratch, but we decided to try and do it in pieces.
The lesson to be learned from this is that this type of refactor leads to re-writing the code 7 or 8 times instead of once.
Much of this lead to a re-write where we were not seeing the forrest because of the trees. Every small change would inevitably become some 60 file refactor and have much larger ramifications than intended.
The work could not be divided up, because the API and the internals were linked across logical
boundaries.
What has resulted should be a design that divides code into logical layers and splits
the public API from the internal structure.
## General Structure
Public Surface:
```
apps/ - Contains all our code
confidential/ - The confidential application API
public/ - The public application API
cache/ - The cache interface that can be implemented to provide persistence cache storage of credentials
```
Internals:
```
apps/
internal/
client/ - Shared package for common calls that Public and Confidential apps share
json/ - Our own json encoder/decoder for special needs
shared/ - Holds types that need to be in multiple packages and can't be moved into a single one due to import cycles
requests/ - The pacakge to communicate to services to get tokens
```
### Use of the Go special internal/ directory
In Go, a directory called internal/ contains packages that should only be used by other packages
rooted at the same location.
This is documented here: https://golang.org/doc/go1.4#internalpackages
For example, a package .../a/b/c/internal/d/e/f can be imported only by code in the directory tree rooted at .../a/b/c. It cannot be imported by code in .../a/b/g or in any other repository.
We use this featurs quite liberally to make clear what is using an internal package. For example:
```
apps/internal/base - Only can be used by packages defined at apps/
apps/internal/base/internal/storage - Only can be use by package client
```
## Public API
The public API will be encapsulated in apps/. apps/ has 3 packages of interest to users:
- public/ - This is what MSAL calls the Public Application Client (service client)
- confidential/ - This is what MSAL calls the Confidential Application Client (service)
- cache/ - This provides the interfaces that must be implemented to create peristent caches for any MSAL client
## Internals
In this section we will be talking about internal/.
### JSON Handling
JSON must be handled specially in our app. The basics are, if we receive fields that our
structs do not contain, we cannot drop them. We must send them back to the service.
To handle that, we use our own custom json package that handles this.
See the design at: [Design](https://github.com/AzureAD/microsoft-authentication-library-for-go/blob/dev/internal/json/design.md)
### Backend communication
Communication to the backends is done via the requests/ package. oauth.Token is the client
for all communication.
oauth.Token communicates via REST calls that are encapsulated in the ops/ client.
## Adding A Feature
This is the general way to add a new feature to MSAL:
- Add the REST calls to ops.REST
- Add the higher level manipulations to oauth.Token
- Add your logic to the app/\ and access the services via your oauth.Token
## Notable Differences To Other Clients
### TBD: Confidential applications needs to handle multiple users without one big cache
The MSAL caching design is rather simple. These design decisions and the fact that multiple applications in different languages can share a cache mean it cannot be easily changed.
The entire cache contents of a confidential.Client is read and written on
almost any action to and from an external cache.
It is not clear to a user that a confidential client should be per user to prevent scaling
problems.
We cannot change the MSAL cache design at this time, therefore it should be clear that
confidential.Client should be done per user. This must go beyond a simple doc entry
that can be ignored. Its great to say: "we told you in the doc", but that is AFTER a support call.
TBD ...
### Use of x509.Certificate and CertFromPEM() function
The original version of this package used an thumbprint and a private key to do authorizations
based on a certificate. But there wasn't a real way to get a thumbprint.
A thumbprint is defined in the Oauth spec, which we had to track down. It is an SHA-1 hash
from the x509 certificate's DER encdoed ASN1 bytes.
Since the user was going to need the x509, we moved to having the user provide the x509.Certificate
object.
We wrote the thumbprint creator for the internals.
Since we also require the private key and it is not straightforward to get, we added a CertFromPEM()
function that will extract the x509.Certificate and private key. We did support encrypted PEM.
It should be noted that Keyvault stores things in PKCS12 and PEM. Keyvault is not straight forward
in how it works. Frankly, I'm in serious doubt that a regular Go user can get certs out of
Keyvault's Go API.
Before I began working on MSAL I was re-writing the Keyvault Go API. https://github.com/element-of-surprise/keyvault . It does the right things to extract cers for TLS now.
I was still working on the Cert() API and hadn't exposed the public surface when I stopped.
Since we have representation from the Go SDK team, we might have them go bridge this problem in
the current implementation using some of that code so its possible for our users to store the
cert in Keyvault.
## Logging
For errors, see [error design](../errors/error_design.md).
This library does not log personal identifiable information (PII). For a definition of PII, see https://www.microsoft.com/en-us/trust-center/privacy/customer-data-definitions. MSAL Go does not log any of the 3 data categories listed there.
The library may log information related to your organization, such as tenant id, authority, client id etc. as well as information that cannot be tied to a user such as request correlation id, HTTP status codes etc.
microsoft-authentication-library-for-go-1.0.0/apps/errors/ 0000775 0000000 0000000 00000000000 14420263624 0023645 5 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/apps/errors/error_design.md 0000664 0000000 0000000 00000010542 14420263624 0026653 0 ustar 00root root 0000000 0000000 # MSAL Error Design
Author: Abhidnya Patil(abhidnya.patil@microsoft.com)
Contributors:
- John Doak(jdoak@microsoft.com)
- Keegan Caruso(Keegan.Caruso@microsoft.com)
- Joel Hendrix(jhendrix@microsoft.com)
## Background
Errors in MSAL are intended for app developers to troubleshoot and not for displaying to end-users.
### Go error handling vs other MSAL languages
Most modern languages use exception based errors. Simply put, you "throw" an exception and it must be caught at some routine in the upper stack or it will eventually crash the program.
Go doesn't use exceptions, instead it relies on multiple return values, one of which can be the builtin error interface type. It is up to the user to decide what to do.
### Go custom error types
Errors can be created in Go by simply using errors.New() or fmt.Errorf() to create an "error".
Custom errors can be created in multiple ways. One of the more robust ways is simply to satisfy the error interface:
```go
type MyCustomErr struct {
Msg string
}
func (m MyCustomErr) Error() string { // This implements "error"
return m.Msg
}
```
### MSAL Error Goals
- Provide diagnostics to the user and for tickets that can be used to track down bugs or client misconfigurations
- Detect errors that are transitory and can be retried
- Allow the user to identify certain errors that the program can respond to, such a informing the user for the need to do an enrollment
## Implementing Client Side Errors
Client side errors indicate a misconfiguration or passing of bad arguments that is non-recoverable. Retrying isn't possible.
These errors can simply be standard Go errors created by errors.New() or fmt.Errorf(). If down the line we need a custom error, we can introduce it, but for now the error messages just need to be clear on what the issue was.
## Implementing Service Side Errors
Service side errors occur when an external RPC responds either with an HTTP error code or returns a message that includes an error.
These errors can be transitory (please slow down) or permanent (HTTP 404). To provide our diagnostic goals, we require the ability to differentiate these errors from other errors.
The current implementation includes a specialized type that captures any error from the server:
```go
// CallErr represents an HTTP call error. Has a Verbose() method that allows getting the
// http.Request and Response objects. Implements error.
type CallErr struct {
Req *http.Request
Resp *http.Response
Err error
}
// Errors implements error.Error().
func (e CallErr) Error() string {
return e.Err.Error()
}
// Verbose prints a versbose error message with the request or response.
func (e CallErr) Verbose() string {
e.Resp.Request = nil // This brings in a bunch of TLS stuff we don't need
e.Resp.TLS = nil // Same
return fmt.Sprintf("%s:\nRequest:\n%s\nResponse:\n%s", e.Err, prettyConf.Sprint(e.Req), prettyConf.Sprint(e.Resp))
}
```
A user will always receive the most concise error we provide. They can tell if it is a server side error using Go error package:
```go
var callErr CallErr
if errors.As(err, &callErr) {
...
}
```
We provide a Verbose() function that can retrieve the most verbose message from any error we provide:
```go
fmt.Println(errors.Verbose(err))
```
If further differentiation is required, we can add custom errors that use Go error wrapping on top of CallErr to achieve our diagnostic goals (such as detecting when to retry a call due to transient errors).
CallErr is always thrown from the comm package (which handles all http requests) and looks similar to:
```go
return nil, errors.CallErr{
Req: req,
Resp: reply,
Err: fmt.Errorf("http call(%s)(%s) error: reply status code was %d:\n%s", req.URL.String(), req.Method, reply.StatusCode, ErrorResponse), //ErrorResponse is the json body extracted from the http response
}
```
## Future Decisions
The ability to retry calls needs to have centralized responsibility. Either the user is doing it or the client is doing it.
If the user should be responsible, our errors package will include a CanRetry() function that will inform the user if the error provided to them is retryable. This is based on the http error code and possibly the type of error that was returned. It would also include a sleep time if the server returned an amount of time to wait.
Otherwise we will do this internally and retries will be left to us.
microsoft-authentication-library-for-go-1.0.0/apps/errors/errors.go 0000664 0000000 0000000 00000004056 14420263624 0025515 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package errors
import (
"errors"
"fmt"
"io"
"net/http"
"reflect"
"strings"
"github.com/kylelemons/godebug/pretty"
)
var prettyConf = &pretty.Config{
IncludeUnexported: false,
SkipZeroFields: true,
TrackCycles: true,
Formatter: map[reflect.Type]interface{}{
reflect.TypeOf((*io.Reader)(nil)).Elem(): func(r io.Reader) string {
b, err := io.ReadAll(r)
if err != nil {
return "could not read io.Reader content"
}
return string(b)
},
},
}
type verboser interface {
Verbose() string
}
// Verbose prints the most verbose error that the error message has.
func Verbose(err error) string {
build := strings.Builder{}
for {
if err == nil {
break
}
if v, ok := err.(verboser); ok {
build.WriteString(v.Verbose())
} else {
build.WriteString(err.Error())
}
err = errors.Unwrap(err)
}
return build.String()
}
// New is equivalent to errors.New().
func New(text string) error {
return errors.New(text)
}
// CallErr represents an HTTP call error. Has a Verbose() method that allows getting the
// http.Request and Response objects. Implements error.
type CallErr struct {
Req *http.Request
// Resp contains response body
Resp *http.Response
Err error
}
// Errors implements error.Error().
func (e CallErr) Error() string {
return e.Err.Error()
}
// Verbose prints a versbose error message with the request or response.
func (e CallErr) Verbose() string {
e.Resp.Request = nil // This brings in a bunch of TLS crap we don't need
e.Resp.TLS = nil // Same
return fmt.Sprintf("%s:\nRequest:\n%s\nResponse:\n%s", e.Err, prettyConf.Sprint(e.Req), prettyConf.Sprint(e.Resp))
}
// Is reports whether any error in errors chain matches target.
func Is(err, target error) bool {
return errors.Is(err, target)
}
// As finds the first error in errors chain that matches target,
// and if so, sets target to that error value and returns true.
// Otherwise, it returns false.
func As(err error, target interface{}) bool {
return errors.As(err, target)
}
microsoft-authentication-library-for-go-1.0.0/apps/internal/ 0000775 0000000 0000000 00000000000 14420263624 0024145 5 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/apps/internal/base/ 0000775 0000000 0000000 00000000000 14420263624 0025057 5 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/apps/internal/base/base.go 0000664 0000000 0000000 00000037141 14420263624 0026326 0 ustar 00root root 0000000 0000000 // Package base contains a "Base" client that is used by the external public.Client and confidential.Client.
// Base holds shared attributes that must be available to both clients and methods that act as
// shared calls.
package base
import (
"context"
"errors"
"fmt"
"net/url"
"reflect"
"strings"
"sync"
"time"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base/internal/storage"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/shared"
)
const (
// AuthorityPublicCloud is the default AAD authority host
AuthorityPublicCloud = "https://login.microsoftonline.com/common"
scopeSeparator = " "
)
// manager provides an internal cache. It is defined to allow faking the cache in tests.
// In production it's a *storage.Manager or *storage.PartitionedManager.
type manager interface {
cache.Serializer
Read(context.Context, authority.AuthParams) (storage.TokenResponse, error)
Write(authority.AuthParams, accesstokens.TokenResponse) (shared.Account, error)
}
// accountManager is a manager that also caches accounts. In production it's a *storage.Manager.
type accountManager interface {
manager
AllAccounts() []shared.Account
Account(homeAccountID string) shared.Account
RemoveAccount(account shared.Account, clientID string)
}
// AcquireTokenSilentParameters contains the parameters to acquire a token silently (from cache).
type AcquireTokenSilentParameters struct {
Scopes []string
Account shared.Account
RequestType accesstokens.AppType
Credential *accesstokens.Credential
IsAppCache bool
TenantID string
UserAssertion string
AuthorizationType authority.AuthorizeType
Claims string
}
// AcquireTokenAuthCodeParameters contains the parameters required to acquire an access token using the auth code flow.
// To use PKCE, set the CodeChallengeParameter.
// Code challenges are used to secure authorization code grants; for more information, visit
// https://tools.ietf.org/html/rfc7636.
type AcquireTokenAuthCodeParameters struct {
Scopes []string
Code string
Challenge string
Claims string
RedirectURI string
AppType accesstokens.AppType
Credential *accesstokens.Credential
TenantID string
}
type AcquireTokenOnBehalfOfParameters struct {
Scopes []string
Claims string
Credential *accesstokens.Credential
TenantID string
UserAssertion string
}
// AuthResult contains the results of one token acquisition operation in PublicClientApplication
// or ConfidentialClientApplication. For details see https://aka.ms/msal-net-authenticationresult
type AuthResult struct {
Account shared.Account
IDToken accesstokens.IDToken
AccessToken string
ExpiresOn time.Time
GrantedScopes []string
DeclinedScopes []string
}
// AuthResultFromStorage creates an AuthResult from a storage token response (which is generated from the cache).
func AuthResultFromStorage(storageTokenResponse storage.TokenResponse) (AuthResult, error) {
if err := storageTokenResponse.AccessToken.Validate(); err != nil {
return AuthResult{}, fmt.Errorf("problem with access token in StorageTokenResponse: %w", err)
}
account := storageTokenResponse.Account
accessToken := storageTokenResponse.AccessToken.Secret
grantedScopes := strings.Split(storageTokenResponse.AccessToken.Scopes, scopeSeparator)
// Checking if there was an ID token in the cache; this will throw an error in the case of confidential client applications.
var idToken accesstokens.IDToken
if !storageTokenResponse.IDToken.IsZero() {
err := idToken.UnmarshalJSON([]byte(storageTokenResponse.IDToken.Secret))
if err != nil {
return AuthResult{}, fmt.Errorf("problem decoding JWT token: %w", err)
}
}
return AuthResult{account, idToken, accessToken, storageTokenResponse.AccessToken.ExpiresOn.T, grantedScopes, nil}, nil
}
// NewAuthResult creates an AuthResult.
func NewAuthResult(tokenResponse accesstokens.TokenResponse, account shared.Account) (AuthResult, error) {
if len(tokenResponse.DeclinedScopes) > 0 {
return AuthResult{}, fmt.Errorf("token response failed because declined scopes are present: %s", strings.Join(tokenResponse.DeclinedScopes, ","))
}
return AuthResult{
Account: account,
IDToken: tokenResponse.IDToken,
AccessToken: tokenResponse.AccessToken,
ExpiresOn: tokenResponse.ExpiresOn.T,
GrantedScopes: tokenResponse.GrantedScopes.Slice,
}, nil
}
// Client is a base client that provides access to common methods and primatives that
// can be used by multiple clients.
type Client struct {
Token *oauth.Client
manager accountManager // *storage.Manager or fakeManager in tests
// pmanager is a partitioned cache for OBO authentication. *storage.PartitionedManager or fakeManager in tests
pmanager manager
AuthParams authority.AuthParams // DO NOT EVER MAKE THIS A POINTER! See "Note" in New().
cacheAccessor cache.ExportReplace
cacheAccessorMu *sync.RWMutex
}
// Option is an optional argument to the New constructor.
type Option func(c *Client) error
// WithCacheAccessor allows you to set some type of cache for storing authentication tokens.
func WithCacheAccessor(ca cache.ExportReplace) Option {
return func(c *Client) error {
if ca != nil {
c.cacheAccessor = ca
}
return nil
}
}
// WithClientCapabilities allows configuring one or more client capabilities such as "CP1"
func WithClientCapabilities(capabilities []string) Option {
return func(c *Client) error {
var err error
if len(capabilities) > 0 {
cc, err := authority.NewClientCapabilities(capabilities)
if err == nil {
c.AuthParams.Capabilities = cc
}
}
return err
}
}
// WithKnownAuthorityHosts specifies hosts Client shouldn't validate or request metadata for because they're known to the user
func WithKnownAuthorityHosts(hosts []string) Option {
return func(c *Client) error {
cp := make([]string, len(hosts))
copy(cp, hosts)
c.AuthParams.KnownAuthorityHosts = cp
return nil
}
}
// WithX5C specifies if x5c claim(public key of the certificate) should be sent to STS to enable Subject Name Issuer Authentication.
func WithX5C(sendX5C bool) Option {
return func(c *Client) error {
c.AuthParams.SendX5C = sendX5C
return nil
}
}
func WithRegionDetection(region string) Option {
return func(c *Client) error {
c.AuthParams.AuthorityInfo.Region = region
return nil
}
}
func WithInstanceDiscovery(instanceDiscoveryEnabled bool) Option {
return func(c *Client) error {
c.AuthParams.AuthorityInfo.ValidateAuthority = instanceDiscoveryEnabled
c.AuthParams.AuthorityInfo.InstanceDiscoveryDisabled = !instanceDiscoveryEnabled
return nil
}
}
// New is the constructor for Base.
func New(clientID string, authorityURI string, token *oauth.Client, options ...Option) (Client, error) {
//By default, validateAuthority is set to true and instanceDiscoveryDisabled is set to false
authInfo, err := authority.NewInfoFromAuthorityURI(authorityURI, true, false)
if err != nil {
return Client{}, err
}
authParams := authority.NewAuthParams(clientID, authInfo)
client := Client{ // Note: Hey, don't even THINK about making Base into *Base. See "design notes" in public.go and confidential.go
Token: token,
AuthParams: authParams,
cacheAccessorMu: &sync.RWMutex{},
manager: storage.New(token),
pmanager: storage.NewPartitionedManager(token),
}
for _, o := range options {
if err = o(&client); err != nil {
break
}
}
return client, err
}
// AuthCodeURL creates a URL used to acquire an authorization code.
func (b Client) AuthCodeURL(ctx context.Context, clientID, redirectURI string, scopes []string, authParams authority.AuthParams) (string, error) {
endpoints, err := b.Token.ResolveEndpoints(ctx, authParams.AuthorityInfo, "")
if err != nil {
return "", err
}
baseURL, err := url.Parse(endpoints.AuthorizationEndpoint)
if err != nil {
return "", err
}
claims, err := authParams.MergeCapabilitiesAndClaims()
if err != nil {
return "", err
}
v := url.Values{}
v.Add("client_id", clientID)
v.Add("response_type", "code")
v.Add("redirect_uri", redirectURI)
v.Add("scope", strings.Join(scopes, scopeSeparator))
if authParams.State != "" {
v.Add("state", authParams.State)
}
if claims != "" {
v.Add("claims", claims)
}
if authParams.CodeChallenge != "" {
v.Add("code_challenge", authParams.CodeChallenge)
}
if authParams.CodeChallengeMethod != "" {
v.Add("code_challenge_method", authParams.CodeChallengeMethod)
}
if authParams.LoginHint != "" {
v.Add("login_hint", authParams.LoginHint)
}
if authParams.Prompt != "" {
v.Add("prompt", authParams.Prompt)
}
if authParams.DomainHint != "" {
v.Add("domain_hint", authParams.DomainHint)
}
// There were left over from an implementation that didn't use any of these. We may
// need to add them later, but as of now aren't needed.
/*
if p.ResponseMode != "" {
urlParams.Add("response_mode", p.ResponseMode)
}
*/
baseURL.RawQuery = v.Encode()
return baseURL.String(), nil
}
func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilentParameters) (AuthResult, error) {
ar := AuthResult{}
// when tenant == "", the caller didn't specify a tenant and WithTenant will choose the client's configured tenant
tenant := silent.TenantID
authParams, err := b.AuthParams.WithTenant(tenant)
if err != nil {
return ar, err
}
authParams.Scopes = silent.Scopes
authParams.HomeAccountID = silent.Account.HomeAccountID
authParams.AuthorizationType = silent.AuthorizationType
authParams.Claims = silent.Claims
authParams.UserAssertion = silent.UserAssertion
m := b.pmanager
if authParams.AuthorizationType != authority.ATOnBehalfOf {
authParams.AuthorizationType = authority.ATRefreshToken
m = b.manager
}
if b.cacheAccessor != nil {
key := authParams.CacheKey(silent.IsAppCache)
b.cacheAccessorMu.RLock()
err = b.cacheAccessor.Replace(ctx, m, cache.ReplaceHints{PartitionKey: key})
b.cacheAccessorMu.RUnlock()
}
if err != nil {
return ar, err
}
storageTokenResponse, err := m.Read(ctx, authParams)
if err != nil {
return ar, err
}
// ignore cached access tokens when given claims
if silent.Claims == "" {
ar, err = AuthResultFromStorage(storageTokenResponse)
if err == nil {
return ar, err
}
}
// redeem a cached refresh token, if available
if reflect.ValueOf(storageTokenResponse.RefreshToken).IsZero() {
return ar, errors.New("no token found")
}
var cc *accesstokens.Credential
if silent.RequestType == accesstokens.ATConfidential {
cc = silent.Credential
}
token, err := b.Token.Refresh(ctx, silent.RequestType, authParams, cc, storageTokenResponse.RefreshToken)
if err != nil {
return ar, err
}
return b.AuthResultFromToken(ctx, authParams, token, true)
}
func (b Client) AcquireTokenByAuthCode(ctx context.Context, authCodeParams AcquireTokenAuthCodeParameters) (AuthResult, error) {
authParams, err := b.AuthParams.WithTenant(authCodeParams.TenantID)
if err != nil {
return AuthResult{}, err
}
authParams.Claims = authCodeParams.Claims
authParams.Scopes = authCodeParams.Scopes
authParams.Redirecturi = authCodeParams.RedirectURI
authParams.AuthorizationType = authority.ATAuthCode
var cc *accesstokens.Credential
if authCodeParams.AppType == accesstokens.ATConfidential {
cc = authCodeParams.Credential
authParams.IsConfidentialClient = true
}
req, err := accesstokens.NewCodeChallengeRequest(authParams, authCodeParams.AppType, cc, authCodeParams.Code, authCodeParams.Challenge)
if err != nil {
return AuthResult{}, err
}
token, err := b.Token.AuthCode(ctx, req)
if err != nil {
return AuthResult{}, err
}
return b.AuthResultFromToken(ctx, authParams, token, true)
}
// AcquireTokenOnBehalfOf acquires a security token for an app using middle tier apps access token.
func (b Client) AcquireTokenOnBehalfOf(ctx context.Context, onBehalfOfParams AcquireTokenOnBehalfOfParameters) (AuthResult, error) {
var ar AuthResult
silentParameters := AcquireTokenSilentParameters{
Scopes: onBehalfOfParams.Scopes,
RequestType: accesstokens.ATConfidential,
Credential: onBehalfOfParams.Credential,
UserAssertion: onBehalfOfParams.UserAssertion,
AuthorizationType: authority.ATOnBehalfOf,
TenantID: onBehalfOfParams.TenantID,
Claims: onBehalfOfParams.Claims,
}
ar, err := b.AcquireTokenSilent(ctx, silentParameters)
if err == nil {
return ar, err
}
authParams, err := b.AuthParams.WithTenant(onBehalfOfParams.TenantID)
if err != nil {
return AuthResult{}, err
}
authParams.AuthorizationType = authority.ATOnBehalfOf
authParams.Claims = onBehalfOfParams.Claims
authParams.Scopes = onBehalfOfParams.Scopes
authParams.UserAssertion = onBehalfOfParams.UserAssertion
token, err := b.Token.OnBehalfOf(ctx, authParams, onBehalfOfParams.Credential)
if err == nil {
ar, err = b.AuthResultFromToken(ctx, authParams, token, true)
}
return ar, err
}
func (b Client) AuthResultFromToken(ctx context.Context, authParams authority.AuthParams, token accesstokens.TokenResponse, cacheWrite bool) (AuthResult, error) {
if !cacheWrite {
return NewAuthResult(token, shared.Account{})
}
var m manager = b.manager
if authParams.AuthorizationType == authority.ATOnBehalfOf {
m = b.pmanager
}
key := token.CacheKey(authParams)
if b.cacheAccessor != nil {
b.cacheAccessorMu.Lock()
defer b.cacheAccessorMu.Unlock()
err := b.cacheAccessor.Replace(ctx, m, cache.ReplaceHints{PartitionKey: key})
if err != nil {
return AuthResult{}, err
}
}
account, err := m.Write(authParams, token)
if err != nil {
return AuthResult{}, err
}
ar, err := NewAuthResult(token, account)
if err == nil && b.cacheAccessor != nil {
err = b.cacheAccessor.Export(ctx, b.manager, cache.ExportHints{PartitionKey: key})
}
return ar, err
}
func (b Client) AllAccounts(ctx context.Context) ([]shared.Account, error) {
if b.cacheAccessor != nil {
b.cacheAccessorMu.RLock()
defer b.cacheAccessorMu.RUnlock()
key := b.AuthParams.CacheKey(false)
err := b.cacheAccessor.Replace(ctx, b.manager, cache.ReplaceHints{PartitionKey: key})
if err != nil {
return nil, err
}
}
return b.manager.AllAccounts(), nil
}
func (b Client) Account(ctx context.Context, homeAccountID string) (shared.Account, error) {
if b.cacheAccessor != nil {
b.cacheAccessorMu.RLock()
defer b.cacheAccessorMu.RUnlock()
authParams := b.AuthParams // This is a copy, as we don't have a pointer receiver and .AuthParams is not a pointer.
authParams.AuthorizationType = authority.AccountByID
authParams.HomeAccountID = homeAccountID
key := b.AuthParams.CacheKey(false)
err := b.cacheAccessor.Replace(ctx, b.manager, cache.ReplaceHints{PartitionKey: key})
if err != nil {
return shared.Account{}, err
}
}
return b.manager.Account(homeAccountID), nil
}
// RemoveAccount removes all the ATs, RTs and IDTs from the cache associated with this account.
func (b Client) RemoveAccount(ctx context.Context, account shared.Account) error {
if b.cacheAccessor == nil {
b.manager.RemoveAccount(account, b.AuthParams.ClientID)
return nil
}
b.cacheAccessorMu.Lock()
defer b.cacheAccessorMu.Unlock()
key := b.AuthParams.CacheKey(false)
err := b.cacheAccessor.Replace(ctx, b.manager, cache.ReplaceHints{PartitionKey: key})
if err != nil {
return err
}
b.manager.RemoveAccount(account, b.AuthParams.ClientID)
return b.cacheAccessor.Export(ctx, b.manager, cache.ExportHints{PartitionKey: key})
}
microsoft-authentication-library-for-go-1.0.0/apps/internal/base/base_test.go 0000664 0000000 0000000 00000034004 14420263624 0027360 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package base
import (
"context"
"errors"
"fmt"
"reflect"
"testing"
"time"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base/internal/storage"
internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/fake"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/shared"
"github.com/kylelemons/godebug/pretty"
)
const (
fakeAccessToken = "fake-access-token"
fakeAuthority = "fake_authority"
fakeClientID = "fake-client-id"
fakeRefreshToken = "fake-refresh-token"
fakeTenantID = "fake-tenant-id"
fakeUsername = "fake-username"
)
var (
fakeIDToken = accesstokens.IDToken{
Oid: "oid",
PreferredUsername: fakeUsername,
RawToken: "x.e30",
TenantID: fakeTenantID,
UPN: fakeUsername,
}
testScopes = []string{"scope"}
)
func fakeClient(t *testing.T, opts ...Option) Client {
client, err := New(fakeClientID, fmt.Sprintf("https://%s/%s", fakeAuthority, fakeTenantID), &oauth.Client{}, opts...)
if err != nil {
t.Fatal(err)
}
client.Token.AccessTokens = &fake.AccessTokens{
AccessToken: accesstokens.TokenResponse{
AccessToken: fakeAccessToken,
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)},
FamilyID: "family-id",
GrantedScopes: accesstokens.Scopes{Slice: testScopes},
IDToken: fakeIDToken,
RefreshToken: fakeRefreshToken,
},
}
client.Token.Authority = &fake.Authority{
InstanceResp: authority.InstanceDiscoveryResponse{
Metadata: []authority.InstanceDiscoveryMetadata{
{Aliases: []string{fakeAuthority}, PreferredNetwork: fakeAuthority},
},
TenantDiscoveryEndpoint: fmt.Sprintf("https://%s/fake/discovery/endpoint", fakeAuthority),
},
}
client.Token.Resolver = &fake.ResolveEndpoints{
Endpoints: authority.NewEndpoints(
fmt.Sprintf("https://%s/fake/auth", fakeAuthority),
fmt.Sprintf("https://%s/fake/token", fakeAuthority),
fmt.Sprintf("https://%s/fake/jwt", fakeAuthority),
fakeAuthority,
),
}
return client
}
func TestAcquireTokenSilentEmptyCache(t *testing.T) {
client := fakeClient(t)
_, err := client.AcquireTokenSilent(context.Background(), AcquireTokenSilentParameters{
Account: shared.NewAccount("homeAccountID", "env", "realm", "localAccountID", authority.AAD, "username"),
Scopes: testScopes,
})
if err == nil {
t.Fatal("expected an error because the cache is empty")
}
}
func TestAcquireTokenSilentScopes(t *testing.T) {
// ensure fakeIDToken.RawToken unmarshals (doesn't matter to what) because an unmarshalling
// error can conceal a test bug by making an "err != nil" check true for the wrong reason
var idToken accesstokens.IDToken
if err := idToken.UnmarshalJSON([]byte(fakeIDToken.RawToken)); err != nil {
t.Fatal(err)
}
for _, test := range []struct {
desc string
cachedTokenScopes []string
}{
{"expired access token", testScopes},
{"no access token", []string{"other-" + testScopes[0]}},
} {
t.Run(test.desc, func(t *testing.T) {
client := fakeClient(t)
validated := false
client.Token.AccessTokens.(*fake.AccessTokens).FromRefreshTokenCallback = func(at accesstokens.AppType, ap authority.AuthParams, cc *accesstokens.Credential, rt string) {
validated = true
if !reflect.DeepEqual(ap.Scopes, testScopes) {
t.Fatalf("unexpected scopes: %v", ap.Scopes)
}
if cc != nil {
t.Fatal("client shouldn't have a credential")
}
if rt != fakeRefreshToken {
t.Fatal("unexpected refresh token")
}
}
// cache a refresh token and an expired access token for the given scopes
// (testing only the public client code path)
storage.FakeValidate = func(storage.AccessToken) error { return nil }
account, err := client.manager.Write(
authority.AuthParams{
AuthorityInfo: authority.Info{
AuthorityType: authority.AAD,
Host: fakeAuthority,
Tenant: fakeIDToken.TenantID,
},
ClientID: fakeClientID,
Scopes: test.cachedTokenScopes,
Username: fakeIDToken.PreferredUsername,
},
accesstokens.TokenResponse{
AccessToken: fakeAccessToken,
ClientInfo: accesstokens.ClientInfo{UID: "uid", UTID: "utid"},
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(-time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: test.cachedTokenScopes},
IDToken: fakeIDToken,
RefreshToken: fakeRefreshToken,
},
)
storage.FakeValidate = nil
if err != nil {
t.Fatal(err)
}
// AcquireTokenSilent should redeem the refresh token for a new access token
ar, err := client.AcquireTokenSilent(context.Background(), AcquireTokenSilentParameters{Account: account, Scopes: testScopes})
if err != nil {
t.Fatal(err)
}
if ar.AccessToken != fakeAccessToken {
t.Fatal("unexpected access token")
}
if !validated {
t.Fatal("FromRefreshTokenCallback wasn't called")
}
})
}
}
func TestAcquireTokenSilentGrantedScopes(t *testing.T) {
client := fakeClient(t)
grantedScopes := []string{"scope1", "scope2"}
expectedToken := "not-" + fakeAccessToken
account, err := client.manager.Write(
authority.AuthParams{
AuthorityInfo: authority.Info{
AuthorityType: authority.AAD,
Host: fakeAuthority,
Tenant: fakeIDToken.TenantID,
},
ClientID: fakeClientID,
Scopes: grantedScopes[1:],
},
accesstokens.TokenResponse{
AccessToken: expectedToken,
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: grantedScopes},
},
)
if err != nil {
t.Fatal(err)
}
for _, scope := range grantedScopes {
ar, err := client.AcquireTokenSilent(context.Background(), AcquireTokenSilentParameters{Account: account, Scopes: []string{scope}})
if err != nil {
t.Fatal(err)
}
if ar.AccessToken != expectedToken {
t.Fatal("unexpected access token")
}
}
}
// failCache helps tests inject cache I/O errors
type failCache struct {
exported bool
exportErr, replaceErr error
}
func (c *failCache) Export(context.Context, cache.Marshaler, cache.ExportHints) error {
c.exported = true
return c.exportErr
}
func (c failCache) Replace(context.Context, cache.Unmarshaler, cache.ReplaceHints) error {
return c.replaceErr
}
func TestCacheIOErrors(t *testing.T) {
ctx := context.Background()
expected := errors.New("cache error")
for _, export := range []bool{true, false} {
name := "replace"
cache := failCache{}
if export {
cache.exportErr = expected
name = "export"
} else {
cache.replaceErr = expected
}
t.Run(name, func(t *testing.T) {
client := fakeClient(t, WithCacheAccessor(&cache))
if !export {
// Account and AllAccounts don't export the cache, and AcquireTokenSilent does so
// only after redeeming a refresh token, so we test them only for replace errors
_, actual := client.Account(ctx, "...")
if !errors.Is(actual, expected) {
t.Fatalf(`expected "%v", got "%v"`, expected, actual)
}
_, actual = client.AllAccounts(ctx)
if !errors.Is(actual, expected) {
t.Fatalf(`expected "%v", got "%v"`, expected, actual)
}
_, actual = client.AcquireTokenSilent(ctx, AcquireTokenSilentParameters{Scopes: testScopes})
if cache.replaceErr != nil && !errors.Is(actual, expected) {
t.Fatalf(`expected "%v", got "%v"`, expected, actual)
}
}
_, actual := client.AcquireTokenByAuthCode(ctx, AcquireTokenAuthCodeParameters{AppType: accesstokens.ATConfidential, Scopes: testScopes})
if !errors.Is(actual, expected) {
t.Fatalf(`expected "%v", got "%v"`, expected, actual)
}
_, actual = client.AcquireTokenOnBehalfOf(ctx, AcquireTokenOnBehalfOfParameters{Credential: &accesstokens.Credential{Secret: "..."}, Scopes: testScopes})
if !errors.Is(actual, expected) {
t.Fatalf(`expected "%v", got "%v"`, expected, actual)
}
_, actual = client.AuthResultFromToken(ctx, authority.AuthParams{}, accesstokens.TokenResponse{}, true)
if !errors.Is(actual, expected) {
t.Fatalf(`expected "%v", got "%v"`, expected, actual)
}
actual = client.RemoveAccount(ctx, shared.Account{})
if !errors.Is(actual, expected) {
t.Fatalf(`expected "%v", got "%v"`, expected, actual)
}
})
}
// testing that AcquireTokenSilent propagates errors from Export requires more elaborate
// setup because that method exports the cache only after acquiring a new access token
t.Run("silent auth export error", func(t *testing.T) {
cache := failCache{}
hid := "uid.utid"
client := fakeClient(t, WithCacheAccessor(&cache))
// cache fake tokens and app metadata
_, err := client.AuthResultFromToken(ctx,
authority.AuthParams{
AuthorityInfo: authority.Info{Host: fakeAuthority},
ClientID: fakeClientID,
HomeAccountID: hid,
Scopes: testScopes,
},
accesstokens.TokenResponse{
AccessToken: "at",
ClientInfo: accesstokens.ClientInfo{UID: "uid", UTID: "utid"},
GrantedScopes: accesstokens.Scopes{Slice: testScopes},
IDToken: fakeIDToken,
RefreshToken: "rt",
},
true,
)
if err != nil {
t.Fatal(err)
}
// AcquireTokenSilent should return this error after redeeming a refresh token
cache.exportErr = expected
_, actual := client.AcquireTokenSilent(ctx,
AcquireTokenSilentParameters{
Account: shared.NewAccount(hid, fakeAuthority, "realm", "id", authority.AAD, "upn"),
Scopes: []string{"not-" + testScopes[0]},
},
)
if !errors.Is(actual, expected) {
t.Fatalf(`expected "%v", got "%v"`, expected, actual)
}
})
// when the client fails to acquire a token, it should return an error instead of exporting the cache
t.Run("auth error", func(t *testing.T) {
cache := failCache{}
client := fakeClient(t, WithCacheAccessor(&cache))
client.Token.AccessTokens.(*fake.AccessTokens).Err = true
_, err := client.AcquireTokenByAuthCode(ctx, AcquireTokenAuthCodeParameters{AppType: accesstokens.ATConfidential})
if err == nil || cache.exported {
t.Fatal("client should have returned an error instead of exporting the cache")
}
_, err = client.AcquireTokenOnBehalfOf(ctx, AcquireTokenOnBehalfOfParameters{Credential: &accesstokens.Credential{Secret: "..."}})
if err == nil || cache.exported {
t.Fatal("client should have returned an error instead of exporting the cache")
}
_, err = client.AcquireTokenSilent(ctx, AcquireTokenSilentParameters{})
if err == nil || cache.exported {
t.Fatal("client should have returned an error instead of exporting the cache")
}
})
}
func TestCreateAuthenticationResult(t *testing.T) {
future := time.Now().Add(400 * time.Second)
tests := []struct {
desc string
input accesstokens.TokenResponse
want AuthResult
err bool
}{
{
desc: "no declined scopes",
input: accesstokens.TokenResponse{
AccessToken: "accessToken",
ExpiresOn: internalTime.DurationTime{T: future},
GrantedScopes: accesstokens.Scopes{Slice: []string{"user.read"}},
DeclinedScopes: nil,
},
want: AuthResult{
AccessToken: "accessToken",
ExpiresOn: future,
GrantedScopes: []string{"user.read"},
DeclinedScopes: nil,
},
},
{
desc: "declined scopes",
input: accesstokens.TokenResponse{
AccessToken: "accessToken",
ExpiresOn: internalTime.DurationTime{T: future},
GrantedScopes: accesstokens.Scopes{Slice: []string{"user.read"}},
DeclinedScopes: []string{"openid"},
},
err: true,
},
}
for _, test := range tests {
got, err := NewAuthResult(test.input, shared.Account{})
switch {
case err == nil && test.err:
t.Errorf("TestCreateAuthenticationResult(%s): got err == nil, want err != nil", test.desc)
continue
case err != nil && !test.err:
t.Errorf("TestCreateAuthenticationResult(%s): got err == %s, want err == nil", test.desc, err)
case err != nil:
continue
}
if diff := pretty.Compare(test.want, got); diff != "" {
t.Errorf("TestCreateAuthenticationResult(%s): -want/+got:\n%s", test.desc, diff)
}
}
}
func TestAuthResultFromStorage(t *testing.T) {
now := time.Now()
future := time.Now().Add(500 * time.Second)
tests := []struct {
desc string
storeToken storage.TokenResponse
want AuthResult
err bool
}{
{
desc: "Error: AccessToken.Validate error (AccessToken.CachedAt not set)",
storeToken: storage.TokenResponse{
AccessToken: storage.AccessToken{
ExpiresOn: internalTime.Unix{T: future},
Secret: "secret",
Scopes: "profile openid user.read",
},
IDToken: storage.IDToken{Secret: "x.e30"},
},
err: true,
},
{
desc: "Success",
storeToken: storage.TokenResponse{
AccessToken: storage.AccessToken{
CachedAt: internalTime.Unix{T: now},
ExpiresOn: internalTime.Unix{T: future},
Secret: "secret",
Scopes: "profile openid user.read",
},
IDToken: storage.IDToken{Secret: "x.e30"},
},
want: AuthResult{
AccessToken: "secret",
IDToken: accesstokens.IDToken{
RawToken: "x.e30",
},
ExpiresOn: future,
GrantedScopes: []string{"profile", "openid", "user.read"},
},
},
}
for _, test := range tests {
got, err := AuthResultFromStorage(test.storeToken)
switch {
case err == nil && test.err:
t.Errorf("TestAuthResultFromStorage(%s): got err == nil, want == != nil", test.desc)
continue
case err != nil && !test.err:
t.Errorf("TestAuthResultFromStorage(%s): got err == %s, want == nil", test.desc, err)
continue
case err != nil:
continue
}
if diff := (&pretty.Config{IncludeUnexported: false}).Compare(test.want, got); diff != "" {
t.Errorf("TestAuthResultFromStorage: -want/+got:\n%s", diff)
}
}
}
microsoft-authentication-library-for-go-1.0.0/apps/internal/base/internal/ 0000775 0000000 0000000 00000000000 14420263624 0026673 5 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/apps/internal/base/internal/storage/ 0000775 0000000 0000000 00000000000 14420263624 0030337 5 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/apps/internal/base/internal/storage/items.go 0000664 0000000 0000000 00000016007 14420263624 0032013 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package storage
import (
"errors"
"fmt"
"reflect"
"strings"
"time"
internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/shared"
)
// Contract is the JSON structure that is written to any storage medium when serializing
// the internal cache. This design is shared between MSAL versions in many languages.
// This cannot be changed without design that includes other SDKs.
type Contract struct {
AccessTokens map[string]AccessToken `json:"AccessToken,omitempty"`
RefreshTokens map[string]accesstokens.RefreshToken `json:"RefreshToken,omitempty"`
IDTokens map[string]IDToken `json:"IdToken,omitempty"`
Accounts map[string]shared.Account `json:"Account,omitempty"`
AppMetaData map[string]AppMetaData `json:"AppMetadata,omitempty"`
AdditionalFields map[string]interface{}
}
// Contract is the JSON structure that is written to any storage medium when serializing
// the internal cache. This design is shared between MSAL versions in many languages.
// This cannot be changed without design that includes other SDKs.
type InMemoryContract struct {
AccessTokensPartition map[string]map[string]AccessToken
RefreshTokensPartition map[string]map[string]accesstokens.RefreshToken
IDTokensPartition map[string]map[string]IDToken
AccountsPartition map[string]map[string]shared.Account
AppMetaData map[string]AppMetaData
}
// NewContract is the constructor for Contract.
func NewInMemoryContract() *InMemoryContract {
return &InMemoryContract{
AccessTokensPartition: map[string]map[string]AccessToken{},
RefreshTokensPartition: map[string]map[string]accesstokens.RefreshToken{},
IDTokensPartition: map[string]map[string]IDToken{},
AccountsPartition: map[string]map[string]shared.Account{},
AppMetaData: map[string]AppMetaData{},
}
}
// NewContract is the constructor for Contract.
func NewContract() *Contract {
return &Contract{
AccessTokens: map[string]AccessToken{},
RefreshTokens: map[string]accesstokens.RefreshToken{},
IDTokens: map[string]IDToken{},
Accounts: map[string]shared.Account{},
AppMetaData: map[string]AppMetaData{},
AdditionalFields: map[string]interface{}{},
}
}
// AccessToken is the JSON representation of a MSAL access token for encoding to storage.
type AccessToken struct {
HomeAccountID string `json:"home_account_id,omitempty"`
Environment string `json:"environment,omitempty"`
Realm string `json:"realm,omitempty"`
CredentialType string `json:"credential_type,omitempty"`
ClientID string `json:"client_id,omitempty"`
Secret string `json:"secret,omitempty"`
Scopes string `json:"target,omitempty"`
ExpiresOn internalTime.Unix `json:"expires_on,omitempty"`
ExtendedExpiresOn internalTime.Unix `json:"extended_expires_on,omitempty"`
CachedAt internalTime.Unix `json:"cached_at,omitempty"`
UserAssertionHash string `json:"user_assertion_hash,omitempty"`
AdditionalFields map[string]interface{}
}
// NewAccessToken is the constructor for AccessToken.
func NewAccessToken(homeID, env, realm, clientID string, cachedAt, expiresOn, extendedExpiresOn time.Time, scopes, token string) AccessToken {
return AccessToken{
HomeAccountID: homeID,
Environment: env,
Realm: realm,
CredentialType: "AccessToken",
ClientID: clientID,
Secret: token,
Scopes: scopes,
CachedAt: internalTime.Unix{T: cachedAt.UTC()},
ExpiresOn: internalTime.Unix{T: expiresOn.UTC()},
ExtendedExpiresOn: internalTime.Unix{T: extendedExpiresOn.UTC()},
}
}
// Key outputs the key that can be used to uniquely look up this entry in a map.
func (a AccessToken) Key() string {
return strings.Join(
[]string{a.HomeAccountID, a.Environment, a.CredentialType, a.ClientID, a.Realm, a.Scopes},
shared.CacheKeySeparator,
)
}
// FakeValidate enables tests to fake access token validation
var FakeValidate func(AccessToken) error
// Validate validates that this AccessToken can be used.
func (a AccessToken) Validate() error {
if FakeValidate != nil {
return FakeValidate(a)
}
if a.CachedAt.T.After(time.Now()) {
return errors.New("access token isn't valid, it was cached at a future time")
}
if a.ExpiresOn.T.Before(time.Now().Add(5 * time.Minute)) {
return fmt.Errorf("access token is expired")
}
if a.CachedAt.T.IsZero() {
return fmt.Errorf("access token does not have CachedAt set")
}
return nil
}
// IDToken is the JSON representation of an MSAL id token for encoding to storage.
type IDToken struct {
HomeAccountID string `json:"home_account_id,omitempty"`
Environment string `json:"environment,omitempty"`
Realm string `json:"realm,omitempty"`
CredentialType string `json:"credential_type,omitempty"`
ClientID string `json:"client_id,omitempty"`
Secret string `json:"secret,omitempty"`
UserAssertionHash string `json:"user_assertion_hash,omitempty"`
AdditionalFields map[string]interface{}
}
// IsZero determines if IDToken is the zero value.
func (i IDToken) IsZero() bool {
v := reflect.ValueOf(i)
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
if !field.IsZero() {
switch field.Kind() {
case reflect.Map, reflect.Slice:
if field.Len() == 0 {
continue
}
}
return false
}
}
return true
}
// NewIDToken is the constructor for IDToken.
func NewIDToken(homeID, env, realm, clientID, idToken string) IDToken {
return IDToken{
HomeAccountID: homeID,
Environment: env,
Realm: realm,
CredentialType: "IDToken",
ClientID: clientID,
Secret: idToken,
}
}
// Key outputs the key that can be used to uniquely look up this entry in a map.
func (id IDToken) Key() string {
return strings.Join(
[]string{id.HomeAccountID, id.Environment, id.CredentialType, id.ClientID, id.Realm},
shared.CacheKeySeparator,
)
}
// AppMetaData is the JSON representation of application metadata for encoding to storage.
type AppMetaData struct {
FamilyID string `json:"family_id,omitempty"`
ClientID string `json:"client_id,omitempty"`
Environment string `json:"environment,omitempty"`
AdditionalFields map[string]interface{}
}
// NewAppMetaData is the constructor for AppMetaData.
func NewAppMetaData(familyID, clientID, environment string) AppMetaData {
return AppMetaData{
FamilyID: familyID,
ClientID: clientID,
Environment: environment,
}
}
// Key outputs the key that can be used to uniquely look up this entry in a map.
func (a AppMetaData) Key() string {
return strings.Join(
[]string{"AppMetaData", a.Environment, a.ClientID},
shared.CacheKeySeparator,
)
}
microsoft-authentication-library-for-go-1.0.0/apps/internal/base/internal/storage/items_test.go 0000664 0000000 0000000 00000037217 14420263624 0033060 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package storage
import (
stdJSON "encoding/json"
"os"
"testing"
"time"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json"
internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/shared"
"github.com/kylelemons/godebug/pretty"
)
var (
testHID = "testHID"
env = "env"
credential = "AccessToken"
clientID = "clientID"
realm = "realm"
scopes = "user.read"
secret = "access"
expiresOn = time.Unix(1592049600, 0)
extExpiresOn = time.Unix(1592049600, 0)
cachedAt = time.Unix(1592049600, 0)
atCacheEntity = &AccessToken{
HomeAccountID: testHID,
Environment: env,
CredentialType: credential,
ClientID: clientID,
Realm: realm,
Scopes: scopes,
Secret: secret,
ExpiresOn: internalTime.Unix{T: expiresOn},
ExtendedExpiresOn: internalTime.Unix{T: extExpiresOn},
CachedAt: internalTime.Unix{T: cachedAt},
}
)
func TestCreateAccessToken(t *testing.T) {
testExpiresOn := time.Date(2020, time.June, 13, 12, 0, 0, 0, time.UTC)
testExtExpiresOn := time.Date(2020, time.June, 13, 12, 0, 0, 0, time.UTC)
testCachedAt := time.Date(2020, time.June, 13, 11, 0, 0, 0, time.UTC)
actualAt := NewAccessToken("testHID",
"env",
"realm",
"clientID",
testCachedAt,
testExpiresOn,
testExtExpiresOn,
"user.read",
"access",
)
if !extExpiresOn.Equal(actualAt.ExtendedExpiresOn.T) {
t.Errorf("Actual ext expires on %s differs from expected ext expires on %s", actualAt.ExtendedExpiresOn, extExpiresOn)
}
}
func TestKeyForAccessToken(t *testing.T) {
const want = "testHID-env-AccessToken-clientID-realm-user.read"
got := atCacheEntity.Key()
if got != want {
t.Errorf("TestKeyForAccessToken: got %s, want %s", got, want)
}
}
func TestAccessTokenUnmarshal(t *testing.T) {
jsonMap := map[string]interface{}{
"home_account_id": "testHID",
"environment": "env",
"extra": "this_is_extra",
"cached_at": "100",
}
jsonData, err := stdJSON.Marshal(jsonMap)
if err != nil {
panic(err)
}
want := &AccessToken{
HomeAccountID: testHID,
Environment: env,
CachedAt: internalTime.Unix{T: time.Unix(100, 0)},
AdditionalFields: map[string]interface{}{
"extra": json.MarshalRaw("this_is_extra"),
},
}
got := &AccessToken{}
err = json.Unmarshal(jsonData, got)
if err != nil {
t.Errorf("Error is supposed to be nil, but it is %v", err)
}
if diff := pretty.Compare(want, got); diff != "" {
t.Errorf("TestAccessTokenUnmarshal(access tokens): -want/+got:\n %s", diff)
}
}
func TestAccessTokenMarshal(t *testing.T) {
accessToken := &AccessToken{
HomeAccountID: testHID,
Environment: "",
CachedAt: internalTime.Unix{T: time.Unix(100, 0)},
CredentialType: credential,
AdditionalFields: map[string]interface{}{
"extra": json.MarshalRaw("this_is_extra"),
},
}
b, err := json.Marshal(accessToken)
if err != nil {
t.Fatalf("TestAccessTokenMarshal: unable to marshal: %s", err)
}
got := AccessToken{}
if err := json.Unmarshal(b, &got); err != nil {
t.Fatalf("TestAccessTokenMarshal: unable to take JSON byte output and unmarshal: %s", err)
}
if diff := pretty.Compare(accessToken, got); diff != "" {
t.Errorf("TestAccessTokenConvertToJSONMap(access token): -want/+got:\n%s", diff)
}
}
var (
appClient = "cid"
appEnv = "env"
appMeta = &AppMetaData{
ClientID: appClient,
Environment: appEnv,
FamilyID: "",
}
)
func TestKeyForAppMetaData(t *testing.T) {
want := "AppMetaData-env-cid"
got := appMeta.Key()
if want != got {
t.Errorf("actual key %v differs from expected key %v", want, got)
}
}
func TestAppMetaDataUnmarshal(t *testing.T) {
jsonMap := map[string]interface{}{
"environment": "env",
"extra": "this_is_extra",
"cached_at": "100",
"client_id": "cid",
"family_id": nil,
}
want := AppMetaData{
ClientID: "cid",
Environment: "env",
AdditionalFields: map[string]interface{}{
"extra": json.MarshalRaw("this_is_extra"),
"cached_at": json.MarshalRaw("100"),
},
}
b, err := stdJSON.Marshal(jsonMap)
if err != nil {
panic(err)
}
got := AppMetaData{}
if err := json.Unmarshal(b, &got); err != nil {
t.Fatalf("TestAppMetaDataUnmarshal(unmarshal): got err == %s, want err == nil", err)
}
if diff := pretty.Compare(want, got); diff != "" {
t.Fatalf("TestAppMetaDataUnmarshal: -want/+got:\n%s", diff)
}
}
func TestAppMetaDataMarshal(t *testing.T) {
AppMetaData := AppMetaData{
Environment: "",
ClientID: appClient,
FamilyID: "",
AdditionalFields: map[string]interface{}{
"extra": "this_is_extra",
"cached_at": "100",
},
}
want := map[string]interface{}{
"client_id": "cid",
"extra": "this_is_extra",
"cached_at": "100",
}
b, err := json.Marshal(AppMetaData)
if err != nil {
panic(err)
}
got := map[string]interface{}{}
if err := stdJSON.Unmarshal(b, &got); err != nil {
t.Fatalf("TestAppMetaDataMarshal(unmarshal): err == %s, want err == nil", err)
}
if diff := pretty.Compare(want, got); diff != "" {
t.Errorf("TestAppMetaDataConvertToJSONMap: -want/+got:\n%s", diff)
}
}
func TestContractUnmarshalJSON(t *testing.T) {
testCache, err := os.ReadFile(testFile)
if err != nil {
panic(err)
}
got := Contract{}
err = json.Unmarshal(testCache, &got)
if err != nil {
t.Fatalf("TestContractUnmarshalJSON(unmarshal): %v", err)
}
want := Contract{
AccessTokens: map[string]AccessToken{
"an-entry": {
AdditionalFields: map[string]interface{}{
"foo": json.MarshalRaw("bar"),
},
},
"uid.utid-login.windows.net-accesstoken-my_client_id-contoso-s2 s1 s3": {
Environment: defaultEnvironment,
CredentialType: "AccessToken",
Secret: accessTokenSecret,
Realm: defaultRealm,
Scopes: defaultScopes,
ClientID: defaultClientID,
CachedAt: internalTime.Unix{T: atCached},
HomeAccountID: defaultHID,
ExpiresOn: internalTime.Unix{T: atExpires},
ExtendedExpiresOn: internalTime.Unix{T: atExpires},
},
},
Accounts: map[string]shared.Account{
"uid.utid-login.windows.net-contoso": {
PreferredUsername: "John Doe",
LocalAccountID: "object1234",
Realm: "contoso",
Environment: "login.windows.net",
HomeAccountID: "uid.utid",
AuthorityType: "MSSTS",
},
},
RefreshTokens: map[string]accesstokens.RefreshToken{
"uid.utid-login.windows.net-refreshtoken-my_client_id--s2 s1 s3": {
Target: defaultScopes,
Environment: defaultEnvironment,
CredentialType: "RefreshToken",
Secret: rtSecret,
ClientID: defaultClientID,
HomeAccountID: defaultHID,
},
},
IDTokens: map[string]IDToken{
"uid.utid-login.windows.net-idtoken-my_client_id-contoso-": {
Realm: defaultRealm,
Environment: defaultEnvironment,
CredentialType: idCred,
Secret: idSecret,
ClientID: defaultClientID,
HomeAccountID: defaultHID,
},
},
AppMetaData: map[string]AppMetaData{
"AppMetadata-login.windows.net-my_client_id": {
Environment: defaultEnvironment,
FamilyID: "",
ClientID: defaultClientID,
},
},
AdditionalFields: map[string]interface{}{
"unknownEntity": json.MarshalRaw(
map[string]interface{}{
"field1": "1",
"field2": "whats",
},
),
},
}
if diff := pretty.Compare(want, got); diff != "" {
t.Errorf("TestContractUnmarshalJSON: -want/+got:\n%s", diff)
t.Errorf(string(got.AdditionalFields["unknownEntity"].(stdJSON.RawMessage)))
}
}
func TestContractMarshalJSON(t *testing.T) {
want := Contract{
AccessTokens: map[string]AccessToken{
"an-entry": {
AdditionalFields: map[string]interface{}{
"foo": json.MarshalRaw("bar"),
},
},
"uid.utid-login.windows.net-accesstoken-my_client_id-contoso-s2 s1 s3": {
Environment: defaultEnvironment,
CredentialType: "AccessToken",
Secret: accessTokenSecret,
Realm: defaultRealm,
Scopes: defaultScopes,
ClientID: defaultClientID,
CachedAt: internalTime.Unix{T: atCached},
HomeAccountID: defaultHID,
ExpiresOn: internalTime.Unix{T: atExpires},
ExtendedExpiresOn: internalTime.Unix{T: atExpires},
},
},
RefreshTokens: map[string]accesstokens.RefreshToken{
"uid.utid-login.windows.net-refreshtoken-my_client_id--s2 s1 s3": {
Target: defaultScopes,
Environment: defaultEnvironment,
CredentialType: "RefreshToken",
Secret: rtSecret,
ClientID: defaultClientID,
HomeAccountID: defaultHID,
},
},
IDTokens: map[string]IDToken{
"uid.utid-login.windows.net-idtoken-my_client_id-contoso-": {
Realm: defaultRealm,
Environment: defaultEnvironment,
CredentialType: idCred,
Secret: idSecret,
ClientID: defaultClientID,
HomeAccountID: defaultHID,
},
},
Accounts: map[string]shared.Account{
"uid.utid-login.windows.net-contoso": {
PreferredUsername: accUser,
LocalAccountID: accLID,
Realm: defaultRealm,
Environment: defaultEnvironment,
HomeAccountID: defaultHID,
AuthorityType: accAuth,
},
},
AppMetaData: map[string]AppMetaData{
"AppMetadata-login.windows.net-my_client_id": {
Environment: defaultEnvironment,
FamilyID: "",
ClientID: defaultClientID,
},
},
AdditionalFields: map[string]interface{}{
"unknownEntity": json.MarshalRaw(
map[string]interface{}{
"field1": "1",
"field2": "whats",
},
),
},
}
b, err := json.Marshal(want)
if err != nil {
t.Fatalf("TestContractMarshalJSON(marshal): got err == %s, want err == nil", err)
}
got := Contract{}
if err := json.Unmarshal(b, &got); err != nil {
t.Fatalf("TestContractMarshalJSON(unmarshal back): got err == %s, want err == nil", err)
}
if diff := pretty.Compare(want, got); diff != "" {
t.Errorf("TestContractMarshalJSON: -want/+got:\n%s", diff)
}
}
var (
idHid = "HID"
idEnv = "env"
idCredential = "IdToken"
idClient = "clientID"
idRealm = "realm"
idTokSecret = "id"
)
var idToken = IDToken{
HomeAccountID: idHid,
Environment: idEnv,
CredentialType: idCredential,
ClientID: idClient,
Realm: idRealm,
Secret: idTokSecret,
}
func TestKeyForIDToken(t *testing.T) {
want := "HID-env-IdToken-clientID-realm"
if idToken.Key() != want {
t.Errorf("actual key %v differs from expected key %v", idToken.Key(), want)
}
}
func TestIDTokenUnmarshal(t *testing.T) {
jsonMap := map[string]interface{}{
"home_account_id": "HID",
"environment": "env",
"extra": "this_is_extra",
}
b, err := stdJSON.Marshal(jsonMap)
if err != nil {
panic(err)
}
want := IDToken{
HomeAccountID: "HID",
Environment: "env",
AdditionalFields: map[string]interface{}{
"extra": json.MarshalRaw("this_is_extra"),
},
}
got := IDToken{}
if err := json.Unmarshal(b, &got); err != nil {
panic(err)
}
if diff := pretty.Compare(want, got); diff != "" {
t.Errorf("TestIDTokenUnmarshal: -want/+got:\n%s", diff)
}
}
func TestIDTokenMarshal(t *testing.T) {
idToken := IDToken{
HomeAccountID: idHid,
Environment: idEnv,
Realm: "",
AdditionalFields: map[string]interface{}{"extra": "this_is_extra"},
}
want := map[string]interface{}{
"home_account_id": "HID",
"environment": "env",
"extra": "this_is_extra",
}
b, err := json.Marshal(idToken)
if err != nil {
panic(err)
}
got := map[string]interface{}{}
if err := stdJSON.Unmarshal(b, &got); err != nil {
panic(err)
}
if diff := pretty.Compare(want, got); diff != "" {
t.Errorf("TestIDTokenMarshal: -want/+got:\n%s", diff)
}
}
var (
hid = "HID"
rtEnv = "env"
rtClientID = "clientID"
rtCredential = "accesstokens.RefreshToken"
refSecret = "secret"
)
var rt = &accesstokens.RefreshToken{
HomeAccountID: hid,
Environment: env,
ClientID: rtClientID,
CredentialType: rtCredential,
Secret: refSecret,
}
func TestNewRefreshToken(t *testing.T) {
got := accesstokens.NewRefreshToken("HID", "env", "clientID", "secret", "")
if refSecret != got.Secret {
t.Errorf("expected secret %s differs from actualSecret %s", refSecret, got.Secret)
}
}
func TestKeyForRefreshToken(t *testing.T) {
want := "HID-env-accesstokens.RefreshToken-clientID"
got := rt.Key()
if want != got {
t.Errorf("Actual key %v differs from expected key %v", got, want)
}
}
func TestRefreshTokenUnmarshal(t *testing.T) {
jsonMap := map[string]interface{}{
"home_account_id": "hid",
"environment": "env",
"extra": "this_is_extra",
"secret": "secret",
}
b, err := stdJSON.Marshal(jsonMap)
if err != nil {
panic(err)
}
want := accesstokens.RefreshToken{
HomeAccountID: "hid",
Environment: "env",
Secret: "secret",
AdditionalFields: map[string]interface{}{
"extra": json.MarshalRaw("this_is_extra"),
},
}
got := accesstokens.RefreshToken{}
err = json.Unmarshal(b, &got)
if err != nil {
panic(err)
}
if diff := pretty.Compare(want, got); diff != "" {
t.Errorf("TestRefreshTokenUnmarshal: -want/+got:\n%s", diff)
}
}
func TestRefreshTokenMarshal(t *testing.T) {
refreshToken := accesstokens.RefreshToken{
HomeAccountID: "",
Environment: rtEnv,
CredentialType: rtCredential,
Secret: refSecret,
AdditionalFields: map[string]interface{}{
"extra": "this_is_extra",
},
}
want := map[string]interface{}{
"environment": "env",
"credential_type": "accesstokens.RefreshToken",
"secret": "secret",
"extra": "this_is_extra",
}
b, err := json.Marshal(refreshToken)
if err != nil {
panic(err)
}
got := map[string]interface{}{}
if err := stdJSON.Unmarshal(b, &got); err != nil {
panic(err)
}
if diff := pretty.Compare(want, got); diff != "" {
t.Errorf("TestRefreshTokenMarshal: -want/+got:\n%s", diff)
}
}
func TestRegression196(t *testing.T) {
// https://github.com/AzureAD/microsoft-authentication-library-for-go/issues/196
// Note: all values here look real, but they have been altered to prevent any exposure
// of even a temporary security value.
contract := &Contract{
AccessTokens: map[string]AccessToken{
"-login.microsoftonline.com-AccessToken-5b0c5134eacb-https://graph.microsoft.com/.default": {
HomeAccountID: "",
Environment: "login.microsoftonline.com",
Realm: "2cce-489d-4002-8293-5b0eacb",
CredentialType: "AccessToken",
ClientID: "841-b1d2-460b-bc46-11cfb",
Secret: "secret",
Scopes: "https://graph.microsoft.com/.default",
ExpiresOn: internalTime.Unix{T: expiresOn},
ExtendedExpiresOn: internalTime.Unix{T: extExpiresOn},
CachedAt: internalTime.Unix{T: cachedAt},
},
},
AppMetaData: map[string]AppMetaData{
"AppMetaData-login.microsoftonline.com-84a31-b1d2-460b-bc46-1158fb": {
ClientID: "8431-bd2-460b-bc46-11c4c8fb",
Environment: "login.microsoftonline.com",
},
},
}
b, err := json.Marshal(contract)
if err != nil {
t.Fatalf("TestRegression196: Marshal had unexpected error: %v", err)
}
got := &Contract{}
if err := json.Unmarshal(b, got); err != nil {
t.Fatalf("TestRegression196: Unmarshal had unexpected error: %v, json was:\n%s", err, string(b))
}
if diff := pretty.Compare(contract, got); diff != "" {
t.Fatalf("TestRegression196: -want/+got:\n%s", diff)
}
}
partitioned_storage.go 0000664 0000000 0000000 00000036613 14420263624 0034666 0 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/apps/internal/base/internal/storage // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package storage
import (
"context"
"errors"
"fmt"
"strings"
"sync"
"time"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/shared"
)
// PartitionedManager is a partitioned in-memory cache of access tokens, accounts and meta data.
type PartitionedManager struct {
contract *InMemoryContract
contractMu sync.RWMutex
requests aadInstanceDiscoveryer // *oauth.Token
aadCacheMu sync.RWMutex
aadCache map[string]authority.InstanceDiscoveryMetadata
}
// NewPartitionedManager is the constructor for PartitionedManager.
func NewPartitionedManager(requests *oauth.Client) *PartitionedManager {
m := &PartitionedManager{requests: requests, aadCache: make(map[string]authority.InstanceDiscoveryMetadata)}
m.contract = NewInMemoryContract()
return m
}
// Read reads a storage token from the cache if it exists.
func (m *PartitionedManager) Read(ctx context.Context, authParameters authority.AuthParams) (TokenResponse, error) {
tr := TokenResponse{}
realm := authParameters.AuthorityInfo.Tenant
clientID := authParameters.ClientID
scopes := authParameters.Scopes
// fetch metadata if instanceDiscovery is enabled
aliases := []string{authParameters.AuthorityInfo.Host}
if !authParameters.AuthorityInfo.InstanceDiscoveryDisabled {
metadata, err := m.getMetadataEntry(ctx, authParameters.AuthorityInfo)
if err != nil {
return TokenResponse{}, err
}
aliases = metadata.Aliases
}
userAssertionHash := authParameters.AssertionHash()
partitionKeyFromRequest := userAssertionHash
// errors returned by read* methods indicate a cache miss and are therefore non-fatal. We continue populating
// TokenResponse fields so that e.g. lack of an ID token doesn't prevent the caller from receiving a refresh token.
accessToken, err := m.readAccessToken(aliases, realm, clientID, userAssertionHash, scopes, partitionKeyFromRequest)
if err == nil {
tr.AccessToken = accessToken
}
idToken, err := m.readIDToken(aliases, realm, clientID, userAssertionHash, getPartitionKeyIDTokenRead(accessToken))
if err == nil {
tr.IDToken = idToken
}
if appMetadata, err := m.readAppMetaData(aliases, clientID); err == nil {
// we need the family ID to identify the correct refresh token, if any
familyID := appMetadata.FamilyID
refreshToken, err := m.readRefreshToken(aliases, familyID, clientID, userAssertionHash, partitionKeyFromRequest)
if err == nil {
tr.RefreshToken = refreshToken
}
}
account, err := m.readAccount(aliases, realm, userAssertionHash, idToken.HomeAccountID)
if err == nil {
tr.Account = account
}
return tr, nil
}
// Write writes a token response to the cache and returns the account information the token is stored with.
func (m *PartitionedManager) Write(authParameters authority.AuthParams, tokenResponse accesstokens.TokenResponse) (shared.Account, error) {
authParameters.HomeAccountID = tokenResponse.ClientInfo.HomeAccountID()
homeAccountID := authParameters.HomeAccountID
environment := authParameters.AuthorityInfo.Host
realm := authParameters.AuthorityInfo.Tenant
clientID := authParameters.ClientID
target := strings.Join(tokenResponse.GrantedScopes.Slice, scopeSeparator)
userAssertionHash := authParameters.AssertionHash()
cachedAt := time.Now()
var account shared.Account
if len(tokenResponse.RefreshToken) > 0 {
refreshToken := accesstokens.NewRefreshToken(homeAccountID, environment, clientID, tokenResponse.RefreshToken, tokenResponse.FamilyID)
if authParameters.AuthorizationType == authority.ATOnBehalfOf {
refreshToken.UserAssertionHash = userAssertionHash
}
if err := m.writeRefreshToken(refreshToken, getPartitionKeyRefreshToken(refreshToken)); err != nil {
return account, err
}
}
if len(tokenResponse.AccessToken) > 0 {
accessToken := NewAccessToken(
homeAccountID,
environment,
realm,
clientID,
cachedAt,
tokenResponse.ExpiresOn.T,
tokenResponse.ExtExpiresOn.T,
target,
tokenResponse.AccessToken,
)
if authParameters.AuthorizationType == authority.ATOnBehalfOf {
accessToken.UserAssertionHash = userAssertionHash // get Hash method on this
}
// Since we have a valid access token, cache it before moving on.
if err := accessToken.Validate(); err == nil {
if err := m.writeAccessToken(accessToken, getPartitionKeyAccessToken(accessToken)); err != nil {
return account, err
}
} else {
return shared.Account{}, err
}
}
idTokenJwt := tokenResponse.IDToken
if !idTokenJwt.IsZero() {
idToken := NewIDToken(homeAccountID, environment, realm, clientID, idTokenJwt.RawToken)
if authParameters.AuthorizationType == authority.ATOnBehalfOf {
idToken.UserAssertionHash = userAssertionHash
}
if err := m.writeIDToken(idToken, getPartitionKeyIDToken(idToken)); err != nil {
return shared.Account{}, err
}
localAccountID := idTokenJwt.LocalAccountID()
authorityType := authParameters.AuthorityInfo.AuthorityType
preferredUsername := idTokenJwt.UPN
if idTokenJwt.PreferredUsername != "" {
preferredUsername = idTokenJwt.PreferredUsername
}
account = shared.NewAccount(
homeAccountID,
environment,
realm,
localAccountID,
authorityType,
preferredUsername,
)
if authParameters.AuthorizationType == authority.ATOnBehalfOf {
account.UserAssertionHash = userAssertionHash
}
if err := m.writeAccount(account, getPartitionKeyAccount(account)); err != nil {
return shared.Account{}, err
}
}
AppMetaData := NewAppMetaData(tokenResponse.FamilyID, clientID, environment)
if err := m.writeAppMetaData(AppMetaData); err != nil {
return shared.Account{}, err
}
return account, nil
}
func (m *PartitionedManager) getMetadataEntry(ctx context.Context, authorityInfo authority.Info) (authority.InstanceDiscoveryMetadata, error) {
md, err := m.aadMetadataFromCache(ctx, authorityInfo)
if err != nil {
// not in the cache, retrieve it
md, err = m.aadMetadata(ctx, authorityInfo)
}
return md, err
}
func (m *PartitionedManager) aadMetadataFromCache(ctx context.Context, authorityInfo authority.Info) (authority.InstanceDiscoveryMetadata, error) {
m.aadCacheMu.RLock()
defer m.aadCacheMu.RUnlock()
metadata, ok := m.aadCache[authorityInfo.Host]
if ok {
return metadata, nil
}
return metadata, errors.New("not found")
}
func (m *PartitionedManager) aadMetadata(ctx context.Context, authorityInfo authority.Info) (authority.InstanceDiscoveryMetadata, error) {
discoveryResponse, err := m.requests.AADInstanceDiscovery(ctx, authorityInfo)
if err != nil {
return authority.InstanceDiscoveryMetadata{}, err
}
m.aadCacheMu.Lock()
defer m.aadCacheMu.Unlock()
for _, metadataEntry := range discoveryResponse.Metadata {
for _, aliasedAuthority := range metadataEntry.Aliases {
m.aadCache[aliasedAuthority] = metadataEntry
}
}
if _, ok := m.aadCache[authorityInfo.Host]; !ok {
m.aadCache[authorityInfo.Host] = authority.InstanceDiscoveryMetadata{
PreferredNetwork: authorityInfo.Host,
PreferredCache: authorityInfo.Host,
}
}
return m.aadCache[authorityInfo.Host], nil
}
func (m *PartitionedManager) readAccessToken(envAliases []string, realm, clientID, userAssertionHash string, scopes []string, partitionKey string) (AccessToken, error) {
m.contractMu.RLock()
defer m.contractMu.RUnlock()
if accessTokens, ok := m.contract.AccessTokensPartition[partitionKey]; ok {
// TODO: linear search (over a map no less) is slow for a large number (thousands) of tokens.
// this shows up as the dominating node in a profile. for real-world scenarios this likely isn't
// an issue, however if it does become a problem then we know where to look.
for _, at := range accessTokens {
if at.Realm == realm && at.ClientID == clientID && at.UserAssertionHash == userAssertionHash {
if checkAlias(at.Environment, envAliases) {
if isMatchingScopes(scopes, at.Scopes) {
return at, nil
}
}
}
}
}
return AccessToken{}, fmt.Errorf("access token not found")
}
func (m *PartitionedManager) writeAccessToken(accessToken AccessToken, partitionKey string) error {
m.contractMu.Lock()
defer m.contractMu.Unlock()
key := accessToken.Key()
if m.contract.AccessTokensPartition[partitionKey] == nil {
m.contract.AccessTokensPartition[partitionKey] = make(map[string]AccessToken)
}
m.contract.AccessTokensPartition[partitionKey][key] = accessToken
return nil
}
func matchFamilyRefreshTokenObo(rt accesstokens.RefreshToken, userAssertionHash string, envAliases []string) bool {
return rt.UserAssertionHash == userAssertionHash && checkAlias(rt.Environment, envAliases) && rt.FamilyID != ""
}
func matchClientIDRefreshTokenObo(rt accesstokens.RefreshToken, userAssertionHash string, envAliases []string, clientID string) bool {
return rt.UserAssertionHash == userAssertionHash && checkAlias(rt.Environment, envAliases) && rt.ClientID == clientID
}
func (m *PartitionedManager) readRefreshToken(envAliases []string, familyID, clientID, userAssertionHash, partitionKey string) (accesstokens.RefreshToken, error) {
byFamily := func(rt accesstokens.RefreshToken) bool {
return matchFamilyRefreshTokenObo(rt, userAssertionHash, envAliases)
}
byClient := func(rt accesstokens.RefreshToken) bool {
return matchClientIDRefreshTokenObo(rt, userAssertionHash, envAliases, clientID)
}
var matchers []func(rt accesstokens.RefreshToken) bool
if familyID == "" {
matchers = []func(rt accesstokens.RefreshToken) bool{
byClient, byFamily,
}
} else {
matchers = []func(rt accesstokens.RefreshToken) bool{
byFamily, byClient,
}
}
// TODO(keegan): All the tests here pass, but Bogdan says this is
// more complicated. I'm opening an issue for this to have him
// review the tests and suggest tests that would break this so
// we can re-write against good tests. His comments as follow:
// The algorithm is a bit more complex than this, I assume there are some tests covering everything. I would keep the order as is.
// The algorithm is:
// If application is NOT part of the family, search by client_ID
// If app is part of the family or if we DO NOT KNOW if it's part of the family, search by family ID, then by client_id (we will know if an app is part of the family after the first token response).
// https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/311fe8b16e7c293462806f397e189a6aa1159769/src/client/Microsoft.Identity.Client/Internal/Requests/Silent/CacheSilentStrategy.cs#L95
m.contractMu.RLock()
defer m.contractMu.RUnlock()
for _, matcher := range matchers {
for _, rt := range m.contract.RefreshTokensPartition[partitionKey] {
if matcher(rt) {
return rt, nil
}
}
}
return accesstokens.RefreshToken{}, fmt.Errorf("refresh token not found")
}
func (m *PartitionedManager) writeRefreshToken(refreshToken accesstokens.RefreshToken, partitionKey string) error {
m.contractMu.Lock()
defer m.contractMu.Unlock()
key := refreshToken.Key()
if m.contract.AccessTokensPartition[partitionKey] == nil {
m.contract.RefreshTokensPartition[partitionKey] = make(map[string]accesstokens.RefreshToken)
}
m.contract.RefreshTokensPartition[partitionKey][key] = refreshToken
return nil
}
func (m *PartitionedManager) readIDToken(envAliases []string, realm, clientID, userAssertionHash, partitionKey string) (IDToken, error) {
m.contractMu.RLock()
defer m.contractMu.RUnlock()
for _, idt := range m.contract.IDTokensPartition[partitionKey] {
if idt.Realm == realm && idt.ClientID == clientID && idt.UserAssertionHash == userAssertionHash {
if checkAlias(idt.Environment, envAliases) {
return idt, nil
}
}
}
return IDToken{}, fmt.Errorf("token not found")
}
func (m *PartitionedManager) writeIDToken(idToken IDToken, partitionKey string) error {
key := idToken.Key()
m.contractMu.Lock()
defer m.contractMu.Unlock()
if m.contract.IDTokensPartition[partitionKey] == nil {
m.contract.IDTokensPartition[partitionKey] = make(map[string]IDToken)
}
m.contract.IDTokensPartition[partitionKey][key] = idToken
return nil
}
func (m *PartitionedManager) readAccount(envAliases []string, realm, UserAssertionHash, partitionKey string) (shared.Account, error) {
m.contractMu.RLock()
defer m.contractMu.RUnlock()
// You might ask why, if cache.Accounts is a map, we would loop through all of these instead of using a key.
// We only use a map because the storage contract shared between all language implementations says use a map.
// We can't change that. The other is because the keys are made using a specific "env", but here we are allowing
// a match in multiple envs (envAlias). That means we either need to hash each possible keyand do the lookup
// or just statically check. Since the design is to have a storage.Manager per user, the amount of keys stored
// is really low (say 2). Each hash is more expensive than the entire iteration.
for _, acc := range m.contract.AccountsPartition[partitionKey] {
if checkAlias(acc.Environment, envAliases) && acc.UserAssertionHash == UserAssertionHash && acc.Realm == realm {
return acc, nil
}
}
return shared.Account{}, fmt.Errorf("account not found")
}
func (m *PartitionedManager) writeAccount(account shared.Account, partitionKey string) error {
key := account.Key()
m.contractMu.Lock()
defer m.contractMu.Unlock()
if m.contract.AccountsPartition[partitionKey] == nil {
m.contract.AccountsPartition[partitionKey] = make(map[string]shared.Account)
}
m.contract.AccountsPartition[partitionKey][key] = account
return nil
}
func (m *PartitionedManager) readAppMetaData(envAliases []string, clientID string) (AppMetaData, error) {
m.contractMu.RLock()
defer m.contractMu.RUnlock()
for _, app := range m.contract.AppMetaData {
if checkAlias(app.Environment, envAliases) && app.ClientID == clientID {
return app, nil
}
}
return AppMetaData{}, fmt.Errorf("not found")
}
func (m *PartitionedManager) writeAppMetaData(AppMetaData AppMetaData) error {
key := AppMetaData.Key()
m.contractMu.Lock()
defer m.contractMu.Unlock()
m.contract.AppMetaData[key] = AppMetaData
return nil
}
// update updates the internal cache object. This is for use in tests, other uses are not
// supported.
func (m *PartitionedManager) update(cache *InMemoryContract) {
m.contractMu.Lock()
defer m.contractMu.Unlock()
m.contract = cache
}
// Marshal implements cache.Marshaler.
func (m *PartitionedManager) Marshal() ([]byte, error) {
return json.Marshal(m.contract)
}
// Unmarshal implements cache.Unmarshaler.
func (m *PartitionedManager) Unmarshal(b []byte) error {
m.contractMu.Lock()
defer m.contractMu.Unlock()
contract := NewInMemoryContract()
err := json.Unmarshal(b, contract)
if err != nil {
return err
}
m.contract = contract
return nil
}
func getPartitionKeyAccessToken(item AccessToken) string {
if item.UserAssertionHash != "" {
return item.UserAssertionHash
}
return item.HomeAccountID
}
func getPartitionKeyRefreshToken(item accesstokens.RefreshToken) string {
if item.UserAssertionHash != "" {
return item.UserAssertionHash
}
return item.HomeAccountID
}
func getPartitionKeyIDToken(item IDToken) string {
return item.HomeAccountID
}
func getPartitionKeyAccount(item shared.Account) string {
return item.HomeAccountID
}
func getPartitionKeyIDTokenRead(item AccessToken) string {
return item.HomeAccountID
}
partitioned_storage_test.go 0000664 0000000 0000000 00000043525 14420263624 0035725 0 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/apps/internal/base/internal/storage // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package storage
import (
"context"
"fmt"
"reflect"
"testing"
"time"
internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/shared"
"github.com/kylelemons/godebug/pretty"
)
func newPartitionedManagerForTest(authorityClient aadInstanceDiscoveryer) *PartitionedManager {
m := &PartitionedManager{requests: authorityClient, aadCache: make(map[string]authority.InstanceDiscoveryMetadata)}
m.contract = NewInMemoryContract()
return m
}
func TestOBOAccessTokenScopes(t *testing.T) {
fakeAuthority := "fakeauthority"
mgr := newPartitionedManagerForTest(&fakeDiscoveryResponser{
ret: authority.InstanceDiscoveryResponse{
Metadata: []authority.InstanceDiscoveryMetadata{
{Aliases: []string{fakeAuthority}},
},
},
})
upn := "upn"
idt := accesstokens.IDToken{
Oid: upn + "-oid",
PreferredUsername: upn,
TenantID: "tenant",
UPN: upn,
}
authParams := []authority.AuthParams{}
for _, scope := range [][]string{{"scopeA"}, {"scopeB"}} {
ap := authority.AuthParams{
AuthorityInfo: authority.Info{
AuthorityType: authority.AAD,
Host: fakeAuthority,
Tenant: idt.TenantID,
},
AuthorizationType: authority.ATOnBehalfOf,
ClientID: "client-id",
Scopes: scope,
UserAssertion: upn + "-assertion",
Username: idt.PreferredUsername,
}
_, err := mgr.Write(
ap,
accesstokens.TokenResponse{
AccessToken: scope[0] + "-at",
ClientInfo: accesstokens.ClientInfo{UID: upn, UTID: idt.TenantID},
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: scope},
IDToken: idt,
RefreshToken: upn + "-rt",
},
)
if err != nil {
t.Fatal(err)
}
authParams = append(authParams, ap)
}
for _, ap := range authParams {
tr, err := mgr.Read(context.Background(), ap)
if err != nil {
t.Fatal(err)
}
if tr.AccessToken.Secret != ap.Scopes[0]+"-at" {
t.Fatalf(`unexpected access token "%s"`, tr.AccessToken.Secret)
}
}
}
func TestOBOPartitioning(t *testing.T) {
fakeAuthority := "fakeauthority"
mgr := newPartitionedManagerForTest(&fakeDiscoveryResponser{
ret: authority.InstanceDiscoveryResponse{
Metadata: []authority.InstanceDiscoveryMetadata{
{Aliases: []string{fakeAuthority}},
},
},
})
scopes := []string{"scope"}
accounts := make([]shared.Account, 2)
authParams := make([]authority.AuthParams, len(accounts))
for i := 0; i < len(accounts); i++ {
upn := fmt.Sprintf("%d", i)
idt := accesstokens.IDToken{
Oid: upn + "-oid",
PreferredUsername: upn,
TenantID: "tenant",
UPN: upn,
}
authParams[i] = authority.AuthParams{
AuthorityInfo: authority.Info{
AuthorityType: authority.AAD,
Host: fakeAuthority,
Tenant: idt.TenantID,
},
AuthorizationType: authority.ATOnBehalfOf,
ClientID: "client-id",
Scopes: scopes,
UserAssertion: upn + "-assertion",
Username: idt.PreferredUsername,
}
account, err := mgr.Write(
authParams[i],
accesstokens.TokenResponse{
AccessToken: upn + "-at",
ClientInfo: accesstokens.ClientInfo{UID: upn, UTID: idt.TenantID},
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: scopes},
IDToken: idt,
RefreshToken: upn + "-rt",
},
)
if err != nil {
t.Fatal(err)
}
accounts[i] = account
}
for i, ap := range authParams {
tr, err := mgr.Read(context.Background(), ap)
if err != nil {
t.Fatal(err)
}
if tr.AccessToken.Secret != accounts[i].PreferredUsername+"-at" {
t.Fatalf(`unexpected access token "%s"`, tr.AccessToken.Secret)
}
}
}
func TestReadPartitionedAccessToken(t *testing.T) {
now := time.Now()
testAccessToken := NewAccessToken(
"hid",
"env",
"realm",
"cid",
now,
now,
now,
"openid user.read",
"secret",
)
testAccessToken.UserAssertionHash = "user_assertion_hash"
cache := &InMemoryContract{
AccessTokensPartition: map[string]map[string]AccessToken{
"at_partition": {testAccessToken.Key(): testAccessToken},
},
}
storageManager := newPartitionedManagerForTest(nil)
storageManager.update(cache)
retAccessToken, err := storageManager.readAccessToken(
[]string{"hello", "env", "test"},
"realm",
"cid",
"user_assertion_hash",
[]string{"user.read", "openid"},
"at_partition",
)
if err != nil {
t.Errorf("TestReadPartitionedAccessToken: got err == %s, want err == nil", err)
}
if diff := pretty.Compare(testAccessToken, retAccessToken); diff != "" {
t.Fatalf("Returned access token is not the same as expected access token: -want/+got:\n%s", diff)
}
_, err = storageManager.readAccessToken(
[]string{"hello", "env", "test"},
"realm",
"cid",
"this_should_break_it",
[]string{"user.read", "openid"},
"at_partition",
)
if err == nil {
t.Errorf("TestReadPartitionedAccessToken: got err == nil, want err != nil")
}
}
func TestWritePartitionedAccessToken(t *testing.T) {
now := time.Now()
storageManager := newPartitionedManagerForTest(nil)
testAccessToken := NewAccessToken(
"hid",
"env",
"realm",
"cid",
now,
now,
now,
"openid",
"secret",
)
key := testAccessToken.Key()
err := storageManager.writeAccessToken(testAccessToken, "at_partition")
if err != nil {
t.Fatalf("TestWritePartitionedAccessToken: got err == %s, want err == nil", err)
}
if diff := pretty.Compare(testAccessToken, storageManager.contract.AccessTokensPartition["at_partition"][key]); diff != "" {
t.Errorf("TestWritePartitionedAccessToken: -want/+got:\n%s", diff)
}
}
func TestReadPartitionedAccount(t *testing.T) {
testAcc := shared.NewAccount("hid", "env", "realm", "lid", accAuth, "username")
testAcc.UserAssertionHash = "user_assertion_hash"
cache := &InMemoryContract{
AccountsPartition: map[string]map[string]shared.Account{
"acc_partition": {testAcc.Key(): testAcc},
},
}
storageManager := newPartitionedManagerForTest(nil)
storageManager.update(cache)
returnedAccount, err := storageManager.readAccount([]string{"hello", "env", "test"}, "realm", "user_assertion_hash", "acc_partition")
if err != nil {
t.Fatalf("TestReadPartitionedAccount: got err == %s, want err == nil", err)
}
if diff := pretty.Compare(testAcc, returnedAccount); diff != "" {
t.Errorf("TestReadPartitionedAccount: -want/+got:\n%s", diff)
}
_, err = storageManager.readAccount([]string{"hello", "env", "test"}, "realm", "this_should_break_it", "acc_partition")
if err == nil {
t.Errorf("TestReadPartitionedAccount: got err == nil, want err != nil")
}
}
func TestWritePartitionedAccount(t *testing.T) {
storageManager := newPartitionedManagerForTest(nil)
testAcc := shared.NewAccount("hid", "env", "realm", "lid", accAuth, "username")
testAcc.UserAssertionHash = "user_assertion_hash"
key := testAcc.Key()
err := storageManager.writeAccount(testAcc, "acc_partition")
if err != nil {
t.Fatalf("TestWritePartitionedAccount: got err == %s, want err == nil", err)
}
if diff := pretty.Compare(testAcc, storageManager.contract.AccountsPartition["acc_partition"][key]); diff != "" {
t.Errorf("TestWritePartitionedAccount: -want/+got:\n%s", diff)
}
}
func TestReadAppMetaDataPartitionedManager(t *testing.T) {
testAppMeta := NewAppMetaData("fid", "cid", "env")
cache := &InMemoryContract{
AppMetaData: map[string]AppMetaData{
testAppMeta.Key(): testAppMeta,
},
}
storageManager := newPartitionedManagerForTest(nil)
storageManager.update(cache)
returnedAppMeta, err := storageManager.readAppMetaData([]string{"hello", "test", "env"}, "cid")
if err != nil {
t.Fatalf("TestreadAppMetaDataPartitionedManager(readAppMetaData): got err == %s, want err == nil", err)
}
if diff := pretty.Compare(testAppMeta, returnedAppMeta); diff != "" {
t.Fatalf("TestreadAppMetaDataPartitionedManager(readAppMetaData): -want/+got:\n%s", diff)
}
_, err = storageManager.readAppMetaData([]string{"hello", "test", "env"}, "break_this")
if err == nil {
t.Fatalf("TestreadAppMetaDataPartitionedManager(bad readAppMetaData): got err == nil, want err != nil")
}
}
func TestWriteAppMetaDataPartitionedManager(t *testing.T) {
storageManager := newPartitionedManagerForTest(nil)
testAppMeta := NewAppMetaData("fid", "cid", "env")
key := testAppMeta.Key()
err := storageManager.writeAppMetaData(testAppMeta)
if err != nil {
t.Fatalf("TestwriteAppMetaDataPartitionedManager: got err == %s, want err == nil", err)
}
if diff := pretty.Compare(testAppMeta, storageManager.contract.AppMetaData[key]); diff != "" {
t.Errorf("TestwriteAppMetaDataPartitionedManager: -want/+got:\n%s", diff)
}
}
func TestReadPartitionedIDToken(t *testing.T) {
testIDToken := NewIDToken(
"hid",
"env",
"realm",
"cid",
"secret",
)
testIDToken.UserAssertionHash = "user_assertion_hash"
cache := &InMemoryContract{
IDTokensPartition: map[string]map[string]IDToken{
"idt_partition": {testIDToken.Key(): testIDToken},
},
}
storageManager := newPartitionedManagerForTest(nil)
storageManager.update(cache)
returnedIDToken, err := storageManager.readIDToken(
[]string{"hello", "env", "test"},
"realm",
"cid",
"user_assertion_hash",
"idt_partition",
)
if err != nil {
panic(err)
}
if diff := pretty.Compare(testIDToken, returnedIDToken); diff != "" {
t.Fatalf("TestReadPartitionedIDToken(good token): -want/+got:\n%s", diff)
}
_, err = storageManager.readIDToken(
[]string{"hello", "env", "test"},
"realm",
"cid",
"this_should_break_it",
"idt_partition",
)
if err == nil {
t.Errorf("TestReadPartitionedIDToken(bad token): got err == nil, want err != nil")
}
}
func TestWritePartitionedIDToken(t *testing.T) {
storageManager := newPartitionedManagerForTest(nil)
testIDToken := NewIDToken(
"hid",
"env",
"realm",
"cid",
"secret",
)
testIDToken.UserAssertionHash = "user_assertion_hash"
key := testIDToken.Key()
err := storageManager.writeIDToken(testIDToken, "idt_partition")
if err != nil {
t.Fatalf("TestWritePartitionedIDToken: got err == %s, want err == nil", err)
}
if diff := pretty.Compare(testIDToken, storageManager.contract.IDTokensPartition["idt_partition"][key]); diff != "" {
t.Errorf("TestWritePartitionedIDToken: -want/+got:\n%s", diff)
}
}
func TestReadPartionedRefreshToken(t *testing.T) {
testRefreshTokenWithFID := accesstokens.NewRefreshToken(
"hid",
"env",
"cid",
"secret",
"fid",
)
testRefreshTokenWithFID.UserAssertionHash = "user_assertion_hash"
testRefreshTokenWoFID := accesstokens.NewRefreshToken(
"hid",
"env",
"cid",
"secret",
"",
)
testRefreshTokenWoFID.UserAssertionHash = "user_assertion_hash"
testRefreshTokenWoFIDAltCID := accesstokens.NewRefreshToken(
"hid",
"env",
"cid2",
"secret",
"",
)
testRefreshTokenWoFIDAltCID.UserAssertionHash = "user_assertion_hash"
type args struct {
envAliases []string
familyID string
clientID string
userAssertionHash string
}
tests := []struct {
name string
contract *InMemoryContract
args args
want accesstokens.RefreshToken
err bool
}{
{
name: "Token without fid, read with fid, cid, env, and hid",
contract: &InMemoryContract{
RefreshTokensPartition: map[string]map[string]accesstokens.RefreshToken{
"rt_partition": {testRefreshTokenWoFID.Key(): testRefreshTokenWoFID},
},
},
args: args{
envAliases: []string{"test", "env", "hello"},
familyID: "fid",
clientID: "cid",
userAssertionHash: "user_assertion_hash",
},
want: testRefreshTokenWoFID,
},
{
name: "Token without fid, read with cid, env, and hid",
contract: &InMemoryContract{
RefreshTokensPartition: map[string]map[string]accesstokens.RefreshToken{
"rt_partition": {testRefreshTokenWoFID.Key(): testRefreshTokenWoFID},
},
},
args: args{
envAliases: []string{"test", "env", "hello"},
familyID: "",
clientID: "cid",
userAssertionHash: "user_assertion_hash",
},
want: testRefreshTokenWoFID,
},
{
name: "Token without fid, verify CID is required",
contract: &InMemoryContract{
RefreshTokensPartition: map[string]map[string]accesstokens.RefreshToken{
"rt_partition": {testRefreshTokenWoFID.Key(): testRefreshTokenWoFID},
},
},
args: args{
envAliases: []string{"test", "env", "hello"},
familyID: "",
clientID: "",
userAssertionHash: "user_assertion_hash",
},
err: true,
},
{
name: "Token without fid, Verify env is required",
contract: &InMemoryContract{
RefreshTokensPartition: map[string]map[string]accesstokens.RefreshToken{
"rt_partition": {testRefreshTokenWoFID.Key(): testRefreshTokenWoFID},
},
},
args: args{
envAliases: []string{},
familyID: "",
clientID: "",
userAssertionHash: "user_assertion_hash",
},
err: true,
},
{
name: "Token with fid, read with fid, cid, env, and hid",
contract: &InMemoryContract{
RefreshTokensPartition: map[string]map[string]accesstokens.RefreshToken{
"rt_partition": {testRefreshTokenWoFID.Key(): testRefreshTokenWithFID},
},
},
args: args{
envAliases: []string{"test", "env", "hello"},
familyID: "fid",
clientID: "cid",
userAssertionHash: "user_assertion_hash",
},
want: testRefreshTokenWithFID,
},
{
name: "Token with fid, read with cid, env, and hid",
contract: &InMemoryContract{
RefreshTokensPartition: map[string]map[string]accesstokens.RefreshToken{
"rt_partition": {testRefreshTokenWoFID.Key(): testRefreshTokenWithFID},
},
},
args: args{
envAliases: []string{"test", "env", "hello"},
familyID: "",
clientID: "cid",
userAssertionHash: "user_assertion_hash",
},
want: testRefreshTokenWithFID,
},
{
name: "Token with fid, verify CID is not required", // match on hid, env, and has fid
contract: &InMemoryContract{
RefreshTokensPartition: map[string]map[string]accesstokens.RefreshToken{
"rt_partition": {testRefreshTokenWoFID.Key(): testRefreshTokenWithFID},
},
},
args: args{
envAliases: []string{"test", "env", "hello"},
familyID: "",
clientID: "",
userAssertionHash: "user_assertion_hash",
},
want: testRefreshTokenWithFID,
},
{
name: "Token with fid, Verify env is required",
contract: &InMemoryContract{
RefreshTokensPartition: map[string]map[string]accesstokens.RefreshToken{
"rt_partition": {testRefreshTokenWoFID.Key(): testRefreshTokenWithFID},
},
},
args: args{
envAliases: []string{},
familyID: "",
clientID: "",
userAssertionHash: "user_assertion_hash",
},
err: true,
},
{
name: "Multiple items in cache, given a fid, item with fid will be returned",
contract: &InMemoryContract{
RefreshTokensPartition: map[string]map[string]accesstokens.RefreshToken{
"rt_partition": {
testRefreshTokenWoFID.Key(): testRefreshTokenWoFID,
testRefreshTokenWoFID.Key(): testRefreshTokenWithFID,
testRefreshTokenWoFIDAltCID.Key(): testRefreshTokenWoFIDAltCID,
},
},
},
args: args{
envAliases: []string{},
familyID: "fid",
clientID: "cid",
userAssertionHash: "user_assertion_hash",
},
err: true,
},
// Cannot guarentee that without an alternate cid which token will be
// returned deterministically when HID, CID, and env match.
{
name: "Multiple items in cache, without a fid and with alternate CID, token with alternate CID is returned",
contract: &InMemoryContract{
RefreshTokensPartition: map[string]map[string]accesstokens.RefreshToken{
"rt_partition": {
testRefreshTokenWoFID.Key(): testRefreshTokenWoFID,
testRefreshTokenWoFID.Key(): testRefreshTokenWithFID,
testRefreshTokenWoFIDAltCID.Key(): testRefreshTokenWoFIDAltCID,
},
},
},
args: args{
envAliases: []string{},
familyID: "",
clientID: "cid2",
userAssertionHash: "user_assertion_hash",
},
err: true,
},
}
m := &PartitionedManager{}
for _, test := range tests {
m.update(test.contract)
got, err := m.readRefreshToken(test.args.envAliases, test.args.familyID, test.args.clientID, test.args.userAssertionHash, "rt_partition")
switch {
case test.err && err == nil:
t.Errorf("TestDefaultStorageManagerreadRefreshToken(%s): got err == nil, want err != nil", test.name)
continue
case !test.err && err != nil:
t.Errorf("TestDefaultStorageManagerreadRefreshToken(%s): got err == %s, want err == nil", test.name, err)
continue
case err != nil:
continue
}
if diff := pretty.Compare(test.want, got); diff != "" {
t.Errorf("TestDefaultStorageManagerreadRefreshToken(%s): -want/+got:\n%s", test.name, diff)
}
}
}
func TestWritePartitionedRefreshToken(t *testing.T) {
storageManager := newPartitionedManagerForTest(nil)
testRefreshToken := accesstokens.NewRefreshToken(
"hid",
"env",
"cid",
"secret",
"fid",
)
testRefreshToken.UserAssertionHash = "user_assertion_hash"
key := testRefreshToken.Key()
err := storageManager.writeRefreshToken(testRefreshToken, "rt_partition")
if err != nil {
t.Errorf("Error should be nil, but it is %v", err)
}
if !reflect.DeepEqual(storageManager.contract.RefreshTokensPartition["rt_partition"][key], testRefreshToken) {
t.Errorf("Added refresh token %v differs from expected refresh token %v",
storageManager.contract.RefreshTokensPartition["rt_partition"][key],
testRefreshToken)
}
}
microsoft-authentication-library-for-go-1.0.0/apps/internal/base/internal/storage/storage.go 0000664 0000000 0000000 00000042027 14420263624 0032337 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
// Package storage holds all cached token information for MSAL. This storage can be
// augmented with third-party extensions to provide persistent storage. In that case,
// reads and writes in upper packages will call Marshal() to take the entire in-memory
// representation and write it to storage and Unmarshal() to update the entire in-memory
// storage with what was in the persistent storage. The persistent storage can only be
// accessed in this way because multiple MSAL clients written in multiple languages can
// access the same storage and must adhere to the same method that was defined
// previously.
package storage
import (
"context"
"errors"
"fmt"
"strings"
"sync"
"time"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/shared"
)
// aadInstanceDiscoveryer allows faking in tests.
// It is implemented in production by ops/authority.Client
type aadInstanceDiscoveryer interface {
AADInstanceDiscovery(ctx context.Context, authorityInfo authority.Info) (authority.InstanceDiscoveryResponse, error)
}
// TokenResponse mimics a token response that was pulled from the cache.
type TokenResponse struct {
RefreshToken accesstokens.RefreshToken
IDToken IDToken // *Credential
AccessToken AccessToken
Account shared.Account
}
// Manager is an in-memory cache of access tokens, accounts and meta data. This data is
// updated on read/write calls. Unmarshal() replaces all data stored here with whatever
// was given to it on each call.
type Manager struct {
contract *Contract
contractMu sync.RWMutex
requests aadInstanceDiscoveryer // *oauth.Token
aadCacheMu sync.RWMutex
aadCache map[string]authority.InstanceDiscoveryMetadata
}
// New is the constructor for Manager.
func New(requests *oauth.Client) *Manager {
m := &Manager{requests: requests, aadCache: make(map[string]authority.InstanceDiscoveryMetadata)}
m.contract = NewContract()
return m
}
func checkAlias(alias string, aliases []string) bool {
for _, v := range aliases {
if alias == v {
return true
}
}
return false
}
func isMatchingScopes(scopesOne []string, scopesTwo string) bool {
newScopesTwo := strings.Split(scopesTwo, scopeSeparator)
scopeCounter := 0
for _, scope := range scopesOne {
for _, otherScope := range newScopesTwo {
if strings.EqualFold(scope, otherScope) {
scopeCounter++
continue
}
}
}
return scopeCounter == len(scopesOne)
}
// Read reads a storage token from the cache if it exists.
func (m *Manager) Read(ctx context.Context, authParameters authority.AuthParams) (TokenResponse, error) {
tr := TokenResponse{}
homeAccountID := authParameters.HomeAccountID
realm := authParameters.AuthorityInfo.Tenant
clientID := authParameters.ClientID
scopes := authParameters.Scopes
// fetch metadata if instanceDiscovery is enabled
aliases := []string{authParameters.AuthorityInfo.Host}
if !authParameters.AuthorityInfo.InstanceDiscoveryDisabled {
metadata, err := m.getMetadataEntry(ctx, authParameters.AuthorityInfo)
if err != nil {
return TokenResponse{}, err
}
aliases = metadata.Aliases
}
accessToken := m.readAccessToken(homeAccountID, aliases, realm, clientID, scopes)
tr.AccessToken = accessToken
if homeAccountID == "" {
// caller didn't specify a user, so there's no reason to search for an ID or refresh token
return tr, nil
}
// errors returned by read* methods indicate a cache miss and are therefore non-fatal. We continue populating
// TokenResponse fields so that e.g. lack of an ID token doesn't prevent the caller from receiving a refresh token.
idToken, err := m.readIDToken(homeAccountID, aliases, realm, clientID)
if err == nil {
tr.IDToken = idToken
}
if appMetadata, err := m.readAppMetaData(aliases, clientID); err == nil {
// we need the family ID to identify the correct refresh token, if any
familyID := appMetadata.FamilyID
refreshToken, err := m.readRefreshToken(homeAccountID, aliases, familyID, clientID)
if err == nil {
tr.RefreshToken = refreshToken
}
}
account, err := m.readAccount(homeAccountID, aliases, realm)
if err == nil {
tr.Account = account
}
return tr, nil
}
const scopeSeparator = " "
// Write writes a token response to the cache and returns the account information the token is stored with.
func (m *Manager) Write(authParameters authority.AuthParams, tokenResponse accesstokens.TokenResponse) (shared.Account, error) {
authParameters.HomeAccountID = tokenResponse.ClientInfo.HomeAccountID()
homeAccountID := authParameters.HomeAccountID
environment := authParameters.AuthorityInfo.Host
realm := authParameters.AuthorityInfo.Tenant
clientID := authParameters.ClientID
target := strings.Join(tokenResponse.GrantedScopes.Slice, scopeSeparator)
cachedAt := time.Now()
var account shared.Account
if len(tokenResponse.RefreshToken) > 0 {
refreshToken := accesstokens.NewRefreshToken(homeAccountID, environment, clientID, tokenResponse.RefreshToken, tokenResponse.FamilyID)
if err := m.writeRefreshToken(refreshToken); err != nil {
return account, err
}
}
if len(tokenResponse.AccessToken) > 0 {
accessToken := NewAccessToken(
homeAccountID,
environment,
realm,
clientID,
cachedAt,
tokenResponse.ExpiresOn.T,
tokenResponse.ExtExpiresOn.T,
target,
tokenResponse.AccessToken,
)
// Since we have a valid access token, cache it before moving on.
if err := accessToken.Validate(); err == nil {
if err := m.writeAccessToken(accessToken); err != nil {
return account, err
}
}
}
idTokenJwt := tokenResponse.IDToken
if !idTokenJwt.IsZero() {
idToken := NewIDToken(homeAccountID, environment, realm, clientID, idTokenJwt.RawToken)
if err := m.writeIDToken(idToken); err != nil {
return shared.Account{}, err
}
localAccountID := idTokenJwt.LocalAccountID()
authorityType := authParameters.AuthorityInfo.AuthorityType
preferredUsername := idTokenJwt.UPN
if idTokenJwt.PreferredUsername != "" {
preferredUsername = idTokenJwt.PreferredUsername
}
account = shared.NewAccount(
homeAccountID,
environment,
realm,
localAccountID,
authorityType,
preferredUsername,
)
if err := m.writeAccount(account); err != nil {
return shared.Account{}, err
}
}
AppMetaData := NewAppMetaData(tokenResponse.FamilyID, clientID, environment)
if err := m.writeAppMetaData(AppMetaData); err != nil {
return shared.Account{}, err
}
return account, nil
}
func (m *Manager) getMetadataEntry(ctx context.Context, authorityInfo authority.Info) (authority.InstanceDiscoveryMetadata, error) {
md, err := m.aadMetadataFromCache(ctx, authorityInfo)
if err != nil {
// not in the cache, retrieve it
md, err = m.aadMetadata(ctx, authorityInfo)
}
return md, err
}
func (m *Manager) aadMetadataFromCache(ctx context.Context, authorityInfo authority.Info) (authority.InstanceDiscoveryMetadata, error) {
m.aadCacheMu.RLock()
defer m.aadCacheMu.RUnlock()
metadata, ok := m.aadCache[authorityInfo.Host]
if ok {
return metadata, nil
}
return metadata, errors.New("not found")
}
func (m *Manager) aadMetadata(ctx context.Context, authorityInfo authority.Info) (authority.InstanceDiscoveryMetadata, error) {
m.aadCacheMu.Lock()
defer m.aadCacheMu.Unlock()
discoveryResponse, err := m.requests.AADInstanceDiscovery(ctx, authorityInfo)
if err != nil {
return authority.InstanceDiscoveryMetadata{}, err
}
for _, metadataEntry := range discoveryResponse.Metadata {
for _, aliasedAuthority := range metadataEntry.Aliases {
m.aadCache[aliasedAuthority] = metadataEntry
}
}
if _, ok := m.aadCache[authorityInfo.Host]; !ok {
m.aadCache[authorityInfo.Host] = authority.InstanceDiscoveryMetadata{
PreferredNetwork: authorityInfo.Host,
PreferredCache: authorityInfo.Host,
}
}
return m.aadCache[authorityInfo.Host], nil
}
func (m *Manager) readAccessToken(homeID string, envAliases []string, realm, clientID string, scopes []string) AccessToken {
m.contractMu.RLock()
defer m.contractMu.RUnlock()
// TODO: linear search (over a map no less) is slow for a large number (thousands) of tokens.
// this shows up as the dominating node in a profile. for real-world scenarios this likely isn't
// an issue, however if it does become a problem then we know where to look.
for _, at := range m.contract.AccessTokens {
if at.HomeAccountID == homeID && at.Realm == realm && at.ClientID == clientID {
if checkAlias(at.Environment, envAliases) {
if isMatchingScopes(scopes, at.Scopes) {
return at
}
}
}
}
return AccessToken{}
}
func (m *Manager) writeAccessToken(accessToken AccessToken) error {
m.contractMu.Lock()
defer m.contractMu.Unlock()
key := accessToken.Key()
m.contract.AccessTokens[key] = accessToken
return nil
}
func (m *Manager) readRefreshToken(homeID string, envAliases []string, familyID, clientID string) (accesstokens.RefreshToken, error) {
byFamily := func(rt accesstokens.RefreshToken) bool {
return matchFamilyRefreshToken(rt, homeID, envAliases)
}
byClient := func(rt accesstokens.RefreshToken) bool {
return matchClientIDRefreshToken(rt, homeID, envAliases, clientID)
}
var matchers []func(rt accesstokens.RefreshToken) bool
if familyID == "" {
matchers = []func(rt accesstokens.RefreshToken) bool{
byClient, byFamily,
}
} else {
matchers = []func(rt accesstokens.RefreshToken) bool{
byFamily, byClient,
}
}
// TODO(keegan): All the tests here pass, but Bogdan says this is
// more complicated. I'm opening an issue for this to have him
// review the tests and suggest tests that would break this so
// we can re-write against good tests. His comments as follow:
// The algorithm is a bit more complex than this, I assume there are some tests covering everything. I would keep the order as is.
// The algorithm is:
// If application is NOT part of the family, search by client_ID
// If app is part of the family or if we DO NOT KNOW if it's part of the family, search by family ID, then by client_id (we will know if an app is part of the family after the first token response).
// https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/311fe8b16e7c293462806f397e189a6aa1159769/src/client/Microsoft.Identity.Client/Internal/Requests/Silent/CacheSilentStrategy.cs#L95
m.contractMu.RLock()
defer m.contractMu.RUnlock()
for _, matcher := range matchers {
for _, rt := range m.contract.RefreshTokens {
if matcher(rt) {
return rt, nil
}
}
}
return accesstokens.RefreshToken{}, fmt.Errorf("refresh token not found")
}
func matchFamilyRefreshToken(rt accesstokens.RefreshToken, homeID string, envAliases []string) bool {
return rt.HomeAccountID == homeID && checkAlias(rt.Environment, envAliases) && rt.FamilyID != ""
}
func matchClientIDRefreshToken(rt accesstokens.RefreshToken, homeID string, envAliases []string, clientID string) bool {
return rt.HomeAccountID == homeID && checkAlias(rt.Environment, envAliases) && rt.ClientID == clientID
}
func (m *Manager) writeRefreshToken(refreshToken accesstokens.RefreshToken) error {
key := refreshToken.Key()
m.contractMu.Lock()
defer m.contractMu.Unlock()
m.contract.RefreshTokens[key] = refreshToken
return nil
}
func (m *Manager) readIDToken(homeID string, envAliases []string, realm, clientID string) (IDToken, error) {
m.contractMu.RLock()
defer m.contractMu.RUnlock()
for _, idt := range m.contract.IDTokens {
if idt.HomeAccountID == homeID && idt.Realm == realm && idt.ClientID == clientID {
if checkAlias(idt.Environment, envAliases) {
return idt, nil
}
}
}
return IDToken{}, fmt.Errorf("token not found")
}
func (m *Manager) writeIDToken(idToken IDToken) error {
key := idToken.Key()
m.contractMu.Lock()
defer m.contractMu.Unlock()
m.contract.IDTokens[key] = idToken
return nil
}
func (m *Manager) AllAccounts() []shared.Account {
m.contractMu.RLock()
defer m.contractMu.RUnlock()
var accounts []shared.Account
for _, v := range m.contract.Accounts {
accounts = append(accounts, v)
}
return accounts
}
func (m *Manager) Account(homeAccountID string) shared.Account {
m.contractMu.RLock()
defer m.contractMu.RUnlock()
for _, v := range m.contract.Accounts {
if v.HomeAccountID == homeAccountID {
return v
}
}
return shared.Account{}
}
func (m *Manager) readAccount(homeAccountID string, envAliases []string, realm string) (shared.Account, error) {
m.contractMu.RLock()
defer m.contractMu.RUnlock()
// You might ask why, if cache.Accounts is a map, we would loop through all of these instead of using a key.
// We only use a map because the storage contract shared between all language implementations says use a map.
// We can't change that. The other is because the keys are made using a specific "env", but here we are allowing
// a match in multiple envs (envAlias). That means we either need to hash each possible keyand do the lookup
// or just statically check. Since the design is to have a storage.Manager per user, the amount of keys stored
// is really low (say 2). Each hash is more expensive than the entire iteration.
for _, acc := range m.contract.Accounts {
if acc.HomeAccountID == homeAccountID && checkAlias(acc.Environment, envAliases) && acc.Realm == realm {
return acc, nil
}
}
return shared.Account{}, fmt.Errorf("account not found")
}
func (m *Manager) writeAccount(account shared.Account) error {
key := account.Key()
m.contractMu.Lock()
defer m.contractMu.Unlock()
m.contract.Accounts[key] = account
return nil
}
func (m *Manager) readAppMetaData(envAliases []string, clientID string) (AppMetaData, error) {
m.contractMu.RLock()
defer m.contractMu.RUnlock()
for _, app := range m.contract.AppMetaData {
if checkAlias(app.Environment, envAliases) && app.ClientID == clientID {
return app, nil
}
}
return AppMetaData{}, fmt.Errorf("not found")
}
func (m *Manager) writeAppMetaData(AppMetaData AppMetaData) error {
key := AppMetaData.Key()
m.contractMu.Lock()
defer m.contractMu.Unlock()
m.contract.AppMetaData[key] = AppMetaData
return nil
}
// RemoveAccount removes all the associated ATs, RTs and IDTs from the cache associated with this account.
func (m *Manager) RemoveAccount(account shared.Account, clientID string) {
m.removeRefreshTokens(account.HomeAccountID, account.Environment, clientID)
m.removeAccessTokens(account.HomeAccountID, account.Environment)
m.removeIDTokens(account.HomeAccountID, account.Environment)
m.removeAccounts(account.HomeAccountID, account.Environment)
}
func (m *Manager) removeRefreshTokens(homeID string, env string, clientID string) {
m.contractMu.Lock()
defer m.contractMu.Unlock()
for key, rt := range m.contract.RefreshTokens {
// Check for RTs associated with the account.
if rt.HomeAccountID == homeID && rt.Environment == env {
// Do RT's app ownership check as a precaution, in case family apps
// and 3rd-party apps share same token cache, although they should not.
if rt.ClientID == clientID || rt.FamilyID != "" {
delete(m.contract.RefreshTokens, key)
}
}
}
}
func (m *Manager) removeAccessTokens(homeID string, env string) {
m.contractMu.Lock()
defer m.contractMu.Unlock()
for key, at := range m.contract.AccessTokens {
// Remove AT's associated with the account
if at.HomeAccountID == homeID && at.Environment == env {
// # To avoid the complexity of locating sibling family app's AT, we skip AT's app ownership check.
// It means ATs for other apps will also be removed, it is OK because:
// non-family apps are not supposed to share token cache to begin with;
// Even if it happens, we keep other app's RT already, so SSO still works.
delete(m.contract.AccessTokens, key)
}
}
}
func (m *Manager) removeIDTokens(homeID string, env string) {
m.contractMu.Lock()
defer m.contractMu.Unlock()
for key, idt := range m.contract.IDTokens {
// Remove ID tokens associated with the account.
if idt.HomeAccountID == homeID && idt.Environment == env {
delete(m.contract.IDTokens, key)
}
}
}
func (m *Manager) removeAccounts(homeID string, env string) {
m.contractMu.Lock()
defer m.contractMu.Unlock()
for key, acc := range m.contract.Accounts {
// Remove the specified account.
if acc.HomeAccountID == homeID && acc.Environment == env {
delete(m.contract.Accounts, key)
}
}
}
// update updates the internal cache object. This is for use in tests, other uses are not
// supported.
func (m *Manager) update(cache *Contract) {
m.contractMu.Lock()
defer m.contractMu.Unlock()
m.contract = cache
}
// Marshal implements cache.Marshaler.
func (m *Manager) Marshal() ([]byte, error) {
m.contractMu.RLock()
defer m.contractMu.RUnlock()
return json.Marshal(m.contract)
}
// Unmarshal implements cache.Unmarshaler.
func (m *Manager) Unmarshal(b []byte) error {
m.contractMu.Lock()
defer m.contractMu.Unlock()
contract := NewContract()
err := json.Unmarshal(b, contract)
if err != nil {
return err
}
m.contract = contract
return nil
}
microsoft-authentication-library-for-go-1.0.0/apps/internal/base/internal/storage/storage_test.go 0000664 0000000 0000000 00000076325 14420263624 0033406 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package storage
import (
"context"
"errors"
"os"
"reflect"
"sort"
"testing"
"time"
internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/shared"
"github.com/kylelemons/godebug/pretty"
)
const (
testFile = "test_serialized_cache.json"
defaultEnvironment = "login.windows.net"
defaultHID = "uid.utid"
defaultRealm = "contoso"
defaultScopes = "s2 s1 s3"
defaultClientID = "my_client_id"
accessTokenSecret = "an access token"
rtSecret = "a refresh token"
idCred = "IdToken"
idSecret = "header.eyJvaWQiOiAib2JqZWN0MTIzNCIsICJwcmVmZXJyZWRfdXNlcm5hbWUiOiAiSm9obiBEb2UiLCAic3ViIjogInN1YiJ9.signature"
accUser = "John Doe"
accLID = "object1234"
accAuth = "MSSTS"
)
var (
atCached = time.Unix(1000, 0)
atExpires = time.Unix(4600, 0)
)
func newForTest(authorityClient aadInstanceDiscoveryer) *Manager {
m := &Manager{requests: authorityClient, aadCache: make(map[string]authority.InstanceDiscoveryMetadata)}
m.contract = NewContract()
return m
}
type fakeDiscoveryResponser struct {
err bool
ret authority.InstanceDiscoveryResponse
}
func (f *fakeDiscoveryResponser) AADInstanceDiscovery(ctx context.Context, authorityInfo authority.Info) (authority.InstanceDiscoveryResponse, error) {
if f.err {
return authority.InstanceDiscoveryResponse{}, errors.New("error")
}
return f.ret, nil
}
func TestCheckAlias(t *testing.T) {
aliases := []string{"testOne", "testTwo", "testThree"}
aliasOne := "noTest"
aliasTwo := "testOne"
if checkAlias(aliasOne, aliases) {
t.Errorf("%v isn't supposed to be in %v", aliasOne, aliases)
}
if !checkAlias(aliasTwo, aliases) {
t.Errorf("%v is supposed to be in %v", aliasTwo, aliases)
}
}
func TestIsMatchingScopes(t *testing.T) {
scopesOne := []string{"user.read", "openid", "user.write"}
scopesTwo := "openid user.write user.read"
if !isMatchingScopes(scopesOne, scopesTwo) {
t.Fatalf("Scopes %v and %v are supposed to be the same", scopesOne, scopesTwo)
}
scopesUpperCase := "openid User.Write User.Read"
if !isMatchingScopes(scopesOne, scopesUpperCase) {
t.Fatalf("Scopes %v and %v are supposed to be the same as the comparison is case insensitive", scopesOne, scopesUpperCase)
}
errorScopes := "openid user.read hello"
if isMatchingScopes(scopesOne, errorScopes) {
t.Fatalf("Scopes %v and %v are not supposed to be the same", scopesOne, errorScopes)
}
}
func TestAllAccounts(t *testing.T) {
testAccOne := shared.NewAccount("hid", "env", "realm", "lid", accAuth, "username")
testAccTwo := shared.NewAccount("HID", "ENV", "REALM", "LID", accAuth, "USERNAME")
cache := &Contract{
Accounts: map[string]shared.Account{
testAccOne.Key(): testAccOne,
testAccTwo.Key(): testAccTwo,
},
}
storageManager := Manager{}
storageManager.update(cache)
actualAccounts := storageManager.AllAccounts()
// AllAccounts() is unstable in that the order can be reversed between calls.
// This fixes that.
sort.Slice(
actualAccounts,
func(i, j int) bool {
return actualAccounts[i].HomeAccountID > actualAccounts[j].HomeAccountID
},
)
expectedAccounts := []shared.Account{testAccOne, testAccTwo}
if diff := pretty.Compare(expectedAccounts, actualAccounts); diff != "" {
t.Errorf("Actual accounts differ from expected accounts: -want/+got:\n%s", diff)
}
}
func TestReadAccessToken(t *testing.T) {
now := time.Now()
testAccessToken := NewAccessToken(
"hid",
"env",
"realm",
"cid",
now,
now,
now,
"openid user.read",
"secret",
)
cache := &Contract{
AccessTokens: map[string]AccessToken{
testAccessToken.Key(): testAccessToken,
},
}
storageManager := newForTest(nil)
storageManager.update(cache)
retAccessToken := storageManager.readAccessToken(
"hid",
[]string{"hello", "env", "test"},
"realm",
"cid",
[]string{"user.read", "openid"},
)
if diff := pretty.Compare(testAccessToken, retAccessToken); diff != "" {
t.Fatalf("Returned access token is not the same as expected access token: -want/+got:\n%s", diff)
}
retAccessToken = storageManager.readAccessToken(
"this_should_break_it",
[]string{"hello", "env", "test"},
"realm",
"cid",
[]string{"user.read", "openid"},
)
if !reflect.ValueOf(retAccessToken).IsZero() {
t.Fatal("expected to find no access token")
}
}
func TestWriteAccessToken(t *testing.T) {
now := time.Now()
storageManager := newForTest(nil)
testAccessToken := NewAccessToken(
"hid",
"env",
"realm",
"cid",
now,
now,
now,
"openid",
"secret",
)
key := testAccessToken.Key()
err := storageManager.writeAccessToken(testAccessToken)
if err != nil {
t.Fatalf("TestwriteAccessToken: got err == %s, want err == nil", err)
}
if diff := pretty.Compare(testAccessToken, storageManager.contract.AccessTokens[key]); diff != "" {
t.Errorf("TestwriteAccessToken: -want/+got:\n%s", diff)
}
}
func TestReadAccount(t *testing.T) {
testAcc := shared.NewAccount("hid", "env", "realm", "lid", accAuth, "username")
cache := &Contract{
Accounts: map[string]shared.Account{
testAcc.Key(): testAcc,
},
}
storageManager := newForTest(nil)
storageManager.update(cache)
returnedAccount, err := storageManager.readAccount("hid", []string{"hello", "env", "test"}, "realm")
if err != nil {
t.Fatalf("TestreadAccount: got err == %s, want err == nil", err)
}
if diff := pretty.Compare(testAcc, returnedAccount); diff != "" {
t.Errorf("TestreadAccount: -want/+got:\n%s", diff)
}
_, err = storageManager.readAccount("this_should_break_it", []string{"hello", "env", "test"}, "realm")
if err == nil {
t.Errorf("TestreadAccount: got err == nil, want err != nil")
}
}
func TestWriteAccount(t *testing.T) {
storageManager := newForTest(nil)
testAcc := shared.NewAccount("hid", "env", "realm", "lid", accAuth, "username")
key := testAcc.Key()
err := storageManager.writeAccount(testAcc)
if err != nil {
t.Fatalf("TestwriteAccount: got err == %s, want err == nil", err)
}
if diff := pretty.Compare(testAcc, storageManager.contract.Accounts[key]); diff != "" {
t.Errorf("TestwriteAccount: -want/+got:\n%s", diff)
}
}
func TestReadAppMetaData(t *testing.T) {
testAppMeta := NewAppMetaData("fid", "cid", "env")
cache := &Contract{
AppMetaData: map[string]AppMetaData{
testAppMeta.Key(): testAppMeta,
},
}
storageManager := newForTest(nil)
storageManager.update(cache)
returnedAppMeta, err := storageManager.readAppMetaData([]string{"hello", "test", "env"}, "cid")
if err != nil {
t.Fatalf("TestreadAppMetaData(readAppMetaData): got err == %s, want err == nil", err)
}
if diff := pretty.Compare(testAppMeta, returnedAppMeta); diff != "" {
t.Fatalf("TestreadAppMetaData(readAppMetaData): -want/+got:\n%s", diff)
}
_, err = storageManager.readAppMetaData([]string{"hello", "test", "env"}, "break_this")
if err == nil {
t.Fatalf("TestreadAppMetaData(bad readAppMetaData): got err == nil, want err != nil")
}
}
func TestWriteAppMetaData(t *testing.T) {
storageManager := newForTest(nil)
testAppMeta := NewAppMetaData("fid", "cid", "env")
key := testAppMeta.Key()
err := storageManager.writeAppMetaData(testAppMeta)
if err != nil {
t.Fatalf("TestwriteAppMetaData: got err == %s, want err == nil", err)
}
if diff := pretty.Compare(testAppMeta, storageManager.contract.AppMetaData[key]); diff != "" {
t.Errorf("TestwriteAppMetaData: -want/+got:\n%s", diff)
}
}
func TestReadIDToken(t *testing.T) {
testIDToken := NewIDToken(
"hid",
"env",
"realm",
"cid",
"secret",
)
cache := &Contract{
IDTokens: map[string]IDToken{
testIDToken.Key(): testIDToken,
},
}
storageManager := newForTest(nil)
storageManager.update(cache)
returnedIDToken, err := storageManager.readIDToken(
"hid",
[]string{"hello", "env", "test"},
"realm",
"cid",
)
if err != nil {
panic(err)
}
if diff := pretty.Compare(testIDToken, returnedIDToken); diff != "" {
t.Fatalf("TestreadIDToken(good token): -want/+got:\n%s", diff)
}
_, err = storageManager.readIDToken(
"this_should_break_it",
[]string{"hello", "env", "test"},
"realm",
"cid",
)
if err == nil {
t.Errorf("TestreadIDToken(bad token): got err == nil, want err != nil")
}
}
func TestWriteIDToken(t *testing.T) {
storageManager := newForTest(nil)
testIDToken := NewIDToken(
"hid",
"env",
"realm",
"cid",
"secret",
)
key := testIDToken.Key()
err := storageManager.writeIDToken(testIDToken)
if err != nil {
t.Fatalf("TestwriteIDToken: got err == %s, want err == nil", err)
}
if diff := pretty.Compare(testIDToken, storageManager.contract.IDTokens[key]); diff != "" {
t.Errorf("TestwriteIDToken: -want/+got:\n%s", diff)
}
}
func TestDefaultStorageManagerreadRefreshToken(t *testing.T) {
testRefreshTokenWithFID := accesstokens.NewRefreshToken(
"hid",
"env",
"cid",
"secret",
"fid",
)
testRefreshTokenWoFID := accesstokens.NewRefreshToken(
"hid",
"env",
"cid",
"secret",
"",
)
testRefreshTokenWoFIDAltCID := accesstokens.NewRefreshToken(
"hid",
"env",
"cid2",
"secret",
"",
)
type args struct {
homeAccountID string
envAliases []string
familyID string
clientID string
}
tests := []struct {
name string
contract *Contract
args args
want accesstokens.RefreshToken
err bool
}{
{
name: "Token without fid, read with fid, cid, env, and hid",
contract: &Contract{
RefreshTokens: map[string]accesstokens.RefreshToken{
testRefreshTokenWoFID.Key(): testRefreshTokenWoFID,
},
},
args: args{
homeAccountID: "hid",
envAliases: []string{"test", "env", "hello"},
familyID: "fid",
clientID: "cid",
},
want: testRefreshTokenWoFID,
},
{
name: "Token without fid, read with cid, env, and hid",
contract: &Contract{
RefreshTokens: map[string]accesstokens.RefreshToken{
testRefreshTokenWoFID.Key(): testRefreshTokenWoFID,
},
},
args: args{
homeAccountID: "hid",
envAliases: []string{"test", "env", "hello"},
familyID: "",
clientID: "cid",
},
want: testRefreshTokenWoFID,
},
{
name: "Token without fid, verify CID is required",
contract: &Contract{
RefreshTokens: map[string]accesstokens.RefreshToken{
testRefreshTokenWoFID.Key(): testRefreshTokenWoFID,
},
},
args: args{
homeAccountID: "hid",
envAliases: []string{"test", "env", "hello"},
familyID: "",
clientID: "",
},
err: true,
},
{
name: "Token without fid, Verify env is required",
contract: &Contract{
RefreshTokens: map[string]accesstokens.RefreshToken{
testRefreshTokenWoFID.Key(): testRefreshTokenWoFID,
},
},
args: args{
homeAccountID: "hid",
envAliases: []string{},
familyID: "",
clientID: "",
},
err: true,
},
{
name: "Token without fid, read with fid, cid, env, and hid",
contract: &Contract{
RefreshTokens: map[string]accesstokens.RefreshToken{
testRefreshTokenWoFID.Key(): testRefreshTokenWithFID,
},
},
args: args{
homeAccountID: "hid",
envAliases: []string{"test", "env", "hello"},
familyID: "fid",
clientID: "cid",
},
want: testRefreshTokenWithFID,
},
{
name: "Token with fid, read with cid, env, and hid",
contract: &Contract{
RefreshTokens: map[string]accesstokens.RefreshToken{
testRefreshTokenWoFID.Key(): testRefreshTokenWithFID,
},
},
args: args{
homeAccountID: "hid",
envAliases: []string{"test", "env", "hello"},
familyID: "",
clientID: "cid",
},
want: testRefreshTokenWithFID,
},
{
name: "Token with fid, verify CID is not required", // match on hid, env, and has fid
contract: &Contract{
RefreshTokens: map[string]accesstokens.RefreshToken{
testRefreshTokenWoFID.Key(): testRefreshTokenWithFID,
},
},
args: args{
homeAccountID: "hid",
envAliases: []string{"test", "env", "hello"},
familyID: "",
clientID: "",
},
want: testRefreshTokenWithFID,
},
{
name: "Token with fid, Verify env is required",
contract: &Contract{
RefreshTokens: map[string]accesstokens.RefreshToken{
testRefreshTokenWoFID.Key(): testRefreshTokenWithFID,
},
},
args: args{
homeAccountID: "hid",
envAliases: []string{},
familyID: "",
clientID: "",
},
err: true,
},
{
name: "Multiple items in cache, given a fid, item with fid will be returned",
contract: &Contract{
RefreshTokens: map[string]accesstokens.RefreshToken{
testRefreshTokenWoFID.Key(): testRefreshTokenWoFID,
testRefreshTokenWithFID.Key(): testRefreshTokenWithFID,
testRefreshTokenWoFIDAltCID.Key(): testRefreshTokenWoFIDAltCID,
},
},
args: args{
homeAccountID: "hid",
envAliases: []string{},
familyID: "fid",
clientID: "cid",
},
err: true,
},
// Cannot guarentee that without an alternate cid which token will be
// returned deterministically when HID, CID, and env match.
{
name: "Multiple items in cache, without a fid and with alternate CID, token with alternate CID is returned",
contract: &Contract{
RefreshTokens: map[string]accesstokens.RefreshToken{
testRefreshTokenWoFID.Key(): testRefreshTokenWoFID,
testRefreshTokenWithFID.Key(): testRefreshTokenWithFID,
testRefreshTokenWoFIDAltCID.Key(): testRefreshTokenWoFIDAltCID,
},
},
args: args{
homeAccountID: "hid",
envAliases: []string{},
familyID: "",
clientID: "cid2",
},
err: true,
},
}
m := &Manager{}
for _, test := range tests {
m.update(test.contract)
got, err := m.readRefreshToken(test.args.homeAccountID, test.args.envAliases, test.args.familyID, test.args.clientID)
switch {
case test.err && err == nil:
t.Errorf("TestDefaultStorageManagerreadRefreshToken(%s): got err == nil, want err != nil", test.name)
continue
case !test.err && err != nil:
t.Errorf("TestDefaultStorageManagerreadRefreshToken(%s): got err == %s, want err == nil", test.name, err)
continue
case err != nil:
continue
}
if diff := pretty.Compare(test.want, got); diff != "" {
t.Errorf("TestDefaultStorageManagerreadRefreshToken(%s): -want/+got:\n%s", test.name, diff)
}
}
}
func TestWriteRefreshToken(t *testing.T) {
storageManager := newForTest(nil)
testRefreshToken := accesstokens.NewRefreshToken(
"hid",
"env",
"cid",
"secret",
"fid",
)
key := testRefreshToken.Key()
err := storageManager.writeRefreshToken(testRefreshToken)
if err != nil {
t.Errorf("Error should be nil, but it is %v", err)
}
if !reflect.DeepEqual(storageManager.contract.RefreshTokens[key], testRefreshToken) {
t.Errorf("Added refresh token %v differs from expected refresh token %v",
storageManager.contract.RefreshTokens[key],
testRefreshToken)
}
}
func TestStorageManagerSerialize(t *testing.T) {
contract := &Contract{
AccessTokens: map[string]AccessToken{
"an-entry": {
AdditionalFields: map[string]interface{}{
"foo": "bar",
},
},
"uid.utid-login.windows.net-accesstoken-my_client_id-contoso-s2 s1 s3": {
Environment: defaultEnvironment,
CredentialType: "AccessToken",
Secret: accessTokenSecret,
Realm: defaultRealm,
Scopes: defaultScopes,
ClientID: defaultClientID,
CachedAt: internalTime.Unix{T: atCached},
HomeAccountID: defaultHID,
ExpiresOn: internalTime.Unix{T: atExpires},
ExtendedExpiresOn: internalTime.Unix{T: atExpires},
},
},
RefreshTokens: map[string]accesstokens.RefreshToken{
"uid.utid-login.windows.net-refreshtoken-my_client_id--s2 s1 s3": {
Target: defaultScopes,
Environment: defaultEnvironment,
CredentialType: "RefreshToken",
Secret: rtSecret,
ClientID: defaultClientID,
HomeAccountID: defaultHID,
},
},
IDTokens: map[string]IDToken{
"uid.utid-login.windows.net-idtoken-my_client_id-contoso-": {
Realm: defaultRealm,
Environment: defaultEnvironment,
CredentialType: idCred,
Secret: idSecret,
ClientID: defaultClientID,
HomeAccountID: defaultHID,
},
},
Accounts: map[string]shared.Account{
"uid.utid-login.windows.net-contoso": {
PreferredUsername: accUser,
LocalAccountID: accLID,
Realm: defaultRealm,
Environment: defaultEnvironment,
HomeAccountID: defaultHID,
AuthorityType: accAuth,
},
},
AppMetaData: map[string]AppMetaData{
"AppMetadata-login.windows.net-my_client_id": {
Environment: defaultEnvironment,
FamilyID: "",
ClientID: defaultClientID,
},
},
}
manager := newForTest(nil)
manager.update(contract)
_, err := manager.Marshal()
if err != nil {
t.Errorf("Error should be nil; instead it is %v", err)
}
}
func TestUnmarshal(t *testing.T) {
manager := newForTest(nil)
b, err := os.ReadFile(testFile)
if err != nil {
panic(err)
}
err = manager.Unmarshal(b)
if err != nil {
t.Fatalf("TestUnmarshal(unmarshal): got err == %s, want err == nil", err)
}
actualAccessTokenSecret := manager.contract.AccessTokens["uid.utid-login.windows.net-accesstoken-my_client_id-contoso-s2 s1 s3"].Secret
if accessTokenSecret != actualAccessTokenSecret {
t.Errorf("TestUnmarshal(access token secret):got %q, want %q", actualAccessTokenSecret, accessTokenSecret)
}
actualRTSecret := manager.contract.RefreshTokens["uid.utid-login.windows.net-refreshtoken-my_client_id--s2 s1 s3"].Secret
if diff := pretty.Compare(rtSecret, actualRTSecret); diff != "" {
t.Errorf("TestUnmarshal(refresh token secret): -want/+got:\n%s", diff)
}
actualIDSecret := manager.contract.IDTokens["uid.utid-login.windows.net-idtoken-my_client_id-contoso-"].Secret
if diff := pretty.Compare(idSecret, actualIDSecret); diff != "" {
t.Errorf("TestUnmarshal(id secret): -want/+got:\n%s", diff)
}
actualUser := manager.contract.Accounts["uid.utid-login.windows.net-contoso"].PreferredUsername
if diff := pretty.Compare(actualUser, accUser); diff != "" {
t.Errorf("TestUnmarshal(actula user): -want/+got:\n%s", diff)
}
if manager.contract.AppMetaData["AppMetadata-login.windows.net-my_client_id"].FamilyID != "" {
t.Errorf("TestUnmarshal(app metadata family id): got %q, want empty string", manager.contract.AppMetaData["AppMetadata-login.windows.net-my_client_id"].FamilyID)
}
}
func TestIsAccessTokenValid(t *testing.T) {
cachedAt := time.Now()
badCachedAt := time.Now().Add(500 * time.Second)
expiresOn := time.Now().Add(1000 * time.Second)
badExpiresOn := time.Now().Add(200 * time.Second)
extended := time.Now()
tests := []struct {
desc string
token AccessToken
err bool
}{
{
desc: "Success",
token: NewAccessToken("hid", "env", "realm", "cid", cachedAt, expiresOn, extended, "openid", "secret"),
},
{
desc: "ExpiresOnUnixTimestamp has expired",
token: NewAccessToken("hid", "env", "realm", "cid", cachedAt, badExpiresOn, extended, "openid", "secret"),
err: true,
},
{
desc: "Success",
token: NewAccessToken("hid", "env", "realm", "cid", badCachedAt, expiresOn, extended, "openid", "secret"),
err: true,
},
}
for _, test := range tests {
err := test.token.Validate()
switch {
case err == nil && test.err:
t.Errorf("TestIsAccessTokenValid(%s): got err == nil, want err != nil", test.desc)
case err != nil && !test.err:
t.Errorf("TestIsAccessTokenValid(%s): got err == %s, want err == nil", test.desc, err)
}
}
}
func TestRead(t *testing.T) {
accessTokenCacheItem := NewAccessToken(
"hid",
"env",
"realm",
"cid",
time.Now(),
time.Now().Add(1000*time.Second),
time.Now(),
"openid profile",
"secret",
)
testIDToken := NewIDToken("hid", "env", "realm", "cid", "secret")
testAppMeta := NewAppMetaData("fid", "cid", "env")
testRefreshToken := accesstokens.NewRefreshToken("hid", "env", "cid", "secret", "fid")
testAccount := shared.NewAccount("hid", "env", "realm", "lid", accAuth, "username")
contract := &Contract{
RefreshTokens: map[string]accesstokens.RefreshToken{
testRefreshToken.Key(): testRefreshToken,
},
Accounts: map[string]shared.Account{
testAccount.Key(): testAccount,
},
AppMetaData: map[string]AppMetaData{
testAppMeta.Key(): testAppMeta,
},
IDTokens: map[string]IDToken{
testIDToken.Key(): testIDToken,
},
AccessTokens: map[string]AccessToken{
accessTokenCacheItem.Key(): accessTokenCacheItem,
},
}
authInfo := authority.Info{
Host: "env",
Tenant: "realm",
}
authParameters := authority.AuthParams{
HomeAccountID: "hid",
AuthorityInfo: authInfo,
ClientID: "cid",
Scopes: []string{"openid", "profile"},
}
tests := []struct {
desc string
discRespErr bool
discResp authority.InstanceDiscoveryResponse
err bool
want TokenResponse
}{
{
desc: "Error: AAD Discovery Fails",
discRespErr: true,
err: true,
},
{
desc: "Success",
discResp: authority.InstanceDiscoveryResponse{
TenantDiscoveryEndpoint: "tenant",
Metadata: []authority.InstanceDiscoveryMetadata{
{
Aliases: []string{"env", "alias2"},
},
{
Aliases: []string{"alias3", "alias4"},
},
},
},
want: TokenResponse{
AccessToken: accessTokenCacheItem,
RefreshToken: testRefreshToken,
IDToken: testIDToken,
Account: testAccount,
},
},
}
for _, test := range tests {
responder := &fakeDiscoveryResponser{err: test.discRespErr, ret: test.discResp}
manager := newForTest(responder)
manager.update(contract)
got, err := manager.Read(context.Background(), authParameters)
switch {
case err == nil && test.err:
t.Errorf("TestRead(%s): got err == nil, want err != nil", test.desc)
continue
case err != nil && !test.err:
t.Errorf("TestRead(%s): got err == %s, want err == nil", test.desc, err)
continue
case err != nil:
continue
}
if diff := pretty.Compare(test.want, got); diff != "" {
t.Errorf("TestRead(%s): -want/+got:\n%s", test.desc, diff)
}
}
}
func removeSubSeconds(t time.Time) time.Time {
t = t.Add(-time.Duration(t.Nanosecond()))
return t
}
func TestWrite(t *testing.T) {
now := removeSubSeconds(time.Now().UTC())
cacheManager := newForTest(nil)
clientInfo := accesstokens.ClientInfo{
UID: "testUID",
UTID: "testUtid",
}
idToken := accesstokens.IDToken{
RawToken: "idToken",
Oid: "lid",
PreferredUsername: "username",
}
expiresOn := internalTime.DurationTime{T: now.Add(1000 * time.Second)}
tokenResponse := accesstokens.TokenResponse{
AccessToken: "accessToken",
RefreshToken: "refreshToken",
IDToken: idToken,
FamilyID: "fid",
ClientInfo: clientInfo,
GrantedScopes: accesstokens.Scopes{Slice: []string{"openid", "profile"}},
ExpiresOn: expiresOn,
ExtExpiresOn: internalTime.DurationTime{T: now},
}
authInfo := authority.Info{Host: "env", Tenant: "realm", AuthorityType: accAuth}
authParams := authority.AuthParams{
AuthorityInfo: authInfo,
ClientID: "cid",
}
testRefreshToken := accesstokens.NewRefreshToken(
"testUID.testUtid",
"env",
"cid",
"refreshToken",
"fid",
)
AccessToken := NewAccessToken(
"testUID.testUtid",
"env",
"realm",
"cid",
now,
now.Add(1000*time.Second),
now,
"openid profile",
"accessToken",
)
testIDToken := NewIDToken(
"testUID.testUtid",
"env",
"realm",
"cid",
"idToken",
)
testAccount := shared.NewAccount("testUID.testUtid", "env", "realm", "lid", accAuth, "username")
testAppMeta := NewAppMetaData("fid", "cid", "env")
actualAccount, err := cacheManager.Write(authParams, tokenResponse)
if err != nil {
t.Errorf("Error should be nil; instead, it is %v", err)
}
if !reflect.DeepEqual(actualAccount, testAccount) {
t.Errorf("Actual account %+v differs from expected account %+v", actualAccount, testAccount)
}
gotRefresh, ok := cacheManager.contract.RefreshTokens[testRefreshToken.Key()]
if !ok {
t.Fatalf("TestWrite(refresh token): refresh token was not written as expected")
}
if diff := pretty.Compare(testRefreshToken, gotRefresh); diff != "" {
t.Fatalf("TestWrite(refresh token): -want/+got\n%s", diff)
}
gotAccess, ok := cacheManager.contract.AccessTokens[AccessToken.Key()]
if !ok {
t.Fatalf("TestWrite(access token): access token was not written as expected")
}
// CachedAt is generated for this exact moment, not from input. We would need to
// fake time.Now() call with a var now = time.Now() in the package in order to
// control this or we can just ignore this value. We are going to simply check its
// not zero and then zero it for our got/want comparison.
if gotAccess.CachedAt.T.IsZero() {
t.Fatalf("TestWrite(access token): AccessToken.CachedAt is the zero value, which is incorrect")
}
gotAccess.CachedAt = internalTime.Unix{}
AccessToken.CachedAt = internalTime.Unix{}
if diff := pretty.Compare(AccessToken, gotAccess); diff != "" {
t.Fatalf("TestWrite(access token): -want/+got\n%s", diff)
}
gotToken, ok := cacheManager.contract.IDTokens[testIDToken.Key()]
if !ok {
t.Fatalf("TestWrite(id token): id token was not written as expected")
}
if diff := pretty.Compare(testIDToken, gotToken); diff != "" {
t.Fatalf("TestWrite(id token): -want/+got\n%s", diff)
}
gotAccount, ok := cacheManager.contract.Accounts[testAccount.Key()]
if !ok {
t.Fatalf("TestWrite(account): account was not written as expected")
}
if diff := pretty.Compare(testAccount, gotAccount); diff != "" {
t.Fatalf("TestWrite(account): -want/+got\n%s", diff)
}
gotMeta, ok := cacheManager.contract.AppMetaData[testAppMeta.Key()]
if !ok {
t.Fatalf("TestWrite(app metadata): metadata was not written as expected")
}
if diff := pretty.Compare(testAppMeta, gotMeta); diff != "" {
t.Fatalf("TestWrite(app metadata): -want/+got\n%s", diff)
}
}
func TestRemoveRefreshTokens(t *testing.T) {
storageManager := newForTest(nil)
testRefreshToken := accesstokens.NewRefreshToken("hid", "env", "cid", "secret", "fid")
key := testRefreshToken.Key()
contract := &Contract{
RefreshTokens: map[string]accesstokens.RefreshToken{
key: testRefreshToken,
},
}
storageManager.update(contract)
storageManager.removeRefreshTokens("hid", "env", "cid")
if val, ok := storageManager.contract.RefreshTokens[key]; ok {
t.Fatalf("TestRemoveRefreshTokens: got refreshToken == %s, want refreshToken == empty", val)
}
}
func TestRemoveAccessTokens(t *testing.T) {
now := time.Now()
storageManager := newForTest(nil)
testAccessToken := NewAccessToken("hid", "env", "realm", "cid", now, now, now, "openid", "secret")
key := testAccessToken.Key()
contract := &Contract{
AccessTokens: map[string]AccessToken{
key: testAccessToken,
},
}
storageManager.update(contract)
storageManager.removeAccessTokens("hid", "env")
if val, ok := storageManager.contract.AccessTokens[key]; ok {
t.Fatalf("TestRemoveAccessTokens: got accessToken == %s, want accessToken == empty", val)
}
}
func TestRemoveIDTokens(t *testing.T) {
storageManager := newForTest(nil)
testIDToken := NewIDToken("hid", "env", "realm", "cid", "secret")
key := testIDToken.Key()
contract := &Contract{
IDTokens: map[string]IDToken{
key: testIDToken,
},
}
storageManager.update(contract)
storageManager.removeIDTokens("hid", "env")
if val, ok := storageManager.contract.IDTokens[key]; ok {
t.Fatalf("TestRemoveIDTokens: got IDToken == %s, want IDToken == empty", val)
}
}
func TestRemoveAccountObject(t *testing.T) {
storageManager := newForTest(nil)
testAccount := shared.NewAccount("hid", "env", "realm", "lid", accAuth, "username")
key := testAccount.Key()
contract := &Contract{
Accounts: map[string]shared.Account{
key: testAccount,
},
}
storageManager.update(contract)
storageManager.removeAccounts("hid", "env")
if val, ok := storageManager.contract.Accounts[key]; ok {
t.Fatalf("TestRemoveAccountObject: got Account == %s, want Account == empty", val)
}
}
func TestRemoveAccount(t *testing.T) {
now := time.Now()
testAccessToken := NewAccessToken("hid", "env", "realm", "cid", now, now, now, "openid profile", "secret")
testIDToken := NewIDToken("hid", "env", "realm", "cid", "secret")
testAppMeta := NewAppMetaData("fid", "cid", "env")
testRefreshToken := accesstokens.NewRefreshToken("hid", "env", "cid", "secret", "fid")
testAccount := shared.NewAccount("hid", "env", "realm", "lid", accAuth, "username")
contract := &Contract{
RefreshTokens: map[string]accesstokens.RefreshToken{
testRefreshToken.Key(): testRefreshToken,
},
Accounts: map[string]shared.Account{
testAccount.Key(): testAccount,
},
AppMetaData: map[string]AppMetaData{
testAppMeta.Key(): testAppMeta,
},
IDTokens: map[string]IDToken{
testIDToken.Key(): testIDToken,
},
AccessTokens: map[string]AccessToken{
testAccessToken.Key(): testAccessToken,
},
}
manager := newForTest(nil)
manager.update(contract)
manager.RemoveAccount(testAccount, "cid")
if val, ok := manager.contract.RefreshTokens[testRefreshToken.Key()]; ok {
t.Fatalf("TestRemoveAccount: got refreshToken == %s, want refreshToken == empty", val)
}
if val, ok := manager.contract.AccessTokens[testAccessToken.Key()]; ok {
t.Fatalf("TestRemoveAccount: got accessToken == %s, want accessToken == empty", val)
}
if val, ok := manager.contract.IDTokens[testIDToken.Key()]; ok {
t.Fatalf("TestRemoveAccount: got IDToken == %s, want IDToken == empty", val)
}
if val, ok := manager.contract.Accounts[testAccount.Key()]; ok {
t.Fatalf("TestRemoveAccount: got Account == %s, want Account == empty", val)
}
}
func TestRemoveEmptyAccount(t *testing.T) {
now := time.Now()
testAccessToken := NewAccessToken("hid", "env", "realm", "cid", now, now, now, "openid profile", "secret")
testIDToken := NewIDToken("hid", "env", "realm", "cid", "secret")
testAppMeta := NewAppMetaData("fid", "cid", "env")
testRefreshToken := accesstokens.NewRefreshToken("hid", "env", "cid", "secret", "fid")
testAccount := shared.NewAccount("hid", "env", "realm", "lid", accAuth, "username")
contract := &Contract{
RefreshTokens: map[string]accesstokens.RefreshToken{
testRefreshToken.Key(): testRefreshToken,
},
Accounts: map[string]shared.Account{
testAccount.Key(): testAccount,
},
AppMetaData: map[string]AppMetaData{
testAppMeta.Key(): testAppMeta,
},
IDTokens: map[string]IDToken{
testIDToken.Key(): testIDToken,
},
AccessTokens: map[string]AccessToken{
testAccessToken.Key(): testAccessToken,
},
}
manager := newForTest(nil)
manager.update(contract)
manager.RemoveAccount(shared.Account{}, "cid")
if _, ok := manager.contract.RefreshTokens[testRefreshToken.Key()]; !ok {
t.Fatalf("TestRemoveEmptyAccount: got refreshToken == empty, want refreshToken == %s", testRefreshToken)
}
if _, ok := manager.contract.AccessTokens[testAccessToken.Key()]; !ok {
t.Fatalf("TestRemoveEmptyAccount: got accessToken == empty, want accessToken == %s", testAccessToken)
}
if _, ok := manager.contract.IDTokens[testIDToken.Key()]; !ok {
t.Fatalf("TestRemoveEmptyAccount: got IDToken == empty, want IDToken == %s", testIDToken)
}
if _, ok := manager.contract.Accounts[testAccount.Key()]; !ok {
t.Fatalf("TestRemoveEmptyAccount: got Account == empty, want Account == %s", testAccount)
}
}
test_serialized_cache.json 0000664 0000000 0000000 00000003411 14420263624 0035467 0 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/apps/internal/base/internal/storage {
"Account": {
"uid.utid-login.windows.net-contoso": {
"username": "John Doe",
"local_account_id": "object1234",
"realm": "contoso",
"environment": "login.windows.net",
"home_account_id": "uid.utid",
"authority_type": "MSSTS"
}
},
"RefreshToken": {
"uid.utid-login.windows.net-refreshtoken-my_client_id--s2 s1 s3": {
"target": "s2 s1 s3",
"environment": "login.windows.net",
"credential_type": "RefreshToken",
"secret": "a refresh token",
"client_id": "my_client_id",
"home_account_id": "uid.utid"
}
},
"AccessToken": {
"an-entry": {
"foo": "bar"
},
"uid.utid-login.windows.net-accesstoken-my_client_id-contoso-s2 s1 s3": {
"environment": "login.windows.net",
"credential_type": "AccessToken",
"secret": "an access token",
"realm": "contoso",
"target": "s2 s1 s3",
"client_id": "my_client_id",
"cached_at": "1000",
"home_account_id": "uid.utid",
"extended_expires_on": "4600",
"expires_on": "4600"
}
},
"IdToken": {
"uid.utid-login.windows.net-idtoken-my_client_id-contoso-": {
"realm": "contoso",
"environment": "login.windows.net",
"credential_type": "IdToken",
"secret": "header.eyJvaWQiOiAib2JqZWN0MTIzNCIsICJwcmVmZXJyZWRfdXNlcm5hbWUiOiAiSm9obiBEb2UiLCAic3ViIjogInN1YiJ9.signature",
"client_id": "my_client_id",
"home_account_id": "uid.utid"
}
},
"unknownEntity": {"field1":"1","field2":"whats"},
"AppMetadata": {
"AppMetadata-login.windows.net-my_client_id": {
"environment": "login.windows.net",
"client_id": "my_client_id"
}
}
} microsoft-authentication-library-for-go-1.0.0/apps/internal/exported/ 0000775 0000000 0000000 00000000000 14420263624 0025777 5 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/apps/internal/exported/exported.go 0000664 0000000 0000000 00000002315 14420263624 0030161 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
// package exported contains internal types that are re-exported from a public package
package exported
// AssertionRequestOptions has information required to generate a client assertion
type AssertionRequestOptions struct {
// ClientID identifies the application for which an assertion is requested. Used as the assertion's "iss" and "sub" claims.
ClientID string
// TokenEndpoint is the intended token endpoint. Used as the assertion's "aud" claim.
TokenEndpoint string
}
// TokenProviderParameters is the authentication parameters passed to token providers
type TokenProviderParameters struct {
// Claims contains any additional claims requested for the token
Claims string
// CorrelationID of the authentication request
CorrelationID string
// Scopes requested for the token
Scopes []string
// TenantID identifies the tenant in which to authenticate
TenantID string
}
// TokenProviderResult is the authentication result returned by custom token providers
type TokenProviderResult struct {
// AccessToken is the requested token
AccessToken string
// ExpiresInSeconds is the lifetime of the token in seconds
ExpiresInSeconds int
}
microsoft-authentication-library-for-go-1.0.0/apps/internal/json/ 0000775 0000000 0000000 00000000000 14420263624 0025116 5 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/apps/internal/json/design.md 0000664 0000000 0000000 00000012705 14420263624 0026716 0 ustar 00root root 0000000 0000000 # JSON Package Design
Author: John Doak(jdoak@microsoft.com)
## Why?
This project needs a special type of marshal/unmarshal not directly supported
by the encoding/json package.
The need revolves around a few key wants/needs:
- unmarshal and marshal structs representing JSON messages
- fields in the messgage not in the struct must be maintained when unmarshalled
- those same fields must be marshalled back when encoded again
The initial version used map[string]interface{} to put in the keys that
were known and then any other keys were put into a field called AdditionalFields.
This has a few negatives:
- Dual marshaling/unmarshalling is required
- Adding a struct field requires manually adding a key by name to be encoded/decoded from the map (which is a loosely coupled construct), which can lead to bugs that aren't detected or have bad side effects
- Tests can become quickly disconnected if those keys aren't put
in tests as well. So you think you have support working, but you
don't. Existing tests were found that didn't test the marshalling output.
- There is no enforcement that if AdditionalFields is required on one struct, it should be on all containers
that don't have custom marshal/unmarshal.
This package aims to support our needs by providing custom Marshal()/Unmarshal() functions.
This prevents all the negatives in the initial solution listed above. However, it does add its own negative:
- Custom encoding/decoding via reflection is messy (as can be seen in encoding/json itself)
Go proverb: Reflection is never clear
Suggested reading: https://blog.golang.org/laws-of-reflection
## Important design decisions
- We don't want to understand all JSON decoding rules
- We don't want to deal with all the quoting, commas, etc on decode
- Need support for json.Marshaler/Unmarshaler, so we can support types like time.Time
- If struct does not implement json.Unmarshaler, it must have AdditionalFields defined
- We only support root level objects that are \*struct or struct
To faciliate these goals, we will utilize the json.Encoder and json.Decoder.
They provide streaming processing (efficient) and return errors on bad JSON.
Support for json.Marshaler/Unmarshaler allows for us to use non-basic types
that must be specially encoded/decoded (like time.Time objects).
We don't support types that can't customer unmarshal or have AdditionalFields
in order to prevent future devs from forgetting that important field and
generating bad return values.
Support for root level objects of \*struct or struct simply acknowledges the
fact that this is designed only for the purposes listed in the Introduction.
Outside that (like encoding a lone number) should be done with the
regular json package (as it will not have additional fields).
We don't support a few things on json supported reference types and structs:
- \*map: no need for pointers to maps
- \*slice: no need for pointers to slices
- any further pointers on struct after \*struct
There should never be a need for this in Go.
## Design
## State Machines
This uses state machine designs that based upon the Rob Pike talk on
lexers and parsers: https://www.youtube.com/watch?v=HxaD_trXwRE
This is the most common pattern for state machines in Go and
the model to follow closesly when dealing with streaming
processing of textual data.
Our state machines are based on the type:
```go
type stateFn func() (stateFn, error)
```
The state machine itself is simply a struct that has methods that
satisfy stateFn.
Our state machines have a few standard calls
- run(): runs the state machine
- start(): always the first stateFn to be called
All state machines have the following logic:
* run() is called
* start() is called and returns the next stateFn or error
* stateFn is called
- If returned stateFn(next state) is non-nil, call it
- If error is non-nil, run() returns the error
- If stateFn == nil and err == nil, run() return err == nil
## Supporting types
Marshalling/Unmarshalling must support(within top level struct):
- struct
- \*struct
- []struct
- []\*struct
- []map[string]structContainer
- [][]structContainer
**Term note:** structContainer == type that has a struct or \*struct inside it
We specifically do not support []interface or map[string]interface
where the interface value would hold some value with a struct in it.
Those will still marshal/unmarshal, but without support for
AdditionalFields.
## Marshalling
The marshalling design will be based around a statemachine design.
The basic logic is as follows:
* If struct has custom marshaller, call it and return
* If struct has field "AdditionalFields", it must be a map[string]interface{}
* If struct does not have "AdditionalFields", give an error
* Get struct tag detailing json names to go names, create mapping
* For each public field name
- Write field name out
- If field value is a struct, recursively call our state machine
- Otherwise, use the json.Encoder to write out the value
## Unmarshalling
The unmarshalling desin is also based around a statemachine design. The
basic logic is as follows:
* If struct has custom marhaller, call it
* If struct has field "AdditionalFields", it must be a map[string]interface{}
* Get struct tag detailing json names to go names, create mapping
* For each key found
- If key exists,
- If value is basic type, extract value into struct field using Decoder
- If value is struct type, recursively call statemachine
- If key doesn't exist, add it to AdditionalFields if it exists using Decoder
microsoft-authentication-library-for-go-1.0.0/apps/internal/json/json.go 0000664 0000000 0000000 00000011762 14420263624 0026425 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
// Package json provide functions for marshalling an unmarshalling types to JSON. These functions are meant to
// be utilized inside of structs that implement json.Unmarshaler and json.Marshaler interfaces.
// This package provides the additional functionality of writing fields that are not in the struct when marshalling
// to a field called AdditionalFields if that field exists and is a map[string]interface{}.
// When marshalling, if the struct has all the same prerequisites, it will uses the keys in AdditionalFields as
// extra fields. This package uses encoding/json underneath.
package json
import (
"bytes"
"encoding/json"
"fmt"
"reflect"
"strings"
)
const addField = "AdditionalFields"
const (
marshalJSON = "MarshalJSON"
unmarshalJSON = "UnmarshalJSON"
)
var (
leftBrace = []byte("{")[0]
rightBrace = []byte("}")[0]
comma = []byte(",")[0]
leftParen = []byte("[")[0]
rightParen = []byte("]")[0]
)
var mapStrInterType = reflect.TypeOf(map[string]interface{}{})
// stateFn defines a state machine function. This will be used in all state
// machines in this package.
type stateFn func() (stateFn, error)
// Marshal is used to marshal a type into its JSON representation. It
// wraps the stdlib calls in order to marshal a struct or *struct so
// that a field called "AdditionalFields" of type map[string]interface{}
// with "-" used inside struct tag `json:"-"` can be marshalled as if
// they were fields within the struct.
func Marshal(i interface{}) ([]byte, error) {
buff := bytes.Buffer{}
enc := json.NewEncoder(&buff)
enc.SetEscapeHTML(false)
enc.SetIndent("", "")
v := reflect.ValueOf(i)
if v.Kind() != reflect.Ptr && v.CanAddr() {
v = v.Addr()
}
err := marshalStruct(v, &buff, enc)
if err != nil {
return nil, err
}
return buff.Bytes(), nil
}
// Unmarshal unmarshals a []byte representing JSON into i, which must be a *struct. In addition, if the struct has
// a field called AdditionalFields of type map[string]interface{}, JSON data representing fields not in the struct
// will be written as key/value pairs to AdditionalFields.
func Unmarshal(b []byte, i interface{}) error {
if len(b) == 0 {
return nil
}
jdec := json.NewDecoder(bytes.NewBuffer(b))
jdec.UseNumber()
return unmarshalStruct(jdec, i)
}
// MarshalRaw marshals i into a json.RawMessage. If I cannot be marshalled,
// this will panic. This is exposed to help test AdditionalField values
// which are stored as json.RawMessage.
func MarshalRaw(i interface{}) json.RawMessage {
b, err := json.Marshal(i)
if err != nil {
panic(err)
}
return json.RawMessage(b)
}
// isDelim simply tests to see if a json.Token is a delimeter.
func isDelim(got json.Token) bool {
switch got.(type) {
case json.Delim:
return true
}
return false
}
// delimIs tests got to see if it is want.
func delimIs(got json.Token, want rune) bool {
switch v := got.(type) {
case json.Delim:
if v == json.Delim(want) {
return true
}
}
return false
}
// hasMarshalJSON will determine if the value or a pointer to this value has
// the MarshalJSON method.
func hasMarshalJSON(v reflect.Value) bool {
if method := v.MethodByName(marshalJSON); method.Kind() != reflect.Invalid {
_, ok := v.Interface().(json.Marshaler)
return ok
}
if v.Kind() == reflect.Ptr {
v = v.Elem()
} else {
if !v.CanAddr() {
return false
}
v = v.Addr()
}
if method := v.MethodByName(marshalJSON); method.Kind() != reflect.Invalid {
_, ok := v.Interface().(json.Marshaler)
return ok
}
return false
}
// callMarshalJSON will call MarshalJSON() method on the value or a pointer to this value.
// This will panic if the method is not defined.
func callMarshalJSON(v reflect.Value) ([]byte, error) {
if method := v.MethodByName(marshalJSON); method.Kind() != reflect.Invalid {
marsh := v.Interface().(json.Marshaler)
return marsh.MarshalJSON()
}
if v.Kind() == reflect.Ptr {
v = v.Elem()
} else {
if v.CanAddr() {
v = v.Addr()
}
}
if method := v.MethodByName(unmarshalJSON); method.Kind() != reflect.Invalid {
marsh := v.Interface().(json.Marshaler)
return marsh.MarshalJSON()
}
panic(fmt.Sprintf("callMarshalJSON called on type %T that does not have MarshalJSON defined", v.Interface()))
}
// hasUnmarshalJSON will determine if the value or a pointer to this value has
// the UnmarshalJSON method.
func hasUnmarshalJSON(v reflect.Value) bool {
// You can't unmarshal on a non-pointer type.
if v.Kind() != reflect.Ptr {
if !v.CanAddr() {
return false
}
v = v.Addr()
}
if method := v.MethodByName(unmarshalJSON); method.Kind() != reflect.Invalid {
_, ok := v.Interface().(json.Unmarshaler)
return ok
}
return false
}
// hasOmitEmpty indicates if the field has instructed us to not output
// the field if omitempty is set on the tag. tag is the string
// returned by reflect.StructField.Tag().Get().
func hasOmitEmpty(tag string) bool {
sl := strings.Split(tag, ",")
for _, str := range sl {
if str == "omitempty" {
return true
}
}
return false
}
microsoft-authentication-library-for-go-1.0.0/apps/internal/json/json_test.go 0000664 0000000 0000000 00000007317 14420263624 0027465 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package json
import (
"encoding/json"
"fmt"
"testing"
"time"
"github.com/kylelemons/godebug/pretty"
)
type StructA struct {
Name string
ID int `json:"id"`
Meta *StructB
AdditionalFields map[string]interface{}
}
type StructB struct {
Address string
AdditionalFields map[string]interface{}
}
type StructC struct {
Time time.Time
Project StructD
AdditionalFields map[string]interface{}
}
type StructD struct {
Project string
Info StructE
AdditionalFields map[string]interface{}
}
type StructE struct {
Employees int
AdditionalFields map[string]interface{}
}
func TestUnmarshal(t *testing.T) {
now := time.Now()
nowJSON, err := now.MarshalJSON()
if err != nil {
panic(err)
}
tests := []struct {
desc string
b []byte
got interface{}
want interface{}
err bool
}{
{
desc: "receiver not a pointer",
got: StructA{},
b: []byte(`{"content": "value"}`),
err: true,
},
{
desc: "receiver not a pointer to a struct",
got: new(string),
b: []byte(`{"content": "value"}`),
err: true,
},
{
desc: "AdditionalFields not a map",
b: []byte(`{"content": "value"}`),
got: &struct {
AdditionalFields string
}{},
err: true,
},
{
desc: "Success, no json.Unmarshaler types",
b: []byte(
`
{
"Name": "John",
"id": 3,
"Meta": {
"Address": "291 Street",
"unknown0": 3.2
},
"unknown0": 10,
"unknown1": "hello"
}
`,
),
got: &StructA{},
want: &StructA{
Name: "John",
ID: 3,
Meta: &StructB{
Address: "291 Street",
AdditionalFields: map[string]interface{}{
"unknown0": MarshalRaw(3.2),
},
},
AdditionalFields: map[string]interface{}{
"unknown0": MarshalRaw(10),
"unknown1": MarshalRaw("hello"),
},
},
},
{
desc: "Success, a type has json.Unmarshaler",
b: []byte(fmt.Sprintf(`
{
"Time":%s,
"Project": {
"Project":"myProject",
"Info":{
"Employees":2
}
}
}
`, string(nowJSON))),
got: &StructC{},
want: &StructC{
Time: now,
Project: StructD{
Project: "myProject",
Info: StructE{
Employees: 2,
},
},
},
},
}
for _, test := range tests {
err := Unmarshal(test.b, test.got)
switch {
case err == nil && test.err:
t.Errorf("TestUnmarshal(%s): got err == nil, want err != nil", test.desc)
continue
case err != nil && !test.err:
t.Errorf("TestUnmarshal(%s): got err == %s, want err == nil", test.desc, err)
continue
case err != nil:
continue
}
if diff := (&pretty.Config{IncludeUnexported: false}).Compare(test.want, test.got); diff != "" {
t.Errorf("TestUnmarshal(%s): -want/+got:\n%s", test.desc, diff)
}
}
}
func TestIsDelim(t *testing.T) {
tests := []struct {
desc string
token json.Token
want bool
}{
{desc: "Is delim", token: json.Delim('{'), want: true},
{desc: "Not a delim", token: json.Token("{"), want: false},
}
for _, test := range tests {
got := isDelim(test.token)
if got != test.want {
t.Errorf("TestIsDelim(%s): got %v, want %v", test.desc, got, test.want)
}
}
}
func TestDelimIs(t *testing.T) {
tests := []struct {
desc string
token json.Token
delim rune
want bool
}{
{desc: "Token is a match", token: json.Delim('{'), delim: '{', want: true},
{desc: "Token is not a match", token: json.Delim('{'), delim: '}', want: false},
}
for _, test := range tests {
got := delimIs(test.token, test.delim)
if got != test.want {
t.Errorf("TestDelimIs(%s): got %v, want %v", test.desc, got, test.want)
}
}
}
microsoft-authentication-library-for-go-1.0.0/apps/internal/json/mapslice.go 0000664 0000000 0000000 00000016427 14420263624 0027254 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package json
import (
"encoding/json"
"fmt"
"reflect"
)
// unmarshalMap unmarshal's a map.
func unmarshalMap(dec *json.Decoder, m reflect.Value) error {
if m.Kind() != reflect.Ptr || m.Elem().Kind() != reflect.Map {
panic("unmarshalMap called on non-*map value")
}
mapValueType := m.Elem().Type().Elem()
walk := mapWalk{dec: dec, m: m, valueType: mapValueType}
if err := walk.run(); err != nil {
return err
}
return nil
}
type mapWalk struct {
dec *json.Decoder
key string
m reflect.Value
valueType reflect.Type
}
// run runs our decoder state machine.
func (m *mapWalk) run() error {
var state = m.start
var err error
for {
state, err = state()
if err != nil {
return err
}
if state == nil {
return nil
}
}
}
func (m *mapWalk) start() (stateFn, error) {
// maps can have custom unmarshaler's.
if hasUnmarshalJSON(m.m) {
err := m.dec.Decode(m.m.Interface())
if err != nil {
return nil, err
}
return nil, nil
}
// We only want to use this if the map value is:
// *struct/struct/map/slice
// otherwise use standard decode
t, _ := m.valueBaseType()
switch t.Kind() {
case reflect.Struct, reflect.Map, reflect.Slice:
delim, err := m.dec.Token()
if err != nil {
return nil, err
}
// This indicates the value was set to JSON null.
if delim == nil {
return nil, nil
}
if !delimIs(delim, '{') {
return nil, fmt.Errorf("Unmarshal expected opening {, received %v", delim)
}
return m.next, nil
case reflect.Ptr:
return nil, fmt.Errorf("do not support maps with values of '**type' or '*reference")
}
// This is a basic map type, so just use Decode().
if err := m.dec.Decode(m.m.Interface()); err != nil {
return nil, err
}
return nil, nil
}
func (m *mapWalk) next() (stateFn, error) {
if m.dec.More() {
key, err := m.dec.Token()
if err != nil {
return nil, err
}
m.key = key.(string)
return m.storeValue, nil
}
// No more entries, so remove final }.
_, err := m.dec.Token()
if err != nil {
return nil, err
}
return nil, nil
}
func (m *mapWalk) storeValue() (stateFn, error) {
v := m.valueType
for {
switch v.Kind() {
case reflect.Ptr:
v = v.Elem()
continue
case reflect.Struct:
return m.storeStruct, nil
case reflect.Map:
return m.storeMap, nil
case reflect.Slice:
return m.storeSlice, nil
}
return nil, fmt.Errorf("bug: mapWalk.storeValue() called on unsupported type: %v", v.Kind())
}
}
func (m *mapWalk) storeStruct() (stateFn, error) {
v := newValue(m.valueType)
if err := unmarshalStruct(m.dec, v.Interface()); err != nil {
return nil, err
}
if m.valueType.Kind() == reflect.Ptr {
m.m.Elem().SetMapIndex(reflect.ValueOf(m.key), v)
return m.next, nil
}
m.m.Elem().SetMapIndex(reflect.ValueOf(m.key), v.Elem())
return m.next, nil
}
func (m *mapWalk) storeMap() (stateFn, error) {
v := reflect.MakeMap(m.valueType)
ptr := newValue(v.Type())
ptr.Elem().Set(v)
if err := unmarshalMap(m.dec, ptr); err != nil {
return nil, err
}
m.m.Elem().SetMapIndex(reflect.ValueOf(m.key), v)
return m.next, nil
}
func (m *mapWalk) storeSlice() (stateFn, error) {
v := newValue(m.valueType)
if err := unmarshalSlice(m.dec, v); err != nil {
return nil, err
}
m.m.Elem().SetMapIndex(reflect.ValueOf(m.key), v.Elem())
return m.next, nil
}
// valueType returns the underlying Type. So a *struct would yield
// struct, etc...
func (m *mapWalk) valueBaseType() (reflect.Type, bool) {
ptr := false
v := m.valueType
if v.Kind() == reflect.Ptr {
ptr = true
v = v.Elem()
}
return v, ptr
}
// unmarshalSlice unmarshal's the next value, which must be a slice, into
// ptrSlice, which must be a pointer to a slice. newValue() can be use to
// create the slice.
func unmarshalSlice(dec *json.Decoder, ptrSlice reflect.Value) error {
if ptrSlice.Kind() != reflect.Ptr || ptrSlice.Elem().Kind() != reflect.Slice {
panic("unmarshalSlice called on non-*[]slice value")
}
sliceValueType := ptrSlice.Elem().Type().Elem()
walk := sliceWalk{
dec: dec,
s: ptrSlice,
valueType: sliceValueType,
}
if err := walk.run(); err != nil {
return err
}
return nil
}
type sliceWalk struct {
dec *json.Decoder
s reflect.Value // *[]slice
valueType reflect.Type
}
// run runs our decoder state machine.
func (s *sliceWalk) run() error {
var state = s.start
var err error
for {
state, err = state()
if err != nil {
return err
}
if state == nil {
return nil
}
}
}
func (s *sliceWalk) start() (stateFn, error) {
// slices can have custom unmarshaler's.
if hasUnmarshalJSON(s.s) {
err := s.dec.Decode(s.s.Interface())
if err != nil {
return nil, err
}
return nil, nil
}
// We only want to use this if the slice value is:
// []*struct/[]struct/[]map/[]slice
// otherwise use standard decode
t := s.valueBaseType()
switch t.Kind() {
case reflect.Ptr:
return nil, fmt.Errorf("cannot unmarshal into a ** or *")
case reflect.Struct, reflect.Map, reflect.Slice:
delim, err := s.dec.Token()
if err != nil {
return nil, err
}
// This indicates the value was set to nil.
if delim == nil {
return nil, nil
}
if !delimIs(delim, '[') {
return nil, fmt.Errorf("Unmarshal expected opening [, received %v", delim)
}
return s.next, nil
}
if err := s.dec.Decode(s.s.Interface()); err != nil {
return nil, err
}
return nil, nil
}
func (s *sliceWalk) next() (stateFn, error) {
if s.dec.More() {
return s.storeValue, nil
}
// Nothing left in the slice, remove closing ]
_, err := s.dec.Token()
return nil, err
}
func (s *sliceWalk) storeValue() (stateFn, error) {
t := s.valueBaseType()
switch t.Kind() {
case reflect.Ptr:
return nil, fmt.Errorf("do not support 'pointer to pointer' or 'pointer to reference' types")
case reflect.Struct:
return s.storeStruct, nil
case reflect.Map:
return s.storeMap, nil
case reflect.Slice:
return s.storeSlice, nil
}
return nil, fmt.Errorf("bug: sliceWalk.storeValue() called on unsupported type: %v", t.Kind())
}
func (s *sliceWalk) storeStruct() (stateFn, error) {
v := newValue(s.valueType)
if err := unmarshalStruct(s.dec, v.Interface()); err != nil {
return nil, err
}
if s.valueType.Kind() == reflect.Ptr {
s.s.Elem().Set(reflect.Append(s.s.Elem(), v))
return s.next, nil
}
s.s.Elem().Set(reflect.Append(s.s.Elem(), v.Elem()))
return s.next, nil
}
func (s *sliceWalk) storeMap() (stateFn, error) {
v := reflect.MakeMap(s.valueType)
ptr := newValue(v.Type())
ptr.Elem().Set(v)
if err := unmarshalMap(s.dec, ptr); err != nil {
return nil, err
}
s.s.Elem().Set(reflect.Append(s.s.Elem(), v))
return s.next, nil
}
func (s *sliceWalk) storeSlice() (stateFn, error) {
v := newValue(s.valueType)
if err := unmarshalSlice(s.dec, v); err != nil {
return nil, err
}
s.s.Elem().Set(reflect.Append(s.s.Elem(), v.Elem()))
return s.next, nil
}
// valueType returns the underlying Type. So a *struct would yield
// struct, etc...
func (s *sliceWalk) valueBaseType() reflect.Type {
v := s.valueType
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
return v
}
// newValue() returns a new *type that represents type passed.
func newValue(valueType reflect.Type) reflect.Value {
if valueType.Kind() == reflect.Ptr {
return reflect.New(valueType.Elem())
}
return reflect.New(valueType)
}
microsoft-authentication-library-for-go-1.0.0/apps/internal/json/mapslice_test.go 0000664 0000000 0000000 00000014731 14420263624 0030307 0 ustar 00root root 0000000 0000000 package json
import (
"bytes"
"encoding/json"
"reflect"
"testing"
"github.com/kylelemons/godebug/pretty"
)
type StructWithUnmarshal struct {
Name string
}
type StructName struct {
Name string
AdditionalFields map[string]interface{}
}
func (s *StructWithUnmarshal) UnmarshalJSON(b []byte) error {
// Note this looks sill, but you can't use json.Unmarshal
// in an UnmarshalJSON, it causes a recursion loop. This is
// just a simple workaround.
type unmarshal struct {
Name string
}
u := unmarshal{}
err := json.Unmarshal(b, &u)
if err != nil {
panic(err)
}
s.Name = u.Name
return nil
}
func TestUnmarshalMap(t *testing.T) {
tests := []struct {
desc string
input string
got interface{}
want interface{}
err bool
}{
{
desc: "error: struct has no AdditionalFields",
input: `
{
"key": {
"Name": "John"
}
}
`,
got: &map[string]struct{ Name string }{},
err: true,
},
{
desc: "success: basic map[string]interface{}",
input: `
{
"key": {
"Name": "John"
}
}
`,
got: &map[string]interface{}{},
want: map[string]interface{}{
"key": map[string]interface{}{
"Name": "John",
},
},
},
{
desc: "success: struct has UnmarshalJSON",
input: `
{
"key": {
"Name": "John"
}
}
`,
got: &map[string]*StructWithUnmarshal{},
want: map[string]*StructWithUnmarshal{
"key": {
Name: "John",
},
},
},
{
desc: "success: map[string]struct",
input: `
{
"key": {
"Name": "John",
"extra": "extra"
}
}
`,
got: &map[string]StructName{},
want: map[string]StructName{
"key": {
Name: "John",
AdditionalFields: map[string]interface{}{
"extra": MarshalRaw("extra"),
},
},
},
},
{
desc: "success: map[string]*struct",
input: `
{
"key": {
"Name": "John",
"extra": "extra"
}
}
`,
got: &map[string]*StructName{},
want: map[string]*StructName{
"key": {
Name: "John",
AdditionalFields: map[string]interface{}{
"extra": MarshalRaw("extra"),
},
},
},
},
{
desc: "success: map[string][]struct",
input: `
{
"key": [
{
"Name": "John",
"extra": "extra"
}
]
}
`,
got: &map[string][]StructName{},
want: map[string][]StructName{
"key": {
{
Name: "John",
AdditionalFields: map[string]interface{}{
"extra": MarshalRaw("extra"),
},
},
},
},
},
{
desc: "success: map[string][]*struct",
input: `
{
"key": [
{
"Name": "John",
"extra": "extra"
}
]
}
`,
got: &map[string][]*StructName{},
want: map[string][]*StructName{
"key": {
{
Name: "John",
AdditionalFields: map[string]interface{}{
"extra": MarshalRaw("extra"),
},
},
},
},
},
}
for _, test := range tests {
dec := json.NewDecoder(bytes.NewBuffer([]byte(test.input)))
err := unmarshalMap(dec, reflect.ValueOf(test.got))
switch {
case err == nil && test.err:
t.Errorf("TestUnmarshalMap(%s): err == nil, want err != nil", test.desc)
continue
case err != nil && !test.err:
t.Errorf("TestUnmarshalMap(%s): err == %s, want err == nil", test.desc, err)
continue
case err != nil:
continue
}
if diff := pretty.Compare(test.want, test.got); diff != "" {
t.Errorf("TestUnmarshalMap(%s): -want/+got\n%s", test.desc, diff)
}
}
}
func TestUnmarshalSlice(t *testing.T) {
tests := []struct {
desc string
input string
got interface{}
want interface{}
err bool
}{
{
desc: "error: struct has no AdditionalFields",
input: `
[
{
"Name": "John"
}
]
`,
got: new([]struct{ Name string }),
err: true,
},
{
desc: "success: basic slice",
input: `
[
"John",
"Steve"
]
`,
got: new([]string),
want: []string{"John", "Steve"},
},
{
desc: "success: struct has UnmarshalJSON",
input: `
[
{
"Name": "John"
}
]
`,
got: new([]*StructWithUnmarshal),
want: []*StructWithUnmarshal{
{
Name: "John",
},
},
},
{
desc: "success: []struct",
input: `
[
{
"Name": "John",
"extra": "extra"
}
]
`,
got: new([]StructName),
want: []StructName{
{
Name: "John",
AdditionalFields: map[string]interface{}{
"extra": MarshalRaw("extra"),
},
},
},
},
{
desc: "success: []*struct",
input: `
[
{
"Name": "John",
"extra": "extra"
}
]
`,
got: new([]*StructName),
want: []*StructName{
{
Name: "John",
AdditionalFields: map[string]interface{}{
"extra": MarshalRaw("extra"),
},
},
},
},
{
desc: "success: [][]struct",
input: `
[
[
{
"Name": "John",
"extra": "extra"
}
]
]
`,
got: new([][]StructName),
want: [][]StructName{
{
{
Name: "John",
AdditionalFields: map[string]interface{}{
"extra": MarshalRaw("extra"),
},
},
},
},
},
{
desc: "success: [][]*struct",
input: `
[
[
{
"Name": "John",
"extra": "extra"
}
]
]
`,
got: new([][]*StructName),
want: [][]*StructName{
{
{
Name: "John",
AdditionalFields: map[string]interface{}{
"extra": MarshalRaw("extra"),
},
},
},
},
},
{
desc: "success: []map[string]struct",
input: `
[
{
"key": {
"Name": "John",
"extra": "extra"
}
}
]
`,
got: new([]map[string]StructName),
want: []map[string]StructName{
{
"key": {
Name: "John",
AdditionalFields: map[string]interface{}{
"extra": MarshalRaw("extra"),
},
},
},
},
},
}
for _, test := range tests {
dec := json.NewDecoder(bytes.NewBuffer([]byte(test.input)))
err := unmarshalSlice(dec, reflect.ValueOf(test.got))
switch {
case err == nil && test.err:
t.Errorf("TestUnmarshalSlice(%s): err == nil, want err != nil", test.desc)
continue
case err != nil && !test.err:
t.Errorf("TestUnmarshalSlice(%s): err == %s, want err == nil", test.desc, err)
continue
case err != nil:
continue
}
if diff := pretty.Compare(test.want, test.got); diff != "" {
t.Errorf("TestUnmarshalSlice(%s): -want/+got\n%s", test.desc, diff)
}
}
}
microsoft-authentication-library-for-go-1.0.0/apps/internal/json/marshal.go 0000664 0000000 0000000 00000017753 14420263624 0027111 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package json
import (
"bytes"
"encoding/json"
"fmt"
"reflect"
"unicode"
)
// marshalStruct takes in i, which must be a *struct or struct and marshals its content
// as JSON into buff (sometimes with writes to buff directly, sometimes via enc).
// This call is recursive for all fields of *struct or struct type.
func marshalStruct(v reflect.Value, buff *bytes.Buffer, enc *json.Encoder) error {
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
// We only care about custom Marshalling a struct.
if v.Kind() != reflect.Struct {
return fmt.Errorf("bug: marshal() received a non *struct or struct, received type %T", v.Interface())
}
if hasMarshalJSON(v) {
b, err := callMarshalJSON(v)
if err != nil {
return err
}
buff.Write(b)
return nil
}
t := v.Type()
// If it has an AdditionalFields field make sure its the right type.
f := v.FieldByName(addField)
if f.Kind() != reflect.Invalid {
if f.Kind() != reflect.Map {
return fmt.Errorf("type %T has field 'AdditionalFields' that is not a map[string]interface{}", v.Interface())
}
if !f.Type().AssignableTo(mapStrInterType) {
return fmt.Errorf("type %T has field 'AdditionalFields' that is not a map[string]interface{}", v.Interface())
}
}
translator, err := findFields(v)
if err != nil {
return err
}
buff.WriteByte(leftBrace)
for x := 0; x < v.NumField(); x++ {
field := v.Field(x)
// We don't access private fields.
if unicode.IsLower(rune(t.Field(x).Name[0])) {
continue
}
if t.Field(x).Name == addField {
if v.Field(x).Len() > 0 {
if err := writeAddFields(field.Interface(), buff, enc); err != nil {
return err
}
buff.WriteByte(comma)
}
continue
}
// If they have omitempty set, we don't write out the field if
// it is the zero value.
if hasOmitEmpty(t.Field(x).Tag.Get("json")) {
if v.Field(x).IsZero() {
continue
}
}
// Write out the field name part.
jsonName := translator.jsonName(t.Field(x).Name)
buff.WriteString(fmt.Sprintf("%q:", jsonName))
if field.Kind() == reflect.Ptr {
field = field.Elem()
}
if err := marshalStructField(field, buff, enc); err != nil {
return err
}
}
buff.Truncate(buff.Len() - 1) // Remove final comma
buff.WriteByte(rightBrace)
return nil
}
func marshalStructField(field reflect.Value, buff *bytes.Buffer, enc *json.Encoder) error {
// Determine if we need a trailing comma.
defer buff.WriteByte(comma)
switch field.Kind() {
// If it was a *struct or struct, we need to recursively all marshal().
case reflect.Struct:
if field.CanAddr() {
field = field.Addr()
}
return marshalStruct(field, buff, enc)
case reflect.Map:
return marshalMap(field, buff, enc)
case reflect.Slice:
return marshalSlice(field, buff, enc)
}
// It is just a basic type, so encode it.
if err := enc.Encode(field.Interface()); err != nil {
return err
}
buff.Truncate(buff.Len() - 1) // Remove Encode() added \n
return nil
}
func marshalMap(v reflect.Value, buff *bytes.Buffer, enc *json.Encoder) error {
if v.Kind() != reflect.Map {
return fmt.Errorf("bug: marshalMap() called on %T", v.Interface())
}
if v.Len() == 0 {
buff.WriteByte(leftBrace)
buff.WriteByte(rightBrace)
return nil
}
encoder := mapEncode{m: v, buff: buff, enc: enc}
return encoder.run()
}
type mapEncode struct {
m reflect.Value
buff *bytes.Buffer
enc *json.Encoder
valueBaseType reflect.Type
}
// run runs our encoder state machine.
func (m *mapEncode) run() error {
var state = m.start
var err error
for {
state, err = state()
if err != nil {
return err
}
if state == nil {
return nil
}
}
}
func (m *mapEncode) start() (stateFn, error) {
if hasMarshalJSON(m.m) {
b, err := callMarshalJSON(m.m)
if err != nil {
return nil, err
}
m.buff.Write(b)
return nil, nil
}
valueBaseType := m.m.Type().Elem()
if valueBaseType.Kind() == reflect.Ptr {
valueBaseType = valueBaseType.Elem()
}
m.valueBaseType = valueBaseType
switch valueBaseType.Kind() {
case reflect.Ptr:
return nil, fmt.Errorf("Marshal does not support ** or *")
case reflect.Struct, reflect.Map, reflect.Slice:
return m.encode, nil
}
// If the map value doesn't have a struct/map/slice, just Encode() it.
if err := m.enc.Encode(m.m.Interface()); err != nil {
return nil, err
}
m.buff.Truncate(m.buff.Len() - 1) // Remove Encode() added \n
return nil, nil
}
func (m *mapEncode) encode() (stateFn, error) {
m.buff.WriteByte(leftBrace)
iter := m.m.MapRange()
for iter.Next() {
// Write the key.
k := iter.Key()
m.buff.WriteString(fmt.Sprintf("%q:", k.String()))
v := iter.Value()
switch m.valueBaseType.Kind() {
case reflect.Struct:
if v.CanAddr() {
v = v.Addr()
}
if err := marshalStruct(v, m.buff, m.enc); err != nil {
return nil, err
}
case reflect.Map:
if err := marshalMap(v, m.buff, m.enc); err != nil {
return nil, err
}
case reflect.Slice:
if err := marshalSlice(v, m.buff, m.enc); err != nil {
return nil, err
}
default:
panic(fmt.Sprintf("critical bug: mapEncode.encode() called with value base type: %v", m.valueBaseType.Kind()))
}
m.buff.WriteByte(comma)
}
m.buff.Truncate(m.buff.Len() - 1) // Remove final comma
m.buff.WriteByte(rightBrace)
return nil, nil
}
func marshalSlice(v reflect.Value, buff *bytes.Buffer, enc *json.Encoder) error {
if v.Kind() != reflect.Slice {
return fmt.Errorf("bug: marshalSlice() called on %T", v.Interface())
}
if v.Len() == 0 {
buff.WriteByte(leftParen)
buff.WriteByte(rightParen)
return nil
}
encoder := sliceEncode{s: v, buff: buff, enc: enc}
return encoder.run()
}
type sliceEncode struct {
s reflect.Value
buff *bytes.Buffer
enc *json.Encoder
valueBaseType reflect.Type
}
// run runs our encoder state machine.
func (s *sliceEncode) run() error {
var state = s.start
var err error
for {
state, err = state()
if err != nil {
return err
}
if state == nil {
return nil
}
}
}
func (s *sliceEncode) start() (stateFn, error) {
if hasMarshalJSON(s.s) {
b, err := callMarshalJSON(s.s)
if err != nil {
return nil, err
}
s.buff.Write(b)
return nil, nil
}
valueBaseType := s.s.Type().Elem()
if valueBaseType.Kind() == reflect.Ptr {
valueBaseType = valueBaseType.Elem()
}
s.valueBaseType = valueBaseType
switch valueBaseType.Kind() {
case reflect.Ptr:
return nil, fmt.Errorf("Marshal does not support ** or *")
case reflect.Struct, reflect.Map, reflect.Slice:
return s.encode, nil
}
// If the map value doesn't have a struct/map/slice, just Encode() it.
if err := s.enc.Encode(s.s.Interface()); err != nil {
return nil, err
}
s.buff.Truncate(s.buff.Len() - 1) // Remove Encode added \n
return nil, nil
}
func (s *sliceEncode) encode() (stateFn, error) {
s.buff.WriteByte(leftParen)
for i := 0; i < s.s.Len(); i++ {
v := s.s.Index(i)
switch s.valueBaseType.Kind() {
case reflect.Struct:
if v.CanAddr() {
v = v.Addr()
}
if err := marshalStruct(v, s.buff, s.enc); err != nil {
return nil, err
}
case reflect.Map:
if err := marshalMap(v, s.buff, s.enc); err != nil {
return nil, err
}
case reflect.Slice:
if err := marshalSlice(v, s.buff, s.enc); err != nil {
return nil, err
}
default:
panic(fmt.Sprintf("critical bug: mapEncode.encode() called with value base type: %v", s.valueBaseType.Kind()))
}
s.buff.WriteByte(comma)
}
s.buff.Truncate(s.buff.Len() - 1) // Remove final comma
s.buff.WriteByte(rightParen)
return nil, nil
}
// writeAddFields writes the AdditionalFields struct field out to JSON as field
// values. i must be a map[string]interface{} or this will panic.
func writeAddFields(i interface{}, buff *bytes.Buffer, enc *json.Encoder) error {
m := i.(map[string]interface{})
x := 0
for k, v := range m {
buff.WriteString(fmt.Sprintf("%q:", k))
if err := enc.Encode(v); err != nil {
return err
}
buff.Truncate(buff.Len() - 1) // Remove Encode() added \n
if x+1 != len(m) {
buff.WriteByte(comma)
}
x++
}
return nil
}
microsoft-authentication-library-for-go-1.0.0/apps/internal/json/marshal_test.go 0000664 0000000 0000000 00000012020 14420263624 0030126 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package json
import (
"encoding/json"
"testing"
"github.com/kylelemons/godebug/pretty"
)
func TestMarshalStruct(t *testing.T) {
tests := []struct {
desc string
value interface{}
want map[string]interface{}
err bool
}{
{
desc: "struct with no additional fields",
value: struct {
Name string
Int int
}{
Name: "my name",
Int: 5,
},
want: map[string]interface{}{
"Name": "my name",
"Int": 5,
},
},
{
desc: "*struct with AdditionalFields",
value: &struct {
Name string
Int int
AdditionalFields map[string]interface{} `json:"-"`
}{
Name: "John Doak",
Int: 45,
AdditionalFields: map[string]interface{}{
"Hello": "World",
"Float": 3.2,
},
},
want: map[string]interface{}{
"Name": "John Doak",
"Int": 45,
"Float": 3.2,
"Hello": "World",
},
},
{
desc: "AdditionalFields is not a map",
value: struct {
AdditionalFields string `json:"-"`
}{
AdditionalFields: "hello",
},
err: true,
},
{
desc: "AdditionalFields is not a map[string]interface{}",
value: struct {
AdditionalFields map[string]string `json:"-"`
}{
AdditionalFields: map[string]string{
"Hello": "World",
},
},
err: true,
},
{
desc: "Multiple Structs",
value: &StructA{
Name: "John",
ID: 3,
Meta: &StructB{
Address: "291 Street",
AdditionalFields: map[string]interface{}{
"unknown0": MarshalRaw(3.2),
},
},
AdditionalFields: map[string]interface{}{
"unknown0": MarshalRaw(10),
"unknown1": MarshalRaw("hello"),
},
},
want: map[string]interface{}{
"Name": "John",
"id": 3,
"Meta": map[string]interface{}{
"Address": "291 Street",
"unknown0": 3.2,
},
"unknown0": 10,
"unknown1": "hello",
},
},
{
desc: "Struct with map[string]interface{}",
value: struct {
Name string
Map map[string]interface{}
AdditionalFields map[string]interface{}
}{
Name: "John",
Map: map[string]interface{}{
"key": "value",
},
},
want: map[string]interface{}{
"Name": "John",
"Map": map[string]interface{}{
"key": "value",
},
},
},
{
desc: "Struct with map[string]struct{}",
value: struct {
Name string
Map map[string]StructB
AdditionalFields map[string]interface{}
}{
Name: "John",
Map: map[string]StructB{
"key": {
Address: "addr",
},
},
},
want: map[string]interface{}{
"Name": "John",
"Map": map[string]interface{}{
"key": map[string]interface{}{
"Address": "addr",
},
},
},
},
{
desc: "Struct with map[string][]",
value: struct {
Name string
Map map[string]interface{}
AdditionalFields map[string]interface{}
}{
Name: "John",
Map: map[string]interface{}{
"key": []string{
"apples",
},
},
},
want: map[string]interface{}{
"Name": "John",
"Map": map[string]interface{}{
"key": []string{"apples"},
},
},
},
{
desc: "Struct with map[string][]struct",
value: struct {
Name string
Map map[string][]StructB
AdditionalFields map[string]interface{}
}{
Name: "John",
Map: map[string][]StructB{
"key": {
{Address: "addr"},
},
},
},
want: map[string]interface{}{
"Name": "John",
"Map": map[string]interface{}{
"key": []interface{}{
map[string]interface{}{
"Address": "addr",
},
},
},
},
},
}
for _, test := range tests {
b, err := Marshal(test.value)
switch {
case err == nil && test.err:
t.Errorf("TestMarshal(%s): got err == nil, want err != nil", test.desc)
continue
case err != nil && !test.err:
t.Errorf("TestMarshal(%s): got err == %s, want err == nil", test.desc, err)
continue
case err != nil:
continue
}
got := map[string]interface{}{}
if err := json.Unmarshal(b, &got); err != nil {
t.Errorf("TestMarshal(%s): Marshal produced invalid JSON:\n%s\n%s", test.desc, err, string(b))
continue
}
if diff := pretty.Compare(test.want, got); diff != "" {
t.Errorf("TestMarshal(%s): -want/+got:\n%s", test.desc, diff)
}
}
}
func TestEmptyTypes(t *testing.T) {
type structA struct {
EmptyMap map[string]bool
EmptySlice []string
Slice []string
EmptyInt int
Int int
AdditionalFields map[string]interface{}
}
val := structA{
EmptyMap: map[string]bool{},
Slice: []string{"hello"},
Int: 1,
}
b, err := Marshal(val)
if err != nil {
t.Fatalf("TestEmptyTypes: unexpected error on Marshal: %v", err)
}
got := structA{}
if err := Unmarshal(b, &got); err != nil {
t.Fatalf("TestEmptyTypes: unexpected error when Umarshalling: %v", err)
}
if diff := pretty.Compare(got, val); diff != "" {
t.Fatalf("TestEmptyTypes: -want/+got:\n%s", diff)
}
}
microsoft-authentication-library-for-go-1.0.0/apps/internal/json/struct.go 0000664 0000000 0000000 00000016414 14420263624 0026777 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package json
import (
"encoding/json"
"fmt"
"reflect"
"strings"
)
func unmarshalStruct(jdec *json.Decoder, i interface{}) error {
v := reflect.ValueOf(i)
if v.Kind() != reflect.Ptr {
return fmt.Errorf("Unmarshal() received type %T, which is not a *struct", i)
}
v = v.Elem()
if v.Kind() != reflect.Struct {
return fmt.Errorf("Unmarshal() received type %T, which is not a *struct", i)
}
if hasUnmarshalJSON(v) {
// Indicates that this type has a custom Unmarshaler.
return jdec.Decode(v.Addr().Interface())
}
f := v.FieldByName(addField)
if f.Kind() == reflect.Invalid {
return fmt.Errorf("Unmarshal(%T) only supports structs that have the field AdditionalFields or implements json.Unmarshaler", i)
}
if f.Kind() != reflect.Map || !f.Type().AssignableTo(mapStrInterType) {
return fmt.Errorf("type %T has field 'AdditionalFields' that is not a map[string]interface{}", i)
}
dec := newDecoder(jdec, v)
return dec.run()
}
type decoder struct {
dec *json.Decoder
value reflect.Value // This will be a reflect.Struct
translator translateFields
key string
}
func newDecoder(dec *json.Decoder, value reflect.Value) *decoder {
return &decoder{value: value, dec: dec}
}
// run runs our decoder state machine.
func (d *decoder) run() error {
var state = d.start
var err error
for {
state, err = state()
if err != nil {
return err
}
if state == nil {
return nil
}
}
}
// start looks for our opening delimeter '{' and then transitions to looping through our fields.
func (d *decoder) start() (stateFn, error) {
var err error
d.translator, err = findFields(d.value)
if err != nil {
return nil, err
}
delim, err := d.dec.Token()
if err != nil {
return nil, err
}
if !delimIs(delim, '{') {
return nil, fmt.Errorf("Unmarshal expected opening {, received %v", delim)
}
return d.next, nil
}
// next gets the next struct field name from the raw json or stops the machine if we get our closing }.
func (d *decoder) next() (stateFn, error) {
if !d.dec.More() {
// Remove the closing }.
if _, err := d.dec.Token(); err != nil {
return nil, err
}
return nil, nil
}
key, err := d.dec.Token()
if err != nil {
return nil, err
}
d.key = key.(string)
return d.storeValue, nil
}
// storeValue takes the next value and stores it our struct. If the field can't be found
// in the struct, it pushes the operation to storeAdditional().
func (d *decoder) storeValue() (stateFn, error) {
goName := d.translator.goName(d.key)
if goName == "" {
goName = d.key
}
// We don't have the field in the struct, so it goes in AdditionalFields.
f := d.value.FieldByName(goName)
if f.Kind() == reflect.Invalid {
return d.storeAdditional, nil
}
// Indicates that this type has a custom Unmarshaler.
if hasUnmarshalJSON(f) {
err := d.dec.Decode(f.Addr().Interface())
if err != nil {
return nil, err
}
return d.next, nil
}
t, isPtr, err := fieldBaseType(d.value, goName)
if err != nil {
return nil, fmt.Errorf("type(%s) had field(%s) %w", d.value.Type().Name(), goName, err)
}
switch t.Kind() {
// We need to recursively call ourselves on any *struct or struct.
case reflect.Struct:
if isPtr {
if f.IsNil() {
f.Set(reflect.New(t))
}
} else {
f = f.Addr()
}
if err := unmarshalStruct(d.dec, f.Interface()); err != nil {
return nil, err
}
return d.next, nil
case reflect.Map:
v := reflect.MakeMap(f.Type())
ptr := newValue(f.Type())
ptr.Elem().Set(v)
if err := unmarshalMap(d.dec, ptr); err != nil {
return nil, err
}
f.Set(ptr.Elem())
return d.next, nil
case reflect.Slice:
v := reflect.MakeSlice(f.Type(), 0, 0)
ptr := newValue(f.Type())
ptr.Elem().Set(v)
if err := unmarshalSlice(d.dec, ptr); err != nil {
return nil, err
}
f.Set(ptr.Elem())
return d.next, nil
}
if !isPtr {
f = f.Addr()
}
// For values that are pointers, we need them to be non-nil in order
// to decode into them.
if f.IsNil() {
f.Set(reflect.New(t))
}
if err := d.dec.Decode(f.Interface()); err != nil {
return nil, err
}
return d.next, nil
}
// storeAdditional pushes the key/value into our .AdditionalFields map.
func (d *decoder) storeAdditional() (stateFn, error) {
rw := json.RawMessage{}
if err := d.dec.Decode(&rw); err != nil {
return nil, err
}
field := d.value.FieldByName(addField)
if field.IsNil() {
field.Set(reflect.MakeMap(field.Type()))
}
field.SetMapIndex(reflect.ValueOf(d.key), reflect.ValueOf(rw))
return d.next, nil
}
func fieldBaseType(v reflect.Value, fieldName string) (t reflect.Type, isPtr bool, err error) {
sf, ok := v.Type().FieldByName(fieldName)
if !ok {
return nil, false, fmt.Errorf("bug: fieldBaseType() lookup of field(%s) on type(%s): do not have field", fieldName, v.Type().Name())
}
t = sf.Type
if t.Kind() == reflect.Ptr {
t = t.Elem()
isPtr = true
}
if t.Kind() == reflect.Ptr {
return nil, isPtr, fmt.Errorf("received pointer to pointer type, not supported")
}
return t, isPtr, nil
}
type translateField struct {
jsonName string
goName string
}
// translateFields is a list of translateFields with a handy lookup method.
type translateFields []translateField
// goName loops through a list of fields looking for one contaning the jsonName and
// returning the goName. If not found, returns the empty string.
// Note: not a map because at this size slices are faster even in tight loops.
func (t translateFields) goName(jsonName string) string {
for _, entry := range t {
if entry.jsonName == jsonName {
return entry.goName
}
}
return ""
}
// jsonName loops through a list of fields looking for one contaning the goName and
// returning the jsonName. If not found, returns the empty string.
// Note: not a map because at this size slices are faster even in tight loops.
func (t translateFields) jsonName(goName string) string {
for _, entry := range t {
if entry.goName == goName {
return entry.jsonName
}
}
return ""
}
var umarshalerType = reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()
// findFields parses a struct and writes the field tags for lookup. It will return an error
// if any field has a type of *struct or struct that does not implement json.Marshaler.
func findFields(v reflect.Value) (translateFields, error) {
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
if v.Kind() != reflect.Struct {
return nil, fmt.Errorf("findFields received a %s type, expected *struct or struct", v.Type().Name())
}
tfs := make([]translateField, 0, v.NumField())
for i := 0; i < v.NumField(); i++ {
tf := translateField{
goName: v.Type().Field(i).Name,
jsonName: parseTag(v.Type().Field(i).Tag.Get("json")),
}
switch tf.jsonName {
case "", "-":
tf.jsonName = tf.goName
}
tfs = append(tfs, tf)
f := v.Field(i)
if f.Kind() == reflect.Ptr {
f = f.Elem()
}
if f.Kind() == reflect.Struct {
if f.Type().Implements(umarshalerType) {
return nil, fmt.Errorf("struct type %q which has field %q which "+
"doesn't implement json.Unmarshaler", v.Type().Name(), v.Type().Field(i).Name)
}
}
}
return tfs, nil
}
// parseTag just returns the first entry in the tag. tag is the string
// returned by reflect.StructField.Tag().Get().
func parseTag(tag string) string {
if idx := strings.Index(tag, ","); idx != -1 {
return tag[:idx]
}
return tag
}
microsoft-authentication-library-for-go-1.0.0/apps/internal/json/struct_test.go 0000664 0000000 0000000 00000014623 14420263624 0030036 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package json
import (
"bytes"
"encoding/json"
"reflect"
"runtime"
"testing"
"github.com/kylelemons/godebug/pretty"
)
func TestDecoderStart(t *testing.T) {
tests := []struct {
desc string
b []byte
i interface{}
stateFn stateFn
err bool
}{
{
desc: "No content to decode",
i: &StructA{},
stateFn: nil,
err: true,
},
{
desc: "No opening brace",
b: []byte("3"),
i: &StructA{},
stateFn: nil,
err: true,
},
{
desc: "Success",
b: []byte(`{"Name": "value"}`),
i: &StructA{},
stateFn: (new(decoder).next),
},
}
for _, test := range tests {
dec := newDecoder(json.NewDecoder(bytes.NewBuffer(test.b)), reflect.ValueOf(test.i))
stateFn, err := dec.start()
switch {
case err == nil && test.err:
t.Errorf("TestDecoderStart(%s): got err == nil, want err != nil", test.desc)
continue
case err != nil && !test.err:
t.Errorf("TestDecoderStart(%s): got err == %s, want err == nil", test.desc, err)
continue
case err != nil:
continue
}
gotStateFn := runtime.FuncForPC(reflect.ValueOf(stateFn).Pointer()).Name()
wantStateFn := runtime.FuncForPC(reflect.ValueOf(test.stateFn).Pointer()).Name()
if gotStateFn != wantStateFn {
t.Errorf("TestDecoderStart(%s): got(stateFn) %s, want %s", test.desc, gotStateFn, wantStateFn)
}
}
}
func TestDecoderNext(t *testing.T) {
tests := []struct {
desc string
b []byte
// advToken advanced the decoder this may Token() calls, as the decoder only works
// on well formed JSON.
advToken int
i interface{}
key string
stateFn stateFn
err bool
}{
{
desc: "No content to decode",
i: &StructA{},
stateFn: nil,
err: true,
},
{
desc: "Bad ] found",
b: []byte("{]"),
advToken: 1,
i: &StructA{},
stateFn: nil,
err: true,
},
{
desc: "Closing brace",
b: []byte("{}"),
advToken: 1,
i: &StructA{},
stateFn: nil,
err: false,
},
{
desc: "Success",
b: []byte(`{"Name": "value"}`),
advToken: 1,
i: &StructA{},
key: "Name",
stateFn: (new(decoder).storeValue),
},
}
for _, test := range tests {
dec := newDecoder(json.NewDecoder(bytes.NewBuffer(test.b)), reflect.ValueOf(test.i))
for i := 0; i < test.advToken; i++ {
if _, err := dec.dec.Token(); err != nil {
panic(err)
}
}
stateFn, err := dec.next()
switch {
case err == nil && test.err:
t.Errorf("TestDecoderNext(%s): got err == nil, want err != nil", test.desc)
continue
case err != nil && !test.err:
t.Errorf("TestDecoderNext(%s): got err == %s, want err == nil", test.desc, err)
continue
case err != nil:
continue
}
if dec.key != test.key {
t.Errorf("TestDecoderNext(%s): got(.key) %s, want %s", test.desc, dec.key, test.key)
}
gotStateFn := runtime.FuncForPC(reflect.ValueOf(stateFn).Pointer()).Name()
wantStateFn := runtime.FuncForPC(reflect.ValueOf(test.stateFn).Pointer()).Name()
if gotStateFn != wantStateFn {
t.Errorf("TestDecoderNext(%s): got(stateFn) %s, want %s", test.desc, gotStateFn, wantStateFn)
}
}
}
func TestDecoderStoreValue(t *testing.T) {
tests := []struct {
desc string
b []byte
want StructA
stateFn stateFn
}{
{
desc: "Field found, no struct tag",
b: []byte(`{"Name": "myName"}`),
want: StructA{Name: "myName"},
stateFn: (new(decoder).next),
},
{
desc: "Field found, using struct tag",
b: []byte(`{"id": 3}`),
want: StructA{ID: 3},
stateFn: (new(decoder).next),
},
{
desc: "Field not found, go to storeAdditional()",
b: []byte(`{"blah": 3}`),
want: StructA{},
stateFn: (new(decoder).storeAdditional),
},
}
for _, test := range tests {
got := StructA{}
dec := newDecoder(json.NewDecoder(bytes.NewBuffer(test.b)), reflect.ValueOf(&got).Elem())
_, err := dec.start() // populates our translator field
if err != nil {
panic(err)
}
_, err = dec.next()
if err != nil {
panic(err)
}
stateFn, err := dec.storeValue()
if err != nil {
t.Errorf("TestDecoderStoreValue(%s): got err == %s, want err == nil", test.desc, err)
continue
}
if diff := pretty.Compare(test.want, got); diff != "" {
t.Errorf("TestDecoderStoreValue(%s): -want/+got:\n%s", test.desc, diff)
continue
}
gotStateFn := runtime.FuncForPC(reflect.ValueOf(stateFn).Pointer()).Name()
wantStateFn := runtime.FuncForPC(reflect.ValueOf(test.stateFn).Pointer()).Name()
if gotStateFn != wantStateFn {
t.Errorf("TestDecoderStoreValue(%s): got(stateFn) %s, want %s", test.desc, gotStateFn, wantStateFn)
}
}
}
func TestDecoderStoreAdditional(t *testing.T) {
tests := []struct {
desc string
b []byte
got StructA
want StructA
stateFn stateFn
}{
{
desc: "Map not initialized",
b: []byte(`{"blah": "whatever"}`),
got: StructA{},
want: StructA{
AdditionalFields: map[string]interface{}{
"blah": json.RawMessage(`"whatever"`),
},
},
stateFn: (new(decoder).next),
},
{
desc: "Map exists",
b: []byte(`{"blah": "whatever"}`),
got: StructA{
AdditionalFields: map[string]interface{}{
"else": json.RawMessage(`"if"`),
},
},
want: StructA{
AdditionalFields: map[string]interface{}{
"else": json.RawMessage(`"if"`),
"blah": json.RawMessage(`"whatever"`),
},
},
stateFn: (new(decoder).next),
},
}
for _, test := range tests {
dec := newDecoder(json.NewDecoder(bytes.NewBuffer(test.b)), reflect.ValueOf(&test.got).Elem())
_, err := dec.start() // populates our translator field
if err != nil {
panic(err)
}
_, err = dec.next()
if err != nil {
panic(err)
}
stateFn, err := dec.storeAdditional()
if err != nil {
t.Errorf("TestDecoderStoreAdditional(%s): got err == %s, want err == nil", test.desc, err)
continue
}
if diff := pretty.Compare(test.want, test.got); diff != "" {
t.Errorf("TestDecoderStoreAdditional(%s): -want/+got:\n%s", test.desc, diff)
continue
}
gotStateFn := runtime.FuncForPC(reflect.ValueOf(stateFn).Pointer()).Name()
wantStateFn := runtime.FuncForPC(reflect.ValueOf(test.stateFn).Pointer()).Name()
if gotStateFn != wantStateFn {
t.Errorf("TestDecoderStoreAdditional(%s): got(stateFn) %s, want %s", test.desc, gotStateFn, wantStateFn)
}
}
}
microsoft-authentication-library-for-go-1.0.0/apps/internal/json/types/ 0000775 0000000 0000000 00000000000 14420263624 0026262 5 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/apps/internal/json/types/time/ 0000775 0000000 0000000 00000000000 14420263624 0027220 5 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/apps/internal/json/types/time/time.go 0000664 0000000 0000000 00000004344 14420263624 0030512 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
// Package time provides for custom types to translate time from JSON and other formats
// into time.Time objects.
package time
import (
"fmt"
"strconv"
"strings"
"time"
)
// Unix provides a type that can marshal and unmarshal a string representation
// of the unix epoch into a time.Time object.
type Unix struct {
T time.Time
}
// MarshalJSON implements encoding/json.MarshalJSON().
func (u Unix) MarshalJSON() ([]byte, error) {
if u.T.IsZero() {
return []byte(""), nil
}
return []byte(fmt.Sprintf("%q", strconv.FormatInt(u.T.Unix(), 10))), nil
}
// UnmarshalJSON implements encoding/json.UnmarshalJSON().
func (u *Unix) UnmarshalJSON(b []byte) error {
i, err := strconv.Atoi(strings.Trim(string(b), `"`))
if err != nil {
return fmt.Errorf("unix time(%s) could not be converted from string to int: %w", string(b), err)
}
u.T = time.Unix(int64(i), 0)
return nil
}
// DurationTime provides a type that can marshal and unmarshal a string representation
// of a duration from now into a time.Time object.
// Note: I'm not sure this is the best way to do this. What happens is we get a field
// called "expires_in" that represents the seconds from now that this expires. We
// turn that into a time we call .ExpiresOn. But maybe we should be recording
// when the token was received at .TokenRecieved and .ExpiresIn should remain as a duration.
// Then we could have a method called ExpiresOn(). Honestly, the whole thing is
// bad because the server doesn't return a concrete time. I think this is
// cleaner, but its not great either.
type DurationTime struct {
T time.Time
}
// MarshalJSON implements encoding/json.MarshalJSON().
func (d DurationTime) MarshalJSON() ([]byte, error) {
if d.T.IsZero() {
return []byte(""), nil
}
dt := time.Until(d.T)
return []byte(fmt.Sprintf("%d", int64(dt*time.Second))), nil
}
// UnmarshalJSON implements encoding/json.UnmarshalJSON().
func (d *DurationTime) UnmarshalJSON(b []byte) error {
i, err := strconv.Atoi(strings.Trim(string(b), `"`))
if err != nil {
return fmt.Errorf("unix time(%s) could not be converted from string to int: %w", string(b), err)
}
d.T = time.Now().Add(time.Duration(i) * time.Second)
return nil
}
microsoft-authentication-library-for-go-1.0.0/apps/internal/local/ 0000775 0000000 0000000 00000000000 14420263624 0025237 5 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/apps/internal/local/server.go 0000664 0000000 0000000 00000010014 14420263624 0027070 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
// Package local contains a local HTTP server used with interactive authentication.
package local
import (
"context"
"fmt"
"net"
"net/http"
"strconv"
"strings"
"time"
)
var okPage = []byte(`
Authentication Complete
Authentication complete. You can return to the application. Feel free to close this browser tab.
`)
const failPage = `
Authentication Failed
Authentication failed. You can return to the application. Feel free to close this browser tab.
Error details: error %s error_description: %s
`
// Result is the result from the redirect.
type Result struct {
// Code is the code sent by the authority server.
Code string
// Err is set if there was an error.
Err error
}
// Server is an HTTP server.
type Server struct {
// Addr is the address the server is listening on.
Addr string
resultCh chan Result
s *http.Server
reqState string
}
// New creates a local HTTP server and starts it.
func New(reqState string, port int) (*Server, error) {
var l net.Listener
var err error
var portStr string
if port > 0 {
// use port provided by caller
l, err = net.Listen("tcp", fmt.Sprintf("localhost:%d", port))
portStr = strconv.FormatInt(int64(port), 10)
} else {
// find a free port
for i := 0; i < 10; i++ {
l, err = net.Listen("tcp", "localhost:0")
if err != nil {
continue
}
addr := l.Addr().String()
portStr = addr[strings.LastIndex(addr, ":")+1:]
break
}
}
if err != nil {
return nil, err
}
serv := &Server{
Addr: fmt.Sprintf("http://localhost:%s", portStr),
s: &http.Server{Addr: "localhost:0", ReadHeaderTimeout: time.Second},
reqState: reqState,
resultCh: make(chan Result, 1),
}
serv.s.Handler = http.HandlerFunc(serv.handler)
if err := serv.start(l); err != nil {
return nil, err
}
return serv, nil
}
func (s *Server) start(l net.Listener) error {
go func() {
err := s.s.Serve(l)
if err != nil {
select {
case s.resultCh <- Result{Err: err}:
default:
}
}
}()
return nil
}
// Result gets the result of the redirect operation. Once a single result is returned, the server
// is shutdown. ctx deadline will be honored.
func (s *Server) Result(ctx context.Context) Result {
select {
case <-ctx.Done():
return Result{Err: ctx.Err()}
case r := <-s.resultCh:
return r
}
}
// Shutdown shuts down the server.
func (s *Server) Shutdown() {
// Note: You might get clever and think you can do this in handler() as a defer, you can't.
_ = s.s.Shutdown(context.Background())
}
func (s *Server) putResult(r Result) {
select {
case s.resultCh <- r:
default:
}
}
func (s *Server) handler(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
headerErr := q.Get("error")
if headerErr != "" {
desc := q.Get("error_description")
// Note: It is a little weird we handle some errors by not going to the failPage. If they all should,
// change this to s.error() and make s.error() write the failPage instead of an error code.
_, _ = w.Write([]byte(fmt.Sprintf(failPage, headerErr, desc)))
s.putResult(Result{Err: fmt.Errorf(desc)})
return
}
respState := q.Get("state")
switch respState {
case s.reqState:
case "":
s.error(w, http.StatusInternalServerError, "server didn't send OAuth state")
return
default:
s.error(w, http.StatusInternalServerError, "mismatched OAuth state, req(%s), resp(%s)", s.reqState, respState)
return
}
code := q.Get("code")
if code == "" {
s.error(w, http.StatusInternalServerError, "authorization code missing in query string")
return
}
_, _ = w.Write(okPage)
s.putResult(Result{Code: code})
}
func (s *Server) error(w http.ResponseWriter, code int, str string, i ...interface{}) {
err := fmt.Errorf(str, i...)
http.Error(w, err.Error(), code)
s.putResult(Result{Err: err})
}
microsoft-authentication-library-for-go-1.0.0/apps/internal/local/server_test.go 0000664 0000000 0000000 00000006602 14420263624 0030137 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package local
import (
"context"
"io"
"net/http"
"net/url"
"strings"
"testing"
"time"
"github.com/kylelemons/godebug/pretty"
)
func TestServer(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
tests := []struct {
desc string
reqState string
port int
q url.Values
failPage bool
statusCode int
}{
{
desc: "Error: Query Values has 'error' key",
reqState: "state",
port: 0,
q: url.Values{"state": []string{"state"}, "error": []string{"error"}},
statusCode: 200,
failPage: true,
},
{
desc: "Error: Query Values missing 'state' key",
reqState: "state",
port: 0,
q: url.Values{"code": []string{"code"}},
statusCode: http.StatusInternalServerError,
},
{
desc: "Error: Query Values missing had 'state' key value that was different that requested",
reqState: "state",
port: 0,
q: url.Values{"state": []string{"etats"}, "code": []string{"code"}},
statusCode: http.StatusInternalServerError,
},
{
desc: "Error: Query Values missing 'code' key",
reqState: "state",
port: 0,
q: url.Values{"state": []string{"state"}},
statusCode: http.StatusInternalServerError,
},
{
desc: "Success",
reqState: "state",
port: 0,
q: url.Values{"state": []string{"state"}, "code": []string{"code"}},
statusCode: 200,
},
}
for _, test := range tests {
serv, err := New(test.reqState, test.port)
if err != nil {
panic(err)
}
defer serv.Shutdown()
if !strings.HasPrefix(serv.Addr, "http://localhost") {
t.Fatalf("unexpected server address %s", serv.Addr)
}
u, err := url.Parse(serv.Addr)
if err != nil {
panic(err)
}
u.RawQuery = test.q.Encode()
resp, err := http.DefaultClient.Do(
&http.Request{
Method: "GET",
URL: u,
},
)
if err != nil {
panic(err)
}
if resp.StatusCode != test.statusCode {
if test.statusCode == 200 {
t.Errorf("TestServer(%s): got StatusCode == %d, want StatusCode == 200", test.desc, resp.StatusCode)
res := serv.Result(ctx)
if res.Err == nil {
t.Errorf("TestServer(%s): Result.Err == nil, want Result.Err != nil", test.desc)
}
continue
}
t.Errorf("TestServer(%s): got StatusCode == %d, want StatusCode == %d", test.desc, resp.StatusCode, test.statusCode)
res := serv.Result(ctx)
if res.Err == nil {
t.Errorf("TestServer(%s): Result.Err == nil, want Result.Err != nil", test.desc)
}
continue
}
if resp.StatusCode != 200 {
continue
}
content, err := io.ReadAll(resp.Body)
if err != nil {
panic(err)
}
if test.failPage {
if !strings.Contains(string(content), "Authentication Failed") {
t.Errorf("TestServer(%s): got okay page, want failed page", test.desc)
}
res := serv.Result(ctx)
if res.Err == nil {
t.Errorf("TestServer(%s): Result.Err == nil, want Result.Err != nil", test.desc)
}
continue
}
if !strings.Contains(string(content), "Authentication Complete") {
t.Errorf("TestServer(%s): got failed page, okay page", test.desc)
}
res := serv.Result(ctx)
if diff := pretty.Compare(Result{Code: "code"}, res); diff != "" {
t.Errorf("TestServer(%s): -want/+got:\n%s", test.desc, diff)
}
}
}
microsoft-authentication-library-for-go-1.0.0/apps/internal/mock/ 0000775 0000000 0000000 00000000000 14420263624 0025076 5 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/apps/internal/mock/mock.go 0000664 0000000 0000000 00000011215 14420263624 0026356 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package mock
import (
"bytes"
"encoding/base64"
"fmt"
"io"
"net/http"
"strings"
"time"
)
type response struct {
body []byte
callback func(*http.Request)
code int
headers http.Header
}
type responseOption interface {
apply(*response)
}
type respOpt func(*response)
func (fn respOpt) apply(r *response) {
fn(r)
}
// WithBody sets the HTTP response's body to the specified value.
func WithBody(b []byte) responseOption {
return respOpt(func(r *response) {
r.body = b
})
}
// WithCallback sets a callback to invoke before returning the response.
func WithCallback(callback func(*http.Request)) responseOption {
return respOpt(func(r *response) {
r.callback = callback
})
}
// Client is a mock HTTP client that returns a sequence of responses. Use AppendResponse to specify the sequence.
type Client struct {
resp []response
}
func (c *Client) AppendResponse(opts ...responseOption) {
r := response{code: http.StatusOK, headers: http.Header{}}
for _, o := range opts {
o.apply(&r)
}
c.resp = append(c.resp, r)
}
func (c *Client) Do(req *http.Request) (*http.Response, error) {
if len(c.resp) == 0 {
panic(fmt.Sprintf(`no response for "%s"`, req.URL.String()))
}
resp := c.resp[0]
c.resp = c.resp[1:]
if resp.callback != nil {
resp.callback(req)
}
res := http.Response{Header: resp.headers, StatusCode: resp.code}
res.Body = io.NopCloser(bytes.NewReader(resp.body))
return &res, nil
}
// CloseIdleConnections implements the comm.HTTPClient interface
func (*Client) CloseIdleConnections() {}
func GetAccessTokenBody(accessToken, idToken, refreshToken, clientInfo string, expiresIn int) []byte {
body := fmt.Sprintf(
`{"access_token": "%s","expires_in": %d,"expires_on": %d`,
accessToken, expiresIn, time.Now().Add(time.Duration(expiresIn)*time.Second).Unix(),
)
if clientInfo != "" {
body += fmt.Sprintf(`, "client_info": "%s"`, clientInfo)
}
if idToken != "" {
body += fmt.Sprintf(`, "id_token": "%s"`, idToken)
}
if refreshToken != "" {
body += fmt.Sprintf(`, "refresh_token": "%s"`, refreshToken)
}
body += "}"
return []byte(body)
}
func GetIDToken(tenant, issuer string) string {
now := time.Now().Unix()
payload := []byte(fmt.Sprintf(`{"aud": "%s","exp": %d,"iat": %d,"iss": "%s","tid": "%s"}`, tenant, now+3600, now, issuer, tenant))
return fmt.Sprintf("header.%s.signature", base64.RawStdEncoding.EncodeToString(payload))
}
func GetInstanceDiscoveryBody(host, tenant string) []byte {
authority := fmt.Sprintf("https://%s/%s", host, tenant)
body := fmt.Sprintf(`{"tenant_discovery_endpoint": "%s/v2.0/.well-known/openid-configuration","api-version": "1.1","metadata": [{"preferred_network": "%s","preferred_cache": "%s","aliases": ["%s"]}]}`,
authority, host, host, host,
)
headers := http.Header{}
headers.Add("Content-Type", "application/json; charset=utf-8")
return []byte(body)
}
func GetTenantDiscoveryBody(host, tenant string) []byte {
authority := fmt.Sprintf("https://%s/%s", host, tenant)
content := strings.ReplaceAll(`{"token_endpoint": "{authority}/oauth2/v2.0/token",
"token_endpoint_auth_methods_supported": [
"client_secret_post",
"private_key_jwt",
"client_secret_basic"
],
"jwks_uri": "{authority}/discovery/v2.0/keys",
"response_modes_supported": [
"query",
"fragment",
"form_post"
],
"subject_types_supported": [
"pairwise"
],
"id_token_signing_alg_values_supported": [
"RS256"
],
"response_types_supported": [
"code",
"id_token",
"code id_token",
"id_token token"
],
"scopes_supported": [
"openid",
"profile",
"email",
"offline_access"
],
"issuer": "{authority}/v2.0",
"request_uri_parameter_supported": false,
"userinfo_endpoint": "https://graph.microsoft.com/oidc/userinfo",
"authorization_endpoint": "{authority}/oauth2/v2.0/authorize",
"device_authorization_endpoint": "{authority}/oauth2/v2.0/devicecode",
"http_logout_supported": true,
"frontchannel_logout_supported": true,
"end_session_endpoint": "{authority}/oauth2/v2.0/logout",
"claims_supported": [
"sub",
"iss",
"cloud_instance_name",
"cloud_instance_host_name",
"cloud_graph_host_name",
"msgraph_host",
"aud",
"exp",
"iat",
"auth_time",
"acr",
"nonce",
"preferred_username",
"name",
"tid",
"ver",
"at_hash",
"c_hash",
"email"
],
"kerberos_endpoint": "{authority}/kerberos",
"tenant_region_scope": "NA",
"cloud_instance_name": "microsoftonline.com",
"cloud_graph_host_name": "graph.windows.net",
"msgraph_host": "graph.microsoft.com",
"rbac_url": "https://pas.windows.net"
}`, "{authority}", authority)
return []byte(content)
}
microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ 0000775 0000000 0000000 00000000000 14420263624 0025265 5 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/fake/ 0000775 0000000 0000000 00000000000 14420263624 0026173 5 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/fake/fake.go 0000664 0000000 0000000 00000014514 14420263624 0027435 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package fake
import (
"context"
"errors"
"fmt"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/wstrust"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/wstrust/defs"
)
// ResolveEndpoints is a fake implementation of the oauth.resolveEndpointer interface.
type ResolveEndpoints struct {
// Set this to true to have all APIs return an error.
Err bool
// fake result to return
Endpoints authority.Endpoints
}
func (f ResolveEndpoints) ResolveEndpoints(ctx context.Context, authorityInfo authority.Info, userPrincipalName string) (authority.Endpoints, error) {
if f.Err {
return authority.Endpoints{}, errors.New("error")
}
return f.Endpoints, nil
}
// AccessTokens is a fake implementation of the oauth.accessTokens interface.
type AccessTokens struct {
// Set this to true to have all APIs return an error.
Err bool
// Result is for use with FromDeviceCodeResult. On each call it returns
// the next item in this slice. They must be either an error or nil.
Result []error
Next int
// fake result to return
AccessToken accesstokens.TokenResponse
// fake result to return
DeviceCode accesstokens.DeviceCodeResult
// FromRefreshTokenCallback is an optional callback invoked by FromRefreshToken
FromRefreshTokenCallback func(appType accesstokens.AppType, authParams authority.AuthParams, cc *accesstokens.Credential, refreshToken string)
// ValidateAssertion is an optional callback for validating an assertion generated by confidential.Client
ValidateAssertion func(string)
}
func (f *AccessTokens) FromUsernamePassword(ctx context.Context, authParameters authority.AuthParams) (accesstokens.TokenResponse, error) {
if f.Err {
return accesstokens.TokenResponse{}, fmt.Errorf("error")
}
return f.AccessToken, nil
}
func (f *AccessTokens) FromAuthCode(ctx context.Context, req accesstokens.AuthCodeRequest) (accesstokens.TokenResponse, error) {
if f.Err {
return accesstokens.TokenResponse{}, fmt.Errorf("error")
}
return f.AccessToken, nil
}
func (f *AccessTokens) FromRefreshToken(ctx context.Context, appType accesstokens.AppType, authParams authority.AuthParams, cc *accesstokens.Credential, refreshToken string) (accesstokens.TokenResponse, error) {
if f.FromRefreshTokenCallback != nil {
f.FromRefreshTokenCallback(appType, authParams, cc, refreshToken)
}
if f.Err {
return accesstokens.TokenResponse{}, fmt.Errorf("error")
}
return f.AccessToken, nil
}
func (f *AccessTokens) FromClientSecret(ctx context.Context, authParameters authority.AuthParams, clientSecret string) (accesstokens.TokenResponse, error) {
if f.Err {
return accesstokens.TokenResponse{}, fmt.Errorf("error")
}
return f.AccessToken, nil
}
func (f *AccessTokens) FromAssertion(ctx context.Context, authParameters authority.AuthParams, assertion string) (accesstokens.TokenResponse, error) {
if f.Err {
return accesstokens.TokenResponse{}, fmt.Errorf("error")
}
if f.ValidateAssertion != nil {
f.ValidateAssertion(assertion)
}
return f.AccessToken, nil
}
func (f *AccessTokens) FromUserAssertionClientSecret(ctx context.Context, authParameters authority.AuthParams, userAssertion, clientSecret string) (accesstokens.TokenResponse, error) {
if f.Err {
return accesstokens.TokenResponse{}, fmt.Errorf("error")
}
return f.AccessToken, nil
}
func (f *AccessTokens) FromUserAssertionClientCertificate(ctx context.Context, authParameters authority.AuthParams, userAssertion, assertion string) (accesstokens.TokenResponse, error) {
if f.Err {
return accesstokens.TokenResponse{}, fmt.Errorf("error")
}
return f.AccessToken, nil
}
func (f *AccessTokens) DeviceCodeResult(ctx context.Context, authParameters authority.AuthParams) (accesstokens.DeviceCodeResult, error) {
if f.Err {
return accesstokens.DeviceCodeResult{}, fmt.Errorf("error")
}
return f.DeviceCode, nil
}
func (f *AccessTokens) FromDeviceCodeResult(ctx context.Context, authParameters authority.AuthParams, deviceCodeResult accesstokens.DeviceCodeResult) (accesstokens.TokenResponse, error) {
if f.Next < len(f.Result) {
defer func() { f.Next++ }()
v := f.Result[f.Next]
if v == nil {
return f.AccessToken, nil
}
return accesstokens.TokenResponse{}, v
}
panic("AccessTokens.FromDeviceCodeResult() asked for more return values than provided")
}
func (f *AccessTokens) FromSamlGrant(ctx context.Context, authParameters authority.AuthParams, samlGrant wstrust.SamlTokenInfo) (accesstokens.TokenResponse, error) {
if f.Err {
return accesstokens.TokenResponse{}, fmt.Errorf("error")
}
return f.AccessToken, nil
}
// Authority is a fake implementation of the oauth.fetchAuthority interface.
type Authority struct {
// Set this to true to have all APIs return an error.
Err bool
// The fake UserRealm to return from the UserRealm() API.
Realm authority.UserRealm
// fake result to return
InstanceResp authority.InstanceDiscoveryResponse
}
func (f Authority) UserRealm(ctx context.Context, params authority.AuthParams) (authority.UserRealm, error) {
if f.Err {
return authority.UserRealm{}, errors.New("error")
}
return f.Realm, nil
}
func (f Authority) AADInstanceDiscovery(ctx context.Context, info authority.Info) (authority.InstanceDiscoveryResponse, error) {
if f.Err {
return authority.InstanceDiscoveryResponse{}, errors.New("error")
}
return f.InstanceResp, nil
}
// WSTrust is a fake implementation of the oauth.fetchWSTrust interface.
type WSTrust struct {
// Set these to true to have their respective APIs return an error.
GetMexErr, GetSAMLTokenInfoErr bool
// fake result to return
MexDocument defs.MexDocument
// fake result to return
SamlTokenInfo wstrust.SamlTokenInfo
}
func (f WSTrust) Mex(ctx context.Context, federationMetadataURL string) (defs.MexDocument, error) {
if f.GetMexErr {
return defs.MexDocument{}, errors.New("error")
}
return f.MexDocument, nil
}
func (f WSTrust) SAMLTokenInfo(ctx context.Context, authParameters authority.AuthParams, cloudAudienceURN string, endpoint defs.Endpoint) (wstrust.SamlTokenInfo, error) {
if f.GetSAMLTokenInfoErr {
return wstrust.SamlTokenInfo{}, errors.New("error")
}
return f.SamlTokenInfo, nil
}
microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/oauth.go 0000664 0000000 0000000 00000033534 14420263624 0026744 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package oauth
import (
"context"
"encoding/json"
"fmt"
"io"
"time"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/errors"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/exported"
internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/wstrust"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/wstrust/defs"
"github.com/google/uuid"
)
// ResolveEndpointer contains the methods for resolving authority endpoints.
type ResolveEndpointer interface {
ResolveEndpoints(ctx context.Context, authorityInfo authority.Info, userPrincipalName string) (authority.Endpoints, error)
}
// AccessTokens contains the methods for fetching tokens from different sources.
type AccessTokens interface {
DeviceCodeResult(ctx context.Context, authParameters authority.AuthParams) (accesstokens.DeviceCodeResult, error)
FromUsernamePassword(ctx context.Context, authParameters authority.AuthParams) (accesstokens.TokenResponse, error)
FromAuthCode(ctx context.Context, req accesstokens.AuthCodeRequest) (accesstokens.TokenResponse, error)
FromRefreshToken(ctx context.Context, appType accesstokens.AppType, authParams authority.AuthParams, cc *accesstokens.Credential, refreshToken string) (accesstokens.TokenResponse, error)
FromClientSecret(ctx context.Context, authParameters authority.AuthParams, clientSecret string) (accesstokens.TokenResponse, error)
FromAssertion(ctx context.Context, authParameters authority.AuthParams, assertion string) (accesstokens.TokenResponse, error)
FromUserAssertionClientSecret(ctx context.Context, authParameters authority.AuthParams, userAssertion string, clientSecret string) (accesstokens.TokenResponse, error)
FromUserAssertionClientCertificate(ctx context.Context, authParameters authority.AuthParams, userAssertion string, assertion string) (accesstokens.TokenResponse, error)
FromDeviceCodeResult(ctx context.Context, authParameters authority.AuthParams, deviceCodeResult accesstokens.DeviceCodeResult) (accesstokens.TokenResponse, error)
FromSamlGrant(ctx context.Context, authParameters authority.AuthParams, samlGrant wstrust.SamlTokenInfo) (accesstokens.TokenResponse, error)
}
// FetchAuthority will be implemented by authority.Authority.
type FetchAuthority interface {
UserRealm(context.Context, authority.AuthParams) (authority.UserRealm, error)
AADInstanceDiscovery(context.Context, authority.Info) (authority.InstanceDiscoveryResponse, error)
}
// FetchWSTrust contains the methods for interacting with WSTrust endpoints.
type FetchWSTrust interface {
Mex(ctx context.Context, federationMetadataURL string) (defs.MexDocument, error)
SAMLTokenInfo(ctx context.Context, authParameters authority.AuthParams, cloudAudienceURN string, endpoint defs.Endpoint) (wstrust.SamlTokenInfo, error)
}
// Client provides tokens for various types of token requests.
type Client struct {
Resolver ResolveEndpointer
AccessTokens AccessTokens
Authority FetchAuthority
WSTrust FetchWSTrust
}
// New is the constructor for Token.
func New(httpClient ops.HTTPClient) *Client {
r := ops.New(httpClient)
return &Client{
Resolver: newAuthorityEndpoint(r),
AccessTokens: r.AccessTokens(),
Authority: r.Authority(),
WSTrust: r.WSTrust(),
}
}
// ResolveEndpoints gets the authorization and token endpoints and creates an AuthorityEndpoints instance.
func (t *Client) ResolveEndpoints(ctx context.Context, authorityInfo authority.Info, userPrincipalName string) (authority.Endpoints, error) {
return t.Resolver.ResolveEndpoints(ctx, authorityInfo, userPrincipalName)
}
// AADInstanceDiscovery attempts to discover a tenant endpoint (used in OIDC auth with an authorization endpoint).
// This is done by AAD which allows for aliasing of tenants (windows.sts.net is the same as login.windows.com).
func (t *Client) AADInstanceDiscovery(ctx context.Context, authorityInfo authority.Info) (authority.InstanceDiscoveryResponse, error) {
return t.Authority.AADInstanceDiscovery(ctx, authorityInfo)
}
// AuthCode returns a token based on an authorization code.
func (t *Client) AuthCode(ctx context.Context, req accesstokens.AuthCodeRequest) (accesstokens.TokenResponse, error) {
if err := scopeError(req.AuthParams); err != nil {
return accesstokens.TokenResponse{}, err
}
if err := t.resolveEndpoint(ctx, &req.AuthParams, ""); err != nil {
return accesstokens.TokenResponse{}, err
}
tResp, err := t.AccessTokens.FromAuthCode(ctx, req)
if err != nil {
return accesstokens.TokenResponse{}, fmt.Errorf("could not retrieve token from auth code: %w", err)
}
return tResp, nil
}
// Credential acquires a token from the authority using a client credentials grant.
func (t *Client) Credential(ctx context.Context, authParams authority.AuthParams, cred *accesstokens.Credential) (accesstokens.TokenResponse, error) {
if cred.TokenProvider != nil {
now := time.Now()
scopes := make([]string, len(authParams.Scopes))
copy(scopes, authParams.Scopes)
params := exported.TokenProviderParameters{
Claims: authParams.Claims,
CorrelationID: uuid.New().String(),
Scopes: scopes,
TenantID: authParams.AuthorityInfo.Tenant,
}
tr, err := cred.TokenProvider(ctx, params)
if err != nil {
if len(scopes) == 0 {
err = fmt.Errorf("token request had an empty authority.AuthParams.Scopes, which may cause the following error: %w", err)
return accesstokens.TokenResponse{}, err
}
return accesstokens.TokenResponse{}, err
}
return accesstokens.TokenResponse{
AccessToken: tr.AccessToken,
ExpiresOn: internalTime.DurationTime{
T: now.Add(time.Duration(tr.ExpiresInSeconds) * time.Second),
},
GrantedScopes: accesstokens.Scopes{Slice: authParams.Scopes},
}, nil
}
if err := t.resolveEndpoint(ctx, &authParams, ""); err != nil {
return accesstokens.TokenResponse{}, err
}
if cred.Secret != "" {
return t.AccessTokens.FromClientSecret(ctx, authParams, cred.Secret)
}
jwt, err := cred.JWT(ctx, authParams)
if err != nil {
return accesstokens.TokenResponse{}, err
}
return t.AccessTokens.FromAssertion(ctx, authParams, jwt)
}
// Credential acquires a token from the authority using a client credentials grant.
func (t *Client) OnBehalfOf(ctx context.Context, authParams authority.AuthParams, cred *accesstokens.Credential) (accesstokens.TokenResponse, error) {
if err := scopeError(authParams); err != nil {
return accesstokens.TokenResponse{}, err
}
if err := t.resolveEndpoint(ctx, &authParams, ""); err != nil {
return accesstokens.TokenResponse{}, err
}
if cred.Secret != "" {
return t.AccessTokens.FromUserAssertionClientSecret(ctx, authParams, authParams.UserAssertion, cred.Secret)
}
jwt, err := cred.JWT(ctx, authParams)
if err != nil {
return accesstokens.TokenResponse{}, err
}
tr, err := t.AccessTokens.FromUserAssertionClientCertificate(ctx, authParams, authParams.UserAssertion, jwt)
if err != nil {
return accesstokens.TokenResponse{}, err
}
return tr, nil
}
func (t *Client) Refresh(ctx context.Context, reqType accesstokens.AppType, authParams authority.AuthParams, cc *accesstokens.Credential, refreshToken accesstokens.RefreshToken) (accesstokens.TokenResponse, error) {
if err := scopeError(authParams); err != nil {
return accesstokens.TokenResponse{}, err
}
if err := t.resolveEndpoint(ctx, &authParams, ""); err != nil {
return accesstokens.TokenResponse{}, err
}
tr, err := t.AccessTokens.FromRefreshToken(ctx, reqType, authParams, cc, refreshToken.Secret)
if err != nil {
return accesstokens.TokenResponse{}, err
}
return tr, nil
}
// UsernamePassword retrieves a token where a username and password is used. However, if this is
// a user realm of "Federated", this uses SAML tokens. If "Managed", uses normal username/password.
func (t *Client) UsernamePassword(ctx context.Context, authParams authority.AuthParams) (accesstokens.TokenResponse, error) {
if err := scopeError(authParams); err != nil {
return accesstokens.TokenResponse{}, err
}
if authParams.AuthorityInfo.AuthorityType == authority.ADFS {
if err := t.resolveEndpoint(ctx, &authParams, authParams.Username); err != nil {
return accesstokens.TokenResponse{}, err
}
return t.AccessTokens.FromUsernamePassword(ctx, authParams)
}
if err := t.resolveEndpoint(ctx, &authParams, ""); err != nil {
return accesstokens.TokenResponse{}, err
}
userRealm, err := t.Authority.UserRealm(ctx, authParams)
if err != nil {
return accesstokens.TokenResponse{}, fmt.Errorf("problem getting user realm from authority: %w", err)
}
switch userRealm.AccountType {
case authority.Federated:
mexDoc, err := t.WSTrust.Mex(ctx, userRealm.FederationMetadataURL)
if err != nil {
err = fmt.Errorf("problem getting mex doc from federated url(%s): %w", userRealm.FederationMetadataURL, err)
return accesstokens.TokenResponse{}, err
}
saml, err := t.WSTrust.SAMLTokenInfo(ctx, authParams, userRealm.CloudAudienceURN, mexDoc.UsernamePasswordEndpoint)
if err != nil {
err = fmt.Errorf("problem getting SAML token info: %w", err)
return accesstokens.TokenResponse{}, err
}
tr, err := t.AccessTokens.FromSamlGrant(ctx, authParams, saml)
if err != nil {
return accesstokens.TokenResponse{}, err
}
return tr, nil
case authority.Managed:
if len(authParams.Scopes) == 0 {
err = fmt.Errorf("token request had an empty authority.AuthParams.Scopes, which may cause the following error: %w", err)
return accesstokens.TokenResponse{}, err
}
return t.AccessTokens.FromUsernamePassword(ctx, authParams)
}
return accesstokens.TokenResponse{}, errors.New("unknown account type")
}
// DeviceCode is the result of a call to Token.DeviceCode().
type DeviceCode struct {
// Result is the device code result from the first call in the device code flow. This allows
// the caller to retrieve the displayed code that is used to authorize on the second device.
Result accesstokens.DeviceCodeResult
authParams authority.AuthParams
accessTokens AccessTokens
}
// Token returns a token AFTER the user uses the user code on the second device. This will block
// until either: (1) the code is input by the user and the service releases a token, (2) the token
// expires, (3) the Context passed to .DeviceCode() is cancelled or expires, (4) some other service
// error occurs.
func (d DeviceCode) Token(ctx context.Context) (accesstokens.TokenResponse, error) {
if d.accessTokens == nil {
return accesstokens.TokenResponse{}, fmt.Errorf("DeviceCode was either created outside its package or the creating method had an error. DeviceCode is not valid")
}
var cancel context.CancelFunc
if deadline, ok := ctx.Deadline(); !ok || d.Result.ExpiresOn.Before(deadline) {
ctx, cancel = context.WithDeadline(ctx, d.Result.ExpiresOn)
} else {
ctx, cancel = context.WithCancel(ctx)
}
defer cancel()
var interval = 50 * time.Millisecond
timer := time.NewTimer(interval)
defer timer.Stop()
for {
timer.Reset(interval)
select {
case <-ctx.Done():
return accesstokens.TokenResponse{}, ctx.Err()
case <-timer.C:
interval += interval * 2
if interval > 5*time.Second {
interval = 5 * time.Second
}
}
token, err := d.accessTokens.FromDeviceCodeResult(ctx, d.authParams, d.Result)
if err != nil && isWaitDeviceCodeErr(err) {
continue
}
return token, err // This handles if it was a non-wait error or success
}
}
type deviceCodeError struct {
Error string `json:"error"`
}
func isWaitDeviceCodeErr(err error) bool {
var c errors.CallErr
if !errors.As(err, &c) {
return false
}
if c.Resp.StatusCode != 400 {
return false
}
var dCErr deviceCodeError
defer c.Resp.Body.Close()
body, err := io.ReadAll(c.Resp.Body)
if err != nil {
return false
}
err = json.Unmarshal(body, &dCErr)
if err != nil {
return false
}
if dCErr.Error == "authorization_pending" || dCErr.Error == "slow_down" {
return true
}
return false
}
// DeviceCode returns a DeviceCode object that can be used to get the code that must be entered on the second
// device and optionally the token once the code has been entered on the second device.
func (t *Client) DeviceCode(ctx context.Context, authParams authority.AuthParams) (DeviceCode, error) {
if err := scopeError(authParams); err != nil {
return DeviceCode{}, err
}
if err := t.resolveEndpoint(ctx, &authParams, ""); err != nil {
return DeviceCode{}, err
}
dcr, err := t.AccessTokens.DeviceCodeResult(ctx, authParams)
if err != nil {
return DeviceCode{}, err
}
return DeviceCode{Result: dcr, authParams: authParams, accessTokens: t.AccessTokens}, nil
}
func (t *Client) resolveEndpoint(ctx context.Context, authParams *authority.AuthParams, userPrincipalName string) error {
endpoints, err := t.Resolver.ResolveEndpoints(ctx, authParams.AuthorityInfo, userPrincipalName)
if err != nil {
return fmt.Errorf("unable to resolve an endpoint: %s", err)
}
authParams.Endpoints = endpoints
return nil
}
// scopeError takes an authority.AuthParams and returns an error
// if len(AuthParams.Scope) == 0.
func scopeError(a authority.AuthParams) error {
// TODO(someone): we could look deeper at the message to determine if
// it's a scope error, but this is a good start.
/*
{error":"invalid_scope","error_description":"AADSTS1002012: The provided value for scope
openid offline_access profile is not valid. Client credential flows must have a scope value
with /.default suffixed to the resource identifier (application ID URI)...}
*/
if len(a.Scopes) == 0 {
return fmt.Errorf("token request had an empty authority.AuthParams.Scopes, which is invalid")
}
return nil
}
microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/oauth_test.go 0000664 0000000 0000000 00000025160 14420263624 0027777 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package oauth
// NOTE: These tests cover that we handle errors from other lower level modules.
// We don't actually care about a TokenResponse{}, that is gathered from a remote system
// and they are tested via intergration tests (data retrieved from one system and passed from
// to another). We care about execution behavior (service X says there is an error and we handle it,
// we require .X is set and input doesn't have it, ...)
import (
"bytes"
"context"
"crypto/x509"
"io"
"net/http"
"testing"
"time"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/errors"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/exported"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/fake"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
)
var testScopes = []string{"scope"}
func TestAuthCode(t *testing.T) {
tests := []struct {
desc string
re fake.ResolveEndpoints
at *fake.AccessTokens
err bool
}{
{
desc: "Error: Unable to resolve endpoints",
re: fake.ResolveEndpoints{Err: true},
at: &fake.AccessTokens{},
err: true,
},
{
desc: "Error: REST access token error",
re: fake.ResolveEndpoints{},
at: &fake.AccessTokens{Err: true},
err: true,
},
{
desc: "Success",
re: fake.ResolveEndpoints{},
at: &fake.AccessTokens{},
},
}
token := &Client{}
for _, test := range tests {
token.AccessTokens = test.at
token.Resolver = test.re
_, err := token.AuthCode(context.Background(), accesstokens.AuthCodeRequest{AuthParams: authority.AuthParams{Scopes: testScopes}})
switch {
case err == nil && test.err:
t.Errorf("TestAuthCode(%s): got err == nil, want err != nil", test.desc)
case err != nil && !test.err:
t.Errorf("TestAuthCode(%s): got err == %s, want err == nil", test.desc, err)
}
}
}
func TestCredential(t *testing.T) {
callback := func(context.Context, exported.AssertionRequestOptions) (string, error) {
return "assertion", nil
}
tests := []struct {
desc string
re fake.ResolveEndpoints
at *fake.AccessTokens
authParams authority.AuthParams
cred *accesstokens.Credential
err bool
}{
{
desc: "Error: Unable to resolve endpoints",
re: fake.ResolveEndpoints{Err: true},
at: &fake.AccessTokens{},
cred: &accesstokens.Credential{
AssertionCallback: callback,
},
err: true,
},
{
desc: "Error: REST access token error on secret",
re: fake.ResolveEndpoints{},
at: &fake.AccessTokens{Err: true},
cred: &accesstokens.Credential{
AssertionCallback: callback,
},
err: true,
},
{
desc: "Error: could not generate JWT from cred assertion",
re: fake.ResolveEndpoints{},
at: &fake.AccessTokens{Err: true},
cred: &accesstokens.Credential{
AssertionCallback: callback,
Cert: &x509.Certificate{},
// Key is nil and causes token.SignedString(c.Key) to fail in Credential.JWT()
},
err: true,
},
{
desc: "Error: REST access token error on assertion",
re: fake.ResolveEndpoints{},
at: &fake.AccessTokens{Err: true},
cred: &accesstokens.Credential{
AssertionCallback: callback,
},
err: true,
},
{
desc: "Success: secret cred",
re: fake.ResolveEndpoints{},
at: &fake.AccessTokens{},
cred: &accesstokens.Credential{
Secret: "secret",
},
},
{
desc: "Success: assertion cred",
re: fake.ResolveEndpoints{},
at: &fake.AccessTokens{},
cred: &accesstokens.Credential{
AssertionCallback: callback,
},
},
}
token := &Client{}
for _, test := range tests {
token.AccessTokens = test.at
token.Resolver = test.re
_, err := token.Credential(context.Background(), test.authParams, test.cred)
switch {
case err == nil && test.err:
t.Errorf("TestCredential(%s): got err == nil, want err != nil", test.desc)
case err != nil && !test.err:
t.Errorf("TestCredential(%s): got err == %s, want err == nil", test.desc, err)
}
}
}
func TestRefresh(t *testing.T) {
tests := []struct {
desc string
re fake.ResolveEndpoints
at *fake.AccessTokens
err bool
}{
{
desc: "Error: Unable to resolve endpoints",
re: fake.ResolveEndpoints{Err: true},
at: &fake.AccessTokens{},
err: true,
},
{
desc: "Error: REST access token error",
re: fake.ResolveEndpoints{},
at: &fake.AccessTokens{Err: true},
err: true,
},
{
desc: "Success",
re: fake.ResolveEndpoints{},
at: &fake.AccessTokens{},
},
}
token := &Client{}
for _, test := range tests {
token.AccessTokens = test.at
token.Resolver = test.re
_, err := token.Refresh(
context.Background(),
accesstokens.ATPublic,
authority.AuthParams{Scopes: testScopes},
&accesstokens.Credential{},
accesstokens.RefreshToken{},
)
switch {
case err == nil && test.err:
t.Errorf("TestRefresh(%s): got err == nil, want err != nil", test.desc)
case err != nil && !test.err:
t.Errorf("TestRefresh(%s): got err == %s, want err == nil", test.desc, err)
}
}
}
func TestUsernamePassword(t *testing.T) {
tests := []struct {
desc string
re fake.ResolveEndpoints
at *fake.AccessTokens
au fake.Authority
ws fake.WSTrust
err bool
}{
{
desc: "Error: Unable to resolve endpoints",
re: fake.ResolveEndpoints{Err: true},
at: &fake.AccessTokens{},
au: fake.Authority{Realm: authority.UserRealm{AccountType: authority.Managed}},
err: true,
},
{
desc: "Error: authority.Federated and Mex() error",
re: fake.ResolveEndpoints{Err: false},
at: &fake.AccessTokens{},
au: fake.Authority{Realm: authority.UserRealm{AccountType: authority.Federated}},
ws: fake.WSTrust{GetMexErr: true},
err: true,
},
{
desc: "Error: authority.Federated and SAMLTokenInfo() error",
re: fake.ResolveEndpoints{Err: false},
at: &fake.AccessTokens{},
au: fake.Authority{Realm: authority.UserRealm{AccountType: authority.Federated}},
ws: fake.WSTrust{GetSAMLTokenInfoErr: true},
err: true,
},
{
desc: "Error: authority.Federated and GetAccessTokenFromSamlGrant() error",
re: fake.ResolveEndpoints{Err: false},
au: fake.Authority{Realm: authority.UserRealm{AccountType: authority.Federated}},
at: &fake.AccessTokens{Err: true},
err: true,
},
{
desc: "Error: authority.Managed and REST access token error",
re: fake.ResolveEndpoints{Err: false},
at: &fake.AccessTokens{Err: true},
au: fake.Authority{Realm: authority.UserRealm{AccountType: authority.Managed}},
err: true,
},
{
desc: "Success: authority.Managed",
re: fake.ResolveEndpoints{Err: false},
at: &fake.AccessTokens{},
au: fake.Authority{Realm: authority.UserRealm{AccountType: authority.Managed}},
},
{
desc: "Success: authority.Federated",
re: fake.ResolveEndpoints{Err: false},
at: &fake.AccessTokens{},
au: fake.Authority{Realm: authority.UserRealm{AccountType: authority.Federated}},
},
}
token := &Client{}
for _, test := range tests {
token.AccessTokens = test.at
token.Authority = test.au
token.Resolver = test.re
token.WSTrust = test.ws
_, err := token.UsernamePassword(context.Background(), authority.AuthParams{Scopes: testScopes})
switch {
case err == nil && test.err:
t.Errorf("TestUsernamePassword(%s): got err == nil, want err != nil", test.desc)
case err != nil && !test.err:
t.Errorf("TestUsernamePassword(%s): got err == %s, want err == nil", test.desc, err)
}
}
}
func TestDeviceCode(t *testing.T) {
tests := []struct {
desc string
dc DeviceCode
err bool
}{
{
desc: "Error: .accessTokens == nil",
dc: DeviceCode{},
err: true,
},
{
desc: "Error: FromDeviceCodeResult() returned a !isWaitDeviceCodeErr",
dc: DeviceCode{
accessTokens: &fake.AccessTokens{
Result: []error{
errors.CallErr{
Resp: &http.Response{
StatusCode: 400,
Body: io.NopCloser(bytes.NewReader([]byte(`{"error": "authorization_pending"}`))),
},
},
errors.CallErr{
Resp: &http.Response{
StatusCode: 400,
Body: io.NopCloser(bytes.NewReader([]byte(`{"error": "slow_down"}`))),
},
},
errors.CallErr{
Resp: &http.Response{
StatusCode: 400,
Body: io.NopCloser(bytes.NewReader([]byte(`{"error": "bad_error"}`))),
},
},
nil,
},
},
},
err: true,
},
{
desc: "Success",
dc: DeviceCode{
Result: accesstokens.DeviceCodeResult{
ExpiresOn: time.Now().Add(5 * time.Minute),
},
accessTokens: &fake.AccessTokens{
Result: []error{
errors.CallErr{
Resp: &http.Response{
StatusCode: 400,
Body: io.NopCloser(bytes.NewReader([]byte(`{"error": "authorization_pending"}`))),
},
},
errors.CallErr{
Resp: &http.Response{
StatusCode: 400,
Body: io.NopCloser(bytes.NewReader([]byte(`{"error": "slow_down"}`))),
},
},
nil,
},
},
},
},
}
for _, test := range tests {
_, err := test.dc.Token(context.Background())
switch {
case err == nil && test.err:
t.Errorf("TestDeviceCode(%s): got err == nil, want err != nil", test.desc)
case err != nil && !test.err:
t.Errorf("TestDeviceCode(%s): got err == %s, want err == nil", test.desc, err)
}
}
}
func TestDeviceCodeToken(t *testing.T) {
tests := []struct {
desc string
re fake.ResolveEndpoints
at *fake.AccessTokens
err bool
}{
{
desc: "Error: Unable to resolve endpoints",
re: fake.ResolveEndpoints{Err: true},
at: &fake.AccessTokens{},
err: true,
},
{
desc: "Error: REST access token error",
re: fake.ResolveEndpoints{},
at: &fake.AccessTokens{Err: true},
err: true,
},
{
desc: "Success",
re: fake.ResolveEndpoints{},
at: &fake.AccessTokens{},
},
}
token := &Client{}
for _, test := range tests {
token.AccessTokens = test.at
token.Resolver = test.re
dc, err := token.DeviceCode(context.Background(), authority.AuthParams{Scopes: testScopes})
switch {
case err == nil && test.err:
t.Errorf("TestDeviceCodeToken(%s): got err == nil, want err != nil", test.desc)
continue
case err != nil && !test.err:
t.Errorf("TestDeviceCodeToken(%s): got err == %s, want err == nil", test.desc, err)
continue
case err != nil:
continue
}
if dc.accessTokens == nil {
t.Errorf("TestDeviceCodeToken(%s): got DeviceCode{} back that did not have accessTokens set", test.desc)
}
}
}
microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/ 0000775 0000000 0000000 00000000000 14420263624 0026066 5 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/accesstokens/ 0000775 0000000 0000000 00000000000 14420263624 0030553 5 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/accesstokens/accesstokens.go 0000664 0000000 0000000 00000035676 14420263624 0033610 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
/*
Package accesstokens exposes a REST client for querying backend systems to get various types of
access tokens (oauth) for use in authentication.
These calls are of type "application/x-www-form-urlencoded". This means we use url.Values to
represent arguments and then encode them into the POST body message. We receive JSON in
return for the requests. The request definition is defined in https://tools.ietf.org/html/rfc7521#section-4.2 .
*/
package accesstokens
import (
"context"
"crypto"
/* #nosec */
"crypto/sha1"
"crypto/x509"
"encoding/base64"
"encoding/json"
"fmt"
"net/url"
"strconv"
"strings"
"time"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/exported"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/internal/grant"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/wstrust"
"github.com/golang-jwt/jwt/v4"
"github.com/google/uuid"
)
const (
grantType = "grant_type"
deviceCode = "device_code"
clientID = "client_id"
clientInfo = "client_info"
clientInfoVal = "1"
username = "username"
password = "password"
)
//go:generate stringer -type=AppType
// AppType is whether the authorization code flow is for a public or confidential client.
type AppType int8
const (
// ATUnknown is the zero value when the type hasn't been set.
ATUnknown AppType = iota
// ATPublic indicates this if for the Public.Client.
ATPublic
// ATConfidential indicates this if for the Confidential.Client.
ATConfidential
)
type urlFormCaller interface {
URLFormCall(ctx context.Context, endpoint string, qv url.Values, resp interface{}) error
}
// DeviceCodeResponse represents the HTTP response received from the device code endpoint
type DeviceCodeResponse struct {
authority.OAuthResponseBase
UserCode string `json:"user_code"`
DeviceCode string `json:"device_code"`
VerificationURL string `json:"verification_url"`
ExpiresIn int `json:"expires_in"`
Interval int `json:"interval"`
Message string `json:"message"`
AdditionalFields map[string]interface{}
}
// Convert converts the DeviceCodeResponse to a DeviceCodeResult
func (dcr DeviceCodeResponse) Convert(clientID string, scopes []string) DeviceCodeResult {
expiresOn := time.Now().UTC().Add(time.Duration(dcr.ExpiresIn) * time.Second)
return NewDeviceCodeResult(dcr.UserCode, dcr.DeviceCode, dcr.VerificationURL, expiresOn, dcr.Interval, dcr.Message, clientID, scopes)
}
// Credential represents the credential used in confidential client flows. This can be either
// a Secret or Cert/Key.
type Credential struct {
// Secret contains the credential secret if we are doing auth by secret.
Secret string
// Cert is the public certificate, if we're authenticating by certificate.
Cert *x509.Certificate
// Key is the private key for signing, if we're authenticating by certificate.
Key crypto.PrivateKey
// X5c is the JWT assertion's x5c header value, required for SN/I authentication.
X5c []string
// AssertionCallback is a function provided by the application, if we're authenticating by assertion.
AssertionCallback func(context.Context, exported.AssertionRequestOptions) (string, error)
// TokenProvider is a function provided by the application that implements custom authentication
// logic for a confidential client
TokenProvider func(context.Context, exported.TokenProviderParameters) (exported.TokenProviderResult, error)
}
// JWT gets the jwt assertion when the credential is not using a secret.
func (c *Credential) JWT(ctx context.Context, authParams authority.AuthParams) (string, error) {
if c.AssertionCallback != nil {
options := exported.AssertionRequestOptions{
ClientID: authParams.ClientID,
TokenEndpoint: authParams.Endpoints.TokenEndpoint,
}
return c.AssertionCallback(ctx, options)
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, jwt.MapClaims{
"aud": authParams.Endpoints.TokenEndpoint,
"exp": json.Number(strconv.FormatInt(time.Now().Add(10*time.Minute).Unix(), 10)),
"iss": authParams.ClientID,
"jti": uuid.New().String(),
"nbf": json.Number(strconv.FormatInt(time.Now().Unix(), 10)),
"sub": authParams.ClientID,
})
token.Header = map[string]interface{}{
"alg": "RS256",
"typ": "JWT",
"x5t": base64.StdEncoding.EncodeToString(thumbprint(c.Cert)),
}
if authParams.SendX5C {
token.Header["x5c"] = c.X5c
}
assertion, err := token.SignedString(c.Key)
if err != nil {
return "", fmt.Errorf("unable to sign a JWT token using private key: %w", err)
}
return assertion, nil
}
// thumbprint runs the asn1.Der bytes through sha1 for use in the x5t parameter of JWT.
// https://tools.ietf.org/html/rfc7517#section-4.8
func thumbprint(cert *x509.Certificate) []byte {
/* #nosec */
a := sha1.Sum(cert.Raw)
return a[:]
}
// Client represents the REST calls to get tokens from token generator backends.
type Client struct {
// Comm provides the HTTP transport client.
Comm urlFormCaller
testing bool
}
// FromUsernamePassword uses a username and password to get an access token.
func (c Client) FromUsernamePassword(ctx context.Context, authParameters authority.AuthParams) (TokenResponse, error) {
qv := url.Values{}
if err := addClaims(qv, authParameters); err != nil {
return TokenResponse{}, err
}
qv.Set(grantType, grant.Password)
qv.Set(username, authParameters.Username)
qv.Set(password, authParameters.Password)
qv.Set(clientID, authParameters.ClientID)
qv.Set(clientInfo, clientInfoVal)
addScopeQueryParam(qv, authParameters)
return c.doTokenResp(ctx, authParameters, qv)
}
// AuthCodeRequest stores the values required to request a token from the authority using an authorization code
type AuthCodeRequest struct {
AuthParams authority.AuthParams
Code string
CodeChallenge string
Credential *Credential
AppType AppType
}
// NewCodeChallengeRequest returns an AuthCodeRequest that uses a code challenge..
func NewCodeChallengeRequest(params authority.AuthParams, appType AppType, cc *Credential, code, challenge string) (AuthCodeRequest, error) {
if appType == ATUnknown {
return AuthCodeRequest{}, fmt.Errorf("bug: NewCodeChallengeRequest() called with AppType == ATUnknown")
}
return AuthCodeRequest{
AuthParams: params,
AppType: appType,
Code: code,
CodeChallenge: challenge,
Credential: cc,
}, nil
}
// FromAuthCode uses an authorization code to retrieve an access token.
func (c Client) FromAuthCode(ctx context.Context, req AuthCodeRequest) (TokenResponse, error) {
var qv url.Values
switch req.AppType {
case ATUnknown:
return TokenResponse{}, fmt.Errorf("bug: Token.AuthCode() received request with AppType == ATUnknown")
case ATConfidential:
var err error
if req.Credential == nil {
return TokenResponse{}, fmt.Errorf("AuthCodeRequest had nil Credential for Confidential app")
}
qv, err = prepURLVals(ctx, req.Credential, req.AuthParams)
if err != nil {
return TokenResponse{}, err
}
case ATPublic:
qv = url.Values{}
default:
return TokenResponse{}, fmt.Errorf("bug: Token.AuthCode() received request with AppType == %v, which we do not recongnize", req.AppType)
}
qv.Set(grantType, grant.AuthCode)
qv.Set("code", req.Code)
qv.Set("code_verifier", req.CodeChallenge)
qv.Set("redirect_uri", req.AuthParams.Redirecturi)
qv.Set(clientID, req.AuthParams.ClientID)
qv.Set(clientInfo, clientInfoVal)
addScopeQueryParam(qv, req.AuthParams)
if err := addClaims(qv, req.AuthParams); err != nil {
return TokenResponse{}, err
}
return c.doTokenResp(ctx, req.AuthParams, qv)
}
// FromRefreshToken uses a refresh token (for refreshing credentials) to get a new access token.
func (c Client) FromRefreshToken(ctx context.Context, appType AppType, authParams authority.AuthParams, cc *Credential, refreshToken string) (TokenResponse, error) {
qv := url.Values{}
if appType == ATConfidential {
var err error
qv, err = prepURLVals(ctx, cc, authParams)
if err != nil {
return TokenResponse{}, err
}
}
if err := addClaims(qv, authParams); err != nil {
return TokenResponse{}, err
}
qv.Set(grantType, grant.RefreshToken)
qv.Set(clientID, authParams.ClientID)
qv.Set(clientInfo, clientInfoVal)
qv.Set("refresh_token", refreshToken)
addScopeQueryParam(qv, authParams)
return c.doTokenResp(ctx, authParams, qv)
}
// FromClientSecret uses a client's secret (aka password) to get a new token.
func (c Client) FromClientSecret(ctx context.Context, authParameters authority.AuthParams, clientSecret string) (TokenResponse, error) {
qv := url.Values{}
if err := addClaims(qv, authParameters); err != nil {
return TokenResponse{}, err
}
qv.Set(grantType, grant.ClientCredential)
qv.Set("client_secret", clientSecret)
qv.Set(clientID, authParameters.ClientID)
addScopeQueryParam(qv, authParameters)
token, err := c.doTokenResp(ctx, authParameters, qv)
if err != nil {
return token, fmt.Errorf("FromClientSecret(): %w", err)
}
return token, nil
}
func (c Client) FromAssertion(ctx context.Context, authParameters authority.AuthParams, assertion string) (TokenResponse, error) {
qv := url.Values{}
if err := addClaims(qv, authParameters); err != nil {
return TokenResponse{}, err
}
qv.Set(grantType, grant.ClientCredential)
qv.Set("client_assertion_type", grant.ClientAssertion)
qv.Set("client_assertion", assertion)
qv.Set(clientID, authParameters.ClientID)
qv.Set(clientInfo, clientInfoVal)
addScopeQueryParam(qv, authParameters)
token, err := c.doTokenResp(ctx, authParameters, qv)
if err != nil {
return token, fmt.Errorf("FromAssertion(): %w", err)
}
return token, nil
}
func (c Client) FromUserAssertionClientSecret(ctx context.Context, authParameters authority.AuthParams, userAssertion string, clientSecret string) (TokenResponse, error) {
qv := url.Values{}
if err := addClaims(qv, authParameters); err != nil {
return TokenResponse{}, err
}
qv.Set(grantType, grant.JWT)
qv.Set(clientID, authParameters.ClientID)
qv.Set("client_secret", clientSecret)
qv.Set("assertion", userAssertion)
qv.Set(clientInfo, clientInfoVal)
qv.Set("requested_token_use", "on_behalf_of")
addScopeQueryParam(qv, authParameters)
return c.doTokenResp(ctx, authParameters, qv)
}
func (c Client) FromUserAssertionClientCertificate(ctx context.Context, authParameters authority.AuthParams, userAssertion string, assertion string) (TokenResponse, error) {
qv := url.Values{}
if err := addClaims(qv, authParameters); err != nil {
return TokenResponse{}, err
}
qv.Set(grantType, grant.JWT)
qv.Set("client_assertion_type", grant.ClientAssertion)
qv.Set("client_assertion", assertion)
qv.Set(clientID, authParameters.ClientID)
qv.Set("assertion", userAssertion)
qv.Set(clientInfo, clientInfoVal)
qv.Set("requested_token_use", "on_behalf_of")
addScopeQueryParam(qv, authParameters)
return c.doTokenResp(ctx, authParameters, qv)
}
func (c Client) DeviceCodeResult(ctx context.Context, authParameters authority.AuthParams) (DeviceCodeResult, error) {
qv := url.Values{}
if err := addClaims(qv, authParameters); err != nil {
return DeviceCodeResult{}, err
}
qv.Set(clientID, authParameters.ClientID)
addScopeQueryParam(qv, authParameters)
endpoint := strings.Replace(authParameters.Endpoints.TokenEndpoint, "token", "devicecode", -1)
resp := DeviceCodeResponse{}
err := c.Comm.URLFormCall(ctx, endpoint, qv, &resp)
if err != nil {
return DeviceCodeResult{}, err
}
return resp.Convert(authParameters.ClientID, authParameters.Scopes), nil
}
func (c Client) FromDeviceCodeResult(ctx context.Context, authParameters authority.AuthParams, deviceCodeResult DeviceCodeResult) (TokenResponse, error) {
qv := url.Values{}
if err := addClaims(qv, authParameters); err != nil {
return TokenResponse{}, err
}
qv.Set(grantType, grant.DeviceCode)
qv.Set(deviceCode, deviceCodeResult.DeviceCode)
qv.Set(clientID, authParameters.ClientID)
qv.Set(clientInfo, clientInfoVal)
addScopeQueryParam(qv, authParameters)
return c.doTokenResp(ctx, authParameters, qv)
}
func (c Client) FromSamlGrant(ctx context.Context, authParameters authority.AuthParams, samlGrant wstrust.SamlTokenInfo) (TokenResponse, error) {
qv := url.Values{}
if err := addClaims(qv, authParameters); err != nil {
return TokenResponse{}, err
}
qv.Set(username, authParameters.Username)
qv.Set(password, authParameters.Password)
qv.Set(clientID, authParameters.ClientID)
qv.Set(clientInfo, clientInfoVal)
qv.Set("assertion", base64.StdEncoding.WithPadding(base64.StdPadding).EncodeToString([]byte(samlGrant.Assertion)))
addScopeQueryParam(qv, authParameters)
switch samlGrant.AssertionType {
case grant.SAMLV1:
qv.Set(grantType, grant.SAMLV1)
case grant.SAMLV2:
qv.Set(grantType, grant.SAMLV2)
default:
return TokenResponse{}, fmt.Errorf("GetAccessTokenFromSamlGrant returned unknown SAML assertion type: %q", samlGrant.AssertionType)
}
return c.doTokenResp(ctx, authParameters, qv)
}
func (c Client) doTokenResp(ctx context.Context, authParams authority.AuthParams, qv url.Values) (TokenResponse, error) {
resp := TokenResponse{}
err := c.Comm.URLFormCall(ctx, authParams.Endpoints.TokenEndpoint, qv, &resp)
if err != nil {
return resp, err
}
resp.ComputeScope(authParams)
if c.testing {
return resp, nil
}
return resp, resp.Validate()
}
// prepURLVals returns an url.Values that sets various key/values if we are doing secrets
// or JWT assertions.
func prepURLVals(ctx context.Context, cc *Credential, authParams authority.AuthParams) (url.Values, error) {
params := url.Values{}
if cc.Secret != "" {
params.Set("client_secret", cc.Secret)
return params, nil
}
jwt, err := cc.JWT(ctx, authParams)
if err != nil {
return nil, err
}
params.Set("client_assertion", jwt)
params.Set("client_assertion_type", grant.ClientAssertion)
return params, nil
}
// openid required to get an id token
// offline_access required to get a refresh token
// profile required to get the client_info field back
var detectDefaultScopes = map[string]bool{
"openid": true,
"offline_access": true,
"profile": true,
}
var defaultScopes = []string{"openid", "offline_access", "profile"}
func AppendDefaultScopes(authParameters authority.AuthParams) []string {
scopes := make([]string, 0, len(authParameters.Scopes)+len(defaultScopes))
for _, scope := range authParameters.Scopes {
s := strings.TrimSpace(scope)
if s == "" {
continue
}
if detectDefaultScopes[scope] {
continue
}
scopes = append(scopes, scope)
}
scopes = append(scopes, defaultScopes...)
return scopes
}
// addClaims adds client capabilities and claims from AuthParams to the given url.Values
func addClaims(v url.Values, ap authority.AuthParams) error {
claims, err := ap.MergeCapabilitiesAndClaims()
if err == nil && claims != "" {
v.Set("claims", claims)
}
return err
}
func addScopeQueryParam(queryParams url.Values, authParameters authority.AuthParams) {
scopes := AppendDefaultScopes(authParameters)
queryParams.Set("scope", strings.Join(scopes, " "))
}
accesstokens_test.go 0000664 0000000 0000000 00000063166 14420263624 0034563 0 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/accesstokens // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package accesstokens
import (
"context"
"encoding/base64"
"errors"
"fmt"
"net/url"
"reflect"
"strings"
"testing"
"time"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json"
internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/internal/grant"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/wstrust"
"github.com/kylelemons/godebug/pretty"
)
var testAuthorityEndpoints = authority.NewEndpoints(
"https://login.microsoftonline.com/v2.0/authorize",
"https://login.microsoftonline.com/v2.0/token",
"https://login.microsoftonline.com/v2.0",
"login.microsoftonline.com",
)
var jwtDecoderFake = func(s string) ([]byte, error) {
if s == "error" {
return nil, errors.New("error")
}
return []byte(s), nil
}
type fakeURLCaller struct {
err bool
gotEndpoint string
gotQV url.Values
gotResp interface{}
}
func (f *fakeURLCaller) URLFormCall(ctx context.Context, endpoint string, qv url.Values, resp interface{}) error {
if f.err {
return errors.New("error")
}
f.gotEndpoint = endpoint
f.gotQV = qv
f.gotResp = resp
return nil
}
func (f *fakeURLCaller) compare(endpoint string, qv url.Values) error {
if f.gotEndpoint != endpoint {
return fmt.Errorf("got endpoint == %s, want endpoint == %s", f.gotEndpoint, endpoint)
}
if diff := pretty.Compare(qv, f.gotQV); diff != "" {
return fmt.Errorf("qv -want/+got:\n%s", diff)
}
return nil
}
func TestAccessTokenFromUsernamePassword(t *testing.T) {
authParams := authority.AuthParams{
Username: "username",
Password: "password",
Endpoints: testAuthorityEndpoints,
ClientID: "clientID",
}
tests := []struct {
desc string
err bool
commErr bool
createErr bool
qv url.Values
}{
{
desc: "Error: comm returns error",
err: true,
commErr: true,
},
{
desc: "Success",
qv: url.Values{
grantType: []string{grant.Password},
username: []string{authParams.Username},
password: []string{authParams.Password},
clientID: []string{authParams.ClientID},
clientInfo: []string{clientInfoVal},
},
},
}
for _, test := range tests {
if test.qv != nil {
addScopeQueryParam(test.qv, authParams)
}
fake := &fakeURLCaller{err: test.commErr}
client := Client{Comm: fake, testing: true}
// We don't care about the result, that is just a translation from the JSON handled
// automatically in the comm package. We care only that the comm package got what
// it needed.
_, err := client.FromUsernamePassword(context.Background(), authParams)
switch {
case err == nil && test.err:
t.Errorf("TestAccessTokenFromUsernamePassword(%s): got err == nil , want err != nil", test.desc)
continue
case err != nil && !test.err:
t.Errorf("TestAccessTokenFromUsernamePassword(%s): got err == %s , want err == nil", test.desc, err)
continue
case err != nil:
continue
}
if err := fake.compare(authParams.Endpoints.TokenEndpoint, test.qv); err != nil {
t.Errorf("TestAccessTokenFromUsernamePassword(%s): %s", test.desc, err)
}
}
}
func TestAccessTokenFromAuthCode(t *testing.T) {
authParams := authority.AuthParams{
Endpoints: testAuthorityEndpoints,
ClientID: "clientID",
Redirecturi: "redirectURI",
}
tests := []struct {
desc string
err bool
// commErr causes the comm call to return an error.
commErr bool
// createErr causes the TokenResponse creation to error.
createErr bool
authCodeRequest AuthCodeRequest
authCode string
codeVerifier string
qv url.Values
}{
{
desc: "Error: comm returns error",
err: true,
commErr: true,
authCodeRequest: AuthCodeRequest{
AuthParams: authParams,
Code: "authCode",
CodeChallenge: "codeVerifier",
Credential: &Credential{Secret: "secret"},
AppType: ATConfidential,
},
qv: url.Values{
"code": []string{"authCode"},
"code_verifier": []string{"codeVerifier"},
"redirect_uri": []string{"redirectURI"},
grantType: []string{grant.AuthCode},
clientID: []string{authParams.ClientID},
clientInfo: []string{clientInfoVal},
},
},
{
desc: "Error: Credential is nil",
authCodeRequest: AuthCodeRequest{
AuthParams: authParams,
Code: "authCode",
CodeChallenge: "codeVerifier",
AppType: ATConfidential,
},
qv: url.Values{
"code": []string{"authCode"},
"code_verifier": []string{"codeVerifier"},
"redirect_uri": []string{"redirectURI"},
grantType: []string{grant.AuthCode},
clientID: []string{authParams.ClientID},
clientInfo: []string{clientInfoVal},
},
err: true,
},
{
desc: "Success",
authCodeRequest: AuthCodeRequest{
AuthParams: authParams,
Code: "authCode",
CodeChallenge: "codeVerifier",
AppType: ATConfidential,
Credential: &Credential{Secret: "secret"},
},
qv: url.Values{
"code": []string{"authCode"},
"code_verifier": []string{"codeVerifier"},
"redirect_uri": []string{"redirectURI"},
"client_secret": []string{"secret"},
grantType: []string{grant.AuthCode},
clientID: []string{authParams.ClientID},
clientInfo: []string{clientInfoVal},
},
},
}
for _, test := range tests {
if test.qv != nil {
addScopeQueryParam(test.qv, authParams)
}
fake := &fakeURLCaller{err: test.err}
client := Client{Comm: fake, testing: true}
// We don't care about the result, that is just a translation from the JSON handled
// automatically in the comm package. We care only that the comm package got what
// it needed.
_, err := client.FromAuthCode(context.Background(), test.authCodeRequest)
switch {
case err == nil && test.err:
t.Errorf("TestAccessTokenFromAuthCode(%s): got err == nil , want err != nil", test.desc)
continue
case err != nil && !test.err:
t.Errorf("TestAccessTokenFromAuthCode(%s): got err == %s , want err == nil", test.desc, err)
continue
case err != nil:
continue
}
if err := fake.compare(authParams.Endpoints.TokenEndpoint, test.qv); err != nil {
t.Errorf("TestAccessTokenFromAuthCode(%s): %s", test.desc, err)
}
}
}
func TestAccessTokenFromRefreshToken(t *testing.T) {
authParams := authority.AuthParams{
Endpoints: testAuthorityEndpoints,
ClientID: "clientID",
Redirecturi: "redirectURI",
}
tests := []struct {
desc string
err bool
commErr bool
createErr bool
cred *Credential
refreshToken string
qv url.Values
}{
{
desc: "Error: comm returns error",
err: true,
commErr: true,
refreshToken: "refreshToken",
qv: url.Values{
"refresh_token": []string{"refreshToken"},
grantType: []string{grant.RefreshToken},
clientID: []string{authParams.ClientID},
clientInfo: []string{clientInfoVal},
},
},
{
desc: "Success(public app)",
refreshToken: "refreshToken",
qv: url.Values{
"refresh_token": []string{"refreshToken"},
grantType: []string{grant.RefreshToken},
clientID: []string{authParams.ClientID},
clientInfo: []string{clientInfoVal},
},
},
{
desc: "Success(confidential app)",
refreshToken: "refreshToken",
qv: url.Values{
"refresh_token": []string{"refreshToken"},
grantType: []string{grant.RefreshToken},
clientID: []string{authParams.ClientID},
clientInfo: []string{clientInfoVal},
},
},
}
for _, test := range tests {
if test.qv != nil {
addScopeQueryParam(test.qv, authParams)
}
fake := &fakeURLCaller{err: test.commErr}
client := Client{Comm: fake, testing: true}
// We don't care about the result, that is just a translation from the JSON handled
// automatically in the comm package. We care only that the comm package got what
// it needed.
_, err := client.FromRefreshToken(context.Background(), ATPublic, authParams, test.cred, test.refreshToken)
switch {
case err == nil && test.err:
t.Errorf("TestAccessTokenFromRefreshToken(%s): got err == nil , want err != nil", test.desc)
continue
case err != nil && !test.err:
t.Errorf("TestAccessTokenFromRefreshToken(%s): got err == %s , want err == nil", test.desc, err)
continue
case err != nil:
continue
}
if err := fake.compare(authParams.Endpoints.TokenEndpoint, test.qv); err != nil {
t.Errorf("TestAccessTokenFromRefreshToken(%s): %s", test.desc, err)
}
}
}
func TestAccessTokenWithClientSecret(t *testing.T) {
authParams := authority.AuthParams{
Endpoints: testAuthorityEndpoints,
ClientID: "clientID",
Redirecturi: "redirectURI",
}
tests := []struct {
desc string
err bool
commErr bool
createErr bool
clientSecret string
qv url.Values
}{
{
desc: "Error: comm returns error",
err: true,
commErr: true,
clientSecret: "clientSecret",
qv: url.Values{
"client_secret": []string{"clientSecret"},
grantType: []string{grant.ClientCredential},
clientID: []string{authParams.ClientID},
},
},
{
desc: "Success",
clientSecret: "clientSecret",
qv: url.Values{
"client_secret": []string{"clientSecret"},
grantType: []string{grant.ClientCredential},
clientID: []string{authParams.ClientID},
},
},
}
for _, test := range tests {
if test.qv != nil {
addScopeQueryParam(test.qv, authParams)
}
fake := &fakeURLCaller{err: test.err}
client := Client{Comm: fake, testing: true}
// We don't care about the result, that is just a translation from the JSON handled
// automatically in the comm package. We care only that the comm package got what
// it needed.
_, err := client.FromClientSecret(context.Background(), authParams, test.clientSecret)
switch {
case err == nil && test.err:
t.Errorf("TestAccessTokenWithClientSecret(%s): got err == nil , want err != nil", test.desc)
continue
case err != nil && !test.err:
t.Errorf("TestAccessTokenWithClientSecret(%s): got err == %s , want err == nil", test.desc, err)
continue
case err != nil:
continue
}
if err := fake.compare(authParams.Endpoints.TokenEndpoint, test.qv); err != nil {
t.Errorf("TestAccessTokenWithClientSecret(%s): %s", test.desc, err)
}
}
}
func TestAccessTokenWithAssertion(t *testing.T) {
authParams := authority.AuthParams{
Endpoints: testAuthorityEndpoints,
ClientID: "clientID",
Redirecturi: "redirectURI",
}
tests := []struct {
desc string
err bool
commErr bool
createErr bool
assertion string
params url.Values
qv url.Values
}{
{
desc: "Error: comm returns error",
err: true,
commErr: true,
assertion: "assertion",
qv: url.Values{
"client_assertion_type": []string{grant.ClientAssertion},
"client_assertion": []string{"assertion"},
grantType: []string{grant.ClientCredential},
clientInfo: []string{clientInfoVal},
clientID: []string{authParams.ClientID},
},
},
{
desc: "Success",
assertion: "assertion",
qv: url.Values{
"client_assertion_type": []string{grant.ClientAssertion},
"client_assertion": []string{"assertion"},
grantType: []string{grant.ClientCredential},
clientInfo: []string{clientInfoVal},
clientID: []string{authParams.ClientID},
},
},
}
for _, test := range tests {
if test.qv != nil {
addScopeQueryParam(test.qv, authParams)
}
fake := &fakeURLCaller{err: test.commErr}
client := Client{Comm: fake, testing: true}
// We don't care about the result, that is just a translation from the JSON handled
// automatically in the comm package. We care only that the comm package got what
// it needed.
_, err := client.FromAssertion(context.Background(), authParams, test.assertion)
switch {
case err == nil && test.err:
t.Errorf("TestAccessTokenWithAssertion(%s): got err == nil , want err != nil", test.desc)
continue
case err != nil && !test.err:
t.Errorf("TestAccessTokenWithAssertion(%s): got err == %s , want err == nil", test.desc, err)
continue
case err != nil:
continue
}
if err := fake.compare(authParams.Endpoints.TokenEndpoint, test.qv); err != nil {
t.Errorf("TestAccessTokenWithAssertion(%s): %s", test.desc, err)
}
}
}
func TestDeviceCodeResult(t *testing.T) {
authParams := authority.AuthParams{
Endpoints: testAuthorityEndpoints,
ClientID: "clientID",
Redirecturi: "redirectURI",
}
tests := []struct {
desc string
err bool
commErr bool
createErr bool
assertion string
params url.Values
qv url.Values
}{
{
desc: "Error: comm returns error",
err: true,
commErr: true,
qv: url.Values{
clientID: []string{authParams.ClientID},
},
},
{
desc: "Success",
qv: url.Values{
clientID: []string{authParams.ClientID},
},
},
}
for _, test := range tests {
if test.qv != nil {
addScopeQueryParam(test.qv, authParams)
}
fake := &fakeURLCaller{err: test.commErr}
client := Client{Comm: fake, testing: true}
// We don't care about the result, that is just a translation from the JSON handled
// automatically in the comm package. We care only that the comm package got what
// it needed.
_, err := client.DeviceCodeResult(context.Background(), authParams)
switch {
case err == nil && test.err:
t.Errorf("TestDeviceCodeResult(%s): got err == nil , want err != nil", test.desc)
continue
case err != nil && !test.err:
t.Errorf("TestDeviceCodeResult(%s): got err == %s , want err == nil", test.desc, err)
continue
case err != nil:
continue
}
wantEndpoint := strings.Replace(authParams.Endpoints.TokenEndpoint, "token", "devicecode", -1)
if err := fake.compare(wantEndpoint, test.qv); err != nil {
t.Errorf("TestDeviceCodeResult(%s): %s", test.desc, err)
}
}
}
func TestFromDeviceCodeResult(t *testing.T) {
authParams := authority.AuthParams{
Endpoints: testAuthorityEndpoints,
ClientID: "clientID",
Redirecturi: "redirectURI",
}
tests := []struct {
desc string
err bool
commErr bool
createErr bool
deviceCodeResult DeviceCodeResult
qv url.Values
}{
{
desc: "Error: comm returns error",
err: true,
commErr: true,
deviceCodeResult: NewDeviceCodeResult(
"userCode",
"deviceCode",
"verificationURL",
time.Now(),
1,
"message",
"clientID",
nil,
),
qv: url.Values{
deviceCode: []string{"deviceCode"},
grantType: []string{grant.DeviceCode},
clientID: []string{authParams.ClientID},
clientInfo: []string{clientInfoVal},
},
},
{
desc: "Success",
deviceCodeResult: NewDeviceCodeResult(
"userCode",
"deviceCode",
"verificationURL",
time.Now(),
1,
"message",
"clientID",
nil,
),
qv: url.Values{
deviceCode: []string{"deviceCode"},
grantType: []string{grant.DeviceCode},
clientID: []string{authParams.ClientID},
clientInfo: []string{clientInfoVal},
},
},
}
for _, test := range tests {
if test.qv != nil {
addScopeQueryParam(test.qv, authParams)
}
fake := &fakeURLCaller{err: test.commErr}
client := Client{Comm: fake, testing: true}
// We don't care about the result, that is just a translation from the JSON handled
// automatically in the comm package. We care only that the comm package got what
// it needed.
_, err := client.FromDeviceCodeResult(context.Background(), authParams, test.deviceCodeResult)
switch {
case err == nil && test.err:
t.Errorf("TestFromDeviceCodeResult(%s): got err == nil , want err != nil", test.desc)
continue
case err != nil && !test.err:
t.Errorf("TestFromDeviceCodeResult(%s): got err == %s , want err == nil", test.desc, err)
continue
case err != nil:
continue
}
if err := fake.compare(authParams.Endpoints.TokenEndpoint, test.qv); err != nil {
t.Errorf("TestFromDeviceCodeResult(%s): %s", test.desc, err)
}
}
}
func TestAccessTokenFromSamlGrant(t *testing.T) {
authParams := authority.AuthParams{
Username: "username",
Password: "password",
Endpoints: testAuthorityEndpoints,
ClientID: "clientID",
}
base64Assertion := base64.StdEncoding.WithPadding(base64.StdPadding).EncodeToString([]byte("assertion"))
tests := []struct {
desc string
err bool
commErr bool
createErr bool
samlGrant wstrust.SamlTokenInfo
qv url.Values
}{
{
desc: "Error: comm returns error",
err: true,
commErr: true,
samlGrant: wstrust.SamlTokenInfo{
AssertionType: grant.SAMLV1,
Assertion: "assertion",
},
qv: url.Values{
username: []string{"username"},
password: []string{"password"},
grantType: []string{grant.SAMLV1},
clientID: []string{authParams.ClientID},
clientInfo: []string{clientInfoVal},
"assertion": []string{base64Assertion},
},
},
{
desc: "Error: unknown grant type(empty space)",
err: true,
samlGrant: wstrust.SamlTokenInfo{
Assertion: "assertion",
},
qv: url.Values{
username: []string{"username"},
password: []string{"password"},
grantType: []string{grant.SAMLV1},
clientID: []string{authParams.ClientID},
clientInfo: []string{clientInfoVal},
"assertion": []string{base64Assertion},
},
},
{
desc: "Success: SAMLV1Grant",
samlGrant: wstrust.SamlTokenInfo{
AssertionType: grant.SAMLV1,
Assertion: "assertion",
},
qv: url.Values{
username: []string{"username"},
password: []string{"password"},
grantType: []string{grant.SAMLV1},
clientID: []string{authParams.ClientID},
clientInfo: []string{clientInfoVal},
"assertion": []string{base64Assertion},
},
},
{
desc: "Success: SAMLV2Grant",
samlGrant: wstrust.SamlTokenInfo{
AssertionType: grant.SAMLV2,
Assertion: "assertion",
},
qv: url.Values{
username: []string{"username"},
password: []string{"password"},
grantType: []string{grant.SAMLV2},
clientID: []string{authParams.ClientID},
clientInfo: []string{clientInfoVal},
"assertion": []string{base64Assertion},
},
},
}
for _, test := range tests {
if test.qv != nil {
addScopeQueryParam(test.qv, authParams)
}
fake := &fakeURLCaller{err: test.commErr}
client := Client{Comm: fake, testing: true}
// We don't care about the result, that is just a translation from the JSON handled
// automatically in the comm package. We care only that the comm package got what
// it needed.
_, err := client.FromSamlGrant(context.Background(), authParams, test.samlGrant)
switch {
case err == nil && test.err:
t.Errorf("TestAccessTokenFromSamlGrant(%s): got err == nil , want err != nil", test.desc)
continue
case err != nil && !test.err:
t.Errorf("TestAccessTokenFromSamlGrant(%s): got err == %s , want err == nil", test.desc, err)
continue
case err != nil:
continue
}
if err := fake.compare(authParams.Endpoints.TokenEndpoint, test.qv); err != nil {
t.Errorf("TestAccessTokenFromSamlGrant(%s): %s", test.desc, err)
}
}
}
func TestDecodeJWT(t *testing.T) {
encodedStr := "A-z_4ME"
expectedStr := []byte{3, 236, 255, 224, 193}
actualString, err := decodeJWT(encodedStr)
if err != nil {
t.Errorf("Error should be nil but it is %v", err)
}
if !reflect.DeepEqual(expectedStr, actualString) {
t.Errorf("Actual decoded string %s differs from expected decoded string %s", actualString, expectedStr)
}
}
func TestLocalAccountID(t *testing.T) {
id := &IDToken{
Subject: "sub",
}
actualLID := id.LocalAccountID()
if !reflect.DeepEqual("sub", actualLID) {
t.Errorf("Expected local account ID sub differs from actual local account ID %s", actualLID)
}
id.Oid = "oid"
actualLID = id.LocalAccountID()
if !reflect.DeepEqual("oid", actualLID) {
t.Errorf("Expected local account ID oid differs from actual local account ID %s", actualLID)
}
}
func TestTokenResponseUnmarshal(t *testing.T) {
tests := []struct {
desc string
payload string
want TokenResponse
jwtDecoder func(data string) ([]byte, error)
err bool
}{
{
desc: "Error: decodeJWT is going to error",
payload: `
{
"access_token": "secret",
"expires_in": 86399,
"ext_expires_in": 86399,
"client_info": error,
"scope": "openid profile"
}`,
err: true,
jwtDecoder: jwtDecoderFake,
},
{
desc: "Success",
payload: `
{
"access_token": "secret",
"expires_in": 86399,
"ext_expires_in": 86399,
"client_info": {"uid": "uid","utid": "utid"},
"scope": "openid profile"
}`,
want: TokenResponse{
AccessToken: "secret",
ExpiresOn: internalTime.DurationTime{T: time.Unix(86399, 0)},
ExtExpiresOn: internalTime.DurationTime{T: time.Unix(86399, 0)},
GrantedScopes: Scopes{Slice: []string{"openid", "profile"}},
ClientInfo: ClientInfo{
UID: "uid",
UTID: "utid",
},
},
jwtDecoder: jwtDecoderFake,
},
}
for _, test := range tests {
jwtDecoder = test.jwtDecoder
got := TokenResponse{}
err := json.Unmarshal([]byte(test.payload), &got)
switch {
case err == nil && test.err:
t.Errorf("TestCreateTokenResponse(%s): got err == nil, want err != nil", test.desc)
continue
case err != nil && !test.err:
t.Errorf("TestCreateTokenResponse(%s): got err == %v, want err == nil", test.desc, err)
continue
case err != nil:
continue
}
// Note: IncludeUnexported prevents minor differences in time.Time due to internal fields.
if diff := (&pretty.Config{IncludeUnexported: false}).Compare(test.want, got); diff != "" {
t.Errorf("TestCreateTokenResponse: -want/+got:\n%s", diff)
}
}
}
func TestTokenResponseValidate(t *testing.T) {
tests := []struct {
desc string
input TokenResponse
err bool
}{
{
desc: "Error: TokenResponse had .Error set",
input: TokenResponse{
OAuthResponseBase: authority.OAuthResponseBase{
Error: "error",
},
AccessToken: "token",
scopesComputed: true,
},
err: true,
},
{
desc: "Error: .AccessToken was empty",
input: TokenResponse{
scopesComputed: true,
},
err: true,
},
{
desc: "Error: .scopesComputed was false",
input: TokenResponse{
AccessToken: "token",
scopesComputed: false,
},
err: true,
},
{
desc: "Success",
input: TokenResponse{
AccessToken: "token",
scopesComputed: true,
},
},
}
for _, test := range tests {
err := test.input.Validate()
switch {
case err == nil && test.err:
t.Errorf("TestTokenResponseValidate(%s): got err == nil, want err != nil", test.desc)
case err != nil && !test.err:
t.Errorf("TestTokenResponseValidate(%s): got err == %s, want err == nil", test.desc, err)
}
}
}
func TestComputeScopes(t *testing.T) {
tests := []struct {
desc string
authParams authority.AuthParams
input TokenResponse
want TokenResponse
}{
{
desc: "authParam scopes copied in, no declined scopes",
authParams: authority.AuthParams{
Scopes: []string{
"scope0",
"scope1",
},
},
input: TokenResponse{},
want: TokenResponse{
GrantedScopes: Scopes{
Slice: []string{"scope0", "scope1"},
},
scopesComputed: true,
},
},
{
desc: "a few declined scopes",
authParams: authority.AuthParams{
Scopes: []string{
"scope0",
"scope1",
"scope2",
},
},
input: TokenResponse{
GrantedScopes: Scopes{
Slice: []string{
"scope0",
"scope1",
},
},
},
want: TokenResponse{
GrantedScopes: Scopes{
Slice: []string{"scope0", "scope1"},
},
DeclinedScopes: []string{"scope2"},
scopesComputed: true,
},
},
{
desc: "no declined scopes case insensitive",
authParams: authority.AuthParams{
Scopes: []string{
"scope0",
"scope1",
},
},
input: TokenResponse{
GrantedScopes: Scopes{
Slice: []string{
"Scope0",
"Scope1",
},
},
},
want: TokenResponse{
GrantedScopes: Scopes{
Slice: []string{"Scope0", "Scope1"},
},
DeclinedScopes: nil,
scopesComputed: true,
},
},
}
for _, test := range tests {
test.input.ComputeScope(test.authParams)
if diff := pretty.Compare(test.want, test.input); diff != "" {
t.Errorf("TestComputeScopes(%s): -want/+got:\n%s", test.desc, diff)
}
}
}
func TestHomeAccountID(t *testing.T) {
tests := []struct {
desc string
ci ClientInfo
want string
}{
{
desc: "UID and UTID is not set",
},
{
desc: "UID is not set",
ci: ClientInfo{UTID: "utid"},
},
{
desc: "UTID is not set",
ci: ClientInfo{UID: "uid"},
want: "uid.uid",
},
{
desc: "UID and UTID are set",
ci: ClientInfo{UID: "uid", UTID: "utid"},
want: "uid.utid",
},
}
for _, test := range tests {
got := test.ci.HomeAccountID()
if got != test.want {
t.Errorf("TestHomeAccountID(%s): got %q, want %q", test.desc, got, test.want)
}
}
}
func TestFindDeclinedScopes(t *testing.T) {
requestedScopes := []string{"user.read", "openid"}
grantedScopes := []string{"user.read"}
expectedDeclinedScopes := []string{"openid"}
actualDeclinedScopes := findDeclinedScopes(requestedScopes, grantedScopes)
if !reflect.DeepEqual(expectedDeclinedScopes, actualDeclinedScopes) {
t.Errorf("Actual declined scopes %v differ from expected declined scopes %v", actualDeclinedScopes, expectedDeclinedScopes)
}
}
microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/accesstokens/apptype_string.go0000664 0000000 0000000 00000001227 14420263624 0034154 0 ustar 00root root 0000000 0000000 // Code generated by "stringer -type=AppType"; DO NOT EDIT.
package accesstokens
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[ATUnknown-0]
_ = x[ATPublic-1]
_ = x[ATConfidential-2]
}
const _AppType_name = "ATUnknownATPublicATConfidential"
var _AppType_index = [...]uint8{0, 9, 17, 31}
func (i AppType) String() string {
if i < 0 || i >= AppType(len(_AppType_index)-1) {
return "AppType(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _AppType_name[_AppType_index[i]:_AppType_index[i+1]]
}
microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/accesstokens/tokens.go 0000664 0000000 0000000 00000025105 14420263624 0032410 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package accesstokens
import (
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"reflect"
"strings"
"time"
internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/shared"
)
// IDToken consists of all the information used to validate a user.
// https://docs.microsoft.com/azure/active-directory/develop/id-tokens .
type IDToken struct {
PreferredUsername string `json:"preferred_username,omitempty"`
GivenName string `json:"given_name,omitempty"`
FamilyName string `json:"family_name,omitempty"`
MiddleName string `json:"middle_name,omitempty"`
Name string `json:"name,omitempty"`
Oid string `json:"oid,omitempty"`
TenantID string `json:"tid,omitempty"`
Subject string `json:"sub,omitempty"`
UPN string `json:"upn,omitempty"`
Email string `json:"email,omitempty"`
AlternativeID string `json:"alternative_id,omitempty"`
Issuer string `json:"iss,omitempty"`
Audience string `json:"aud,omitempty"`
ExpirationTime int64 `json:"exp,omitempty"`
IssuedAt int64 `json:"iat,omitempty"`
NotBefore int64 `json:"nbf,omitempty"`
RawToken string
AdditionalFields map[string]interface{}
}
var null = []byte("null")
// UnmarshalJSON implements json.Unmarshaler.
func (i *IDToken) UnmarshalJSON(b []byte) error {
if bytes.Equal(null, b) {
return nil
}
// Because we have a custom unmarshaler, you
// cannot directly call json.Unmarshal here. If you do, it will call this function
// recursively until reach our recursion limit. We have to create a new type
// that doesn't have this method in order to use json.Unmarshal.
type idToken2 IDToken
jwt := strings.Trim(string(b), `"`)
jwtArr := strings.Split(jwt, ".")
if len(jwtArr) < 2 {
return errors.New("IDToken returned from server is invalid")
}
jwtPart := jwtArr[1]
jwtDecoded, err := decodeJWT(jwtPart)
if err != nil {
return fmt.Errorf("unable to unmarshal IDToken, problem decoding JWT: %w", err)
}
token := idToken2{}
err = json.Unmarshal(jwtDecoded, &token)
if err != nil {
return fmt.Errorf("unable to unmarshal IDToken: %w", err)
}
token.RawToken = jwt
*i = IDToken(token)
return nil
}
// IsZero indicates if the IDToken is the zero value.
func (i IDToken) IsZero() bool {
v := reflect.ValueOf(i)
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
if !field.IsZero() {
switch field.Kind() {
case reflect.Map, reflect.Slice:
if field.Len() == 0 {
continue
}
}
return false
}
}
return true
}
// LocalAccountID extracts an account's local account ID from an ID token.
func (i IDToken) LocalAccountID() string {
if i.Oid != "" {
return i.Oid
}
return i.Subject
}
// jwtDecoder is provided to allow tests to provide their own.
var jwtDecoder = decodeJWT
// ClientInfo is used to create a Home Account ID for an account.
type ClientInfo struct {
UID string `json:"uid"`
UTID string `json:"utid"`
AdditionalFields map[string]interface{}
}
// UnmarshalJSON implements json.Unmarshaler.s
func (c *ClientInfo) UnmarshalJSON(b []byte) error {
s := strings.Trim(string(b), `"`)
// Client info may be empty in some flows, e.g. certificate exchange.
if len(s) == 0 {
return nil
}
// Because we have a custom unmarshaler, you
// cannot directly call json.Unmarshal here. If you do, it will call this function
// recursively until reach our recursion limit. We have to create a new type
// that doesn't have this method in order to use json.Unmarshal.
type clientInfo2 ClientInfo
raw, err := jwtDecoder(s)
if err != nil {
return fmt.Errorf("TokenResponse client_info field had JWT decode error: %w", err)
}
var c2 clientInfo2
err = json.Unmarshal(raw, &c2)
if err != nil {
return fmt.Errorf("was unable to unmarshal decoded JWT in TokenRespone to ClientInfo: %w", err)
}
*c = ClientInfo(c2)
return nil
}
// HomeAccountID creates the home account ID.
func (c ClientInfo) HomeAccountID() string {
if c.UID == "" {
return ""
} else if c.UTID == "" {
return fmt.Sprintf("%s.%s", c.UID, c.UID)
} else {
return fmt.Sprintf("%s.%s", c.UID, c.UTID)
}
}
// Scopes represents scopes in a TokenResponse.
type Scopes struct {
Slice []string
}
// UnmarshalJSON implements json.Unmarshal.
func (s *Scopes) UnmarshalJSON(b []byte) error {
str := strings.Trim(string(b), `"`)
if len(str) == 0 {
return nil
}
sl := strings.Split(str, " ")
s.Slice = sl
return nil
}
// TokenResponse is the information that is returned from a token endpoint during a token acquisition flow.
type TokenResponse struct {
authority.OAuthResponseBase
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
FamilyID string `json:"foci"`
IDToken IDToken `json:"id_token"`
ClientInfo ClientInfo `json:"client_info"`
ExpiresOn internalTime.DurationTime `json:"expires_in"`
ExtExpiresOn internalTime.DurationTime `json:"ext_expires_in"`
GrantedScopes Scopes `json:"scope"`
DeclinedScopes []string // This is derived
AdditionalFields map[string]interface{}
scopesComputed bool
}
// ComputeScope computes the final scopes based on what was granted by the server and
// what our AuthParams were from the authority server. Per OAuth spec, if no scopes are returned, the response should be treated as if all scopes were granted
// This behavior can be observed in client assertion flows, but can happen at any time, this check ensures we treat
// those special responses properly Link to spec: https://tools.ietf.org/html/rfc6749#section-3.3
func (tr *TokenResponse) ComputeScope(authParams authority.AuthParams) {
if len(tr.GrantedScopes.Slice) == 0 {
tr.GrantedScopes = Scopes{Slice: authParams.Scopes}
} else {
tr.DeclinedScopes = findDeclinedScopes(authParams.Scopes, tr.GrantedScopes.Slice)
}
tr.scopesComputed = true
}
// Validate validates the TokenResponse has basic valid values. It must be called
// after ComputeScopes() is called.
func (tr *TokenResponse) Validate() error {
if tr.Error != "" {
return fmt.Errorf("%s: %s", tr.Error, tr.ErrorDescription)
}
if tr.AccessToken == "" {
return errors.New("response is missing access_token")
}
if !tr.scopesComputed {
return fmt.Errorf("TokenResponse hasn't had ScopesComputed() called")
}
return nil
}
func (tr *TokenResponse) CacheKey(authParams authority.AuthParams) string {
if authParams.AuthorizationType == authority.ATOnBehalfOf {
return authParams.AssertionHash()
}
if authParams.AuthorizationType == authority.ATClientCredentials {
return authParams.AppKey()
}
if authParams.IsConfidentialClient || authParams.AuthorizationType == authority.ATRefreshToken {
return tr.ClientInfo.HomeAccountID()
}
return ""
}
func findDeclinedScopes(requestedScopes []string, grantedScopes []string) []string {
declined := []string{}
grantedMap := map[string]bool{}
for _, s := range grantedScopes {
grantedMap[strings.ToLower(s)] = true
}
// Comparing the requested scopes with the granted scopes to see if there are any scopes that have been declined.
for _, r := range requestedScopes {
if !grantedMap[strings.ToLower(r)] {
declined = append(declined, r)
}
}
return declined
}
// decodeJWT decodes a JWT and converts it to a byte array representing a JSON object
// JWT has headers and payload base64url encoded without padding
// https://tools.ietf.org/html/rfc7519#section-3 and
// https://tools.ietf.org/html/rfc7515#section-2
func decodeJWT(data string) ([]byte, error) {
// https://tools.ietf.org/html/rfc7515#appendix-C
return base64.RawURLEncoding.DecodeString(data)
}
// RefreshToken is the JSON representation of a MSAL refresh token for encoding to storage.
type RefreshToken struct {
HomeAccountID string `json:"home_account_id,omitempty"`
Environment string `json:"environment,omitempty"`
CredentialType string `json:"credential_type,omitempty"`
ClientID string `json:"client_id,omitempty"`
FamilyID string `json:"family_id,omitempty"`
Secret string `json:"secret,omitempty"`
Realm string `json:"realm,omitempty"`
Target string `json:"target,omitempty"`
UserAssertionHash string `json:"user_assertion_hash,omitempty"`
AdditionalFields map[string]interface{}
}
// NewRefreshToken is the constructor for RefreshToken.
func NewRefreshToken(homeID, env, clientID, refreshToken, familyID string) RefreshToken {
return RefreshToken{
HomeAccountID: homeID,
Environment: env,
CredentialType: "RefreshToken",
ClientID: clientID,
FamilyID: familyID,
Secret: refreshToken,
}
}
// Key outputs the key that can be used to uniquely look up this entry in a map.
func (rt RefreshToken) Key() string {
var fourth = rt.FamilyID
if fourth == "" {
fourth = rt.ClientID
}
return strings.Join(
[]string{rt.HomeAccountID, rt.Environment, rt.CredentialType, fourth},
shared.CacheKeySeparator,
)
}
func (rt RefreshToken) GetSecret() string {
return rt.Secret
}
// DeviceCodeResult stores the response from the STS device code endpoint.
type DeviceCodeResult struct {
// UserCode is the code the user needs to provide when authentication at the verification URI.
UserCode string
// DeviceCode is the code used in the access token request.
DeviceCode string
// VerificationURL is the the URL where user can authenticate.
VerificationURL string
// ExpiresOn is the expiration time of device code in seconds.
ExpiresOn time.Time
// Interval is the interval at which the STS should be polled at.
Interval int
// Message is the message which should be displayed to the user.
Message string
// ClientID is the UUID issued by the authorization server for your application.
ClientID string
// Scopes is the OpenID scopes used to request access a protected API.
Scopes []string
}
// NewDeviceCodeResult creates a DeviceCodeResult instance.
func NewDeviceCodeResult(userCode, deviceCode, verificationURL string, expiresOn time.Time, interval int, message, clientID string, scopes []string) DeviceCodeResult {
return DeviceCodeResult{userCode, deviceCode, verificationURL, expiresOn, interval, message, clientID, scopes}
}
func (dcr DeviceCodeResult) String() string {
return fmt.Sprintf("UserCode: (%v)\nDeviceCode: (%v)\nURL: (%v)\nMessage: (%v)\n", dcr.UserCode, dcr.DeviceCode, dcr.VerificationURL, dcr.Message)
}
microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/authority/ 0000775 0000000 0000000 00000000000 14420263624 0030116 5 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/authority/authority.go 0000664 0000000 0000000 00000046061 14420263624 0032504 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package authority
import (
"context"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"os"
"path"
"strings"
"time"
"github.com/google/uuid"
)
const (
authorizationEndpoint = "https://%v/%v/oauth2/v2.0/authorize"
instanceDiscoveryEndpoint = "https://%v/common/discovery/instance"
tenantDiscoveryEndpointWithRegion = "https://%s.%s/%s/v2.0/.well-known/openid-configuration"
regionName = "REGION_NAME"
defaultAPIVersion = "2021-10-01"
imdsEndpoint = "http://169.254.169.254/metadata/instance/compute/location?format=text&api-version=" + defaultAPIVersion
autoDetectRegion = "TryAutoDetect"
)
// These are various hosts that host AAD Instance discovery endpoints.
const (
defaultHost = "login.microsoftonline.com"
loginMicrosoft = "login.microsoft.com"
loginWindows = "login.windows.net"
loginSTSWindows = "sts.windows.net"
loginMicrosoftOnline = defaultHost
)
// jsonCaller is an interface that allows us to mock the JSONCall method.
type jsonCaller interface {
JSONCall(ctx context.Context, endpoint string, headers http.Header, qv url.Values, body, resp interface{}) error
}
var aadTrustedHostList = map[string]bool{
"login.windows.net": true, // Microsoft Azure Worldwide - Used in validation scenarios where host is not this list
"login.chinacloudapi.cn": true, // Microsoft Azure China
"login.microsoftonline.de": true, // Microsoft Azure Blackforest
"login-us.microsoftonline.com": true, // Microsoft Azure US Government - Legacy
"login.microsoftonline.us": true, // Microsoft Azure US Government
"login.microsoftonline.com": true, // Microsoft Azure Worldwide
"login.cloudgovapi.us": true, // Microsoft Azure US Government
}
// TrustedHost checks if an AAD host is trusted/valid.
func TrustedHost(host string) bool {
if _, ok := aadTrustedHostList[host]; ok {
return true
}
return false
}
// OAuthResponseBase is the base JSON return message for an OAuth call.
// This is embedded in other calls to get the base fields from every response.
type OAuthResponseBase struct {
Error string `json:"error"`
SubError string `json:"suberror"`
ErrorDescription string `json:"error_description"`
ErrorCodes []int `json:"error_codes"`
CorrelationID string `json:"correlation_id"`
Claims string `json:"claims"`
}
// TenantDiscoveryResponse is the tenant endpoints from the OpenID configuration endpoint.
type TenantDiscoveryResponse struct {
OAuthResponseBase
AuthorizationEndpoint string `json:"authorization_endpoint"`
TokenEndpoint string `json:"token_endpoint"`
Issuer string `json:"issuer"`
AdditionalFields map[string]interface{}
}
// Validate validates that the response had the correct values required.
func (r *TenantDiscoveryResponse) Validate() error {
switch "" {
case r.AuthorizationEndpoint:
return errors.New("TenantDiscoveryResponse: authorize endpoint was not found in the openid configuration")
case r.TokenEndpoint:
return errors.New("TenantDiscoveryResponse: token endpoint was not found in the openid configuration")
case r.Issuer:
return errors.New("TenantDiscoveryResponse: issuer was not found in the openid configuration")
}
return nil
}
type InstanceDiscoveryMetadata struct {
PreferredNetwork string `json:"preferred_network"`
PreferredCache string `json:"preferred_cache"`
Aliases []string `json:"aliases"`
AdditionalFields map[string]interface{}
}
type InstanceDiscoveryResponse struct {
TenantDiscoveryEndpoint string `json:"tenant_discovery_endpoint"`
Metadata []InstanceDiscoveryMetadata `json:"metadata"`
AdditionalFields map[string]interface{}
}
//go:generate stringer -type=AuthorizeType
// AuthorizeType represents the type of token flow.
type AuthorizeType int
// These are all the types of token flows.
const (
ATUnknown AuthorizeType = iota
ATUsernamePassword
ATWindowsIntegrated
ATAuthCode
ATInteractive
ATClientCredentials
ATDeviceCode
ATRefreshToken
AccountByID
ATOnBehalfOf
)
// These are all authority types
const (
AAD = "MSSTS"
ADFS = "ADFS"
)
// AuthParams represents the parameters used for authorization for token acquisition.
type AuthParams struct {
AuthorityInfo Info
CorrelationID string
Endpoints Endpoints
ClientID string
// Redirecturi is used for auth flows that specify a redirect URI (e.g. local server for interactive auth flow).
Redirecturi string
HomeAccountID string
// Username is the user-name portion for username/password auth flow.
Username string
// Password is the password portion for username/password auth flow.
Password string
// Scopes is the list of scopes the user consents to.
Scopes []string
// AuthorizationType specifies the auth flow being used.
AuthorizationType AuthorizeType
// State is a random value used to prevent cross-site request forgery attacks.
State string
// CodeChallenge is derived from a code verifier and is sent in the auth request.
CodeChallenge string
// CodeChallengeMethod describes the method used to create the CodeChallenge.
CodeChallengeMethod string
// Prompt specifies the user prompt type during interactive auth.
Prompt string
// IsConfidentialClient specifies if it is a confidential client.
IsConfidentialClient bool
// SendX5C specifies if x5c claim(public key of the certificate) should be sent to STS.
SendX5C bool
// UserAssertion is the access token used to acquire token on behalf of user
UserAssertion string
// Capabilities the client will include with each token request, for example "CP1".
// Call [NewClientCapabilities] to construct a value for this field.
Capabilities ClientCapabilities
// Claims required for an access token to satisfy a conditional access policy
Claims string
// KnownAuthorityHosts don't require metadata discovery because they're known to the user
KnownAuthorityHosts []string
// LoginHint is a username with which to pre-populate account selection during interactive auth
LoginHint string
// DomainHint is a directive that can be used to accelerate the user to their federated IdP sign-in page
DomainHint string
}
// NewAuthParams creates an authorization parameters object.
func NewAuthParams(clientID string, authorityInfo Info) AuthParams {
return AuthParams{
ClientID: clientID,
AuthorityInfo: authorityInfo,
CorrelationID: uuid.New().String(),
}
}
// WithTenant returns a copy of the AuthParams having the specified tenant ID. If the given
// ID is empty, the copy is identical to the original. This function returns an error in
// several cases:
// - ID isn't specific (for example, it's "common")
// - ID is non-empty and the authority doesn't support tenants (for example, it's an ADFS authority)
// - the client is configured to authenticate only Microsoft accounts via the "consumers" endpoint
// - the resulting authority URL is invalid
func (p AuthParams) WithTenant(ID string) (AuthParams, error) {
switch ID {
case "", p.AuthorityInfo.Tenant:
// keep the default tenant because the caller didn't override it
return p, nil
case "common", "consumers", "organizations":
if p.AuthorityInfo.AuthorityType == AAD {
return p, fmt.Errorf(`tenant ID must be a specific tenant, not "%s"`, ID)
}
// else we'll return a better error below
}
if p.AuthorityInfo.AuthorityType != AAD {
return p, errors.New("the authority doesn't support tenants")
}
if p.AuthorityInfo.Tenant == "consumers" {
return p, errors.New(`client is configured to authenticate only personal Microsoft accounts, via the "consumers" endpoint`)
}
authority := "https://" + path.Join(p.AuthorityInfo.Host, ID)
info, err := NewInfoFromAuthorityURI(authority, p.AuthorityInfo.ValidateAuthority, p.AuthorityInfo.InstanceDiscoveryDisabled)
if err == nil {
info.Region = p.AuthorityInfo.Region
p.AuthorityInfo = info
}
return p, err
}
// MergeCapabilitiesAndClaims combines client capabilities and challenge claims into a value suitable for an authentication request's "claims" parameter.
func (p AuthParams) MergeCapabilitiesAndClaims() (string, error) {
claims := p.Claims
if len(p.Capabilities.asMap) > 0 {
if claims == "" {
// without claims the result is simply the capabilities
return p.Capabilities.asJSON, nil
}
// Otherwise, merge claims and capabilties into a single JSON object.
// We handle the claims challenge as a map because we don't know its structure.
var challenge map[string]any
if err := json.Unmarshal([]byte(claims), &challenge); err != nil {
return "", fmt.Errorf(`claims must be JSON. Are they base64 encoded? json.Unmarshal returned "%v"`, err)
}
if err := merge(p.Capabilities.asMap, challenge); err != nil {
return "", err
}
b, err := json.Marshal(challenge)
if err != nil {
return "", err
}
claims = string(b)
}
return claims, nil
}
// merges a into b without overwriting b's values. Returns an error when a and b share a key for which either has a non-object value.
func merge(a, b map[string]any) error {
for k, av := range a {
if bv, ok := b[k]; !ok {
// b doesn't contain this key => simply set it to a's value
b[k] = av
} else {
// b does contain this key => recursively merge a[k] into b[k], provided both are maps. If a[k] or b[k] isn't
// a map, return an error because merging would overwrite some value in b. Errors shouldn't occur in practice
// because the challenge will be from AAD, which knows the capabilities format.
if A, ok := av.(map[string]any); ok {
if B, ok := bv.(map[string]any); ok {
return merge(A, B)
} else {
// b[k] isn't a map
return errors.New("challenge claims conflict with client capabilities")
}
} else {
// a[k] isn't a map
return errors.New("challenge claims conflict with client capabilities")
}
}
}
return nil
}
// ClientCapabilities stores capabilities in the formats used by AuthParams.MergeCapabilitiesAndClaims.
// [NewClientCapabilities] precomputes these representations because capabilities are static for the
// lifetime of a client and are included with every authentication request i.e., these computations
// always have the same result and would otherwise have to be repeated for every request.
type ClientCapabilities struct {
// asJSON is for the common case: adding the capabilities to an auth request with no challenge claims
asJSON string
// asMap is for merging the capabilities with challenge claims
asMap map[string]any
}
func NewClientCapabilities(capabilities []string) (ClientCapabilities, error) {
c := ClientCapabilities{}
var err error
if len(capabilities) > 0 {
cpbs := make([]string, len(capabilities))
for i := 0; i < len(cpbs); i++ {
cpbs[i] = fmt.Sprintf(`"%s"`, capabilities[i])
}
c.asJSON = fmt.Sprintf(`{"access_token":{"xms_cc":{"values":[%s]}}}`, strings.Join(cpbs, ","))
// note our JSON is valid but we can't stop users breaking it with garbage like "}"
err = json.Unmarshal([]byte(c.asJSON), &c.asMap)
}
return c, err
}
// Info consists of information about the authority.
type Info struct {
Host string
CanonicalAuthorityURI string
AuthorityType string
UserRealmURIPrefix string
ValidateAuthority bool
Tenant string
Region string
InstanceDiscoveryDisabled bool
}
func firstPathSegment(u *url.URL) (string, error) {
pathParts := strings.Split(u.EscapedPath(), "/")
if len(pathParts) >= 2 {
return pathParts[1], nil
}
return "", errors.New(`authority must be an https URL such as "https://login.microsoftonline.com/"`)
}
// NewInfoFromAuthorityURI creates an AuthorityInfo instance from the authority URL provided.
func NewInfoFromAuthorityURI(authority string, validateAuthority bool, instanceDiscoveryDisabled bool) (Info, error) {
u, err := url.Parse(strings.ToLower(authority))
if err != nil || u.Scheme != "https" {
return Info{}, errors.New(`authority must be an https URL such as "https://login.microsoftonline.com/"`)
}
tenant, err := firstPathSegment(u)
if err != nil {
return Info{}, err
}
authorityType := AAD
if tenant == "adfs" {
authorityType = ADFS
}
// u.Host includes the port, if any, which is required for private cloud deployments
return Info{
Host: u.Host,
CanonicalAuthorityURI: fmt.Sprintf("https://%v/%v/", u.Host, tenant),
AuthorityType: authorityType,
UserRealmURIPrefix: fmt.Sprintf("https://%v/common/userrealm/", u.Hostname()),
ValidateAuthority: validateAuthority,
Tenant: tenant,
InstanceDiscoveryDisabled: instanceDiscoveryDisabled,
}, nil
}
// Endpoints consists of the endpoints from the tenant discovery response.
type Endpoints struct {
AuthorizationEndpoint string
TokenEndpoint string
selfSignedJwtAudience string
authorityHost string
}
// NewEndpoints creates an Endpoints object.
func NewEndpoints(authorizationEndpoint string, tokenEndpoint string, selfSignedJwtAudience string, authorityHost string) Endpoints {
return Endpoints{authorizationEndpoint, tokenEndpoint, selfSignedJwtAudience, authorityHost}
}
// UserRealmAccountType refers to the type of user realm.
type UserRealmAccountType string
// These are the different types of user realms.
const (
Unknown UserRealmAccountType = ""
Federated UserRealmAccountType = "Federated"
Managed UserRealmAccountType = "Managed"
)
// UserRealm is used for the username password request to determine user type
type UserRealm struct {
AccountType UserRealmAccountType `json:"account_type"`
DomainName string `json:"domain_name"`
CloudInstanceName string `json:"cloud_instance_name"`
CloudAudienceURN string `json:"cloud_audience_urn"`
// required if accountType is Federated
FederationProtocol string `json:"federation_protocol"`
FederationMetadataURL string `json:"federation_metadata_url"`
AdditionalFields map[string]interface{}
}
func (u UserRealm) validate() error {
switch "" {
case string(u.AccountType):
return errors.New("the account type (Federated or Managed) is missing")
case u.DomainName:
return errors.New("domain name of user realm is missing")
case u.CloudInstanceName:
return errors.New("cloud instance name of user realm is missing")
case u.CloudAudienceURN:
return errors.New("cloud Instance URN is missing")
}
if u.AccountType == Federated {
switch "" {
case u.FederationProtocol:
return errors.New("federation protocol of user realm is missing")
case u.FederationMetadataURL:
return errors.New("federation metadata URL of user realm is missing")
}
}
return nil
}
// Client represents the REST calls to authority backends.
type Client struct {
// Comm provides the HTTP transport client.
Comm jsonCaller // *comm.Client
}
func (c Client) UserRealm(ctx context.Context, authParams AuthParams) (UserRealm, error) {
endpoint := fmt.Sprintf("https://%s/common/UserRealm/%s", authParams.Endpoints.authorityHost, url.PathEscape(authParams.Username))
qv := url.Values{
"api-version": []string{"1.0"},
}
resp := UserRealm{}
err := c.Comm.JSONCall(
ctx,
endpoint,
http.Header{"client-request-id": []string{authParams.CorrelationID}},
qv,
nil,
&resp,
)
if err != nil {
return resp, err
}
return resp, resp.validate()
}
func (c Client) GetTenantDiscoveryResponse(ctx context.Context, openIDConfigurationEndpoint string) (TenantDiscoveryResponse, error) {
resp := TenantDiscoveryResponse{}
err := c.Comm.JSONCall(
ctx,
openIDConfigurationEndpoint,
http.Header{},
nil,
nil,
&resp,
)
return resp, err
}
// AADInstanceDiscovery attempts to discover a tenant endpoint (used in OIDC auth with an authorization endpoint).
// This is done by AAD which allows for aliasing of tenants (windows.sts.net is the same as login.windows.com).
func (c Client) AADInstanceDiscovery(ctx context.Context, authorityInfo Info) (InstanceDiscoveryResponse, error) {
region := ""
var err error
resp := InstanceDiscoveryResponse{}
if authorityInfo.Region != "" && authorityInfo.Region != autoDetectRegion {
region = authorityInfo.Region
} else if authorityInfo.Region == autoDetectRegion {
region = detectRegion(ctx)
}
if region != "" {
environment := authorityInfo.Host
switch environment {
case loginMicrosoft, loginWindows, loginSTSWindows, defaultHost:
environment = loginMicrosoft
}
resp.TenantDiscoveryEndpoint = fmt.Sprintf(tenantDiscoveryEndpointWithRegion, region, environment, authorityInfo.Tenant)
metadata := InstanceDiscoveryMetadata{
PreferredNetwork: fmt.Sprintf("%v.%v", region, authorityInfo.Host),
PreferredCache: authorityInfo.Host,
Aliases: []string{fmt.Sprintf("%v.%v", region, authorityInfo.Host), authorityInfo.Host},
}
resp.Metadata = []InstanceDiscoveryMetadata{metadata}
} else {
qv := url.Values{}
qv.Set("api-version", "1.1")
qv.Set("authorization_endpoint", fmt.Sprintf(authorizationEndpoint, authorityInfo.Host, authorityInfo.Tenant))
discoveryHost := defaultHost
if TrustedHost(authorityInfo.Host) {
discoveryHost = authorityInfo.Host
}
endpoint := fmt.Sprintf(instanceDiscoveryEndpoint, discoveryHost)
err = c.Comm.JSONCall(ctx, endpoint, http.Header{}, qv, nil, &resp)
}
return resp, err
}
func detectRegion(ctx context.Context) string {
region := os.Getenv(regionName)
if region != "" {
region = strings.ReplaceAll(region, " ", "")
return strings.ToLower(region)
}
// HTTP call to IMDS endpoint to get region
// Refer : https://identitydivision.visualstudio.com/DevEx/_git/AuthLibrariesApiReview?path=%2FPinAuthToRegion%2FAAD%20SDK%20Proposal%20to%20Pin%20Auth%20to%20region.md&_a=preview&version=GBdev
// Set a 2 second timeout for this http client which only does calls to IMDS endpoint
client := http.Client{
Timeout: time.Duration(2 * time.Second),
}
req, _ := http.NewRequest("GET", imdsEndpoint, nil)
req.Header.Set("Metadata", "true")
resp, err := client.Do(req)
// If the request times out or there is an error, it is retried once
if err != nil || resp.StatusCode != 200 {
resp, err = client.Do(req)
if err != nil || resp.StatusCode != 200 {
return ""
}
}
defer resp.Body.Close()
response, err := io.ReadAll(resp.Body)
if err != nil {
return ""
}
return string(response)
}
func (a *AuthParams) CacheKey(isAppCache bool) string {
if a.AuthorizationType == ATOnBehalfOf {
return a.AssertionHash()
}
if a.AuthorizationType == ATClientCredentials || isAppCache {
return a.AppKey()
}
if a.AuthorizationType == ATRefreshToken || a.AuthorizationType == AccountByID {
return a.HomeAccountID
}
return ""
}
func (a *AuthParams) AssertionHash() string {
hasher := sha256.New()
// Per documentation this never returns an error : https://pkg.go.dev/hash#pkg-types
_, _ = hasher.Write([]byte(a.UserAssertion))
sha := base64.URLEncoding.EncodeToString(hasher.Sum(nil))
return sha
}
func (a *AuthParams) AppKey() string {
if a.AuthorityInfo.Tenant != "" {
return fmt.Sprintf("%s_%s_AppTokenCache", a.ClientID, a.AuthorityInfo.Tenant)
}
return fmt.Sprintf("%s__AppTokenCache", a.ClientID)
}
microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/authority/authority_test.go 0000664 0000000 0000000 00000034174 14420263624 0033545 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package authority
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"reflect"
"strings"
"testing"
"github.com/kylelemons/godebug/pretty"
)
type fakeJSONCaller struct {
err bool
resp []byte
gotEndpoint string
gotHeaders http.Header
gotQV url.Values
gotBody interface{}
gotResp interface{}
}
func (f *fakeJSONCaller) JSONCall(ctx context.Context, endpoint string, headers http.Header, qv url.Values, body, resp interface{}) error {
if f.err {
return errors.New("error")
}
f.gotEndpoint = endpoint
f.gotHeaders = headers
f.gotQV = qv
f.gotBody = body
f.gotResp = resp
if f.resp != nil {
if err := json.Unmarshal(f.resp, resp); err != nil {
return err
}
}
return nil
}
func (f *fakeJSONCaller) compare(endpoint string, headers http.Header, qv url.Values, body, resp interface{}) error {
if f.gotEndpoint != endpoint {
return fmt.Errorf("got endpoint == %s, want endpoint == %s", f.gotEndpoint, endpoint)
}
if diff := pretty.Compare(headers, f.gotHeaders); diff != "" {
return fmt.Errorf("headers -want/+got:\n%s", diff)
}
if diff := pretty.Compare(qv, f.gotQV); diff != "" {
return fmt.Errorf("qv -want/+got:\n%s", diff)
}
if diff := pretty.Compare(body, f.gotBody); diff != "" {
return fmt.Errorf("body -want/+got:\n%s", diff)
}
gotValue := reflect.ValueOf(f.gotResp)
if gotValue.Kind() != reflect.Ptr {
return fmt.Errorf("resp cannot be a non-pointer type")
}
gotValue = gotValue.Elem()
gotName := gotValue.Type().Name()
wantName := reflect.ValueOf(resp).Elem().Type().Name()
if gotName != wantName {
return fmt.Errorf("resp type was %s, want %s", gotName, wantName)
}
return nil
}
var testAuthorityEndpoints = NewEndpoints(
"https://login.microsoftonline.com/v2.0/authorize",
"https://login.microsoftonline.com/v2.0/token",
"https://login.microsoftonline.com/v2.0",
"login.microsoftonline.com",
)
func TestUserRealm(t *testing.T) {
authParams := AuthParams{
Username: "username",
Endpoints: testAuthorityEndpoints,
CorrelationID: "id",
}
tests := []struct {
desc string
err bool
endpoint string
jsonResp *UserRealm
headers http.Header
qv url.Values
resp interface{}
}{
{
desc: "Error: comm returns error",
err: true,
},
{
desc: "Success",
endpoint: fmt.Sprintf("https://login.microsoftonline.com/common/UserRealm/%s", url.PathEscape(authParams.Username)),
headers: http.Header{
"client-request-id": []string{"id"},
},
qv: url.Values{
"api-version": []string{"1.0"},
},
jsonResp: &UserRealm{
AccountType: "Managed",
DomainName: "microsoftonline.com",
CloudInstanceName: "instance",
CloudAudienceURN: "urn",
},
resp: &UserRealm{},
},
}
for _, test := range tests {
fake := &fakeJSONCaller{err: test.err}
client := Client{fake}
if test.jsonResp != nil {
b, err := json.Marshal(test.jsonResp)
if err != nil {
panic(err)
}
fake.resp = b
}
// We don't care about the result, that is just a translation from the JSON handled
// automatically in the comm package. We care only that the comm package got what
// it needed.
_, err := client.UserRealm(context.Background(), authParams)
switch {
case err == nil && test.err:
t.Errorf("TestUserRealm(%s): got err == nil , want err != nil", test.desc)
continue
case err != nil && !test.err:
t.Errorf("TestUserRealm(%s): got err == %s , want err == nil", test.desc, err)
continue
case err != nil:
continue
}
if err := fake.compare(test.endpoint, test.headers, test.qv, nil, test.resp); err != nil {
t.Errorf("TestUserRealm(%s): %s", test.desc, err)
}
}
}
func TestTenantDiscoveryResponse(t *testing.T) {
tests := []struct {
desc string
err bool
endpoint string
resp interface{}
}{
{
desc: "Error: comm returns error",
err: true,
},
{
desc: "Success",
endpoint: "endpoint",
resp: &TenantDiscoveryResponse{},
},
}
for _, test := range tests {
fake := &fakeJSONCaller{err: test.err}
client := Client{fake}
// We don't care about the result, that is just a translation from the JSON handled
// automatically in the comm package. We care only that the comm package got what
// it needed.
_, err := client.GetTenantDiscoveryResponse(context.Background(), "endpoint")
switch {
case err == nil && test.err:
t.Errorf("TestTenantDiscoveryResponse(%s): got err == nil , want err != nil", test.desc)
continue
case err != nil && !test.err:
t.Errorf("TestTenantDiscoveryResponse(%s): got err == %s , want err == nil", test.desc, err)
continue
case err != nil:
continue
}
if err := fake.compare(test.endpoint, http.Header{}, nil, nil, test.resp); err != nil {
t.Errorf("TestTenantDiscoveryResponse(%s): %s", test.desc, err)
}
}
}
func TestAADInstanceDiscovery(t *testing.T) {
tests := []struct {
desc string
err bool
authInfo Info
endpoint string
qv url.Values
resp interface{}
}{
{
desc: "Error: comm returns error",
err: true,
},
{
desc: "Success with authorityInfo.Host not in trusted list",
endpoint: fmt.Sprintf(instanceDiscoveryEndpoint, defaultHost),
authInfo: Info{
Host: "host",
Tenant: "tenant",
},
qv: url.Values{
"api-version": []string{"1.1"},
"authorization_endpoint": []string{fmt.Sprintf(authorizationEndpoint, "host", "tenant")},
},
resp: &InstanceDiscoveryResponse{},
},
{
desc: "Success with authorityInfo.Host in trusted list",
endpoint: fmt.Sprintf(instanceDiscoveryEndpoint, "login.microsoftonline.de"),
authInfo: Info{
Host: "login.microsoftonline.de",
Tenant: "tenant",
},
qv: url.Values{
"api-version": []string{"1.1"},
"authorization_endpoint": []string{fmt.Sprintf(authorizationEndpoint, "login.microsoftonline.de", "tenant")},
},
resp: &InstanceDiscoveryResponse{},
},
}
for _, test := range tests {
fake := &fakeJSONCaller{err: test.err}
client := Client{fake}
// We don't care about the result, that is just a translation from the JSON handled
// automatically in the comm package. We care only that the comm package got what
// it needed.
_, err := client.AADInstanceDiscovery(context.Background(), test.authInfo)
switch {
case err == nil && test.err:
t.Errorf("AADInstanceDiscovery(%s): got err == nil , want err != nil", test.desc)
continue
case err != nil && !test.err:
t.Errorf("AADInstanceDiscovery(%s): got err == %s , want err == nil", test.desc, err)
continue
case err != nil:
continue
}
if err := fake.compare(test.endpoint, http.Header{}, test.qv, nil, test.resp); err != nil {
t.Errorf("AADInstanceDiscovery(%s): %s", test.desc, err)
}
}
}
func TestAADInstanceDiscoveryWithRegion(t *testing.T) {
client := Client{&fakeJSONCaller{}}
region := "region"
discoveryPath := "tenant/v2.0/.well-known/openid-configuration"
publicCloudEndpoint := fmt.Sprintf("https://%s.login.microsoft.com/%s", region, discoveryPath)
for _, test := range []struct{ host, expectedEndpoint string }{
{"login.chinacloudapi.cn", fmt.Sprintf("https://%s.login.chinacloudapi.cn/%s", region, discoveryPath)},
{"login.microsoft.com", publicCloudEndpoint},
{"login.microsoftonline.com", publicCloudEndpoint},
{"login.windows.net", publicCloudEndpoint},
{"login.windows-ppe.net", fmt.Sprintf("https://%s.login.windows-ppe.net/%s", region, discoveryPath)},
{"sts.windows.net", publicCloudEndpoint},
} {
t.Run(test.host, func(t *testing.T) {
authInfo := Info{Host: test.host, Tenant: "tenant", Region: region}
resp, err := client.AADInstanceDiscovery(context.Background(), authInfo)
if err != nil {
t.Errorf("AADInstanceDiscoveryWithRegion failing with %s", err)
}
expectedPreferredNetwork := fmt.Sprintf("%v.%v", region, test.host)
expectedPreferredCache := test.host
if resp.TenantDiscoveryEndpoint != test.expectedEndpoint {
t.Errorf("AADInstanceDiscoveryWithRegion incorrect TenantDiscoveryEndpoint: got: %s, want: %s", resp.TenantDiscoveryEndpoint, test.expectedEndpoint)
}
if resp.Metadata[0].PreferredNetwork != expectedPreferredNetwork {
t.Errorf("AADInstanceDiscoveryWithRegion incorrect Preferred Network got: %s, want: %s", resp.Metadata[0].PreferredNetwork, expectedPreferredNetwork)
}
if resp.Metadata[0].PreferredCache != expectedPreferredCache {
t.Errorf("AADInstanceDiscoveryWithRegion incorrect Preferred Cache got: %s, want: %s", resp.Metadata[0].PreferredCache, expectedPreferredCache)
}
})
}
}
func TestCreateAuthorityInfoFromAuthorityUri(t *testing.T) {
const authorityURI = "https://login.microsoftonline.com/common/"
want := Info{
Host: "login.microsoftonline.com",
CanonicalAuthorityURI: authorityURI,
AuthorityType: "MSSTS",
UserRealmURIPrefix: "https://login.microsoftonline.com/common/userrealm/",
Tenant: "common",
ValidateAuthority: true,
}
got, err := NewInfoFromAuthorityURI(authorityURI, true, false)
if err != nil {
t.Fatalf("TestCreateAuthorityInfoFromAuthorityUri: got err == %s, want err == nil", err)
}
if diff := pretty.Compare(want, got); diff != "" {
t.Errorf("TestCreateAuthorityInfoFromAuthorityUri: -want/+got:\n%s", diff)
}
}
func TestAuthParamsWithTenant(t *testing.T) {
uuid1 := "00000000-0000-0000-0000-000000000000"
uuid2 := strings.ReplaceAll(uuid1, "0", "1")
host := "https://localhost/"
for _, test := range []struct {
authority, expectedAuthority, tenant string
expectError bool
}{
{authority: host + "common", tenant: uuid1, expectedAuthority: host + uuid1},
{authority: host + "organizations", tenant: uuid1, expectedAuthority: host + uuid1},
{authority: host + uuid1, tenant: uuid2, expectedAuthority: host + uuid2},
{authority: host + uuid1, tenant: "common", expectError: true},
{authority: host + uuid1, tenant: "organizations", expectError: true},
{authority: host + "adfs", tenant: uuid1, expectError: true},
{authority: host + "consumers", tenant: uuid1, expectError: true},
} {
t.Run("", func(t *testing.T) {
info, err := NewInfoFromAuthorityURI(test.authority, false, false)
if err != nil {
t.Fatal(err)
}
params := NewAuthParams("client-id", info)
p, err := params.WithTenant(test.tenant)
if test.expectError {
if err == nil {
t.Fatal("expected an error")
}
return
}
if err != nil {
t.Fatal(err)
}
if v := strings.TrimSuffix(p.AuthorityInfo.CanonicalAuthorityURI, "/"); v != test.expectedAuthority {
t.Fatalf(`unexpected tenant "%s"`, v)
}
})
}
// WithTenant shouldn't change AuthorityInfo fields unrelated to the tenant, such as Region
t.Run("AuthorityInfo", func(t *testing.T) {
a := "A"
b := "B"
before, err := NewInfoFromAuthorityURI("https://localhost/"+a, true, false)
if err != nil {
t.Fatal(err)
}
before.Region = "region"
params := NewAuthParams("client-id", before)
p, err := params.WithTenant(b)
if err != nil {
t.Fatal(err)
}
after := p.AuthorityInfo
// these values should be different because they contain the tenant (this is tested above)
after.CanonicalAuthorityURI = before.CanonicalAuthorityURI
after.Tenant = before.Tenant
// With those fields equal, we can compare the before and after Infos without enumerating
// their fields i.e., we can implicitly compare all the other fields at once. With this
// approach, when Info gets a new field, this test needs an update only if that field
// contains the tenant, in which case this test will break so maintainers don't overlook it.
if diff := pretty.Compare(before, after); diff != "" {
t.Fatal(diff)
}
})
}
func TestMergeCapabilitiesAndClaims(t *testing.T) {
for _, test := range []struct {
capabilities []string
challenge, desc, expected string
err bool
}{
{
desc: "no capabilities or challenge",
expected: "",
},
{
desc: "encoded challenge",
capabilities: []string{"cp1"},
challenge: "eyJpZF90b2tlbiI6eyJhdXRoX3RpbWUiOnsiZXNzZW50aWFsIjp0cnVlfX19",
err: true,
},
{
desc: "only capabilities",
capabilities: []string{"cp1"},
expected: `{"access_token":{"xms_cc":{"values":["cp1"]}}}`,
},
{
desc: "only challenge",
challenge: `{"id_token":{"auth_time":{"essential":true}}}`,
expected: `{"id_token":{"auth_time":{"essential":true}}}`,
},
{
desc: "overlapping claim", // i.e. capabilities and claims are siblings
capabilities: []string{"cp1", "cp2"},
challenge: `{"access_token":{"nbf":{"essential":true, "value":"42"}}}`,
expected: `{"access_token":{"nbf":{"essential":true, "value":"42"}, "xms_cc":{"values":["cp1","cp2"]}}}`,
},
{
desc: "non-overlapping claim",
capabilities: []string{"cp1", "cp2"},
challenge: `{"id_token":{"auth_time":{"essential":true}}}`,
expected: `{"id_token":{"auth_time":{"essential":true}}, "access_token":{"xms_cc":{"values":["cp1","cp2"]}}}`,
},
{
desc: "overlapping and non-overlapping claims",
capabilities: []string{"cp1", "cp2"},
challenge: `{"id_token":{"auth_time":{"essential":true}},"access_token":{"nbf":{"essential":true, "value":"42"}}}`,
expected: `{"id_token":{"auth_time":{"essential":true}},"access_token":{"nbf":{"essential":true, "value":"42"},"xms_cc":{"values":["cp1","cp2"]}}}`,
},
} {
cpb, err := NewClientCapabilities(test.capabilities)
if err != nil {
t.Fatal(err)
}
ap := AuthParams{Capabilities: cpb, Claims: test.challenge}
t.Run(test.desc, func(t *testing.T) {
var expected map[string]any
if err := json.Unmarshal([]byte(test.expected), &expected); err != nil && test.expected != "" {
t.Fatal("test bug: the expected result must be JSON or an empty string")
}
merged, err := ap.MergeCapabilitiesAndClaims()
if err != nil {
if test.err {
return
}
t.Fatal(err)
}
if merged == test.expected {
return
}
var actual map[string]any
if err = json.Unmarshal([]byte(merged), &actual); err != nil {
t.Fatal(err)
}
if diff := pretty.Compare(expected, actual); diff != "" {
t.Fatal(diff)
}
})
}
}
authorizetype_string.go 0000664 0000000 0000000 00000001703 14420263624 0034671 0 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/authority // Code generated by "stringer -type=AuthorizeType"; DO NOT EDIT.
package authority
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[ATUnknown-0]
_ = x[ATUsernamePassword-1]
_ = x[ATWindowsIntegrated-2]
_ = x[ATAuthCode-3]
_ = x[ATInteractive-4]
_ = x[ATClientCredentials-5]
_ = x[ATDeviceCode-6]
_ = x[ATRefreshToken-7]
}
const _AuthorizeType_name = "ATUnknownATUsernamePasswordATWindowsIntegratedATAuthCodeATInteractiveATClientCredentialsATDeviceCodeATRefreshToken"
var _AuthorizeType_index = [...]uint8{0, 9, 27, 46, 56, 69, 88, 100, 114}
func (i AuthorizeType) String() string {
if i < 0 || i >= AuthorizeType(len(_AuthorizeType_index)-1) {
return "AuthorizeType(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _AuthorizeType_name[_AuthorizeType_index[i]:_AuthorizeType_index[i+1]]
}
microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/internal/ 0000775 0000000 0000000 00000000000 14420263624 0027702 5 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/internal/comm/ 0000775 0000000 0000000 00000000000 14420263624 0030635 5 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/internal/comm/comm.go 0000664 0000000 0000000 00000022651 14420263624 0032125 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
// Package comm provides helpers for communicating with HTTP backends.
package comm
import (
"bytes"
"context"
"encoding/json"
"encoding/xml"
"fmt"
"io"
"net/http"
"net/url"
"reflect"
"runtime"
"strings"
"time"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/errors"
customJSON "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/version"
"github.com/google/uuid"
)
// HTTPClient represents an HTTP client.
// It's usually an *http.Client from the standard library.
type HTTPClient interface {
// Do sends an HTTP request and returns an HTTP response.
Do(req *http.Request) (*http.Response, error)
// CloseIdleConnections closes any idle connections in a "keep-alive" state.
CloseIdleConnections()
}
// Client provides a wrapper to our *http.Client that handles compression and serialization needs.
type Client struct {
client HTTPClient
}
// New returns a new Client object.
func New(httpClient HTTPClient) *Client {
if httpClient == nil {
panic("http.Client cannot == nil")
}
return &Client{client: httpClient}
}
// JSONCall connects to the REST endpoint passing the HTTP query values, headers and JSON conversion
// of body in the HTTP body. It automatically handles compression and decompression with gzip. The response is JSON
// unmarshalled into resp. resp must be a pointer to a struct. If the body struct contains a field called
// "AdditionalFields" we use a custom marshal/unmarshal engine.
func (c *Client) JSONCall(ctx context.Context, endpoint string, headers http.Header, qv url.Values, body, resp interface{}) error {
if qv == nil {
qv = url.Values{}
}
v := reflect.ValueOf(resp)
if err := c.checkResp(v); err != nil {
return err
}
// Choose a JSON marshal/unmarshal depending on if we have AdditionalFields attribute.
var marshal = json.Marshal
var unmarshal = json.Unmarshal
if _, ok := v.Elem().Type().FieldByName("AdditionalFields"); ok {
marshal = customJSON.Marshal
unmarshal = customJSON.Unmarshal
}
u, err := url.Parse(endpoint)
if err != nil {
return fmt.Errorf("could not parse path URL(%s): %w", endpoint, err)
}
u.RawQuery = qv.Encode()
addStdHeaders(headers)
req := &http.Request{Method: http.MethodGet, URL: u, Header: headers}
if body != nil {
// Note: In case your wondering why we are not gzip encoding....
// I'm not sure if these various services support gzip on send.
headers.Add("Content-Type", "application/json; charset=utf-8")
data, err := marshal(body)
if err != nil {
return fmt.Errorf("bug: conn.Call(): could not marshal the body object: %w", err)
}
req.Body = io.NopCloser(bytes.NewBuffer(data))
req.Method = http.MethodPost
}
data, err := c.do(ctx, req)
if err != nil {
return err
}
if resp != nil {
if err := unmarshal(data, resp); err != nil {
return fmt.Errorf("json decode error: %w\njson message bytes were: %s", err, string(data))
}
}
return nil
}
// XMLCall connects to an endpoint and decodes the XML response into resp. This is used when
// sending application/xml . If sending XML via SOAP, use SOAPCall().
func (c *Client) XMLCall(ctx context.Context, endpoint string, headers http.Header, qv url.Values, resp interface{}) error {
if err := c.checkResp(reflect.ValueOf(resp)); err != nil {
return err
}
if qv == nil {
qv = url.Values{}
}
u, err := url.Parse(endpoint)
if err != nil {
return fmt.Errorf("could not parse path URL(%s): %w", endpoint, err)
}
u.RawQuery = qv.Encode()
headers.Set("Content-Type", "application/xml; charset=utf-8") // This was not set in he original Mex(), but...
addStdHeaders(headers)
return c.xmlCall(ctx, u, headers, "", resp)
}
// SOAPCall returns the SOAP message given an endpoint, action, body of the request and the response object to marshal into.
func (c *Client) SOAPCall(ctx context.Context, endpoint, action string, headers http.Header, qv url.Values, body string, resp interface{}) error {
if body == "" {
return fmt.Errorf("cannot make a SOAP call with body set to empty string")
}
if err := c.checkResp(reflect.ValueOf(resp)); err != nil {
return err
}
if qv == nil {
qv = url.Values{}
}
u, err := url.Parse(endpoint)
if err != nil {
return fmt.Errorf("could not parse path URL(%s): %w", endpoint, err)
}
u.RawQuery = qv.Encode()
headers.Set("Content-Type", "application/soap+xml; charset=utf-8")
headers.Set("SOAPAction", action)
addStdHeaders(headers)
return c.xmlCall(ctx, u, headers, body, resp)
}
// xmlCall sends an XML in body and decodes into resp. This simply does the transport and relies on
// an upper level call to set things such as SOAP parameters and Content-Type, if required.
func (c *Client) xmlCall(ctx context.Context, u *url.URL, headers http.Header, body string, resp interface{}) error {
req := &http.Request{Method: http.MethodGet, URL: u, Header: headers}
if len(body) > 0 {
req.Method = http.MethodPost
req.Body = io.NopCloser(strings.NewReader(body))
}
data, err := c.do(ctx, req)
if err != nil {
return err
}
return xml.Unmarshal(data, resp)
}
// URLFormCall is used to make a call where we need to send application/x-www-form-urlencoded data
// to the backend and receive JSON back. qv will be encoded into the request body.
func (c *Client) URLFormCall(ctx context.Context, endpoint string, qv url.Values, resp interface{}) error {
if len(qv) == 0 {
return fmt.Errorf("URLFormCall() requires qv to have non-zero length")
}
if err := c.checkResp(reflect.ValueOf(resp)); err != nil {
return err
}
u, err := url.Parse(endpoint)
if err != nil {
return fmt.Errorf("could not parse path URL(%s): %w", endpoint, err)
}
headers := http.Header{}
headers.Set("Content-Type", "application/x-www-form-urlencoded; charset=utf-8")
addStdHeaders(headers)
enc := qv.Encode()
req := &http.Request{
Method: http.MethodPost,
URL: u,
Header: headers,
ContentLength: int64(len(enc)),
Body: io.NopCloser(strings.NewReader(enc)),
GetBody: func() (io.ReadCloser, error) {
return io.NopCloser(strings.NewReader(enc)), nil
},
}
data, err := c.do(ctx, req)
if err != nil {
return err
}
v := reflect.ValueOf(resp)
if err := c.checkResp(v); err != nil {
return err
}
var unmarshal = json.Unmarshal
if _, ok := v.Elem().Type().FieldByName("AdditionalFields"); ok {
unmarshal = customJSON.Unmarshal
}
if resp != nil {
if err := unmarshal(data, resp); err != nil {
return fmt.Errorf("json decode error: %w\nraw message was: %s", err, string(data))
}
}
return nil
}
// do makes the HTTP call to the server and returns the contents of the body.
func (c *Client) do(ctx context.Context, req *http.Request) ([]byte, error) {
if _, ok := ctx.Deadline(); !ok {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, 30*time.Second)
defer cancel()
}
req = req.WithContext(ctx)
reply, err := c.client.Do(req)
if err != nil {
return nil, fmt.Errorf("server response error:\n %w", err)
}
defer reply.Body.Close()
data, err := c.readBody(reply)
if err != nil {
return nil, fmt.Errorf("could not read the body of an HTTP Response: %w", err)
}
reply.Body = io.NopCloser(bytes.NewBuffer(data))
// NOTE: This doesn't happen immediately after the call so that we can get an error message
// from the server and include it in our error.
switch reply.StatusCode {
case 200, 201:
default:
sd := strings.TrimSpace(string(data))
if sd != "" {
// We probably have the error in the body.
return nil, errors.CallErr{
Req: req,
Resp: reply,
Err: fmt.Errorf("http call(%s)(%s) error: reply status code was %d:\n%s", req.URL.String(), req.Method, reply.StatusCode, sd),
}
}
return nil, errors.CallErr{
Req: req,
Resp: reply,
Err: fmt.Errorf("http call(%s)(%s) error: reply status code was %d", req.URL.String(), req.Method, reply.StatusCode),
}
}
return data, nil
}
// checkResp checks a response object o make sure it is a pointer to a struct.
func (c *Client) checkResp(v reflect.Value) error {
if v.Kind() != reflect.Ptr {
return fmt.Errorf("bug: resp argument must a *struct, was %T", v.Interface())
}
v = v.Elem()
if v.Kind() != reflect.Struct {
return fmt.Errorf("bug: resp argument must be a *struct, was %T", v.Interface())
}
return nil
}
// readBody reads the body out of an *http.Response. It supports gzip encoded responses.
func (c *Client) readBody(resp *http.Response) ([]byte, error) {
var reader io.Reader = resp.Body
switch resp.Header.Get("Content-Encoding") {
case "":
// Do nothing
case "gzip":
reader = gzipDecompress(resp.Body)
default:
return nil, fmt.Errorf("bug: comm.Client.JSONCall(): content was send with unsupported content-encoding %s", resp.Header.Get("Content-Encoding"))
}
return io.ReadAll(reader)
}
var testID string
// addStdHeaders adds the standard headers we use on all calls.
func addStdHeaders(headers http.Header) http.Header {
headers.Set("Accept-Encoding", "gzip")
// So that I can have a static id for tests.
if testID != "" {
headers.Set("client-request-id", testID)
headers.Set("Return-Client-Request-Id", "false")
} else {
headers.Set("client-request-id", uuid.New().String())
headers.Set("Return-Client-Request-Id", "false")
}
headers.Set("x-client-sku", "MSAL.Go")
headers.Set("x-client-os", runtime.GOOS)
headers.Set("x-client-cpu", runtime.GOARCH)
headers.Set("x-client-ver", version.Version)
return headers
}
microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/internal/comm/comm_test.go 0000664 0000000 0000000 00000032553 14420263624 0033166 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package comm
import (
"context"
"encoding/json"
"encoding/xml"
"io"
"net/http"
"net/http/httptest"
"net/url"
"testing"
customJSON "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json"
"github.com/kylelemons/godebug/diff"
"github.com/kylelemons/godebug/pretty"
)
type recorder struct {
xml bool
statusCode int
ret interface{}
gotMethod string
gotQV url.Values
gotBody []byte
gotHeaders http.Header
}
func (rec *recorder) reset() {
rec.statusCode = 0
rec.ret = nil
rec.gotMethod = ""
rec.gotQV = nil
rec.gotBody = nil
rec.gotHeaders = nil
}
func (rec *recorder) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if rec.statusCode != http.StatusOK {
http.Error(w, "error", http.StatusBadRequest)
return
}
rec.gotMethod = r.Method
rec.gotQV = r.URL.Query()
b, err := io.ReadAll(r.Body)
if err != nil {
panic(err)
}
rec.gotBody = b
// This gets added by the test server.
delete(r.Header, "User-Agent")
delete(r.Header, "Content-Length")
rec.gotHeaders = r.Header
if rec.xml {
b, err = xml.Marshal(rec.ret)
if err != nil {
panic(err)
}
} else {
b, err = customJSON.Marshal(rec.ret)
if err != nil {
panic(err)
}
}
if _, err := w.Write(b); err != nil {
panic(err)
}
}
type SampleData struct {
Ok string
}
func init() {
testID = "testID"
}
func TestJSONCall(t *testing.T) {
tests := []struct {
desc string
statusCode int
headers http.Header
qv url.Values
body, resp interface{}
expectMethod string
expectHeaders http.Header
expectBody interface{}
want interface{}
err bool
}{
{
desc: "Error: non-struct resp value",
statusCode: http.StatusOK,
resp: new(int),
err: true,
},
{
desc: "Error: non-pointer resp value",
statusCode: http.StatusOK,
resp: SampleData{},
err: true,
},
{
desc: "Body == nil[http Get]",
statusCode: http.StatusOK,
headers: http.Header{"header": []string{"here"}},
qv: url.Values{"key": []string{"value"}},
resp: &SampleData{Ok: "true"},
expectMethod: http.MethodGet,
expectHeaders: addStdHeaders(http.Header{"Header": []string{"here"}}),
want: &SampleData{Ok: "true"},
},
{
desc: "Body != nil[http Post]",
statusCode: http.StatusOK,
headers: http.Header{"header": []string{"here"}},
qv: url.Values{"key": []string{"value"}},
body: &SampleData{Ok: "false"},
resp: &SampleData{Ok: "true"},
expectMethod: http.MethodPost,
expectHeaders: addStdHeaders(
http.Header{
"Header": []string{"here"},
"Content-Type": []string{"application/json; charset=utf-8"},
},
),
want: &SampleData{Ok: "true"},
},
{
desc: "Error: non-200 response",
statusCode: http.StatusBadRequest,
headers: http.Header{},
qv: url.Values{},
resp: &SampleData{Ok: "true"},
err: true,
},
}
rec := &recorder{}
serv := httptest.NewServer(rec)
defer serv.Close()
for _, test := range tests {
rec.reset()
rec.statusCode = test.statusCode
rec.ret = test.resp
comm := New(serv.Client())
err := comm.JSONCall(context.Background(), serv.URL, test.headers, test.qv, test.body, test.resp)
switch {
case err == nil && test.err:
t.Errorf("TestJSONCall(%s): got err == nil, want err != nil", test.desc)
continue
case err != nil && !test.err:
t.Errorf("TestJSONCall(%s): got err == %s, want err == nil", test.desc, err)
continue
case err != nil:
continue
}
if test.expectMethod != rec.gotMethod {
t.Errorf("TestJSONCall(%s): got method == %s, want http method == %s", test.desc, test.expectMethod, rec.gotMethod)
continue
}
if diff := pretty.Compare(test.qv, rec.gotQV); diff != "" {
t.Errorf("TestJSONCall(%s): query values: -want/+got:\n%s", test.desc, diff)
continue
}
if test.expectHeaders != nil {
if diff := pretty.Compare(test.expectHeaders, rec.gotHeaders); diff != "" {
t.Errorf("TestJSONCall(%s): headers: -want/+got:\n%s", test.desc, diff)
continue
}
}
if test.expectBody != nil {
gotBody := SampleData{}
if err := json.Unmarshal(rec.gotBody, &gotBody); err != nil {
panic(err)
}
if diff := pretty.Compare(test.expectBody, gotBody); diff != "" {
t.Errorf("TestJSONCall(%s): body: -want/+got:\n%s", test.desc, diff)
continue
}
}
if diff := pretty.Compare(test.want, test.resp); diff != "" {
t.Errorf("TestJSONCall(%s): result: -want/+got:\n%s", test.desc, diff)
}
}
}
func TestXMLCall(t *testing.T) {
tests := []struct {
desc string
statusCode int
headers http.Header
qv url.Values
resp interface{}
expectHeaders http.Header
expectBody interface{}
want interface{}
err bool
}{
{
desc: "Error: non-struct resp value",
statusCode: http.StatusOK,
resp: new(int),
err: true,
},
{
desc: "Error: non-pointer resp value",
statusCode: http.StatusOK,
resp: SampleData{},
err: true,
},
{
desc: "Success",
statusCode: http.StatusOK,
headers: http.Header{"header": []string{"here"}},
qv: url.Values{"key": []string{"value"}},
resp: &SampleData{Ok: "true"},
expectHeaders: addStdHeaders(
http.Header{
"Header": []string{"here"},
"Content-Type": []string{"application/xml; charset=utf-8"},
},
),
want: &SampleData{Ok: "true"},
},
{
desc: "Error: non-200 response",
statusCode: http.StatusBadRequest,
headers: http.Header{},
qv: url.Values{},
resp: &SampleData{Ok: "true"},
err: true,
},
}
rec := &recorder{xml: true}
serv := httptest.NewServer(rec)
defer serv.Close()
for _, test := range tests {
rec.reset()
rec.statusCode = test.statusCode
rec.ret = test.resp
comm := New(serv.Client())
err := comm.XMLCall(context.Background(), serv.URL, test.headers, test.qv, test.resp)
switch {
case err == nil && test.err:
t.Errorf("TestXMLCall(%s): got err == nil, want err != nil", test.desc)
continue
case err != nil && !test.err:
t.Errorf("TestXMLCall(%s): got err == %s, want err == nil", test.desc, err)
continue
case err != nil:
continue
}
if rec.gotMethod != http.MethodGet {
t.Errorf("TestXMLCall(%s): got method == %s, want http method == GET", test.desc, rec.gotMethod)
continue
}
if diff := pretty.Compare(test.qv, rec.gotQV); diff != "" {
t.Errorf("TestXMLCall(%s): query values: -want/+got:\n%s", test.desc, diff)
continue
}
if test.expectHeaders != nil {
if diff := pretty.Compare(test.expectHeaders, rec.gotHeaders); diff != "" {
t.Errorf("TestXMLCall(%s): headers: -want/+got:\n%s", test.desc, diff)
continue
}
}
if test.expectBody != nil {
gotBody := SampleData{}
if err := xml.Unmarshal(rec.gotBody, &gotBody); err != nil {
panic(err)
}
if diff := pretty.Compare(test.expectBody, gotBody); diff != "" {
t.Errorf("TestXMLCall(%s): body: -want/+got:\n%s", test.desc, diff)
continue
}
}
if diff := pretty.Compare(test.want, test.resp); diff != "" {
t.Errorf("TestXMLCall(%s): result: -want/+got:\n%s", test.desc, diff)
}
}
}
func TestSoapCall(t *testing.T) {
const soapActionDefault = "http://docs.oasis-open.org/ws-sx/ws-trust/200512/RST/Issue"
req := SampleData{Ok: "whatever"}
body, err := xml.Marshal(req)
if err != nil {
panic(err)
}
tests := []struct {
desc string
statusCode int
action string
body string
headers http.Header
qv url.Values
resp interface{}
expectHeaders http.Header
expectBody interface{}
want interface{}
err bool
}{
{
desc: "Error: non-struct resp value",
statusCode: http.StatusOK,
resp: new(int),
err: true,
},
{
desc: "Error: non-pointer resp value",
statusCode: http.StatusOK,
resp: SampleData{},
err: true,
},
{
desc: "Error: body arg was empty string",
statusCode: http.StatusOK,
action: soapActionDefault,
headers: http.Header{"header": []string{"here"}},
qv: url.Values{"key": []string{"value"}},
resp: &SampleData{Ok: "true"},
err: true,
},
{
desc: "Success",
statusCode: http.StatusOK,
headers: http.Header{"header": []string{"here"}},
qv: url.Values{"key": []string{"value"}},
action: soapActionDefault,
body: string(body),
resp: &SampleData{Ok: "true"},
expectHeaders: addStdHeaders(
http.Header{
"Header": []string{"here"},
"Content-Type": []string{"application/soap+xml; charset=utf-8"},
"Soapaction": []string{soapActionDefault},
},
),
want: &SampleData{Ok: "true"},
},
{
desc: "Error: non-200 response",
statusCode: http.StatusBadRequest,
headers: http.Header{},
qv: url.Values{},
resp: &SampleData{Ok: "true"},
err: true,
},
}
rec := &recorder{xml: true}
serv := httptest.NewServer(rec)
defer serv.Close()
for _, test := range tests {
rec.reset()
rec.statusCode = test.statusCode
rec.ret = test.resp
comm := New(serv.Client())
err := comm.SOAPCall(context.Background(), serv.URL, test.action, test.headers, test.qv, test.body, test.resp)
switch {
case err == nil && test.err:
t.Errorf("TestXMLCall(%s): got err == nil, want err != nil", test.desc)
continue
case err != nil && !test.err:
t.Errorf("TestXMLCall(%s): got err == %s, want err == nil", test.desc, err)
continue
case err != nil:
continue
}
if rec.gotMethod != http.MethodPost {
t.Errorf("TestXMLCall(%s): got method == %s, want http method == POST", test.desc, rec.gotMethod)
continue
}
if diff := pretty.Compare(test.qv, rec.gotQV); diff != "" {
t.Errorf("TestXMLCall(%s): query values: -want/+got:\n%s", test.desc, diff)
continue
}
if test.expectHeaders != nil {
if diff := pretty.Compare(test.expectHeaders, rec.gotHeaders); diff != "" {
t.Errorf("TestXMLCall(%s): headers: -want/+got:\n%s", test.desc, diff)
continue
}
}
if test.expectBody != nil {
gotBody := SampleData{}
if err := xml.Unmarshal(rec.gotBody, &gotBody); err != nil {
panic(err)
}
if diff := pretty.Compare(test.expectBody, gotBody); diff != "" {
t.Errorf("TestXMLCall(%s): body: -want/+got:\n%s", test.desc, diff)
continue
}
}
if diff := pretty.Compare(test.want, test.resp); diff != "" {
t.Errorf("TestXMLCall(%s): result: -want/+got:\n%s", test.desc, diff)
}
}
}
func TestURLFormCall(t *testing.T) {
tests := []struct {
desc string
statusCode int
action string
body string
headers http.Header
qv url.Values
resp interface{}
expectHeaders http.Header
expectEndpoint string
want interface{}
err bool
}{
{
desc: "Error: non-struct resp value",
statusCode: http.StatusOK,
resp: new(int),
err: true,
},
{
desc: "Error: non-pointer resp value",
statusCode: http.StatusOK,
resp: SampleData{},
err: true,
},
{
desc: "Error: empty query values",
statusCode: http.StatusOK,
headers: http.Header{"header": []string{"here"}},
resp: &SampleData{Ok: "true"},
expectHeaders: addStdHeaders(
http.Header{
"Content-Type": []string{"application/x-www-form-urlencoded; charset=utf-8"},
},
),
err: true,
},
{
desc: "Success",
statusCode: http.StatusOK,
headers: http.Header{"header": []string{"here"}},
qv: url.Values{"key": []string{"value"}},
resp: &SampleData{Ok: "true"},
expectHeaders: addStdHeaders(
http.Header{
"Content-Type": []string{"application/x-www-form-urlencoded; charset=utf-8"},
},
),
want: &SampleData{Ok: "true"},
},
{
desc: "Error: non-200 response",
statusCode: http.StatusBadRequest,
headers: http.Header{"header": []string{"here"}},
qv: url.Values{"key": []string{"value"}},
resp: &SampleData{Ok: "true"},
err: true,
},
}
rec := &recorder{}
serv := httptest.NewServer(rec)
defer serv.Close()
for _, test := range tests {
rec.reset()
rec.statusCode = test.statusCode
rec.ret = test.resp
comm := New(serv.Client())
err := comm.URLFormCall(context.Background(), serv.URL, test.qv, test.resp)
switch {
case err == nil && test.err:
t.Errorf("TestURLFormCall(%s): got err == nil, want err != nil", test.desc)
continue
case err != nil && !test.err:
t.Errorf("TestURLFormCall(%s): got err == %s, want err == nil", test.desc, err)
continue
case err != nil:
continue
}
if rec.gotMethod != http.MethodPost {
t.Errorf("TestURLFormCall(%s): got method == %s, want http method == POST", test.desc, rec.gotMethod)
continue
}
if test.expectHeaders != nil {
if diff := pretty.Compare(test.expectHeaders, rec.gotHeaders); diff != "" {
t.Errorf("TestURLFormCall(%s): headers: -want/+got:\n%s", test.desc, diff)
continue
}
}
want := test.qv.Encode()
got := string(rec.gotBody)
if diff := diff.Diff(want, got); diff != "" {
t.Errorf("TestXMLCall(%s): body: -want/+got:\n%s", test.desc, diff)
continue
}
if diff := pretty.Compare(test.want, test.resp); diff != "" {
t.Errorf("TestXMLCall(%s): result: -want/+got:\n%s", test.desc, diff)
}
}
}
microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/internal/comm/compress.go 0000664 0000000 0000000 00000001304 14420263624 0033015 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package comm
import (
"compress/gzip"
"io"
)
func gzipDecompress(r io.Reader) io.Reader {
gzipReader, _ := gzip.NewReader(r)
pipeOut, pipeIn := io.Pipe()
go func() {
// decompression bomb would have to come from Azure services.
// If we want to limit, we should do that in comm.do().
_, err := io.Copy(pipeIn, gzipReader) //nolint
if err != nil {
// don't need the error.
pipeIn.CloseWithError(err) //nolint
gzipReader.Close()
return
}
if err := gzipReader.Close(); err != nil {
// don't need the error.
pipeIn.CloseWithError(err) //nolint
return
}
pipeIn.Close()
}()
return pipeOut
}
microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/internal/grant/ 0000775 0000000 0000000 00000000000 14420263624 0031015 5 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/internal/grant/grant.go 0000664 0000000 0000000 00000001176 14420263624 0032464 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
// Package grant holds types of grants issued by authorization services.
package grant
const (
Password = "password"
JWT = "urn:ietf:params:oauth:grant-type:jwt-bearer"
SAMLV1 = "urn:ietf:params:oauth:grant-type:saml1_1-bearer"
SAMLV2 = "urn:ietf:params:oauth:grant-type:saml2-bearer"
DeviceCode = "device_code"
AuthCode = "authorization_code"
RefreshToken = "refresh_token"
ClientCredential = "client_credentials"
ClientAssertion = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
)
microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/ops.go 0000664 0000000 0000000 00000003544 14420263624 0027224 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
/*
Package ops provides operations to various backend services using REST clients.
The REST type provides several clients that can be used to communicate to backends.
Usage is simple:
rest := ops.New()
// Creates an authority client and calls the UserRealm() method.
userRealm, err := rest.Authority().UserRealm(ctx, authParameters)
if err != nil {
// Do something
}
*/
package ops
import (
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/internal/comm"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/wstrust"
)
// HTTPClient represents an HTTP client.
// It's usually an *http.Client from the standard library.
type HTTPClient = comm.HTTPClient
// REST provides REST clients for communicating with various backends used by MSAL.
type REST struct {
client *comm.Client
}
// New is the constructor for REST.
func New(httpClient HTTPClient) *REST {
return &REST{client: comm.New(httpClient)}
}
// Authority returns a client for querying information about various authorities.
func (r *REST) Authority() authority.Client {
return authority.Client{Comm: r.client}
}
// AccessTokens returns a client that can be used to get various access tokens for
// authorization purposes.
func (r *REST) AccessTokens() accesstokens.Client {
return accesstokens.Client{Comm: r.client}
}
// WSTrust provides access to various metadata in a WSTrust service. This data can
// be used to gain tokens based on SAML data using the client provided by AccessTokens().
func (r *REST) WSTrust() wstrust.Client {
return wstrust.Client{Comm: r.client}
}
microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/wstrust/ 0000775 0000000 0000000 00000000000 14420263624 0027621 5 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/wstrust/defs/ 0000775 0000000 0000000 00000000000 14420263624 0030542 5 ustar 00root root 0000000 0000000 endpointtype_string.go 0000664 0000000 0000000 00000001335 14420263624 0035124 0 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/wstrust/defs // Code generated by "stringer -type=endpointType"; DO NOT EDIT.
package defs
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[etUnknown-0]
_ = x[etUsernamePassword-1]
_ = x[etWindowsTransport-2]
}
const _endpointType_name = "etUnknownetUsernamePasswordetWindowsTransport"
var _endpointType_index = [...]uint8{0, 9, 27, 45}
func (i endpointType) String() string {
if i < 0 || i >= endpointType(len(_endpointType_index)-1) {
return "endpointType(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _endpointType_name[_endpointType_index[i]:_endpointType_index[i+1]]
}
mex_document_definitions.go 0000664 0000000 0000000 00000027700 14420263624 0036102 0 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/wstrust/defs // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package defs
import "encoding/xml"
type Definitions struct {
XMLName xml.Name `xml:"definitions"`
Text string `xml:",chardata"`
Name string `xml:"name,attr"`
TargetNamespace string `xml:"targetNamespace,attr"`
WSDL string `xml:"wsdl,attr"`
XSD string `xml:"xsd,attr"`
T string `xml:"t,attr"`
SOAPENC string `xml:"soapenc,attr"`
SOAP string `xml:"soap,attr"`
TNS string `xml:"tns,attr"`
MSC string `xml:"msc,attr"`
WSAM string `xml:"wsam,attr"`
SOAP12 string `xml:"soap12,attr"`
WSA10 string `xml:"wsa10,attr"`
WSA string `xml:"wsa,attr"`
WSAW string `xml:"wsaw,attr"`
WSX string `xml:"wsx,attr"`
WSAP string `xml:"wsap,attr"`
WSU string `xml:"wsu,attr"`
Trust string `xml:"trust,attr"`
WSP string `xml:"wsp,attr"`
Policy []Policy `xml:"Policy"`
Types Types `xml:"types"`
Message []Message `xml:"message"`
PortType []PortType `xml:"portType"`
Binding []Binding `xml:"binding"`
Service Service `xml:"service"`
}
type Policy struct {
Text string `xml:",chardata"`
ID string `xml:"Id,attr"`
ExactlyOne ExactlyOne `xml:"ExactlyOne"`
}
type ExactlyOne struct {
Text string `xml:",chardata"`
All All `xml:"All"`
}
type All struct {
Text string `xml:",chardata"`
NegotiateAuthentication NegotiateAuthentication `xml:"NegotiateAuthentication"`
TransportBinding TransportBinding `xml:"TransportBinding"`
UsingAddressing Text `xml:"UsingAddressing"`
EndorsingSupportingTokens EndorsingSupportingTokens `xml:"EndorsingSupportingTokens"`
WSS11 WSS11 `xml:"Wss11"`
Trust10 Trust10 `xml:"Trust10"`
SignedSupportingTokens SignedSupportingTokens `xml:"SignedSupportingTokens"`
Trust13 WSTrust13 `xml:"Trust13"`
SignedEncryptedSupportingTokens SignedEncryptedSupportingTokens `xml:"SignedEncryptedSupportingTokens"`
}
type NegotiateAuthentication struct {
Text string `xml:",chardata"`
HTTP string `xml:"http,attr"`
XMLName xml.Name
}
type TransportBinding struct {
Text string `xml:",chardata"`
SP string `xml:"sp,attr"`
Policy TransportBindingPolicy `xml:"Policy"`
}
type TransportBindingPolicy struct {
Text string `xml:",chardata"`
TransportToken TransportToken `xml:"TransportToken"`
AlgorithmSuite AlgorithmSuite `xml:"AlgorithmSuite"`
Layout Layout `xml:"Layout"`
IncludeTimestamp Text `xml:"IncludeTimestamp"`
}
type TransportToken struct {
Text string `xml:",chardata"`
Policy TransportTokenPolicy `xml:"Policy"`
}
type TransportTokenPolicy struct {
Text string `xml:",chardata"`
HTTPSToken HTTPSToken `xml:"HttpsToken"`
}
type HTTPSToken struct {
Text string `xml:",chardata"`
RequireClientCertificate string `xml:"RequireClientCertificate,attr"`
}
type AlgorithmSuite struct {
Text string `xml:",chardata"`
Policy AlgorithmSuitePolicy `xml:"Policy"`
}
type AlgorithmSuitePolicy struct {
Text string `xml:",chardata"`
Basic256 Text `xml:"Basic256"`
Basic128 Text `xml:"Basic128"`
}
type Layout struct {
Text string `xml:",chardata"`
Policy LayoutPolicy `xml:"Policy"`
}
type LayoutPolicy struct {
Text string `xml:",chardata"`
Strict Text `xml:"Strict"`
}
type EndorsingSupportingTokens struct {
Text string `xml:",chardata"`
SP string `xml:"sp,attr"`
Policy EndorsingSupportingTokensPolicy `xml:"Policy"`
}
type EndorsingSupportingTokensPolicy struct {
Text string `xml:",chardata"`
X509Token X509Token `xml:"X509Token"`
RSAToken RSAToken `xml:"RsaToken"`
SignedParts SignedParts `xml:"SignedParts"`
KerberosToken KerberosToken `xml:"KerberosToken"`
IssuedToken IssuedToken `xml:"IssuedToken"`
KeyValueToken KeyValueToken `xml:"KeyValueToken"`
}
type X509Token struct {
Text string `xml:",chardata"`
IncludeToken string `xml:"IncludeToken,attr"`
Policy X509TokenPolicy `xml:"Policy"`
}
type X509TokenPolicy struct {
Text string `xml:",chardata"`
RequireThumbprintReference Text `xml:"RequireThumbprintReference"`
WSSX509V3Token10 Text `xml:"WssX509V3Token10"`
}
type RSAToken struct {
Text string `xml:",chardata"`
IncludeToken string `xml:"IncludeToken,attr"`
Optional string `xml:"Optional,attr"`
MSSP string `xml:"mssp,attr"`
}
type SignedParts struct {
Text string `xml:",chardata"`
Header SignedPartsHeader `xml:"Header"`
}
type SignedPartsHeader struct {
Text string `xml:",chardata"`
Name string `xml:"Name,attr"`
Namespace string `xml:"Namespace,attr"`
}
type KerberosToken struct {
Text string `xml:",chardata"`
IncludeToken string `xml:"IncludeToken,attr"`
Policy KerberosTokenPolicy `xml:"Policy"`
}
type KerberosTokenPolicy struct {
Text string `xml:",chardata"`
WSSGSSKerberosV5ApReqToken11 Text `xml:"WssGssKerberosV5ApReqToken11"`
}
type IssuedToken struct {
Text string `xml:",chardata"`
IncludeToken string `xml:"IncludeToken,attr"`
RequestSecurityTokenTemplate RequestSecurityTokenTemplate `xml:"RequestSecurityTokenTemplate"`
Policy IssuedTokenPolicy `xml:"Policy"`
}
type RequestSecurityTokenTemplate struct {
Text string `xml:",chardata"`
KeyType Text `xml:"KeyType"`
EncryptWith Text `xml:"EncryptWith"`
SignatureAlgorithm Text `xml:"SignatureAlgorithm"`
CanonicalizationAlgorithm Text `xml:"CanonicalizationAlgorithm"`
EncryptionAlgorithm Text `xml:"EncryptionAlgorithm"`
KeySize Text `xml:"KeySize"`
KeyWrapAlgorithm Text `xml:"KeyWrapAlgorithm"`
}
type IssuedTokenPolicy struct {
Text string `xml:",chardata"`
RequireInternalReference Text `xml:"RequireInternalReference"`
}
type KeyValueToken struct {
Text string `xml:",chardata"`
IncludeToken string `xml:"IncludeToken,attr"`
Optional string `xml:"Optional,attr"`
}
type WSS11 struct {
Text string `xml:",chardata"`
SP string `xml:"sp,attr"`
Policy Wss11Policy `xml:"Policy"`
}
type Wss11Policy struct {
Text string `xml:",chardata"`
MustSupportRefThumbprint Text `xml:"MustSupportRefThumbprint"`
}
type Trust10 struct {
Text string `xml:",chardata"`
SP string `xml:"sp,attr"`
Policy Trust10Policy `xml:"Policy"`
}
type Trust10Policy struct {
Text string `xml:",chardata"`
MustSupportIssuedTokens Text `xml:"MustSupportIssuedTokens"`
RequireClientEntropy Text `xml:"RequireClientEntropy"`
RequireServerEntropy Text `xml:"RequireServerEntropy"`
}
type SignedSupportingTokens struct {
Text string `xml:",chardata"`
SP string `xml:"sp,attr"`
Policy SupportingTokensPolicy `xml:"Policy"`
}
type SupportingTokensPolicy struct {
Text string `xml:",chardata"`
UsernameToken UsernameToken `xml:"UsernameToken"`
}
type UsernameToken struct {
Text string `xml:",chardata"`
IncludeToken string `xml:"IncludeToken,attr"`
Policy UsernameTokenPolicy `xml:"Policy"`
}
type UsernameTokenPolicy struct {
Text string `xml:",chardata"`
WSSUsernameToken10 WSSUsernameToken10 `xml:"WssUsernameToken10"`
}
type WSSUsernameToken10 struct {
Text string `xml:",chardata"`
XMLName xml.Name
}
type WSTrust13 struct {
Text string `xml:",chardata"`
SP string `xml:"sp,attr"`
Policy WSTrust13Policy `xml:"Policy"`
}
type WSTrust13Policy struct {
Text string `xml:",chardata"`
MustSupportIssuedTokens Text `xml:"MustSupportIssuedTokens"`
RequireClientEntropy Text `xml:"RequireClientEntropy"`
RequireServerEntropy Text `xml:"RequireServerEntropy"`
}
type SignedEncryptedSupportingTokens struct {
Text string `xml:",chardata"`
SP string `xml:"sp,attr"`
Policy SupportingTokensPolicy `xml:"Policy"`
}
type Types struct {
Text string `xml:",chardata"`
Schema Schema `xml:"schema"`
}
type Schema struct {
Text string `xml:",chardata"`
TargetNamespace string `xml:"targetNamespace,attr"`
Import []Import `xml:"import"`
}
type Import struct {
Text string `xml:",chardata"`
SchemaLocation string `xml:"schemaLocation,attr"`
Namespace string `xml:"namespace,attr"`
}
type Message struct {
Text string `xml:",chardata"`
Name string `xml:"name,attr"`
Part Part `xml:"part"`
}
type Part struct {
Text string `xml:",chardata"`
Name string `xml:"name,attr"`
Element string `xml:"element,attr"`
}
type PortType struct {
Text string `xml:",chardata"`
Name string `xml:"name,attr"`
Operation Operation `xml:"operation"`
}
type Operation struct {
Text string `xml:",chardata"`
Name string `xml:"name,attr"`
Input OperationIO `xml:"input"`
Output OperationIO `xml:"output"`
}
type OperationIO struct {
Text string `xml:",chardata"`
Action string `xml:"Action,attr"`
Message string `xml:"message,attr"`
Body OperationIOBody `xml:"body"`
}
type OperationIOBody struct {
Text string `xml:",chardata"`
Use string `xml:"use,attr"`
}
type Binding struct {
Text string `xml:",chardata"`
Name string `xml:"name,attr"`
Type string `xml:"type,attr"`
PolicyReference PolicyReference `xml:"PolicyReference"`
Binding DefinitionsBinding `xml:"binding"`
Operation BindingOperation `xml:"operation"`
}
type PolicyReference struct {
Text string `xml:",chardata"`
URI string `xml:"URI,attr"`
}
type DefinitionsBinding struct {
Text string `xml:",chardata"`
Transport string `xml:"transport,attr"`
}
type BindingOperation struct {
Text string `xml:",chardata"`
Name string `xml:"name,attr"`
Operation BindingOperationOperation `xml:"operation"`
Input BindingOperationIO `xml:"input"`
Output BindingOperationIO `xml:"output"`
}
type BindingOperationOperation struct {
Text string `xml:",chardata"`
SoapAction string `xml:"soapAction,attr"`
Style string `xml:"style,attr"`
}
type BindingOperationIO struct {
Text string `xml:",chardata"`
Body OperationIOBody `xml:"body"`
}
type Service struct {
Text string `xml:",chardata"`
Name string `xml:"name,attr"`
Port []Port `xml:"port"`
}
type Port struct {
Text string `xml:",chardata"`
Name string `xml:"name,attr"`
Binding string `xml:"binding,attr"`
Address Address `xml:"address"`
EndpointReference PortEndpointReference `xml:"EndpointReference"`
}
type Address struct {
Text string `xml:",chardata"`
Location string `xml:"location,attr"`
}
type PortEndpointReference struct {
Text string `xml:",chardata"`
Address Text `xml:"Address"`
Identity Identity `xml:"Identity"`
}
type Identity struct {
Text string `xml:",chardata"`
XMLNS string `xml:"xmlns,attr"`
SPN Text `xml:"Spn"`
}
saml_assertion_definitions.go 0000664 0000000 0000000 00000017136 14420263624 0036440 0 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/wstrust/defs // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package defs
import "encoding/xml"
// TODO(msal): Someone (and it ain't gonna be me) needs to document these attributes or
// at the least put a link to RFC.
type SAMLDefinitions struct {
XMLName xml.Name `xml:"Envelope"`
Text string `xml:",chardata"`
S string `xml:"s,attr"`
A string `xml:"a,attr"`
U string `xml:"u,attr"`
Header Header `xml:"Header"`
Body Body `xml:"Body"`
}
type Header struct {
Text string `xml:",chardata"`
Action Action `xml:"Action"`
Security Security `xml:"Security"`
}
type Action struct {
Text string `xml:",chardata"`
MustUnderstand string `xml:"mustUnderstand,attr"`
}
type Security struct {
Text string `xml:",chardata"`
MustUnderstand string `xml:"mustUnderstand,attr"`
O string `xml:"o,attr"`
Timestamp Timestamp `xml:"Timestamp"`
}
type Timestamp struct {
Text string `xml:",chardata"`
ID string `xml:"Id,attr"`
Created Text `xml:"Created"`
Expires Text `xml:"Expires"`
}
type Text struct {
Text string `xml:",chardata"`
}
type Body struct {
Text string `xml:",chardata"`
RequestSecurityTokenResponseCollection RequestSecurityTokenResponseCollection `xml:"RequestSecurityTokenResponseCollection"`
}
type RequestSecurityTokenResponseCollection struct {
Text string `xml:",chardata"`
Trust string `xml:"trust,attr"`
RequestSecurityTokenResponse []RequestSecurityTokenResponse `xml:"RequestSecurityTokenResponse"`
}
type RequestSecurityTokenResponse struct {
Text string `xml:",chardata"`
Lifetime Lifetime `xml:"Lifetime"`
AppliesTo AppliesTo `xml:"AppliesTo"`
RequestedSecurityToken RequestedSecurityToken `xml:"RequestedSecurityToken"`
RequestedAttachedReference RequestedAttachedReference `xml:"RequestedAttachedReference"`
RequestedUnattachedReference RequestedUnattachedReference `xml:"RequestedUnattachedReference"`
TokenType Text `xml:"TokenType"`
RequestType Text `xml:"RequestType"`
KeyType Text `xml:"KeyType"`
}
type Lifetime struct {
Text string `xml:",chardata"`
Created WSUTimestamp `xml:"Created"`
Expires WSUTimestamp `xml:"Expires"`
}
type WSUTimestamp struct {
Text string `xml:",chardata"`
Wsu string `xml:"wsu,attr"`
}
type AppliesTo struct {
Text string `xml:",chardata"`
Wsp string `xml:"wsp,attr"`
EndpointReference EndpointReference `xml:"EndpointReference"`
}
type EndpointReference struct {
Text string `xml:",chardata"`
Wsa string `xml:"wsa,attr"`
Address Text `xml:"Address"`
}
type RequestedSecurityToken struct {
Text string `xml:",chardata"`
AssertionRawXML string `xml:",innerxml"`
Assertion Assertion `xml:"Assertion"`
}
type Assertion struct {
XMLName xml.Name // Normally its `xml:"Assertion"`, but I think they want to capture the xmlns
Text string `xml:",chardata"`
MajorVersion string `xml:"MajorVersion,attr"`
MinorVersion string `xml:"MinorVersion,attr"`
AssertionID string `xml:"AssertionID,attr"`
Issuer string `xml:"Issuer,attr"`
IssueInstant string `xml:"IssueInstant,attr"`
Saml string `xml:"saml,attr"`
Conditions Conditions `xml:"Conditions"`
AttributeStatement AttributeStatement `xml:"AttributeStatement"`
AuthenticationStatement AuthenticationStatement `xml:"AuthenticationStatement"`
Signature Signature `xml:"Signature"`
}
type Conditions struct {
Text string `xml:",chardata"`
NotBefore string `xml:"NotBefore,attr"`
NotOnOrAfter string `xml:"NotOnOrAfter,attr"`
AudienceRestrictionCondition AudienceRestrictionCondition `xml:"AudienceRestrictionCondition"`
}
type AudienceRestrictionCondition struct {
Text string `xml:",chardata"`
Audience Text `xml:"Audience"`
}
type AttributeStatement struct {
Text string `xml:",chardata"`
Subject Subject `xml:"Subject"`
Attribute []Attribute `xml:"Attribute"`
}
type Subject struct {
Text string `xml:",chardata"`
NameIdentifier NameIdentifier `xml:"NameIdentifier"`
SubjectConfirmation SubjectConfirmation `xml:"SubjectConfirmation"`
}
type NameIdentifier struct {
Text string `xml:",chardata"`
Format string `xml:"Format,attr"`
}
type SubjectConfirmation struct {
Text string `xml:",chardata"`
ConfirmationMethod Text `xml:"ConfirmationMethod"`
}
type Attribute struct {
Text string `xml:",chardata"`
AttributeName string `xml:"AttributeName,attr"`
AttributeNamespace string `xml:"AttributeNamespace,attr"`
AttributeValue Text `xml:"AttributeValue"`
}
type AuthenticationStatement struct {
Text string `xml:",chardata"`
AuthenticationMethod string `xml:"AuthenticationMethod,attr"`
AuthenticationInstant string `xml:"AuthenticationInstant,attr"`
Subject Subject `xml:"Subject"`
}
type Signature struct {
Text string `xml:",chardata"`
Ds string `xml:"ds,attr"`
SignedInfo SignedInfo `xml:"SignedInfo"`
SignatureValue Text `xml:"SignatureValue"`
KeyInfo KeyInfo `xml:"KeyInfo"`
}
type SignedInfo struct {
Text string `xml:",chardata"`
CanonicalizationMethod Method `xml:"CanonicalizationMethod"`
SignatureMethod Method `xml:"SignatureMethod"`
Reference Reference `xml:"Reference"`
}
type Method struct {
Text string `xml:",chardata"`
Algorithm string `xml:"Algorithm,attr"`
}
type Reference struct {
Text string `xml:",chardata"`
URI string `xml:"URI,attr"`
Transforms Transforms `xml:"Transforms"`
DigestMethod Method `xml:"DigestMethod"`
DigestValue Text `xml:"DigestValue"`
}
type Transforms struct {
Text string `xml:",chardata"`
Transform []Method `xml:"Transform"`
}
type KeyInfo struct {
Text string `xml:",chardata"`
Xmlns string `xml:"xmlns,attr"`
X509Data X509Data `xml:"X509Data"`
}
type X509Data struct {
Text string `xml:",chardata"`
X509Certificate Text `xml:"X509Certificate"`
}
type RequestedAttachedReference struct {
Text string `xml:",chardata"`
SecurityTokenReference SecurityTokenReference `xml:"SecurityTokenReference"`
}
type SecurityTokenReference struct {
Text string `xml:",chardata"`
TokenType string `xml:"TokenType,attr"`
O string `xml:"o,attr"`
K string `xml:"k,attr"`
KeyIdentifier KeyIdentifier `xml:"KeyIdentifier"`
}
type KeyIdentifier struct {
Text string `xml:",chardata"`
ValueType string `xml:"ValueType,attr"`
}
type RequestedUnattachedReference struct {
Text string `xml:",chardata"`
SecurityTokenReference SecurityTokenReference `xml:"SecurityTokenReference"`
}
microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/wstrust/defs/version_string.go0000664 0000000 0000000 00000001212 14420263624 0034140 0 ustar 00root root 0000000 0000000 // Code generated by "stringer -type=Version"; DO NOT EDIT.
package defs
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[TrustUnknown-0]
_ = x[Trust2005-1]
_ = x[Trust13-2]
}
const _Version_name = "TrustUnknownTrust2005Trust13"
var _Version_index = [...]uint8{0, 12, 21, 28}
func (i Version) String() string {
if i < 0 || i >= Version(len(_Version_index)-1) {
return "Version(" + strconv.FormatInt(int64(i), 10) + ")"
}
return _Version_name[_Version_index[i]:_Version_index[i+1]]
}
wstrust_endpoint.go 0000664 0000000 0000000 00000015075 14420263624 0034455 0 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/wstrust/defs // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package defs
import (
"encoding/xml"
"fmt"
"time"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
uuid "github.com/google/uuid"
)
//go:generate stringer -type=Version
type Version int
const (
TrustUnknown Version = iota
Trust2005
Trust13
)
// Endpoint represents a WSTrust endpoint.
type Endpoint struct {
// Version is the version of the endpoint.
Version Version
// URL is the URL of the endpoint.
URL string
}
type wsTrustTokenRequestEnvelope struct {
XMLName xml.Name `xml:"s:Envelope"`
Text string `xml:",chardata"`
S string `xml:"xmlns:s,attr"`
Wsa string `xml:"xmlns:wsa,attr"`
Wsu string `xml:"xmlns:wsu,attr"`
Header struct {
Text string `xml:",chardata"`
Action struct {
Text string `xml:",chardata"`
MustUnderstand string `xml:"s:mustUnderstand,attr"`
} `xml:"wsa:Action"`
MessageID struct {
Text string `xml:",chardata"`
} `xml:"wsa:messageID"`
ReplyTo struct {
Text string `xml:",chardata"`
Address struct {
Text string `xml:",chardata"`
} `xml:"wsa:Address"`
} `xml:"wsa:ReplyTo"`
To struct {
Text string `xml:",chardata"`
MustUnderstand string `xml:"s:mustUnderstand,attr"`
} `xml:"wsa:To"`
Security struct {
Text string `xml:",chardata"`
MustUnderstand string `xml:"s:mustUnderstand,attr"`
Wsse string `xml:"xmlns:wsse,attr"`
Timestamp struct {
Text string `xml:",chardata"`
ID string `xml:"wsu:Id,attr"`
Created struct {
Text string `xml:",chardata"`
} `xml:"wsu:Created"`
Expires struct {
Text string `xml:",chardata"`
} `xml:"wsu:Expires"`
} `xml:"wsu:Timestamp"`
UsernameToken struct {
Text string `xml:",chardata"`
ID string `xml:"wsu:Id,attr"`
Username struct {
Text string `xml:",chardata"`
} `xml:"wsse:Username"`
Password struct {
Text string `xml:",chardata"`
} `xml:"wsse:Password"`
} `xml:"wsse:UsernameToken"`
} `xml:"wsse:Security"`
} `xml:"s:Header"`
Body struct {
Text string `xml:",chardata"`
RequestSecurityToken struct {
Text string `xml:",chardata"`
Wst string `xml:"xmlns:wst,attr"`
AppliesTo struct {
Text string `xml:",chardata"`
Wsp string `xml:"xmlns:wsp,attr"`
EndpointReference struct {
Text string `xml:",chardata"`
Address struct {
Text string `xml:",chardata"`
} `xml:"wsa:Address"`
} `xml:"wsa:EndpointReference"`
} `xml:"wsp:AppliesTo"`
KeyType struct {
Text string `xml:",chardata"`
} `xml:"wst:KeyType"`
RequestType struct {
Text string `xml:",chardata"`
} `xml:"wst:RequestType"`
} `xml:"wst:RequestSecurityToken"`
} `xml:"s:Body"`
}
func buildTimeString(t time.Time) string {
// Golang time formats are weird: https://stackoverflow.com/questions/20234104/how-to-format-current-time-using-a-yyyymmddhhmmss-format
return t.Format("2006-01-02T15:04:05.000Z")
}
func (wte *Endpoint) buildTokenRequestMessage(authType authority.AuthorizeType, cloudAudienceURN string, username string, password string) (string, error) {
var soapAction string
var trustNamespace string
var keyType string
var requestType string
createdTime := time.Now().UTC()
expiresTime := createdTime.Add(10 * time.Minute)
switch wte.Version {
case Trust2005:
soapAction = trust2005Spec
trustNamespace = "http://schemas.xmlsoap.org/ws/2005/02/trust"
keyType = "http://schemas.xmlsoap.org/ws/2005/05/identity/NoProofKey"
requestType = "http://schemas.xmlsoap.org/ws/2005/02/trust/Issue"
case Trust13:
soapAction = trust13Spec
trustNamespace = "http://docs.oasis-open.org/ws-sx/ws-trust/200512"
keyType = "http://docs.oasis-open.org/ws-sx/ws-trust/200512/Bearer"
requestType = "http://docs.oasis-open.org/ws-sx/ws-trust/200512/Issue"
default:
return "", fmt.Errorf("buildTokenRequestMessage had Version == %q, which is not recognized", wte.Version)
}
var envelope wsTrustTokenRequestEnvelope
messageUUID := uuid.New()
envelope.S = "http://www.w3.org/2003/05/soap-envelope"
envelope.Wsa = "http://www.w3.org/2005/08/addressing"
envelope.Wsu = "http://docs.oasis-open.org/wss/2004/01/oasis-200401-wss-wssecurity-utility-1.0.xsd"
envelope.Header.Action.MustUnderstand = "1"
envelope.Header.Action.Text = soapAction
envelope.Header.MessageID.Text = "urn:uuid:" + messageUUID.String()
envelope.Header.ReplyTo.Address.Text = "http://www.w3.org/2005/08/addressing/anonymous"
envelope.Header.To.MustUnderstand = "1"
envelope.Header.To.Text = wte.URL
switch authType {
case authority.ATUnknown:
return "", fmt.Errorf("buildTokenRequestMessage had no authority type(%v)", authType)
case authority.ATUsernamePassword:
endpointUUID := uuid.New()
var trustID string
if wte.Version == Trust2005 {
trustID = "UnPwSecTok2005-" + endpointUUID.String()
} else {
trustID = "UnPwSecTok13-" + endpointUUID.String()
}
envelope.Header.Security.MustUnderstand = "1"
envelope.Header.Security.Wsse = "http://docs.oasis-open.org/wss/2004/01/oasis-200401-wss-wssecurity-secext-1.0.xsd"
envelope.Header.Security.Timestamp.ID = "MSATimeStamp"
envelope.Header.Security.Timestamp.Created.Text = buildTimeString(createdTime)
envelope.Header.Security.Timestamp.Expires.Text = buildTimeString(expiresTime)
envelope.Header.Security.UsernameToken.ID = trustID
envelope.Header.Security.UsernameToken.Username.Text = username
envelope.Header.Security.UsernameToken.Password.Text = password
default:
// This is just to note that we don't do anything for other cases.
// We aren't missing anything I know of.
}
envelope.Body.RequestSecurityToken.Wst = trustNamespace
envelope.Body.RequestSecurityToken.AppliesTo.Wsp = "http://schemas.xmlsoap.org/ws/2004/09/policy"
envelope.Body.RequestSecurityToken.AppliesTo.EndpointReference.Address.Text = cloudAudienceURN
envelope.Body.RequestSecurityToken.KeyType.Text = keyType
envelope.Body.RequestSecurityToken.RequestType.Text = requestType
output, err := xml.Marshal(envelope)
if err != nil {
return "", err
}
return string(output), nil
}
func (wte *Endpoint) BuildTokenRequestMessageWIA(cloudAudienceURN string) (string, error) {
return wte.buildTokenRequestMessage(authority.ATWindowsIntegrated, cloudAudienceURN, "", "")
}
func (wte *Endpoint) BuildTokenRequestMessageUsernamePassword(cloudAudienceURN string, username string, password string) (string, error) {
return wte.buildTokenRequestMessage(authority.ATUsernamePassword, cloudAudienceURN, username, password)
}
wstrust_mex_document.go 0000664 0000000 0000000 00000010531 14420263624 0035314 0 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/wstrust/defs // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package defs
import (
"errors"
"fmt"
"strings"
)
//go:generate stringer -type=endpointType
type endpointType int
const (
etUnknown endpointType = iota
etUsernamePassword
etWindowsTransport
)
type wsEndpointData struct {
Version Version
EndpointType endpointType
}
const trust13Spec string = "http://docs.oasis-open.org/ws-sx/ws-trust/200512/RST/Issue"
const trust2005Spec string = "http://schemas.xmlsoap.org/ws/2005/02/trust/RST/Issue"
type MexDocument struct {
UsernamePasswordEndpoint Endpoint
WindowsTransportEndpoint Endpoint
policies map[string]endpointType
bindings map[string]wsEndpointData
}
func updateEndpoint(cached *Endpoint, found Endpoint) {
if cached == nil || cached.Version == TrustUnknown {
*cached = found
return
}
if (*cached).Version == Trust2005 && found.Version == Trust13 {
*cached = found
return
}
}
// TODO(msal): Someone needs to write tests for everything below.
// NewFromDef creates a new MexDocument.
func NewFromDef(defs Definitions) (MexDocument, error) {
policies, err := policies(defs)
if err != nil {
return MexDocument{}, err
}
bindings, err := bindings(defs, policies)
if err != nil {
return MexDocument{}, err
}
userPass, windows, err := endpoints(defs, bindings)
if err != nil {
return MexDocument{}, err
}
return MexDocument{
UsernamePasswordEndpoint: userPass,
WindowsTransportEndpoint: windows,
policies: policies,
bindings: bindings,
}, nil
}
func policies(defs Definitions) (map[string]endpointType, error) {
policies := make(map[string]endpointType, len(defs.Policy))
for _, policy := range defs.Policy {
if policy.ExactlyOne.All.NegotiateAuthentication.XMLName.Local != "" {
if policy.ExactlyOne.All.TransportBinding.SP != "" && policy.ID != "" {
policies["#"+policy.ID] = etWindowsTransport
}
}
if policy.ExactlyOne.All.SignedEncryptedSupportingTokens.Policy.UsernameToken.Policy.WSSUsernameToken10.XMLName.Local != "" {
if policy.ExactlyOne.All.TransportBinding.SP != "" && policy.ID != "" {
policies["#"+policy.ID] = etUsernamePassword
}
}
if policy.ExactlyOne.All.SignedSupportingTokens.Policy.UsernameToken.Policy.WSSUsernameToken10.XMLName.Local != "" {
if policy.ExactlyOne.All.TransportBinding.SP != "" && policy.ID != "" {
policies["#"+policy.ID] = etUsernamePassword
}
}
}
if len(policies) == 0 {
return policies, errors.New("no policies for mex document")
}
return policies, nil
}
func bindings(defs Definitions, policies map[string]endpointType) (map[string]wsEndpointData, error) {
bindings := make(map[string]wsEndpointData, len(defs.Binding))
for _, binding := range defs.Binding {
policyName := binding.PolicyReference.URI
transport := binding.Binding.Transport
if transport == "http://schemas.xmlsoap.org/soap/http" {
if policy, ok := policies[policyName]; ok {
bindingName := binding.Name
specVersion := binding.Operation.Operation.SoapAction
if specVersion == trust13Spec {
bindings[bindingName] = wsEndpointData{Trust13, policy}
} else if specVersion == trust2005Spec {
bindings[bindingName] = wsEndpointData{Trust2005, policy}
} else {
return nil, errors.New("found unknown spec version in mex document")
}
}
}
}
return bindings, nil
}
func endpoints(defs Definitions, bindings map[string]wsEndpointData) (userPass, windows Endpoint, err error) {
for _, port := range defs.Service.Port {
bindingName := port.Binding
index := strings.Index(bindingName, ":")
if index != -1 {
bindingName = bindingName[index+1:]
}
if binding, ok := bindings[bindingName]; ok {
url := strings.TrimSpace(port.EndpointReference.Address.Text)
if url == "" {
return Endpoint{}, Endpoint{}, fmt.Errorf("MexDocument cannot have blank URL endpoint")
}
if binding.Version == TrustUnknown {
return Endpoint{}, Endpoint{}, fmt.Errorf("endpoint version unknown")
}
endpoint := Endpoint{Version: binding.Version, URL: url}
switch binding.EndpointType {
case etUsernamePassword:
updateEndpoint(&userPass, endpoint)
case etWindowsTransport:
updateEndpoint(&windows, endpoint)
default:
return Endpoint{}, Endpoint{}, errors.New("found unknown port type in MEX document")
}
}
}
return userPass, windows, nil
}
microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/wstrust/wstrust.go 0000664 0000000 0000000 00000011623 14420263624 0031706 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
/*
Package wstrust provides a client for communicating with a WSTrust (https://en.wikipedia.org/wiki/WS-Trust#:~:text=WS%2DTrust%20is%20a%20WS,in%20a%20secure%20message%20exchange.)
for the purposes of extracting metadata from the service. This data can be used to acquire
tokens using the accesstokens.Client.GetAccessTokenFromSamlGrant() call.
*/
package wstrust
import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/internal/grant"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/wstrust/defs"
)
type xmlCaller interface {
XMLCall(ctx context.Context, endpoint string, headers http.Header, qv url.Values, resp interface{}) error
SOAPCall(ctx context.Context, endpoint, action string, headers http.Header, qv url.Values, body string, resp interface{}) error
}
type SamlTokenInfo struct {
AssertionType string // Should be either constants SAMLV1Grant or SAMLV2Grant.
Assertion string
}
// Client represents the REST calls to get tokens from token generator backends.
type Client struct {
// Comm provides the HTTP transport client.
Comm xmlCaller
}
// TODO(msal): This allows me to call Mex without having a real Def file on line 45.
// This would fail because policies() would not find a policy. This is easy enough to
// fix in test data, but.... Definitions is defined with built in structs. That needs
// to be pulled apart and until then I have this hack in.
var newFromDef = defs.NewFromDef
// Mex provides metadata about a wstrust service.
func (c Client) Mex(ctx context.Context, federationMetadataURL string) (defs.MexDocument, error) {
resp := defs.Definitions{}
err := c.Comm.XMLCall(
ctx,
federationMetadataURL,
http.Header{},
nil,
&resp,
)
if err != nil {
return defs.MexDocument{}, err
}
return newFromDef(resp)
}
const (
SoapActionDefault = "http://docs.oasis-open.org/ws-sx/ws-trust/200512/RST/Issue"
// Note: Commented out because this action is not supported. It was in the original code
// but only used in a switch where it errored. Since there was only one value, a default
// worked better. However, buildTokenRequestMessage() had 2005 support. I'm not actually
// sure what's going on here. It like we have half support. For now this is here just
// for documentation purposes in case we are going to add support.
//
// SoapActionWSTrust2005 = "http://schemas.xmlsoap.org/ws/2005/02/trust/RST/Issue"
)
// SAMLTokenInfo provides SAML information that is used to generate a SAML token.
func (c Client) SAMLTokenInfo(ctx context.Context, authParameters authority.AuthParams, cloudAudienceURN string, endpoint defs.Endpoint) (SamlTokenInfo, error) {
var wsTrustRequestMessage string
var err error
switch authParameters.AuthorizationType {
case authority.ATWindowsIntegrated:
wsTrustRequestMessage, err = endpoint.BuildTokenRequestMessageWIA(cloudAudienceURN)
if err != nil {
return SamlTokenInfo{}, err
}
case authority.ATUsernamePassword:
wsTrustRequestMessage, err = endpoint.BuildTokenRequestMessageUsernamePassword(
cloudAudienceURN, authParameters.Username, authParameters.Password)
if err != nil {
return SamlTokenInfo{}, err
}
default:
return SamlTokenInfo{}, fmt.Errorf("unknown auth type %v", authParameters.AuthorizationType)
}
var soapAction string
switch endpoint.Version {
case defs.Trust13:
soapAction = SoapActionDefault
case defs.Trust2005:
return SamlTokenInfo{}, errors.New("WS Trust 2005 support is not implemented")
default:
return SamlTokenInfo{}, fmt.Errorf("the SOAP endpoint for a wstrust call had an invalid version: %v", endpoint.Version)
}
resp := defs.SAMLDefinitions{}
err = c.Comm.SOAPCall(ctx, endpoint.URL, soapAction, http.Header{}, nil, wsTrustRequestMessage, &resp)
if err != nil {
return SamlTokenInfo{}, err
}
return c.samlAssertion(resp)
}
const (
samlv1Assertion = "urn:oasis:names:tc:SAML:1.0:assertion"
samlv2Assertion = "urn:oasis:names:tc:SAML:2.0:assertion"
)
func (c Client) samlAssertion(def defs.SAMLDefinitions) (SamlTokenInfo, error) {
for _, tokenResponse := range def.Body.RequestSecurityTokenResponseCollection.RequestSecurityTokenResponse {
token := tokenResponse.RequestedSecurityToken
if token.Assertion.XMLName.Local != "" {
assertion := token.AssertionRawXML
samlVersion := token.Assertion.Saml
switch samlVersion {
case samlv1Assertion:
return SamlTokenInfo{AssertionType: grant.SAMLV1, Assertion: assertion}, nil
case samlv2Assertion:
return SamlTokenInfo{AssertionType: grant.SAMLV2, Assertion: assertion}, nil
}
return SamlTokenInfo{}, fmt.Errorf("couldn't parse SAML assertion, version unknown: %q", samlVersion)
}
}
return SamlTokenInfo{}, errors.New("unknown WS-Trust version")
}
microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/ops/wstrust/wstrust_test.go 0000664 0000000 0000000 00000036240 14420263624 0032747 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package wstrust
import (
"context"
"encoding/xml"
"errors"
"fmt"
"net/http"
"net/url"
"reflect"
"regexp"
"strings"
"testing"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/wstrust/defs"
"github.com/kylelemons/godebug/diff"
"github.com/kylelemons/godebug/pretty"
)
var testAuthorityEndpoints = authority.NewEndpoints(
"https://login.microsoftonline.com/v2.0/authorize",
"https://login.microsoftonline.com/v2.0/token",
"https://login.microsoftonline.com/v2.0",
"login.microsoftonline.com",
)
type fakeXMLCaller struct {
err bool
giveResp interface{}
gotAction string
gotEndpoint string
gotQV url.Values
gotHeaders http.Header
gotBody interface{}
gotResp interface{}
}
func (f *fakeXMLCaller) XMLCall(ctx context.Context, endpoint string, headers http.Header, qv url.Values, resp interface{}) error {
if f.err {
return errors.New("error")
}
f.gotEndpoint = endpoint
f.gotHeaders = headers
f.gotQV = qv
f.gotResp = resp
return nil
}
func (f *fakeXMLCaller) SOAPCall(ctx context.Context, endpoint, action string, headers http.Header, qv url.Values, body string, resp interface{}) error {
if f.err {
return errors.New("error")
}
f.gotEndpoint = endpoint
f.gotAction = action
f.gotHeaders = headers
f.gotQV = qv
f.gotBody = body
f.gotResp = resp
if f.giveResp != nil {
b, err := xml.MarshalIndent(f.giveResp, "", "\t")
if err != nil {
panic(err)
}
if err := xml.Unmarshal(b, resp); err != nil {
panic(err)
}
}
return nil
}
func (f *fakeXMLCaller) compareBase(endpoint string, headers http.Header, qv url.Values, resp interface{}) error {
if f.gotEndpoint != endpoint {
return fmt.Errorf("got endpoint == %s, want endpoint == %s", f.gotEndpoint, endpoint)
}
if diff := pretty.Compare(headers, f.gotHeaders); diff != "" {
return fmt.Errorf("headers -want/+got:\n%s", diff)
}
if diff := pretty.Compare(qv, f.gotQV); diff != "" {
return fmt.Errorf("qv -want/+got:\n%s", diff)
}
gotValue := reflect.ValueOf(f.gotResp)
if gotValue.Kind() != reflect.Ptr {
return fmt.Errorf("resp cannot be a non-pointer type")
}
gotValue = gotValue.Elem()
gotName := gotValue.Type().Name()
wantName := reflect.ValueOf(resp).Elem().Type().Name()
if gotName != wantName {
return fmt.Errorf("resp type was %s, want %s", gotName, wantName)
}
return nil
}
func (f *fakeXMLCaller) compareXML(endpoint string, resp interface{}) error {
if err := f.compareBase(endpoint, http.Header{}, url.Values{}, resp); err != nil {
return err
}
return nil
}
var replaceURNRE = regexp.MustCompile(`urn:uuid:.*`)
func (f *fakeXMLCaller) compareSOAP(action, endpoint string, body, resp interface{}) error {
if err := f.compareBase(endpoint, http.Header{}, nil, resp); err != nil {
return err
}
if f.gotAction != action {
return fmt.Errorf("got endpoint == %s, want endpoint == %s", f.gotEndpoint, endpoint)
}
// Removes a uuid that will change every time.
// example: `urn:uuid:373ea2fa-d586-4cad-8bb8-10392ddbb5c6``
bodyStr := replaceURNRE.ReplaceAllString(body.(string), ``)
gotBodyStr := replaceURNRE.ReplaceAllString(body.(string), ``)
if diff := diff.Diff(
strings.ReplaceAll(bodyStr, ">", ">\n"), // So we can do a line by line comparison
strings.ReplaceAll(gotBodyStr, ">", ">\n"), // So we can do a line by line comparison
); diff != "" {
return fmt.Errorf("body -want/+got:\n%s", diff)
}
return nil
}
func TestMex(t *testing.T) {
tests := []struct {
desc string
err bool
createErr bool
newFromDef func(d defs.Definitions) (defs.MexDocument, error)
federationMetadataURL string
}{
{
desc: "Error: comm returns error",
err: true,
},
{
desc: "Definition was bad",
federationMetadataURL: "",
newFromDef: func(d defs.Definitions) (defs.MexDocument, error) {
return defs.MexDocument{}, errors.New("error")
},
err: true,
},
{
desc: "Success",
federationMetadataURL: "",
newFromDef: func(d defs.Definitions) (defs.MexDocument, error) {
return defs.MexDocument{}, nil
},
},
}
defer func() { newFromDef = defs.NewFromDef }()
for _, test := range tests {
newFromDef = test.newFromDef
fake := &fakeXMLCaller{err: test.err}
client := Client{Comm: fake}
// We don't care about the result, that is just a translation from the XML handled
// in the comm package via wstrust.CreateWsTrustMexDocumentFromDef().
// We care only that the comm package got what the right inputs.
_, err := client.Mex(context.Background(), "http://something")
switch {
case err == nil && test.err:
t.Errorf("TestMex(%s): got err == nil , want err != nil", test.desc)
continue
case err != nil && !test.err:
t.Errorf("TestMex(%s): got err == %s , want err == nil", test.desc, err)
continue
case err != nil:
continue
}
if err := fake.compareXML("http://something", &defs.Definitions{}); err != nil {
t.Errorf("TestMex(%s): %s", test.desc, err)
}
}
}
func TestSAMLTokenInfo(t *testing.T) {
authParams := authority.AuthParams{
Username: "username",
Password: "password",
Endpoints: testAuthorityEndpoints,
ClientID: "clientID",
}
// Note: We don't tests any error conditions built on buildTokenRequestMessage(),
// as they can only fail if the xml marshaller fails.
tests := []struct {
desc string
err bool
commErr bool
endpoint defs.Endpoint
body string
action string
authorizationType authority.AuthorizeType
giveResp defs.SAMLDefinitions
}{
{
desc: "Error: comm returns error",
err: true,
commErr: true,
endpoint: defs.Endpoint{Version: defs.Trust13, URL: "upEndpoint"},
action: SoapActionDefault,
authorizationType: authority.ATWindowsIntegrated,
body: "http://docs.oasis-open.org/ws-sx/ws-trust/200512/RST/Issueurn:uuid:fb8ec65b-f117-468f-b4e8-50c5e802affehttp://www.w3.org/2005/08/addressing/anonymousupEndpointurnhttp://docs.oasis-open.org/ws-sx/ws-trust/200512/Bearerhttp://docs.oasis-open.org/ws-sx/ws-trust/200512/Issue",
giveResp: defs.SAMLDefinitions{
Body: defs.Body{
RequestSecurityTokenResponseCollection: defs.RequestSecurityTokenResponseCollection{
RequestSecurityTokenResponse: []defs.RequestSecurityTokenResponse{
{
RequestedSecurityToken: defs.RequestedSecurityToken{
Assertion: defs.Assertion{
Text: "hello",
XMLName: xml.Name{
Local: "Assertion",
},
Saml: samlv1Assertion,
},
},
},
},
},
},
},
},
{
desc: "Error: Trust2005 endpoint, which isn't supported",
err: true,
endpoint: defs.Endpoint{Version: defs.Trust2005, URL: "upEndpoint"},
action: SoapActionDefault,
authorizationType: authority.ATWindowsIntegrated,
body: "http://docs.oasis-open.org/ws-sx/ws-trust/200512/RST/Issueurn:uuid:fb8ec65b-f117-468f-b4e8-50c5e802affehttp://www.w3.org/2005/08/addressing/anonymousupEndpointurnhttp://docs.oasis-open.org/ws-sx/ws-trust/200512/Bearerhttp://docs.oasis-open.org/ws-sx/ws-trust/200512/Issue",
giveResp: defs.SAMLDefinitions{
Body: defs.Body{
RequestSecurityTokenResponseCollection: defs.RequestSecurityTokenResponseCollection{
RequestSecurityTokenResponse: []defs.RequestSecurityTokenResponse{
{
RequestedSecurityToken: defs.RequestedSecurityToken{
Assertion: defs.Assertion{
Text: "hello",
XMLName: xml.Name{
Local: "Assertion",
},
Saml: samlv1Assertion,
},
},
},
},
},
},
},
},
{
desc: "Success: SAMLV1 assertion with AuthorizationTypeWindowsIntegratedAuth",
endpoint: defs.Endpoint{Version: defs.Trust13, URL: "upEndpoint"},
action: SoapActionDefault,
authorizationType: authority.ATWindowsIntegrated,
body: "http://docs.oasis-open.org/ws-sx/ws-trust/200512/RST/Issueurn:uuid:fb8ec65b-f117-468f-b4e8-50c5e802affehttp://www.w3.org/2005/08/addressing/anonymousupEndpointurnhttp://docs.oasis-open.org/ws-sx/ws-trust/200512/Bearerhttp://docs.oasis-open.org/ws-sx/ws-trust/200512/Issue",
giveResp: defs.SAMLDefinitions{
Body: defs.Body{
RequestSecurityTokenResponseCollection: defs.RequestSecurityTokenResponseCollection{
RequestSecurityTokenResponse: []defs.RequestSecurityTokenResponse{
{
RequestedSecurityToken: defs.RequestedSecurityToken{
Assertion: defs.Assertion{
Text: "hello",
XMLName: xml.Name{
Local: "Assertion",
},
Saml: samlv1Assertion,
},
},
},
},
},
},
},
},
{
desc: "Success: SAMLV2 assertion with AuthorizationTypeUsernamePassword",
endpoint: defs.Endpoint{Version: defs.Trust13, URL: "upEndpoint"},
action: SoapActionDefault,
authorizationType: authority.ATUsernamePassword,
body: "http://docs.oasis-open.org/ws-sx/ws-trust/200512/RST/Issueurn:uuid:fb8ec65b-f117-468f-b4e8-50c5e802affehttp://www.w3.org/2005/08/addressing/anonymousupEndpointurnhttp://docs.oasis-open.org/ws-sx/ws-trust/200512/Bearerhttp://docs.oasis-open.org/ws-sx/ws-trust/200512/Issue",
giveResp: defs.SAMLDefinitions{
Body: defs.Body{
RequestSecurityTokenResponseCollection: defs.RequestSecurityTokenResponseCollection{
RequestSecurityTokenResponse: []defs.RequestSecurityTokenResponse{
{
RequestedSecurityToken: defs.RequestedSecurityToken{
Assertion: defs.Assertion{
Text: "hello",
XMLName: xml.Name{
Local: "Assertion",
},
Saml: samlv2Assertion,
},
},
},
},
},
},
},
},
}
for _, test := range tests {
fake := &fakeXMLCaller{err: test.commErr, giveResp: test.giveResp}
client := Client{Comm: fake}
authParams.AuthorizationType = test.authorizationType
// We don't care about the result, that is just a translation from the XML handled
// in the comm package via wstrust.CreateWsTrustMexDocumentFromDef().
// We care only that the comm package got the right inputs.
_, err := client.SAMLTokenInfo(context.Background(), authParams, "urn", test.endpoint)
switch {
case err == nil && test.err:
t.Errorf("TestSAMLTokenInfo(%s): got err == nil , want err != nil", test.desc)
continue
case err != nil && !test.err:
t.Errorf("TestSAMLTokenInfo(%s): got err == %s , want err == nil", test.desc, err)
continue
case err != nil:
continue
}
if err := fake.compareSOAP(test.action, test.endpoint.URL, test.body, &defs.SAMLDefinitions{}); err != nil {
t.Errorf("TestSAMLTokenInfo(%s): %s", test.desc, err)
}
}
}
microsoft-authentication-library-for-go-1.0.0/apps/internal/oauth/resolvers.go 0000664 0000000 0000000 00000011462 14420263624 0027644 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
// TODO(msal): Write some tests. The original code this came from didn't have tests and I'm too
// tired at this point to do it. It, like many other *Manager code I found was broken because
// they didn't have mutex protection.
package oauth
import (
"context"
"errors"
"fmt"
"strings"
"sync"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
)
// ADFS is an active directory federation service authority type.
const ADFS = "ADFS"
type cacheEntry struct {
Endpoints authority.Endpoints
ValidForDomainsInList map[string]bool
}
func createcacheEntry(endpoints authority.Endpoints) cacheEntry {
return cacheEntry{endpoints, map[string]bool{}}
}
// AuthorityEndpoint retrieves endpoints from an authority for auth and token acquisition.
type authorityEndpoint struct {
rest *ops.REST
mu sync.Mutex
cache map[string]cacheEntry
}
// newAuthorityEndpoint is the constructor for AuthorityEndpoint.
func newAuthorityEndpoint(rest *ops.REST) *authorityEndpoint {
m := &authorityEndpoint{rest: rest, cache: map[string]cacheEntry{}}
return m
}
// ResolveEndpoints gets the authorization and token endpoints and creates an AuthorityEndpoints instance
func (m *authorityEndpoint) ResolveEndpoints(ctx context.Context, authorityInfo authority.Info, userPrincipalName string) (authority.Endpoints, error) {
if endpoints, found := m.cachedEndpoints(authorityInfo, userPrincipalName); found {
return endpoints, nil
}
endpoint, err := m.openIDConfigurationEndpoint(ctx, authorityInfo, userPrincipalName)
if err != nil {
return authority.Endpoints{}, err
}
resp, err := m.rest.Authority().GetTenantDiscoveryResponse(ctx, endpoint)
if err != nil {
return authority.Endpoints{}, err
}
if err := resp.Validate(); err != nil {
return authority.Endpoints{}, fmt.Errorf("ResolveEndpoints(): %w", err)
}
tenant := authorityInfo.Tenant
endpoints := authority.NewEndpoints(
strings.Replace(resp.AuthorizationEndpoint, "{tenant}", tenant, -1),
strings.Replace(resp.TokenEndpoint, "{tenant}", tenant, -1),
strings.Replace(resp.Issuer, "{tenant}", tenant, -1),
authorityInfo.Host)
m.addCachedEndpoints(authorityInfo, userPrincipalName, endpoints)
return endpoints, nil
}
// cachedEndpoints returns a the cached endpoints if they exists. If not, we return false.
func (m *authorityEndpoint) cachedEndpoints(authorityInfo authority.Info, userPrincipalName string) (authority.Endpoints, bool) {
m.mu.Lock()
defer m.mu.Unlock()
if cacheEntry, ok := m.cache[authorityInfo.CanonicalAuthorityURI]; ok {
if authorityInfo.AuthorityType == ADFS {
domain, err := adfsDomainFromUpn(userPrincipalName)
if err == nil {
if _, ok := cacheEntry.ValidForDomainsInList[domain]; ok {
return cacheEntry.Endpoints, true
}
}
}
return cacheEntry.Endpoints, true
}
return authority.Endpoints{}, false
}
func (m *authorityEndpoint) addCachedEndpoints(authorityInfo authority.Info, userPrincipalName string, endpoints authority.Endpoints) {
m.mu.Lock()
defer m.mu.Unlock()
updatedCacheEntry := createcacheEntry(endpoints)
if authorityInfo.AuthorityType == ADFS {
// Since we're here, we've made a call to the backend. We want to ensure we're caching
// the latest values from the server.
if cacheEntry, ok := m.cache[authorityInfo.CanonicalAuthorityURI]; ok {
for k := range cacheEntry.ValidForDomainsInList {
updatedCacheEntry.ValidForDomainsInList[k] = true
}
}
domain, err := adfsDomainFromUpn(userPrincipalName)
if err == nil {
updatedCacheEntry.ValidForDomainsInList[domain] = true
}
}
m.cache[authorityInfo.CanonicalAuthorityURI] = updatedCacheEntry
}
func (m *authorityEndpoint) openIDConfigurationEndpoint(ctx context.Context, authorityInfo authority.Info, userPrincipalName string) (string, error) {
if authorityInfo.Tenant == "adfs" {
return fmt.Sprintf("https://%s/adfs/.well-known/openid-configuration", authorityInfo.Host), nil
} else if authorityInfo.ValidateAuthority && !authority.TrustedHost(authorityInfo.Host) {
resp, err := m.rest.Authority().AADInstanceDiscovery(ctx, authorityInfo)
if err != nil {
return "", err
}
return resp.TenantDiscoveryEndpoint, nil
} else if authorityInfo.Region != "" {
resp, err := m.rest.Authority().AADInstanceDiscovery(ctx, authorityInfo)
if err != nil {
return "", err
}
return resp.TenantDiscoveryEndpoint, nil
}
return authorityInfo.CanonicalAuthorityURI + "v2.0/.well-known/openid-configuration", nil
}
func adfsDomainFromUpn(userPrincipalName string) (string, error) {
parts := strings.Split(userPrincipalName, "@")
if len(parts) < 2 {
return "", errors.New("no @ present in user principal name")
}
return parts[1], nil
}
microsoft-authentication-library-for-go-1.0.0/apps/internal/options/ 0000775 0000000 0000000 00000000000 14420263624 0025640 5 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/apps/internal/options/options.go 0000664 0000000 0000000 00000002757 14420263624 0027675 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package options
import (
"errors"
"fmt"
)
// CallOption implements an optional argument to a method call. See
// https://blog.devgenius.io/go-call-option-that-can-be-used-with-multiple-methods-6c81734f3dbe
// for an explanation of the usage pattern.
type CallOption interface {
Do(any) error
callOption()
}
// ApplyOptions applies all the callOptions to options. options must be a pointer to a struct and
// callOptions must be a list of objects that implement CallOption.
func ApplyOptions[O, C any](options O, callOptions []C) error {
for _, o := range callOptions {
if t, ok := any(o).(CallOption); !ok {
return fmt.Errorf("unexpected option type %T", o)
} else if err := t.Do(options); err != nil {
return err
}
}
return nil
}
// NewCallOption returns a new CallOption whose Do() method calls function "f".
func NewCallOption(f func(any) error) CallOption {
if f == nil {
// This isn't a practical concern because only an MSAL maintainer can get
// us here, by implementing a do-nothing option. But if someone does that,
// the below ensures the method invoked with the option returns an error.
return callOption(func(any) error {
return errors.New("invalid option: missing implementation")
})
}
return callOption(f)
}
// callOption is an adapter for a function to a CallOption
type callOption func(any) error
func (c callOption) Do(a any) error {
return c(a)
}
func (callOption) callOption() {}
microsoft-authentication-library-for-go-1.0.0/apps/internal/shared/ 0000775 0000000 0000000 00000000000 14420263624 0025413 5 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/apps/internal/shared/shared.go 0000664 0000000 0000000 00000003757 14420263624 0027224 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package shared
import (
"net/http"
"reflect"
"strings"
)
const (
// CacheKeySeparator is used in creating the keys of the cache.
CacheKeySeparator = "-"
)
type Account struct {
HomeAccountID string `json:"home_account_id,omitempty"`
Environment string `json:"environment,omitempty"`
Realm string `json:"realm,omitempty"`
LocalAccountID string `json:"local_account_id,omitempty"`
AuthorityType string `json:"authority_type,omitempty"`
PreferredUsername string `json:"username,omitempty"`
GivenName string `json:"given_name,omitempty"`
FamilyName string `json:"family_name,omitempty"`
MiddleName string `json:"middle_name,omitempty"`
Name string `json:"name,omitempty"`
AlternativeID string `json:"alternative_account_id,omitempty"`
RawClientInfo string `json:"client_info,omitempty"`
UserAssertionHash string `json:"user_assertion_hash,omitempty"`
AdditionalFields map[string]interface{}
}
// NewAccount creates an account.
func NewAccount(homeAccountID, env, realm, localAccountID, authorityType, username string) Account {
return Account{
HomeAccountID: homeAccountID,
Environment: env,
Realm: realm,
LocalAccountID: localAccountID,
AuthorityType: authorityType,
PreferredUsername: username,
}
}
// Key creates the key for storing accounts in the cache.
func (acc Account) Key() string {
return strings.Join([]string{acc.HomeAccountID, acc.Environment, acc.Realm}, CacheKeySeparator)
}
// IsZero checks the zero value of account.
func (acc Account) IsZero() bool {
v := reflect.ValueOf(acc)
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
if !field.IsZero() {
switch field.Kind() {
case reflect.Map, reflect.Slice:
if field.Len() == 0 {
continue
}
}
return false
}
}
return true
}
// DefaultClient is our default shared HTTP client.
var DefaultClient = &http.Client{}
microsoft-authentication-library-for-go-1.0.0/apps/internal/shared/shared_test.go 0000664 0000000 0000000 00000004217 14420263624 0030253 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package shared
import (
stdJSON "encoding/json"
"testing"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json"
"github.com/kylelemons/godebug/pretty"
)
var (
accHID = "hid"
accEnv = "env"
accRealm = "realm"
authType = "MSSTS"
accLid = "lid"
accUser = "user"
)
func TestAccountUnmarshal(t *testing.T) {
jsonMap := map[string]interface{}{
"home_account_id": "hid",
"environment": "env",
"extra": "this_is_extra",
"authority_type": authType,
}
b, err := stdJSON.Marshal(jsonMap)
if err != nil {
panic(err)
}
want := Account{
HomeAccountID: accHID,
Environment: accEnv,
AuthorityType: authType,
AdditionalFields: map[string]interface{}{
"extra": json.MarshalRaw("this_is_extra"),
},
}
got := Account{}
err = json.Unmarshal(b, &got)
if err != nil {
panic(err)
}
if diff := pretty.Compare(want, got); diff != "" {
t.Errorf("TestAccountUnmarshal: -want/+got:\n%s", diff)
}
}
func TestAccountKey(t *testing.T) {
acc := &Account{
HomeAccountID: accHID,
Environment: accEnv,
Realm: accRealm,
}
expectedKey := "hid-env-realm"
actualKey := acc.Key()
if expectedKey != actualKey {
t.Errorf("Actual key %s differs from expected key %s", actualKey, expectedKey)
}
}
func TestAccountMarshal(t *testing.T) {
acc := Account{
HomeAccountID: accHID,
Environment: accEnv,
Realm: accRealm,
LocalAccountID: accLid,
AuthorityType: authType,
PreferredUsername: accUser,
AdditionalFields: map[string]interface{}{"extra": "extra"},
}
want := map[string]interface{}{
"home_account_id": "hid",
"environment": "env",
"realm": "realm",
"local_account_id": "lid",
"authority_type": authType,
"username": "user",
"extra": "extra",
}
b, err := json.Marshal(acc)
if err != nil {
panic(err)
}
got := map[string]interface{}{}
if err := stdJSON.Unmarshal(b, &got); err != nil {
panic(err)
}
if diff := pretty.Compare(want, got); diff != "" {
t.Errorf("TestAccountMarshal: -want/+got:\n%s", diff)
}
}
microsoft-authentication-library-for-go-1.0.0/apps/internal/version/ 0000775 0000000 0000000 00000000000 14420263624 0025632 5 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/apps/internal/version/version.go 0000664 0000000 0000000 00000000415 14420263624 0027646 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
// Package version keeps the version number of the client package.
package version
// Version is the version of this client package that is communicated to the server.
const Version = "1.0.0"
microsoft-authentication-library-for-go-1.0.0/apps/public/ 0000775 0000000 0000000 00000000000 14420263624 0023607 5 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/apps/public/public.go 0000664 0000000 0000000 00000052260 14420263624 0025421 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
/*
Package public provides a client for authentication of "public" applications. A "public"
application is defined as an app that runs on client devices (android, ios, windows, linux, ...).
These devices are "untrusted" and access resources via web APIs that must authenticate.
*/
package public
/*
Design note:
public.Client uses client.Base as an embedded type. client.Base statically assigns its attributes
during creation. As it doesn't have any pointers in it, anything borrowed from it, such as
Base.AuthParams is a copy that is free to be manipulated here.
*/
// TODO(msal): This should have example code for each method on client using Go's example doc framework.
// base usage details should be includee in the package documentation.
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"fmt"
"net/url"
"strconv"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/local"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/options"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/shared"
"github.com/google/uuid"
"github.com/pkg/browser"
)
// AuthResult contains the results of one token acquisition operation.
// For details see https://aka.ms/msal-net-authenticationresult
type AuthResult = base.AuthResult
type Account = shared.Account
// clientOptions configures the Client's behavior.
type clientOptions struct {
accessor cache.ExportReplace
authority string
capabilities []string
disableInstanceDiscovery bool
httpClient ops.HTTPClient
}
func (p *clientOptions) validate() error {
u, err := url.Parse(p.authority)
if err != nil {
return fmt.Errorf("Authority options cannot be URL parsed: %w", err)
}
if u.Scheme != "https" {
return fmt.Errorf("Authority(%s) did not start with https://", u.String())
}
return nil
}
// Option is an optional argument to the New constructor.
type Option func(o *clientOptions)
// WithAuthority allows for a custom authority to be set. This must be a valid https url.
func WithAuthority(authority string) Option {
return func(o *clientOptions) {
o.authority = authority
}
}
// WithCache provides an accessor that will read and write authentication data to an externally managed cache.
func WithCache(accessor cache.ExportReplace) Option {
return func(o *clientOptions) {
o.accessor = accessor
}
}
// WithClientCapabilities allows configuring one or more client capabilities such as "CP1"
func WithClientCapabilities(capabilities []string) Option {
return func(o *clientOptions) {
// there's no danger of sharing the slice's underlying memory with the application because
// this slice is simply passed to base.WithClientCapabilities, which copies its data
o.capabilities = capabilities
}
}
// WithHTTPClient allows for a custom HTTP client to be set.
func WithHTTPClient(httpClient ops.HTTPClient) Option {
return func(o *clientOptions) {
o.httpClient = httpClient
}
}
// WithInstanceDiscovery set to false to disable authority validation (to support private cloud scenarios)
func WithInstanceDiscovery(enabled bool) Option {
return func(o *clientOptions) {
o.disableInstanceDiscovery = !enabled
}
}
// Client is a representation of authentication client for public applications as defined in the
// package doc. For more information, visit https://docs.microsoft.com/azure/active-directory/develop/msal-client-applications.
type Client struct {
base base.Client
}
// New is the constructor for Client.
func New(clientID string, options ...Option) (Client, error) {
opts := clientOptions{
authority: base.AuthorityPublicCloud,
httpClient: shared.DefaultClient,
}
for _, o := range options {
o(&opts)
}
if err := opts.validate(); err != nil {
return Client{}, err
}
base, err := base.New(clientID, opts.authority, oauth.New(opts.httpClient), base.WithCacheAccessor(opts.accessor), base.WithClientCapabilities(opts.capabilities), base.WithInstanceDiscovery(!opts.disableInstanceDiscovery))
if err != nil {
return Client{}, err
}
return Client{base}, nil
}
// authCodeURLOptions contains options for AuthCodeURL
type authCodeURLOptions struct {
claims, loginHint, tenantID, domainHint string
}
// AuthCodeURLOption is implemented by options for AuthCodeURL
type AuthCodeURLOption interface {
authCodeURLOption()
}
// AuthCodeURL creates a URL used to acquire an authorization code.
//
// Options: [WithClaims], [WithDomainHint], [WithLoginHint], [WithTenantID]
func (pca Client) AuthCodeURL(ctx context.Context, clientID, redirectURI string, scopes []string, opts ...AuthCodeURLOption) (string, error) {
o := authCodeURLOptions{}
if err := options.ApplyOptions(&o, opts); err != nil {
return "", err
}
ap, err := pca.base.AuthParams.WithTenant(o.tenantID)
if err != nil {
return "", err
}
ap.Claims = o.claims
ap.LoginHint = o.loginHint
ap.DomainHint = o.domainHint
return pca.base.AuthCodeURL(ctx, clientID, redirectURI, scopes, ap)
}
// WithClaims sets additional claims to request for the token, such as those required by conditional access policies.
// Use this option when Azure AD returned a claims challenge for a prior request. The argument must be decoded.
// This option is valid for any token acquisition method.
func WithClaims(claims string) interface {
AcquireByAuthCodeOption
AcquireByDeviceCodeOption
AcquireByUsernamePasswordOption
AcquireInteractiveOption
AcquireSilentOption
AuthCodeURLOption
options.CallOption
} {
return struct {
AcquireByAuthCodeOption
AcquireByDeviceCodeOption
AcquireByUsernamePasswordOption
AcquireInteractiveOption
AcquireSilentOption
AuthCodeURLOption
options.CallOption
}{
CallOption: options.NewCallOption(
func(a any) error {
switch t := a.(type) {
case *acquireTokenByAuthCodeOptions:
t.claims = claims
case *acquireTokenByDeviceCodeOptions:
t.claims = claims
case *acquireTokenByUsernamePasswordOptions:
t.claims = claims
case *acquireTokenSilentOptions:
t.claims = claims
case *authCodeURLOptions:
t.claims = claims
case *interactiveAuthOptions:
t.claims = claims
default:
return fmt.Errorf("unexpected options type %T", a)
}
return nil
},
),
}
}
// WithTenantID specifies a tenant for a single authentication. It may be different than the tenant set in [New] by [WithAuthority].
// This option is valid for any token acquisition method.
func WithTenantID(tenantID string) interface {
AcquireByAuthCodeOption
AcquireByDeviceCodeOption
AcquireByUsernamePasswordOption
AcquireInteractiveOption
AcquireSilentOption
AuthCodeURLOption
options.CallOption
} {
return struct {
AcquireByAuthCodeOption
AcquireByDeviceCodeOption
AcquireByUsernamePasswordOption
AcquireInteractiveOption
AcquireSilentOption
AuthCodeURLOption
options.CallOption
}{
CallOption: options.NewCallOption(
func(a any) error {
switch t := a.(type) {
case *acquireTokenByAuthCodeOptions:
t.tenantID = tenantID
case *acquireTokenByDeviceCodeOptions:
t.tenantID = tenantID
case *acquireTokenByUsernamePasswordOptions:
t.tenantID = tenantID
case *acquireTokenSilentOptions:
t.tenantID = tenantID
case *authCodeURLOptions:
t.tenantID = tenantID
case *interactiveAuthOptions:
t.tenantID = tenantID
default:
return fmt.Errorf("unexpected options type %T", a)
}
return nil
},
),
}
}
// acquireTokenSilentOptions are all the optional settings to an AcquireTokenSilent() call.
// These are set by using various AcquireTokenSilentOption functions.
type acquireTokenSilentOptions struct {
account Account
claims, tenantID string
}
// AcquireSilentOption is implemented by options for AcquireTokenSilent
type AcquireSilentOption interface {
acquireSilentOption()
}
// WithSilentAccount uses the passed account during an AcquireTokenSilent() call.
func WithSilentAccount(account Account) interface {
AcquireSilentOption
options.CallOption
} {
return struct {
AcquireSilentOption
options.CallOption
}{
CallOption: options.NewCallOption(
func(a any) error {
switch t := a.(type) {
case *acquireTokenSilentOptions:
t.account = account
default:
return fmt.Errorf("unexpected options type %T", a)
}
return nil
},
),
}
}
// AcquireTokenSilent acquires a token from either the cache or using a refresh token.
//
// Options: [WithClaims], [WithSilentAccount], [WithTenantID]
func (pca Client) AcquireTokenSilent(ctx context.Context, scopes []string, opts ...AcquireSilentOption) (AuthResult, error) {
o := acquireTokenSilentOptions{}
if err := options.ApplyOptions(&o, opts); err != nil {
return AuthResult{}, err
}
silentParameters := base.AcquireTokenSilentParameters{
Scopes: scopes,
Account: o.account,
Claims: o.claims,
RequestType: accesstokens.ATPublic,
IsAppCache: false,
TenantID: o.tenantID,
}
return pca.base.AcquireTokenSilent(ctx, silentParameters)
}
// acquireTokenByUsernamePasswordOptions contains optional configuration for AcquireTokenByUsernamePassword
type acquireTokenByUsernamePasswordOptions struct {
claims, tenantID string
}
// AcquireByUsernamePasswordOption is implemented by options for AcquireTokenByUsernamePassword
type AcquireByUsernamePasswordOption interface {
acquireByUsernamePasswordOption()
}
// AcquireTokenByUsernamePassword acquires a security token from the authority, via Username/Password Authentication.
// NOTE: this flow is NOT recommended.
//
// Options: [WithClaims], [WithTenantID]
func (pca Client) AcquireTokenByUsernamePassword(ctx context.Context, scopes []string, username, password string, opts ...AcquireByUsernamePasswordOption) (AuthResult, error) {
o := acquireTokenByUsernamePasswordOptions{}
if err := options.ApplyOptions(&o, opts); err != nil {
return AuthResult{}, err
}
authParams, err := pca.base.AuthParams.WithTenant(o.tenantID)
if err != nil {
return AuthResult{}, err
}
authParams.Scopes = scopes
authParams.AuthorizationType = authority.ATUsernamePassword
authParams.Claims = o.claims
authParams.Username = username
authParams.Password = password
token, err := pca.base.Token.UsernamePassword(ctx, authParams)
if err != nil {
return AuthResult{}, err
}
return pca.base.AuthResultFromToken(ctx, authParams, token, true)
}
type DeviceCodeResult = accesstokens.DeviceCodeResult
// DeviceCode provides the results of the device code flows first stage (containing the code)
// that must be entered on the second device and provides a method to retrieve the AuthenticationResult
// once that code has been entered and verified.
type DeviceCode struct {
// Result holds the information about the device code (such as the code).
Result DeviceCodeResult
authParams authority.AuthParams
client Client
dc oauth.DeviceCode
}
// AuthenticationResult retreives the AuthenticationResult once the user enters the code
// on the second device. Until then it blocks until the .AcquireTokenByDeviceCode() context
// is cancelled or the token expires.
func (d DeviceCode) AuthenticationResult(ctx context.Context) (AuthResult, error) {
token, err := d.dc.Token(ctx)
if err != nil {
return AuthResult{}, err
}
return d.client.base.AuthResultFromToken(ctx, d.authParams, token, true)
}
// acquireTokenByDeviceCodeOptions contains optional configuration for AcquireTokenByDeviceCode
type acquireTokenByDeviceCodeOptions struct {
claims, tenantID string
}
// AcquireByDeviceCodeOption is implemented by options for AcquireTokenByDeviceCode
type AcquireByDeviceCodeOption interface {
acquireByDeviceCodeOptions()
}
// AcquireTokenByDeviceCode acquires a security token from the authority, by acquiring a device code and using that to acquire the token.
// Users need to create an AcquireTokenDeviceCodeParameters instance and pass it in.
//
// Options: [WithClaims], [WithTenantID]
func (pca Client) AcquireTokenByDeviceCode(ctx context.Context, scopes []string, opts ...AcquireByDeviceCodeOption) (DeviceCode, error) {
o := acquireTokenByDeviceCodeOptions{}
if err := options.ApplyOptions(&o, opts); err != nil {
return DeviceCode{}, err
}
authParams, err := pca.base.AuthParams.WithTenant(o.tenantID)
if err != nil {
return DeviceCode{}, err
}
authParams.Scopes = scopes
authParams.AuthorizationType = authority.ATDeviceCode
authParams.Claims = o.claims
dc, err := pca.base.Token.DeviceCode(ctx, authParams)
if err != nil {
return DeviceCode{}, err
}
return DeviceCode{Result: dc.Result, authParams: authParams, client: pca, dc: dc}, nil
}
// acquireTokenByAuthCodeOptions contains the optional parameters used to acquire an access token using the authorization code flow.
type acquireTokenByAuthCodeOptions struct {
challenge, claims, tenantID string
}
// AcquireByAuthCodeOption is implemented by options for AcquireTokenByAuthCode
type AcquireByAuthCodeOption interface {
acquireByAuthCodeOption()
}
// WithChallenge allows you to provide a code for the .AcquireTokenByAuthCode() call.
func WithChallenge(challenge string) interface {
AcquireByAuthCodeOption
options.CallOption
} {
return struct {
AcquireByAuthCodeOption
options.CallOption
}{
CallOption: options.NewCallOption(
func(a any) error {
switch t := a.(type) {
case *acquireTokenByAuthCodeOptions:
t.challenge = challenge
default:
return fmt.Errorf("unexpected options type %T", a)
}
return nil
},
),
}
}
// AcquireTokenByAuthCode is a request to acquire a security token from the authority, using an authorization code.
// The specified redirect URI must be the same URI that was used when the authorization code was requested.
//
// Options: [WithChallenge], [WithClaims], [WithTenantID]
func (pca Client) AcquireTokenByAuthCode(ctx context.Context, code string, redirectURI string, scopes []string, opts ...AcquireByAuthCodeOption) (AuthResult, error) {
o := acquireTokenByAuthCodeOptions{}
if err := options.ApplyOptions(&o, opts); err != nil {
return AuthResult{}, err
}
params := base.AcquireTokenAuthCodeParameters{
Scopes: scopes,
Code: code,
Challenge: o.challenge,
Claims: o.claims,
AppType: accesstokens.ATPublic,
RedirectURI: redirectURI,
TenantID: o.tenantID,
}
return pca.base.AcquireTokenByAuthCode(ctx, params)
}
// Accounts gets all the accounts in the token cache.
// If there are no accounts in the cache the returned slice is empty.
func (pca Client) Accounts(ctx context.Context) ([]Account, error) {
return pca.base.AllAccounts(ctx)
}
// RemoveAccount signs the account out and forgets account from token cache.
func (pca Client) RemoveAccount(ctx context.Context, account Account) error {
return pca.base.RemoveAccount(ctx, account)
}
// interactiveAuthOptions contains the optional parameters used to acquire an access token for interactive auth code flow.
type interactiveAuthOptions struct {
claims, domainHint, loginHint, redirectURI, tenantID string
}
// AcquireInteractiveOption is implemented by options for AcquireTokenInteractive
type AcquireInteractiveOption interface {
acquireInteractiveOption()
}
// WithLoginHint pre-populates the login prompt with a username.
func WithLoginHint(username string) interface {
AcquireInteractiveOption
AuthCodeURLOption
options.CallOption
} {
return struct {
AcquireInteractiveOption
AuthCodeURLOption
options.CallOption
}{
CallOption: options.NewCallOption(
func(a any) error {
switch t := a.(type) {
case *authCodeURLOptions:
t.loginHint = username
case *interactiveAuthOptions:
t.loginHint = username
default:
return fmt.Errorf("unexpected options type %T", a)
}
return nil
},
),
}
}
// WithDomainHint adds the IdP domain as domain_hint query parameter in the auth url.
func WithDomainHint(domain string) interface {
AcquireInteractiveOption
AuthCodeURLOption
options.CallOption
} {
return struct {
AcquireInteractiveOption
AuthCodeURLOption
options.CallOption
}{
CallOption: options.NewCallOption(
func(a any) error {
switch t := a.(type) {
case *authCodeURLOptions:
t.domainHint = domain
case *interactiveAuthOptions:
t.domainHint = domain
default:
return fmt.Errorf("unexpected options type %T", a)
}
return nil
},
),
}
}
// WithRedirectURI sets a port for the local server used in interactive authentication, for
// example http://localhost:port. All URI components other than the port are ignored.
func WithRedirectURI(redirectURI string) interface {
AcquireInteractiveOption
options.CallOption
} {
return struct {
AcquireInteractiveOption
options.CallOption
}{
CallOption: options.NewCallOption(
func(a any) error {
switch t := a.(type) {
case *interactiveAuthOptions:
t.redirectURI = redirectURI
default:
return fmt.Errorf("unexpected options type %T", a)
}
return nil
},
),
}
}
// AcquireTokenInteractive acquires a security token from the authority using the default web browser to select the account.
// https://docs.microsoft.com/en-us/azure/active-directory/develop/msal-authentication-flows#interactive-and-non-interactive-authentication
//
// Options: [WithDomainHint], [WithLoginHint], [WithRedirectURI], [WithTenantID]
func (pca Client) AcquireTokenInteractive(ctx context.Context, scopes []string, opts ...AcquireInteractiveOption) (AuthResult, error) {
o := interactiveAuthOptions{}
if err := options.ApplyOptions(&o, opts); err != nil {
return AuthResult{}, err
}
// the code verifier is a random 32-byte sequence that's been base-64 encoded without padding.
// it's used to prevent MitM attacks during auth code flow, see https://tools.ietf.org/html/rfc7636
cv, challenge, err := codeVerifier()
if err != nil {
return AuthResult{}, err
}
var redirectURL *url.URL
if o.redirectURI != "" {
redirectURL, err = url.Parse(o.redirectURI)
if err != nil {
return AuthResult{}, err
}
}
authParams, err := pca.base.AuthParams.WithTenant(o.tenantID)
if err != nil {
return AuthResult{}, err
}
authParams.Scopes = scopes
authParams.AuthorizationType = authority.ATInteractive
authParams.Claims = o.claims
authParams.CodeChallenge = challenge
authParams.CodeChallengeMethod = "S256"
authParams.LoginHint = o.loginHint
authParams.DomainHint = o.domainHint
authParams.State = uuid.New().String()
authParams.Prompt = "select_account"
res, err := pca.browserLogin(ctx, redirectURL, authParams)
if err != nil {
return AuthResult{}, err
}
authParams.Redirecturi = res.redirectURI
req, err := accesstokens.NewCodeChallengeRequest(authParams, accesstokens.ATPublic, nil, res.authCode, cv)
if err != nil {
return AuthResult{}, err
}
token, err := pca.base.Token.AuthCode(ctx, req)
if err != nil {
return AuthResult{}, err
}
return pca.base.AuthResultFromToken(ctx, authParams, token, true)
}
type interactiveAuthResult struct {
authCode string
redirectURI string
}
// provides a test hook to simulate opening a browser
var browserOpenURL = func(authURL string) error {
return browser.OpenURL(authURL)
}
// parses the port number from the provided URL.
// returns 0 if nil or no port is specified.
func parsePort(u *url.URL) (int, error) {
if u == nil {
return 0, nil
}
p := u.Port()
if p == "" {
return 0, nil
}
return strconv.Atoi(p)
}
// browserLogin launches the system browser for interactive login
func (pca Client) browserLogin(ctx context.Context, redirectURI *url.URL, params authority.AuthParams) (interactiveAuthResult, error) {
// start local redirect server so login can call us back
port, err := parsePort(redirectURI)
if err != nil {
return interactiveAuthResult{}, err
}
srv, err := local.New(params.State, port)
if err != nil {
return interactiveAuthResult{}, err
}
defer srv.Shutdown()
params.Scopes = accesstokens.AppendDefaultScopes(params)
authURL, err := pca.base.AuthCodeURL(ctx, params.ClientID, srv.Addr, params.Scopes, params)
if err != nil {
return interactiveAuthResult{}, err
}
// open browser window so user can select credentials
if err := browserOpenURL(authURL); err != nil {
return interactiveAuthResult{}, err
}
// now wait until the logic calls us back
res := srv.Result(ctx)
if res.Err != nil {
return interactiveAuthResult{}, res.Err
}
return interactiveAuthResult{
authCode: res.Code,
redirectURI: srv.Addr,
}, nil
}
// creates a code verifier string along with its SHA256 hash which
// is used as the challenge when requesting an auth code.
// used in interactive auth flow for PKCE.
func codeVerifier() (codeVerifier string, challenge string, err error) {
cvBytes := make([]byte, 32)
if _, err = rand.Read(cvBytes); err != nil {
return
}
codeVerifier = base64.RawURLEncoding.EncodeToString(cvBytes)
// for PKCE, create a hash of the code verifier
cvh := sha256.Sum256([]byte(codeVerifier))
challenge = base64.RawURLEncoding.EncodeToString(cvh[:])
return
}
microsoft-authentication-library-for-go-1.0.0/apps/public/public_test.go 0000664 0000000 0000000 00000073643 14420263624 0026470 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package public
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"testing"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/mock"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/fake"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/wstrust"
"github.com/kylelemons/godebug/pretty"
)
const authorityFmt = "https://%s/%s"
var tokenScope = []string{"the_scope"}
func fakeBrowserOpenURL(authURL string) error {
// we will get called with the URL for requesting an auth code
u, err := url.Parse(authURL)
if err != nil {
return err
}
// validate the URL content
q := u.Query()
if q.Get("code_challenge") == "" {
return errors.New("missing query param 'code_challenge")
}
if m := q.Get("code_challenge_method"); m != "S256" {
return fmt.Errorf("unexpected code_challenge_method '%s'", m)
}
if q.Get("prompt") == "" {
return errors.New("missing query param 'prompt")
}
state := q.Get("state")
if state == "" {
return errors.New("missing query param 'state'")
}
redirect := q.Get("redirect_uri")
if redirect == "" {
return errors.New("missing query param 'redirect_uri'")
}
// now send the info to our local redirect server
resp, err := http.DefaultClient.Get(redirect + fmt.Sprintf("/?state=%s&code=fake_auth_code", state))
if err != nil {
return err
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("unexpected status code %d", resp.StatusCode)
}
return nil
}
func TestAcquireTokenInteractive(t *testing.T) {
realBrowserOpenURL := browserOpenURL
defer func() { browserOpenURL = realBrowserOpenURL }()
browserOpenURL = fakeBrowserOpenURL
client, err := New("some_client_id")
if err != nil {
t.Fatal(err)
}
client.base.Token.AccessTokens = &fake.AccessTokens{}
client.base.Token.Authority = &fake.Authority{}
client.base.Token.Resolver = &fake.ResolveEndpoints{}
client.base.Token.WSTrust = &fake.WSTrust{}
_, err = client.AcquireTokenInteractive(context.Background(), []string{"the_scope"})
if err != nil {
t.Fatal(err)
}
}
func TestAcquireTokenSilentHomeTenantAliases(t *testing.T) {
accessToken := "*"
homeTenant := "home-tenant"
clientInfo := base64.RawStdEncoding.EncodeToString([]byte(
fmt.Sprintf(`{"uid":"uid","utid":"%s"}`, homeTenant),
))
lmo := "login.microsoftonline.com"
for _, alias := range []string{"common", "organizations"} {
mockClient := mock.Client{}
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, alias)))
mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(accessToken, mock.GetIDToken(homeTenant, fmt.Sprintf(authorityFmt, lmo, homeTenant)), "rt", clientInfo, 3600)))
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, homeTenant)))
client, err := New("client-id", WithAuthority(fmt.Sprintf(authorityFmt, lmo, alias)), WithHTTPClient(&mockClient))
if err != nil {
t.Fatal(err)
}
// the auth flow isn't important, we just need to populate the cache
ar, err := client.AcquireTokenByAuthCode(context.Background(), "code", "https://localhost", tokenScope)
if err != nil {
t.Fatal(err)
}
if ar.AccessToken != accessToken {
t.Fatalf("expected %q, got %q", accessToken, ar.AccessToken)
}
account := ar.Account
ar, err = client.AcquireTokenSilent(context.Background(), tokenScope, WithSilentAccount(account))
if err != nil {
t.Fatal(err)
}
if ar.AccessToken != accessToken {
t.Fatalf("expected %q, got %q", accessToken, ar.AccessToken)
}
}
}
func TestAcquireTokenSilentWithTenantID(t *testing.T) {
tenantA, tenantB := "a", "b"
lmo := "login.microsoftonline.com"
mockClient := mock.Client{}
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenantA)))
client, err := New("client-id", WithAuthority(fmt.Sprintf(authorityFmt, lmo, tenantA)), WithHTTPClient(&mockClient))
if err != nil {
t.Fatal(err)
}
clientInfo := base64.RawStdEncoding.EncodeToString([]byte(`{"uid":"uid","utid":"utid"}`))
ctx := context.Background()
// cache an access token for each tenant. To simplify determining their provenance below, the value of each token is the ID of the tenant that provided it.
for _, tenant := range []string{tenantA, tenantB} {
if _, err = client.AcquireTokenSilent(ctx, tokenScope, WithTenantID(tenant)); err == nil {
t.Fatal("silent auth should fail because the cache is empty")
}
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenant)))
mockClient.AppendResponse(mock.WithBody([]byte(`{"account_type":"Managed","cloud_audience_urn":"urn","cloud_instance_name":"...","domain_name":"..."}`)))
mockClient.AppendResponse(mock.WithBody(
mock.GetAccessTokenBody(tenant, mock.GetIDToken(tenant, fmt.Sprintf(authorityFmt, lmo, tenant)), "rt-"+tenant, clientInfo, 3600)),
)
ar, err := client.AcquireTokenByUsernamePassword(ctx, tokenScope, "username", "password", WithTenantID(tenant))
if err != nil {
t.Fatal(err)
}
if ar.AccessToken != tenant {
t.Fatalf(`unexpected token "%s"`, ar.AccessToken)
}
}
// cache should return the correct access token for each tenant
var account Account
accounts, err := client.Accounts(ctx)
if err != nil {
t.Fatal(err)
}
// expecting one account for each tenant we authenticated in above
if len(accounts) == 2 {
account = accounts[0]
} else {
t.Fatalf("expected 2 accounts but got %d", len(accounts))
}
for _, test := range []struct {
desc, expected string
opts []AcquireSilentOption
}{
// when no tenant is specified the client should return the cached token for its configured authority
{"no tenant specified", tenantA, []AcquireSilentOption{WithSilentAccount(account)}},
// when a tenant is specified the client should return the cached token for that tenant
{"redundant tenant specified", tenantA, []AcquireSilentOption{WithSilentAccount(account), WithTenantID(tenantA)}},
{"different tenant specified", tenantB, []AcquireSilentOption{WithSilentAccount(account), WithTenantID(tenantB)}},
} {
t.Run(test.desc, func(t *testing.T) {
ar, err := client.AcquireTokenSilent(ctx, tokenScope, test.opts...)
if err != nil {
t.Fatal(err)
}
if ar.AccessToken != test.expected {
t.Fatalf(`expected "%s", got "%s"`, test.expected, ar.AccessToken)
}
})
}
}
func TestAcquireTokenWithTenantID(t *testing.T) {
// replacing browserOpenURL with a fake for the duration of this test enables testing AcquireTokenInteractive
realBrowserOpenURL := browserOpenURL
defer func() { browserOpenURL = realBrowserOpenURL }()
browserOpenURL = fakeBrowserOpenURL
accessToken := "*"
clientInfo := base64.RawStdEncoding.EncodeToString([]byte(`{"uid":"uid","utid":"utid"}`))
uuid1 := "00000000-0000-0000-0000-000000000000"
uuid2 := strings.ReplaceAll(uuid1, "0", "1")
lmo := "login.microsoftonline.com"
host := fmt.Sprintf("https://%s/", lmo)
for _, test := range []struct {
authority, expectedAuthority, tenant string
expectError bool
}{
{authority: host + "common", tenant: uuid1, expectedAuthority: host + uuid1},
{authority: host + "organizations", tenant: uuid1, expectedAuthority: host + uuid1},
{authority: host + uuid1, tenant: uuid2, expectedAuthority: host + uuid2},
{authority: host + uuid1, tenant: "common", expectError: true},
{authority: host + uuid1, tenant: "organizations", expectError: true},
{authority: host + "consumers", tenant: uuid1, expectError: true},
} {
for _, method := range []string{"authcode", "authcodeURL", "devicecode", "interactive", "password"} {
t.Run(method, func(t *testing.T) {
URL := ""
mockClient := mock.Client{}
if method == "obo" {
// TODO: OBO does instance discovery twice before first token request https://github.com/AzureAD/microsoft-authentication-library-for-go/issues/351
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, test.tenant)))
}
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, test.tenant)))
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, test.tenant)))
if method == "devicecode" {
mockClient.AppendResponse(mock.WithBody([]byte(`{"device_code":"...","expires_in":600}`)))
} else if method == "password" {
// user realm metadata
mockClient.AppendResponse(mock.WithBody([]byte(`{"account_type":"Managed","cloud_audience_urn":"urn","cloud_instance_name":"...","domain_name":"..."}`)))
}
mockClient.AppendResponse(
mock.WithBody(mock.GetAccessTokenBody(accessToken, mock.GetIDToken(test.tenant, test.authority), "rt", clientInfo, 3600)),
mock.WithCallback(func(r *http.Request) { URL = r.URL.String() }),
)
client, err := New("client-id", WithAuthority(test.authority), WithHTTPClient(&mockClient))
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
if _, err = client.AcquireTokenSilent(ctx, tokenScope, WithTenantID(test.tenant)); err == nil {
t.Fatal("silent auth should fail because the cache is empty")
}
var ar AuthResult
var dc DeviceCode
switch method {
case "authcode":
ar, err = client.AcquireTokenByAuthCode(ctx, "auth code", "https://localhost", tokenScope, WithTenantID(test.tenant))
case "authcodeURL":
URL, err = client.AuthCodeURL(ctx, "client-id", "https://localhost", tokenScope, WithTenantID(test.tenant))
case "devicecode":
dc, err = client.AcquireTokenByDeviceCode(ctx, tokenScope, WithTenantID(test.tenant))
case "interactive":
ar, err = client.AcquireTokenInteractive(ctx, tokenScope, WithTenantID(test.tenant))
case "password":
ar, err = client.AcquireTokenByUsernamePassword(ctx, tokenScope, "username", "password", WithTenantID(test.tenant))
default:
t.Fatalf("test bug: no test for " + method)
}
if err != nil {
if test.expectError {
return
}
t.Fatal(err)
} else if test.expectError {
t.Fatal("expected an error")
}
if method == "devicecode" {
if ar, err = dc.AuthenticationResult(ctx); err != nil {
t.Fatal(err)
}
}
if !strings.HasPrefix(URL, test.expectedAuthority) {
t.Fatalf(`expected "%s", got "%s"`, test.expectedAuthority, URL)
}
if method == "authcodeURL" {
// didn't acquire a token, no need to test silent auth
return
}
if ar.AccessToken != accessToken {
t.Fatalf(`unexpected access token "%s"`, ar.AccessToken)
}
// silent authentication should succeed for the given tenant because the client has a cached
// access token, and for a different tenant because the client has a cached refresh token
if ar, err = client.AcquireTokenSilent(ctx, tokenScope, WithSilentAccount(ar.Account), WithTenantID(test.tenant)); err != nil {
t.Fatal(err)
} else if ar.AccessToken != accessToken {
t.Fatal("cached access token should match the one returned by AcquireToken...")
}
otherTenant := "not-" + test.tenant
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, otherTenant)))
mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(accessToken, mock.GetIDToken(otherTenant, test.authority), "rt", clientInfo, 3600)))
if _, err = client.AcquireTokenSilent(ctx, tokenScope, WithSilentAccount(ar.Account), WithTenantID("not-"+test.tenant)); err != nil {
t.Fatal(err)
}
})
}
}
}
func TestWithInstanceDiscovery(t *testing.T) {
// replacing browserOpenURL with a fake for the duration of this test enables testing AcquireTokenInteractive
realBrowserOpenURL := browserOpenURL
defer func() { browserOpenURL = realBrowserOpenURL }()
browserOpenURL = fakeBrowserOpenURL
accessToken := "*"
clientInfo := base64.RawStdEncoding.EncodeToString([]byte(`{"uid":"uid","utid":"utid"}`))
host := "stack.local"
stackurl := fmt.Sprintf("https://%s/", host)
for _, tenant := range []string{
"adfs",
"98b8267d-e97f-426e-8b3f-7956511fd63f",
} {
for _, method := range []string{"authcode", "devicecode", "interactive", "password"} {
t.Run(method, func(t *testing.T) {
authority := stackurl + tenant
mockClient := mock.Client{}
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(host, tenant)))
if method == "devicecode" {
mockClient.AppendResponse(mock.WithBody([]byte(`{"device_code":"...","expires_in":600}`)))
} else if method == "password" && tenant != "adfs" {
// user realm metadata, which is not requested when AuthorityType is ADFS
mockClient.AppendResponse(mock.WithBody([]byte(`{"account_type":"Managed","cloud_audience_urn":"urn","cloud_instance_name":"...","domain_name":"..."}`)))
}
mockClient.AppendResponse(
mock.WithBody(mock.GetAccessTokenBody(accessToken, mock.GetIDToken(tenant, authority), "rt", clientInfo, 3600)),
)
client, err := New("client-id", WithAuthority(authority), WithHTTPClient(&mockClient), WithInstanceDiscovery(false))
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
if _, err = client.AcquireTokenSilent(ctx, tokenScope); err == nil {
t.Fatal("silent auth should fail because the cache is empty")
}
var ar AuthResult
var dc DeviceCode
switch method {
case "authcode":
ar, err = client.AcquireTokenByAuthCode(ctx, "auth code", "https://localhost", tokenScope)
case "devicecode":
dc, err = client.AcquireTokenByDeviceCode(ctx, tokenScope)
case "interactive":
ar, err = client.AcquireTokenInteractive(ctx, tokenScope)
case "password":
ar, err = client.AcquireTokenByUsernamePassword(ctx, tokenScope, "username", "password")
default:
t.Fatal("test bug: no test for " + method)
}
if err != nil {
t.Fatal(err)
}
if method == "devicecode" {
if ar, err = dc.AuthenticationResult(ctx); err != nil {
t.Fatal(err)
}
}
if ar.AccessToken != accessToken {
t.Fatalf(`unexpected access token "%s"`, ar.AccessToken)
}
if ar, err = client.AcquireTokenSilent(ctx, tokenScope, WithSilentAccount(ar.Account)); err != nil {
t.Fatal(err)
} else if ar.AccessToken != accessToken {
t.Fatal("cached access token should match the one returned by AcquireToken...")
}
})
}
}
}
// testCache is a simple in-memory cache.ExportReplace implementation
type testCache map[string][]byte
func (c testCache) Export(ctx context.Context, m cache.Marshaler, h cache.ExportHints) error {
v, err := m.Marshal()
if err == nil {
c[h.PartitionKey] = v
}
return err
}
func (c testCache) Replace(ctx context.Context, u cache.Unmarshaler, h cache.ReplaceHints) error {
if v, has := c[h.PartitionKey]; has {
return u.Unmarshal(v)
}
return nil
}
func TestWithCache(t *testing.T) {
cache := make(testCache)
accessToken, refreshToken := "*", "rt"
clientInfo := base64.RawStdEncoding.EncodeToString([]byte(`{"uid":"uid","utid":"utid"}`))
lmo := "login.microsoftonline.com"
tenantA, tenantB := "a", "b"
authorityA, authorityB := fmt.Sprintf(authorityFmt, lmo, tenantA), fmt.Sprintf(authorityFmt, lmo, tenantB)
mockClient := mock.Client{}
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenantA)))
mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(accessToken, mock.GetIDToken(tenantA, authorityA), refreshToken, clientInfo, 3600)))
client, err := New("client-id", WithAuthority(authorityA), WithCache(&cache), WithHTTPClient(&mockClient))
if err != nil {
t.Fatal(err)
}
// The particular flow isn't important, we just need to populate the cache. Auth code is the simplest for this test
ar, err := client.AcquireTokenByAuthCode(context.Background(), "code", "https://localhost", tokenScope)
if err != nil {
t.Fatal(err)
}
if ar.AccessToken != accessToken {
t.Fatalf(`unexpected access token "%s"`, ar.AccessToken)
}
account := ar.Account
if actual := account.Realm; actual != tenantA {
t.Fatalf(`unexpected realm "%s"`, actual)
}
// a client configured for a different tenant should be able to authenticate silently with the shared cache's data
client, err = New("client-id", WithAuthority(authorityB), WithCache(&cache), WithHTTPClient(&mockClient))
if err != nil {
t.Fatal(err)
}
accounts, err := client.Accounts(context.Background())
if err != nil {
t.Fatal(err)
}
if actual := len(accounts); actual != 1 {
t.Fatalf("expected 1 account but cache contains %d", actual)
}
if diff := pretty.Compare(account, accounts[0]); diff != "" {
t.Fatal(diff)
}
// this should work because the cache contains an access token from tenantA
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenantA)))
ar, err = client.AcquireTokenSilent(context.Background(), tokenScope, WithSilentAccount(account), WithTenantID(tenantA))
if err != nil {
t.Fatal(err)
}
if ar.AccessToken != accessToken {
t.Fatalf(`unexpected access token "%s"`, ar.AccessToken)
}
// this should work because the cache contains a refresh token for the user
accessToken2 := accessToken + "2"
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenantB)))
mockClient.AppendResponse(mock.WithBody(mock.GetAccessTokenBody(accessToken2, mock.GetIDToken(tenantB, authorityB), refreshToken, clientInfo, 3600)))
ar, err = client.AcquireTokenSilent(context.Background(), tokenScope, WithSilentAccount(account))
if err != nil {
t.Fatal(err)
}
if ar.AccessToken != accessToken2 {
t.Fatalf(`unexpected access token "%s"`, ar.AccessToken)
}
}
func TestWithClaims(t *testing.T) {
// replacing browserOpenURL with a fake for the duration of this test enables testing AcquireTokenInteractive
realBrowserOpenURL := browserOpenURL
defer func() { browserOpenURL = realBrowserOpenURL }()
browserOpenURL = fakeBrowserOpenURL
clientInfo := base64.RawStdEncoding.EncodeToString([]byte(`{"uid":"uid","utid":"utid"}`))
lmo, tenant := "login.microsoftonline.com", "tenant"
authority := fmt.Sprintf(authorityFmt, lmo, tenant)
accessToken, idToken, refreshToken := "at", mock.GetIDToken(tenant, lmo), "rt"
for _, test := range []struct {
capabilities []string
claims, expected string
}{
{},
{
capabilities: []string{"cp1"},
expected: `{"access_token":{"xms_cc":{"values":["cp1"]}}}`,
},
{
claims: `{"id_token":{"auth_time":{"essential":true}}}`,
expected: `{"id_token":{"auth_time":{"essential":true}}}`,
},
{
capabilities: []string{"cp1", "cp2"},
claims: `{"access_token":{"nbf":{"essential":true, "value":"42"}}}`,
expected: `{"access_token":{"nbf":{"essential":true, "value":"42"}, "xms_cc":{"values":["cp1","cp2"]}}}`,
},
} {
var expected map[string]any
if err := json.Unmarshal([]byte(test.expected), &expected); err != nil && test.expected != "" {
t.Fatal("test bug: the expected result must be JSON or an empty string")
}
// validate determines whether a request's query or form values contain the expected claims
validate := func(t *testing.T, v url.Values) {
if test.expected == "" {
if v.Has("claims") {
t.Fatal("claims shouldn't be set")
}
return
}
claims, ok := v["claims"]
if !ok {
t.Fatal("claims should be set")
}
if len(claims) != 1 {
t.Fatalf("expected exactly 1 claims value, got %d", len(claims))
}
var actual map[string]any
if err := json.Unmarshal([]byte(claims[0]), &actual); err != nil {
t.Fatal(err)
}
if diff := pretty.Compare(expected, actual); diff != "" {
t.Fatal(diff)
}
}
for _, method := range []string{"authcode", "authcodeURL", "devicecode", "interactive", "password", "passwordFederated"} {
t.Run(method, func(t *testing.T) {
mockClient := mock.Client{}
if method == "obo" {
// TODO: OBO does instance discovery twice before first token request https://github.com/AzureAD/microsoft-authentication-library-for-go/issues/351
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenant)))
}
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenant)))
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenant)))
switch method {
case "devicecode":
mockClient.AppendResponse(mock.WithBody([]byte(`{"device_code":".","expires_in":600}`)))
case "password":
mockClient.AppendResponse(mock.WithBody([]byte(`{"account_type":"Managed","cloud_audience_urn":".","cloud_instance_name":".","domain_name":"."}`)))
case "passwordFederated":
mockClient.AppendResponse(mock.WithBody([]byte(`{"account_type":"Federated","cloud_audience_urn":".","cloud_instance_name":".","domain_name":".","federation_protocol":".","federation_metadata_url":"."}`)))
}
mockClient.AppendResponse(
mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, clientInfo, 3600)),
mock.WithCallback(func(r *http.Request) {
if err := r.ParseForm(); err != nil {
t.Fatal(err)
}
validate(t, r.Form)
}),
)
client, err := New("client-id", WithAuthority(authority), WithClientCapabilities(test.capabilities), WithHTTPClient(&mockClient))
if err != nil {
t.Fatal(err)
}
if _, err = client.AcquireTokenSilent(context.Background(), tokenScope); err == nil {
t.Fatal("silent authentication should fail because the cache is empty")
}
ctx := context.Background()
var ar AuthResult
var dc DeviceCode
switch method {
case "authcode":
ar, err = client.AcquireTokenByAuthCode(ctx, "auth code", "https://localhost", tokenScope, WithClaims(test.claims))
case "authcodeURL":
u := ""
if u, err = client.AuthCodeURL(ctx, "client-id", "https://localhost", tokenScope, WithClaims(test.claims)); err == nil {
var parsed *url.URL
if parsed, err = url.Parse(u); err == nil {
validate(t, parsed.Query())
return // didn't acquire a token, no need for further validation
}
}
case "devicecode":
dc, err = client.AcquireTokenByDeviceCode(ctx, tokenScope, WithClaims(test.claims))
case "interactive":
ar, err = client.AcquireTokenInteractive(ctx, tokenScope, WithClaims(test.claims))
case "password":
ar, err = client.AcquireTokenByUsernamePassword(ctx, tokenScope, "username", "password", WithClaims(test.claims))
case "passwordFederated":
client.base.Token.WSTrust = fake.WSTrust{SamlTokenInfo: wstrust.SamlTokenInfo{AssertionType: "urn:ietf:params:oauth:grant-type:saml1_1-bearer"}}
ar, err = client.AcquireTokenByUsernamePassword(ctx, tokenScope, "username", "password", WithClaims(test.claims))
default:
t.Fatalf("test bug: no test for " + method)
}
if method == "devicecode" && err == nil {
// complete the device code flow
ar, err = dc.AuthenticationResult(ctx)
}
if err != nil {
t.Fatal(err)
}
if ar.AccessToken != accessToken {
t.Fatalf(`unexpected access token "%s"`, ar.AccessToken)
}
// silent auth should now succeed because the client has an access token cached
ar, err = client.AcquireTokenSilent(ctx, tokenScope, WithSilentAccount(ar.Account))
if err != nil {
t.Fatal(err)
}
if ar.AccessToken != accessToken {
t.Fatalf(`unexpected access token "%s"`, ar.AccessToken)
}
if test.claims != "" {
// when given claims, AcquireTokenSilent should request a new access token instead of returning the cached one
newToken := "new-access-token"
mockClient.AppendResponse(
mock.WithBody(mock.GetAccessTokenBody(newToken, idToken, "", clientInfo, 3600)),
mock.WithCallback(func(r *http.Request) {
if err := r.ParseForm(); err != nil {
t.Fatal(err)
}
// all token requests should include any specified claims
validate(t, r.Form)
if actual := r.Form.Get("refresh_token"); actual != refreshToken {
t.Fatalf(`unexpected refresh token "%s"`, actual)
}
}),
)
ar, err = client.AcquireTokenSilent(ctx, tokenScope, WithClaims(test.claims), WithSilentAccount(ar.Account))
if err != nil {
t.Fatal(err)
}
if actual := ar.AccessToken; actual != newToken {
t.Fatalf("Expected %s, got %s. Client should have redeemed its cached refresh token for a new access token.", newToken, actual)
}
}
})
}
}
}
func TestWithPortAuthority(t *testing.T) {
accessToken := "*"
sl := "stack.local"
port := ":3001"
host := sl + port
tenant := "00000000-0000-0000-0000-000000000000"
authority := fmt.Sprintf("https://%s%s/%s", sl, port, tenant)
idToken, refreshToken, URL := "", "", ""
mockClient := mock.Client{}
//2 calls to instance discovery are made because Host is not trusted
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(host, tenant)))
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(host, tenant)))
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(host, tenant)))
mockClient.AppendResponse(
mock.WithBody(mock.GetAccessTokenBody(accessToken, idToken, refreshToken, "", 3600)),
mock.WithCallback(func(r *http.Request) { URL = r.URL.String() }),
)
client, err := New("client-id", WithAuthority(authority), WithHTTPClient(&mockClient))
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
if _, err = client.AcquireTokenSilent(ctx, tokenScope); err == nil {
t.Fatal("silent auth should fail because the cache is empty")
}
var ar AuthResult
ar, err = client.AcquireTokenByAuthCode(ctx, "auth code", "https://localhost", tokenScope)
if err != nil {
t.Fatal(err)
}
if !strings.HasPrefix(URL, authority) {
t.Fatalf(`expected "%s", got "%s"`, authority, URL)
}
if ar.AccessToken != accessToken {
t.Fatalf(`unexpected access token "%s"`, ar.AccessToken)
}
if ar, err = client.AcquireTokenSilent(ctx, tokenScope); err != nil {
t.Fatal(err)
}
if ar.AccessToken != accessToken {
t.Fatal("cached access token should match the one returned by AcquireToken...")
}
}
func TestWithLoginHint(t *testing.T) {
realBrowserOpenURL := browserOpenURL
defer func() { browserOpenURL = realBrowserOpenURL }()
upn := "user@localhost"
client, err := New("client-id")
if err != nil {
t.Fatal(err)
}
client.base.Token.AccessTokens = &fake.AccessTokens{}
client.base.Token.Authority = &fake.Authority{}
client.base.Token.Resolver = &fake.ResolveEndpoints{}
for _, expectHint := range []bool{true, false} {
t.Run(fmt.Sprint(expectHint), func(t *testing.T) {
// replace the browser launching function with a fake that validates login_hint is set as expected
called := false
validate := func(v url.Values) error {
if !v.Has("login_hint") {
if !expectHint {
return nil
}
return errors.New("expected a login hint")
} else if !expectHint {
return errors.New("expected no login hint")
}
if actual := v["login_hint"]; len(actual) != 1 || actual[0] != upn {
err = fmt.Errorf(`unexpected login_hint "%v"`, actual)
}
return err
}
browserOpenURL = func(authURL string) error {
called = true
parsed, err := url.Parse(authURL)
if err != nil {
return err
}
query, err := url.ParseQuery(parsed.RawQuery)
if err != nil {
return err
}
if err = validate(query); err != nil {
t.Fatal(err)
return err
}
// this helper validates the other params and completes the redirect
return fakeBrowserOpenURL(authURL)
}
acquireOpts := []AcquireInteractiveOption{}
urlOpts := []AuthCodeURLOption{}
if expectHint {
acquireOpts = append(acquireOpts, WithLoginHint(upn))
urlOpts = append(urlOpts, WithLoginHint(upn))
}
_, err = client.AcquireTokenInteractive(context.Background(), tokenScope, acquireOpts...)
if err != nil {
t.Fatal(err)
}
if !called {
t.Fatal("browserOpenURL wasn't called")
}
u, err := client.AuthCodeURL(context.Background(), "id", "https://localhost", tokenScope, urlOpts...)
if err == nil {
var parsed *url.URL
parsed, err = url.Parse(u)
if err == nil {
err = validate(parsed.Query())
}
}
if err != nil {
t.Fatal(err)
}
})
}
}
func TestWithDomainHint(t *testing.T) {
realBrowserOpenURL := browserOpenURL
defer func() { browserOpenURL = realBrowserOpenURL }()
domain := "contoso.com"
client, err := New("client-id")
if err != nil {
t.Fatal(err)
}
client.base.Token.AccessTokens = &fake.AccessTokens{}
client.base.Token.Authority = &fake.Authority{}
client.base.Token.Resolver = &fake.ResolveEndpoints{}
for _, expectHint := range []bool{true, false} {
t.Run(fmt.Sprint(expectHint), func(t *testing.T) {
// replace the browser launching function with a fake that validates domain_hint is set as expected
called := false
validate := func(v url.Values) error {
if !v.Has("domain_hint") {
if !expectHint {
return nil
}
return errors.New("expected a domain hint")
} else if !expectHint {
return errors.New("expected no domain hint")
}
if actual := v["domain_hint"]; len(actual) != 1 || actual[0] != domain {
err = fmt.Errorf(`unexpected domain_hint "%v"`, actual)
}
return err
}
browserOpenURL = func(authURL string) error {
called = true
parsed, err := url.Parse(authURL)
if err != nil {
return err
}
query, err := url.ParseQuery(parsed.RawQuery)
if err != nil {
return err
}
if err = validate(query); err != nil {
t.Fatal(err)
return err
}
// this helper validates the other params and completes the redirect
return fakeBrowserOpenURL(authURL)
}
var acquireOpts []AcquireInteractiveOption
var urlOpts []AuthCodeURLOption
if expectHint {
acquireOpts = append(acquireOpts, WithDomainHint(domain))
urlOpts = append(urlOpts, WithDomainHint(domain))
}
_, err = client.AcquireTokenInteractive(context.Background(), tokenScope, acquireOpts...)
if err != nil {
t.Fatal(err)
}
if !called {
t.Fatal("browserOpenURL wasn't called")
}
u, err := client.AuthCodeURL(context.Background(), "id", "https://localhost", tokenScope, urlOpts...)
if err == nil {
var parsed *url.URL
parsed, err = url.Parse(u)
if err == nil {
err = validate(parsed.Query())
}
}
if err != nil {
t.Fatal(err)
}
})
}
}
microsoft-authentication-library-for-go-1.0.0/apps/testdata/ 0000775 0000000 0000000 00000000000 14420263624 0024142 5 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/apps/testdata/test-cert-chain-reverse.pem 0000664 0000000 0000000 00000011501 14420263624 0031306 0 ustar 00root root 0000000 0000000 -----BEGIN CERTIFICATE-----
MIIFGTCCAwGgAwIBAgIUBpOlpNN/cgasvozVw6mfa04+ZC0wDQYJKoZIhvcNAQEL
BQAwOzELMAkGA1UEBhMCVVMxDDAKBgNVBAoMA3h6eTEMMAoGA1UECwwDYWJjMRAw
DgYDVQQDDAdST09ULUNOMCAXDTIwMDgyMTE3MTAyNVoYDzMzODkwODA0MTcxMDI1
WjA+MQswCQYDVQQGEwJVUzEMMAoGA1UECgwDeHl6MQwwCgYDVQQLDANhYmMxEzAR
BgNVBAMMCklOVEVSSU0tQ04wggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAwggIKAoIC
AQCr+Tblr4DhX3Xahbei00OJnUgRw6FMsnyROZ170Lx0YNcOrRJ9PuaOZiYXY2Hm
t71o/PZjMtmiYMIxFaiMnql/dCca777l+uBmlwFOR8bquBWiLStmPpvf7Kh5GZNw
XvLGAhk/oxG0O9Pa3OfrlD5vrn/UEGJBu0C+c6ZSLyRk8RjAh8ZbUvnDhhQw3PoK
MQSmFK8BN8X34elu7kq0j7nS0D6Mt7eS40oYeHEaQDdBGl8f7rcqC3RjJ/b/F9wA
+CsKaps6TvpxE7ln9Y3+0yscgeRbyHW0zem6U7MMvVnK/znuNY90Wmajbea7SUj6
nGZpLGS1TqS4H5rn9U1N1WCSyFukTpAQLCPQHeUrSiHKa9Ye5KuC6u2ZXgy0qpGj
nMLu+7746wemi7jN06yZjEmDVneMNCxjLYs4ZhuhiTEItlZpR0VBugNbKo2mJw2U
UesizB3AzQkqGOKp70y74yC+ykLkR5vRNyY3MENJ+W83U1haS7C1rhqFV4eXflVe
EHl8tj7p4KrfhSPr0Rd12UIWDXkYUpCAPlDMdEa9+SDAyuSnkN4P1fAeuzG01jeJ
bnsrWgs3gH3KaGBcPTV4tOTavilGNYDvHZbN9XpYZoZQoPrDZc61M5Ol/cxBahkO
n4aDyhpx5hHnSs7VQuHnjeMUxt3J5HqrXPvaf6uPYNT8KQIDAQABoxAwDjAMBgNV
HRMEBTADAQH/MA0GCSqGSIb3DQEBCwUAA4ICAQCHCxFqJwfVMI9kMvwlj+sxd4Q5
KuyWxlXRfzpYZ/6JCUq7VBceRVJ87KytMNCyq61rd3Jhb8ssoMCENB68HYhIFUGz
GR92AAc6LTh2Y3vQAg640Cz2vLCGnqnlbIslYV6fzxYqgSopR5wJ4D/kJ9w7NSrC
paN6bS8Olv//tN6RSnvEMJZdXFA40xFin6qT8Op3nrysEE7Z84wPG9Wj2DXskX6v
bZenCEgl1/Ezif5IEgJcYdRkXtYPp6JNbVV+KjDTIMEaUVMpGMGefrt22E+4nSa3
qFvcbzYEKeANe9IAxdPzeWiQ2U90PqWFYCA9sOVsrlSwrup+yYXl0yhTxKY67NCX
gyVtZRnzawv0AVFsfCOT4V0wJSuUz4BV6sH7kl2C7FW3zqYVdFEDigbUNsEEh/jF
3JiAtgNbpJ8TtiCFrCI4g9Jepa3polVPzDD8mLtkWWnfSBN/28cxa2jiUlfQxB39
kyqu4rWbm01lyucJxVgJzH0SGyEM5OvF/OIOU3Q7UIXEcZSX3m4Xo59+v6ZNDwKL
PcFDNK+PL3WNYfdexQCSAbLm1gkUrVIqvidpCSSVv5oWwTM5m7rbA16Hlu4Ea2ep
Pl7I9YXXXnIEFqLYZDnCJglcXmlt6OjI8D3w0TRWHb6bFqubDP417sJDX1S6udN5
wOnOIqg0ZZcqfvpxXA==
-----END CERTIFICATE-----
-----BEGIN CERTIFICATE-----
MIID7zCCAdcCAQEwDQYJKoZIhvcNAQEFBQAwPjELMAkGA1UEBhMCVVMxDDAKBgNV
BAoMA3h5ejEMMAoGA1UECwwDYWJjMRMwEQYDVQQDDApJTlRFUklNLUNOMCAXDTIw
MDgyMTE3MTA0M1oYDzMzODkwODA0MTcxMDQzWjA7MQswCQYDVQQGEwJVUzEMMAoG
A1UECgwDeHl6MQwwCgYDVQQLDANhYmMxEDAOBgNVBAMMB1VTRVItQ04wggEiMA0G
CSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC6eQYdbIFhsinob3t3AV4yEH/tz/LV
I+UAGLpxQnqGnuAV5GY3CXiAO8GZjx7y3oA1DGfe+/cc6n9BmYWXsKvxpKO8PQkB
PYIFtD878uDNv7kVoZG8EVsEngBxd4efMniKWwKtMle0hZ+jj3u4Ad49DsXcC0L2
8uV/eQ6hzsQiR0nTQJ/4QqNNtThSGAFSr7Oo8xzxBNTJhe+BvwDE8JMkCS0v22JW
my2GYrRKw4RlSKxwv9QZr83gSicKSUPUACBYfJ7RuXSQOHOMlIcC4oGtDrMshGzr
704Ho+DiByYf5G6nkfZ1I7T039gEKKIllNKWqhyQHejKba3nP163ZKI3AgMBAAEw
DQYJKoZIhvcNAQEFBQADggIBADfitSfjlYa2inBKlpWN8VT0DPm5uw8EHuwLymCM
WYrQMCuQVE2xYoqCSmXj6KLFt8ycgxHsthdkAzXxDhawaKjz2UFp6nszmUA4xfvS
mxLSajwzK/KMBkjdFL7TM+TTBJ1bleDbmoJvDiUeQwisbb1Uh8b3v/jpBwoiamm8
Y4Ca5A15SeBUvAt0/Mc4XJfZ/Ts+LBAPevI9ZyU7C5JZky1q41KPklEHfFZKQRfP
cTyTYYvlPoq57C8XPDs6r50EV3B6Z8MN21OB6MVGi8BOY/c7a2h1ZOhxNyBnJuQX
w4meJthoKcHUnAs8YCrEoQKayMqPH0Vdhaii/gx4jAgh4PNyIZz5cAst+ybPtQj4
i7LFEWjxis+NLQMHhyE4fIGIkEjzU0uGDugifheIwKALqYEgMDrcoolwvGMdPxGo
Qps7tkad5vZV9d9+tTbI+DMB16Y51S04/u1dGFz3jSrDVF08PznJc99VB69OReiC
K17n8Xyox/VAaYsRFbOAJpLRWwcnotDpFQbgiLrmXxNOoiWPNbQsQzaQx7cR9okQ
v5RTpFAkrdjadhMsXFFiQh+axlaGD368ZGAj5ZoyOiXkV88tNCtyP/RDgW5ftQQ7
fdv05bNXhDfLgEgQvVSDfClDL1hKukLmLQS3ILfB4FlM/XmE+FW/qgo9aSx2XIbx
E4ie
-----END CERTIFICATE-----
-----BEGIN RSA PRIVATE KEY-----
MIIEowIBAAKCAQEAunkGHWyBYbIp6G97dwFeMhB/7c/y1SPlABi6cUJ6hp7gFeRm
Nwl4gDvBmY8e8t6ANQxn3vv3HOp/QZmFl7Cr8aSjvD0JAT2CBbQ/O/Lgzb+5FaGR
vBFbBJ4AcXeHnzJ4ilsCrTJXtIWfo497uAHePQ7F3AtC9vLlf3kOoc7EIkdJ00Cf
+EKjTbU4UhgBUq+zqPMc8QTUyYXvgb8AxPCTJAktL9tiVpsthmK0SsOEZUiscL/U
Ga/N4EonCklD1AAgWHye0bl0kDhzjJSHAuKBrQ6zLIRs6+9OB6Pg4gcmH+Rup5H2
dSO09N/YBCiiJZTSlqockB3oym2t5z9et2SiNwIDAQABAoIBAQCKzivPG0X0AztO
2i19mHcVrVKNI44POnjsaXvfcyzhqMIFic7MiTA5xEGInRDcmOO2mVV4lvaLf8La
gfz/vXNAnN2E8aoSUkbHGDU52sGcZmrPv0VMSV8HQNXzoJZD2r3/v19urVq79fuv
NM9TWZCkwqpl8bwXNxe+m85YhCFboY9G543qmuXzKAQLoSupT0e4eIo2IGp7eJYK
5J/wtlEumUdhsKo1ajLojDgsgPKfrCyvsmO+bj1dRKGXVLO2SL2pFVCjjHF4SP3q
1WX39beu61Zu+kGthDgj5muHgH06FtnWoHLIUrRmYpM+ezCxQHdRWz7AYjheeE7q
QqJv1PqBAoGBAOlb/gzsps+rInE+LQoEzVj8osILI4NxIpNc6+iG81dEi+zQABX/
bHV6hXGGceozVcX4B+V7f08PlZIAgM3IDqfy0fH2pwEQahJ8a3MwzCgR66RxYlkX
E8czkoz0pcHW58FnLLlWXpHRALTtqoPP5LnWs0SmoNvcHZ9yjJ6tvpRlAoGBAMyQ
fytsyla1ujO0l/kuLFG7gndeOc96SutH3V17lZ1pN0efHyk2aglOnl6YsdPKLZvZ
3ghj01HV0Q0f//xpftduuA7gdgDzSG1irXsxEidfVxX7RsPxX6cx8dhYnuk5rz5E
XyTko7zTpr+A4XMnq6+JNSSCIE+CVYcYf/hyemxrAoGAeC9py4xCaWgxR/OGzMcm
X3NV++wysSqebRkJYuvF/icOjbuen7W6TVL50Ts2BjHENj6FCpqtObHEDbr2m4Uy
jysPF7g50OF8T+MGkAAM1YJNQ5cl2M564DhefPwvNoMRP1l8/kNOV3k2DPjuvg5f
NZsvHudWp4VZOFqNs9e19MUCgYAjewCDoKfrqDN2mmEtmAOZ3YMAfzhZsyVhb6KG
f1Pw7HnpE0FNXaHAoYE4eRWG3W9Rs9Ud8WqKrCJJO36j4gxdA1grRGVTPt8WEeJz
FozGhXPOXTnl7GyhzDjdRGmznAy4KRWziXCY5MDsQEdaOMw/cvXjsio2gC2jc+1m
QzzWpwKBgHzszJ5s6vcWElox4Yc1elQ8xniPpo3RtfXZOLX8xA4eR9yQawah1zd6
ChfeYbHVfq007s+RWGTb+KYQ6ic9nkW464qmVxHGBatUo9+MR4Gk8blANoAfHxdV
g6JNgT2kIGu9IEwoD6XQldC/v24bvFSesyGRHNdI4mUG+hhU4aNw
-----END RSA PRIVATE KEY-----
microsoft-authentication-library-for-go-1.0.0/apps/testdata/test-cert-chain.pem 0000664 0000000 0000000 00000011501 14420263624 0027635 0 ustar 00root root 0000000 0000000 -----BEGIN RSA PRIVATE KEY-----
MIIEowIBAAKCAQEAunkGHWyBYbIp6G97dwFeMhB/7c/y1SPlABi6cUJ6hp7gFeRm
Nwl4gDvBmY8e8t6ANQxn3vv3HOp/QZmFl7Cr8aSjvD0JAT2CBbQ/O/Lgzb+5FaGR
vBFbBJ4AcXeHnzJ4ilsCrTJXtIWfo497uAHePQ7F3AtC9vLlf3kOoc7EIkdJ00Cf
+EKjTbU4UhgBUq+zqPMc8QTUyYXvgb8AxPCTJAktL9tiVpsthmK0SsOEZUiscL/U
Ga/N4EonCklD1AAgWHye0bl0kDhzjJSHAuKBrQ6zLIRs6+9OB6Pg4gcmH+Rup5H2
dSO09N/YBCiiJZTSlqockB3oym2t5z9et2SiNwIDAQABAoIBAQCKzivPG0X0AztO
2i19mHcVrVKNI44POnjsaXvfcyzhqMIFic7MiTA5xEGInRDcmOO2mVV4lvaLf8La
gfz/vXNAnN2E8aoSUkbHGDU52sGcZmrPv0VMSV8HQNXzoJZD2r3/v19urVq79fuv
NM9TWZCkwqpl8bwXNxe+m85YhCFboY9G543qmuXzKAQLoSupT0e4eIo2IGp7eJYK
5J/wtlEumUdhsKo1ajLojDgsgPKfrCyvsmO+bj1dRKGXVLO2SL2pFVCjjHF4SP3q
1WX39beu61Zu+kGthDgj5muHgH06FtnWoHLIUrRmYpM+ezCxQHdRWz7AYjheeE7q
QqJv1PqBAoGBAOlb/gzsps+rInE+LQoEzVj8osILI4NxIpNc6+iG81dEi+zQABX/
bHV6hXGGceozVcX4B+V7f08PlZIAgM3IDqfy0fH2pwEQahJ8a3MwzCgR66RxYlkX
E8czkoz0pcHW58FnLLlWXpHRALTtqoPP5LnWs0SmoNvcHZ9yjJ6tvpRlAoGBAMyQ
fytsyla1ujO0l/kuLFG7gndeOc96SutH3V17lZ1pN0efHyk2aglOnl6YsdPKLZvZ
3ghj01HV0Q0f//xpftduuA7gdgDzSG1irXsxEidfVxX7RsPxX6cx8dhYnuk5rz5E
XyTko7zTpr+A4XMnq6+JNSSCIE+CVYcYf/hyemxrAoGAeC9py4xCaWgxR/OGzMcm
X3NV++wysSqebRkJYuvF/icOjbuen7W6TVL50Ts2BjHENj6FCpqtObHEDbr2m4Uy
jysPF7g50OF8T+MGkAAM1YJNQ5cl2M564DhefPwvNoMRP1l8/kNOV3k2DPjuvg5f
NZsvHudWp4VZOFqNs9e19MUCgYAjewCDoKfrqDN2mmEtmAOZ3YMAfzhZsyVhb6KG
f1Pw7HnpE0FNXaHAoYE4eRWG3W9Rs9Ud8WqKrCJJO36j4gxdA1grRGVTPt8WEeJz
FozGhXPOXTnl7GyhzDjdRGmznAy4KRWziXCY5MDsQEdaOMw/cvXjsio2gC2jc+1m
QzzWpwKBgHzszJ5s6vcWElox4Yc1elQ8xniPpo3RtfXZOLX8xA4eR9yQawah1zd6
ChfeYbHVfq007s+RWGTb+KYQ6ic9nkW464qmVxHGBatUo9+MR4Gk8blANoAfHxdV
g6JNgT2kIGu9IEwoD6XQldC/v24bvFSesyGRHNdI4mUG+hhU4aNw
-----END RSA PRIVATE KEY-----
-----BEGIN CERTIFICATE-----
MIID7zCCAdcCAQEwDQYJKoZIhvcNAQEFBQAwPjELMAkGA1UEBhMCVVMxDDAKBgNV
BAoMA3h5ejEMMAoGA1UECwwDYWJjMRMwEQYDVQQDDApJTlRFUklNLUNOMCAXDTIw
MDgyMTE3MTA0M1oYDzMzODkwODA0MTcxMDQzWjA7MQswCQYDVQQGEwJVUzEMMAoG
A1UECgwDeHl6MQwwCgYDVQQLDANhYmMxEDAOBgNVBAMMB1VTRVItQ04wggEiMA0G
CSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC6eQYdbIFhsinob3t3AV4yEH/tz/LV
I+UAGLpxQnqGnuAV5GY3CXiAO8GZjx7y3oA1DGfe+/cc6n9BmYWXsKvxpKO8PQkB
PYIFtD878uDNv7kVoZG8EVsEngBxd4efMniKWwKtMle0hZ+jj3u4Ad49DsXcC0L2
8uV/eQ6hzsQiR0nTQJ/4QqNNtThSGAFSr7Oo8xzxBNTJhe+BvwDE8JMkCS0v22JW
my2GYrRKw4RlSKxwv9QZr83gSicKSUPUACBYfJ7RuXSQOHOMlIcC4oGtDrMshGzr
704Ho+DiByYf5G6nkfZ1I7T039gEKKIllNKWqhyQHejKba3nP163ZKI3AgMBAAEw
DQYJKoZIhvcNAQEFBQADggIBADfitSfjlYa2inBKlpWN8VT0DPm5uw8EHuwLymCM
WYrQMCuQVE2xYoqCSmXj6KLFt8ycgxHsthdkAzXxDhawaKjz2UFp6nszmUA4xfvS
mxLSajwzK/KMBkjdFL7TM+TTBJ1bleDbmoJvDiUeQwisbb1Uh8b3v/jpBwoiamm8
Y4Ca5A15SeBUvAt0/Mc4XJfZ/Ts+LBAPevI9ZyU7C5JZky1q41KPklEHfFZKQRfP
cTyTYYvlPoq57C8XPDs6r50EV3B6Z8MN21OB6MVGi8BOY/c7a2h1ZOhxNyBnJuQX
w4meJthoKcHUnAs8YCrEoQKayMqPH0Vdhaii/gx4jAgh4PNyIZz5cAst+ybPtQj4
i7LFEWjxis+NLQMHhyE4fIGIkEjzU0uGDugifheIwKALqYEgMDrcoolwvGMdPxGo
Qps7tkad5vZV9d9+tTbI+DMB16Y51S04/u1dGFz3jSrDVF08PznJc99VB69OReiC
K17n8Xyox/VAaYsRFbOAJpLRWwcnotDpFQbgiLrmXxNOoiWPNbQsQzaQx7cR9okQ
v5RTpFAkrdjadhMsXFFiQh+axlaGD368ZGAj5ZoyOiXkV88tNCtyP/RDgW5ftQQ7
fdv05bNXhDfLgEgQvVSDfClDL1hKukLmLQS3ILfB4FlM/XmE+FW/qgo9aSx2XIbx
E4ie
-----END CERTIFICATE-----
-----BEGIN CERTIFICATE-----
MIIFGTCCAwGgAwIBAgIUBpOlpNN/cgasvozVw6mfa04+ZC0wDQYJKoZIhvcNAQEL
BQAwOzELMAkGA1UEBhMCVVMxDDAKBgNVBAoMA3h6eTEMMAoGA1UECwwDYWJjMRAw
DgYDVQQDDAdST09ULUNOMCAXDTIwMDgyMTE3MTAyNVoYDzMzODkwODA0MTcxMDI1
WjA+MQswCQYDVQQGEwJVUzEMMAoGA1UECgwDeHl6MQwwCgYDVQQLDANhYmMxEzAR
BgNVBAMMCklOVEVSSU0tQ04wggIiMA0GCSqGSIb3DQEBAQUAA4ICDwAwggIKAoIC
AQCr+Tblr4DhX3Xahbei00OJnUgRw6FMsnyROZ170Lx0YNcOrRJ9PuaOZiYXY2Hm
t71o/PZjMtmiYMIxFaiMnql/dCca777l+uBmlwFOR8bquBWiLStmPpvf7Kh5GZNw
XvLGAhk/oxG0O9Pa3OfrlD5vrn/UEGJBu0C+c6ZSLyRk8RjAh8ZbUvnDhhQw3PoK
MQSmFK8BN8X34elu7kq0j7nS0D6Mt7eS40oYeHEaQDdBGl8f7rcqC3RjJ/b/F9wA
+CsKaps6TvpxE7ln9Y3+0yscgeRbyHW0zem6U7MMvVnK/znuNY90Wmajbea7SUj6
nGZpLGS1TqS4H5rn9U1N1WCSyFukTpAQLCPQHeUrSiHKa9Ye5KuC6u2ZXgy0qpGj
nMLu+7746wemi7jN06yZjEmDVneMNCxjLYs4ZhuhiTEItlZpR0VBugNbKo2mJw2U
UesizB3AzQkqGOKp70y74yC+ykLkR5vRNyY3MENJ+W83U1haS7C1rhqFV4eXflVe
EHl8tj7p4KrfhSPr0Rd12UIWDXkYUpCAPlDMdEa9+SDAyuSnkN4P1fAeuzG01jeJ
bnsrWgs3gH3KaGBcPTV4tOTavilGNYDvHZbN9XpYZoZQoPrDZc61M5Ol/cxBahkO
n4aDyhpx5hHnSs7VQuHnjeMUxt3J5HqrXPvaf6uPYNT8KQIDAQABoxAwDjAMBgNV
HRMEBTADAQH/MA0GCSqGSIb3DQEBCwUAA4ICAQCHCxFqJwfVMI9kMvwlj+sxd4Q5
KuyWxlXRfzpYZ/6JCUq7VBceRVJ87KytMNCyq61rd3Jhb8ssoMCENB68HYhIFUGz
GR92AAc6LTh2Y3vQAg640Cz2vLCGnqnlbIslYV6fzxYqgSopR5wJ4D/kJ9w7NSrC
paN6bS8Olv//tN6RSnvEMJZdXFA40xFin6qT8Op3nrysEE7Z84wPG9Wj2DXskX6v
bZenCEgl1/Ezif5IEgJcYdRkXtYPp6JNbVV+KjDTIMEaUVMpGMGefrt22E+4nSa3
qFvcbzYEKeANe9IAxdPzeWiQ2U90PqWFYCA9sOVsrlSwrup+yYXl0yhTxKY67NCX
gyVtZRnzawv0AVFsfCOT4V0wJSuUz4BV6sH7kl2C7FW3zqYVdFEDigbUNsEEh/jF
3JiAtgNbpJ8TtiCFrCI4g9Jepa3polVPzDD8mLtkWWnfSBN/28cxa2jiUlfQxB39
kyqu4rWbm01lyucJxVgJzH0SGyEM5OvF/OIOU3Q7UIXEcZSX3m4Xo59+v6ZNDwKL
PcFDNK+PL3WNYfdexQCSAbLm1gkUrVIqvidpCSSVv5oWwTM5m7rbA16Hlu4Ea2ep
Pl7I9YXXXnIEFqLYZDnCJglcXmlt6OjI8D3w0TRWHb6bFqubDP417sJDX1S6udN5
wOnOIqg0ZZcqfvpxXA==
-----END CERTIFICATE-----
microsoft-authentication-library-for-go-1.0.0/apps/testdata/test-cert.pem 0000664 0000000 0000000 00000005150 14420263624 0026560 0 ustar 00root root 0000000 0000000 -----BEGIN CERTIFICATE-----
MIICljCCAX4CCQDNgteZ+lJH4zANBgkqhkiG9w0BAQsFADANMQswCQYDVQQGEwJ1
czAeFw0yMTAxMDQyMzQzNDVaFw0yMTAyMDMyMzQzNDVaMA0xCzAJBgNVBAYTAnVz
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA1r58wq7JQxM12viLNbdG
fFizeVQwWRwrx/4CH3kU8jjGovbhkvC/uLWqVGchgATThhGkvNrA92WvdkVwsZMk
Qf7ZnTA7kemo4VFtgo5XCGEej9gOTW13Evdc/0Flip+RXl3h3Q6BbbB9IFE0c6cS
3i/v/t8KGpVYQHQzBwTcYehM6eDO8ZjUyUUcJOMXdMCctamig7fMGlziKFahn4dX
JoiiK4oNKE9okXIAXCTbVkAxxH0hD+5XH1nn5LJnHe0e5DflI3YIiPgmRL5uC89K
XqmYCKWrq5z2D5k+5fQLmbOcxErBcFCh8hA+Xu0RLT4BHPEgc6iVIqxL4CZi/cke
uwIDAQABMA0GCSqGSIb3DQEBCwUAA4IBAQAAyDbm0Fda0/vY6ZVDML2IbGWbro1w
nWYNw6wclNU6sx1oeG/k/y2ni7NImPpbFN+594WS6rYHgFdROfeuNgGnjgQCJogk
+8ouf1R6vFMUAScWeSaFnZmBEgwofWsnIcUKkbDIXbpRhMrkNEcY09VgjmCKhspQ
iX2bJQTj49XBac9tBaJJYDZ4HgkO4nU7QeEPpvwlELZFoZZXtd3fan+VUyFS2a9n
gkAMDYoQPGN4tyGFabWws/GlMxelWvqUzpQKmeRPVz+cij75l8eKThEiu0zbjOTD
Gq81BcY61SPqN02zoPCtqZ/zU6HhaL3x7zUuzhLhNoh83A43UVYEoOOf
-----END CERTIFICATE-----
-----BEGIN PRIVATE KEY-----
MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDWvnzCrslDEzXa
+Is1t0Z8WLN5VDBZHCvH/gIfeRTyOMai9uGS8L+4tapUZyGABNOGEaS82sD3Za92
RXCxkyRB/tmdMDuR6ajhUW2CjlcIYR6P2A5NbXcS91z/QWWKn5FeXeHdDoFtsH0g
UTRzpxLeL+/+3woalVhAdDMHBNxh6Ezp4M7xmNTJRRwk4xd0wJy1qaKDt8waXOIo
VqGfh1cmiKIrig0oT2iRcgBcJNtWQDHEfSEP7lcfWefksmcd7R7kN+UjdgiI+CZE
vm4Lz0peqZgIpaurnPYPmT7l9AuZs5zESsFwUKHyED5e7REtPgEc8SBzqJUirEvg
JmL9yR67AgMBAAECggEAAQ/IBh5fGFnL9l0sMwPI8Wxu1ra31njxLnfvAsDSfbAS
K1QVIWjXSc58HRa1b7CWax9DNTvPoGl8SJVnTTlxAHKGGOTYJoyFLTf91ptlisEQ
KZ3j1DYqVImsiAaGvfyz90d3imQ795Lby4EbRUcaLMcH5LatkhwS556rcelwPXuq
M43XaZu5Es4pG0EmzfXplO/awt5HdUDPEAY3yw7QH8D1/l/toLPyiFv37RezkVK9
ffcUQpH7uH000Gja+JSEHgpWZhE96ac6H0zBtlM1VkMtfBuczz5tkKN/p70fhr8T
ZXARZqIaF4vx7RkBBzCfhvrgGqxXMuvTaW6N4RDWYQKBgQD1iZ7/xr9qy4cPFSOt
yBnG5cE6wC7wP8qgr0N7MgAii5OZgx6rtfGIVJDY58CFijnT8jZ5pjNS3p7j/Rzp
lQJMIwC5kIe/7FU7nmE3ko7Wg+bpd8iWLLIi/QWVFLbS7qVmulTc+CEXWyhAiI2u
RL/1APjIDFKp9gqtKmwb9erxDwKBgQDf5PbGHuPv5RBLJz9du+M/BIBY+HDltG89
p3huHHTjkJ5R38oximf2HnV4ygT/p2+ZUD6TJZZw6qou3/GiU5gZbRpg+4LXtQUR
vV+S2n/t86NG1YcGmM29r8LWqrK9gxLW0X62Fpps16rHSP7kVc4SvmrYwqNzqKlC
D9QbFYYflQKBgQCKEVzrDuNMNi43+PcbHU4BXeiOFMtQJU7XlDYp7C/PPRU+WVDB
1Yl/062vioHjlZp259hiB2cMzkoigY3kevnTvksGDZOIBGjZIXIhQbQ4Q+twlP6i
E3gH3Kdq8T7s1W0EmvplVtGkxImZ4C9rMxWNu4IpW2SQVd4jCZvJDTuTWQKBgQCn
LGjuCYacSubdlpKDxJSrKwtCY0641P7yhCcx4GGOwR7Vd0mbsAJsDNYduIn+8eAs
E3SFnl00NqOXmHLth4lcAtDddS5/LZR5aHMCTc+TtoVFkI3faRzF84SBkLchNctN
RuNbxojLmETVxDU9/Kt/51oUO1CcPWUUBImVJ38b+QKBgQCTbi0nS0n8kC7nlXWN
QtPcf4UraJAxv1DGq4lnJ8AHSZqqkP5fyjfknSw5ExOPDg4mEHhnnpsvwJuSX00d
UYUN2ZJXPZeaO0HmbYZ3/vC9bo6KW95PhidEUQpGlKrFY342khjQHJtH67YUThwU
lQFhpxvPgPNBuxVRnsxoH/sLOA==
-----END PRIVATE KEY-----
microsoft-authentication-library-for-go-1.0.0/apps/tests/ 0000775 0000000 0000000 00000000000 14420263624 0023473 5 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/apps/tests/benchmarks/ 0000775 0000000 0000000 00000000000 14420263624 0025610 5 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/apps/tests/benchmarks/confidential.go 0000664 0000000 0000000 00000013100 14420263624 0030571 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package main
import (
"context"
"fmt"
"os"
"runtime"
"strconv"
"sync"
"text/template"
"time"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base"
internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/fake"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
)
const accessToken = "fake_token"
var tokenScope = []string{"fake_scope"}
type testParams struct {
// the number of goroutines to use
Concurrency int
// the number of tokens in the cache
// must be divisible by Concurrency
TokenCount int
}
func fakeClient() (base.Client, error) {
// we use a base.Client so we can provide a fake OAuth client
return base.New("fake_client_id", "https://fake_authority/fake", &oauth.Client{
AccessTokens: &fake.AccessTokens{
AccessToken: accesstokens.TokenResponse{
AccessToken: accessToken,
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: tokenScope},
},
},
Authority: &fake.Authority{
InstanceResp: authority.InstanceDiscoveryResponse{
Metadata: []authority.InstanceDiscoveryMetadata{
{
PreferredNetwork: "fake_authority",
Aliases: []string{"fake_authority"},
},
},
},
},
Resolver: &fake.ResolveEndpoints{
Endpoints: authority.Endpoints{
AuthorizationEndpoint: "auth_endpoint",
TokenEndpoint: "token_endpoint",
},
},
WSTrust: &fake.WSTrust{},
})
}
type execTime struct {
start time.Time
end time.Time
}
func populateTokenCache(client base.Client, params testParams) execTime {
if r := params.TokenCount % params.Concurrency; r != 0 {
panic("TokenCount must be divisible by Concurrency")
}
parts := params.TokenCount / params.Concurrency
authParams := client.AuthParams
authParams.Scopes = tokenScope
authParams.AuthorizationType = authority.ATClientCredentials
wg := &sync.WaitGroup{}
fmt.Printf("Populating token cache with %d tokens...", params.TokenCount)
start := time.Now()
for n := 0; n < params.Concurrency; n++ {
wg.Add(1)
go func(chunk int) {
for i := parts * chunk; i < parts*(chunk+1); i++ {
// we use this to add a fake token to the cache.
// each token has a different scope which is what makes them unique
_, err := client.AuthResultFromToken(context.Background(), authParams, accesstokens.TokenResponse{
AccessToken: accessToken,
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: []string{strconv.FormatInt(int64(i), 10)}},
}, true)
if err != nil {
panic(err)
}
}
wg.Done()
}(n)
}
wg.Wait()
return execTime{start: start, end: time.Now()}
}
func executeTest(client base.Client, params testParams) execTime {
wg := &sync.WaitGroup{}
fmt.Printf("Begin token retrieval.....")
start := time.Now()
for n := 0; n < params.Concurrency; n++ {
wg.Add(1)
go func() {
// retrieve each token once per goroutine
for tk := 0; tk < params.TokenCount; tk++ {
_, err := client.AcquireTokenSilent(context.Background(), base.AcquireTokenSilentParameters{
Scopes: []string{strconv.FormatInt(int64(tk), 10)},
RequestType: accesstokens.ATConfidential,
Credential: &accesstokens.Credential{
Secret: "fake_secret",
},
})
if err != nil {
panic(err)
}
}
wg.Done()
}()
}
wg.Wait()
return execTime{start: start, end: time.Now()}
}
// Stats is used with statsTemplText for reporting purposes
type Stats struct {
popExec execTime
retExec execTime
Concurrency int
Count int64
}
// PopDur returns the total duration for populating the cache.
func (s *Stats) PopDur() time.Duration {
return s.popExec.end.Sub(s.popExec.start)
}
// RetDur returns the total duration for retrieving tokens.
func (s *Stats) RetDur() time.Duration {
return s.retExec.end.Sub(s.retExec.start)
}
// PopAvg returns the mean average of caching a token.
func (s *Stats) PopAvg() time.Duration {
return s.PopDur() / time.Duration(s.Count)
}
// RetAvg returns the mean average of retrieving a token.
func (s *Stats) RetAvg() time.Duration {
return s.RetDur() / time.Duration(s.Count)
}
var statsTemplText = `
Test Results:
[{{.Concurrency}} goroutines][{{.Count}} tokens] [population: total {{.PopDur}}, avg {{.PopAvg}}] [retrieval: total {{.RetDur}}, avg {{.RetAvg}}]
==========================================================================
`
var statsTempl = template.Must(template.New("stats").Parse(statsTemplText))
func main() {
tests := []testParams{
{
Concurrency: runtime.NumCPU(),
TokenCount: 100,
},
{
Concurrency: runtime.NumCPU(),
TokenCount: 1000,
},
{
Concurrency: runtime.NumCPU(),
TokenCount: 10000,
},
{
Concurrency: runtime.NumCPU(),
TokenCount: 20000,
},
}
for _, t := range tests {
client, err := fakeClient()
if err != nil {
panic(err)
}
fmt.Printf("Test Params: %#v\n", t)
ptime := populateTokenCache(client, t)
ttime := executeTest(client, t)
if err := statsTempl.Execute(os.Stdout, &Stats{
popExec: ptime,
retExec: ttime,
Concurrency: t.Concurrency,
Count: int64(t.TokenCount),
}); err != nil {
panic(err)
}
}
}
microsoft-authentication-library-for-go-1.0.0/apps/tests/devapps/ 0000775 0000000 0000000 00000000000 14420263624 0025135 5 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/apps/tests/devapps/README.md 0000664 0000000 0000000 00000003707 14420263624 0026423 0 ustar 00root root 0000000 0000000 # Running the Dev Apps for MSAL Go
To run one of the dev app which uses MSAL Go, the `config.json` file and the `confidential_config.json` should look like the following:
```json
{
"authority": "https://login.microsoftonline.com/organizations",
"client_id": "your_client_id",
"scopes": ["user.read"],
"username": "your_username",
"password": "your_password",
"redirect_uri": "redirect uri registered on the portal",
"code_challenge": "transformed code verifier from PKCE",
"state": "state parameter for authorization code flow",
"client_secret": "client secret you generated for your app",
"thumbprint": "the certificate thumbprint defined in your app generation",
"pem_file": "the file path of your private key pem"
}
```
The dev apps in this repo get tokens for the MS Graph API. To find permissible scopes for MS Graph, visit this [link](https://docs.microsoft.com/graph/permissions-reference). PKCE is explained [here](https://tools.ietf.org/html/rfc7636#section-4.1).
## On Windows
To run the dev samples:
`cd test/devapps`
run the command:
'go run ./ 1'
Alternatives:
* 1 build and run "locally"
* In the devapps folder
* type 'go build'
* type 'devapps.exe 1' to run the device code flow
* 2 (Advanced) install and run from the gobin folder
* See more: https://golang.org/cmd/go/#hdr-Compile_and_install_packages_and_dependencies
* In the devapps folder
* type 'go install'
* locate your gobin folder e.g. type 'go env' to find your gobin folder location
cd to your gobin folder
* type 'devapps.exe 1' to run the device code flow
## On Mac
To run one of the devapps, run the command `go run src/test/devapps/*.go `. The devapp numbers are as follows:
* 1 - `device_code_flow_sample.go`
* 2 - `authorization_code_sample.go`
* 3 - `username_password_sample.go`
* 4 - `confidential_auth_code_sample.go`
* 5 - `client_secret_sample.go`
* 6 - `client_certificate_sample.go`
microsoft-authentication-library-for-go-1.0.0/apps/tests/devapps/authorization_code_sample.go 0000664 0000000 0000000 00000004225 14420263624 0032722 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package main
// TODO(msal expert): This should be refactored into an example maybe?
// a "main" with a bunch of private functions that can't run isn't a good code sample.
/*
func getToken(w http.ResponseWriter, r *http.Request) {
// Getting the authorization code from the URL's query
states, ok := r.URL.Query()["state"]
if !ok || len(states[0]) < 1 {
log.Fatal(errors.New("State parameter missing, can't verify authorization code"))
}
codes, ok := r.URL.Query()["code"]
if !ok || len(codes[0]) < 1 {
log.Fatal(errors.New("Authorization code missing"))
}
if states[0] != config.State {
log.Fatal(errors.New("State parameter is incorrect"))
}
code := codes[0]
// Getting the access token using the authorization code
result, err := publicClientApp.AcquireTokenByAuthCode(context.Background(), config.Scopes, &msal.AcquireTokenByAuthCodeOptions{
Code: code,
CodeChallenge: config.CodeChallenge,
})
if err != nil {
log.Fatal(err)
}
// Prints the access token on the webpage
fmt.Fprintf(w, "Access token is "+result.GetAccessToken())
}
func acquireByAuthorizationCodePublic() {
options := msal.DefaultPublicClientApplicationOptions()
options.Authority = config.Authority
publicClientApp, err := msal.NewPublicClientApplication(config.ClientID, &options)
if err != nil {
panic(err)
}
http.HandleFunc("/", redirectToURL)
// The redirect uri set in our app's registration is http://localhost:port/redirect
http.HandleFunc("/redirect", getToken)
log.Fatal(http.ListenAndServe(":"+port, nil))
}
func redirectToURL(w http.ResponseWriter, r *http.Request) {
// Getting the URL to redirect to acquire the authorization code
authCodeURLParams := msal.CreateAuthorizationCodeURLParameters(config.ClientID, config.RedirectURI, config.Scopes)
authCodeURLParams.CodeChallenge = config.CodeChallenge
authCodeURLParams.State = config.State
authURL, err := publicClientApp.AuthCodeURL(context.Background(), authCodeURLParams)
if err != nil {
log.Fatal(err)
}
// Redirecting to the URL we have received
log.Info(authURL)
http.Redirect(w, r, authURL, http.StatusSeeOther)
}
*/
microsoft-authentication-library-for-go-1.0.0/apps/tests/devapps/client_certificate_sample.go 0000664 0000000 0000000 00000002320 14420263624 0032642 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package main
import (
"context"
"fmt"
"log"
"os"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential"
)
func acquireTokenClientCertificate() {
config := CreateConfig("confidential_config.json")
pemData, err := os.ReadFile(config.PemData)
if err != nil {
log.Fatal(err)
}
// This extracts our public certificates and private key from the PEM file. If it is
// encrypted, the second argument must be password to decode.
certs, privateKey, err := confidential.CertFromPEM(pemData, "")
if err != nil {
log.Fatal(err)
}
cred, err := confidential.NewCredFromCert(certs, privateKey)
if err != nil {
log.Fatal(err)
}
app, err := confidential.New(config.Authority, config.ClientID, cred, confidential.WithCache(cacheAccessor))
if err != nil {
log.Fatal(err)
}
result, err := app.AcquireTokenSilent(context.Background(), config.Scopes)
if err != nil {
result, err = app.AcquireTokenByCredential(context.Background(), config.Scopes)
if err != nil {
log.Fatal(err)
}
fmt.Println("Access Token Is " + result.AccessToken)
return
}
fmt.Println("Silently acquired token " + result.AccessToken)
}
microsoft-authentication-library-for-go-1.0.0/apps/tests/devapps/client_secret_sample.go 0000664 0000000 0000000 00000001561 14420263624 0031653 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package main
import (
"context"
"fmt"
"log"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential"
)
func acquireTokenClientSecret() {
config := CreateConfig("confidential_config.json")
cred, err := confidential.NewCredFromSecret(config.ClientSecret)
if err != nil {
log.Fatal(err)
}
app, err := confidential.New(config.Authority, config.ClientID, cred, confidential.WithCache(cacheAccessor))
if err != nil {
log.Fatal(err)
}
result, err := app.AcquireTokenSilent(context.Background(), config.Scopes)
if err != nil {
result, err = app.AcquireTokenByCredential(context.Background(), config.Scopes)
if err != nil {
log.Fatal(err)
}
fmt.Println("Access Token Is " + result.AccessToken)
}
fmt.Println("Silently acquired token " + result.AccessToken)
}
microsoft-authentication-library-for-go-1.0.0/apps/tests/devapps/confidential_auth_code_sample.go 0000664 0000000 0000000 00000006567 14420263624 0033515 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package main
/*
var (
accessToken string
confidentialConfig = CreateConfig("confidential_config.json")
app confidential.Client
)
// TODO(msal): I'm not sure what to do here with the CodeChallenge and State. authCodeURLParams
// is no more. CodeChallenge is only used now in a confidential.AcquireTokenByAuthCode(), which
// this is not using. Maybe now this is a two step process????
func redirectToURLConfidential(w http.ResponseWriter, r *http.Request) {
// Getting the URL to redirect to acquire the authorization code
authCodeURLParams.CodeChallenge = confidentialConfig.CodeChallenge
authCodeURLParams.State = confidentialConfig.State
authURL, err := app.AuthCodeURL(context.Background(), confidentialConfig.ClientID, confidentialConfig.RedirectURI, confidentialConfig.Scopes)
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
// Redirecting to the URL we have received
log.Println("redirecting to auth: ", authURL)
http.Redirect(w, r, authURL, http.StatusSeeOther)
}
func getTokenConfidential(w http.ResponseWriter, r *http.Request) {
// Getting the authorization code from the URL's query
states, ok := r.URL.Query()["state"]
if !ok || len(states[0]) < 1 {
log.Fatal(errors.New("State parameter missing, can't verify authorization code"))
}
codes, ok := r.URL.Query()["code"]
if !ok || len(codes[0]) < 1 {
log.Fatal(errors.New("Authorization code missing"))
}
if states[0] != config.State {
log.Fatal(errors.New("State parameter is incorrect"))
}
code := codes[0]
// Getting the access token using the authorization code
result, err := app.AcquireTokenByAuthCode(
context.Background(),
confidentialConfig.Scopes,
confidential.CodeChallenge(code, confidentialConfig.CodeChallenge),
)
if err != nil {
log.Fatal(err)
}
// Prints the access token on the webpage
fmt.Fprintf(w, "Access token is "+result.GetAccessToken())
accessToken = result.GetAccessToken()
}
// TODO(msal): Needs to use an x509 certificate like the other now that we are not using a
// thumbprint directly.
/*
func acquireByAuthorizationCodeConfidential(ctx context.Context) {
key, err := os.ReadFile(confidentialConfig.KeyFile)
if err != nil {
log.Fatal(err)
}
certificate, err := msal.CreateClientCredentialFromCertificate(confidentialConfig.Thumbprint, key)
if err != nil {
log.Fatal(err)
}
options := msal.DefaultConfidentialClientApplicationOptions()
options.Accessor = cacheAccessor
options.Authority = confidentialConfig.Authority
app, err := msal.NewConfidentialClientApplication(confidentialConfig.ClientID, certificate, &options)
if err != nil {
log.Fatal(err)
}
var userAccount shared.Account
for _, account := range app.Accounts(ctx) {
if account.PreferredUsername == confidentialConfig.Username {
userAccount = account
}
}
result, err := app.AcquireTokenSilent(
context.Background(),
confidentialConfig.Scopes,
&msal.AcquireTokenSilentOptions{
Account: userAccount,
},
)
if err != nil {
panic(err)
}
fmt.Printf("Access token is " + result.GetAccessToken())
accessToken = result.GetAccessToken()
http.HandleFunc("/", redirectToURLConfidential)
// The redirect uri set in our app's registration is http://localhost:port/redirect
http.HandleFunc("/redirect", getTokenConfidential)
log.Fatal(http.ListenAndServe(":"+port, nil))
}
*/
microsoft-authentication-library-for-go-1.0.0/apps/tests/devapps/confidential_config.json 0000664 0000000 0000000 00000001004 14420263624 0032007 0 ustar 00root root 0000000 0000000 {
"authority": "https://login.microsoftonline.com/your_tenant_id",
"client_id": "your_client_id",
"scopes": ["requested_scope- usually of the format - {ResourceId}+/.default"],
"redirect_uri": "redirect uri registered on the portal",
"code_challenge": "transformed code verifier from PKCE",
"state": "state parameter for authorization code flow",
"client_secret": "client secret you generated for your app",
"pem_file": "the file path of pem containing public cert and private key"
}
microsoft-authentication-library-for-go-1.0.0/apps/tests/devapps/config.json 0000664 0000000 0000000 00000001111 14420263624 0027267 0 ustar 00root root 0000000 0000000 {
"authority": "https://login.microsoftonline.com/organizations",
"client_id": "your_client_id",
"scopes": ["user.read"],
"username": "your_username",
"password": "your_password",
"redirect_uri": "redirect uri registered on the portal",
"code_challenge": "transformed code verifier from PKCE",
"state": "state parameter for authorization code flow",
"client_secret": "client secret you generated for your app",
"thumbprint": "the certificate thumbprint defined in your app generation",
"pem_file": "the file path of your private key pem"
} microsoft-authentication-library-for-go-1.0.0/apps/tests/devapps/device_code_flow_sample.go 0000664 0000000 0000000 00000003426 14420263624 0032312 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package main
import (
"context"
"fmt"
"time"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/public"
)
func acquireTokenDeviceCode() {
config := CreateConfig("config.json")
app, err := public.New(config.ClientID, public.WithCache(cacheAccessor), public.WithAuthority(config.Authority))
if err != nil {
panic(err)
}
// look in the cache to see if the account to use has been cached
var userAccount public.Account
accounts, err := app.Accounts(context.Background())
if err != nil {
panic("failed to read the cache")
}
for _, account := range accounts {
if account.PreferredUsername == config.Username {
userAccount = account
}
}
// found a cached account, now see if an applicable token has been cached
// NOTE: this API conflates error states, i.e. err is non-nil if an applicable token isn't
// cached or if something goes wrong (making the HTTP request, unmarshalling, etc).
authResult, err := app.AcquireTokenSilent(context.Background(), config.Scopes, public.WithSilentAccount(userAccount))
if err != nil {
// either there was no cached account/token or the call to AcquireTokenSilent() failed
// make a new request to AAD
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Second)
defer cancel()
devCode, err := app.AcquireTokenByDeviceCode(ctx, config.Scopes)
if err != nil {
panic(err)
}
fmt.Printf("Device Code is: %s\n", devCode.Result.Message)
result, err := devCode.AuthenticationResult(ctx)
if err != nil {
panic(fmt.Sprintf("got error while waiting for user to input the device code: %s", err))
}
fmt.Println("Access token is " + result.AccessToken)
return
}
fmt.Println("Access token is " + authResult.AccessToken)
}
microsoft-authentication-library-for-go-1.0.0/apps/tests/devapps/main.go 0000664 0000000 0000000 00000001363 14420263624 0026413 0 ustar 00root root 0000000 0000000 package main
import (
"context"
"os"
)
var (
//config = CreateConfig("config.json")
cacheAccessor = &TokenCache{file: "serialized_cache.json"}
)
func main() {
ctx := context.Background()
// TODO(msal): This is pretty yikes. At least we should use the flag package.
exampleType := os.Args[1]
if exampleType == "1" {
acquireTokenDeviceCode()
/*} else if exampleType == "2" {
acquireByAuthorizationCodePublic()
*/
} else if exampleType == "3" {
acquireByUsernamePasswordPublic(ctx)
} else if exampleType == "4" {
panic("currently not implemented")
//acquireByAuthorizationCodeConfidential()
} else if exampleType == "5" {
acquireTokenClientSecret()
} else if exampleType == "6" {
acquireTokenClientCertificate()
}
}
microsoft-authentication-library-for-go-1.0.0/apps/tests/devapps/sample_cache_accessor.go 0000664 0000000 0000000 00000001241 14420263624 0031750 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package main
import (
"context"
"log"
"os"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache"
)
type TokenCache struct {
file string
}
func (t *TokenCache) Replace(ctx context.Context, cache cache.Unmarshaler, hints cache.ReplaceHints) error {
data, err := os.ReadFile(t.file)
if err != nil {
log.Println(err)
}
return cache.Unmarshal(data)
}
func (t *TokenCache) Export(ctx context.Context, cache cache.Marshaler, hints cache.ExportHints) error {
data, err := cache.Marshal()
if err != nil {
log.Println(err)
}
return os.WriteFile(t.file, data, 0600)
}
microsoft-authentication-library-for-go-1.0.0/apps/tests/devapps/sample_utils.go 0000664 0000000 0000000 00000002132 14420263624 0030163 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package main
import (
"encoding/json"
"log"
"os"
)
// Config represents the config.json required to run the samples
type Config struct {
ClientID string `json:"client_id"`
Authority string `json:"authority"`
Scopes []string `json:"scopes"`
Username string `json:"username"`
Password string `json:"password"`
RedirectURI string `json:"redirect_uri"`
CodeChallenge string `json:"code_challenge"`
CodeChallengeMethod string `json:"code_challenge_method"`
State string `json:"state"`
ClientSecret string `json:"client_secret"`
Thumbprint string `json:"thumbprint"`
PemData string `json:"pem_file"`
}
// CreateConfig creates the Config struct from a json file.
func CreateConfig(fileName string) *Config {
data, err := os.ReadFile(fileName)
if err != nil {
log.Fatal(err)
}
config := &Config{}
err = json.Unmarshal(data, config)
if err != nil {
log.Fatal(err)
}
return config
}
microsoft-authentication-library-for-go-1.0.0/apps/tests/devapps/serialized_cache.json 0000664 0000000 0000000 00000003274 14420263624 0031314 0 ustar 00root root 0000000 0000000 {
"Account": {
"uid.utid-login.windows.net-contoso": {
"username": "John Doe",
"local_account_id": "object1234",
"realm": "contoso",
"environment": "login.windows.net",
"home_account_id": "uid.utid",
"authority_type": "MSSTS"
}
},
"RefreshToken": {
"uid.utid-login.windows.net-refreshtoken-my_client_id--s2 s1 s3": {
"target": "s2 s1 s3",
"environment": "login.windows.net",
"credential_type": "RefreshToken",
"secret": "a refresh token",
"client_id": "my_client_id",
"home_account_id": "uid.utid"
}
},
"AccessToken": {
"uid.utid-login.windows.net-accesstoken-my_client_id-contoso-s2 s1 s3": {
"environment": "login.windows.net",
"credential_type": "AccessToken",
"secret": "an access token",
"realm": "contoso",
"target": "s2 s1 s3",
"client_id": "my_client_id",
"cached_at": "1000",
"home_account_id": "uid.utid",
"extended_expires_on": "4600",
"expires_on": "4600"
}
},
"IdToken": {
"uid.utid-login.windows.net-idtoken-my_client_id-contoso-": {
"realm": "contoso",
"environment": "login.windows.net",
"credential_type": "IdToken",
"secret": "header.eyJvaWQiOiAib2JqZWN0MTIzNCIsICJwcmVmZXJyZWRfdXNlcm5hbWUiOiAiSm9obiBEb2UiLCAic3ViIjogInN1YiJ9.signature",
"client_id": "my_client_id",
"home_account_id": "uid.utid"
}
},
"AppMetadata": {
"AppMetadata-login.windows.net-my_client_id": {
"environment": "login.windows.net",
"family_id": null,
"client_id": "my_client_id"
}
}
} microsoft-authentication-library-for-go-1.0.0/apps/tests/devapps/username_password_sample.go 0000664 0000000 0000000 00000003112 14420263624 0032563 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package main
import (
"context"
"fmt"
"log"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/public"
)
func acquireByUsernamePasswordPublic(ctx context.Context) {
config := CreateConfig("config.json")
app, err := public.New(config.ClientID, public.WithCache(cacheAccessor), public.WithAuthority(config.Authority))
if err != nil {
panic(err)
}
// look in the cache to see if the account to use has been cached
var userAccount public.Account
accounts, err := app.Accounts(ctx)
if err != nil {
panic("failed to read the cache")
}
for _, account := range accounts {
if account.PreferredUsername == config.Username {
userAccount = account
}
}
// found a cached account, now see if an applicable token has been cached
// NOTE: this API conflates error states, i.e. err is non-nil if an applicable token isn't
// cached or if something goes wrong (making the HTTP request, unmarshalling, etc).
result, err := app.AcquireTokenSilent(
context.Background(),
config.Scopes,
public.WithSilentAccount(userAccount),
)
if err != nil {
// either there's no applicable token in the cache or something failed
log.Println(err)
// either there was no cached account/token or the call to AcquireTokenSilent() failed
// make a new request to AAD
result, err = app.AcquireTokenByUsernamePassword(
context.Background(),
config.Scopes,
config.Username,
config.Password,
)
if err != nil {
log.Fatal(err)
}
}
fmt.Println("Access token is " + result.AccessToken)
}
microsoft-authentication-library-for-go-1.0.0/apps/tests/integration/ 0000775 0000000 0000000 00000000000 14420263624 0026016 5 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/apps/tests/integration/integration_test.go 0000664 0000000 0000000 00000033206 14420263624 0031733 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package integration
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"testing"
"time"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/errors"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/public"
)
const (
msIDlabDefaultScope = "https://msidlab.com/.default"
graphDefaultScope = "https://graph.windows.net/.default"
)
const microsoftAuthorityHost = "https://login.microsoftonline.com/"
const (
organizationsAuthority = microsoftAuthorityHost + "organizations/"
microsoftAuthority = microsoftAuthorityHost + "microsoft.onmicrosoft.com"
//msIDlabTenantAuthority = microsoftAuthorityHost + "msidlab4.onmicrosoft.com" - Will be needed in the future
)
var httpClient = http.Client{}
func httpRequest(ctx context.Context, url string, query url.Values, accessToken string) ([]byte, error) {
if _, ok := ctx.Deadline(); !ok {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, 10*time.Second)
defer cancel()
}
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, fmt.Errorf("failed to build new http request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.URL.RawQuery = query.Encode()
resp, err := httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("http.Get(%s) failed: %w", req.URL.String(), err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("http.Get(%s): could not read body: %w", req.URL.String(), err)
}
return body, nil
}
type labClient struct {
app confidential.Client
}
// TODO : Add app object
type user struct {
AppID string `json:"appId"`
ObjectID string `json:"objectId"`
UserType string `json:"userType"`
DisplayName string `json:"displayName"`
Licenses string `json:"licences"`
Upn string `json:"upn"`
Mfa string `json:"mfa"`
ProtectionPolicy string `json:"protectionPolicy"`
HomeDomain string `json:"homeDomain"`
HomeUPN string `json:"homeUPN"`
B2cProvider string `json:"b2cProvider"`
LabName string `json:"labName"`
LastUpdatedBy string `json:"lastUpdatedBy"`
LastUpdatedDate string `json:"lastUpdatedDate"`
Password string
}
type secret struct {
Value string `json:"value"`
}
func newLabClient() (*labClient, error) {
clientID := os.Getenv("clientId")
secret := os.Getenv("clientSecret")
cred, err := confidential.NewCredFromSecret(secret)
if err != nil {
return nil, fmt.Errorf("could not create a cred from a secret: %w", err)
}
app, err := confidential.New(microsoftAuthority, clientID, cred)
if err != nil {
return nil, err
}
return &labClient{app: app}, nil
}
func (l *labClient) labAccessToken() (string, error) {
scopes := []string{msIDlabDefaultScope}
result, err := l.app.AcquireTokenSilent(context.Background(), scopes)
if err != nil {
result, err = l.app.AcquireTokenByCredential(context.Background(), scopes)
if err != nil {
return "", fmt.Errorf("AcquireTokenByCredential() error: %w", err)
}
}
return result.AccessToken, nil
}
func (l *labClient) user(ctx context.Context, query url.Values) (user, error) {
accessToken, err := l.labAccessToken()
if err != nil {
return user{}, fmt.Errorf("problem getting lab access token: %w", err)
}
responseBody, err := httpRequest(ctx, "https://msidlab.com/api/user", query, accessToken)
if err != nil {
return user{}, err
}
var users []user
err = json.Unmarshal(responseBody, &users)
if err != nil {
return user{}, err
}
if len(users) == 0 {
return user{}, errors.New("No user found")
}
user := users[0]
user.Password, err = l.secret(ctx, url.Values{"Secret": []string{user.LabName}})
if err != nil {
return user, err
}
return user, nil
}
func (l *labClient) secret(ctx context.Context, query url.Values) (string, error) {
accessToken, err := l.labAccessToken()
if err != nil {
return "", err
}
responseBody, err := httpRequest(ctx, "https://msidlab.com/api/LabUserSecret", query, accessToken)
if err != nil {
return "", err
}
var secret secret
err = json.Unmarshal(responseBody, &secret)
if err != nil {
return "", err
}
return secret.Value, nil
}
// TODO: Add getApp() when needed
func testUser(ctx context.Context, desc string, lc *labClient, query url.Values) user {
testUser, err := lc.user(ctx, query)
if err != nil {
panic(fmt.Sprintf("TestUsernamePassword(%s) setup: testUser(): Failed to get input user: %s", desc, err))
}
return testUser
}
func TestUsernamePassword(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test")
}
labClientInstance, err := newLabClient()
if err != nil {
panic("failed to get a lab client: " + err.Error())
}
tests := []struct {
desc string
vals url.Values
}{
{"Managed", url.Values{"usertype": []string{"cloud"}}},
{"ADFSv2", url.Values{"usertype": []string{"federated"}, "federationProvider": []string{"ADFSv2"}}},
{"ADFSv3", url.Values{"usertype": []string{"federated"}, "federationProvider": []string{"ADFSv3"}}},
{"ADFSv4", url.Values{"usertype": []string{"federated"}, "federationProvider": []string{"ADFSv4"}}},
}
for _, test := range tests {
ctx := context.Background()
user := testUser(ctx, test.desc, labClientInstance, test.vals)
app, err := public.New(user.AppID, public.WithAuthority(organizationsAuthority))
if err != nil {
panic(errors.Verbose(err))
}
result, err := app.AcquireTokenByUsernamePassword(
context.Background(),
[]string{graphDefaultScope},
user.Upn,
user.Password,
)
if err != nil {
t.Fatalf("TestUsernamePassword(%s): on AcquireTokenByUsernamePassword(): got err == %s, want err == nil", test.desc, errors.Verbose(err))
}
if result.AccessToken == "" {
t.Fatalf("TestUsernamePassword(%s): got AccessToken == '', want AccessToken != ''", test.desc)
}
if result.IDToken.IsZero() {
t.Fatalf("TestUsernamePassword(%s): got IDToken == empty, want IDToken == non-empty struct", test.desc)
}
if result.Account.PreferredUsername != user.Upn {
t.Fatalf("TestUsernamePassword(%s): got Username == %s, want Username == %s", test.desc, result.Account.PreferredUsername, user.Upn)
}
}
}
func TestConfidentialClientwithSecret(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test")
}
clientID := os.Getenv("clientId")
secret := os.Getenv("clientSecret")
cred, err := confidential.NewCredFromSecret(secret)
if err != nil {
panic(errors.Verbose(err))
}
app, err := confidential.New(microsoftAuthority, clientID, cred)
if err != nil {
panic(errors.Verbose(err))
}
scopes := []string{msIDlabDefaultScope}
result, err := app.AcquireTokenByCredential(context.Background(), scopes)
if err != nil {
t.Fatalf("TestConfidentialClientwithSecret: on AcquireTokenByCredential(): got err == %s, want err == nil", errors.Verbose(err))
}
if result.AccessToken == "" {
t.Fatal("TestConfidentialClientwithSecret: on AcquireTokenByCredential(): got AccessToken == '', want AccessToken != ''")
}
silentResult, err := app.AcquireTokenSilent(context.Background(), scopes)
if err != nil {
t.Fatalf("TestConfidentialClientwithSecret: on AcquireTokenSilent(): got err == %s, want err == nil", errors.Verbose(err))
}
if silentResult.AccessToken == "" {
t.Fatal("TestConfidentialClientwithSecret: on AcquireTokenSilent(): got AccessToken == '', want AccessToken != ''")
}
}
func TestOnBehalfOf(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test")
}
labClientInstance, err := newLabClient()
if err != nil {
panic("failed to get a lab client: " + err.Error())
}
ctx := context.Background()
//Confidential Client Application Config
ccaClientID := os.Getenv("oboConfidentialClientId")
ccaClientSecret := os.Getenv("oboConfidentialClientSecret")
ccaScopes := []string{"https://graph.microsoft.com/user.read"}
// Public Client Application Confifg
pcaClientID := os.Getenv("oboPublicClientId")
user := testUser(ctx, "OnBehalfOf", labClientInstance, url.Values{"usertype": []string{"cloud"}})
pcaScopes := []string{fmt.Sprintf("api://%s/.default", ccaClientID)}
// 1. An app obtains a token representing a user, for our mid-tier service
pca, err := public.New(pcaClientID, public.WithAuthority(organizationsAuthority))
if err != nil {
panic(errors.Verbose(err))
}
result, err := pca.AcquireTokenByUsernamePassword(
ctx, pcaScopes, user.Upn, user.Password,
)
if err != nil {
t.Fatalf("TestOnBehalfOf: on AcquireTokenByUsernamePassword(): got err == %s, want err == nil", errors.Verbose(err))
}
if result.AccessToken == "" {
t.Fatal("TestOnBehalfOf: on AcquireTokenByUsernamePassword(): got AccessToken == '', want AccessToken != ''")
}
// 2. Our mid-tier service uses OBO to obtain a token for downstream service
cred, err := confidential.NewCredFromSecret(ccaClientSecret)
if err != nil {
panic(errors.Verbose(err))
}
cca, err := confidential.New("https://login.microsoftonline.com/common", ccaClientID, cred)
if err != nil {
panic(errors.Verbose(err))
}
result1, err := cca.AcquireTokenOnBehalfOf(ctx, result.AccessToken, ccaScopes)
if err != nil {
t.Fatalf("TestOnBehalfOf: on AcquireTokenOnBehalfOf(): got err == %s, want err == nil", errors.Verbose(err))
}
if result1.AccessToken == "" {
t.Fatal("TestOnBehalfOf: on AcquireTokenOnBehalfOf(): got AccessToken == '', want AccessToken != ''")
}
// 3. Same scope and assertion should return cached access token
result2, err := cca.AcquireTokenOnBehalfOf(ctx, result.AccessToken, ccaScopes)
if err != nil {
t.Fatalf("TestOnBehalfOf: on AcquireTokenOnBehalfOf() silent token retrieval: got err == %s, want err == nil", errors.Verbose(err))
}
if result1.AccessToken != result2.AccessToken {
t.Fatal("TestOnBehalfOf: on AcquireTokenOnBehalfOf(): Access Tokens don't match")
}
// 4. scope2 should return new token
scope2 := []string{"https://graph.windows.net/.default"}
result3, err := cca.AcquireTokenOnBehalfOf(ctx, result.AccessToken, scope2)
if err != nil {
t.Fatalf("TestOnBehalfOf: on AcquireTokenOnBehalfOf(): got err == %s, want err == nil", errors.Verbose(err))
}
if result3.AccessToken == "" {
t.Fatal("TestOnBehalfOf: on AcquireTokenOnBehalfOf(): got AccessToken == '', want AccessToken != ''")
}
if result3.AccessToken == result2.AccessToken {
t.Fatal("TestOnBehalfOf: on AcquireTokenOnBehalfOf(): Access Tokens match when they should not")
}
// 5. scope2 should return cached token
result4, err := cca.AcquireTokenOnBehalfOf(ctx, result.AccessToken, scope2)
if err != nil {
t.Fatalf("TestOnBehalfOf: on AcquireTokenOnBehalfOf(): got err == %s, want err == nil", errors.Verbose(err))
}
if result4.AccessToken == "" {
t.Fatal("TestOnBehalfOf: on AcquireTokenOnBehalfOf(): got AccessToken == '', want AccessToken != ''")
}
if result4.AccessToken != result3.AccessToken {
t.Fatal("TestOnBehalfOf: on AcquireTokenOnBehalfOf(): Access Tokens don't match")
}
// 6. New user assertion should return new token
pca1, err := public.New(pcaClientID, public.WithAuthority(organizationsAuthority))
if err != nil {
panic(errors.Verbose(err))
}
result5, err := pca1.AcquireTokenByUsernamePassword(
ctx, pcaScopes, user.Upn, user.Password,
)
if err != nil {
t.Fatalf("TestOnBehalfOf: on AcquireTokenByUsernamePassword(): got err == %s, want err == nil", errors.Verbose(err))
}
result6, err := cca.AcquireTokenOnBehalfOf(ctx, result5.AccessToken, scope2)
if err != nil {
t.Fatalf("TestOnBehalfOf: on AcquireTokenOnBehalfOf(): got err == %s, want err == nil", errors.Verbose(err))
}
if result6.AccessToken == "" {
t.Fatal("TestOnBehalfOf: on AcquireTokenOnBehalfOf(): got AccessToken == '', want AccessToken != ''")
}
if result6.AccessToken == result4.AccessToken {
t.Fatal("TestOnBehalfOf: on AcquireTokenOnBehalfOf(): Access Tokens match when they should not")
}
if result6.AccessToken == result3.AccessToken {
t.Fatal("TestOnBehalfOf: on AcquireTokenOnBehalfOf(): Access Tokens match when they should not")
}
if result6.AccessToken == result2.AccessToken {
t.Fatal("TestOnBehalfOf: on AcquireTokenOnBehalfOf(): Access Tokens match when they should not")
}
}
func TestRemoveAccount(t *testing.T) {
if testing.Short() {
t.Skip("skipping integration test")
}
labClientInstance, err := newLabClient()
if err != nil {
panic("failed to get a lab client: " + err.Error())
}
ctx := context.Background()
user := testUser(ctx, "TestRemoveAccount", labClientInstance, url.Values{"usertype": []string{"cloud"}})
app, err := public.New(user.AppID, public.WithAuthority(organizationsAuthority))
if err != nil {
panic(errors.Verbose(err))
}
// Populate the cache
_, err = app.AcquireTokenByUsernamePassword(
context.Background(),
[]string{graphDefaultScope},
user.Upn,
user.Password,
)
if err != nil {
t.Fatalf("TestRemoveAccount: on AcquireTokenByUsernamePassword(): got err == %s, want err == nil", errors.Verbose(err))
}
accounts, err := app.Accounts(ctx)
if err != nil {
t.Fatal(err)
}
if len(accounts) == 0 {
t.Fatal("TestRemoveAccount: No user accounts found in cache")
}
testAccount := accounts[0] // Only one account is populated and that is what we will remove.
err = app.RemoveAccount(ctx, testAccount)
if err != nil {
t.Fatalf("TestRemoveAccount: on RemoveAccount(): got err == %s, want err == nil", errors.Verbose(err))
}
// Remove Account will clear the cache fields associated with this account so acquire token silent should fail
_, err = app.AcquireTokenSilent(ctx, []string{graphDefaultScope}, public.WithSilentAccount(testAccount))
if err == nil {
t.Fatal("TestRemoveAccount: RemoveAccount() didn't clear the cache as expected")
}
}
microsoft-authentication-library-for-go-1.0.0/apps/tests/performance/ 0000775 0000000 0000000 00000000000 14420263624 0025774 5 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/apps/tests/performance/performance_test.go 0000664 0000000 0000000 00000011040 14420263624 0031657 0 ustar 00root root 0000000 0000000 // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package performance
import (
"context"
"fmt"
"math/rand"
"os"
"testing"
"time"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base"
internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/fake"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
"github.com/montanaflynn/stats"
)
func fakeClient() (base.Client, error) {
// we use a base.Client so we can provide a fake OAuth client
return base.New("fake_client_id", "https://fake_authority/my_utid", &oauth.Client{
Authority: &fake.Authority{
InstanceResp: authority.InstanceDiscoveryResponse{
Metadata: []authority.InstanceDiscoveryMetadata{
{
PreferredNetwork: "fake_authority",
Aliases: []string{"fake_authority"},
},
},
},
},
Resolver: &fake.ResolveEndpoints{
Endpoints: authority.Endpoints{
AuthorizationEndpoint: "auth_endpoint",
TokenEndpoint: "token_endpoint",
},
},
WSTrust: &fake.WSTrust{},
})
}
func populateCache(users int, tokens int, authParams authority.AuthParams, client base.Client) {
for user := 0; user < users; user++ {
for token := 0; token < tokens; token++ {
authParams := client.AuthParams
authParams.UserAssertion = fmt.Sprintf("fake_access_token%d", user)
authParams.AuthorizationType = authority.ATOnBehalfOf
scope := fmt.Sprintf("scope%d", token)
_, err := client.AuthResultFromToken(context.Background(), authParams, accesstokens.TokenResponse{
AccessToken: fmt.Sprintf("fake_access_token%d", user),
RefreshToken: "fake_refresh_token",
ClientInfo: accesstokens.ClientInfo{UID: "my_uid", UTID: fmt.Sprintf("%dmy_utid", user)},
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: []string{scope}},
IDToken: accesstokens.IDToken{
RawToken: "x.e30",
},
}, true)
if err != nil {
panic(err)
}
}
}
}
func calculateStats(users, tokens int, duration []float64) {
fmt.Printf("No of users: %d, No of tokens per user: %d \n", users, tokens)
mean, err := stats.Mean(duration)
if err != nil {
panic(err)
}
meanTime := mean / float64(time.Microsecond)
fmt.Println("Mean")
fmt.Println(meanTime)
median, err := stats.Median(duration)
medianTime := median / float64(time.Microsecond)
if err != nil {
panic(err)
}
fmt.Println("Median")
fmt.Println(medianTime)
stdDev, err := stats.StandardDeviation(duration)
stdDevTime := stdDev / float64(time.Microsecond)
if err != nil {
panic(err)
}
fmt.Println("Standard Deviation")
fmt.Println(stdDevTime)
min, err := stats.Min(duration)
minTime := min / float64(time.Microsecond)
if err != nil {
panic(err)
}
fmt.Println("Min Time")
fmt.Println(minTime)
max, err := stats.Max(duration)
maxTime := max / float64(time.Microsecond)
if err != nil {
panic(err)
}
fmt.Println("Max Time")
fmt.Println(maxTime)
}
func benchMarkObo(users int, tokens int, client base.Client) {
var duration []float64
for start := time.Now(); time.Since(start) < time.Minute*1; {
s := time.Now()
queryCache(users, tokens, client)
e := time.Now()
duration = append(duration, float64(e.Sub(s)))
}
calculateStats(users, tokens, duration)
}
func queryCache(users int, tokens int, client base.Client) {
userAssertion := fmt.Sprintf("fake_access_token%d", rand.Intn(users))
scope := []string{fmt.Sprintf("scope%d", rand.Intn(tokens))}
params := base.AcquireTokenOnBehalfOfParameters{
Scopes: scope,
UserAssertion: userAssertion,
Credential: &accesstokens.Credential{Secret: "fake_secret"},
}
_, err := client.AcquireTokenOnBehalfOf(context.Background(), params)
if err != nil {
panic(err)
}
}
func TestOnBehalfOfCacheTests(t *testing.T) {
if os.Getenv("CI") != "" {
t.Skip("Skipping testing in CI environment")
}
tests := []struct {
Users int
Tokens int
}{
{1, 10000},
{1, 100000},
{100, 10000},
{1000, 10000},
{10000, 100},
}
for _, test := range tests {
client, err := fakeClient()
if err != nil {
panic(err)
}
authParams := client.AuthParams
populateCache(test.Users, test.Tokens, authParams, client)
benchMarkObo(test.Users, test.Tokens, client)
}
}
microsoft-authentication-library-for-go-1.0.0/changelog.md 0000664 0000000 0000000 00000000000 14420263624 0023625 0 ustar 00root root 0000000 0000000 microsoft-authentication-library-for-go-1.0.0/go.mod 0000664 0000000 0000000 00000000520 14420263624 0022471 0 ustar 00root root 0000000 0000000 module github.com/AzureAD/microsoft-authentication-library-for-go
go 1.18
require (
github.com/golang-jwt/jwt/v4 v4.4.3
github.com/google/uuid v1.3.0
github.com/kylelemons/godebug v1.1.0
github.com/montanaflynn/stats v0.7.0
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8
)
require golang.org/x/sys v0.5.0 // indirect
microsoft-authentication-library-for-go-1.0.0/go.sum 0000664 0000000 0000000 00000002221 14420263624 0022516 0 ustar 00root root 0000000 0000000 github.com/golang-jwt/jwt/v4 v4.4.3 h1:Hxl6lhQFj4AnOX6MLrsCb/+7tCj7DxP7VA+2rDIq5AU=
github.com/golang-jwt/jwt/v4 v4.4.3/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/montanaflynn/stats v0.7.0 h1:r3y12KyNxj/Sb/iOE46ws+3mS1+MZca1wlHQFPsY/JU=
github.com/montanaflynn/stats v0.7.0/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt6R8Bnaayow=
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8 h1:KoWmjvw+nsYOo29YJK9vDA65RGE3NrOnUtO7a+RF9HU=
github.com/pkg/browser v0.0.0-20210911075715-681adbf594b8/go.mod h1:HKlIX3XHQyzLZPlr7++PzdhaXEj94dEiJgZDTsxEqUI=
golang.org/x/sys v0.0.0-20210616045830-e2b7044e8c71/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=