pax_global_header00006660000000000000000000000064143741264400014517gustar00rootroot0000000000000052 comment=debc4157d4842f33451c8e7e34f08377fa1c37ed dtls-2.2.6/000077500000000000000000000000001437412644000124745ustar00rootroot00000000000000dtls-2.2.6/.editorconfig000066400000000000000000000004671437412644000151600ustar00rootroot00000000000000# http://editorconfig.org/ root = true [*] charset = utf-8 insert_final_newline = true trim_trailing_whitespace = true end_of_line = lf [*.go] indent_style = tab indent_size = 4 [{*.yml,*.yaml}] indent_style = space indent_size = 2 # Makefiles always use tabs for indentation [Makefile] indent_style = tab dtls-2.2.6/.github/000077500000000000000000000000001437412644000140345ustar00rootroot00000000000000dtls-2.2.6/.github/.gitignore000066400000000000000000000000121437412644000160150ustar00rootroot00000000000000.goassets dtls-2.2.6/.github/fetch-scripts.sh000077500000000000000000000014351437412644000171540ustar00rootroot00000000000000#!/bin/sh # # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # set -eu SCRIPT_PATH="$(realpath "$(dirname "$0")")" GOASSETS_PATH="${SCRIPT_PATH}/.goassets" GOASSETS_REF=${GOASSETS_REF:-master} if [ -d "${GOASSETS_PATH}" ]; then if ! git -C "${GOASSETS_PATH}" diff --exit-code; then echo "${GOASSETS_PATH} has uncommitted changes" >&2 exit 1 fi git -C "${GOASSETS_PATH}" fetch origin git -C "${GOASSETS_PATH}" checkout ${GOASSETS_REF} git -C "${GOASSETS_PATH}" reset --hard origin/${GOASSETS_REF} else git clone -b ${GOASSETS_REF} https://github.com/pion/.goassets.git "${GOASSETS_PATH}" fi dtls-2.2.6/.github/install-hooks.sh000077500000000000000000000010771437412644000171670ustar00rootroot00000000000000#!/bin/sh # # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # SCRIPT_PATH="$(realpath "$(dirname "$0")")" . ${SCRIPT_PATH}/fetch-scripts.sh cp "${GOASSETS_PATH}/hooks/commit-msg.sh" "${SCRIPT_PATH}/../.git/hooks/commit-msg" cp "${GOASSETS_PATH}/hooks/pre-commit.sh" "${SCRIPT_PATH}/../.git/hooks/pre-commit" cp "${GOASSETS_PATH}/hooks/pre-push.sh" "${SCRIPT_PATH}/../.git/hooks/pre-push" dtls-2.2.6/.github/workflows/000077500000000000000000000000001437412644000160715ustar00rootroot00000000000000dtls-2.2.6/.github/workflows/codeql-analysis.yml000066400000000000000000000011551437412644000217060ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # If this repository should have package specific CI config, # remove the repository name from .goassets/.github/workflows/assets-sync.yml. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # name: CodeQL on: workflow_dispatch: schedule: - cron: '23 5 * * 0' pull_request: branches: - master paths: - '**.go' jobs: analyze: uses: pion/.goassets/.github/workflows/codeql-analysis.reusable.yml@master dtls-2.2.6/.github/workflows/e2e.yaml000066400000000000000000000005521437412644000174320ustar00rootroot00000000000000name: E2E on: pull_request: branches: - master push: branches: - master jobs: e2e-test: name: Test runs-on: ubuntu-latest steps: - name: checkout uses: actions/checkout@v2 - name: test run: | docker build -t pion-dtls-e2e -f e2e/Dockerfile . docker run -i --rm pion-dtls-e2e dtls-2.2.6/.github/workflows/generate-authors.yml000066400000000000000000000011041437412644000220650ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # If this repository should have package specific CI config, # remove the repository name from .goassets/.github/workflows/assets-sync.yml. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # name: Generate Authors on: pull_request: jobs: generate: uses: pion/.goassets/.github/workflows/generate-authors.reusable.yml@master secrets: token: ${{ secrets.PIONBOT_PRIVATE_KEY }} dtls-2.2.6/.github/workflows/lint.yaml000066400000000000000000000007521437412644000177270ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # If this repository should have package specific CI config, # remove the repository name from .goassets/.github/workflows/assets-sync.yml. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # name: Lint on: pull_request: jobs: lint: uses: pion/.goassets/.github/workflows/lint.reusable.yml@master dtls-2.2.6/.github/workflows/release.yml000066400000000000000000000011051437412644000202310ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # If this repository should have package specific CI config, # remove the repository name from .goassets/.github/workflows/assets-sync.yml. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # name: Release on: push: tags: - 'v*' jobs: release: uses: pion/.goassets/.github/workflows/release.reusable.yml@master with: go-version: '1.19' # auto-update/latest-go-version dtls-2.2.6/.github/workflows/renovate-go-sum-fix.yaml000066400000000000000000000011241437412644000225670ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # If this repository should have package specific CI config, # remove the repository name from .goassets/.github/workflows/assets-sync.yml. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # name: Fix go.sum on: push: branches: - renovate/* jobs: fix: uses: pion/.goassets/.github/workflows/renovate-go-sum-fix.reusable.yml@master secrets: token: ${{ secrets.PIONBOT_PRIVATE_KEY }} dtls-2.2.6/.github/workflows/test.yaml000066400000000000000000000021121437412644000177300ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # If this repository should have package specific CI config, # remove the repository name from .goassets/.github/workflows/assets-sync.yml. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # name: Test on: push: branches: - master pull_request: jobs: test: uses: pion/.goassets/.github/workflows/test.reusable.yml@master strategy: matrix: go: ['1.19', '1.18'] # auto-update/supported-go-version-list fail-fast: false with: go-version: ${{ matrix.go }} test-i386: uses: pion/.goassets/.github/workflows/test-i386.reusable.yml@master strategy: matrix: go: ['1.19', '1.18'] # auto-update/supported-go-version-list fail-fast: false with: go-version: ${{ matrix.go }} test-wasm: uses: pion/.goassets/.github/workflows/test-wasm.reusable.yml@master with: go-version: '1.19' # auto-update/latest-go-version dtls-2.2.6/.github/workflows/tidy-check.yaml000066400000000000000000000011371437412644000210030ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # If this repository should have package specific CI config, # remove the repository name from .goassets/.github/workflows/assets-sync.yml. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # name: Go mod tidy on: pull_request: push: branches: - master jobs: tidy: uses: pion/.goassets/.github/workflows/tidy-check.reusable.yml@master with: go-version: '1.19' # auto-update/latest-go-version dtls-2.2.6/.gitignore000066400000000000000000000004661437412644000144720ustar00rootroot00000000000000### JetBrains IDE ### ##################### .idea/ ### Emacs Temporary Files ### ############################# *~ ### Folders ### ############### bin/ vendor/ node_modules/ ### Files ### ############# *.ivf *.ogg tags cover.out *.sw[poe] *.wasm examples/sfu-ws/cert.pem examples/sfu-ws/key.pem wasm_exec.js dtls-2.2.6/.golangci.yml000066400000000000000000000172751437412644000150740ustar00rootroot00000000000000linters-settings: govet: check-shadowing: true misspell: locale: US exhaustive: default-signifies-exhaustive: true gomodguard: blocked: modules: - github.com/pkg/errors: recommendations: - errors linters: enable: - asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers - bidichk # Checks for dangerous unicode character sequences - bodyclose # checks whether HTTP response body is closed successfully - contextcheck # check the function whether use a non-inherited context - decorder # check declaration order and count of types, constants, variables and functions - depguard # Go linter that checks if package imports are in a list of acceptable packages - dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) - dupl # Tool for code clone detection - durationcheck # check for two durations multiplied together - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases - errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted. - errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`. - errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13. - exhaustive # check exhaustiveness of enum switch statements - exportloopref # checks for pointers to enclosing loop variables - forcetypeassert # finds forced type assertions - gci # Gci control golang package import order and make it always deterministic. - gochecknoglobals # Checks that no globals are present in Go code - gochecknoinits # Checks that no init functions are present in Go code - gocognit # Computes and checks the cognitive complexity of functions - goconst # Finds repeated strings that could be replaced by a constant - gocritic # The most opinionated Go source code linter - godox # Tool for detection of FIXME, TODO and other comment keywords - goerr113 # Golang linter to check the errors handling expressions - gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification - gofumpt # Gofumpt checks whether code was gofumpt-ed. - goheader # Checks is file header matches to pattern - goimports # Goimports does everything that gofmt does. Additionally it checks unused imports - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. - goprintffuncname # Checks that printf-like functions are named with `f` at the end - gosec # Inspects source code for security problems - gosimple # Linter for Go source code that specializes in simplifying a code - govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string - grouper # An analyzer to analyze expression groups. - importas # Enforces consistent import aliases - ineffassign # Detects when assignments to existing variables are not used - misspell # Finds commonly misspelled English words in comments - nakedret # Finds naked returns in functions greater than a specified function length - nilerr # Finds the code that returns nil even if it checks that the error is not nil. - nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value. - noctx # noctx finds sending http request without context.Context - predeclared # find code that shadows one of Go's predeclared identifiers - revive # golint replacement, finds style mistakes - staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks - stylecheck # Stylecheck is a replacement for golint - tagliatelle # Checks the struct tags. - tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17 - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes - typecheck # Like the front-end of a Go compiler, parses and type-checks Go code - unconvert # Remove unnecessary type conversions - unparam # Reports unused function parameters - unused # Checks Go code for unused constants, variables, functions and types - wastedassign # wastedassign finds wasted assignment statements - whitespace # Tool for detection of leading and trailing whitespace disable: - containedctx # containedctx is a linter that detects struct contained context.Context field - cyclop # checks function and package cyclomatic complexity - exhaustivestruct # Checks if all struct's fields are initialized - forbidigo # Forbids identifiers - funlen # Tool for detection of long functions - gocyclo # Computes and checks the cyclomatic complexity of functions - godot # Check if comments end in a period - gomnd # An analyzer to detect magic numbers. - ifshort # Checks that your code uses short syntax for if-statements whenever possible - ireturn # Accept Interfaces, Return Concrete Types - lll # Reports long lines - maintidx # maintidx measures the maintainability index of each function. - makezero # Finds slice declarations with non-zero initial length - maligned # Tool to detect Go structs that would take less memory if their fields were sorted - nestif # Reports deeply nested if statements - nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity - nolintlint # Reports ill-formed or insufficient nolint directives - paralleltest # paralleltest detects missing usage of t.Parallel() method in your Go test - prealloc # Finds slice declarations that could potentially be preallocated - promlinter # Check Prometheus metrics naming via promlint - rowserrcheck # checks whether Err of rows is checked successfully - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. - testpackage # linter that makes you use a separate _test package - thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers - varnamelen # checks that the length of a variable's name matches its scope - wrapcheck # Checks that errors returned from external packages are wrapped - wsl # Whitespace Linter - Forces you to use empty lines! issues: exclude-use-default: false exclude-rules: # Allow complex tests, better to be self contained - path: _test\.go linters: - gocognit # Allow complex main function in examples - path: examples text: "of func `main` is high" linters: - gocognit run: skip-dirs-use-default: false dtls-2.2.6/.goreleaser.yml000066400000000000000000000000251437412644000154220ustar00rootroot00000000000000builds: - skip: true dtls-2.2.6/AUTHORS.txt000066400000000000000000000041571437412644000143710ustar00rootroot00000000000000# Thank you to everyone that made Pion possible. If you are interested in contributing # we would love to have you https://github.com/pion/webrtc/wiki/Contributing # # This file is auto generated, using git to list all individuals contributors. # see https://github.com/pion/.goassets/blob/master/scripts/generate-authors.sh for the scripting Aleksandr Razumov alvarowolfx Arlo Breault Atsushi Watanabe backkem bjdgyc boks1971 Bragadeesh Carson Hoffman Cecylia Bocovich Chris Hiszpanski Daniele Sluijters folbrich Hayden James Hugo Arregui Hugo Arregui igolaizola <11333576+igolaizola@users.noreply.github.com> Jeffrey Stoke Jeroen de Bruijn Jeroen de Bruijn Jim Wert jinleileiking Jozef Kralik Julien Salleyron Juliusz Chroboczek Kegan Dougal Kevin Wang Lander Noterman Len Lukas Lihotzki ManuelBk <26275612+ManuelBk@users.noreply.github.com> Michael Zabka Michiel De Backker Rachel Chen Robert Eperjesi Ryan Gordon Sam Lancia Sean DuBois Sean DuBois Sean DuBois Shelikhoo Stefan Tatschner Steffen Vogel Vadim Vadim Filimonov wmiao ZHENK 吕海涛 # List of contributors not appearing in Git history dtls-2.2.6/LICENSE000066400000000000000000000020411437412644000134760ustar00rootroot00000000000000MIT License Copyright (c) 2018 Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. dtls-2.2.6/README.md000066400000000000000000000116031437412644000137540ustar00rootroot00000000000000


Pion DTLS

A Go implementation of DTLS

Pion DTLS Sourcegraph Widget Slack Widget
Build Status GoDoc Coverage Status Go Report Card License: MIT


Native [DTLS 1.2][rfc6347] implementation in the Go programming language. A long term goal is a professional security review, and maybe an inclusion in stdlib. [rfc6347]: https://tools.ietf.org/html/rfc6347 ### Goals/Progress This will only be targeting DTLS 1.2, and the most modern/common cipher suites. We would love contributions that fall under the 'Planned Features' and any bug fixes! #### Current features * DTLS 1.2 Client/Server * Key Exchange via ECDHE(curve25519, nistp256, nistp384) and PSK * Packet loss and re-ordering is handled during handshaking * Key export ([RFC 5705][rfc5705]) * Serialization and Resumption of sessions * Extended Master Secret extension ([RFC 7627][rfc7627]) * ALPN extension ([RFC 7301][rfc7301]) [rfc5705]: https://tools.ietf.org/html/rfc5705 [rfc7627]: https://tools.ietf.org/html/rfc7627 [rfc7301]: https://tools.ietf.org/html/rfc7301 #### Supported ciphers ##### ECDHE * TLS_ECDHE_ECDSA_WITH_AES_128_CCM ([RFC 6655][rfc6655]) * TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 ([RFC 6655][rfc6655]) * TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 ([RFC 5289][rfc5289]) * TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 ([RFC 5289][rfc5289]) * TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 ([RFC 5289][rfc5289]) * TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 ([RFC 5289][rfc5289]) * TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA ([RFC 8422][rfc8422]) * TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA ([RFC 8422][rfc8422]) ##### PSK * TLS_PSK_WITH_AES_128_CCM ([RFC 6655][rfc6655]) * TLS_PSK_WITH_AES_128_CCM_8 ([RFC 6655][rfc6655]) * TLS_PSK_WITH_AES_256_CCM_8 ([RFC 6655][rfc6655]) * TLS_PSK_WITH_AES_128_GCM_SHA256 ([RFC 5487][rfc5487]) * TLS_PSK_WITH_AES_128_CBC_SHA256 ([RFC 5487][rfc5487]) ##### ECDHE & PSK * TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 ([RFC 5489][rfc5489]) [rfc5289]: https://tools.ietf.org/html/rfc5289 [rfc8422]: https://tools.ietf.org/html/rfc8422 [rfc6655]: https://tools.ietf.org/html/rfc6655 [rfc5487]: https://tools.ietf.org/html/rfc5487 [rfc5489]: https://tools.ietf.org/html/rfc5489 #### Planned Features * Chacha20Poly1305 #### Excluded Features * DTLS 1.0 * Renegotiation * Compression ### Using This library needs at least Go 1.13, and you should have [Go modules enabled](https://github.com/golang/go/wiki/Modules). #### Pion DTLS For a DTLS 1.2 Server that listens on 127.0.0.1:4444 ```sh go run examples/listen/selfsign/main.go ``` For a DTLS 1.2 Client that connects to 127.0.0.1:4444 ```sh go run examples/dial/selfsign/main.go ``` #### OpenSSL Pion DTLS can connect to itself and OpenSSL. ``` // Generate a certificate openssl ecparam -out key.pem -name prime256v1 -genkey openssl req -new -sha256 -key key.pem -out server.csr openssl x509 -req -sha256 -days 365 -in server.csr -signkey key.pem -out cert.pem // Use with examples/dial/selfsign/main.go openssl s_server -dtls1_2 -cert cert.pem -key key.pem -accept 4444 // Use with examples/listen/selfsign/main.go openssl s_client -dtls1_2 -connect 127.0.0.1:4444 -debug -cert cert.pem -key key.pem ``` ### Using with PSK Pion DTLS also comes with examples that do key exchange via PSK #### Pion DTLS ```sh go run examples/listen/psk/main.go ``` ```sh go run examples/dial/psk/main.go ``` #### OpenSSL ``` // Use with examples/dial/psk/main.go openssl s_server -dtls1_2 -accept 4444 -nocert -psk abc123 -cipher PSK-AES128-CCM8 // Use with examples/listen/psk/main.go openssl s_client -dtls1_2 -connect 127.0.0.1:4444 -psk abc123 -cipher PSK-AES128-CCM8 ``` ### Contributing Check out the **[contributing wiki](https://github.com/pion/webrtc/wiki/Contributing)** to join the group of amazing people making this project possible: ### License MIT License - see [LICENSE](LICENSE) for full text dtls-2.2.6/bench_test.go000066400000000000000000000046421437412644000151470ustar00rootroot00000000000000package dtls import ( "context" "crypto/tls" "fmt" "testing" "time" "github.com/pion/dtls/v2/internal/net/dpipe" "github.com/pion/dtls/v2/pkg/crypto/selfsign" "github.com/pion/logging" "github.com/pion/transport/v2/test" ) func TestSimpleReadWrite(t *testing.T) { report := test.CheckRoutines(t) defer report() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() ca, cb := dpipe.Pipe() certificate, err := selfsign.GenerateSelfSigned() if err != nil { t.Fatal(err) } gotHello := make(chan struct{}) go func() { server, sErr := testServer(ctx, cb, &Config{ Certificates: []tls.Certificate{certificate}, LoggerFactory: logging.NewDefaultLoggerFactory(), }, false) if sErr != nil { t.Error(sErr) return } buf := make([]byte, 1024) if _, sErr = server.Read(buf); sErr != nil { t.Error(sErr) } gotHello <- struct{}{} if sErr = server.Close(); sErr != nil { //nolint:contextcheck t.Error(sErr) } }() client, err := testClient(ctx, ca, &Config{ LoggerFactory: logging.NewDefaultLoggerFactory(), InsecureSkipVerify: true, }, false) if err != nil { t.Fatal(err) } if _, err = client.Write([]byte("hello")); err != nil { t.Error(err) } select { case <-gotHello: // OK case <-time.After(time.Second * 5): t.Error("timeout") } if err = client.Close(); err != nil { t.Error(err) } } func benchmarkConn(b *testing.B, n int64) { b.Run(fmt.Sprintf("%d", n), func(b *testing.B) { ctx := context.Background() ca, cb := dpipe.Pipe() certificate, err := selfsign.GenerateSelfSigned() server := make(chan *Conn) go func() { s, sErr := testServer(ctx, cb, &Config{ Certificates: []tls.Certificate{certificate}, }, false) if err != nil { b.Error(sErr) return } server <- s }() if err != nil { b.Fatal(err) } hw := make([]byte, n) b.ReportAllocs() b.SetBytes(int64(len(hw))) go func() { client, cErr := testClient(ctx, ca, &Config{InsecureSkipVerify: true}, false) if cErr != nil { b.Error(err) } for { if _, cErr = client.Write(hw); cErr != nil { //nolint:contextcheck b.Error(err) } } }() s := <-server buf := make([]byte, 2048) for i := 0; i < b.N; i++ { if _, err = s.Read(buf); err != nil { b.Error(err) } } }) } func BenchmarkConnReadWrite(b *testing.B) { for _, n := range []int64{16, 128, 512, 1024, 2048} { benchmarkConn(b, n) } } dtls-2.2.6/certificate.go000066400000000000000000000110211437412644000153000ustar00rootroot00000000000000package dtls import ( "bytes" "crypto/tls" "crypto/x509" "fmt" "strings" ) // ClientHelloInfo contains information from a ClientHello message in order to // guide application logic in the GetCertificate. type ClientHelloInfo struct { // ServerName indicates the name of the server requested by the client // in order to support virtual hosting. ServerName is only set if the // client is using SNI (see RFC 4366, Section 3.1). ServerName string // CipherSuites lists the CipherSuites supported by the client (e.g. // TLS_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256). CipherSuites []CipherSuiteID } // CertificateRequestInfo contains information from a server's // CertificateRequest message, which is used to demand a certificate and proof // of control from a client. type CertificateRequestInfo struct { // AcceptableCAs contains zero or more, DER-encoded, X.501 // Distinguished Names. These are the names of root or intermediate CAs // that the server wishes the returned certificate to be signed by. An // empty slice indicates that the server has no preference. AcceptableCAs [][]byte } // SupportsCertificate returns nil if the provided certificate is supported by // the server that sent the CertificateRequest. Otherwise, it returns an error // describing the reason for the incompatibility. // NOTE: original src: https://github.com/golang/go/blob/29b9a328d268d53833d2cc063d1d8b4bf6852675/src/crypto/tls/common.go#L1273 func (cri *CertificateRequestInfo) SupportsCertificate(c *tls.Certificate) error { if len(cri.AcceptableCAs) == 0 { return nil } for j, cert := range c.Certificate { x509Cert := c.Leaf // Parse the certificate if this isn't the leaf node, or if // chain.Leaf was nil. if j != 0 || x509Cert == nil { var err error if x509Cert, err = x509.ParseCertificate(cert); err != nil { return fmt.Errorf("failed to parse certificate #%d in the chain: %w", j, err) } } for _, ca := range cri.AcceptableCAs { if bytes.Equal(x509Cert.RawIssuer, ca) { return nil } } } return errNotAcceptableCertificateChain } func (c *handshakeConfig) setNameToCertificateLocked() { nameToCertificate := make(map[string]*tls.Certificate) for i := range c.localCertificates { cert := &c.localCertificates[i] x509Cert := cert.Leaf if x509Cert == nil { var parseErr error x509Cert, parseErr = x509.ParseCertificate(cert.Certificate[0]) if parseErr != nil { continue } } if len(x509Cert.Subject.CommonName) > 0 { nameToCertificate[strings.ToLower(x509Cert.Subject.CommonName)] = cert } for _, san := range x509Cert.DNSNames { nameToCertificate[strings.ToLower(san)] = cert } } c.nameToCertificate = nameToCertificate } func (c *handshakeConfig) getCertificate(clientHelloInfo *ClientHelloInfo) (*tls.Certificate, error) { c.mu.Lock() defer c.mu.Unlock() if c.localGetCertificate != nil && (len(c.localCertificates) == 0 || len(clientHelloInfo.ServerName) > 0) { cert, err := c.localGetCertificate(clientHelloInfo) if cert != nil || err != nil { return cert, err } } if c.nameToCertificate == nil { c.setNameToCertificateLocked() } if len(c.localCertificates) == 0 { return nil, errNoCertificates } if len(c.localCertificates) == 1 { // There's only one choice, so no point doing any work. return &c.localCertificates[0], nil } if len(clientHelloInfo.ServerName) == 0 { return &c.localCertificates[0], nil } name := strings.TrimRight(strings.ToLower(clientHelloInfo.ServerName), ".") if cert, ok := c.nameToCertificate[name]; ok { return cert, nil } // try replacing labels in the name with wildcards until we get a // match. labels := strings.Split(name, ".") for i := range labels { labels[i] = "*" candidate := strings.Join(labels, ".") if cert, ok := c.nameToCertificate[candidate]; ok { return cert, nil } } // If nothing matches, return the first certificate. return &c.localCertificates[0], nil } // NOTE: original src: https://github.com/golang/go/blob/29b9a328d268d53833d2cc063d1d8b4bf6852675/src/crypto/tls/handshake_client.go#L974 func (c *handshakeConfig) getClientCertificate(cri *CertificateRequestInfo) (*tls.Certificate, error) { c.mu.Lock() defer c.mu.Unlock() if c.localGetClientCertificate != nil { return c.localGetClientCertificate(cri) } for i := range c.localCertificates { chain := c.localCertificates[i] if err := cri.SupportsCertificate(&chain); err != nil { continue } return &chain, nil } // No acceptable certificate found. Don't send a certificate. return new(tls.Certificate), nil } dtls-2.2.6/certificate_test.go000066400000000000000000000046051437412644000163510ustar00rootroot00000000000000package dtls import ( "crypto/tls" "reflect" "testing" "github.com/pion/dtls/v2/pkg/crypto/selfsign" ) func TestGetCertificate(t *testing.T) { certificateWildcard, err := selfsign.GenerateSelfSignedWithDNS("*.test.test") if err != nil { t.Fatal(err) } certificateTest, err := selfsign.GenerateSelfSignedWithDNS("test.test", "www.test.test", "pop.test.test") if err != nil { t.Fatal(err) } certificateRandom, err := selfsign.GenerateSelfSigned() if err != nil { t.Fatal(err) } testCases := []struct { localCertificates []tls.Certificate desc string serverName string expectedCertificate tls.Certificate getCertificate func(info *ClientHelloInfo) (*tls.Certificate, error) }{ { desc: "Simple match in CN", localCertificates: []tls.Certificate{ certificateRandom, certificateTest, certificateWildcard, }, serverName: "test.test", expectedCertificate: certificateTest, }, { desc: "Simple match in SANs", localCertificates: []tls.Certificate{ certificateRandom, certificateTest, certificateWildcard, }, serverName: "www.test.test", expectedCertificate: certificateTest, }, { desc: "Wildcard match", localCertificates: []tls.Certificate{ certificateRandom, certificateTest, certificateWildcard, }, serverName: "foo.test.test", expectedCertificate: certificateWildcard, }, { desc: "No match return first", localCertificates: []tls.Certificate{ certificateRandom, certificateTest, certificateWildcard, }, serverName: "foo.bar", expectedCertificate: certificateRandom, }, { desc: "Get certificate from callback", getCertificate: func(info *ClientHelloInfo) (*tls.Certificate, error) { return &certificateTest, nil }, expectedCertificate: certificateTest, }, } for _, test := range testCases { test := test t.Run(test.desc, func(t *testing.T) { cfg := &handshakeConfig{ localCertificates: test.localCertificates, localGetCertificate: test.getCertificate, } cert, err := cfg.getCertificate(&ClientHelloInfo{ServerName: test.serverName}) if err != nil { t.Fatal(err) } if !reflect.DeepEqual(cert.Leaf, test.expectedCertificate.Leaf) { t.Fatalf("Certificate does not match: expected(%v) actual(%v)", test.expectedCertificate.Leaf, cert.Leaf) } }) } } dtls-2.2.6/cipher_suite.go000066400000000000000000000242111437412644000155060ustar00rootroot00000000000000package dtls import ( "crypto/ecdsa" "crypto/ed25519" "crypto/rsa" "crypto/tls" "fmt" "hash" "github.com/pion/dtls/v2/internal/ciphersuite" "github.com/pion/dtls/v2/pkg/crypto/clientcertificate" "github.com/pion/dtls/v2/pkg/protocol/recordlayer" ) // CipherSuiteID is an ID for our supported CipherSuites type CipherSuiteID = ciphersuite.ID // Supported Cipher Suites const ( // AES-128-CCM TLS_ECDHE_ECDSA_WITH_AES_128_CCM CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_128_CCM //nolint:revive,stylecheck TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 //nolint:revive,stylecheck // AES-128-GCM-SHA256 TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 //nolint:revive,stylecheck TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 CipherSuiteID = ciphersuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 //nolint:revive,stylecheck TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 //nolint:revive,stylecheck TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 CipherSuiteID = ciphersuite.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 //nolint:revive,stylecheck // AES-256-CBC-SHA TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA //nolint:revive,stylecheck TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA CipherSuiteID = ciphersuite.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA //nolint:revive,stylecheck TLS_PSK_WITH_AES_128_CCM CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_CCM //nolint:revive,stylecheck TLS_PSK_WITH_AES_128_CCM_8 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_CCM_8 //nolint:revive,stylecheck TLS_PSK_WITH_AES_256_CCM_8 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_256_CCM_8 //nolint:revive,stylecheck TLS_PSK_WITH_AES_128_GCM_SHA256 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_GCM_SHA256 //nolint:revive,stylecheck TLS_PSK_WITH_AES_128_CBC_SHA256 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_CBC_SHA256 //nolint:revive,stylecheck TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 CipherSuiteID = ciphersuite.TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 //nolint:revive,stylecheck ) // CipherSuiteAuthenticationType controls what authentication method is using during the handshake for a CipherSuite type CipherSuiteAuthenticationType = ciphersuite.AuthenticationType // AuthenticationType Enums const ( CipherSuiteAuthenticationTypeCertificate CipherSuiteAuthenticationType = ciphersuite.AuthenticationTypeCertificate CipherSuiteAuthenticationTypePreSharedKey CipherSuiteAuthenticationType = ciphersuite.AuthenticationTypePreSharedKey CipherSuiteAuthenticationTypeAnonymous CipherSuiteAuthenticationType = ciphersuite.AuthenticationTypeAnonymous ) // CipherSuiteKeyExchangeAlgorithm controls what exchange algorithm is using during the handshake for a CipherSuite type CipherSuiteKeyExchangeAlgorithm = ciphersuite.KeyExchangeAlgorithm // CipherSuiteKeyExchangeAlgorithm Bitmask const ( CipherSuiteKeyExchangeAlgorithmNone CipherSuiteKeyExchangeAlgorithm = ciphersuite.KeyExchangeAlgorithmNone CipherSuiteKeyExchangeAlgorithmPsk CipherSuiteKeyExchangeAlgorithm = ciphersuite.KeyExchangeAlgorithmPsk CipherSuiteKeyExchangeAlgorithmEcdhe CipherSuiteKeyExchangeAlgorithm = ciphersuite.KeyExchangeAlgorithmEcdhe ) var _ = allCipherSuites() // Necessary until this function isn't only used by Go 1.14 // CipherSuite is an interface that all DTLS CipherSuites must satisfy type CipherSuite interface { // String of CipherSuite, only used for logging String() string // ID of CipherSuite. ID() CipherSuiteID // What type of Certificate does this CipherSuite use CertificateType() clientcertificate.Type // What Hash function is used during verification HashFunc() func() hash.Hash // AuthenticationType controls what authentication method is using during the handshake AuthenticationType() CipherSuiteAuthenticationType // KeyExchangeAlgorithm controls what exchange algorithm is using during the handshake KeyExchangeAlgorithm() CipherSuiteKeyExchangeAlgorithm // ECC (Elliptic Curve Cryptography) determines whether ECC extesions will be send during handshake. // https://datatracker.ietf.org/doc/html/rfc4492#page-10 ECC() bool // Called when keying material has been generated, should initialize the internal cipher Init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error IsInitialized() bool Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) Decrypt(in []byte) ([]byte, error) } // CipherSuiteName provides the same functionality as tls.CipherSuiteName // that appeared first in Go 1.14. // // Our implementation differs slightly in that it takes in a CiperSuiteID, // like the rest of our library, instead of a uint16 like crypto/tls. func CipherSuiteName(id CipherSuiteID) string { suite := cipherSuiteForID(id, nil) if suite != nil { return suite.String() } return fmt.Sprintf("0x%04X", uint16(id)) } // Taken from https://www.iana.org/assignments/tls-parameters/tls-parameters.xml // A cipherSuite is a specific combination of key agreement, cipher and MAC // function. func cipherSuiteForID(id CipherSuiteID, customCiphers func() []CipherSuite) CipherSuite { switch id { //nolint:exhaustive case TLS_ECDHE_ECDSA_WITH_AES_128_CCM: return ciphersuite.NewTLSEcdheEcdsaWithAes128Ccm() case TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8: return ciphersuite.NewTLSEcdheEcdsaWithAes128Ccm8() case TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: return &ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{} case TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: return &ciphersuite.TLSEcdheRsaWithAes128GcmSha256{} case TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA: return &ciphersuite.TLSEcdheEcdsaWithAes256CbcSha{} case TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA: return &ciphersuite.TLSEcdheRsaWithAes256CbcSha{} case TLS_PSK_WITH_AES_128_CCM: return ciphersuite.NewTLSPskWithAes128Ccm() case TLS_PSK_WITH_AES_128_CCM_8: return ciphersuite.NewTLSPskWithAes128Ccm8() case TLS_PSK_WITH_AES_256_CCM_8: return ciphersuite.NewTLSPskWithAes256Ccm8() case TLS_PSK_WITH_AES_128_GCM_SHA256: return &ciphersuite.TLSPskWithAes128GcmSha256{} case TLS_PSK_WITH_AES_128_CBC_SHA256: return &ciphersuite.TLSPskWithAes128CbcSha256{} case TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: return &ciphersuite.TLSEcdheEcdsaWithAes256GcmSha384{} case TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384: return &ciphersuite.TLSEcdheRsaWithAes256GcmSha384{} case TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256: return ciphersuite.NewTLSEcdhePskWithAes128CbcSha256() } if customCiphers != nil { for _, c := range customCiphers() { if c.ID() == id { return c } } } return nil } // CipherSuites we support in order of preference func defaultCipherSuites() []CipherSuite { return []CipherSuite{ &ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{}, &ciphersuite.TLSEcdheRsaWithAes128GcmSha256{}, &ciphersuite.TLSEcdheEcdsaWithAes256CbcSha{}, &ciphersuite.TLSEcdheRsaWithAes256CbcSha{}, &ciphersuite.TLSEcdheEcdsaWithAes256GcmSha384{}, &ciphersuite.TLSEcdheRsaWithAes256GcmSha384{}, } } func allCipherSuites() []CipherSuite { return []CipherSuite{ ciphersuite.NewTLSEcdheEcdsaWithAes128Ccm(), ciphersuite.NewTLSEcdheEcdsaWithAes128Ccm8(), &ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{}, &ciphersuite.TLSEcdheRsaWithAes128GcmSha256{}, &ciphersuite.TLSEcdheEcdsaWithAes256CbcSha{}, &ciphersuite.TLSEcdheRsaWithAes256CbcSha{}, ciphersuite.NewTLSPskWithAes128Ccm(), ciphersuite.NewTLSPskWithAes128Ccm8(), ciphersuite.NewTLSPskWithAes256Ccm8(), &ciphersuite.TLSPskWithAes128GcmSha256{}, &ciphersuite.TLSEcdheEcdsaWithAes256GcmSha384{}, &ciphersuite.TLSEcdheRsaWithAes256GcmSha384{}, } } func cipherSuiteIDs(cipherSuites []CipherSuite) []uint16 { rtrn := []uint16{} for _, c := range cipherSuites { rtrn = append(rtrn, uint16(c.ID())) } return rtrn } func parseCipherSuites(userSelectedSuites []CipherSuiteID, customCipherSuites func() []CipherSuite, includeCertificateSuites, includePSKSuites bool) ([]CipherSuite, error) { cipherSuitesForIDs := func(ids []CipherSuiteID) ([]CipherSuite, error) { cipherSuites := []CipherSuite{} for _, id := range ids { c := cipherSuiteForID(id, nil) if c == nil { return nil, &invalidCipherSuiteError{id} } cipherSuites = append(cipherSuites, c) } return cipherSuites, nil } var ( cipherSuites []CipherSuite err error i int ) if userSelectedSuites != nil { cipherSuites, err = cipherSuitesForIDs(userSelectedSuites) if err != nil { return nil, err } } else { cipherSuites = defaultCipherSuites() } // Put CustomCipherSuites before ID selected suites if customCipherSuites != nil { cipherSuites = append(customCipherSuites(), cipherSuites...) } var foundCertificateSuite, foundPSKSuite, foundAnonymousSuite bool for _, c := range cipherSuites { switch { case includeCertificateSuites && c.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate: foundCertificateSuite = true case includePSKSuites && c.AuthenticationType() == CipherSuiteAuthenticationTypePreSharedKey: foundPSKSuite = true case c.AuthenticationType() == CipherSuiteAuthenticationTypeAnonymous: foundAnonymousSuite = true default: continue } cipherSuites[i] = c i++ } switch { case includeCertificateSuites && !foundCertificateSuite && !foundAnonymousSuite: return nil, errNoAvailableCertificateCipherSuite case includePSKSuites && !foundPSKSuite: return nil, errNoAvailablePSKCipherSuite case i == 0: return nil, errNoAvailableCipherSuites } return cipherSuites[:i], nil } func filterCipherSuitesForCertificate(cert *tls.Certificate, cipherSuites []CipherSuite) []CipherSuite { if cert == nil || cert.PrivateKey == nil { return cipherSuites } var certType clientcertificate.Type switch cert.PrivateKey.(type) { case ed25519.PrivateKey, *ecdsa.PrivateKey: certType = clientcertificate.ECDSASign case *rsa.PrivateKey: certType = clientcertificate.RSASign } filtered := []CipherSuite{} for _, c := range cipherSuites { if c.AuthenticationType() != CipherSuiteAuthenticationTypeCertificate || certType == c.CertificateType() { filtered = append(filtered, c) } } return filtered } dtls-2.2.6/cipher_suite_go114.go000066400000000000000000000020761437412644000164260ustar00rootroot00000000000000//go:build go1.14 // +build go1.14 package dtls import ( "crypto/tls" ) // VersionDTLS12 is the DTLS version in the same style as // VersionTLSXX from crypto/tls const VersionDTLS12 = 0xfefd // Convert from our cipherSuite interface to a tls.CipherSuite struct func toTLSCipherSuite(c CipherSuite) *tls.CipherSuite { return &tls.CipherSuite{ ID: uint16(c.ID()), Name: c.String(), SupportedVersions: []uint16{VersionDTLS12}, Insecure: false, } } // CipherSuites returns a list of cipher suites currently implemented by this // package, excluding those with security issues, which are returned by // InsecureCipherSuites. func CipherSuites() []*tls.CipherSuite { suites := allCipherSuites() res := make([]*tls.CipherSuite, len(suites)) for i, c := range suites { res[i] = toTLSCipherSuite(c) } return res } // InsecureCipherSuites returns a list of cipher suites currently implemented by // this package and which have security issues. func InsecureCipherSuites() []*tls.CipherSuite { var res []*tls.CipherSuite return res } dtls-2.2.6/cipher_suite_go114_test.go000066400000000000000000000021311437412644000174550ustar00rootroot00000000000000//go:build go1.14 // +build go1.14 package dtls import ( "testing" ) func TestInsecureCipherSuites(t *testing.T) { r := InsecureCipherSuites() if len(r) != 0 { t.Fatalf("Expected no insecure ciphersuites, got %d", len(r)) } } func TestCipherSuites(t *testing.T) { ours := allCipherSuites() theirs := CipherSuites() if len(ours) != len(theirs) { t.Fatalf("Expected %d CipherSuites, got %d", len(ours), len(theirs)) } for i, s := range ours { i := i s := s t.Run(s.String(), func(t *testing.T) { c := theirs[i] if c.ID != uint16(s.ID()) { t.Fatalf("Expected ID: 0x%04X, got 0x%04X", s.ID(), c.ID) } if c.Name != s.String() { t.Fatalf("Expected Name: %s, got %s", s.String(), c.Name) } if len(c.SupportedVersions) != 1 { t.Fatalf("Expected %d SupportedVersion, got %d", 1, len(c.SupportedVersions)) } if c.SupportedVersions[0] != VersionDTLS12 { t.Fatalf("Expected SupportedVersions 0x%04X, got 0x%04X", VersionDTLS12, c.SupportedVersions[0]) } if c.Insecure { t.Fatalf("Expected Insecure %t, got %t", false, c.Insecure) } }) } } dtls-2.2.6/cipher_suite_test.go000066400000000000000000000046461437412644000165570ustar00rootroot00000000000000package dtls import ( "context" "testing" "time" "github.com/pion/dtls/v2/internal/ciphersuite" "github.com/pion/dtls/v2/internal/net/dpipe" "github.com/pion/transport/v2/test" ) func TestCipherSuiteName(t *testing.T) { testCases := []struct { suite CipherSuiteID expected string }{ {TLS_ECDHE_ECDSA_WITH_AES_128_CCM, "TLS_ECDHE_ECDSA_WITH_AES_128_CCM"}, {CipherSuiteID(0x0000), "0x0000"}, } for _, testCase := range testCases { res := CipherSuiteName(testCase.suite) if res != testCase.expected { t.Fatalf("Expected: %s, got %s", testCase.expected, res) } } } func TestAllCipherSuites(t *testing.T) { actual := len(allCipherSuites()) if actual == 0 { t.Fatal() } } // CustomCipher that is just used to assert Custom IDs work type testCustomCipherSuite struct { ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256 authenticationType CipherSuiteAuthenticationType } func (t *testCustomCipherSuite) ID() CipherSuiteID { return 0xFFFF } func (t *testCustomCipherSuite) AuthenticationType() CipherSuiteAuthenticationType { return t.authenticationType } // Assert that two connections that pass in a CipherSuite with a CustomID works func TestCustomCipherSuite(t *testing.T) { type result struct { c *Conn err error } // Check for leaking routines report := test.CheckRoutines(t) defer report() runTest := func(cipherFactory func() []CipherSuite) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() ca, cb := dpipe.Pipe() c := make(chan result) go func() { client, err := testClient(ctx, ca, &Config{ CipherSuites: []CipherSuiteID{}, CustomCipherSuites: cipherFactory, }, true) c <- result{client, err} }() server, err := testServer(ctx, cb, &Config{ CipherSuites: []CipherSuiteID{}, CustomCipherSuites: cipherFactory, }, true) clientResult := <-c if err != nil { t.Error(err) } else { _ = server.Close() } if clientResult.err != nil { t.Error(clientResult.err) } else { _ = clientResult.c.Close() } } t.Run("Custom ID", func(t *testing.T) { runTest(func() []CipherSuite { return []CipherSuite{&testCustomCipherSuite{authenticationType: CipherSuiteAuthenticationTypeCertificate}} }) }) t.Run("Anonymous Cipher", func(t *testing.T) { runTest(func() []CipherSuite { return []CipherSuite{&testCustomCipherSuite{authenticationType: CipherSuiteAuthenticationTypeAnonymous}} }) }) } dtls-2.2.6/codecov.yml000066400000000000000000000005521437412644000146430ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # coverage: status: project: default: # Allow decreasing 2% of total coverage to avoid noise. threshold: 2% patch: default: target: 70% only_pulls: true ignore: - "examples/*" - "examples/**/*" dtls-2.2.6/compression_method.go000066400000000000000000000002601437412644000167220ustar00rootroot00000000000000package dtls import "github.com/pion/dtls/v2/pkg/protocol" func defaultCompressionMethods() []*protocol.CompressionMethod { return []*protocol.CompressionMethod{ {}, } } dtls-2.2.6/config.go000066400000000000000000000221551437412644000142750ustar00rootroot00000000000000package dtls import ( "context" "crypto/ecdsa" "crypto/ed25519" "crypto/rsa" "crypto/tls" "crypto/x509" "io" "time" "github.com/pion/dtls/v2/pkg/crypto/elliptic" "github.com/pion/logging" ) const keyLogLabelTLS12 = "CLIENT_RANDOM" // Config is used to configure a DTLS client or server. // After a Config is passed to a DTLS function it must not be modified. type Config struct { // Certificates contains certificate chain to present to the other side of the connection. // Server MUST set this if PSK is non-nil // client SHOULD sets this so CertificateRequests can be handled if PSK is non-nil Certificates []tls.Certificate // CipherSuites is a list of supported cipher suites. // If CipherSuites is nil, a default list is used CipherSuites []CipherSuiteID // CustomCipherSuites is a list of CipherSuites that can be // provided by the user. This allow users to user Ciphers that are reserved // for private usage. CustomCipherSuites func() []CipherSuite // SignatureSchemes contains the signature and hash schemes that the peer requests to verify. SignatureSchemes []tls.SignatureScheme // SRTPProtectionProfiles are the supported protection profiles // Clients will send this via use_srtp and assert that the server properly responds // Servers will assert that clients send one of these profiles and will respond as needed SRTPProtectionProfiles []SRTPProtectionProfile // ClientAuth determines the server's policy for // TLS Client Authentication. The default is NoClientCert. ClientAuth ClientAuthType // RequireExtendedMasterSecret determines if the "Extended Master Secret" extension // should be disabled, requested, or required (default requested). ExtendedMasterSecret ExtendedMasterSecretType // FlightInterval controls how often we send outbound handshake messages // defaults to time.Second FlightInterval time.Duration // PSK sets the pre-shared key used by this DTLS connection // If PSK is non-nil only PSK CipherSuites will be used PSK PSKCallback PSKIdentityHint []byte // InsecureSkipVerify controls whether a client verifies the // server's certificate chain and host name. // If InsecureSkipVerify is true, TLS accepts any certificate // presented by the server and any host name in that certificate. // In this mode, TLS is susceptible to man-in-the-middle attacks. // This should be used only for testing. InsecureSkipVerify bool // InsecureHashes allows the use of hashing algorithms that are known // to be vulnerable. InsecureHashes bool // VerifyPeerCertificate, if not nil, is called after normal // certificate verification by either a client or server. It // receives the certificate provided by the peer and also a flag // that tells if normal verification has succeedded. If it returns a // non-nil error, the handshake is aborted and that error results. // // If normal verification fails then the handshake will abort before // considering this callback. If normal verification is disabled by // setting InsecureSkipVerify, or (for a server) when ClientAuth is // RequestClientCert or RequireAnyClientCert, then this callback will // be considered but the verifiedChains will always be nil. VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error // VerifyConnection, if not nil, is called after normal certificate // verification/PSK and after VerifyPeerCertificate by either a TLS client // or server. If it returns a non-nil error, the handshake is aborted // and that error results. // // If normal verification fails then the handshake will abort before // considering this callback. This callback will run for all connections // regardless of InsecureSkipVerify or ClientAuth settings. VerifyConnection func(*State) error // RootCAs defines the set of root certificate authorities // that one peer uses when verifying the other peer's certificates. // If RootCAs is nil, TLS uses the host's root CA set. RootCAs *x509.CertPool // ClientCAs defines the set of root certificate authorities // that servers use if required to verify a client certificate // by the policy in ClientAuth. ClientCAs *x509.CertPool // ServerName is used to verify the hostname on the returned // certificates unless InsecureSkipVerify is given. ServerName string LoggerFactory logging.LoggerFactory // ConnectContextMaker is a function to make a context used in Dial(), // Client(), Server(), and Accept(). If nil, the default ConnectContextMaker // is used. It can be implemented as following. // // func ConnectContextMaker() (context.Context, func()) { // return context.WithTimeout(context.Background(), 30*time.Second) // } ConnectContextMaker func() (context.Context, func()) // MTU is the length at which handshake messages will be fragmented to // fit within the maximum transmission unit (default is 1200 bytes) MTU int // ReplayProtectionWindow is the size of the replay attack protection window. // Duplication of the sequence number is checked in this window size. // Packet with sequence number older than this value compared to the latest // accepted packet will be discarded. (default is 64) ReplayProtectionWindow int // KeyLogWriter optionally specifies a destination for TLS master secrets // in NSS key log format that can be used to allow external programs // such as Wireshark to decrypt TLS connections. // See https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/Key_Log_Format. // Use of KeyLogWriter compromises security and should only be // used for debugging. KeyLogWriter io.Writer // SessionStore is the container to store session for resumption. SessionStore SessionStore // List of application protocols the peer supports, for ALPN SupportedProtocols []string // List of Elliptic Curves to use // // If an ECC ciphersuite is configured and EllipticCurves is empty // it will default to X25519, P-256, P-384 in this specific order. EllipticCurves []elliptic.Curve // GetCertificate returns a Certificate based on the given // ClientHelloInfo. It will only be called if the client supplies SNI // information or if Certificates is empty. // // If GetCertificate is nil or returns nil, then the certificate is // retrieved from NameToCertificate. If NameToCertificate is nil, the // best element of Certificates will be used. GetCertificate func(*ClientHelloInfo) (*tls.Certificate, error) // GetClientCertificate, if not nil, is called when a server requests a // certificate from a client. If set, the contents of Certificates will // be ignored. // // If GetClientCertificate returns an error, the handshake will be // aborted and that error will be returned. Otherwise // GetClientCertificate must return a non-nil Certificate. If // Certificate.Certificate is empty then no certificate will be sent to // the server. If this is unacceptable to the server then it may abort // the handshake. GetClientCertificate func(*CertificateRequestInfo) (*tls.Certificate, error) // InsecureSkipVerifyHello, if true and when acting as server, allow client to // skip hello verify phase and receive ServerHello after initial ClientHello. // This have implication on DoS attack resistance. InsecureSkipVerifyHello bool } func defaultConnectContextMaker() (context.Context, func()) { return context.WithTimeout(context.Background(), 30*time.Second) } func (c *Config) connectContextMaker() (context.Context, func()) { if c.ConnectContextMaker == nil { return defaultConnectContextMaker() } return c.ConnectContextMaker() } func (c *Config) includeCertificateSuites() bool { return c.PSK == nil || len(c.Certificates) > 0 || c.GetCertificate != nil || c.GetClientCertificate != nil } const defaultMTU = 1200 // bytes var defaultCurves = []elliptic.Curve{elliptic.X25519, elliptic.P256, elliptic.P384} //nolint:gochecknoglobals // PSKCallback is called once we have the remote's PSKIdentityHint. // If the remote provided none it will be nil type PSKCallback func([]byte) ([]byte, error) // ClientAuthType declares the policy the server will follow for // TLS Client Authentication. type ClientAuthType int // ClientAuthType enums const ( NoClientCert ClientAuthType = iota RequestClientCert RequireAnyClientCert VerifyClientCertIfGiven RequireAndVerifyClientCert ) // ExtendedMasterSecretType declares the policy the client and server // will follow for the Extended Master Secret extension type ExtendedMasterSecretType int // ExtendedMasterSecretType enums const ( RequestExtendedMasterSecret ExtendedMasterSecretType = iota RequireExtendedMasterSecret DisableExtendedMasterSecret ) func validateConfig(config *Config) error { switch { case config == nil: return errNoConfigProvided case config.PSKIdentityHint != nil && config.PSK == nil: return errIdentityNoPSK } for _, cert := range config.Certificates { if cert.Certificate == nil { return errInvalidCertificate } if cert.PrivateKey != nil { switch cert.PrivateKey.(type) { case ed25519.PrivateKey: case *ecdsa.PrivateKey: case *rsa.PrivateKey: default: return errInvalidPrivateKey } } } _, err := parseCipherSuites(config.CipherSuites, config.CustomCipherSuites, config.includeCertificateSuites(), config.PSK != nil) return err } dtls-2.2.6/config_test.go000066400000000000000000000103331437412644000153270ustar00rootroot00000000000000package dtls import ( "crypto/dsa" //nolint:staticcheck "crypto/rand" "crypto/rsa" "crypto/tls" "errors" "testing" "github.com/pion/dtls/v2/pkg/crypto/selfsign" ) func TestValidateConfig(t *testing.T) { cert, err := selfsign.GenerateSelfSigned() if err != nil { t.Fatalf("TestValidateConfig: Config validation error(%v), self signed certificate not generated", err) return } dsaPrivateKey := &dsa.PrivateKey{} err = dsa.GenerateParameters(&dsaPrivateKey.Parameters, rand.Reader, dsa.L1024N160) if err != nil { t.Fatalf("TestValidateConfig: Config validation error(%v), DSA parameters not generated", err) return } err = dsa.GenerateKey(dsaPrivateKey, rand.Reader) if err != nil { t.Fatalf("TestValidateConfig: Config validation error(%v), DSA private key not generated", err) return } rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { t.Fatalf("TestValidateConfig: Config validation error(%v), RSA private key not generated", err) return } cases := map[string]struct { config *Config wantAnyErr bool expErr error }{ "Empty config": { expErr: errNoConfigProvided, }, "PSK and Certificate, valid cipher suites": { config: &Config{ CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, PSK: func(hint []byte) ([]byte, error) { return nil, nil }, Certificates: []tls.Certificate{cert}, }, }, "PSK and Certificate, no PSK cipher suite": { config: &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, PSK: func(hint []byte) ([]byte, error) { return nil, nil }, Certificates: []tls.Certificate{cert}, }, expErr: errNoAvailablePSKCipherSuite, }, "PSK and Certificate, no non-PSK cipher suite": { config: &Config{ CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, PSK: func(hint []byte) ([]byte, error) { return nil, nil }, Certificates: []tls.Certificate{cert}, }, expErr: errNoAvailableCertificateCipherSuite, }, "PSK identity hint with not PSK": { config: &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, PSK: nil, PSKIdentityHint: []byte{}, }, expErr: errIdentityNoPSK, }, "Invalid private key": { config: &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, Certificates: []tls.Certificate{{Certificate: cert.Certificate, PrivateKey: dsaPrivateKey}}, }, expErr: errInvalidPrivateKey, }, "PrivateKey without Certificate": { config: &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, Certificates: []tls.Certificate{{PrivateKey: cert.PrivateKey}}, }, expErr: errInvalidCertificate, }, "Invalid cipher suites": { config: &Config{CipherSuites: []CipherSuiteID{0x0000}}, wantAnyErr: true, }, "Valid config": { config: &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, Certificates: []tls.Certificate{cert, {Certificate: cert.Certificate, PrivateKey: rsaPrivateKey}}, }, }, "Valid config with get certificate": { config: &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, GetCertificate: func(chi *ClientHelloInfo) (*tls.Certificate, error) { return &tls.Certificate{Certificate: cert.Certificate, PrivateKey: rsaPrivateKey}, nil }, }, }, "Valid config with get client certificate": { config: &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, GetClientCertificate: func(cri *CertificateRequestInfo) (*tls.Certificate, error) { return &tls.Certificate{Certificate: cert.Certificate, PrivateKey: rsaPrivateKey}, nil }, }, }, } for name, testCase := range cases { testCase := testCase t.Run(name, func(t *testing.T) { err := validateConfig(testCase.config) if testCase.expErr != nil || testCase.wantAnyErr { if testCase.expErr != nil && !errors.Is(err, testCase.expErr) { t.Fatalf("TestValidateConfig: Config validation error exp(%v) failed(%v)", testCase.expErr, err) } if err == nil { t.Fatalf("TestValidateConfig: Config validation expected an error") } } }) } } dtls-2.2.6/conn.go000066400000000000000000000660211437412644000137650ustar00rootroot00000000000000package dtls import ( "context" "errors" "fmt" "io" "net" "sync" "sync/atomic" "time" "github.com/pion/dtls/v2/internal/closer" "github.com/pion/dtls/v2/pkg/crypto/elliptic" "github.com/pion/dtls/v2/pkg/crypto/signaturehash" "github.com/pion/dtls/v2/pkg/protocol" "github.com/pion/dtls/v2/pkg/protocol/alert" "github.com/pion/dtls/v2/pkg/protocol/handshake" "github.com/pion/dtls/v2/pkg/protocol/recordlayer" "github.com/pion/logging" "github.com/pion/transport/v2/connctx" "github.com/pion/transport/v2/deadline" "github.com/pion/transport/v2/replaydetector" ) const ( initialTickerInterval = time.Second cookieLength = 20 sessionLength = 32 defaultNamedCurve = elliptic.X25519 inboundBufferSize = 8192 // Default replay protection window is specified by RFC 6347 Section 4.1.2.6 defaultReplayProtectionWindow = 64 ) func invalidKeyingLabels() map[string]bool { return map[string]bool{ "client finished": true, "server finished": true, "master secret": true, "key expansion": true, } } // Conn represents a DTLS connection type Conn struct { lock sync.RWMutex // Internal lock (must not be public) nextConn connctx.ConnCtx // Embedded Conn, typically a udpconn we read/write from fragmentBuffer *fragmentBuffer // out-of-order and missing fragment handling handshakeCache *handshakeCache // caching of handshake messages for verifyData generation decrypted chan interface{} // Decrypted Application Data or error, pull by calling `Read` state State // Internal state maximumTransmissionUnit int handshakeCompletedSuccessfully atomic.Value encryptedPackets [][]byte connectionClosedByUser bool closeLock sync.Mutex closed *closer.Closer handshakeLoopsFinished sync.WaitGroup readDeadline *deadline.Deadline writeDeadline *deadline.Deadline log logging.LeveledLogger reading chan struct{} handshakeRecv chan chan struct{} cancelHandshaker func() cancelHandshakeReader func() fsm *handshakeFSM replayProtectionWindow uint } func createConn(ctx context.Context, nextConn net.Conn, config *Config, isClient bool, initialState *State) (*Conn, error) { err := validateConfig(config) if err != nil { return nil, err } if nextConn == nil { return nil, errNilNextConn } cipherSuites, err := parseCipherSuites(config.CipherSuites, config.CustomCipherSuites, config.includeCertificateSuites(), config.PSK != nil) if err != nil { return nil, err } signatureSchemes, err := signaturehash.ParseSignatureSchemes(config.SignatureSchemes, config.InsecureHashes) if err != nil { return nil, err } workerInterval := initialTickerInterval if config.FlightInterval != 0 { workerInterval = config.FlightInterval } loggerFactory := config.LoggerFactory if loggerFactory == nil { loggerFactory = logging.NewDefaultLoggerFactory() } logger := loggerFactory.NewLogger("dtls") mtu := config.MTU if mtu <= 0 { mtu = defaultMTU } replayProtectionWindow := config.ReplayProtectionWindow if replayProtectionWindow <= 0 { replayProtectionWindow = defaultReplayProtectionWindow } c := &Conn{ nextConn: connctx.New(nextConn), fragmentBuffer: newFragmentBuffer(), handshakeCache: newHandshakeCache(), maximumTransmissionUnit: mtu, decrypted: make(chan interface{}, 1), log: logger, readDeadline: deadline.New(), writeDeadline: deadline.New(), reading: make(chan struct{}, 1), handshakeRecv: make(chan chan struct{}), closed: closer.NewCloser(), cancelHandshaker: func() {}, replayProtectionWindow: uint(replayProtectionWindow), state: State{ isClient: isClient, }, } c.setRemoteEpoch(0) c.setLocalEpoch(0) serverName := config.ServerName // Do not allow the use of an IP address literal as an SNI value. // See RFC 6066, Section 3. if net.ParseIP(serverName) != nil { serverName = "" } curves := config.EllipticCurves if len(curves) == 0 { curves = defaultCurves } hsCfg := &handshakeConfig{ localPSKCallback: config.PSK, localPSKIdentityHint: config.PSKIdentityHint, localCipherSuites: cipherSuites, localSignatureSchemes: signatureSchemes, extendedMasterSecret: config.ExtendedMasterSecret, localSRTPProtectionProfiles: config.SRTPProtectionProfiles, serverName: serverName, supportedProtocols: config.SupportedProtocols, clientAuth: config.ClientAuth, localCertificates: config.Certificates, insecureSkipVerify: config.InsecureSkipVerify, verifyPeerCertificate: config.VerifyPeerCertificate, verifyConnection: config.VerifyConnection, rootCAs: config.RootCAs, clientCAs: config.ClientCAs, customCipherSuites: config.CustomCipherSuites, retransmitInterval: workerInterval, log: logger, initialEpoch: 0, keyLogWriter: config.KeyLogWriter, sessionStore: config.SessionStore, ellipticCurves: curves, localGetCertificate: config.GetCertificate, localGetClientCertificate: config.GetClientCertificate, insecureSkipHelloVerify: config.InsecureSkipVerifyHello, } // rfc5246#section-7.4.3 // In addition, the hash and signature algorithms MUST be compatible // with the key in the server's end-entity certificate. if !isClient { cert, err := hsCfg.getCertificate(&ClientHelloInfo{}) if err != nil && !errors.Is(err, errNoCertificates) { return nil, err } hsCfg.localCipherSuites = filterCipherSuitesForCertificate(cert, cipherSuites) } var initialFlight flightVal var initialFSMState handshakeState if initialState != nil { if c.state.isClient { initialFlight = flight5 } else { initialFlight = flight6 } initialFSMState = handshakeFinished c.state = *initialState } else { if c.state.isClient { initialFlight = flight1 } else { initialFlight = flight0 } initialFSMState = handshakePreparing } // Do handshake if err := c.handshake(ctx, hsCfg, initialFlight, initialFSMState); err != nil { return nil, err } c.log.Trace("Handshake Completed") return c, nil } // Dial connects to the given network address and establishes a DTLS connection on top. // Connection handshake will timeout using ConnectContextMaker in the Config. // If you want to specify the timeout duration, use DialWithContext() instead. func Dial(network string, raddr *net.UDPAddr, config *Config) (*Conn, error) { ctx, cancel := config.connectContextMaker() defer cancel() return DialWithContext(ctx, network, raddr, config) } // Client establishes a DTLS connection over an existing connection. // Connection handshake will timeout using ConnectContextMaker in the Config. // If you want to specify the timeout duration, use ClientWithContext() instead. func Client(conn net.Conn, config *Config) (*Conn, error) { ctx, cancel := config.connectContextMaker() defer cancel() return ClientWithContext(ctx, conn, config) } // Server listens for incoming DTLS connections. // Connection handshake will timeout using ConnectContextMaker in the Config. // If you want to specify the timeout duration, use ServerWithContext() instead. func Server(conn net.Conn, config *Config) (*Conn, error) { ctx, cancel := config.connectContextMaker() defer cancel() return ServerWithContext(ctx, conn, config) } // DialWithContext connects to the given network address and establishes a DTLS connection on top. func DialWithContext(ctx context.Context, network string, raddr *net.UDPAddr, config *Config) (*Conn, error) { pConn, err := net.DialUDP(network, nil, raddr) if err != nil { return nil, err } return ClientWithContext(ctx, pConn, config) } // ClientWithContext establishes a DTLS connection over an existing connection. func ClientWithContext(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) { switch { case config == nil: return nil, errNoConfigProvided case config.PSK != nil && config.PSKIdentityHint == nil: return nil, errPSKAndIdentityMustBeSetForClient } return createConn(ctx, conn, config, true, nil) } // ServerWithContext listens for incoming DTLS connections. func ServerWithContext(ctx context.Context, conn net.Conn, config *Config) (*Conn, error) { if config == nil { return nil, errNoConfigProvided } return createConn(ctx, conn, config, false, nil) } // Read reads data from the connection. func (c *Conn) Read(p []byte) (n int, err error) { if !c.isHandshakeCompletedSuccessfully() { return 0, errHandshakeInProgress } select { case <-c.readDeadline.Done(): return 0, errDeadlineExceeded default: } for { select { case <-c.readDeadline.Done(): return 0, errDeadlineExceeded case out, ok := <-c.decrypted: if !ok { return 0, io.EOF } switch val := out.(type) { case ([]byte): if len(p) < len(val) { return 0, errBufferTooSmall } copy(p, val) return len(val), nil case (error): return 0, val } } } } // Write writes len(p) bytes from p to the DTLS connection func (c *Conn) Write(p []byte) (int, error) { if c.isConnectionClosed() { return 0, ErrConnClosed } select { case <-c.writeDeadline.Done(): return 0, errDeadlineExceeded default: } if !c.isHandshakeCompletedSuccessfully() { return 0, errHandshakeInProgress } return len(p), c.writePackets(c.writeDeadline, []*packet{ { record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Epoch: c.state.getLocalEpoch(), Version: protocol.Version1_2, }, Content: &protocol.ApplicationData{ Data: p, }, }, shouldEncrypt: true, }, }) } // Close closes the connection. func (c *Conn) Close() error { err := c.close(true) //nolint:contextcheck c.handshakeLoopsFinished.Wait() return err } // ConnectionState returns basic DTLS details about the connection. // Note that this replaced the `Export` function of v1. func (c *Conn) ConnectionState() State { c.lock.RLock() defer c.lock.RUnlock() return *c.state.clone() } // SelectedSRTPProtectionProfile returns the selected SRTPProtectionProfile func (c *Conn) SelectedSRTPProtectionProfile() (SRTPProtectionProfile, bool) { c.lock.RLock() defer c.lock.RUnlock() if c.state.srtpProtectionProfile == 0 { return 0, false } return c.state.srtpProtectionProfile, true } func (c *Conn) writePackets(ctx context.Context, pkts []*packet) error { c.lock.Lock() defer c.lock.Unlock() var rawPackets [][]byte for _, p := range pkts { if h, ok := p.record.Content.(*handshake.Handshake); ok { handshakeRaw, err := p.record.Marshal() if err != nil { return err } c.log.Tracef("[handshake:%v] -> %s (epoch: %d, seq: %d)", srvCliStr(c.state.isClient), h.Header.Type.String(), p.record.Header.Epoch, h.Header.MessageSequence) c.handshakeCache.push(handshakeRaw[recordlayer.HeaderSize:], p.record.Header.Epoch, h.Header.MessageSequence, h.Header.Type, c.state.isClient) rawHandshakePackets, err := c.processHandshakePacket(p, h) if err != nil { return err } rawPackets = append(rawPackets, rawHandshakePackets...) } else { rawPacket, err := c.processPacket(p) if err != nil { return err } rawPackets = append(rawPackets, rawPacket) } } if len(rawPackets) == 0 { return nil } compactedRawPackets := c.compactRawPackets(rawPackets) for _, compactedRawPackets := range compactedRawPackets { if _, err := c.nextConn.WriteContext(ctx, compactedRawPackets); err != nil { return netError(err) } } return nil } func (c *Conn) compactRawPackets(rawPackets [][]byte) [][]byte { combinedRawPackets := make([][]byte, 0) currentCombinedRawPacket := make([]byte, 0) for _, rawPacket := range rawPackets { if len(currentCombinedRawPacket) > 0 && len(currentCombinedRawPacket)+len(rawPacket) >= c.maximumTransmissionUnit { combinedRawPackets = append(combinedRawPackets, currentCombinedRawPacket) currentCombinedRawPacket = []byte{} } currentCombinedRawPacket = append(currentCombinedRawPacket, rawPacket...) } combinedRawPackets = append(combinedRawPackets, currentCombinedRawPacket) return combinedRawPackets } func (c *Conn) processPacket(p *packet) ([]byte, error) { epoch := p.record.Header.Epoch for len(c.state.localSequenceNumber) <= int(epoch) { c.state.localSequenceNumber = append(c.state.localSequenceNumber, uint64(0)) } seq := atomic.AddUint64(&c.state.localSequenceNumber[epoch], 1) - 1 if seq > recordlayer.MaxSequenceNumber { // RFC 6347 Section 4.1.0 // The implementation must either abandon an association or rehandshake // prior to allowing the sequence number to wrap. return nil, errSequenceNumberOverflow } p.record.Header.SequenceNumber = seq rawPacket, err := p.record.Marshal() if err != nil { return nil, err } if p.shouldEncrypt { var err error rawPacket, err = c.state.cipherSuite.Encrypt(p.record, rawPacket) if err != nil { return nil, err } } return rawPacket, nil } func (c *Conn) processHandshakePacket(p *packet, h *handshake.Handshake) ([][]byte, error) { rawPackets := make([][]byte, 0) handshakeFragments, err := c.fragmentHandshake(h) if err != nil { return nil, err } epoch := p.record.Header.Epoch for len(c.state.localSequenceNumber) <= int(epoch) { c.state.localSequenceNumber = append(c.state.localSequenceNumber, uint64(0)) } for _, handshakeFragment := range handshakeFragments { seq := atomic.AddUint64(&c.state.localSequenceNumber[epoch], 1) - 1 if seq > recordlayer.MaxSequenceNumber { return nil, errSequenceNumberOverflow } recordlayerHeader := &recordlayer.Header{ Version: p.record.Header.Version, ContentType: p.record.Header.ContentType, ContentLen: uint16(len(handshakeFragment)), Epoch: p.record.Header.Epoch, SequenceNumber: seq, } rawPacket, err := recordlayerHeader.Marshal() if err != nil { return nil, err } p.record.Header = *recordlayerHeader rawPacket = append(rawPacket, handshakeFragment...) if p.shouldEncrypt { var err error rawPacket, err = c.state.cipherSuite.Encrypt(p.record, rawPacket) if err != nil { return nil, err } } rawPackets = append(rawPackets, rawPacket) } return rawPackets, nil } func (c *Conn) fragmentHandshake(h *handshake.Handshake) ([][]byte, error) { content, err := h.Message.Marshal() if err != nil { return nil, err } fragmentedHandshakes := make([][]byte, 0) contentFragments := splitBytes(content, c.maximumTransmissionUnit) if len(contentFragments) == 0 { contentFragments = [][]byte{ {}, } } offset := 0 for _, contentFragment := range contentFragments { contentFragmentLen := len(contentFragment) headerFragment := &handshake.Header{ Type: h.Header.Type, Length: h.Header.Length, MessageSequence: h.Header.MessageSequence, FragmentOffset: uint32(offset), FragmentLength: uint32(contentFragmentLen), } offset += contentFragmentLen fragmentedHandshake, err := headerFragment.Marshal() if err != nil { return nil, err } fragmentedHandshake = append(fragmentedHandshake, contentFragment...) fragmentedHandshakes = append(fragmentedHandshakes, fragmentedHandshake) } return fragmentedHandshakes, nil } var poolReadBuffer = sync.Pool{ //nolint:gochecknoglobals New: func() interface{} { b := make([]byte, inboundBufferSize) return &b }, } func (c *Conn) readAndBuffer(ctx context.Context) error { bufptr, ok := poolReadBuffer.Get().(*[]byte) if !ok { return errFailedToAccessPoolReadBuffer } defer poolReadBuffer.Put(bufptr) b := *bufptr i, err := c.nextConn.ReadContext(ctx, b) if err != nil { return netError(err) } pkts, err := recordlayer.UnpackDatagram(b[:i]) if err != nil { return err } var hasHandshake bool for _, p := range pkts { hs, alert, err := c.handleIncomingPacket(ctx, p, true) if alert != nil { if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil { if err == nil { err = alertErr } } } if hs { hasHandshake = true } var e *alertError if errors.As(err, &e) { if e.IsFatalOrCloseNotify() { return e } } else if err != nil { return e } } if hasHandshake { done := make(chan struct{}) select { case c.handshakeRecv <- done: // If the other party may retransmit the flight, // we should respond even if it not a new message. <-done case <-c.fsm.Done(): } } return nil } func (c *Conn) handleQueuedPackets(ctx context.Context) error { pkts := c.encryptedPackets c.encryptedPackets = nil for _, p := range pkts { _, alert, err := c.handleIncomingPacket(ctx, p, false) // don't re-enqueue if alert != nil { if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil { if err == nil { err = alertErr } } } var e *alertError if errors.As(err, &e) { if e.IsFatalOrCloseNotify() { return e } } else if err != nil { return e } } return nil } func (c *Conn) handleIncomingPacket(ctx context.Context, buf []byte, enqueue bool) (bool, *alert.Alert, error) { //nolint:gocognit h := &recordlayer.Header{} if err := h.Unmarshal(buf); err != nil { // Decode error must be silently discarded // [RFC6347 Section-4.1.2.7] c.log.Debugf("discarded broken packet: %v", err) return false, nil, nil } // Validate epoch remoteEpoch := c.state.getRemoteEpoch() if h.Epoch > remoteEpoch { if h.Epoch > remoteEpoch+1 { c.log.Debugf("discarded future packet (epoch: %d, seq: %d)", h.Epoch, h.SequenceNumber, ) return false, nil, nil } if enqueue { c.log.Debug("received packet of next epoch, queuing packet") c.encryptedPackets = append(c.encryptedPackets, buf) } return false, nil, nil } // Anti-replay protection for len(c.state.replayDetector) <= int(h.Epoch) { c.state.replayDetector = append(c.state.replayDetector, replaydetector.New(c.replayProtectionWindow, recordlayer.MaxSequenceNumber), ) } markPacketAsValid, ok := c.state.replayDetector[int(h.Epoch)].Check(h.SequenceNumber) if !ok { c.log.Debugf("discarded duplicated packet (epoch: %d, seq: %d)", h.Epoch, h.SequenceNumber, ) return false, nil, nil } // Decrypt if h.Epoch != 0 { if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() { if enqueue { c.encryptedPackets = append(c.encryptedPackets, buf) c.log.Debug("handshake not finished, queuing packet") } return false, nil, nil } var err error buf, err = c.state.cipherSuite.Decrypt(buf) if err != nil { c.log.Debugf("%s: decrypt failed: %s", srvCliStr(c.state.isClient), err) return false, nil, nil } } isHandshake, err := c.fragmentBuffer.push(append([]byte{}, buf...)) if err != nil { // Decode error must be silently discarded // [RFC6347 Section-4.1.2.7] c.log.Debugf("defragment failed: %s", err) return false, nil, nil } else if isHandshake { markPacketAsValid() for out, epoch := c.fragmentBuffer.pop(); out != nil; out, epoch = c.fragmentBuffer.pop() { header := &handshake.Header{} if err := header.Unmarshal(out); err != nil { c.log.Debugf("%s: handshake parse failed: %s", srvCliStr(c.state.isClient), err) continue } c.handshakeCache.push(out, epoch, header.MessageSequence, header.Type, !c.state.isClient) } return true, nil, nil } r := &recordlayer.RecordLayer{} if err := r.Unmarshal(buf); err != nil { return false, &alert.Alert{Level: alert.Fatal, Description: alert.DecodeError}, err } switch content := r.Content.(type) { case *alert.Alert: c.log.Tracef("%s: <- %s", srvCliStr(c.state.isClient), content.String()) var a *alert.Alert if content.Description == alert.CloseNotify { // Respond with a close_notify [RFC5246 Section 7.2.1] a = &alert.Alert{Level: alert.Warning, Description: alert.CloseNotify} } markPacketAsValid() return false, a, &alertError{content} case *protocol.ChangeCipherSpec: if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() { if enqueue { c.encryptedPackets = append(c.encryptedPackets, buf) c.log.Debugf("CipherSuite not initialized, queuing packet") } return false, nil, nil } newRemoteEpoch := h.Epoch + 1 c.log.Tracef("%s: <- ChangeCipherSpec (epoch: %d)", srvCliStr(c.state.isClient), newRemoteEpoch) if c.state.getRemoteEpoch()+1 == newRemoteEpoch { c.setRemoteEpoch(newRemoteEpoch) markPacketAsValid() } case *protocol.ApplicationData: if h.Epoch == 0 { return false, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, errApplicationDataEpochZero } markPacketAsValid() select { case c.decrypted <- content.Data: case <-c.closed.Done(): case <-ctx.Done(): } default: return false, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, fmt.Errorf("%w: %d", errUnhandledContextType, content.ContentType()) } return false, nil, nil } func (c *Conn) recvHandshake() <-chan chan struct{} { return c.handshakeRecv } func (c *Conn) notify(ctx context.Context, level alert.Level, desc alert.Description) error { if level == alert.Fatal && len(c.state.SessionID) > 0 { // According to the RFC, we need to delete the stored session. // https://datatracker.ietf.org/doc/html/rfc5246#section-7.2 if ss := c.fsm.cfg.sessionStore; ss != nil { c.log.Tracef("clean invalid session: %s", c.state.SessionID) if err := ss.Del(c.sessionKey()); err != nil { return err } } } return c.writePackets(ctx, []*packet{ { record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Epoch: c.state.getLocalEpoch(), Version: protocol.Version1_2, }, Content: &alert.Alert{ Level: level, Description: desc, }, }, shouldEncrypt: c.isHandshakeCompletedSuccessfully(), }, }) } func (c *Conn) setHandshakeCompletedSuccessfully() { c.handshakeCompletedSuccessfully.Store(struct{ bool }{true}) } func (c *Conn) isHandshakeCompletedSuccessfully() bool { boolean, _ := c.handshakeCompletedSuccessfully.Load().(struct{ bool }) return boolean.bool } func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFlight flightVal, initialState handshakeState) error { //nolint:gocognit c.fsm = newHandshakeFSM(&c.state, c.handshakeCache, cfg, initialFlight) done := make(chan struct{}) ctxRead, cancelRead := context.WithCancel(context.Background()) c.cancelHandshakeReader = cancelRead cfg.onFlightState = func(f flightVal, s handshakeState) { if s == handshakeFinished && !c.isHandshakeCompletedSuccessfully() { c.setHandshakeCompletedSuccessfully() close(done) } } ctxHs, cancel := context.WithCancel(context.Background()) c.cancelHandshaker = cancel firstErr := make(chan error, 1) c.handshakeLoopsFinished.Add(2) // Handshake routine should be live until close. // The other party may request retransmission of the last flight to cope with packet drop. go func() { defer c.handshakeLoopsFinished.Done() err := c.fsm.Run(ctxHs, c, initialState) if !errors.Is(err, context.Canceled) { select { case firstErr <- err: default: } } }() go func() { defer func() { // Escaping read loop. // It's safe to close decrypted channnel now. close(c.decrypted) // Force stop handshaker when the underlying connection is closed. cancel() }() defer c.handshakeLoopsFinished.Done() for { if err := c.readAndBuffer(ctxRead); err != nil { var e *alertError if errors.As(err, &e) { if !e.IsFatalOrCloseNotify() { if c.isHandshakeCompletedSuccessfully() { // Pass the error to Read() select { case c.decrypted <- err: case <-c.closed.Done(): case <-ctxRead.Done(): } } continue // non-fatal alert must not stop read loop } } else { switch { case errors.Is(err, context.DeadlineExceeded), errors.Is(err, context.Canceled), errors.Is(err, io.EOF): default: if c.isHandshakeCompletedSuccessfully() { // Keep read loop and pass the read error to Read() select { case c.decrypted <- err: case <-c.closed.Done(): case <-ctxRead.Done(): } continue // non-fatal alert must not stop read loop } } } select { case firstErr <- err: default: } if e != nil { if e.IsFatalOrCloseNotify() { _ = c.close(false) //nolint:contextcheck } } if !c.isConnectionClosed() && errors.Is(err, context.Canceled) { c.log.Trace("handshake timeouts - closing underline connection") _ = c.close(false) //nolint:contextcheck } return } } }() select { case err := <-firstErr: cancelRead() cancel() c.handshakeLoopsFinished.Wait() return c.translateHandshakeCtxError(err) case <-ctx.Done(): cancelRead() cancel() c.handshakeLoopsFinished.Wait() return c.translateHandshakeCtxError(ctx.Err()) case <-done: return nil } } func (c *Conn) translateHandshakeCtxError(err error) error { if err == nil { return nil } if errors.Is(err, context.Canceled) && c.isHandshakeCompletedSuccessfully() { return nil } return &HandshakeError{Err: err} } func (c *Conn) close(byUser bool) error { c.cancelHandshaker() c.cancelHandshakeReader() if c.isHandshakeCompletedSuccessfully() && byUser { // Discard error from notify() to return non-error on the first user call of Close() // even if the underlying connection is already closed. _ = c.notify(context.Background(), alert.Warning, alert.CloseNotify) } c.closeLock.Lock() // Don't return ErrConnClosed at the first time of the call from user. closedByUser := c.connectionClosedByUser if byUser { c.connectionClosedByUser = true } isClosed := c.isConnectionClosed() c.closed.Close() c.closeLock.Unlock() if closedByUser { return ErrConnClosed } if isClosed { return nil } return c.nextConn.Close() } func (c *Conn) isConnectionClosed() bool { select { case <-c.closed.Done(): return true default: return false } } func (c *Conn) setLocalEpoch(epoch uint16) { c.state.localEpoch.Store(epoch) } func (c *Conn) setRemoteEpoch(epoch uint16) { c.state.remoteEpoch.Store(epoch) } // LocalAddr implements net.Conn.LocalAddr func (c *Conn) LocalAddr() net.Addr { return c.nextConn.LocalAddr() } // RemoteAddr implements net.Conn.RemoteAddr func (c *Conn) RemoteAddr() net.Addr { return c.nextConn.RemoteAddr() } func (c *Conn) sessionKey() []byte { if c.state.isClient { // As ServerName can be like 0.example.com, it's better to add // delimiter character which is not allowed to be in // neither address or domain name. return []byte(c.nextConn.RemoteAddr().String() + "_" + c.fsm.cfg.serverName) } return c.state.SessionID } // SetDeadline implements net.Conn.SetDeadline func (c *Conn) SetDeadline(t time.Time) error { c.readDeadline.Set(t) return c.SetWriteDeadline(t) } // SetReadDeadline implements net.Conn.SetReadDeadline func (c *Conn) SetReadDeadline(t time.Time) error { c.readDeadline.Set(t) // Read deadline is fully managed by this layer. // Don't set read deadline to underlying connection. return nil } // SetWriteDeadline implements net.Conn.SetWriteDeadline func (c *Conn) SetWriteDeadline(t time.Time) error { c.writeDeadline.Set(t) // Write deadline is also fully managed by this layer. return nil } dtls-2.2.6/conn_go_test.go000066400000000000000000000076331437412644000155150ustar00rootroot00000000000000//go:build !js // +build !js package dtls import ( "bytes" "context" "crypto/tls" "errors" "net" "testing" "time" "github.com/pion/dtls/v2/internal/net/dpipe" "github.com/pion/dtls/v2/pkg/crypto/selfsign" "github.com/pion/transport/v2/test" ) func TestContextConfig(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() report := test.CheckRoutines(t) defer report() addrListen, err := net.ResolveUDPAddr("udp", "localhost:0") if err != nil { t.Fatalf("Unexpected error: %v", err) } // Dummy listener listen, err := net.ListenUDP("udp", addrListen) if err != nil { t.Fatalf("Unexpected error: %v", err) } defer func() { _ = listen.Close() }() addr, ok := listen.LocalAddr().(*net.UDPAddr) if !ok { t.Fatal("Failed to cast net.UDPAddr") } cert, err := selfsign.GenerateSelfSigned() if err != nil { t.Fatalf("Unexpected error: %v", err) } config := &Config{ ConnectContextMaker: func() (context.Context, func()) { return context.WithTimeout(context.Background(), 40*time.Millisecond) }, Certificates: []tls.Certificate{cert}, } dials := map[string]struct { f func() (func() (net.Conn, error), func()) order []byte }{ "Dial": { f: func() (func() (net.Conn, error), func()) { return func() (net.Conn, error) { return Dial("udp", addr, config) }, func() { } }, order: []byte{0, 1, 2}, }, "DialWithContext": { f: func() (func() (net.Conn, error), func()) { ctx, cancel := context.WithTimeout(context.Background(), 80*time.Millisecond) return func() (net.Conn, error) { return DialWithContext(ctx, "udp", addr, config) }, func() { cancel() } }, order: []byte{0, 2, 1}, }, "Client": { f: func() (func() (net.Conn, error), func()) { ca, _ := dpipe.Pipe() return func() (net.Conn, error) { return Client(ca, config) }, func() { _ = ca.Close() } }, order: []byte{0, 1, 2}, }, "ClientWithContext": { f: func() (func() (net.Conn, error), func()) { ctx, cancel := context.WithTimeout(context.Background(), 80*time.Millisecond) ca, _ := dpipe.Pipe() return func() (net.Conn, error) { return ClientWithContext(ctx, ca, config) }, func() { cancel() _ = ca.Close() } }, order: []byte{0, 2, 1}, }, "Server": { f: func() (func() (net.Conn, error), func()) { ca, _ := dpipe.Pipe() return func() (net.Conn, error) { return Server(ca, config) }, func() { _ = ca.Close() } }, order: []byte{0, 1, 2}, }, "ServerWithContext": { f: func() (func() (net.Conn, error), func()) { ctx, cancel := context.WithTimeout(context.Background(), 80*time.Millisecond) ca, _ := dpipe.Pipe() return func() (net.Conn, error) { return ServerWithContext(ctx, ca, config) }, func() { cancel() _ = ca.Close() } }, order: []byte{0, 2, 1}, }, } for name, dial := range dials { dial := dial t.Run(name, func(t *testing.T) { done := make(chan struct{}) go func() { d, cancel := dial.f() conn, err := d() defer cancel() var netError net.Error if !errors.As(err, &netError) || !netError.Temporary() { //nolint:staticcheck t.Errorf("Client error exp(Temporary network error) failed(%v)", err) close(done) return } done <- struct{}{} if err == nil { _ = conn.Close() } }() var order []byte early := time.After(20 * time.Millisecond) late := time.After(60 * time.Millisecond) func() { for len(order) < 3 { select { case <-early: order = append(order, 0) case _, ok := <-done: if !ok { return } order = append(order, 1) case <-late: order = append(order, 2) } } }() if !bytes.Equal(dial.order, order) { t.Errorf("Invalid cancel timing, expected: %v, got: %v", dial.order, order) } }) } } dtls-2.2.6/conn_test.go000066400000000000000000002367231437412644000150340ustar00rootroot00000000000000package dtls import ( "bytes" "context" "crypto" "crypto/ecdsa" cryptoElliptic "crypto/elliptic" "crypto/rand" "crypto/rsa" "crypto/tls" "crypto/x509" "encoding/hex" "errors" "fmt" "io" "net" "strings" "sync" "sync/atomic" "testing" "time" "github.com/pion/dtls/v2/internal/ciphersuite" "github.com/pion/dtls/v2/internal/net/dpipe" "github.com/pion/dtls/v2/pkg/crypto/elliptic" "github.com/pion/dtls/v2/pkg/crypto/hash" "github.com/pion/dtls/v2/pkg/crypto/selfsign" "github.com/pion/dtls/v2/pkg/crypto/signature" "github.com/pion/dtls/v2/pkg/crypto/signaturehash" "github.com/pion/dtls/v2/pkg/protocol" "github.com/pion/dtls/v2/pkg/protocol/alert" "github.com/pion/dtls/v2/pkg/protocol/extension" "github.com/pion/dtls/v2/pkg/protocol/handshake" "github.com/pion/dtls/v2/pkg/protocol/recordlayer" "github.com/pion/logging" "github.com/pion/transport/v2/test" ) var ( errTestPSKInvalidIdentity = errors.New("TestPSK: Server got invalid identity") errPSKRejected = errors.New("PSK Rejected") errNotExpectedChain = errors.New("not expected chain") errExpecedChain = errors.New("expected chain") errWrongCert = errors.New("wrong cert") ) func TestStressDuplex(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() // Run the test stressDuplex(t) } func stressDuplex(t *testing.T) { ca, cb, err := pipeMemory() if err != nil { t.Fatal(err) } defer func() { err = ca.Close() if err != nil { t.Fatal(err) } err = cb.Close() if err != nil { t.Fatal(err) } }() opt := test.Options{ MsgSize: 2048, MsgCount: 100, } err = test.StressDuplex(ca, cb, opt) if err != nil { t.Fatal(err) } } func TestRoutineLeakOnClose(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(5 * time.Second) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() ca, cb, err := pipeMemory() if err != nil { t.Fatal(err) } if _, err := ca.Write(make([]byte, 100)); err != nil { t.Fatal(err) } if err := cb.Close(); err != nil { t.Fatal(err) } if err := ca.Close(); err != nil { t.Fatal(err) } // Packet is sent, but not read. // inboundLoop routine should not be leaked. } func TestReadWriteDeadline(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(5 * time.Second) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() var e net.Error ca, cb, err := pipeMemory() if err != nil { t.Fatal(err) } if err := ca.SetDeadline(time.Unix(0, 1)); err != nil { t.Fatal(err) } _, werr := ca.Write(make([]byte, 100)) if errors.As(werr, &e) { if !e.Timeout() { t.Error("Deadline exceeded Write must return Timeout error") } if !e.Temporary() { //nolint:staticcheck t.Error("Deadline exceeded Write must return Temporary error") } } else { t.Error("Write must return net.Error error") } _, rerr := ca.Read(make([]byte, 100)) if errors.As(rerr, &e) { if !e.Timeout() { t.Error("Deadline exceeded Read must return Timeout error") } if !e.Temporary() { //nolint:staticcheck t.Error("Deadline exceeded Read must return Temporary error") } } else { t.Error("Read must return net.Error error") } if err := ca.SetDeadline(time.Time{}); err != nil { t.Error(err) } if err := ca.Close(); err != nil { t.Error(err) } if err := cb.Close(); err != nil { t.Error(err) } if _, err := ca.Write(make([]byte, 100)); !errors.Is(err, ErrConnClosed) { t.Errorf("Write must return %v after close, got %v", ErrConnClosed, err) } if _, err := ca.Read(make([]byte, 100)); !errors.Is(err, io.EOF) { t.Errorf("Read must return %v after close, got %v", io.EOF, err) } } func TestSequenceNumberOverflow(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(5 * time.Second) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() t.Run("ApplicationData", func(t *testing.T) { ca, cb, err := pipeMemory() if err != nil { t.Fatal(err) } atomic.StoreUint64(&ca.state.localSequenceNumber[1], recordlayer.MaxSequenceNumber) if _, werr := ca.Write(make([]byte, 100)); werr != nil { t.Errorf("Write must send message with maximum sequence number, but errord: %v", werr) } if _, werr := ca.Write(make([]byte, 100)); !errors.Is(werr, errSequenceNumberOverflow) { t.Errorf("Write must abandonsend message with maximum sequence number, but errord: %v", werr) } if err := ca.Close(); err != nil { t.Error(err) } if err := cb.Close(); err != nil { t.Error(err) } }) t.Run("Handshake", func(t *testing.T) { ca, cb, err := pipeMemory() if err != nil { t.Fatal(err) } ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() atomic.StoreUint64(&ca.state.localSequenceNumber[0], recordlayer.MaxSequenceNumber+1) // Try to send handshake packet. if werr := ca.writePackets(ctx, []*packet{ { record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &handshake.Handshake{ Message: &handshake.MessageClientHello{ Version: protocol.Version1_2, Cookie: make([]byte, 64), CipherSuiteIDs: cipherSuiteIDs(defaultCipherSuites()), CompressionMethods: defaultCompressionMethods(), }, }, }, }, }); !errors.Is(werr, errSequenceNumberOverflow) { t.Errorf("Connection must fail on handshake packet reaches maximum sequence number") } if err := ca.Close(); err != nil { t.Error(err) } if err := cb.Close(); err != nil { t.Error(err) } }) } func pipeMemory() (*Conn, *Conn, error) { // In memory pipe ca, cb := dpipe.Pipe() return pipeConn(ca, cb) } func pipeConn(ca, cb net.Conn) (*Conn, *Conn, error) { type result struct { c *Conn err error } c := make(chan result) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() // Setup client go func() { client, err := testClient(ctx, ca, &Config{SRTPProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}}, true) c <- result{client, err} }() // Setup server server, err := testServer(ctx, cb, &Config{SRTPProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}}, true) if err != nil { return nil, nil, err } // Receive client res := <-c if res.err != nil { _ = server.Close() return nil, nil, res.err } return res.c, server, nil } func testClient(ctx context.Context, c net.Conn, cfg *Config, generateCertificate bool) (*Conn, error) { if generateCertificate { clientCert, err := selfsign.GenerateSelfSigned() if err != nil { return nil, err } cfg.Certificates = []tls.Certificate{clientCert} } cfg.InsecureSkipVerify = true return ClientWithContext(ctx, c, cfg) } func testServer(ctx context.Context, c net.Conn, cfg *Config, generateCertificate bool) (*Conn, error) { if generateCertificate { serverCert, err := selfsign.GenerateSelfSigned() if err != nil { return nil, err } cfg.Certificates = []tls.Certificate{serverCert} } return ServerWithContext(ctx, c, cfg) } func sendClientHello(cookie []byte, ca net.Conn, sequenceNumber uint64, extensions []extension.Extension) error { packet, err := (&recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, SequenceNumber: sequenceNumber, }, Content: &handshake.Handshake{ Header: handshake.Header{ MessageSequence: uint16(sequenceNumber), }, Message: &handshake.MessageClientHello{ Version: protocol.Version1_2, Cookie: cookie, CipherSuiteIDs: cipherSuiteIDs(defaultCipherSuites()), CompressionMethods: defaultCompressionMethods(), Extensions: extensions, }, }, }).Marshal() if err != nil { return err } if _, err = ca.Write(packet); err != nil { return err } return nil } func TestHandshakeWithAlert(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() cases := map[string]struct { configServer, configClient *Config errServer, errClient error }{ "CipherSuiteNoIntersection": { configServer: &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, }, configClient: &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}, }, errServer: errCipherSuiteNoIntersection, errClient: &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}}, }, "SignatureSchemesNoIntersection": { configServer: &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, SignatureSchemes: []tls.SignatureScheme{tls.ECDSAWithP256AndSHA256}, }, configClient: &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, SignatureSchemes: []tls.SignatureScheme{tls.ECDSAWithP521AndSHA512}, }, errServer: &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}}, errClient: errNoAvailableSignatureSchemes, }, } for name, testCase := range cases { testCase := testCase t.Run(name, func(t *testing.T) { clientErr := make(chan error, 1) ca, cb := dpipe.Pipe() go func() { _, err := testClient(ctx, ca, testCase.configClient, true) clientErr <- err }() _, errServer := testServer(ctx, cb, testCase.configServer, true) if !errors.Is(errServer, testCase.errServer) { t.Fatalf("Server error exp(%v) failed(%v)", testCase.errServer, errServer) } errClient := <-clientErr if !errors.Is(errClient, testCase.errClient) { t.Fatalf("Client error exp(%v) failed(%v)", testCase.errClient, errClient) } }) } } func TestExportKeyingMaterial(t *testing.T) { // Check for leaking routines report := test.CheckRoutines(t) defer report() var rand [28]byte exportLabel := "EXTRACTOR-dtls_srtp" expectedServerKey := []byte{0x61, 0x09, 0x9d, 0x7d, 0xcb, 0x08, 0x52, 0x2c, 0xe7, 0x7b} expectedClientKey := []byte{0x87, 0xf0, 0x40, 0x02, 0xf6, 0x1c, 0xf1, 0xfe, 0x8c, 0x77} c := &Conn{ state: State{ localRandom: handshake.Random{GMTUnixTime: time.Unix(500, 0), RandomBytes: rand}, remoteRandom: handshake.Random{GMTUnixTime: time.Unix(1000, 0), RandomBytes: rand}, localSequenceNumber: []uint64{0, 0}, cipherSuite: &ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{}, }, } c.setLocalEpoch(0) c.setRemoteEpoch(0) state := c.ConnectionState() _, err := state.ExportKeyingMaterial(exportLabel, nil, 0) if !errors.Is(err, errHandshakeInProgress) { t.Errorf("ExportKeyingMaterial when epoch == 0: expected '%s' actual '%s'", errHandshakeInProgress, err) } c.setLocalEpoch(1) state = c.ConnectionState() _, err = state.ExportKeyingMaterial(exportLabel, []byte{0x00}, 0) if !errors.Is(err, errContextUnsupported) { t.Errorf("ExportKeyingMaterial with context: expected '%s' actual '%s'", errContextUnsupported, err) } for k := range invalidKeyingLabels() { state = c.ConnectionState() _, err = state.ExportKeyingMaterial(k, nil, 0) if !errors.Is(err, errReservedExportKeyingMaterial) { t.Errorf("ExportKeyingMaterial reserved label: expected '%s' actual '%s'", errReservedExportKeyingMaterial, err) } } state = c.ConnectionState() keyingMaterial, err := state.ExportKeyingMaterial(exportLabel, nil, 10) if err != nil { t.Errorf("ExportKeyingMaterial as server: unexpected error '%s'", err) } else if !bytes.Equal(keyingMaterial, expectedServerKey) { t.Errorf("ExportKeyingMaterial client export: expected (% 02x) actual (% 02x)", expectedServerKey, keyingMaterial) } c.state.isClient = true state = c.ConnectionState() keyingMaterial, err = state.ExportKeyingMaterial(exportLabel, nil, 10) if err != nil { t.Errorf("ExportKeyingMaterial as server: unexpected error '%s'", err) } else if !bytes.Equal(keyingMaterial, expectedClientKey) { t.Errorf("ExportKeyingMaterial client export: expected (% 02x) actual (% 02x)", expectedClientKey, keyingMaterial) } } func TestPSK(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() for _, test := range []struct { Name string ServerIdentity []byte CipherSuites []CipherSuiteID ClientVerifyConnection func(*State) error ServerVerifyConnection func(*State) error WantFail bool ExpectedServerErr string ExpectedClientErr string }{ { Name: "Server identity specified", ServerIdentity: []byte("Test Identity"), CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, }, { Name: "Server identity specified - Server verify connection fails", ServerIdentity: []byte("Test Identity"), CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, ServerVerifyConnection: func(s *State) error { return errExample }, WantFail: true, ExpectedServerErr: errExample.Error(), ExpectedClientErr: alert.BadCertificate.String(), }, { Name: "Server identity specified - Client verify connection fails", ServerIdentity: []byte("Test Identity"), CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, ClientVerifyConnection: func(s *State) error { return errExample }, WantFail: true, ExpectedServerErr: alert.BadCertificate.String(), ExpectedClientErr: errExample.Error(), }, { Name: "Server identity nil", ServerIdentity: nil, CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, }, { Name: "TLS_PSK_WITH_AES_128_CBC_SHA256", ServerIdentity: nil, CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CBC_SHA256}, }, { Name: "TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256", ServerIdentity: nil, CipherSuites: []CipherSuiteID{TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256}, }, } { test := test t.Run(test.Name, func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() clientIdentity := []byte("Client Identity") type result struct { c *Conn err error } clientRes := make(chan result, 1) ca, cb := dpipe.Pipe() go func() { conf := &Config{ PSK: func(hint []byte) ([]byte, error) { if !bytes.Equal(test.ServerIdentity, hint) { return nil, fmt.Errorf("TestPSK: Client got invalid identity expected(% 02x) actual(% 02x)", test.ServerIdentity, hint) //nolint:goerr113 } return []byte{0xAB, 0xC1, 0x23}, nil }, PSKIdentityHint: clientIdentity, CipherSuites: test.CipherSuites, VerifyConnection: test.ClientVerifyConnection, } c, err := testClient(ctx, ca, conf, false) clientRes <- result{c, err} }() config := &Config{ PSK: func(hint []byte) ([]byte, error) { if !bytes.Equal(clientIdentity, hint) { return nil, fmt.Errorf("%w: expected(% 02x) actual(% 02x)", errTestPSKInvalidIdentity, clientIdentity, hint) } return []byte{0xAB, 0xC1, 0x23}, nil }, PSKIdentityHint: test.ServerIdentity, CipherSuites: test.CipherSuites, VerifyConnection: test.ServerVerifyConnection, } server, err := testServer(ctx, cb, config, false) if test.WantFail { res := <-clientRes if err == nil || !strings.Contains(err.Error(), test.ExpectedServerErr) { t.Fatalf("TestPSK: Server expected(%v) actual(%v)", test.ExpectedServerErr, err) } if res.err == nil || !strings.Contains(res.err.Error(), test.ExpectedClientErr) { t.Fatalf("TestPSK: Client expected(%v) actual(%v)", test.ExpectedClientErr, res.err) } return } if err != nil { t.Fatalf("TestPSK: Server failed(%v)", err) } actualPSKIdentityHint := server.ConnectionState().IdentityHint if !bytes.Equal(actualPSKIdentityHint, clientIdentity) { t.Errorf("TestPSK: Server ClientPSKIdentity Mismatch '%s': expected(%v) actual(%v)", test.Name, clientIdentity, actualPSKIdentityHint) } defer func() { _ = server.Close() }() res := <-clientRes if res.err != nil { t.Fatal(res.err) } _ = res.c.Close() }) } } func TestPSKHintFail(t *testing.T) { // Check for leaking routines report := test.CheckRoutines(t) defer report() serverAlertError := &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InternalError}} pskRejected := errPSKRejected // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() clientErr := make(chan error, 1) ca, cb := dpipe.Pipe() go func() { conf := &Config{ PSK: func(hint []byte) ([]byte, error) { return nil, pskRejected }, PSKIdentityHint: []byte{}, CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, } _, err := testClient(ctx, ca, conf, false) clientErr <- err }() config := &Config{ PSK: func(hint []byte) ([]byte, error) { return nil, pskRejected }, PSKIdentityHint: []byte{}, CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, } if _, err := testServer(ctx, cb, config, false); !errors.Is(err, serverAlertError) { t.Fatalf("TestPSK: Server error exp(%v) failed(%v)", serverAlertError, err) } if err := <-clientErr; !errors.Is(err, pskRejected) { t.Fatalf("TestPSK: Client error exp(%v) failed(%v)", pskRejected, err) } } func TestClientTimeout(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() clientErr := make(chan error, 1) ca, _ := dpipe.Pipe() go func() { conf := &Config{} c, err := testClient(ctx, ca, conf, true) if err == nil { _ = c.Close() //nolint:contextcheck } clientErr <- err }() // no server! err := <-clientErr var netErr net.Error if !errors.As(err, &netErr) || !netErr.Timeout() { t.Fatalf("Client error exp(Temporary network error) failed(%v)", err) } } func TestSRTPConfiguration(t *testing.T) { // Check for leaking routines report := test.CheckRoutines(t) defer report() for _, test := range []struct { Name string ClientSRTP []SRTPProtectionProfile ServerSRTP []SRTPProtectionProfile ExpectedProfile SRTPProtectionProfile WantClientError error WantServerError error }{ { Name: "No SRTP in use", ClientSRTP: nil, ServerSRTP: nil, ExpectedProfile: 0, WantClientError: nil, WantServerError: nil, }, { Name: "SRTP both ends", ClientSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, ServerSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, ExpectedProfile: SRTP_AES128_CM_HMAC_SHA1_80, WantClientError: nil, WantServerError: nil, }, { Name: "SRTP client only", ClientSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, ServerSRTP: nil, ExpectedProfile: 0, WantClientError: &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}}, WantServerError: errServerNoMatchingSRTPProfile, }, { Name: "SRTP server only", ClientSRTP: nil, ServerSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, ExpectedProfile: 0, WantClientError: nil, WantServerError: nil, }, { Name: "Multiple Suites", ClientSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80, SRTP_AES128_CM_HMAC_SHA1_32}, ServerSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80, SRTP_AES128_CM_HMAC_SHA1_32}, ExpectedProfile: SRTP_AES128_CM_HMAC_SHA1_80, WantClientError: nil, WantServerError: nil, }, { Name: "Multiple Suites, Client Chooses", ClientSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80, SRTP_AES128_CM_HMAC_SHA1_32}, ServerSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_32, SRTP_AES128_CM_HMAC_SHA1_80}, ExpectedProfile: SRTP_AES128_CM_HMAC_SHA1_80, WantClientError: nil, WantServerError: nil, }, } { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() ca, cb := dpipe.Pipe() type result struct { c *Conn err error } c := make(chan result) go func() { client, err := testClient(ctx, ca, &Config{SRTPProtectionProfiles: test.ClientSRTP}, true) c <- result{client, err} }() server, err := testServer(ctx, cb, &Config{SRTPProtectionProfiles: test.ServerSRTP}, true) if !errors.Is(err, test.WantServerError) { t.Errorf("TestSRTPConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantServerError, err) } if err == nil { defer func() { _ = server.Close() }() } res := <-c if res.err == nil { defer func() { _ = res.c.Close() }() } if !errors.Is(res.err, test.WantClientError) { t.Fatalf("TestSRTPConfiguration: Client Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantClientError, res.err) } if res.c == nil { return } actualClientSRTP, _ := res.c.SelectedSRTPProtectionProfile() if actualClientSRTP != test.ExpectedProfile { t.Errorf("TestSRTPConfiguration: Client SRTPProtectionProfile Mismatch '%s': expected(%v) actual(%v)", test.Name, test.ExpectedProfile, actualClientSRTP) } actualServerSRTP, _ := server.SelectedSRTPProtectionProfile() if actualServerSRTP != test.ExpectedProfile { t.Errorf("TestSRTPConfiguration: Server SRTPProtectionProfile Mismatch '%s': expected(%v) actual(%v)", test.Name, test.ExpectedProfile, actualServerSRTP) } } } func TestClientCertificate(t *testing.T) { // Check for leaking routines report := test.CheckRoutines(t) defer report() srvCert, err := selfsign.GenerateSelfSigned() if err != nil { t.Fatal(err) } srvCAPool := x509.NewCertPool() srvCertificate, err := x509.ParseCertificate(srvCert.Certificate[0]) if err != nil { t.Fatal(err) } srvCAPool.AddCert(srvCertificate) cert, err := selfsign.GenerateSelfSigned() if err != nil { t.Fatal(err) } certificate, err := x509.ParseCertificate(cert.Certificate[0]) if err != nil { t.Fatal(err) } caPool := x509.NewCertPool() caPool.AddCert(certificate) t.Run("parallel", func(t *testing.T) { // sync routines to check routine leak tests := map[string]struct { clientCfg *Config serverCfg *Config wantErr bool }{ "NoClientCert": { clientCfg: &Config{RootCAs: srvCAPool}, serverCfg: &Config{ Certificates: []tls.Certificate{srvCert}, ClientAuth: NoClientCert, ClientCAs: caPool, }, }, "NoClientCert_ServerVerifyConnectionFails": { clientCfg: &Config{RootCAs: srvCAPool}, serverCfg: &Config{ Certificates: []tls.Certificate{srvCert}, ClientAuth: NoClientCert, ClientCAs: caPool, VerifyConnection: func(s *State) error { return errExample }, }, wantErr: true, }, "NoClientCert_ClientVerifyConnectionFails": { clientCfg: &Config{RootCAs: srvCAPool, VerifyConnection: func(s *State) error { return errExample }}, serverCfg: &Config{ Certificates: []tls.Certificate{srvCert}, ClientAuth: NoClientCert, ClientCAs: caPool, }, wantErr: true, }, "NoClientCert_cert": { clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}}, serverCfg: &Config{ Certificates: []tls.Certificate{srvCert}, ClientAuth: RequireAnyClientCert, }, }, "RequestClientCert_cert": { clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}}, serverCfg: &Config{ Certificates: []tls.Certificate{srvCert}, ClientAuth: RequestClientCert, }, }, "RequestClientCert_no_cert": { clientCfg: &Config{RootCAs: srvCAPool}, serverCfg: &Config{ Certificates: []tls.Certificate{srvCert}, ClientAuth: RequestClientCert, ClientCAs: caPool, }, }, "RequireAnyClientCert": { clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}}, serverCfg: &Config{ Certificates: []tls.Certificate{srvCert}, ClientAuth: RequireAnyClientCert, }, }, "RequireAnyClientCert_error": { clientCfg: &Config{RootCAs: srvCAPool}, serverCfg: &Config{ Certificates: []tls.Certificate{srvCert}, ClientAuth: RequireAnyClientCert, }, wantErr: true, }, "VerifyClientCertIfGiven_no_cert": { clientCfg: &Config{RootCAs: srvCAPool}, serverCfg: &Config{ Certificates: []tls.Certificate{srvCert}, ClientAuth: VerifyClientCertIfGiven, ClientCAs: caPool, }, }, "VerifyClientCertIfGiven_cert": { clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}}, serverCfg: &Config{ Certificates: []tls.Certificate{srvCert}, ClientAuth: VerifyClientCertIfGiven, ClientCAs: caPool, }, }, "VerifyClientCertIfGiven_error": { clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}}, serverCfg: &Config{ Certificates: []tls.Certificate{srvCert}, ClientAuth: VerifyClientCertIfGiven, }, wantErr: true, }, "RequireAndVerifyClientCert": { clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}, VerifyConnection: func(s *State) error { if ok := bytes.Equal(s.PeerCertificates[0], srvCertificate.Raw); !ok { return errExample } return nil }}, serverCfg: &Config{ Certificates: []tls.Certificate{srvCert}, ClientAuth: RequireAndVerifyClientCert, ClientCAs: caPool, VerifyConnection: func(s *State) error { if ok := bytes.Equal(s.PeerCertificates[0], certificate.Raw); !ok { return errExample } return nil }, }, }, "RequireAndVerifyClientCert_callbacks": { clientCfg: &Config{ RootCAs: srvCAPool, // Certificates: []tls.Certificate{cert}, GetClientCertificate: func(cri *CertificateRequestInfo) (*tls.Certificate, error) { return &cert, nil }, }, serverCfg: &Config{ GetCertificate: func(chi *ClientHelloInfo) (*tls.Certificate, error) { return &srvCert, nil }, // Certificates: []tls.Certificate{srvCert}, ClientAuth: RequireAndVerifyClientCert, ClientCAs: caPool, }, }, } for name, tt := range tests { tt := tt t.Run(name, func(t *testing.T) { ca, cb := dpipe.Pipe() type result struct { c *Conn err error } c := make(chan result) go func() { client, err := Client(ca, tt.clientCfg) c <- result{client, err} }() server, err := Server(cb, tt.serverCfg) res := <-c defer func() { if err == nil { _ = server.Close() } if res.err == nil { _ = res.c.Close() } }() if tt.wantErr { if err != nil { // Error expected, test succeeded return } t.Error("Error expected") } if err != nil { t.Errorf("Server failed(%v)", err) } if res.err != nil { t.Errorf("Client failed(%v)", res.err) } actualClientCert := server.ConnectionState().PeerCertificates if tt.serverCfg.ClientAuth == RequireAnyClientCert || tt.serverCfg.ClientAuth == RequireAndVerifyClientCert { if actualClientCert == nil { t.Errorf("Client did not provide a certificate") } var cfgCert [][]byte if len(tt.clientCfg.Certificates) > 0 { cfgCert = tt.clientCfg.Certificates[0].Certificate } if tt.clientCfg.GetClientCertificate != nil { crt, err := tt.clientCfg.GetClientCertificate(&CertificateRequestInfo{}) if err != nil { t.Errorf("Server configuration did not provide a certificate") } cfgCert = crt.Certificate } if len(cfgCert) == 0 || !bytes.Equal(cfgCert[0], actualClientCert[0]) { t.Errorf("Client certificate was not communicated correctly") } } if tt.serverCfg.ClientAuth == NoClientCert { if actualClientCert != nil { t.Errorf("Client certificate wasn't expected") } } actualServerCert := res.c.ConnectionState().PeerCertificates if actualServerCert == nil { t.Errorf("Server did not provide a certificate") } var cfgCert [][]byte if len(tt.serverCfg.Certificates) > 0 { cfgCert = tt.serverCfg.Certificates[0].Certificate } if tt.serverCfg.GetCertificate != nil { crt, err := tt.serverCfg.GetCertificate(&ClientHelloInfo{}) if err != nil { t.Errorf("Server configuration did not provide a certificate") } cfgCert = crt.Certificate } if len(cfgCert) == 0 || !bytes.Equal(cfgCert[0], actualServerCert[0]) { t.Errorf("Server certificate was not communicated correctly") } }) } }) } func TestExtendedMasterSecret(t *testing.T) { // Check for leaking routines report := test.CheckRoutines(t) defer report() tests := map[string]struct { clientCfg *Config serverCfg *Config expectedClientErr error expectedServerErr error }{ "Request_Request_ExtendedMasterSecret": { clientCfg: &Config{ ExtendedMasterSecret: RequestExtendedMasterSecret, }, serverCfg: &Config{ ExtendedMasterSecret: RequestExtendedMasterSecret, }, expectedClientErr: nil, expectedServerErr: nil, }, "Request_Require_ExtendedMasterSecret": { clientCfg: &Config{ ExtendedMasterSecret: RequestExtendedMasterSecret, }, serverCfg: &Config{ ExtendedMasterSecret: RequireExtendedMasterSecret, }, expectedClientErr: nil, expectedServerErr: nil, }, "Request_Disable_ExtendedMasterSecret": { clientCfg: &Config{ ExtendedMasterSecret: RequestExtendedMasterSecret, }, serverCfg: &Config{ ExtendedMasterSecret: DisableExtendedMasterSecret, }, expectedClientErr: nil, expectedServerErr: nil, }, "Require_Request_ExtendedMasterSecret": { clientCfg: &Config{ ExtendedMasterSecret: RequireExtendedMasterSecret, }, serverCfg: &Config{ ExtendedMasterSecret: RequestExtendedMasterSecret, }, expectedClientErr: nil, expectedServerErr: nil, }, "Require_Require_ExtendedMasterSecret": { clientCfg: &Config{ ExtendedMasterSecret: RequireExtendedMasterSecret, }, serverCfg: &Config{ ExtendedMasterSecret: RequireExtendedMasterSecret, }, expectedClientErr: nil, expectedServerErr: nil, }, "Require_Disable_ExtendedMasterSecret": { clientCfg: &Config{ ExtendedMasterSecret: RequireExtendedMasterSecret, }, serverCfg: &Config{ ExtendedMasterSecret: DisableExtendedMasterSecret, }, expectedClientErr: errClientRequiredButNoServerEMS, expectedServerErr: &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}}, }, "Disable_Request_ExtendedMasterSecret": { clientCfg: &Config{ ExtendedMasterSecret: DisableExtendedMasterSecret, }, serverCfg: &Config{ ExtendedMasterSecret: RequestExtendedMasterSecret, }, expectedClientErr: nil, expectedServerErr: nil, }, "Disable_Require_ExtendedMasterSecret": { clientCfg: &Config{ ExtendedMasterSecret: DisableExtendedMasterSecret, }, serverCfg: &Config{ ExtendedMasterSecret: RequireExtendedMasterSecret, }, expectedClientErr: &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}}, expectedServerErr: errServerRequiredButNoClientEMS, }, "Disable_Disable_ExtendedMasterSecret": { clientCfg: &Config{ ExtendedMasterSecret: DisableExtendedMasterSecret, }, serverCfg: &Config{ ExtendedMasterSecret: DisableExtendedMasterSecret, }, expectedClientErr: nil, expectedServerErr: nil, }, } for name, tt := range tests { tt := tt t.Run(name, func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() ca, cb := dpipe.Pipe() type result struct { c *Conn err error } c := make(chan result) go func() { client, err := testClient(ctx, ca, tt.clientCfg, true) c <- result{client, err} }() server, err := testServer(ctx, cb, tt.serverCfg, true) res := <-c defer func() { if err == nil { _ = server.Close() } if res.err == nil { _ = res.c.Close() } }() if !errors.Is(res.err, tt.expectedClientErr) { t.Errorf("Client error expected: \"%v\" but got \"%v\"", tt.expectedClientErr, res.err) } if !errors.Is(err, tt.expectedServerErr) { t.Errorf("Server error expected: \"%v\" but got \"%v\"", tt.expectedServerErr, err) } }) } } func TestServerCertificate(t *testing.T) { // Check for leaking routines report := test.CheckRoutines(t) defer report() cert, err := selfsign.GenerateSelfSigned() if err != nil { t.Fatal(err) } certificate, err := x509.ParseCertificate(cert.Certificate[0]) if err != nil { t.Fatal(err) } caPool := x509.NewCertPool() caPool.AddCert(certificate) t.Run("parallel", func(t *testing.T) { // sync routines to check routine leak tests := map[string]struct { clientCfg *Config serverCfg *Config wantErr bool }{ "no_ca": { clientCfg: &Config{}, serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert}, wantErr: true, }, "good_ca": { clientCfg: &Config{RootCAs: caPool}, serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert}, }, "no_ca_skip_verify": { clientCfg: &Config{InsecureSkipVerify: true}, serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert}, }, "good_ca_skip_verify_custom_verify_peer": { clientCfg: &Config{RootCAs: caPool, Certificates: []tls.Certificate{cert}}, serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: RequireAnyClientCert, VerifyPeerCertificate: func(cert [][]byte, chain [][]*x509.Certificate) error { if len(chain) != 0 { return errNotExpectedChain } return nil }}, }, "good_ca_verify_custom_verify_peer": { clientCfg: &Config{RootCAs: caPool, Certificates: []tls.Certificate{cert}}, serverCfg: &Config{ClientCAs: caPool, Certificates: []tls.Certificate{cert}, ClientAuth: RequireAndVerifyClientCert, VerifyPeerCertificate: func(cert [][]byte, chain [][]*x509.Certificate) error { if len(chain) == 0 { return errExpecedChain } return nil }}, }, "good_ca_custom_verify_peer": { clientCfg: &Config{ RootCAs: caPool, VerifyPeerCertificate: func([][]byte, [][]*x509.Certificate) error { return errWrongCert }, }, serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert}, wantErr: true, }, "server_name": { clientCfg: &Config{RootCAs: caPool, ServerName: certificate.Subject.CommonName}, serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert}, }, "server_name_error": { clientCfg: &Config{RootCAs: caPool, ServerName: "barfoo"}, serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert}, wantErr: true, }, } for name, tt := range tests { tt := tt t.Run(name, func(t *testing.T) { ca, cb := dpipe.Pipe() type result struct { c *Conn err error } srvCh := make(chan result) go func() { s, err := Server(cb, tt.serverCfg) srvCh <- result{s, err} }() cli, err := Client(ca, tt.clientCfg) if err == nil { _ = cli.Close() } if !tt.wantErr && err != nil { t.Errorf("Client failed(%v)", err) } if tt.wantErr && err == nil { t.Fatal("Error expected") } srv := <-srvCh if srv.err == nil { _ = srv.c.Close() } }) } }) } func TestCipherSuiteConfiguration(t *testing.T) { // Check for leaking routines report := test.CheckRoutines(t) defer report() for _, test := range []struct { Name string ClientCipherSuites []CipherSuiteID ServerCipherSuites []CipherSuiteID WantClientError error WantServerError error WantSelectedCipherSuite CipherSuiteID }{ { Name: "No CipherSuites specified", ClientCipherSuites: nil, ServerCipherSuites: nil, WantClientError: nil, WantServerError: nil, }, { Name: "Invalid CipherSuite", ClientCipherSuites: []CipherSuiteID{0x00}, ServerCipherSuites: []CipherSuiteID{0x00}, WantClientError: &invalidCipherSuiteError{0x00}, WantServerError: &invalidCipherSuiteError{0x00}, }, { Name: "Valid CipherSuites specified", ClientCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, ServerCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, WantClientError: nil, WantServerError: nil, WantSelectedCipherSuite: TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, }, { Name: "CipherSuites mismatch", ClientCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, ServerCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA}, WantClientError: &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}}, WantServerError: errCipherSuiteNoIntersection, }, { Name: "Valid CipherSuites CCM specified", ClientCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_CCM}, ServerCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_CCM}, WantClientError: nil, WantServerError: nil, WantSelectedCipherSuite: TLS_ECDHE_ECDSA_WITH_AES_128_CCM, }, { Name: "Valid CipherSuites CCM-8 specified", ClientCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8}, ServerCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8}, WantClientError: nil, WantServerError: nil, WantSelectedCipherSuite: TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8, }, { Name: "Server supports subset of client suites", ClientCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA}, ServerCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA}, WantClientError: nil, WantServerError: nil, WantSelectedCipherSuite: TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, }, } { test := test t.Run(test.Name, func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() ca, cb := dpipe.Pipe() type result struct { c *Conn err error } c := make(chan result) go func() { client, err := testClient(ctx, ca, &Config{CipherSuites: test.ClientCipherSuites}, true) c <- result{client, err} }() server, err := testServer(ctx, cb, &Config{CipherSuites: test.ServerCipherSuites}, true) if err == nil { defer func() { _ = server.Close() }() } if !errors.Is(err, test.WantServerError) { t.Errorf("TestCipherSuiteConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantServerError, err) } res := <-c if res.err == nil { _ = server.Close() _ = res.c.Close() } if !errors.Is(res.err, test.WantClientError) { t.Errorf("TestSRTPConfiguration: Client Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantClientError, res.err) } if test.WantSelectedCipherSuite != 0x00 && res.c.state.cipherSuite.ID() != test.WantSelectedCipherSuite { t.Errorf("TestCipherSuiteConfiguration: Server Selected Bad Cipher Suite '%s': expected(%v) actual(%v)", test.Name, test.WantSelectedCipherSuite, res.c.state.cipherSuite.ID()) } }) } } func TestCertificateAndPSKServer(t *testing.T) { // Check for leaking routines report := test.CheckRoutines(t) defer report() for _, test := range []struct { Name string ClientPSK bool }{ { Name: "Client uses PKI", ClientPSK: false, }, { Name: "Client uses PSK", ClientPSK: true, }, } { test := test t.Run(test.Name, func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() ca, cb := dpipe.Pipe() type result struct { c *Conn err error } c := make(chan result) go func() { config := &Config{CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}} if test.ClientPSK { config.PSK = func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil } config.PSKIdentityHint = []byte{0x00} config.CipherSuites = []CipherSuiteID{TLS_PSK_WITH_AES_128_GCM_SHA256} } client, err := testClient(ctx, ca, config, false) c <- result{client, err} }() config := &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_PSK_WITH_AES_128_GCM_SHA256}, PSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil }, } server, err := testServer(ctx, cb, config, true) if err == nil { defer func() { _ = server.Close() }() } else { t.Errorf("TestCertificateAndPSKServer: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, nil, err) } res := <-c if res.err == nil { _ = server.Close() _ = res.c.Close() } else { t.Errorf("TestCertificateAndPSKServer: Client Error Mismatch '%s': expected(%v) actual(%v)", test.Name, nil, res.err) } }) } } func TestPSKConfiguration(t *testing.T) { // Check for leaking routines report := test.CheckRoutines(t) defer report() for _, test := range []struct { Name string ClientHasCertificate bool ServerHasCertificate bool ClientPSK PSKCallback ServerPSK PSKCallback ClientPSKIdentity []byte ServerPSKIdentity []byte WantClientError error WantServerError error }{ { Name: "PSK and no certificate specified", ClientHasCertificate: false, ServerHasCertificate: false, ClientPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil }, ServerPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil }, ClientPSKIdentity: []byte{0x00}, ServerPSKIdentity: []byte{0x00}, WantClientError: errNoAvailablePSKCipherSuite, WantServerError: errNoAvailablePSKCipherSuite, }, { Name: "PSK and certificate specified", ClientHasCertificate: true, ServerHasCertificate: true, ClientPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil }, ServerPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil }, ClientPSKIdentity: []byte{0x00}, ServerPSKIdentity: []byte{0x00}, WantClientError: errNoAvailablePSKCipherSuite, WantServerError: errNoAvailablePSKCipherSuite, }, { Name: "PSK and no identity specified", ClientHasCertificate: false, ServerHasCertificate: false, ClientPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil }, ServerPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil }, ClientPSKIdentity: nil, ServerPSKIdentity: nil, WantClientError: errPSKAndIdentityMustBeSetForClient, WantServerError: errNoAvailablePSKCipherSuite, }, { Name: "No PSK and identity specified", ClientHasCertificate: false, ServerHasCertificate: false, ClientPSK: nil, ServerPSK: nil, ClientPSKIdentity: []byte{0x00}, ServerPSKIdentity: []byte{0x00}, WantClientError: errIdentityNoPSK, WantServerError: errIdentityNoPSK, }, } { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() ca, cb := dpipe.Pipe() type result struct { c *Conn err error } c := make(chan result) go func() { client, err := testClient(ctx, ca, &Config{PSK: test.ClientPSK, PSKIdentityHint: test.ClientPSKIdentity}, test.ClientHasCertificate) c <- result{client, err} }() _, err := testServer(ctx, cb, &Config{PSK: test.ServerPSK, PSKIdentityHint: test.ServerPSKIdentity}, test.ServerHasCertificate) if err != nil || test.WantServerError != nil { if !(err != nil && test.WantServerError != nil && err.Error() == test.WantServerError.Error()) { t.Fatalf("TestPSKConfiguration: Server Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantServerError, err) } } res := <-c if res.err != nil || test.WantClientError != nil { if !(res.err != nil && test.WantClientError != nil && res.err.Error() == test.WantClientError.Error()) { t.Fatalf("TestPSKConfiguration: Client Error Mismatch '%s': expected(%v) actual(%v)", test.Name, test.WantClientError, res.err) } } } } func TestServerTimeout(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() cookie := make([]byte, 20) _, err := rand.Read(cookie) if err != nil { t.Fatal(err) } var rand [28]byte random := handshake.Random{GMTUnixTime: time.Unix(500, 0), RandomBytes: rand} cipherSuites := []CipherSuite{ &ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{}, &ciphersuite.TLSEcdheRsaWithAes128GcmSha256{}, } extensions := []extension.Extension{ &extension.SupportedSignatureAlgorithms{ SignatureHashAlgorithms: []signaturehash.Algorithm{ {Hash: hash.SHA256, Signature: signature.ECDSA}, {Hash: hash.SHA384, Signature: signature.ECDSA}, {Hash: hash.SHA512, Signature: signature.ECDSA}, {Hash: hash.SHA256, Signature: signature.RSA}, {Hash: hash.SHA384, Signature: signature.RSA}, {Hash: hash.SHA512, Signature: signature.RSA}, }, }, &extension.SupportedEllipticCurves{ EllipticCurves: []elliptic.Curve{elliptic.X25519, elliptic.P256, elliptic.P384}, }, &extension.SupportedPointFormats{ PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed}, }, } record := &recordlayer.RecordLayer{ Header: recordlayer.Header{ SequenceNumber: 0, Version: protocol.Version1_2, }, Content: &handshake.Handshake{ // sequenceNumber and messageSequence line up, may need to be re-evaluated Header: handshake.Header{ MessageSequence: 0, }, Message: &handshake.MessageClientHello{ Version: protocol.Version1_2, Cookie: cookie, Random: random, CipherSuiteIDs: cipherSuiteIDs(cipherSuites), CompressionMethods: defaultCompressionMethods(), Extensions: extensions, }, }, } packet, err := record.Marshal() if err != nil { t.Fatal(err) } ca, cb := dpipe.Pipe() defer func() { err := ca.Close() if err != nil { t.Fatal(err) } }() // Client reader caReadChan := make(chan []byte, 1000) go func() { for { data := make([]byte, 8192) n, err := ca.Read(data) if err != nil { return } caReadChan <- data[:n] } }() // Start sending ClientHello packets until server responds with first packet go func() { for { select { case <-time.After(10 * time.Millisecond): _, err := ca.Write(packet) if err != nil { return } case <-caReadChan: // Once we receive the first reply from the server, stop return } } }() ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) defer cancel() config := &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, FlightInterval: 100 * time.Millisecond, } _, serverErr := testServer(ctx, cb, config, true) var netErr net.Error if !errors.As(serverErr, &netErr) || !netErr.Timeout() { t.Fatalf("Client error exp(Temporary network error) failed(%v)", serverErr) } // Wait a little longer to ensure no additional messages have been sent by the server time.Sleep(300 * time.Millisecond) select { case msg := <-caReadChan: t.Fatalf("Expected no additional messages from server, got: %+v", msg) default: } } func TestProtocolVersionValidation(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() cookie := make([]byte, 20) if _, err := rand.Read(cookie); err != nil { t.Fatal(err) } var rand [28]byte random := handshake.Random{GMTUnixTime: time.Unix(500, 0), RandomBytes: rand} config := &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, FlightInterval: 100 * time.Millisecond, } t.Run("Server", func(t *testing.T) { serverCases := map[string]struct { records []*recordlayer.RecordLayer }{ "ClientHelloVersion": { records: []*recordlayer.RecordLayer{ { Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &handshake.Handshake{ Message: &handshake.MessageClientHello{ Version: protocol.Version{Major: 0xfe, Minor: 0xff}, // try to downgrade Cookie: cookie, Random: random, CipherSuiteIDs: []uint16{uint16((&ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{}).ID())}, CompressionMethods: defaultCompressionMethods(), }, }, }, }, }, "SecondsClientHelloVersion": { records: []*recordlayer.RecordLayer{ { Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &handshake.Handshake{ Message: &handshake.MessageClientHello{ Version: protocol.Version1_2, Cookie: cookie, Random: random, CipherSuiteIDs: []uint16{uint16((&ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{}).ID())}, CompressionMethods: defaultCompressionMethods(), }, }, }, { Header: recordlayer.Header{ Version: protocol.Version1_2, SequenceNumber: 1, }, Content: &handshake.Handshake{ Header: handshake.Header{ MessageSequence: 1, }, Message: &handshake.MessageClientHello{ Version: protocol.Version{Major: 0xfe, Minor: 0xff}, // try to downgrade Cookie: cookie, Random: random, CipherSuiteIDs: []uint16{uint16((&ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{}).ID())}, CompressionMethods: defaultCompressionMethods(), }, }, }, }, }, } for name, c := range serverCases { c := c t.Run(name, func(t *testing.T) { ca, cb := dpipe.Pipe() defer func() { err := ca.Close() if err != nil { t.Error(err) } }() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() var wg sync.WaitGroup wg.Add(1) defer wg.Wait() go func() { defer wg.Done() if _, err := testServer(ctx, cb, config, true); !errors.Is(err, errUnsupportedProtocolVersion) { t.Errorf("Client error exp(%v) failed(%v)", errUnsupportedProtocolVersion, err) } }() time.Sleep(50 * time.Millisecond) resp := make([]byte, 1024) for _, record := range c.records { packet, err := record.Marshal() if err != nil { t.Fatal(err) } if _, werr := ca.Write(packet); werr != nil { t.Fatal(werr) } n, rerr := ca.Read(resp[:cap(resp)]) if rerr != nil { t.Fatal(rerr) } resp = resp[:n] } h := &recordlayer.Header{} if err := h.Unmarshal(resp); err != nil { t.Fatal("Failed to unmarshal response") } if h.ContentType != protocol.ContentTypeAlert { t.Errorf("Peer must return alert to unsupported protocol version") } }) } }) t.Run("Client", func(t *testing.T) { clientCases := map[string]struct { records []*recordlayer.RecordLayer }{ "ServerHelloVersion": { records: []*recordlayer.RecordLayer{ { Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &handshake.Handshake{ Message: &handshake.MessageHelloVerifyRequest{ Version: protocol.Version1_2, Cookie: cookie, }, }, }, { Header: recordlayer.Header{ Version: protocol.Version1_2, SequenceNumber: 1, }, Content: &handshake.Handshake{ Header: handshake.Header{ MessageSequence: 1, }, Message: &handshake.MessageServerHello{ Version: protocol.Version{Major: 0xfe, Minor: 0xff}, // try to downgrade Random: random, CipherSuiteID: func() *uint16 { id := uint16(TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256); return &id }(), CompressionMethod: defaultCompressionMethods()[0], }, }, }, }, }, } for name, c := range clientCases { c := c t.Run(name, func(t *testing.T) { ca, cb := dpipe.Pipe() defer func() { err := ca.Close() if err != nil { t.Error(err) } }() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() var wg sync.WaitGroup wg.Add(1) defer wg.Wait() go func() { defer wg.Done() if _, err := testClient(ctx, cb, config, true); !errors.Is(err, errUnsupportedProtocolVersion) { t.Errorf("Server error exp(%v) failed(%v)", errUnsupportedProtocolVersion, err) } }() time.Sleep(50 * time.Millisecond) for _, record := range c.records { if _, err := ca.Read(make([]byte, 1024)); err != nil { t.Fatal(err) } packet, err := record.Marshal() if err != nil { t.Fatal(err) } if _, err := ca.Write(packet); err != nil { t.Fatal(err) } } resp := make([]byte, 1024) n, err := ca.Read(resp) if err != nil { t.Fatal(err) } resp = resp[:n] h := &recordlayer.Header{} if err := h.Unmarshal(resp); err != nil { t.Fatal("Failed to unmarshal response") } if h.ContentType != protocol.ContentTypeAlert { t.Errorf("Peer must return alert to unsupported protocol version") } }) } }) } func TestMultipleHelloVerifyRequest(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() cookies := [][]byte{ // first clientHello contains an empty cookie {}, } var packets [][]byte for i := 0; i < 2; i++ { cookie := make([]byte, 20) if _, err := rand.Read(cookie); err != nil { t.Fatal(err) } cookies = append(cookies, cookie) record := &recordlayer.RecordLayer{ Header: recordlayer.Header{ SequenceNumber: uint64(i), Version: protocol.Version1_2, }, Content: &handshake.Handshake{ Header: handshake.Header{ MessageSequence: uint16(i), }, Message: &handshake.MessageHelloVerifyRequest{ Version: protocol.Version1_2, Cookie: cookie, }, }, } packet, err := record.Marshal() if err != nil { t.Fatal(err) } packets = append(packets, packet) } ca, cb := dpipe.Pipe() defer func() { err := ca.Close() if err != nil { t.Error(err) } }() ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() var wg sync.WaitGroup wg.Add(1) defer wg.Wait() go func() { defer wg.Done() _, _ = testClient(ctx, ca, &Config{}, false) }() for i, cookie := range cookies { // read client hello resp := make([]byte, 1024) n, err := cb.Read(resp) if err != nil { t.Fatal(err) } record := &recordlayer.RecordLayer{} if err := record.Unmarshal(resp[:n]); err != nil { t.Fatal(err) } clientHello, ok := record.Content.(*handshake.Handshake).Message.(*handshake.MessageClientHello) if !ok { t.Fatal("Failed to cast MessageClientHello") } if !bytes.Equal(clientHello.Cookie, cookie) { t.Fatalf("Wrong cookie, expected: %x, got: %x", clientHello.Cookie, cookie) } if len(packets) <= i { break } // write hello verify request if _, err := cb.Write(packets[i]); err != nil { t.Fatal(err) } } cancel() } // Assert that a DTLS Server always responds with RenegotiationInfo if // a ClientHello contained that extension or not func TestRenegotationInfo(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(10 * time.Second) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() resp := make([]byte, 1024) for _, testCase := range []struct { Name string SendRenegotiationInfo bool }{ { "Include RenegotiationInfo", true, }, { "No RenegotiationInfo", false, }, } { test := testCase t.Run(test.Name, func(t *testing.T) { ca, cb := dpipe.Pipe() defer func() { if err := ca.Close(); err != nil { t.Error(err) } }() ctx, cancel := context.WithCancel(context.Background()) defer cancel() go func() { if _, err := testServer(ctx, cb, &Config{}, true); !errors.Is(err, context.Canceled) { t.Error(err) } }() time.Sleep(50 * time.Millisecond) extensions := []extension.Extension{} if test.SendRenegotiationInfo { extensions = append(extensions, &extension.RenegotiationInfo{ RenegotiatedConnection: 0, }) } err := sendClientHello([]byte{}, ca, 0, extensions) if err != nil { t.Fatal(err) } n, err := ca.Read(resp) if err != nil { t.Fatal(err) } r := &recordlayer.RecordLayer{} if err = r.Unmarshal(resp[:n]); err != nil { t.Fatal(err) } helloVerifyRequest, ok := r.Content.(*handshake.Handshake).Message.(*handshake.MessageHelloVerifyRequest) if !ok { t.Fatal("Failed to cast MessageHelloVerifyRequest") } err = sendClientHello(helloVerifyRequest.Cookie, ca, 1, extensions) if err != nil { t.Fatal(err) } if n, err = ca.Read(resp); err != nil { t.Fatal(err) } messages, err := recordlayer.UnpackDatagram(resp[:n]) if err != nil { t.Fatal(err) } if err := r.Unmarshal(messages[0]); err != nil { t.Fatal(err) } serverHello, ok := r.Content.(*handshake.Handshake).Message.(*handshake.MessageServerHello) if !ok { t.Fatal("Failed to cast MessageServerHello") } gotNegotationInfo := false for _, v := range serverHello.Extensions { if _, ok := v.(*extension.RenegotiationInfo); ok { gotNegotationInfo = true } } if !gotNegotationInfo { t.Fatalf("Received ServerHello without RenegotiationInfo") } }) } } func TestServerNameIndicationExtension(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() for _, test := range []struct { Name string ServerName string Expected []byte IncludeSNI bool }{ { Name: "Server name is a valid hostname", ServerName: "example.com", Expected: []byte("example.com"), IncludeSNI: true, }, { Name: "Server name is an IP literal", ServerName: "1.2.3.4", Expected: []byte(""), IncludeSNI: false, }, { Name: "Server name is empty", ServerName: "", Expected: []byte(""), IncludeSNI: false, }, } { test := test t.Run(test.Name, func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() ca, cb := dpipe.Pipe() go func() { conf := &Config{ ServerName: test.ServerName, } _, _ = testClient(ctx, ca, conf, false) }() // Receive ClientHello resp := make([]byte, 1024) n, err := cb.Read(resp) if err != nil { t.Fatal(err) } r := &recordlayer.RecordLayer{} if err = r.Unmarshal(resp[:n]); err != nil { t.Fatal(err) } clientHello, ok := r.Content.(*handshake.Handshake).Message.(*handshake.MessageClientHello) if !ok { t.Fatal("Failed to cast MessageClientHello") } gotSNI := false var actualServerName string for _, v := range clientHello.Extensions { if _, ok := v.(*extension.ServerName); ok { gotSNI = true extensionServerName, ok := v.(*extension.ServerName) if !ok { t.Fatal("Failed to cast extension.ServerName") } actualServerName = extensionServerName.ServerName } } if gotSNI != test.IncludeSNI { t.Errorf("TestSNI: unexpected SNI inclusion '%s': expected(%v) actual(%v)", test.Name, test.IncludeSNI, gotSNI) } if !bytes.Equal([]byte(actualServerName), test.Expected) { t.Errorf("TestSNI: server name mismatch '%s': expected(%v) actual(%v)", test.Name, test.Expected, actualServerName) } }) } } func TestALPNExtension(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() for _, test := range []struct { Name string ClientProtocolNameList []string ServerProtocolNameList []string ExpectedProtocol string ExpectAlertFromClient bool ExpectAlertFromServer bool Alert alert.Description }{ { Name: "Negotiate a protocol", ClientProtocolNameList: []string{"http/1.1", "spd/1"}, ServerProtocolNameList: []string{"spd/1"}, ExpectedProtocol: "spd/1", ExpectAlertFromClient: false, ExpectAlertFromServer: false, Alert: 0, }, { Name: "Server doesn't support any", ClientProtocolNameList: []string{"http/1.1", "spd/1"}, ServerProtocolNameList: []string{}, ExpectedProtocol: "", ExpectAlertFromClient: false, ExpectAlertFromServer: false, Alert: 0, }, { Name: "Negotiate with higher server precedence", ClientProtocolNameList: []string{"http/1.1", "spd/1", "http/3"}, ServerProtocolNameList: []string{"ssh/2", "http/3", "spd/1"}, ExpectedProtocol: "http/3", ExpectAlertFromClient: false, ExpectAlertFromServer: false, Alert: 0, }, { Name: "Empty intersection", ClientProtocolNameList: []string{"http/1.1", "http/3"}, ServerProtocolNameList: []string{"ssh/2", "spd/1"}, ExpectedProtocol: "", ExpectAlertFromClient: false, ExpectAlertFromServer: true, Alert: alert.NoApplicationProtocol, }, { Name: "Multiple protocols in ServerHello", ClientProtocolNameList: []string{"http/1.1"}, ServerProtocolNameList: []string{"http/1.1"}, ExpectedProtocol: "http/1.1", ExpectAlertFromClient: true, ExpectAlertFromServer: false, Alert: alert.InternalError, }, } { test := test t.Run(test.Name, func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() ca, cb := dpipe.Pipe() go func() { conf := &Config{ SupportedProtocols: test.ClientProtocolNameList, } _, _ = testClient(ctx, ca, conf, false) }() // Receive ClientHello resp := make([]byte, 1024) n, err := cb.Read(resp) if err != nil { t.Fatal(err) } ctx2, cancel2 := context.WithTimeout(context.Background(), 10*time.Second) defer cancel2() ca2, cb2 := dpipe.Pipe() go func() { conf := &Config{ SupportedProtocols: test.ServerProtocolNameList, } if _, err2 := testServer(ctx2, cb2, conf, true); !errors.Is(err2, context.Canceled) { if test.ExpectAlertFromServer { // Assert the error type? } else { t.Error(err2) } } }() time.Sleep(50 * time.Millisecond) // Forward ClientHello if _, err = ca2.Write(resp[:n]); err != nil { t.Fatal(err) } // Receive HelloVerify resp2 := make([]byte, 1024) n, err = ca2.Read(resp2) if err != nil { t.Fatal(err) } // Forward HelloVerify if _, err = cb.Write(resp2[:n]); err != nil { t.Fatal(err) } // Receive ClientHello resp3 := make([]byte, 1024) n, err = cb.Read(resp3) if err != nil { t.Fatal(err) } // Forward ClientHello if _, err = ca2.Write(resp3[:n]); err != nil { t.Fatal(err) } // Receive ServerHello resp4 := make([]byte, 1024) n, err = ca2.Read(resp4) if err != nil { t.Fatal(err) } messages, err := recordlayer.UnpackDatagram(resp4[:n]) if err != nil { t.Fatal(err) } r := &recordlayer.RecordLayer{} if err := r.Unmarshal(messages[0]); err != nil { t.Fatal(err) } if test.ExpectAlertFromServer { a, ok := r.Content.(*alert.Alert) if !ok { t.Fatal("Failed to cast alert.Alert") } if a.Description != test.Alert { t.Errorf("ALPN %v: expected(%v) actual(%v)", test.Name, test.Alert, a.Description) } } else { serverHello, ok := r.Content.(*handshake.Handshake).Message.(*handshake.MessageServerHello) if !ok { t.Fatal("Failed to cast handshake.MessageServerHello") } var negotiatedProtocol string for _, v := range serverHello.Extensions { if _, ok := v.(*extension.ALPN); ok { e, ok := v.(*extension.ALPN) if !ok { t.Fatal("Failed to cast extension.ALPN") } negotiatedProtocol = e.ProtocolNameList[0] // Manipulate ServerHello if test.ExpectAlertFromClient { e.ProtocolNameList = append(e.ProtocolNameList, "oops") } } } if negotiatedProtocol != test.ExpectedProtocol { t.Errorf("ALPN %v: expected(%v) actual(%v)", test.Name, test.ExpectedProtocol, negotiatedProtocol) } s, err := r.Marshal() if err != nil { t.Fatal(err) } // Forward ServerHello if _, err = cb.Write(s); err != nil { t.Fatal(err) } if test.ExpectAlertFromClient { resp5 := make([]byte, 1024) n, err = cb.Read(resp5) if err != nil { t.Fatal(err) } r2 := &recordlayer.RecordLayer{} if err := r2.Unmarshal(resp5[:n]); err != nil { t.Fatal(err) } a, ok := r2.Content.(*alert.Alert) if !ok { t.Fatal("Failed to cast alert.Alert") } if a.Description != test.Alert { t.Errorf("ALPN %v: expected(%v) actual(%v)", test.Name, test.Alert, a.Description) } } } time.Sleep(50 * time.Millisecond) // Give some time for returned errors }) } } // Make sure the supported_groups extension is not included in the ServerHello func TestSupportedGroupsExtension(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() t.Run("ServerHello Supported Groups", func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() ca, cb := dpipe.Pipe() go func() { if _, err := testServer(ctx, cb, &Config{}, true); !errors.Is(err, context.Canceled) { t.Error(err) } }() extensions := []extension.Extension{ &extension.SupportedEllipticCurves{ EllipticCurves: []elliptic.Curve{elliptic.X25519, elliptic.P256, elliptic.P384}, }, &extension.SupportedPointFormats{ PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed}, }, } time.Sleep(50 * time.Millisecond) resp := make([]byte, 1024) err := sendClientHello([]byte{}, ca, 0, extensions) if err != nil { t.Fatal(err) } // Receive ServerHello n, err := ca.Read(resp) if err != nil { t.Fatal(err) } r := &recordlayer.RecordLayer{} if err = r.Unmarshal(resp[:n]); err != nil { t.Fatal(err) } helloVerifyRequest, ok := r.Content.(*handshake.Handshake).Message.(*handshake.MessageHelloVerifyRequest) if !ok { t.Fatal("Failed to cast MessageHelloVerifyRequest") } err = sendClientHello(helloVerifyRequest.Cookie, ca, 1, extensions) if err != nil { t.Fatal(err) } if n, err = ca.Read(resp); err != nil { t.Fatal(err) } messages, err := recordlayer.UnpackDatagram(resp[:n]) if err != nil { t.Fatal(err) } if err := r.Unmarshal(messages[0]); err != nil { t.Fatal(err) } serverHello, ok := r.Content.(*handshake.Handshake).Message.(*handshake.MessageServerHello) if !ok { t.Fatal("Failed to cast MessageServerHello") } gotGroups := false for _, v := range serverHello.Extensions { if _, ok := v.(*extension.SupportedEllipticCurves); ok { gotGroups = true } } if gotGroups { t.Errorf("TestSupportedGroups: supported_groups extension was sent in ServerHello") } }) } func TestSessionResume(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() t.Run("resumed", func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() type result struct { c *Conn err error } clientRes := make(chan result, 1) ss := &memSessStore{} id, _ := hex.DecodeString("9b9fc92255634d9fb109febed42166717bb8ded8c738ba71bc7f2a0d9dae0306") secret, _ := hex.DecodeString("2e942a37aca5241deb2295b5fcedac221c7078d2503d2b62aeb48c880d7da73c001238b708559686b9da6e829c05ead7") s := Session{ID: id, Secret: secret} ca, cb := dpipe.Pipe() _ = ss.Set(id, s) _ = ss.Set([]byte(ca.RemoteAddr().String()+"_example.com"), s) go func() { config := &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, ServerName: "example.com", SessionStore: ss, MTU: 100, } c, err := testClient(ctx, ca, config, false) clientRes <- result{c, err} }() config := &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, ServerName: "example.com", SessionStore: ss, MTU: 100, } server, err := testServer(ctx, cb, config, true) if err != nil { t.Fatalf("TestSessionResume: Server failed(%v)", err) } actualSessionID := server.ConnectionState().SessionID actualMasterSecret := server.ConnectionState().masterSecret if !bytes.Equal(actualSessionID, id) { t.Errorf("TestSessionResumetion: SessionID Mismatch: expected(%v) actual(%v)", id, actualSessionID) } if !bytes.Equal(actualMasterSecret, secret) { t.Errorf("TestSessionResumetion: masterSecret Mismatch: expected(%v) actual(%v)", secret, actualMasterSecret) } defer func() { _ = server.Close() }() res := <-clientRes if res.err != nil { t.Fatal(res.err) } _ = res.c.Close() }) t.Run("new session", func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() type result struct { c *Conn err error } clientRes := make(chan result, 1) s1 := &memSessStore{} s2 := &memSessStore{} ca, cb := dpipe.Pipe() go func() { config := &Config{ ServerName: "example.com", SessionStore: s1, } c, err := testClient(ctx, ca, config, false) clientRes <- result{c, err} }() config := &Config{ SessionStore: s2, } server, err := testServer(ctx, cb, config, true) if err != nil { t.Fatalf("TestSessionResumetion: Server failed(%v)", err) } actualSessionID := server.ConnectionState().SessionID actualMasterSecret := server.ConnectionState().masterSecret ss, _ := s2.Get(actualSessionID) if !bytes.Equal(actualMasterSecret, ss.Secret) { t.Errorf("TestSessionResumetion: masterSecret Mismatch: expected(%v) actual(%v)", ss.Secret, actualMasterSecret) } defer func() { _ = server.Close() }() res := <-clientRes if res.err != nil { t.Fatal(res.err) } cs, _ := s1.Get([]byte(ca.RemoteAddr().String() + "_example.com")) if !bytes.Equal(actualMasterSecret, cs.Secret) { t.Errorf("TestSessionResumetion: masterSecret Mismatch: expected(%v) actual(%v)", ss.Secret, actualMasterSecret) } _ = res.c.Close() }) } type memSessStore struct { sync.Map } func (ms *memSessStore) Set(key []byte, s Session) error { k := hex.EncodeToString(key) ms.Store(k, s) return nil } func (ms *memSessStore) Get(key []byte) (Session, error) { k := hex.EncodeToString(key) v, ok := ms.Load(k) if !ok { return Session{}, nil } s, ok := v.(Session) if !ok { return Session{}, nil } return s, nil } func (ms *memSessStore) Del(key []byte) error { k := hex.EncodeToString(key) ms.Delete(k) return nil } // Assert that the server only uses CipherSuites with a hash+signature that matches // the certificate. As specified in rfc5246#section-7.4.3 func TestCipherSuiteMatchesCertificateType(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() for _, test := range []struct { Name string cipherList []CipherSuiteID expectedCipher CipherSuiteID generateRSA bool }{ { Name: "ECDSA Certificate with RSA CipherSuite first", cipherList: []CipherSuiteID{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, expectedCipher: TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, }, { Name: "RSA Certificate with ECDSA CipherSuite first", cipherList: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}, expectedCipher: TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, generateRSA: true, }, } { test := test t.Run(test.Name, func(t *testing.T) { clientErr := make(chan error, 1) client := make(chan *Conn, 1) ca, cb := dpipe.Pipe() go func() { c, err := testClient(context.TODO(), ca, &Config{CipherSuites: test.cipherList}, false) clientErr <- err client <- c }() var ( priv crypto.PrivateKey err error ) if test.generateRSA { if priv, err = rsa.GenerateKey(rand.Reader, 2048); err != nil { t.Fatal(err) } } else { if priv, err = ecdsa.GenerateKey(cryptoElliptic.P256(), rand.Reader); err != nil { t.Fatal(err) } } serverCert, err := selfsign.SelfSign(priv) if err != nil { t.Fatal(err) } if s, err := testServer(context.TODO(), cb, &Config{ CipherSuites: test.cipherList, Certificates: []tls.Certificate{serverCert}, }, false); err != nil { t.Fatal(err) } else if err = s.Close(); err != nil { t.Fatal(err) } if c, err := <-client, <-clientErr; err != nil { t.Fatal(err) } else if err := c.Close(); err != nil { t.Fatal(err) } else if c.ConnectionState().cipherSuite.ID() != test.expectedCipher { t.Fatalf("Expected(%s) and Actual(%s) CipherSuite do not match", test.expectedCipher, c.ConnectionState().cipherSuite.ID()) } }) } } // Test that we return the proper certificate if we are serving multiple ServerNames on a single Server func TestMultipleServerCertificates(t *testing.T) { fooCert, err := selfsign.GenerateSelfSignedWithDNS("foo") if err != nil { t.Fatal(err) } barCert, err := selfsign.GenerateSelfSignedWithDNS("bar") if err != nil { t.Fatal(err) } caPool := x509.NewCertPool() for _, cert := range []tls.Certificate{fooCert, barCert} { certificate, err := x509.ParseCertificate(cert.Certificate[0]) if err != nil { t.Fatal(err) } caPool.AddCert(certificate) } for _, test := range []struct { RequestServerName string ExpectedDNSName string }{ { "foo", "foo", }, { "bar", "bar", }, { "invalid", "foo", }, } { test := test t.Run(test.RequestServerName, func(t *testing.T) { clientErr := make(chan error, 2) client := make(chan *Conn, 1) ca, cb := dpipe.Pipe() go func() { c, err := testClient(context.TODO(), ca, &Config{ RootCAs: caPool, ServerName: test.RequestServerName, VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { certificate, err := x509.ParseCertificate(rawCerts[0]) if err != nil { return err } if certificate.DNSNames[0] != test.ExpectedDNSName { return errWrongCert } return nil }, }, false) clientErr <- err client <- c }() if s, err := testServer(context.TODO(), cb, &Config{Certificates: []tls.Certificate{fooCert, barCert}}, false); err != nil { t.Fatal(err) } else if err = s.Close(); err != nil { t.Fatal(err) } if c, err := <-client, <-clientErr; err != nil { t.Fatal(err) } else if err := c.Close(); err != nil { t.Fatal(err) } }) } } func TestEllipticCurveConfiguration(t *testing.T) { // Check for leaking routines report := test.CheckRoutines(t) defer report() for _, test := range []struct { Name string ConfigCurves []elliptic.Curve HadnshakeCurves []elliptic.Curve }{ { Name: "Curve defaulting", ConfigCurves: nil, HadnshakeCurves: defaultCurves, }, { Name: "Single curve", ConfigCurves: []elliptic.Curve{elliptic.X25519}, HadnshakeCurves: []elliptic.Curve{elliptic.X25519}, }, { Name: "Multiple curves", ConfigCurves: []elliptic.Curve{elliptic.P384, elliptic.X25519}, HadnshakeCurves: []elliptic.Curve{elliptic.P384, elliptic.X25519}, }, } { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() ca, cb := dpipe.Pipe() type result struct { c *Conn err error } c := make(chan result) go func() { client, err := testClient(ctx, ca, &Config{CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, EllipticCurves: test.ConfigCurves}, true) c <- result{client, err} }() server, err := testServer(ctx, cb, &Config{CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, EllipticCurves: test.ConfigCurves}, true) if err != nil { t.Fatalf("Server error: %v", err) } if len(test.ConfigCurves) == 0 && len(test.HadnshakeCurves) != len(server.fsm.cfg.ellipticCurves) { t.Fatalf("Failed to default Elliptic curves, expected %d, got: %d", len(test.HadnshakeCurves), len(server.fsm.cfg.ellipticCurves)) } if len(test.ConfigCurves) != 0 { if len(test.HadnshakeCurves) != len(server.fsm.cfg.ellipticCurves) { t.Fatalf("Failed to configure Elliptic curves, expect %d, got %d", len(test.HadnshakeCurves), len(server.fsm.cfg.ellipticCurves)) } for i, c := range test.ConfigCurves { if c != server.fsm.cfg.ellipticCurves[i] { t.Fatalf("Failed to maintain Elliptic curve order, expected %s, got %s", c, server.fsm.cfg.ellipticCurves[i]) } } } res := <-c if res.err != nil { t.Fatalf("Client error; %v", err) } defer func() { err = server.Close() if err != nil { t.Fatal(err) } err = res.c.Close() if err != nil { t.Fatal(err) } }() } } func TestSkipHelloVerify(t *testing.T) { report := test.CheckRoutines(t) defer report() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() ca, cb := dpipe.Pipe() certificate, err := selfsign.GenerateSelfSigned() if err != nil { t.Fatal(err) } gotHello := make(chan struct{}) go func() { server, sErr := testServer(ctx, cb, &Config{ Certificates: []tls.Certificate{certificate}, LoggerFactory: logging.NewDefaultLoggerFactory(), InsecureSkipVerifyHello: true, }, false) if sErr != nil { t.Error(sErr) return } buf := make([]byte, 1024) if _, sErr = server.Read(buf); sErr != nil { t.Error(sErr) } gotHello <- struct{}{} if sErr = server.Close(); sErr != nil { //nolint:contextcheck t.Error(sErr) } }() client, err := testClient(ctx, ca, &Config{ LoggerFactory: logging.NewDefaultLoggerFactory(), InsecureSkipVerify: true, }, false) if err != nil { t.Fatal(err) } if _, err = client.Write([]byte("hello")); err != nil { t.Error(err) } select { case <-gotHello: // OK case <-time.After(time.Second * 5): t.Error("timeout") } if err = client.Close(); err != nil { t.Error(err) } } dtls-2.2.6/crypto.go000066400000000000000000000160521437412644000143470ustar00rootroot00000000000000package dtls import ( "crypto" "crypto/ecdsa" "crypto/ed25519" "crypto/rand" "crypto/rsa" "crypto/sha256" "crypto/x509" "encoding/asn1" "encoding/binary" "math/big" "time" "github.com/pion/dtls/v2/pkg/crypto/elliptic" "github.com/pion/dtls/v2/pkg/crypto/hash" ) type ecdsaSignature struct { R, S *big.Int } func valueKeyMessage(clientRandom, serverRandom, publicKey []byte, namedCurve elliptic.Curve) []byte { serverECDHParams := make([]byte, 4) serverECDHParams[0] = 3 // named curve binary.BigEndian.PutUint16(serverECDHParams[1:], uint16(namedCurve)) serverECDHParams[3] = byte(len(publicKey)) plaintext := []byte{} plaintext = append(plaintext, clientRandom...) plaintext = append(plaintext, serverRandom...) plaintext = append(plaintext, serverECDHParams...) plaintext = append(plaintext, publicKey...) return plaintext } // If the client provided a "signature_algorithms" extension, then all // certificates provided by the server MUST be signed by a // hash/signature algorithm pair that appears in that extension // // https://tools.ietf.org/html/rfc5246#section-7.4.2 func generateKeySignature(clientRandom, serverRandom, publicKey []byte, namedCurve elliptic.Curve, privateKey crypto.PrivateKey, hashAlgorithm hash.Algorithm) ([]byte, error) { msg := valueKeyMessage(clientRandom, serverRandom, publicKey, namedCurve) switch p := privateKey.(type) { case ed25519.PrivateKey: // https://crypto.stackexchange.com/a/55483 return p.Sign(rand.Reader, msg, crypto.Hash(0)) case *ecdsa.PrivateKey: hashed := hashAlgorithm.Digest(msg) return p.Sign(rand.Reader, hashed, hashAlgorithm.CryptoHash()) case *rsa.PrivateKey: hashed := hashAlgorithm.Digest(msg) return p.Sign(rand.Reader, hashed, hashAlgorithm.CryptoHash()) } return nil, errKeySignatureGenerateUnimplemented } func verifyKeySignature(message, remoteKeySignature []byte, hashAlgorithm hash.Algorithm, rawCertificates [][]byte) error { //nolint:dupl if len(rawCertificates) == 0 { return errLengthMismatch } certificate, err := x509.ParseCertificate(rawCertificates[0]) if err != nil { return err } switch p := certificate.PublicKey.(type) { case ed25519.PublicKey: if ok := ed25519.Verify(p, message, remoteKeySignature); !ok { return errKeySignatureMismatch } return nil case *ecdsa.PublicKey: ecdsaSig := &ecdsaSignature{} if _, err := asn1.Unmarshal(remoteKeySignature, ecdsaSig); err != nil { return err } if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 { return errInvalidECDSASignature } hashed := hashAlgorithm.Digest(message) if !ecdsa.Verify(p, hashed, ecdsaSig.R, ecdsaSig.S) { return errKeySignatureMismatch } return nil case *rsa.PublicKey: switch certificate.SignatureAlgorithm { case x509.SHA1WithRSA, x509.SHA256WithRSA, x509.SHA384WithRSA, x509.SHA512WithRSA: hashed := hashAlgorithm.Digest(message) return rsa.VerifyPKCS1v15(p, hashAlgorithm.CryptoHash(), hashed, remoteKeySignature) default: return errKeySignatureVerifyUnimplemented } } return errKeySignatureVerifyUnimplemented } // If the server has sent a CertificateRequest message, the client MUST send the Certificate // message. The ClientKeyExchange message is now sent, and the content // of that message will depend on the public key algorithm selected // between the ClientHello and the ServerHello. If the client has sent // a certificate with signing ability, a digitally-signed // CertificateVerify message is sent to explicitly verify possession of // the private key in the certificate. // https://tools.ietf.org/html/rfc5246#section-7.3 func generateCertificateVerify(handshakeBodies []byte, privateKey crypto.PrivateKey, hashAlgorithm hash.Algorithm) ([]byte, error) { if p, ok := privateKey.(ed25519.PrivateKey); ok { // https://pkg.go.dev/crypto/ed25519#PrivateKey.Sign // Sign signs the given message with priv. Ed25519 performs two passes over // messages to be signed and therefore cannot handle pre-hashed messages. return p.Sign(rand.Reader, handshakeBodies, crypto.Hash(0)) } h := sha256.New() if _, err := h.Write(handshakeBodies); err != nil { return nil, err } hashed := h.Sum(nil) switch p := privateKey.(type) { case *ecdsa.PrivateKey: return p.Sign(rand.Reader, hashed, hashAlgorithm.CryptoHash()) case *rsa.PrivateKey: return p.Sign(rand.Reader, hashed, hashAlgorithm.CryptoHash()) } return nil, errInvalidSignatureAlgorithm } func verifyCertificateVerify(handshakeBodies []byte, hashAlgorithm hash.Algorithm, remoteKeySignature []byte, rawCertificates [][]byte) error { //nolint:dupl if len(rawCertificates) == 0 { return errLengthMismatch } certificate, err := x509.ParseCertificate(rawCertificates[0]) if err != nil { return err } switch p := certificate.PublicKey.(type) { case ed25519.PublicKey: if ok := ed25519.Verify(p, handshakeBodies, remoteKeySignature); !ok { return errKeySignatureMismatch } return nil case *ecdsa.PublicKey: ecdsaSig := &ecdsaSignature{} if _, err := asn1.Unmarshal(remoteKeySignature, ecdsaSig); err != nil { return err } if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 { return errInvalidECDSASignature } hash := hashAlgorithm.Digest(handshakeBodies) if !ecdsa.Verify(p, hash, ecdsaSig.R, ecdsaSig.S) { return errKeySignatureMismatch } return nil case *rsa.PublicKey: switch certificate.SignatureAlgorithm { case x509.SHA1WithRSA, x509.SHA256WithRSA, x509.SHA384WithRSA, x509.SHA512WithRSA: hash := hashAlgorithm.Digest(handshakeBodies) return rsa.VerifyPKCS1v15(p, hashAlgorithm.CryptoHash(), hash, remoteKeySignature) default: return errKeySignatureVerifyUnimplemented } } return errKeySignatureVerifyUnimplemented } func loadCerts(rawCertificates [][]byte) ([]*x509.Certificate, error) { if len(rawCertificates) == 0 { return nil, errLengthMismatch } certs := make([]*x509.Certificate, 0, len(rawCertificates)) for _, rawCert := range rawCertificates { cert, err := x509.ParseCertificate(rawCert) if err != nil { return nil, err } certs = append(certs, cert) } return certs, nil } func verifyClientCert(rawCertificates [][]byte, roots *x509.CertPool) (chains [][]*x509.Certificate, err error) { certificate, err := loadCerts(rawCertificates) if err != nil { return nil, err } intermediateCAPool := x509.NewCertPool() for _, cert := range certificate[1:] { intermediateCAPool.AddCert(cert) } opts := x509.VerifyOptions{ Roots: roots, CurrentTime: time.Now(), Intermediates: intermediateCAPool, KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, } return certificate[0].Verify(opts) } func verifyServerCert(rawCertificates [][]byte, roots *x509.CertPool, serverName string) (chains [][]*x509.Certificate, err error) { certificate, err := loadCerts(rawCertificates) if err != nil { return nil, err } intermediateCAPool := x509.NewCertPool() for _, cert := range certificate[1:] { intermediateCAPool.AddCert(cert) } opts := x509.VerifyOptions{ Roots: roots, CurrentTime: time.Now(), DNSName: serverName, Intermediates: intermediateCAPool, } return certificate[0].Verify(opts) } dtls-2.2.6/crypto_test.go000066400000000000000000000107711437412644000154100ustar00rootroot00000000000000package dtls import ( "bytes" "crypto/x509" "encoding/pem" "testing" "github.com/pion/dtls/v2/pkg/crypto/elliptic" "github.com/pion/dtls/v2/pkg/crypto/hash" ) const rawPrivateKey = ` -----BEGIN RSA PRIVATE KEY----- MIIEowIBAAKCAQEAxIA2BrrnR2sIlATsp7aRBD/3krwZ7vt9dNeoDQAee0s6SuYP 6MBx/HPnAkwNvPS90R05a7pwRkoT6Ur4PfPhCVlUe8lV+0Eto3ZSEeHz3HdsqlM3 bso67L7Dqrc7MdVstlKcgJi8yeAoGOIL9/igOv0XBFCeznm9nznx6mnsR5cugw+1 ypXelaHmBCLV7r5SeVSh57+KhvZGbQ2fFpUaTPegRpJZXBNS8lSeWvtOv9d6N5UB ROTAJodMZT5AfX0jB0QB9IT/0I96H6BSENH08NXOeXApMuLKvnAf361rS7cRAfRL rWZqERMP4u6Cnk0Cnckc3WcW27kGGIbtwbqUIQIDAQABAoIBAGF7OVIdZp8Hejn0 N3L8HvT8xtUEe9kS6ioM0lGgvX5s035Uo4/T6LhUx0VcdXRH9eLHnLTUyN4V4cra ZkxVsE3zAvZl60G6E+oDyLMWZOP6Wu4kWlub9597A5atT7BpMIVCdmFVZFLB4SJ3 AXkC3nplFAYP+Lh1rJxRIrIn2g+pEeBboWbYA++oDNuMQffDZaokTkJ8Bn1JZYh0 xEXKY8Bi2Egd5NMeZa1UFO6y8tUbZfwgVs6Enq5uOgtfayq79vZwyjj1kd29MBUD 8g8byV053ZKxbUOiOuUts97eb+fN3DIDRTcT2c+lXt/4C54M1FclJAbtYRK/qwsl pYWKQAECgYEA4ZUbqQnTo1ICvj81ifGrz+H4LKQqe92Hbf/W51D/Umk2kP702W22 HP4CvrJRtALThJIG9m2TwUjl/WAuZIBrhSAbIvc3Fcoa2HjdRp+sO5U1ueDq7d/S Z+PxRI8cbLbRpEdIaoR46qr/2uWZ943PHMv9h4VHPYn1w8b94hwD6vkCgYEA3v87 mFLzyM9ercnEv9zHMRlMZFQhlcUGQZvfb8BuJYl/WogyT6vRrUuM0QXULNEPlrin mBQTqc1nCYbgkFFsD2VVt1qIyiAJsB9MD1LNV6YuvE7T2KOSadmsA4fa9PUqbr71 hf3lTTq+LeR09LebO7WgSGYY+5YKVOEGpYMR1GkCgYEAxPVQmk3HKHEhjgRYdaG5 lp9A9ZE8uruYVJWtiHgzBTxx9TV2iST+fd/We7PsHFTfY3+wbpcMDBXfIVRKDVwH BMwchXH9+Ztlxx34bYJaegd0SmA0Hw9ugWEHNgoSEmWpM1s9wir5/ELjc7dGsFtz uzvsl9fpdLSxDYgAAdzeGtkCgYBAzKIgrVox7DBzB8KojhtD5ToRnXD0+H/M6OKQ srZPKhlb0V/tTtxrIx0UUEFLlKSXA6mPw6XDHfDnD86JoV9pSeUSlrhRI+Ysy6tq eIE7CwthpPZiaYXORHZ7wCqcK/HcpJjsCs9rFbrV0yE5S3FMdIbTAvgXg44VBB7O UbwIoQKBgDuY8gSrA5/A747wjjmsdRWK4DMTMEV4eCW1BEP7Tg7Cxd5n3xPJiYhr nhLGN+mMnVIcv2zEMS0/eNZr1j/0BtEdx+3IC6Eq+ONY0anZ4Irt57/5QeKgKn/L JPhfPySIPG4UmwE4gW8t79vfOKxnUu2fDD1ZXUYopan6EckACNH/ -----END RSA PRIVATE KEY----- ` func TestGenerateKeySignature(t *testing.T) { block, _ := pem.Decode([]byte(rawPrivateKey)) key, err := x509.ParsePKCS1PrivateKey(block.Bytes) if err != nil { t.Error(err) } clientRandom := []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f} serverRandom := []byte{0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f} publicKey := []byte{0x20, 0x9f, 0xd7, 0xad, 0x6d, 0xcf, 0xf4, 0x29, 0x8d, 0xd3, 0xf9, 0x6d, 0x5b, 0x1b, 0x2a, 0xf9, 0x10, 0xa0, 0x53, 0x5b, 0x14, 0x88, 0xd7, 0xf8, 0xfa, 0xbb, 0x34, 0x9a, 0x98, 0x28, 0x80, 0xb6, 0x15} expectedSignature := []byte{ 0x6f, 0x47, 0x97, 0x85, 0xcc, 0x76, 0x50, 0x93, 0xbd, 0xe2, 0x6a, 0x69, 0x0b, 0xc3, 0x03, 0xd1, 0xb7, 0xe4, 0xab, 0x88, 0x7b, 0xa6, 0x52, 0x80, 0xdf, 0xaa, 0x25, 0x7a, 0xdb, 0x29, 0x32, 0xe4, 0xd8, 0x28, 0x28, 0xb3, 0xe8, 0x04, 0x3c, 0x38, 0x16, 0xfc, 0x78, 0xe9, 0x15, 0x7b, 0xc5, 0xbd, 0x7d, 0xfc, 0xcd, 0x83, 0x00, 0x57, 0x4a, 0x3c, 0x23, 0x85, 0x75, 0x6b, 0x37, 0xd5, 0x89, 0x72, 0x73, 0xf0, 0x44, 0x8c, 0x00, 0x70, 0x1f, 0x6e, 0xa2, 0x81, 0xd0, 0x09, 0xc5, 0x20, 0x36, 0xab, 0x23, 0x09, 0x40, 0x1f, 0x4d, 0x45, 0x96, 0x62, 0xbb, 0x81, 0xb0, 0x30, 0x72, 0xad, 0x3a, 0x0a, 0xac, 0x31, 0x63, 0x40, 0x52, 0x0a, 0x27, 0xf3, 0x34, 0xde, 0x27, 0x7d, 0xb7, 0x54, 0xff, 0x0f, 0x9f, 0x5a, 0xfe, 0x07, 0x0f, 0x4e, 0x9f, 0x53, 0x04, 0x34, 0x62, 0xf4, 0x30, 0x74, 0x83, 0x35, 0xfc, 0xe4, 0x7e, 0xbf, 0x5a, 0xc4, 0x52, 0xd0, 0xea, 0xf9, 0x61, 0x4e, 0xf5, 0x1c, 0x0e, 0x58, 0x02, 0x71, 0xfb, 0x1f, 0x34, 0x55, 0xe8, 0x36, 0x70, 0x3c, 0xc1, 0xcb, 0xc9, 0xb7, 0xbb, 0xb5, 0x1c, 0x44, 0x9a, 0x6d, 0x88, 0x78, 0x98, 0xd4, 0x91, 0x2e, 0xeb, 0x98, 0x81, 0x23, 0x30, 0x73, 0x39, 0x43, 0xd5, 0xbb, 0x70, 0x39, 0xba, 0x1f, 0xdb, 0x70, 0x9f, 0x91, 0x83, 0x56, 0xc2, 0xde, 0xed, 0x17, 0x6d, 0x2c, 0x3e, 0x21, 0xea, 0x36, 0xb4, 0x91, 0xd8, 0x31, 0x05, 0x60, 0x90, 0xfd, 0xc6, 0x74, 0xa9, 0x7b, 0x18, 0xfc, 0x1c, 0x6a, 0x1c, 0x6e, 0xec, 0xd3, 0xc1, 0xc0, 0x0d, 0x11, 0x25, 0x48, 0x37, 0x3d, 0x45, 0x11, 0xa2, 0x31, 0x14, 0x0a, 0x66, 0x9f, 0xd8, 0xac, 0x74, 0xa2, 0xcd, 0xc8, 0x79, 0xb3, 0x9e, 0xc6, 0x66, 0x25, 0xcf, 0x2c, 0x87, 0x5e, 0x5c, 0x36, 0x75, 0x86, } signature, err := generateKeySignature(clientRandom, serverRandom, publicKey, elliptic.X25519, key, hash.SHA256) if err != nil { t.Error(err) } else if !bytes.Equal(expectedSignature, signature) { t.Errorf("Signature generation failed \nexp % 02x \nactual % 02x ", expectedSignature, signature) } } dtls-2.2.6/dtls.go000066400000000000000000000001251437412644000137670ustar00rootroot00000000000000// Package dtls implements Datagram Transport Layer Security (DTLS) 1.2 package dtls dtls-2.2.6/e2e/000077500000000000000000000000001437412644000131475ustar00rootroot00000000000000dtls-2.2.6/e2e/Dockerfile000066400000000000000000000002521437412644000151400ustar00rootroot00000000000000FROM docker.io/library/golang:1.18-bullseye COPY . /go/src/github.com/pion/dtls WORKDIR /go/src/github.com/pion/dtls/e2e CMD ["go", "test", "-tags=openssl", "-v", "."] dtls-2.2.6/e2e/e2e.go000066400000000000000000000001031437412644000141430ustar00rootroot00000000000000// Package e2e contains end to end tests for pion/dtls package e2e dtls-2.2.6/e2e/e2e_lossy_test.go000066400000000000000000000111551437412644000164440ustar00rootroot00000000000000package e2e import ( "crypto/tls" "fmt" "math/rand" "testing" "time" "github.com/pion/dtls/v2" "github.com/pion/dtls/v2/pkg/crypto/selfsign" transportTest "github.com/pion/transport/v2/test" ) const ( flightInterval = time.Millisecond * 100 lossyTestTimeout = 30 * time.Second ) /* DTLS Client/Server over a lossy transport, just asserts it can handle at increasing increments */ func TestPionE2ELossy(t *testing.T) { // Check for leaking routines report := transportTest.CheckRoutines(t) defer report() type runResult struct { dtlsConn *dtls.Conn err error } serverCert, err := selfsign.GenerateSelfSigned() if err != nil { t.Fatal(err) } clientCert, err := selfsign.GenerateSelfSigned() if err != nil { t.Fatal(err) } for _, test := range []struct { LossChanceRange int DoClientAuth bool CipherSuites []dtls.CipherSuiteID MTU int }{ { LossChanceRange: 0, }, { LossChanceRange: 10, }, { LossChanceRange: 20, }, { LossChanceRange: 50, }, { LossChanceRange: 0, DoClientAuth: true, }, { LossChanceRange: 10, DoClientAuth: true, }, { LossChanceRange: 20, DoClientAuth: true, }, { LossChanceRange: 50, DoClientAuth: true, }, { LossChanceRange: 0, CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA}, }, { LossChanceRange: 10, CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA}, }, { LossChanceRange: 20, CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA}, }, { LossChanceRange: 50, CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA}, }, { LossChanceRange: 10, MTU: 100, DoClientAuth: true, }, { LossChanceRange: 20, MTU: 100, DoClientAuth: true, }, { LossChanceRange: 50, MTU: 100, DoClientAuth: true, }, } { name := fmt.Sprintf("Loss%d_MTU%d", test.LossChanceRange, test.MTU) if test.DoClientAuth { name += "_WithCliAuth" } for _, ciph := range test.CipherSuites { name += "_With" + ciph.String() } test := test t.Run(name, func(t *testing.T) { // Limit runtime in case of deadlocks lim := transportTest.TimeOut(lossyTestTimeout + time.Second) defer lim.Stop() rand.Seed(time.Now().UTC().UnixNano()) chosenLoss := rand.Intn(9) + test.LossChanceRange //nolint:gosec serverDone := make(chan runResult) clientDone := make(chan runResult) br := transportTest.NewBridge() if err = br.SetLossChance(chosenLoss); err != nil { t.Fatal(err) } go func() { cfg := &dtls.Config{ FlightInterval: flightInterval, CipherSuites: test.CipherSuites, InsecureSkipVerify: true, MTU: test.MTU, } if test.DoClientAuth { cfg.Certificates = []tls.Certificate{clientCert} } client, startupErr := dtls.Client(br.GetConn0(), cfg) clientDone <- runResult{client, startupErr} }() go func() { cfg := &dtls.Config{ Certificates: []tls.Certificate{serverCert}, FlightInterval: flightInterval, MTU: test.MTU, } if test.DoClientAuth { cfg.ClientAuth = dtls.RequireAnyClientCert } server, startupErr := dtls.Server(br.GetConn1(), cfg) serverDone <- runResult{server, startupErr} }() testTimer := time.NewTimer(lossyTestTimeout) var serverConn, clientConn *dtls.Conn defer func() { if serverConn != nil { if err = serverConn.Close(); err != nil { t.Error(err) } } if clientConn != nil { if err = clientConn.Close(); err != nil { t.Error(err) } } }() for { if serverConn != nil && clientConn != nil { break } br.Tick() select { case serverResult := <-serverDone: if serverResult.err != nil { t.Errorf("Fail, serverError: clientComplete(%t) serverComplete(%t) LossChance(%d) error(%v)", clientConn != nil, serverConn != nil, chosenLoss, serverResult.err) return } serverConn = serverResult.dtlsConn case clientResult := <-clientDone: if clientResult.err != nil { t.Errorf("Fail, clientError: clientComplete(%t) serverComplete(%t) LossChance(%d) error(%v)", clientConn != nil, serverConn != nil, chosenLoss, clientResult.err) return } clientConn = clientResult.dtlsConn case <-testTimer.C: t.Errorf("Test expired: clientComplete(%t) serverComplete(%t) LossChance(%d)", clientConn != nil, serverConn != nil, chosenLoss) return case <-time.After(10 * time.Millisecond): } } }) } } dtls-2.2.6/e2e/e2e_openssl_test.go000066400000000000000000000204011437412644000167500ustar00rootroot00000000000000//go:build openssl && !js // +build openssl,!js package e2e import ( "crypto/x509" "encoding/pem" "errors" "fmt" "io/ioutil" "net" "os" "os/exec" "regexp" "strings" "testing" "time" "github.com/pion/dtls/v2" ) func serverOpenSSL(c *comm) { go func() { c.serverMutex.Lock() defer c.serverMutex.Unlock() cfg := c.serverConfig // create openssl arguments args := []string{ "s_server", "-dtls1_2", "-quiet", "-verify_quiet", "-verify_return_error", fmt.Sprintf("-accept=%d", c.serverPort), } ciphers := ciphersOpenSSL(cfg) if ciphers != "" { args = append(args, fmt.Sprintf("-cipher=%s", ciphers)) } // psk arguments if cfg.PSK != nil { psk, err := cfg.PSK(nil) if err != nil { c.errChan <- err return } args = append(args, fmt.Sprintf("-psk=%X", psk)) if len(cfg.PSKIdentityHint) > 0 { args = append(args, fmt.Sprintf("-psk_hint=%s", cfg.PSKIdentityHint)) } } // certs arguments if len(cfg.Certificates) > 0 { // create temporary cert files certPEM, keyPEM, err := writeTempPEM(cfg) if err != nil { c.errChan <- err return } args = append(args, fmt.Sprintf("-cert=%s", certPEM), fmt.Sprintf("-key=%s", keyPEM)) defer func() { _ = os.Remove(certPEM) _ = os.Remove(keyPEM) }() } else { args = append(args, "-nocert") } // launch command // #nosec G204 cmd := exec.CommandContext(c.ctx, "openssl", args...) var inner net.Conn inner, c.serverConn = net.Pipe() cmd.Stdin = inner cmd.Stdout = inner cmd.Stderr = os.Stderr if err := cmd.Start(); err != nil { c.errChan <- err _ = inner.Close() return } // Ensure that server has started time.Sleep(500 * time.Millisecond) c.serverReady <- struct{}{} simpleReadWrite(c.errChan, c.serverChan, c.serverConn, c.messageRecvCount) }() } func clientOpenSSL(c *comm) { select { case <-c.serverReady: // OK case <-time.After(time.Second): c.errChan <- errors.New("waiting on serverReady err: timeout") } c.clientMutex.Lock() defer c.clientMutex.Unlock() cfg := c.clientConfig // create openssl arguments args := []string{ "s_client", "-dtls1_2", "-quiet", "-verify_quiet", "-servername=localhost", fmt.Sprintf("-connect=127.0.0.1:%d", c.serverPort), } ciphers := ciphersOpenSSL(cfg) if ciphers != "" { args = append(args, fmt.Sprintf("-cipher=%s", ciphers)) } // psk arguments if cfg.PSK != nil { psk, err := cfg.PSK(nil) if err != nil { c.errChan <- err return } args = append(args, fmt.Sprintf("-psk=%X", psk)) } // certificate arguments if len(cfg.Certificates) > 0 { // create temporary cert files certPEM, keyPEM, err := writeTempPEM(cfg) if err != nil { c.errChan <- err return } args = append(args, fmt.Sprintf("-CAfile=%s", certPEM), fmt.Sprintf("-cert=%s", certPEM), fmt.Sprintf("-key=%s", keyPEM)) defer func() { _ = os.Remove(certPEM) _ = os.Remove(keyPEM) }() } if !cfg.InsecureSkipVerify { args = append(args, "-verify_return_error") } // launch command // #nosec G204 cmd := exec.CommandContext(c.ctx, "openssl", args...) var inner net.Conn inner, c.clientConn = net.Pipe() cmd.Stdin = inner cmd.Stdout = inner cmd.Stderr = os.Stderr if err := cmd.Start(); err != nil { c.errChan <- err _ = inner.Close() return } simpleReadWrite(c.errChan, c.clientChan, c.clientConn, c.messageRecvCount) } func ciphersOpenSSL(cfg *dtls.Config) string { // See https://tls.mbed.org/supported-ssl-ciphersuites translate := map[dtls.CipherSuiteID]string{ dtls.TLS_ECDHE_ECDSA_WITH_AES_128_CCM: "ECDHE-ECDSA-AES128-CCM", dtls.TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8: "ECDHE-ECDSA-AES128-CCM8", dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: "ECDHE-ECDSA-AES128-GCM-SHA256", dtls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: "ECDHE-ECDSA-AES256-GCM-SHA384", dtls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: "ECDHE-RSA-AES128-GCM-SHA256", dtls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384: "ECDHE-RSA-AES256-GCM-SHA384", dtls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA: "ECDHE-ECDSA-AES256-SHA", dtls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA: "ECDHE-RSA-AES256-SHA", dtls.TLS_PSK_WITH_AES_128_CCM: "PSK-AES128-CCM", dtls.TLS_PSK_WITH_AES_128_CCM_8: "PSK-AES128-CCM8", dtls.TLS_PSK_WITH_AES_256_CCM_8: "PSK-AES256-CCM8", dtls.TLS_PSK_WITH_AES_128_GCM_SHA256: "PSK-AES128-GCM-SHA256", dtls.TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256: "ECDHE-PSK-AES128-CBC-SHA256", } var ciphers []string for _, c := range cfg.CipherSuites { if text, ok := translate[c]; ok { ciphers = append(ciphers, text) } } return strings.Join(ciphers, ";") } func writeTempPEM(cfg *dtls.Config) (string, string, error) { certOut, err := ioutil.TempFile("", "cert.pem") if err != nil { return "", "", fmt.Errorf("failed to create temporary file: %w", err) } keyOut, err := ioutil.TempFile("", "key.pem") if err != nil { return "", "", fmt.Errorf("failed to create temporary file: %w", err) } cert := cfg.Certificates[0] derBytes := cert.Certificate[0] if err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil { return "", "", fmt.Errorf("failed to write data to cert.pem: %w", err) } if err = certOut.Close(); err != nil { return "", "", fmt.Errorf("error closing cert.pem: %w", err) } priv := cert.PrivateKey var privBytes []byte privBytes, err = x509.MarshalPKCS8PrivateKey(priv) if err != nil { return "", "", fmt.Errorf("unable to marshal private key: %w", err) } if err = pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil { return "", "", fmt.Errorf("failed to write data to key.pem: %w", err) } if err = keyOut.Close(); err != nil { return "", "", fmt.Errorf("error closing key.pem: %w", err) } return certOut.Name(), keyOut.Name(), nil } func minimumOpenSSLVersion(t *testing.T) bool { t.Helper() cmd := exec.Command("openssl", "version") allOut, err := cmd.CombinedOutput() if err != nil { t.Log("Cannot determine OpenSSL version: ", err) return false } verMatch := regexp.MustCompile(`(?i)^OpenSSL\s(?P(\d+\.)?(\d+\.)?(\*|\d+)(\w)?).+$`) match := verMatch.FindStringSubmatch(strings.TrimSpace(string(allOut))) params := map[string]string{} for i, name := range verMatch.SubexpNames() { if i > 0 && i <= len(match) { params[name] = match[i] } } var ver string if val, ok := params["version"]; !ok { t.Log("Could not extract OpenSSL version") return false } else { ver = val } cmp := strings.Compare(ver, "3.0.0") if cmp == -1 { return false } return true } func TestPionOpenSSLE2ESimple(t *testing.T) { t.Run("OpenSSLServer", func(t *testing.T) { testPionE2ESimple(t, serverOpenSSL, clientPion) }) t.Run("OpenSSLClient", func(t *testing.T) { testPionE2ESimple(t, serverPion, clientOpenSSL) }) } func TestPionOpenSSLE2ESimplePSK(t *testing.T) { t.Run("OpenSSLServer", func(t *testing.T) { testPionE2ESimplePSK(t, serverOpenSSL, clientPion) }) t.Run("OpenSSLClient", func(t *testing.T) { testPionE2ESimplePSK(t, serverPion, clientOpenSSL) }) } func TestPionOpenSSLE2EMTUs(t *testing.T) { t.Run("OpenSSLServer", func(t *testing.T) { testPionE2EMTUs(t, serverOpenSSL, clientPion) }) t.Run("OpenSSLClient", func(t *testing.T) { testPionE2EMTUs(t, serverPion, clientOpenSSL) }) } func TestPionOpenSSLE2ESimpleED25519(t *testing.T) { t.Run("OpenSSLServer", func(t *testing.T) { if !minimumOpenSSLVersion(t) { t.Skip("Cannot use OpenSSL < 3.0 as a DTLS server with ED25519 keys") } testPionE2ESimpleED25519(t, serverOpenSSL, clientPion) }) t.Run("OpenSSLClient", func(t *testing.T) { testPionE2ESimpleED25519(t, serverPion, clientOpenSSL) }) } func TestPionOpenSSLE2ESimpleED25519ClientCert(t *testing.T) { t.Run("OpenSSLServer", func(t *testing.T) { if !minimumOpenSSLVersion(t) { t.Skip("Cannot use OpenSSL < 3.0 as a DTLS server with ED25519 keys") } testPionE2ESimpleED25519ClientCert(t, serverOpenSSL, clientPion) }) t.Run("OpenSSLClient", func(t *testing.T) { testPionE2ESimpleED25519ClientCert(t, serverPion, clientOpenSSL) }) } func TestPionOpenSSLE2ESimpleECDSAClientCert(t *testing.T) { t.Run("OpenSSLServer", func(t *testing.T) { testPionE2ESimpleECDSAClientCert(t, serverOpenSSL, clientPion) }) t.Run("OpenSSLClient", func(t *testing.T) { testPionE2ESimpleECDSAClientCert(t, serverPion, clientOpenSSL) }) } dtls-2.2.6/e2e/e2e_test.go000066400000000000000000000310731437412644000152140ustar00rootroot00000000000000//go:build !js // +build !js package e2e import ( "context" "crypto/ed25519" "crypto/rand" "crypto/rsa" "crypto/tls" "crypto/x509" "errors" "fmt" "io" "net" "sync" "sync/atomic" "testing" "time" "github.com/pion/dtls/v2" "github.com/pion/dtls/v2/pkg/crypto/selfsign" "github.com/pion/transport/v2/test" ) const ( testMessage = "Hello World" testTimeLimit = 5 * time.Second messageRetry = 200 * time.Millisecond ) var errServerTimeout = errors.New("waiting on serverReady err: timeout") func randomPort(t testing.TB) int { t.Helper() conn, err := net.ListenPacket("udp4", "127.0.0.1:0") if err != nil { t.Fatalf("failed to pickPort: %v", err) } defer func() { _ = conn.Close() }() switch addr := conn.LocalAddr().(type) { case *net.UDPAddr: return addr.Port default: t.Fatalf("unknown addr type %T", addr) return 0 } } func simpleReadWrite(errChan chan error, outChan chan string, conn io.ReadWriter, messageRecvCount *uint64) { go func() { buffer := make([]byte, 8192) n, err := conn.Read(buffer) if err != nil { errChan <- err return } outChan <- string(buffer[:n]) atomic.AddUint64(messageRecvCount, 1) }() for { if atomic.LoadUint64(messageRecvCount) == 2 { break } else if _, err := conn.Write([]byte(testMessage)); err != nil { errChan <- err break } time.Sleep(messageRetry) } } type comm struct { ctx context.Context clientConfig, serverConfig *dtls.Config serverPort int messageRecvCount *uint64 // Counter to make sure both sides got a message clientMutex *sync.Mutex clientConn net.Conn serverMutex *sync.Mutex serverConn net.Conn serverListener net.Listener serverReady chan struct{} errChan chan error clientChan chan string serverChan chan string client func(*comm) server func(*comm) } func newComm(ctx context.Context, clientConfig, serverConfig *dtls.Config, serverPort int, server, client func(*comm)) *comm { messageRecvCount := uint64(0) c := &comm{ ctx: ctx, clientConfig: clientConfig, serverConfig: serverConfig, serverPort: serverPort, messageRecvCount: &messageRecvCount, clientMutex: &sync.Mutex{}, serverMutex: &sync.Mutex{}, serverReady: make(chan struct{}), errChan: make(chan error), clientChan: make(chan string), serverChan: make(chan string), server: server, client: client, } return c } func (c *comm) assert(t *testing.T) { // DTLS Client go c.client(c) // DTLS Server go c.server(c) defer func() { if c.clientConn != nil { if err := c.clientConn.Close(); err != nil { t.Fatal(err) } } if c.serverConn != nil { if err := c.serverConn.Close(); err != nil { t.Fatal(err) } } if c.serverListener != nil { if err := c.serverListener.Close(); err != nil { t.Fatal(err) } } }() func() { seenClient, seenServer := false, false for { select { case err := <-c.errChan: t.Fatal(err) case <-time.After(testTimeLimit): t.Fatalf("Test timeout, seenClient %t seenServer %t", seenClient, seenServer) case clientMsg := <-c.clientChan: if clientMsg != testMessage { t.Fatalf("clientMsg does not equal test message: %s %s", clientMsg, testMessage) } seenClient = true if seenClient && seenServer { return } case serverMsg := <-c.serverChan: if serverMsg != testMessage { t.Fatalf("serverMsg does not equal test message: %s %s", serverMsg, testMessage) } seenServer = true if seenClient && seenServer { return } } } }() } func clientPion(c *comm) { select { case <-c.serverReady: // OK case <-time.After(time.Second): c.errChan <- errServerTimeout } c.clientMutex.Lock() defer c.clientMutex.Unlock() var err error c.clientConn, err = dtls.DialWithContext(c.ctx, "udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: c.serverPort}, c.clientConfig, ) if err != nil { c.errChan <- err return } simpleReadWrite(c.errChan, c.clientChan, c.clientConn, c.messageRecvCount) } func serverPion(c *comm) { c.serverMutex.Lock() defer c.serverMutex.Unlock() var err error c.serverListener, err = dtls.Listen("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: c.serverPort}, c.serverConfig, ) if err != nil { c.errChan <- err return } c.serverReady <- struct{}{} c.serverConn, err = c.serverListener.Accept() if err != nil { c.errChan <- err return } simpleReadWrite(c.errChan, c.serverChan, c.serverConn, c.messageRecvCount) } /* Simple DTLS Client/Server can communicate - Assert that you can send messages both ways - Assert that Close() on both ends work - Assert that no Goroutines are leaked */ func testPionE2ESimple(t *testing.T, server, client func(*comm)) { lim := test.TimeOut(time.Second * 30) defer lim.Stop() report := test.CheckRoutines(t) defer report() for _, cipherSuite := range []dtls.CipherSuiteID{ dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, dtls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, dtls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, } { cipherSuite := cipherSuite t.Run(cipherSuite.String(), func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() cert, err := selfsign.GenerateSelfSignedWithDNS("localhost") if err != nil { t.Fatal(err) } cfg := &dtls.Config{ Certificates: []tls.Certificate{cert}, CipherSuites: []dtls.CipherSuiteID{cipherSuite}, InsecureSkipVerify: true, } serverPort := randomPort(t) comm := newComm(ctx, cfg, cfg, serverPort, server, client) comm.assert(t) }) } } func testPionE2ESimplePSK(t *testing.T, server, client func(*comm)) { lim := test.TimeOut(time.Second * 30) defer lim.Stop() report := test.CheckRoutines(t) defer report() for _, cipherSuite := range []dtls.CipherSuiteID{ dtls.TLS_PSK_WITH_AES_128_CCM, dtls.TLS_PSK_WITH_AES_128_CCM_8, dtls.TLS_PSK_WITH_AES_256_CCM_8, dtls.TLS_PSK_WITH_AES_128_GCM_SHA256, dtls.TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256, } { cipherSuite := cipherSuite t.Run(cipherSuite.String(), func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() cfg := &dtls.Config{ PSK: func(hint []byte) ([]byte, error) { return []byte{0xAB, 0xC1, 0x23}, nil }, PSKIdentityHint: []byte{0x01, 0x02, 0x03, 0x04, 0x05}, CipherSuites: []dtls.CipherSuiteID{cipherSuite}, } serverPort := randomPort(t) comm := newComm(ctx, cfg, cfg, serverPort, server, client) comm.assert(t) }) } } func testPionE2EMTUs(t *testing.T, server, client func(*comm)) { lim := test.TimeOut(time.Second * 30) defer lim.Stop() report := test.CheckRoutines(t) defer report() for _, mtu := range []int{ 10000, 1000, 100, } { mtu := mtu t.Run(fmt.Sprintf("MTU%d", mtu), func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() cert, err := selfsign.GenerateSelfSignedWithDNS("localhost") if err != nil { t.Fatal(err) } cfg := &dtls.Config{ Certificates: []tls.Certificate{cert}, CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, InsecureSkipVerify: true, MTU: mtu, } serverPort := randomPort(t) comm := newComm(ctx, cfg, cfg, serverPort, server, client) comm.assert(t) }) } } func testPionE2ESimpleED25519(t *testing.T, server, client func(*comm)) { lim := test.TimeOut(time.Second * 30) defer lim.Stop() report := test.CheckRoutines(t) defer report() for _, cipherSuite := range []dtls.CipherSuiteID{ dtls.TLS_ECDHE_ECDSA_WITH_AES_128_CCM, dtls.TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8, dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, dtls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, dtls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, } { cipherSuite := cipherSuite t.Run(cipherSuite.String(), func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() _, key, err := ed25519.GenerateKey(rand.Reader) if err != nil { t.Fatal(err) } cert, err := selfsign.SelfSign(key) if err != nil { t.Fatal(err) } cfg := &dtls.Config{ Certificates: []tls.Certificate{cert}, CipherSuites: []dtls.CipherSuiteID{cipherSuite}, InsecureSkipVerify: true, } serverPort := randomPort(t) comm := newComm(ctx, cfg, cfg, serverPort, server, client) comm.assert(t) }) } } func testPionE2ESimpleED25519ClientCert(t *testing.T, server, client func(*comm)) { lim := test.TimeOut(time.Second * 30) defer lim.Stop() report := test.CheckRoutines(t) defer report() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() _, skey, err := ed25519.GenerateKey(rand.Reader) if err != nil { t.Fatal(err) } scert, err := selfsign.SelfSign(skey) if err != nil { t.Fatal(err) } _, ckey, err := ed25519.GenerateKey(rand.Reader) if err != nil { t.Fatal(err) } ccert, err := selfsign.SelfSign(ckey) if err != nil { t.Fatal(err) } scfg := &dtls.Config{ Certificates: []tls.Certificate{scert}, CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, ClientAuth: dtls.RequireAnyClientCert, } ccfg := &dtls.Config{ Certificates: []tls.Certificate{ccert}, CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, InsecureSkipVerify: true, } serverPort := randomPort(t) comm := newComm(ctx, ccfg, scfg, serverPort, server, client) comm.assert(t) } func testPionE2ESimpleECDSAClientCert(t *testing.T, server, client func(*comm)) { lim := test.TimeOut(time.Second * 30) defer lim.Stop() report := test.CheckRoutines(t) defer report() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() scert, err := selfsign.GenerateSelfSigned() if err != nil { t.Fatal(err) } ccert, err := selfsign.GenerateSelfSigned() if err != nil { t.Fatal(err) } clientCAs := x509.NewCertPool() caCert, err := x509.ParseCertificate(ccert.Certificate[0]) if err != nil { t.Fatal(err) } clientCAs.AddCert(caCert) scfg := &dtls.Config{ ClientCAs: clientCAs, Certificates: []tls.Certificate{scert}, CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, ClientAuth: dtls.RequireAnyClientCert, } ccfg := &dtls.Config{ Certificates: []tls.Certificate{ccert}, CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, InsecureSkipVerify: true, } serverPort := randomPort(t) comm := newComm(ctx, ccfg, scfg, serverPort, server, client) comm.assert(t) } func testPionE2ESimpleRSAClientCert(t *testing.T, server, client func(*comm)) { lim := test.TimeOut(time.Second * 30) defer lim.Stop() report := test.CheckRoutines(t) defer report() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() spriv, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { t.Fatal(err) } scert, err := selfsign.SelfSign(spriv) if err != nil { t.Fatal(err) } cpriv, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { t.Fatal(err) } ccert, err := selfsign.SelfSign(cpriv) if err != nil { t.Fatal(err) } scfg := &dtls.Config{ Certificates: []tls.Certificate{scert}, CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}, ClientAuth: dtls.RequireAnyClientCert, } ccfg := &dtls.Config{ Certificates: []tls.Certificate{ccert}, CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}, InsecureSkipVerify: true, } serverPort := randomPort(t) comm := newComm(ctx, ccfg, scfg, serverPort, server, client) comm.assert(t) } func TestPionE2ESimple(t *testing.T) { testPionE2ESimple(t, serverPion, clientPion) } func TestPionE2ESimplePSK(t *testing.T) { testPionE2ESimplePSK(t, serverPion, clientPion) } func TestPionE2EMTUs(t *testing.T) { testPionE2EMTUs(t, serverPion, clientPion) } func TestPionE2ESimpleED25519(t *testing.T) { testPionE2ESimpleED25519(t, serverPion, clientPion) } func TestPionE2ESimpleED25519ClientCert(t *testing.T) { testPionE2ESimpleED25519ClientCert(t, serverPion, clientPion) } func TestPionE2ESimpleECDSAClientCert(t *testing.T) { testPionE2ESimpleECDSAClientCert(t, serverPion, clientPion) } func TestPionE2ESimpleRSAClientCert(t *testing.T) { testPionE2ESimpleRSAClientCert(t, serverPion, clientPion) } dtls-2.2.6/errors.go000066400000000000000000000234721437412644000143470ustar00rootroot00000000000000package dtls import ( "context" "errors" "fmt" "io" "net" "os" "github.com/pion/dtls/v2/pkg/protocol" "github.com/pion/dtls/v2/pkg/protocol/alert" ) // Typed errors var ( ErrConnClosed = &FatalError{Err: errors.New("conn is closed")} //nolint:goerr113 errDeadlineExceeded = &TimeoutError{Err: fmt.Errorf("read/write timeout: %w", context.DeadlineExceeded)} errInvalidContentType = &TemporaryError{Err: errors.New("invalid content type")} //nolint:goerr113 errBufferTooSmall = &TemporaryError{Err: errors.New("buffer is too small")} //nolint:goerr113 errContextUnsupported = &TemporaryError{Err: errors.New("context is not supported for ExportKeyingMaterial")} //nolint:goerr113 errHandshakeInProgress = &TemporaryError{Err: errors.New("handshake is in progress")} //nolint:goerr113 errReservedExportKeyingMaterial = &TemporaryError{Err: errors.New("ExportKeyingMaterial can not be used with a reserved label")} //nolint:goerr113 errApplicationDataEpochZero = &TemporaryError{Err: errors.New("ApplicationData with epoch of 0")} //nolint:goerr113 errUnhandledContextType = &TemporaryError{Err: errors.New("unhandled contentType")} //nolint:goerr113 errCertificateVerifyNoCertificate = &FatalError{Err: errors.New("client sent certificate verify but we have no certificate to verify")} //nolint:goerr113 errCipherSuiteNoIntersection = &FatalError{Err: errors.New("client+server do not support any shared cipher suites")} //nolint:goerr113 errClientCertificateNotVerified = &FatalError{Err: errors.New("client sent certificate but did not verify it")} //nolint:goerr113 errClientCertificateRequired = &FatalError{Err: errors.New("server required client verification, but got none")} //nolint:goerr113 errClientNoMatchingSRTPProfile = &FatalError{Err: errors.New("server responded with SRTP Profile we do not support")} //nolint:goerr113 errClientRequiredButNoServerEMS = &FatalError{Err: errors.New("client required Extended Master Secret extension, but server does not support it")} //nolint:goerr113 errCookieMismatch = &FatalError{Err: errors.New("client+server cookie does not match")} //nolint:goerr113 errIdentityNoPSK = &FatalError{Err: errors.New("PSK Identity Hint provided but PSK is nil")} //nolint:goerr113 errInvalidCertificate = &FatalError{Err: errors.New("no certificate provided")} //nolint:goerr113 errInvalidCipherSuite = &FatalError{Err: errors.New("invalid or unknown cipher suite")} //nolint:goerr113 errInvalidECDSASignature = &FatalError{Err: errors.New("ECDSA signature contained zero or negative values")} //nolint:goerr113 errInvalidPrivateKey = &FatalError{Err: errors.New("invalid private key type")} //nolint:goerr113 errInvalidSignatureAlgorithm = &FatalError{Err: errors.New("invalid signature algorithm")} //nolint:goerr113 errKeySignatureMismatch = &FatalError{Err: errors.New("expected and actual key signature do not match")} //nolint:goerr113 errNilNextConn = &FatalError{Err: errors.New("Conn can not be created with a nil nextConn")} //nolint:goerr113 errNoAvailableCipherSuites = &FatalError{Err: errors.New("connection can not be created, no CipherSuites satisfy this Config")} //nolint:goerr113 errNoAvailablePSKCipherSuite = &FatalError{Err: errors.New("connection can not be created, pre-shared key present but no compatible CipherSuite")} //nolint:goerr113 errNoAvailableCertificateCipherSuite = &FatalError{Err: errors.New("connection can not be created, certificate present but no compatible CipherSuite")} //nolint:goerr113 errNoAvailableSignatureSchemes = &FatalError{Err: errors.New("connection can not be created, no SignatureScheme satisfy this Config")} //nolint:goerr113 errNoCertificates = &FatalError{Err: errors.New("no certificates configured")} //nolint:goerr113 errNoConfigProvided = &FatalError{Err: errors.New("no config provided")} //nolint:goerr113 errNoSupportedEllipticCurves = &FatalError{Err: errors.New("client requested zero or more elliptic curves that are not supported by the server")} //nolint:goerr113 errUnsupportedProtocolVersion = &FatalError{Err: errors.New("unsupported protocol version")} //nolint:goerr113 errPSKAndIdentityMustBeSetForClient = &FatalError{Err: errors.New("PSK and PSK Identity Hint must both be set for client")} //nolint:goerr113 errRequestedButNoSRTPExtension = &FatalError{Err: errors.New("SRTP support was requested but server did not respond with use_srtp extension")} //nolint:goerr113 errServerNoMatchingSRTPProfile = &FatalError{Err: errors.New("client requested SRTP but we have no matching profiles")} //nolint:goerr113 errServerRequiredButNoClientEMS = &FatalError{Err: errors.New("server requires the Extended Master Secret extension, but the client does not support it")} //nolint:goerr113 errVerifyDataMismatch = &FatalError{Err: errors.New("expected and actual verify data does not match")} //nolint:goerr113 errNotAcceptableCertificateChain = &FatalError{Err: errors.New("certificate chain is not signed by an acceptable CA")} //nolint:goerr113 errInvalidFlight = &InternalError{Err: errors.New("invalid flight number")} //nolint:goerr113 errKeySignatureGenerateUnimplemented = &InternalError{Err: errors.New("unable to generate key signature, unimplemented")} //nolint:goerr113 errKeySignatureVerifyUnimplemented = &InternalError{Err: errors.New("unable to verify key signature, unimplemented")} //nolint:goerr113 errLengthMismatch = &InternalError{Err: errors.New("data length and declared length do not match")} //nolint:goerr113 errSequenceNumberOverflow = &InternalError{Err: errors.New("sequence number overflow")} //nolint:goerr113 errInvalidFSMTransition = &InternalError{Err: errors.New("invalid state machine transition")} //nolint:goerr113 errFailedToAccessPoolReadBuffer = &InternalError{Err: errors.New("failed to access pool read buffer")} //nolint:goerr113 errFragmentBufferOverflow = &InternalError{Err: errors.New("fragment buffer overflow")} //nolint:goerr113 ) // FatalError indicates that the DTLS connection is no longer available. // It is mainly caused by wrong configuration of server or client. type FatalError = protocol.FatalError // InternalError indicates and internal error caused by the implementation, and the DTLS connection is no longer available. // It is mainly caused by bugs or tried to use unimplemented features. type InternalError = protocol.InternalError // TemporaryError indicates that the DTLS connection is still available, but the request was failed temporary. type TemporaryError = protocol.TemporaryError // TimeoutError indicates that the request was timed out. type TimeoutError = protocol.TimeoutError // HandshakeError indicates that the handshake failed. type HandshakeError = protocol.HandshakeError // errInvalidCipherSuite indicates an attempt at using an unsupported cipher suite. type invalidCipherSuiteError struct { id CipherSuiteID } func (e *invalidCipherSuiteError) Error() string { return fmt.Sprintf("CipherSuite with id(%d) is not valid", e.id) } func (e *invalidCipherSuiteError) Is(err error) bool { var other *invalidCipherSuiteError if errors.As(err, &other) { return e.id == other.id } return false } // errAlert wraps DTLS alert notification as an error type alertError struct { *alert.Alert } func (e *alertError) Error() string { return fmt.Sprintf("alert: %s", e.Alert.String()) } func (e *alertError) IsFatalOrCloseNotify() bool { return e.Level == alert.Fatal || e.Description == alert.CloseNotify } func (e *alertError) Is(err error) bool { var other *alertError if errors.As(err, &other) { return e.Level == other.Level && e.Description == other.Description } return false } // netError translates an error from underlying Conn to corresponding net.Error. func netError(err error) error { switch { case errors.Is(err, io.EOF), errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded): // Return io.EOF and context errors as is. return err } var ( ne net.Error opError *net.OpError se *os.SyscallError ) if errors.As(err, &opError) { if errors.As(opError, &se) { if se.Timeout() { return &TimeoutError{Err: err} } if isOpErrorTemporary(se) { return &TemporaryError{Err: err} } } } if errors.As(err, &ne) { return err } return &FatalError{Err: err} } dtls-2.2.6/errors_errno.go000066400000000000000000000010641437412644000155450ustar00rootroot00000000000000//go:build aix || darwin || dragonfly || freebsd || linux || nacl || nacljs || netbsd || openbsd || solaris || windows // +build aix darwin dragonfly freebsd linux nacl nacljs netbsd openbsd solaris windows // For systems having syscall.Errno. // Update build targets by following command: // $ grep -R ECONN $(go env GOROOT)/src/syscall/zerrors_*.go \ // | tr "." "_" | cut -d"_" -f"2" | sort | uniq package dtls import ( "errors" "os" "syscall" ) func isOpErrorTemporary(err *os.SyscallError) bool { return errors.Is(err.Err, syscall.ECONNREFUSED) } dtls-2.2.6/errors_errno_test.go000066400000000000000000000021421437412644000166020ustar00rootroot00000000000000//go:build aix || darwin || dragonfly || freebsd || linux || nacl || nacljs || netbsd || openbsd || solaris || windows // +build aix darwin dragonfly freebsd linux nacl nacljs netbsd openbsd solaris windows // For systems having syscall.Errno. // The build target must be same as errors_errno.go. package dtls import ( "errors" "net" "testing" ) func TestErrorsTemporary(t *testing.T) { addrListen, errListen := net.ResolveUDPAddr("udp", "localhost:0") if errListen != nil { t.Fatalf("Unexpected error: %v", errListen) } // Server is not listening. conn, errDial := net.DialUDP("udp", nil, addrListen) if errDial != nil { t.Fatalf("Unexpected error: %v", errDial) } _, _ = conn.Write([]byte{0x00}) // trigger _, err := conn.Read(make([]byte, 10)) _ = conn.Close() if err == nil { t.Skip("ECONNREFUSED is not set by system") } var ne net.Error if !errors.As(netError(err), &ne) { t.Fatalf("netError must return net.Error") } if ne.Timeout() { t.Errorf("%v must not be timeout error", err) } if !ne.Temporary() { //nolint:staticcheck t.Errorf("%v must be temporary error", err) } } dtls-2.2.6/errors_noerrno.go000066400000000000000000000006461437412644000161070ustar00rootroot00000000000000//go:build !aix && !darwin && !dragonfly && !freebsd && !linux && !nacl && !nacljs && !netbsd && !openbsd && !solaris && !windows // +build !aix,!darwin,!dragonfly,!freebsd,!linux,!nacl,!nacljs,!netbsd,!openbsd,!solaris,!windows // For systems without syscall.Errno. // Build targets must be inverse of errors_errno.go package dtls import ( "os" ) func isOpErrorTemporary(err *os.SyscallError) bool { return false } dtls-2.2.6/errors_test.go000066400000000000000000000040611437412644000153770ustar00rootroot00000000000000package dtls import ( "errors" "fmt" "net" "testing" ) var errExample = errors.New("an example error") func TestErrorUnwrap(t *testing.T) { cases := []struct { err error errUnwrapped []error }{ { &FatalError{Err: errExample}, []error{errExample}, }, { &TemporaryError{Err: errExample}, []error{errExample}, }, { &InternalError{Err: errExample}, []error{errExample}, }, { &TimeoutError{Err: errExample}, []error{errExample}, }, { &HandshakeError{Err: errExample}, []error{errExample}, }, } for _, c := range cases { c := c t.Run(fmt.Sprintf("%T", c.err), func(t *testing.T) { err := c.err for _, unwrapped := range c.errUnwrapped { e := errors.Unwrap(err) if !errors.Is(e, unwrapped) { t.Errorf("Unwrapped error is expected to be '%v', got '%v'", unwrapped, e) } } }) } } func TestErrorNetError(t *testing.T) { cases := []struct { err error str string timeout, temporary bool }{ {&FatalError{Err: errExample}, "dtls fatal: an example error", false, false}, {&TemporaryError{Err: errExample}, "dtls temporary: an example error", false, true}, {&InternalError{Err: errExample}, "dtls internal: an example error", false, false}, {&TimeoutError{Err: errExample}, "dtls timeout: an example error", true, true}, {&HandshakeError{Err: errExample}, "handshake error: an example error", false, false}, {&HandshakeError{Err: &TimeoutError{Err: errExample}}, "handshake error: dtls timeout: an example error", true, true}, } for _, c := range cases { c := c t.Run(fmt.Sprintf("%T", c.err), func(t *testing.T) { var ne net.Error if !errors.As(c.err, &ne) { t.Fatalf("%T doesn't implement net.Error", c.err) } if ne.Timeout() != c.timeout { t.Errorf("%T.Timeout() should be %v", c.err, c.timeout) } if ne.Temporary() != c.temporary { //nolint:staticcheck t.Errorf("%T.Temporary() should be %v", c.err, c.temporary) } if ne.Error() != c.str { t.Errorf("%T.Error() should be %v", c.err, c.str) } }) } } dtls-2.2.6/examples/000077500000000000000000000000001437412644000143125ustar00rootroot00000000000000dtls-2.2.6/examples/certificates/000077500000000000000000000000001437412644000167575ustar00rootroot00000000000000dtls-2.2.6/examples/certificates/README.md000066400000000000000000000024321437412644000202370ustar00rootroot00000000000000# Certificates The certificates in for the examples are generated using the commands shown below. Note that this was run on OpenSSL 1.1.1d, of which the arguments can be found in the [OpenSSL Manpages](https://www.openssl.org/docs/man1.1.1/man1), and is not guaranteed to work on different OpenSSL versions. ```shell # Extensions required for certificate validation. $ EXTFILE='extfile.conf' $ echo 'subjectAltName = IP:127.0.0.1\nbasicConstraints = critical,CA:true' > "${EXTFILE}" # Server. $ SERVER_NAME='server' $ openssl ecparam -name prime256v1 -genkey -noout -out "${SERVER_NAME}.pem" $ openssl req -key "${SERVER_NAME}.pem" -new -sha256 -subj '/C=NL' -out "${SERVER_NAME}.csr" $ openssl x509 -req -in "${SERVER_NAME}.csr" -extfile "${EXTFILE}" -days 365 -signkey "${SERVER_NAME}.pem" -sha256 -out "${SERVER_NAME}.pub.pem" # Client. $ CLIENT_NAME='client' $ openssl ecparam -name prime256v1 -genkey -noout -out "${CLIENT_NAME}.pem" $ openssl req -key "${CLIENT_NAME}.pem" -new -sha256 -subj '/C=NL' -out "${CLIENT_NAME}.csr" $ openssl x509 -req -in "${CLIENT_NAME}.csr" -extfile "${EXTFILE}" -days 365 -CA "${SERVER_NAME}.pub.pem" -CAkey "${SERVER_NAME}.pem" -set_serial '0xabcd' -sha256 -out "${CLIENT_NAME}.pub.pem" # Cleanup. $ rm "${EXTFILE}" "${SERVER_NAME}.csr" "${CLIENT_NAME}.csr" ``` dtls-2.2.6/examples/certificates/client.pem000066400000000000000000000003431437412644000207400ustar00rootroot00000000000000-----BEGIN EC PRIVATE KEY----- MHcCAQEEIGOO78dEAcepxdUIeDzC28jMcFrJr2q7x+UdhgtJ/RS3oAoGCCqGSM49 AwEHoUQDQgAEGLSNxlkJ9mETKI2Hogq3Cyh06pJKA1YMgcKqYKS6yQQlvvk5rU88 +RojFPgXJukymhfIJmw4eGxxEMSjuEZY7w== -----END EC PRIVATE KEY----- dtls-2.2.6/examples/certificates/client.pub.pem000066400000000000000000000007251437412644000215310ustar00rootroot00000000000000-----BEGIN CERTIFICATE----- MIIBLTCB1aADAgECAgMAq80wCgYIKoZIzj0EAwIwDTELMAkGA1UEBhMCTkwwHhcN MjAwMzIwMDk0NjQ0WhcNMjEwMzIwMDk0NjQ0WjANMQswCQYDVQQGEwJOTDBZMBMG ByqGSM49AgEGCCqGSM49AwEHA0IABBi0jcZZCfZhEyiNh6IKtwsodOqSSgNWDIHC qmCkuskEJb75Oa1PPPkaIxT4FybpMpoXyCZsOHhscRDEo7hGWO+jJDAiMA8GA1Ud EQQIMAaHBH8AAAEwDwYDVR0TAQH/BAUwAwEB/zAKBggqhkjOPQQDAgNHADBEAiBx sIkcADN9E60veZOFOeANaRWAiQaLWZfUxqkOmfHztQIgI2CfHMjDQwJZFh35HvFs NOPJj8wxFhqR5pqMF23cgOY= -----END CERTIFICATE----- dtls-2.2.6/examples/certificates/server.pem000066400000000000000000000003431437412644000207700ustar00rootroot00000000000000-----BEGIN EC PRIVATE KEY----- MHcCAQEEIDT8Xyx5RpPP+98ulYZKsvKIVdBUJug/L9H2M8JThv+GoAoGCCqGSM49 AwEHoUQDQgAE6Wf0qQqIb5G7g51P83Dh1Yst52kyntGYz1Bt6S7crpmQFs9ZRZMy bJ6MGIwGcVBMgoL3pfxDKdZ3mnzmoibU0w== -----END EC PRIVATE KEY----- dtls-2.2.6/examples/certificates/server.pub.pem000066400000000000000000000007551437412644000215640ustar00rootroot00000000000000-----BEGIN CERTIFICATE----- MIIBPzCB5qADAgECAhRtzyVTL+9D0KHfbcKYeKckpLVRmTAKBggqhkjOPQQDAjAN MQswCQYDVQQGEwJOTDAeFw0yMDAzMjAwOTQ2NDRaFw0yMTAzMjAwOTQ2NDRaMA0x CzAJBgNVBAYTAk5MMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE6Wf0qQqIb5G7 g51P83Dh1Yst52kyntGYz1Bt6S7crpmQFs9ZRZMybJ6MGIwGcVBMgoL3pfxDKdZ3 mnzmoibU06MkMCIwDwYDVR0RBAgwBocEfwAAATAPBgNVHRMBAf8EBTADAQH/MAoG CCqGSM49BAMCA0gAMEUCIQD000SU+klkNLGvHZcMYNVkCFsImnGKIqPMy3LELSiF 0gIgSGIFkNEIAyNxn44CXZJu3piyz1ouK2fLefDJMYfcXgM= -----END CERTIFICATE----- dtls-2.2.6/examples/dial/000077500000000000000000000000001437412644000152235ustar00rootroot00000000000000dtls-2.2.6/examples/dial/psk/000077500000000000000000000000001437412644000160205ustar00rootroot00000000000000dtls-2.2.6/examples/dial/psk/main.go000066400000000000000000000021031437412644000172670ustar00rootroot00000000000000package main import ( "context" "fmt" "net" "time" "github.com/pion/dtls/v2" "github.com/pion/dtls/v2/examples/util" ) func main() { // Prepare the IP to connect to addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444} // // Everything below is the pion-DTLS API! Thanks for using it ❤️. // // Prepare the configuration of the DTLS connection config := &dtls.Config{ PSK: func(hint []byte) ([]byte, error) { fmt.Printf("Server's hint: %s \n", hint) return []byte{0xAB, 0xC1, 0x23}, nil }, PSKIdentityHint: []byte("Pion DTLS Server"), CipherSuites: []dtls.CipherSuiteID{dtls.TLS_PSK_WITH_AES_128_CCM_8}, ExtendedMasterSecret: dtls.RequireExtendedMasterSecret, } // Connect to a DTLS server ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() dtlsConn, err := dtls.DialWithContext(ctx, "udp", addr, config) util.Check(err) defer func() { util.Check(dtlsConn.Close()) }() fmt.Println("Connected; type 'exit' to shutdown gracefully") // Simulate a chat session util.Chat(dtlsConn) } dtls-2.2.6/examples/dial/selfsign/000077500000000000000000000000001437412644000170355ustar00rootroot00000000000000dtls-2.2.6/examples/dial/selfsign/main.go000066400000000000000000000021361437412644000203120ustar00rootroot00000000000000package main import ( "context" "crypto/tls" "fmt" "net" "time" "github.com/pion/dtls/v2" "github.com/pion/dtls/v2/examples/util" "github.com/pion/dtls/v2/pkg/crypto/selfsign" ) func main() { // Prepare the IP to connect to addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444} // Generate a certificate and private key to secure the connection certificate, genErr := selfsign.GenerateSelfSigned() util.Check(genErr) // // Everything below is the pion-DTLS API! Thanks for using it ❤️. // // Prepare the configuration of the DTLS connection config := &dtls.Config{ Certificates: []tls.Certificate{certificate}, InsecureSkipVerify: true, ExtendedMasterSecret: dtls.RequireExtendedMasterSecret, } // Connect to a DTLS server ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() dtlsConn, err := dtls.DialWithContext(ctx, "udp", addr, config) util.Check(err) defer func() { util.Check(dtlsConn.Close()) }() fmt.Println("Connected; type 'exit' to shutdown gracefully") // Simulate a chat session util.Chat(dtlsConn) } dtls-2.2.6/examples/dial/verify/000077500000000000000000000000001437412644000165275ustar00rootroot00000000000000dtls-2.2.6/examples/dial/verify/main.go000066400000000000000000000024701437412644000200050ustar00rootroot00000000000000package main import ( "context" "crypto/tls" "crypto/x509" "fmt" "net" "time" "github.com/pion/dtls/v2" "github.com/pion/dtls/v2/examples/util" ) func main() { // Prepare the IP to connect to addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444} // // Everything below is the pion-DTLS API! Thanks for using it ❤️. // certificate, err := util.LoadKeyAndCertificate("examples/certificates/client.pem", "examples/certificates/client.pub.pem") util.Check(err) rootCertificate, err := util.LoadCertificate("examples/certificates/server.pub.pem") util.Check(err) certPool := x509.NewCertPool() cert, err := x509.ParseCertificate(rootCertificate.Certificate[0]) util.Check(err) certPool.AddCert(cert) // Prepare the configuration of the DTLS connection config := &dtls.Config{ Certificates: []tls.Certificate{certificate}, ExtendedMasterSecret: dtls.RequireExtendedMasterSecret, RootCAs: certPool, } // Connect to a DTLS server ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() dtlsConn, err := dtls.DialWithContext(ctx, "udp", addr, config) util.Check(err) defer func() { util.Check(dtlsConn.Close()) }() fmt.Println("Connected; type 'exit' to shutdown gracefully") // Simulate a chat session util.Chat(dtlsConn) } dtls-2.2.6/examples/listen/000077500000000000000000000000001437412644000156105ustar00rootroot00000000000000dtls-2.2.6/examples/listen/psk/000077500000000000000000000000001437412644000164055ustar00rootroot00000000000000dtls-2.2.6/examples/listen/psk/main.go000066400000000000000000000033071437412644000176630ustar00rootroot00000000000000package main import ( "context" "fmt" "net" "time" "github.com/pion/dtls/v2" "github.com/pion/dtls/v2/examples/util" ) func main() { // Prepare the IP to connect to addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444} // Create parent context to cleanup handshaking connections on exit. ctx, cancel := context.WithCancel(context.Background()) defer cancel() // // Everything below is the pion-DTLS API! Thanks for using it ❤️. // // Prepare the configuration of the DTLS connection config := &dtls.Config{ PSK: func(hint []byte) ([]byte, error) { fmt.Printf("Client's hint: %s \n", hint) return []byte{0xAB, 0xC1, 0x23}, nil }, PSKIdentityHint: []byte("Pion DTLS Client"), CipherSuites: []dtls.CipherSuiteID{dtls.TLS_PSK_WITH_AES_128_CCM_8}, ExtendedMasterSecret: dtls.RequireExtendedMasterSecret, // Create timeout context for accepted connection. ConnectContextMaker: func() (context.Context, func()) { return context.WithTimeout(ctx, 30*time.Second) }, } // Connect to a DTLS server listener, err := dtls.Listen("udp", addr, config) util.Check(err) defer func() { util.Check(listener.Close()) }() fmt.Println("Listening") // Simulate a chat session hub := util.NewHub() go func() { for { // Wait for a connection. conn, err := listener.Accept() util.Check(err) // defer conn.Close() // TODO: graceful shutdown // `conn` is of type `net.Conn` but may be casted to `dtls.Conn` // using `dtlsConn := conn.(*dtls.Conn)` in order to to expose // functions like `ConnectionState` etc. // Register the connection with the chat hub if err == nil { hub.Register(conn) } } }() // Start chatting hub.Chat() } dtls-2.2.6/examples/listen/selfsign/000077500000000000000000000000001437412644000174225ustar00rootroot00000000000000dtls-2.2.6/examples/listen/selfsign/main.go000066400000000000000000000033041437412644000206750ustar00rootroot00000000000000package main import ( "context" "crypto/tls" "fmt" "net" "time" "github.com/pion/dtls/v2" "github.com/pion/dtls/v2/examples/util" "github.com/pion/dtls/v2/pkg/crypto/selfsign" ) func main() { // Prepare the IP to connect to addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444} // Generate a certificate and private key to secure the connection certificate, genErr := selfsign.GenerateSelfSigned() util.Check(genErr) // Create parent context to cleanup handshaking connections on exit. ctx, cancel := context.WithCancel(context.Background()) defer cancel() // // Everything below is the pion-DTLS API! Thanks for using it ❤️. // // Prepare the configuration of the DTLS connection config := &dtls.Config{ Certificates: []tls.Certificate{certificate}, ExtendedMasterSecret: dtls.RequireExtendedMasterSecret, // Create timeout context for accepted connection. ConnectContextMaker: func() (context.Context, func()) { return context.WithTimeout(ctx, 30*time.Second) }, } // Connect to a DTLS server listener, err := dtls.Listen("udp", addr, config) util.Check(err) defer func() { util.Check(listener.Close()) }() fmt.Println("Listening") // Simulate a chat session hub := util.NewHub() go func() { for { // Wait for a connection. conn, err := listener.Accept() util.Check(err) // defer conn.Close() // TODO: graceful shutdown // `conn` is of type `net.Conn` but may be casted to `dtls.Conn` // using `dtlsConn := conn.(*dtls.Conn)` in order to to expose // functions like `ConnectionState` etc. // Register the connection with the chat hub if err == nil { hub.Register(conn) } } }() // Start chatting hub.Chat() } dtls-2.2.6/examples/listen/verify/000077500000000000000000000000001437412644000171145ustar00rootroot00000000000000dtls-2.2.6/examples/listen/verify/main.go000066400000000000000000000037341437412644000203760ustar00rootroot00000000000000package main import ( "context" "crypto/tls" "crypto/x509" "fmt" "net" "time" "github.com/pion/dtls/v2" "github.com/pion/dtls/v2/examples/util" ) func main() { // Prepare the IP to connect to addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444} // Create parent context to cleanup handshaking connections on exit. ctx, cancel := context.WithCancel(context.Background()) defer cancel() // // Everything below is the pion-DTLS API! Thanks for using it ❤️. // certificate, err := util.LoadKeyAndCertificate("examples/certificates/server.pem", "examples/certificates/server.pub.pem") util.Check(err) rootCertificate, err := util.LoadCertificate("examples/certificates/server.pub.pem") util.Check(err) certPool := x509.NewCertPool() cert, err := x509.ParseCertificate(rootCertificate.Certificate[0]) util.Check(err) certPool.AddCert(cert) // Prepare the configuration of the DTLS connection config := &dtls.Config{ Certificates: []tls.Certificate{certificate}, ExtendedMasterSecret: dtls.RequireExtendedMasterSecret, ClientAuth: dtls.RequireAndVerifyClientCert, ClientCAs: certPool, // Create timeout context for accepted connection. ConnectContextMaker: func() (context.Context, func()) { return context.WithTimeout(ctx, 30*time.Second) }, } // Connect to a DTLS server listener, err := dtls.Listen("udp", addr, config) util.Check(err) defer func() { util.Check(listener.Close()) }() fmt.Println("Listening") // Simulate a chat session hub := util.NewHub() go func() { for { // Wait for a connection. conn, err := listener.Accept() util.Check(err) // defer conn.Close() // TODO: graceful shutdown // `conn` is of type `net.Conn` but may be casted to `dtls.Conn` // using `dtlsConn := conn.(*dtls.Conn)` in order to to expose // functions like `ConnectionState` etc. // Register the connection with the chat hub hub.Register(conn) } }() // Start chatting hub.Chat() } dtls-2.2.6/examples/util/000077500000000000000000000000001437412644000152675ustar00rootroot00000000000000dtls-2.2.6/examples/util/hub.go000066400000000000000000000030021437412644000163670ustar00rootroot00000000000000package util import ( "bufio" "fmt" "net" "os" "strings" "sync" ) // Hub is a helper to handle one to many chat type Hub struct { conns map[string]net.Conn lock sync.RWMutex } // NewHub builds a new hub func NewHub() *Hub { return &Hub{conns: make(map[string]net.Conn)} } // Register adds a new conn to the Hub func (h *Hub) Register(conn net.Conn) { fmt.Printf("Connected to %s\n", conn.RemoteAddr()) h.lock.Lock() defer h.lock.Unlock() h.conns[conn.RemoteAddr().String()] = conn go h.readLoop(conn) } func (h *Hub) readLoop(conn net.Conn) { b := make([]byte, bufSize) for { n, err := conn.Read(b) if err != nil { h.unregister(conn) return } fmt.Printf("Got message: %s\n", string(b[:n])) } } func (h *Hub) unregister(conn net.Conn) { h.lock.Lock() defer h.lock.Unlock() delete(h.conns, conn.RemoteAddr().String()) err := conn.Close() if err != nil { fmt.Println("Failed to disconnect", conn.RemoteAddr(), err) } else { fmt.Println("Disconnected ", conn.RemoteAddr()) } } func (h *Hub) broadcast(msg []byte) { h.lock.RLock() defer h.lock.RUnlock() for _, conn := range h.conns { _, err := conn.Write(msg) if err != nil { fmt.Printf("Failed to write message to %s: %v\n", conn.RemoteAddr(), err) } } } // Chat starts the stdin readloop to dispatch messages to the hub func (h *Hub) Chat() { reader := bufio.NewReader(os.Stdin) for { msg, err := reader.ReadString('\n') Check(err) if strings.TrimSpace(msg) == "exit" { return } h.broadcast([]byte(msg)) } } dtls-2.2.6/examples/util/util.go000066400000000000000000000037031437412644000165760ustar00rootroot00000000000000// Package util provides auxiliary utilities used in examples package util import ( "bufio" "crypto/tls" "encoding/pem" "errors" "fmt" "io" "io/ioutil" "net" "os" "path/filepath" "strings" ) const bufSize = 8192 var ( errBlockIsNotCertificate = errors.New("block is not a certificate, unable to load certificates") errNoCertificateFound = errors.New("no certificate found, unable to load certificates") ) // Chat simulates a simple text chat session over the connection func Chat(conn io.ReadWriter) { go func() { b := make([]byte, bufSize) for { n, err := conn.Read(b) Check(err) fmt.Printf("Got message: %s\n", string(b[:n])) } }() reader := bufio.NewReader(os.Stdin) for { text, err := reader.ReadString('\n') Check(err) if strings.TrimSpace(text) == "exit" { return } _, err = conn.Write([]byte(text)) Check(err) } } // Check is a helper to throw errors in the examples func Check(err error) { var netError net.Error if errors.As(err, &netError) && netError.Temporary() { //nolint:staticcheck fmt.Printf("Warning: %v\n", err) } else if err != nil { fmt.Printf("error: %v\n", err) panic(err) } } // LoadKeyAndCertificate reads certificates or key from file func LoadKeyAndCertificate(keyPath string, certificatePath string) (tls.Certificate, error) { return tls.LoadX509KeyPair(certificatePath, keyPath) } // LoadCertificate Load/read certificate(s) from file func LoadCertificate(path string) (*tls.Certificate, error) { rawData, err := ioutil.ReadFile(filepath.Clean(path)) if err != nil { return nil, err } var certificate tls.Certificate for { block, rest := pem.Decode(rawData) if block == nil { break } if block.Type != "CERTIFICATE" { return nil, errBlockIsNotCertificate } certificate.Certificate = append(certificate.Certificate, block.Bytes) rawData = rest } if len(certificate.Certificate) == 0 { return nil, errNoCertificateFound } return &certificate, nil } dtls-2.2.6/flight.go000066400000000000000000000057661437412644000143160ustar00rootroot00000000000000package dtls /* DTLS messages are grouped into a series of message flights, according to the diagrams below. Although each flight of messages may consist of a number of messages, they should be viewed as monolithic for the purpose of timeout and retransmission. https://tools.ietf.org/html/rfc4347#section-4.2.4 Message flights for full handshake: Client Server ------ ------ Waiting Flight 0 ClientHello --------> Flight 1 <------- HelloVerifyRequest Flight 2 ClientHello --------> Flight 3 ServerHello \ Certificate* \ ServerKeyExchange* Flight 4 CertificateRequest* / <-------- ServerHelloDone / Certificate* \ ClientKeyExchange \ CertificateVerify* Flight 5 [ChangeCipherSpec] / Finished --------> / [ChangeCipherSpec] \ Flight 6 <-------- Finished / Message flights for session-resuming handshake (no cookie exchange): Client Server ------ ------ Waiting Flight 0 ClientHello --------> Flight 1 ServerHello \ [ChangeCipherSpec] Flight 4b <-------- Finished / [ChangeCipherSpec] \ Flight 5b Finished --------> / [ChangeCipherSpec] \ Flight 6 <-------- Finished / */ type flightVal uint8 const ( flight0 flightVal = iota + 1 flight1 flight2 flight3 flight4 flight4b flight5 flight5b flight6 ) func (f flightVal) String() string { switch f { case flight0: return "Flight 0" case flight1: return "Flight 1" case flight2: return "Flight 2" case flight3: return "Flight 3" case flight4: return "Flight 4" case flight4b: return "Flight 4b" case flight5: return "Flight 5" case flight5b: return "Flight 5b" case flight6: return "Flight 6" default: return "Invalid Flight" } } func (f flightVal) isLastSendFlight() bool { return f == flight6 || f == flight5b } func (f flightVal) isLastRecvFlight() bool { return f == flight5 || f == flight4b } dtls-2.2.6/flight0handler.go000066400000000000000000000106121437412644000157160ustar00rootroot00000000000000package dtls import ( "context" "crypto/rand" "github.com/pion/dtls/v2/pkg/crypto/elliptic" "github.com/pion/dtls/v2/pkg/protocol" "github.com/pion/dtls/v2/pkg/protocol/alert" "github.com/pion/dtls/v2/pkg/protocol/extension" "github.com/pion/dtls/v2/pkg/protocol/handshake" ) func flight0Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { seq, msgs, ok := cache.fullPullMap(0, state.cipherSuite, handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, ) if !ok { // No valid message received. Keep reading return 0, nil, nil } state.handshakeRecvSequence = seq var clientHello *handshake.MessageClientHello // Validate type if clientHello, ok = msgs[handshake.TypeClientHello].(*handshake.MessageClientHello); !ok { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil } if !clientHello.Version.Equal(protocol.Version1_2) { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion } state.remoteRandom = clientHello.Random cipherSuites := []CipherSuite{} for _, id := range clientHello.CipherSuiteIDs { if c := cipherSuiteForID(CipherSuiteID(id), cfg.customCipherSuites); c != nil { cipherSuites = append(cipherSuites, c) } } if state.cipherSuite, ok = findMatchingCipherSuite(cipherSuites, cfg.localCipherSuites); !ok { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errCipherSuiteNoIntersection } for _, val := range clientHello.Extensions { switch e := val.(type) { case *extension.SupportedEllipticCurves: if len(e.EllipticCurves) == 0 { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errNoSupportedEllipticCurves } state.namedCurve = e.EllipticCurves[0] case *extension.UseSRTP: profile, ok := findMatchingSRTPProfile(e.ProtectionProfiles, cfg.localSRTPProtectionProfiles) if !ok { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errServerNoMatchingSRTPProfile } state.srtpProtectionProfile = profile case *extension.UseExtendedMasterSecret: if cfg.extendedMasterSecret != DisableExtendedMasterSecret { state.extendedMasterSecret = true } case *extension.ServerName: state.serverName = e.ServerName // remote server name case *extension.ALPN: state.peerSupportedProtocols = e.ProtocolNameList } } if cfg.extendedMasterSecret == RequireExtendedMasterSecret && !state.extendedMasterSecret { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errServerRequiredButNoClientEMS } if state.localKeypair == nil { var err error state.localKeypair, err = elliptic.GenerateKeypair(state.namedCurve) if err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, err } } nextFlight := flight2 if cfg.insecureSkipHelloVerify { nextFlight = flight4 } return handleHelloResume(clientHello.SessionID, state, cfg, nextFlight) } func handleHelloResume(sessionID []byte, state *State, cfg *handshakeConfig, next flightVal) (flightVal, *alert.Alert, error) { if len(sessionID) > 0 && cfg.sessionStore != nil { if s, err := cfg.sessionStore.Get(sessionID); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } else if s.ID != nil { cfg.log.Tracef("[handshake] resume session: %x", sessionID) state.SessionID = sessionID state.masterSecret = s.Secret if err := state.initCipherSuite(); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } clientRandom := state.localRandom.MarshalFixed() cfg.writeKeyLog(keyLogLabelTLS12, clientRandom[:], state.masterSecret) return flight4b, nil, nil } } return next, nil, nil } func flight0Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) { // Initialize if !cfg.insecureSkipHelloVerify { state.cookie = make([]byte, cookieLength) if _, err := rand.Read(state.cookie); err != nil { return nil, nil, err } } var zeroEpoch uint16 state.localEpoch.Store(zeroEpoch) state.remoteEpoch.Store(zeroEpoch) state.namedCurve = defaultNamedCurve if err := state.localRandom.Populate(); err != nil { return nil, nil, err } return nil, nil, nil } dtls-2.2.6/flight1handler.go000066400000000000000000000105531437412644000157230ustar00rootroot00000000000000package dtls import ( "context" "github.com/pion/dtls/v2/pkg/crypto/elliptic" "github.com/pion/dtls/v2/pkg/protocol" "github.com/pion/dtls/v2/pkg/protocol/alert" "github.com/pion/dtls/v2/pkg/protocol/extension" "github.com/pion/dtls/v2/pkg/protocol/handshake" "github.com/pion/dtls/v2/pkg/protocol/recordlayer" ) func flight1Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { // HelloVerifyRequest can be skipped by the server, // so allow ServerHello during flight1 also seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite, handshakeCachePullRule{handshake.TypeHelloVerifyRequest, cfg.initialEpoch, false, true}, handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, true}, ) if !ok { // No valid message received. Keep reading return 0, nil, nil } if _, ok := msgs[handshake.TypeServerHello]; ok { // Flight1 and flight2 were skipped. // Parse as flight3. return flight3Parse(ctx, c, state, cache, cfg) } if h, ok := msgs[handshake.TypeHelloVerifyRequest].(*handshake.MessageHelloVerifyRequest); ok { // DTLS 1.2 clients must not assume that the server will use the protocol version // specified in HelloVerifyRequest message. RFC 6347 Section 4.2.1 if !h.Version.Equal(protocol.Version1_0) && !h.Version.Equal(protocol.Version1_2) { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion } state.cookie = append([]byte{}, h.Cookie...) state.handshakeRecvSequence = seq return flight3, nil, nil } return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil } func flight1Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) { var zeroEpoch uint16 state.localEpoch.Store(zeroEpoch) state.remoteEpoch.Store(zeroEpoch) state.namedCurve = defaultNamedCurve state.cookie = nil if err := state.localRandom.Populate(); err != nil { return nil, nil, err } extensions := []extension.Extension{ &extension.SupportedSignatureAlgorithms{ SignatureHashAlgorithms: cfg.localSignatureSchemes, }, &extension.RenegotiationInfo{ RenegotiatedConnection: 0, }, } var setEllipticCurveCryptographyClientHelloExtensions bool for _, c := range cfg.localCipherSuites { if c.ECC() { setEllipticCurveCryptographyClientHelloExtensions = true break } } if setEllipticCurveCryptographyClientHelloExtensions { extensions = append(extensions, []extension.Extension{ &extension.SupportedEllipticCurves{ EllipticCurves: cfg.ellipticCurves, }, &extension.SupportedPointFormats{ PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed}, }, }...) } if len(cfg.localSRTPProtectionProfiles) > 0 { extensions = append(extensions, &extension.UseSRTP{ ProtectionProfiles: cfg.localSRTPProtectionProfiles, }) } if cfg.extendedMasterSecret == RequestExtendedMasterSecret || cfg.extendedMasterSecret == RequireExtendedMasterSecret { extensions = append(extensions, &extension.UseExtendedMasterSecret{ Supported: true, }) } if len(cfg.serverName) > 0 { extensions = append(extensions, &extension.ServerName{ServerName: cfg.serverName}) } if len(cfg.supportedProtocols) > 0 { extensions = append(extensions, &extension.ALPN{ProtocolNameList: cfg.supportedProtocols}) } if cfg.sessionStore != nil { cfg.log.Tracef("[handshake] try to resume session") if s, err := cfg.sessionStore.Get(c.sessionKey()); err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } else if s.ID != nil { cfg.log.Tracef("[handshake] get saved session: %x", s.ID) state.SessionID = s.ID state.masterSecret = s.Secret } } return []*packet{ { record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &handshake.Handshake{ Message: &handshake.MessageClientHello{ Version: protocol.Version1_2, SessionID: state.SessionID, Cookie: state.cookie, Random: state.localRandom, CipherSuiteIDs: cipherSuiteIDs(cfg.localCipherSuites), CompressionMethods: defaultCompressionMethods(), Extensions: extensions, }, }, }, }, }, nil, nil } dtls-2.2.6/flight2handler.go000066400000000000000000000036001437412644000157170ustar00rootroot00000000000000package dtls import ( "bytes" "context" "github.com/pion/dtls/v2/pkg/protocol" "github.com/pion/dtls/v2/pkg/protocol/alert" "github.com/pion/dtls/v2/pkg/protocol/handshake" "github.com/pion/dtls/v2/pkg/protocol/recordlayer" ) func flight2Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite, handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, ) if !ok { // Client may retransmit the first ClientHello when HelloVerifyRequest is dropped. // Parse as flight 0 in this case. return flight0Parse(ctx, c, state, cache, cfg) } state.handshakeRecvSequence = seq var clientHello *handshake.MessageClientHello // Validate type if clientHello, ok = msgs[handshake.TypeClientHello].(*handshake.MessageClientHello); !ok { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil } if !clientHello.Version.Equal(protocol.Version1_2) { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion } if len(clientHello.Cookie) == 0 { return 0, nil, nil } if !bytes.Equal(state.cookie, clientHello.Cookie) { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.AccessDenied}, errCookieMismatch } return flight4, nil, nil } func flight2Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) { state.handshakeSendSequence = 0 return []*packet{ { record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &handshake.Handshake{ Message: &handshake.MessageHelloVerifyRequest{ Version: protocol.Version1_2, Cookie: state.cookie, }, }, }, }, }, nil, nil } dtls-2.2.6/flight3handler.go000066400000000000000000000263741437412644000157350ustar00rootroot00000000000000package dtls import ( "bytes" "context" "github.com/pion/dtls/v2/internal/ciphersuite/types" "github.com/pion/dtls/v2/pkg/crypto/elliptic" "github.com/pion/dtls/v2/pkg/crypto/prf" "github.com/pion/dtls/v2/pkg/protocol" "github.com/pion/dtls/v2/pkg/protocol/alert" "github.com/pion/dtls/v2/pkg/protocol/extension" "github.com/pion/dtls/v2/pkg/protocol/handshake" "github.com/pion/dtls/v2/pkg/protocol/recordlayer" ) func flight3Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { //nolint:gocognit // Clients may receive multiple HelloVerifyRequest messages with different cookies. // Clients SHOULD handle this by sending a new ClientHello with a cookie in response // to the new HelloVerifyRequest. RFC 6347 Section 4.2.1 seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite, handshakeCachePullRule{handshake.TypeHelloVerifyRequest, cfg.initialEpoch, false, true}, ) if ok { if h, msgOk := msgs[handshake.TypeHelloVerifyRequest].(*handshake.MessageHelloVerifyRequest); msgOk { // DTLS 1.2 clients must not assume that the server will use the protocol version // specified in HelloVerifyRequest message. RFC 6347 Section 4.2.1 if !h.Version.Equal(protocol.Version1_0) && !h.Version.Equal(protocol.Version1_2) { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion } state.cookie = append([]byte{}, h.Cookie...) state.handshakeRecvSequence = seq return flight3, nil, nil } } _, msgs, ok = cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite, handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, ) if !ok { // Don't have enough messages. Keep reading return 0, nil, nil } if h, msgOk := msgs[handshake.TypeServerHello].(*handshake.MessageServerHello); msgOk { if !h.Version.Equal(protocol.Version1_2) { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion } for _, v := range h.Extensions { switch e := v.(type) { case *extension.UseSRTP: profile, found := findMatchingSRTPProfile(e.ProtectionProfiles, cfg.localSRTPProtectionProfiles) if !found { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, errClientNoMatchingSRTPProfile } state.srtpProtectionProfile = profile case *extension.UseExtendedMasterSecret: if cfg.extendedMasterSecret != DisableExtendedMasterSecret { state.extendedMasterSecret = true } case *extension.ALPN: if len(e.ProtocolNameList) > 1 { // This should be exactly 1, the zero case is handle when unmarshalling return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, extension.ErrALPNInvalidFormat // Meh, internal error? } state.NegotiatedProtocol = e.ProtocolNameList[0] } } if cfg.extendedMasterSecret == RequireExtendedMasterSecret && !state.extendedMasterSecret { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errClientRequiredButNoServerEMS } if len(cfg.localSRTPProtectionProfiles) > 0 && state.srtpProtectionProfile == 0 { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errRequestedButNoSRTPExtension } remoteCipherSuite := cipherSuiteForID(CipherSuiteID(*h.CipherSuiteID), cfg.customCipherSuites) if remoteCipherSuite == nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errCipherSuiteNoIntersection } selectedCipherSuite, found := findMatchingCipherSuite([]CipherSuite{remoteCipherSuite}, cfg.localCipherSuites) if !found { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errInvalidCipherSuite } state.cipherSuite = selectedCipherSuite state.remoteRandom = h.Random cfg.log.Tracef("[handshake] use cipher suite: %s", selectedCipherSuite.String()) if len(h.SessionID) > 0 && bytes.Equal(state.SessionID, h.SessionID) { return handleResumption(ctx, c, state, cache, cfg) } if len(state.SessionID) > 0 { cfg.log.Tracef("[handshake] clean old session : %s", state.SessionID) if err := cfg.sessionStore.Del(state.SessionID); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } } if cfg.sessionStore == nil { state.SessionID = []byte{} } else { state.SessionID = h.SessionID } state.masterSecret = []byte{} } if cfg.localPSKCallback != nil { seq, msgs, ok = cache.fullPullMap(state.handshakeRecvSequence+1, state.cipherSuite, handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, true}, handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false}, ) } else { seq, msgs, ok = cache.fullPullMap(state.handshakeRecvSequence+1, state.cipherSuite, handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, true}, handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, true}, handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false}, ) } if !ok { // Don't have enough messages. Keep reading return 0, nil, nil } state.handshakeRecvSequence = seq if h, ok := msgs[handshake.TypeCertificate].(*handshake.MessageCertificate); ok { state.PeerCertificates = h.Certificate } else if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.NoCertificate}, errInvalidCertificate } if h, ok := msgs[handshake.TypeServerKeyExchange].(*handshake.MessageServerKeyExchange); ok { alertPtr, err := handleServerKeyExchange(c, state, cfg, h) if err != nil { return 0, alertPtr, err } } if _, ok := msgs[handshake.TypeCertificateRequest].(*handshake.MessageCertificateRequest); ok { state.remoteRequestedCertificate = true } return flight5, nil, nil } func handleResumption(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { if err := state.initCipherSuite(); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } // Now, encrypted packets can be handled if err := c.handleQueuedPackets(ctx); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence+1, state.cipherSuite, handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, false, false}, ) if !ok { // No valid message received. Keep reading return 0, nil, nil } var finished *handshake.MessageFinished if finished, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil } plainText := cache.pullAndMerge( handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, ) expectedVerifyData, err := prf.VerifyDataServer(state.masterSecret, plainText, state.cipherSuite.HashFunc()) if err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } if !bytes.Equal(expectedVerifyData, finished.VerifyData) { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errVerifyDataMismatch } clientRandom := state.localRandom.MarshalFixed() cfg.writeKeyLog(keyLogLabelTLS12, clientRandom[:], state.masterSecret) return flight5b, nil, nil } func handleServerKeyExchange(_ flightConn, state *State, cfg *handshakeConfig, h *handshake.MessageServerKeyExchange) (*alert.Alert, error) { var err error if state.cipherSuite == nil { return &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errInvalidCipherSuite } if cfg.localPSKCallback != nil { var psk []byte if psk, err = cfg.localPSKCallback(h.IdentityHint); err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } state.IdentityHint = h.IdentityHint switch state.cipherSuite.KeyExchangeAlgorithm() { case types.KeyExchangeAlgorithmPsk: state.preMasterSecret = prf.PSKPreMasterSecret(psk) case (types.KeyExchangeAlgorithmEcdhe | types.KeyExchangeAlgorithmPsk): if state.localKeypair, err = elliptic.GenerateKeypair(h.NamedCurve); err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } state.preMasterSecret, err = prf.EcdhePSKPreMasterSecret(psk, h.PublicKey, state.localKeypair.PrivateKey, state.localKeypair.Curve) if err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } default: return &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errInvalidCipherSuite } } else { if state.localKeypair, err = elliptic.GenerateKeypair(h.NamedCurve); err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } if state.preMasterSecret, err = prf.PreMasterSecret(h.PublicKey, state.localKeypair.PrivateKey, state.localKeypair.Curve); err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } } return nil, nil //nolint:nilnil } func flight3Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) { extensions := []extension.Extension{ &extension.SupportedSignatureAlgorithms{ SignatureHashAlgorithms: cfg.localSignatureSchemes, }, &extension.RenegotiationInfo{ RenegotiatedConnection: 0, }, } if state.namedCurve != 0 { extensions = append(extensions, []extension.Extension{ &extension.SupportedEllipticCurves{ EllipticCurves: []elliptic.Curve{elliptic.X25519, elliptic.P256, elliptic.P384}, }, &extension.SupportedPointFormats{ PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed}, }, }...) } if len(cfg.localSRTPProtectionProfiles) > 0 { extensions = append(extensions, &extension.UseSRTP{ ProtectionProfiles: cfg.localSRTPProtectionProfiles, }) } if cfg.extendedMasterSecret == RequestExtendedMasterSecret || cfg.extendedMasterSecret == RequireExtendedMasterSecret { extensions = append(extensions, &extension.UseExtendedMasterSecret{ Supported: true, }) } if len(cfg.serverName) > 0 { extensions = append(extensions, &extension.ServerName{ServerName: cfg.serverName}) } if len(cfg.supportedProtocols) > 0 { extensions = append(extensions, &extension.ALPN{ProtocolNameList: cfg.supportedProtocols}) } return []*packet{ { record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &handshake.Handshake{ Message: &handshake.MessageClientHello{ Version: protocol.Version1_2, SessionID: state.SessionID, Cookie: state.cookie, Random: state.localRandom, CipherSuiteIDs: cipherSuiteIDs(cfg.localCipherSuites), CompressionMethods: defaultCompressionMethods(), Extensions: extensions, }, }, }, }, }, nil, nil } dtls-2.2.6/flight4bhandler.go000066400000000000000000000107371437412644000160740ustar00rootroot00000000000000package dtls import ( "bytes" "context" "github.com/pion/dtls/v2/pkg/crypto/prf" "github.com/pion/dtls/v2/pkg/protocol" "github.com/pion/dtls/v2/pkg/protocol/alert" "github.com/pion/dtls/v2/pkg/protocol/extension" "github.com/pion/dtls/v2/pkg/protocol/handshake" "github.com/pion/dtls/v2/pkg/protocol/recordlayer" ) func flight4bParse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite, handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false}, ) if !ok { // No valid message received. Keep reading return 0, nil, nil } var finished *handshake.MessageFinished if finished, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil } plainText := cache.pullAndMerge( handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, false, false}, ) expectedVerifyData, err := prf.VerifyDataClient(state.masterSecret, plainText, state.cipherSuite.HashFunc()) if err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } if !bytes.Equal(expectedVerifyData, finished.VerifyData) { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errVerifyDataMismatch } // Other party may re-transmit the last flight. Keep state to be flight4b. return flight4b, nil, nil } func flight4bGenerate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) { var pkts []*packet extensions := []extension.Extension{&extension.RenegotiationInfo{ RenegotiatedConnection: 0, }} if (cfg.extendedMasterSecret == RequestExtendedMasterSecret || cfg.extendedMasterSecret == RequireExtendedMasterSecret) && state.extendedMasterSecret { extensions = append(extensions, &extension.UseExtendedMasterSecret{ Supported: true, }) } if state.srtpProtectionProfile != 0 { extensions = append(extensions, &extension.UseSRTP{ ProtectionProfiles: []SRTPProtectionProfile{state.srtpProtectionProfile}, }) } selectedProto, err := extension.ALPNProtocolSelection(cfg.supportedProtocols, state.peerSupportedProtocols) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.NoApplicationProtocol}, err } if selectedProto != "" { extensions = append(extensions, &extension.ALPN{ ProtocolNameList: []string{selectedProto}, }) state.NegotiatedProtocol = selectedProto } cipherSuiteID := uint16(state.cipherSuite.ID()) serverHello := &handshake.Handshake{ Message: &handshake.MessageServerHello{ Version: protocol.Version1_2, Random: state.localRandom, SessionID: state.SessionID, CipherSuiteID: &cipherSuiteID, CompressionMethod: defaultCompressionMethods()[0], Extensions: extensions, }, } serverHello.Header.MessageSequence = uint16(state.handshakeSendSequence) if len(state.localVerifyData) == 0 { plainText := cache.pullAndMerge( handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, ) raw, err := serverHello.Marshal() if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } plainText = append(plainText, raw...) state.localVerifyData, err = prf.VerifyDataServer(state.masterSecret, plainText, state.cipherSuite.HashFunc()) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } } pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: serverHello, }, }, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &protocol.ChangeCipherSpec{}, }, }, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, Epoch: 1, }, Content: &handshake.Handshake{ Message: &handshake.MessageFinished{ VerifyData: state.localVerifyData, }, }, }, shouldEncrypt: true, resetLocalSequenceNumber: true, }, ) return pkts, nil, nil } dtls-2.2.6/flight4handler.go000066400000000000000000000361501437412644000157270ustar00rootroot00000000000000package dtls import ( "context" "crypto/rand" "crypto/x509" "github.com/pion/dtls/v2/internal/ciphersuite" "github.com/pion/dtls/v2/pkg/crypto/clientcertificate" "github.com/pion/dtls/v2/pkg/crypto/elliptic" "github.com/pion/dtls/v2/pkg/crypto/prf" "github.com/pion/dtls/v2/pkg/crypto/signaturehash" "github.com/pion/dtls/v2/pkg/protocol" "github.com/pion/dtls/v2/pkg/protocol/alert" "github.com/pion/dtls/v2/pkg/protocol/extension" "github.com/pion/dtls/v2/pkg/protocol/handshake" "github.com/pion/dtls/v2/pkg/protocol/recordlayer" ) func flight4Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { //nolint:gocognit seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite, handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, true}, handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeCertificateVerify, cfg.initialEpoch, true, true}, ) if !ok { // No valid message received. Keep reading return 0, nil, nil } // Validate type var clientKeyExchange *handshake.MessageClientKeyExchange if clientKeyExchange, ok = msgs[handshake.TypeClientKeyExchange].(*handshake.MessageClientKeyExchange); !ok { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil } if h, hasCert := msgs[handshake.TypeCertificate].(*handshake.MessageCertificate); hasCert { state.PeerCertificates = h.Certificate // If the client offer its certificate, just disable session resumption. // Otherwise, we have to store the certificate identitfication and expire time. // And we have to check whether this certificate expired, revoked or changed. // // https://curl.se/docs/CVE-2016-5419.html state.SessionID = nil } if h, hasCertVerify := msgs[handshake.TypeCertificateVerify].(*handshake.MessageCertificateVerify); hasCertVerify { if state.PeerCertificates == nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.NoCertificate}, errCertificateVerifyNoCertificate } plainText := cache.pullAndMerge( handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false}, ) // Verify that the pair of hash algorithm and signiture is listed. var validSignatureScheme bool for _, ss := range cfg.localSignatureSchemes { if ss.Hash == h.HashAlgorithm && ss.Signature == h.SignatureAlgorithm { validSignatureScheme = true break } } if !validSignatureScheme { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errNoAvailableSignatureSchemes } if err := verifyCertificateVerify(plainText, h.HashAlgorithm, h.Signature, state.PeerCertificates); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err } var chains [][]*x509.Certificate var err error var verified bool if cfg.clientAuth >= VerifyClientCertIfGiven { if chains, err = verifyClientCert(state.PeerCertificates, cfg.clientCAs); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err } verified = true } if cfg.verifyPeerCertificate != nil { if err := cfg.verifyPeerCertificate(state.PeerCertificates, chains); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err } } state.peerCertificatesVerified = verified } else if state.PeerCertificates != nil { // A certificate was received, but we haven't seen a CertificateVerify // keep reading until we receive one return 0, nil, nil } if !state.cipherSuite.IsInitialized() { serverRandom := state.localRandom.MarshalFixed() clientRandom := state.remoteRandom.MarshalFixed() var err error var preMasterSecret []byte if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypePreSharedKey { var psk []byte if psk, err = cfg.localPSKCallback(clientKeyExchange.IdentityHint); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } state.IdentityHint = clientKeyExchange.IdentityHint switch state.cipherSuite.KeyExchangeAlgorithm() { case CipherSuiteKeyExchangeAlgorithmPsk: preMasterSecret = prf.PSKPreMasterSecret(psk) case (CipherSuiteKeyExchangeAlgorithmPsk | CipherSuiteKeyExchangeAlgorithmEcdhe): if preMasterSecret, err = prf.EcdhePSKPreMasterSecret(psk, clientKeyExchange.PublicKey, state.localKeypair.PrivateKey, state.localKeypair.Curve); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } default: return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errInvalidCipherSuite } } else { preMasterSecret, err = prf.PreMasterSecret(clientKeyExchange.PublicKey, state.localKeypair.PrivateKey, state.localKeypair.Curve) if err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, err } } if state.extendedMasterSecret { var sessionHash []byte sessionHash, err = cache.sessionHash(state.cipherSuite.HashFunc(), cfg.initialEpoch) if err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } state.masterSecret, err = prf.ExtendedMasterSecret(preMasterSecret, sessionHash, state.cipherSuite.HashFunc()) if err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } } else { state.masterSecret, err = prf.MasterSecret(preMasterSecret, clientRandom[:], serverRandom[:], state.cipherSuite.HashFunc()) if err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } } if err := state.cipherSuite.Init(state.masterSecret, clientRandom[:], serverRandom[:], false); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } cfg.writeKeyLog(keyLogLabelTLS12, clientRandom[:], state.masterSecret) } if len(state.SessionID) > 0 { s := Session{ ID: state.SessionID, Secret: state.masterSecret, } cfg.log.Tracef("[handshake] save new session: %x", s.ID) if err := cfg.sessionStore.Set(state.SessionID, s); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } } // Now, encrypted packets can be handled if err := c.handleQueuedPackets(ctx); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } seq, msgs, ok = cache.fullPullMap(seq, state.cipherSuite, handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false}, ) if !ok { // No valid message received. Keep reading return 0, nil, nil } state.handshakeRecvSequence = seq if _, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil } if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeAnonymous { if cfg.verifyConnection != nil { if err := cfg.verifyConnection(state.clone()); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err } } return flight6, nil, nil } switch cfg.clientAuth { case RequireAnyClientCert: if state.PeerCertificates == nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.NoCertificate}, errClientCertificateRequired } case VerifyClientCertIfGiven: if state.PeerCertificates != nil && !state.peerCertificatesVerified { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, errClientCertificateNotVerified } case RequireAndVerifyClientCert: if state.PeerCertificates == nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.NoCertificate}, errClientCertificateRequired } if !state.peerCertificatesVerified { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, errClientCertificateNotVerified } case NoClientCert, RequestClientCert: // go to flight6 } if cfg.verifyConnection != nil { if err := cfg.verifyConnection(state.clone()); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err } } return flight6, nil, nil } func flight4Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) { extensions := []extension.Extension{&extension.RenegotiationInfo{ RenegotiatedConnection: 0, }} if (cfg.extendedMasterSecret == RequestExtendedMasterSecret || cfg.extendedMasterSecret == RequireExtendedMasterSecret) && state.extendedMasterSecret { extensions = append(extensions, &extension.UseExtendedMasterSecret{ Supported: true, }) } if state.srtpProtectionProfile != 0 { extensions = append(extensions, &extension.UseSRTP{ ProtectionProfiles: []SRTPProtectionProfile{state.srtpProtectionProfile}, }) } if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate { extensions = append(extensions, &extension.SupportedPointFormats{ PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed}, }) } selectedProto, err := extension.ALPNProtocolSelection(cfg.supportedProtocols, state.peerSupportedProtocols) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.NoApplicationProtocol}, err } if selectedProto != "" { extensions = append(extensions, &extension.ALPN{ ProtocolNameList: []string{selectedProto}, }) state.NegotiatedProtocol = selectedProto } var pkts []*packet cipherSuiteID := uint16(state.cipherSuite.ID()) if cfg.sessionStore != nil { state.SessionID = make([]byte, sessionLength) if _, err := rand.Read(state.SessionID); err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } } pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &handshake.Handshake{ Message: &handshake.MessageServerHello{ Version: protocol.Version1_2, Random: state.localRandom, SessionID: state.SessionID, CipherSuiteID: &cipherSuiteID, CompressionMethod: defaultCompressionMethods()[0], Extensions: extensions, }, }, }, }) switch { case state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate: certificate, err := cfg.getCertificate(&ClientHelloInfo{ ServerName: state.serverName, CipherSuites: []ciphersuite.ID{state.cipherSuite.ID()}, }) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, err } pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &handshake.Handshake{ Message: &handshake.MessageCertificate{ Certificate: certificate.Certificate, }, }, }, }) serverRandom := state.localRandom.MarshalFixed() clientRandom := state.remoteRandom.MarshalFixed() // Find compatible signature scheme signatureHashAlgo, err := signaturehash.SelectSignatureScheme(cfg.localSignatureSchemes, certificate.PrivateKey) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, err } signature, err := generateKeySignature(clientRandom[:], serverRandom[:], state.localKeypair.PublicKey, state.namedCurve, certificate.PrivateKey, signatureHashAlgo.Hash) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } state.localKeySignature = signature pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &handshake.Handshake{ Message: &handshake.MessageServerKeyExchange{ EllipticCurveType: elliptic.CurveTypeNamedCurve, NamedCurve: state.namedCurve, PublicKey: state.localKeypair.PublicKey, HashAlgorithm: signatureHashAlgo.Hash, SignatureAlgorithm: signatureHashAlgo.Signature, Signature: state.localKeySignature, }, }, }, }) if cfg.clientAuth > NoClientCert { // An empty list of certificateAuthorities signals to // the client that it may send any certificate in response // to our request. When we know the CAs we trust, then // we can send them down, so that the client can choose // an appropriate certificate to give to us. var certificateAuthorities [][]byte if cfg.clientCAs != nil { // nolint:staticcheck // ignoring tlsCert.RootCAs.Subjects is deprecated ERR because cert does not come from SystemCertPool and it's ok if certificate authorities is empty. certificateAuthorities = cfg.clientCAs.Subjects() } pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &handshake.Handshake{ Message: &handshake.MessageCertificateRequest{ CertificateTypes: []clientcertificate.Type{clientcertificate.RSASign, clientcertificate.ECDSASign}, SignatureHashAlgorithms: cfg.localSignatureSchemes, CertificateAuthoritiesNames: certificateAuthorities, }, }, }, }) } case cfg.localPSKIdentityHint != nil || state.cipherSuite.KeyExchangeAlgorithm().Has(CipherSuiteKeyExchangeAlgorithmEcdhe): // To help the client in selecting which identity to use, the server // can provide a "PSK identity hint" in the ServerKeyExchange message. // If no hint is provided and cipher suite doesn't use elliptic curve, // the ServerKeyExchange message is omitted. // // https://tools.ietf.org/html/rfc4279#section-2 srvExchange := &handshake.MessageServerKeyExchange{ IdentityHint: cfg.localPSKIdentityHint, } if state.cipherSuite.KeyExchangeAlgorithm().Has(CipherSuiteKeyExchangeAlgorithmEcdhe) { srvExchange.EllipticCurveType = elliptic.CurveTypeNamedCurve srvExchange.NamedCurve = state.namedCurve srvExchange.PublicKey = state.localKeypair.PublicKey } pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &handshake.Handshake{ Message: srvExchange, }, }, }) } pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &handshake.Handshake{ Message: &handshake.MessageServerHelloDone{}, }, }, }) return pkts, nil, nil } dtls-2.2.6/flight4handler_test.go000066400000000000000000000115641437412644000167700ustar00rootroot00000000000000package dtls import ( "context" "testing" "time" "github.com/pion/dtls/v2/internal/ciphersuite" "github.com/pion/dtls/v2/pkg/protocol/alert" "github.com/pion/dtls/v2/pkg/protocol/handshake" "github.com/pion/transport/v2/test" ) type flight4TestMockFlightConn struct{} func (f *flight4TestMockFlightConn) notify(ctx context.Context, level alert.Level, desc alert.Description) error { return nil } func (f *flight4TestMockFlightConn) writePackets(context.Context, []*packet) error { return nil } func (f *flight4TestMockFlightConn) recvHandshake() <-chan chan struct{} { return nil } func (f *flight4TestMockFlightConn) setLocalEpoch(epoch uint16) {} func (f *flight4TestMockFlightConn) handleQueuedPackets(context.Context) error { return nil } func (f *flight4TestMockFlightConn) sessionKey() []byte { return nil } type flight4TestMockCipherSuite struct { ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256 t *testing.T } func (f *flight4TestMockCipherSuite) IsInitialized() bool { f.t.Fatal("IsInitialized called with Certificate but not CertificateVerify") return true } // Assert that if a Client sends a certificate they // must also send a CertificateVerify message. // The flight4handler must not interact with the CipherSuite // if the CertificateVerify is missing func TestFlight4_Process_CertificateVerify(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(5 * time.Second) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() mockConn := &flight4TestMockFlightConn{} state := &State{ cipherSuite: &flight4TestMockCipherSuite{t: t}, } cache := newHandshakeCache() cfg := &handshakeConfig{} rawCertificate := []byte{ 0x0b, 0x00, 0x01, 0x9b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x9b, 0x00, 0x01, 0x98, 0x00, 0x01, 0x95, 0x30, 0x82, 0x01, 0x91, 0x30, 0x82, 0x01, 0x38, 0xa0, 0x03, 0x02, 0x01, 0x02, 0x02, 0x11, 0x01, 0x65, 0x03, 0x3f, 0x4d, 0x0b, 0x9a, 0x62, 0x91, 0xdb, 0x4d, 0x28, 0x2c, 0x1f, 0xd6, 0x73, 0x32, 0x30, 0x0a, 0x06, 0x08, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x04, 0x03, 0x02, 0x30, 0x00, 0x30, 0x1e, 0x17, 0x0d, 0x32, 0x32, 0x30, 0x35, 0x31, 0x35, 0x31, 0x38, 0x34, 0x33, 0x35, 0x35, 0x5a, 0x17, 0x0d, 0x32, 0x32, 0x30, 0x36, 0x31, 0x35, 0x31, 0x38, 0x34, 0x33, 0x35, 0x35, 0x5a, 0x30, 0x00, 0x30, 0x59, 0x30, 0x13, 0x06, 0x07, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x02, 0x01, 0x06, 0x08, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x03, 0x01, 0x07, 0x03, 0x42, 0x00, 0x04, 0xc3, 0xb7, 0x13, 0x1a, 0x0a, 0xfc, 0xd0, 0x82, 0xf8, 0x94, 0x5e, 0xc0, 0x77, 0x07, 0x81, 0x28, 0xc9, 0xcb, 0x08, 0x84, 0x50, 0x6b, 0xf0, 0x22, 0xe8, 0x79, 0xb9, 0x15, 0x33, 0xc4, 0x56, 0xa1, 0xd3, 0x1b, 0x24, 0xe3, 0x61, 0xbd, 0x4d, 0x65, 0x80, 0x6b, 0x5d, 0x96, 0x48, 0xa2, 0x44, 0x9e, 0xce, 0xe8, 0x65, 0xd6, 0x3c, 0xe0, 0x9b, 0x6b, 0xa1, 0x36, 0x34, 0xb2, 0x39, 0xe2, 0x03, 0x00, 0xa3, 0x81, 0x92, 0x30, 0x81, 0x8f, 0x30, 0x0e, 0x06, 0x03, 0x55, 0x1d, 0x0f, 0x01, 0x01, 0xff, 0x04, 0x04, 0x03, 0x02, 0x02, 0xa4, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x1d, 0x25, 0x04, 0x16, 0x30, 0x14, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x02, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x01, 0x30, 0x0f, 0x06, 0x03, 0x55, 0x1d, 0x13, 0x01, 0x01, 0xff, 0x04, 0x05, 0x30, 0x03, 0x01, 0x01, 0xff, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x1d, 0x0e, 0x04, 0x16, 0x04, 0x14, 0xb1, 0x1a, 0xe3, 0xeb, 0x6f, 0x7c, 0xc3, 0x8f, 0xba, 0x6f, 0x1c, 0xe8, 0xf0, 0x23, 0x08, 0x50, 0x8d, 0x3c, 0xea, 0x31, 0x30, 0x2e, 0x06, 0x03, 0x55, 0x1d, 0x11, 0x01, 0x01, 0xff, 0x04, 0x24, 0x30, 0x22, 0x82, 0x20, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x0a, 0x06, 0x08, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x04, 0x03, 0x02, 0x03, 0x47, 0x00, 0x30, 0x44, 0x02, 0x20, 0x06, 0x31, 0x43, 0xac, 0x03, 0x45, 0x79, 0x3c, 0xd7, 0x5f, 0x6e, 0x6a, 0xf8, 0x0e, 0xfd, 0x35, 0x49, 0xee, 0x1b, 0xbc, 0x47, 0xce, 0xe3, 0x39, 0xec, 0xe4, 0x62, 0xe1, 0x30, 0x1a, 0xa1, 0x89, 0x02, 0x20, 0x35, 0xcd, 0x7a, 0x15, 0x68, 0x09, 0x50, 0x49, 0x9e, 0x3e, 0x05, 0xd7, 0xc2, 0x69, 0x3f, 0x9c, 0x0c, 0x98, 0x92, 0x65, 0xec, 0xae, 0x44, 0xfe, 0xe5, 0x68, 0xb8, 0x09, 0x78, 0x7f, 0x6b, 0x77, } rawClientKeyExchange := []byte{ 0x10, 0x00, 0x00, 0x21, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x21, 0x20, 0x96, 0xed, 0x0c, 0xee, 0xf3, 0x11, 0xb1, 0x9d, 0x8b, 0x1c, 0x02, 0x7f, 0x06, 0x7c, 0x57, 0x7a, 0x14, 0xa6, 0x41, 0xde, 0x63, 0x57, 0x9e, 0xcd, 0x34, 0x54, 0xba, 0x37, 0x4d, 0x34, 0x15, 0x18, } cache.push(rawCertificate, 0, 0, handshake.TypeCertificate, true) cache.push(rawClientKeyExchange, 0, 1, handshake.TypeClientKeyExchange, true) if _, _, err := flight4Parse(context.TODO(), mockConn, state, cache, cfg); err != nil { t.Fatal(err) } } dtls-2.2.6/flight5bhandler.go000066400000000000000000000043551437412644000160740ustar00rootroot00000000000000package dtls import ( "context" "github.com/pion/dtls/v2/pkg/crypto/prf" "github.com/pion/dtls/v2/pkg/protocol" "github.com/pion/dtls/v2/pkg/protocol/alert" "github.com/pion/dtls/v2/pkg/protocol/handshake" "github.com/pion/dtls/v2/pkg/protocol/recordlayer" ) func flight5bParse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence-1, state.cipherSuite, handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, false, false}, ) if !ok { // No valid message received. Keep reading return 0, nil, nil } if _, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil } // Other party may re-transmit the last flight. Keep state to be flight5b. return flight5b, nil, nil } func flight5bGenerate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) { //nolint:gocognit var pkts []*packet pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &protocol.ChangeCipherSpec{}, }, }) if len(state.localVerifyData) == 0 { plainText := cache.pullAndMerge( handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, false, false}, ) var err error state.localVerifyData, err = prf.VerifyDataClient(state.masterSecret, plainText, state.cipherSuite.HashFunc()) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } } pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, Epoch: 1, }, Content: &handshake.Handshake{ Message: &handshake.MessageFinished{ VerifyData: state.localVerifyData, }, }, }, shouldEncrypt: true, resetLocalSequenceNumber: true, }) return pkts, nil, nil } dtls-2.2.6/flight5handler.go000066400000000000000000000323451437412644000157320ustar00rootroot00000000000000package dtls import ( "bytes" "context" "crypto" "crypto/x509" "github.com/pion/dtls/v2/pkg/crypto/prf" "github.com/pion/dtls/v2/pkg/crypto/signaturehash" "github.com/pion/dtls/v2/pkg/protocol" "github.com/pion/dtls/v2/pkg/protocol/alert" "github.com/pion/dtls/v2/pkg/protocol/handshake" "github.com/pion/dtls/v2/pkg/protocol/recordlayer" ) func flight5Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite, handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, false, false}, ) if !ok { // No valid message received. Keep reading return 0, nil, nil } var finished *handshake.MessageFinished if finished, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil } plainText := cache.pullAndMerge( handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeCertificateVerify, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false}, ) expectedVerifyData, err := prf.VerifyDataServer(state.masterSecret, plainText, state.cipherSuite.HashFunc()) if err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } if !bytes.Equal(expectedVerifyData, finished.VerifyData) { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errVerifyDataMismatch } if len(state.SessionID) > 0 { s := Session{ ID: state.SessionID, Secret: state.masterSecret, } cfg.log.Tracef("[handshake] save new session: %x", s.ID) if err := cfg.sessionStore.Set(c.sessionKey(), s); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } } return flight5, nil, nil } func flight5Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) { //nolint:gocognit var privateKey crypto.PrivateKey var pkts []*packet if state.remoteRequestedCertificate { _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence-2, state.cipherSuite, handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false}) if !ok { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errClientCertificateRequired } reqInfo := CertificateRequestInfo{} if r, ok := msgs[handshake.TypeCertificateRequest].(*handshake.MessageCertificateRequest); ok { reqInfo.AcceptableCAs = r.CertificateAuthoritiesNames } else { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errClientCertificateRequired } certificate, err := cfg.getClientCertificate(&reqInfo) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, err } if certificate == nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errNotAcceptableCertificateChain } if certificate.Certificate != nil { privateKey = certificate.PrivateKey } pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &handshake.Handshake{ Message: &handshake.MessageCertificate{ Certificate: certificate.Certificate, }, }, }, }) } clientKeyExchange := &handshake.MessageClientKeyExchange{} if cfg.localPSKCallback == nil { clientKeyExchange.PublicKey = state.localKeypair.PublicKey } else { clientKeyExchange.IdentityHint = cfg.localPSKIdentityHint } if state != nil && state.localKeypair != nil && len(state.localKeypair.PublicKey) > 0 { clientKeyExchange.PublicKey = state.localKeypair.PublicKey } pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &handshake.Handshake{ Message: clientKeyExchange, }, }, }) serverKeyExchangeData := cache.pullAndMerge( handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false}, ) serverKeyExchange := &handshake.MessageServerKeyExchange{} // handshakeMessageServerKeyExchange is optional for PSK if len(serverKeyExchangeData) == 0 { alertPtr, err := handleServerKeyExchange(c, state, cfg, &handshake.MessageServerKeyExchange{}) if err != nil { return nil, alertPtr, err } } else { rawHandshake := &handshake.Handshake{ KeyExchangeAlgorithm: state.cipherSuite.KeyExchangeAlgorithm(), } err := rawHandshake.Unmarshal(serverKeyExchangeData) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, err } switch h := rawHandshake.Message.(type) { case *handshake.MessageServerKeyExchange: serverKeyExchange = h default: return nil, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, errInvalidContentType } } // Append not-yet-sent packets merged := []byte{} seqPred := uint16(state.handshakeSendSequence) for _, p := range pkts { h, ok := p.record.Content.(*handshake.Handshake) if !ok { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errInvalidContentType } h.Header.MessageSequence = seqPred seqPred++ raw, err := h.Marshal() if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } merged = append(merged, raw...) } if alertPtr, err := initalizeCipherSuite(state, cache, cfg, serverKeyExchange, merged); err != nil { return nil, alertPtr, err } // If the client has sent a certificate with signing ability, a digitally-signed // CertificateVerify message is sent to explicitly verify possession of the // private key in the certificate. if state.remoteRequestedCertificate && privateKey != nil { plainText := append(cache.pullAndMerge( handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false}, ), merged...) // Find compatible signature scheme signatureHashAlgo, err := signaturehash.SelectSignatureScheme(cfg.localSignatureSchemes, privateKey) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, err } certVerify, err := generateCertificateVerify(plainText, privateKey, signatureHashAlgo.Hash) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } state.localCertificatesVerify = certVerify p := &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &handshake.Handshake{ Message: &handshake.MessageCertificateVerify{ HashAlgorithm: signatureHashAlgo.Hash, SignatureAlgorithm: signatureHashAlgo.Signature, Signature: state.localCertificatesVerify, }, }, }, } pkts = append(pkts, p) h, ok := p.record.Content.(*handshake.Handshake) if !ok { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errInvalidContentType } h.Header.MessageSequence = seqPred // seqPred++ // this is the last use of seqPred raw, err := h.Marshal() if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } merged = append(merged, raw...) } pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &protocol.ChangeCipherSpec{}, }, }) if len(state.localVerifyData) == 0 { plainText := cache.pullAndMerge( handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeCertificateVerify, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false}, ) var err error state.localVerifyData, err = prf.VerifyDataClient(state.masterSecret, append(plainText, merged...), state.cipherSuite.HashFunc()) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } } pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, Epoch: 1, }, Content: &handshake.Handshake{ Message: &handshake.MessageFinished{ VerifyData: state.localVerifyData, }, }, }, shouldEncrypt: true, resetLocalSequenceNumber: true, }) return pkts, nil, nil } func initalizeCipherSuite(state *State, cache *handshakeCache, cfg *handshakeConfig, h *handshake.MessageServerKeyExchange, sendingPlainText []byte) (*alert.Alert, error) { //nolint:gocognit if state.cipherSuite.IsInitialized() { return nil, nil //nolint } clientRandom := state.localRandom.MarshalFixed() serverRandom := state.remoteRandom.MarshalFixed() var err error if state.extendedMasterSecret { var sessionHash []byte sessionHash, err = cache.sessionHash(state.cipherSuite.HashFunc(), cfg.initialEpoch, sendingPlainText) if err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } state.masterSecret, err = prf.ExtendedMasterSecret(state.preMasterSecret, sessionHash, state.cipherSuite.HashFunc()) if err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, err } } else { state.masterSecret, err = prf.MasterSecret(state.preMasterSecret, clientRandom[:], serverRandom[:], state.cipherSuite.HashFunc()) if err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } } if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate { // Verify that the pair of hash algorithm and signiture is listed. var validSignatureScheme bool for _, ss := range cfg.localSignatureSchemes { if ss.Hash == h.HashAlgorithm && ss.Signature == h.SignatureAlgorithm { validSignatureScheme = true break } } if !validSignatureScheme { return &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errNoAvailableSignatureSchemes } expectedMsg := valueKeyMessage(clientRandom[:], serverRandom[:], h.PublicKey, h.NamedCurve) if err = verifyKeySignature(expectedMsg, h.Signature, h.HashAlgorithm, state.PeerCertificates); err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err } var chains [][]*x509.Certificate if !cfg.insecureSkipVerify { if chains, err = verifyServerCert(state.PeerCertificates, cfg.rootCAs, cfg.serverName); err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err } } if cfg.verifyPeerCertificate != nil { if err = cfg.verifyPeerCertificate(state.PeerCertificates, chains); err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err } } } if cfg.verifyConnection != nil { if err = cfg.verifyConnection(state.clone()); err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err } } if err = state.cipherSuite.Init(state.masterSecret, clientRandom[:], serverRandom[:], true); err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } cfg.writeKeyLog(keyLogLabelTLS12, clientRandom[:], state.masterSecret) return nil, nil //nolint } dtls-2.2.6/flight6handler.go000066400000000000000000000055131437412644000157300ustar00rootroot00000000000000package dtls import ( "context" "github.com/pion/dtls/v2/pkg/crypto/prf" "github.com/pion/dtls/v2/pkg/protocol" "github.com/pion/dtls/v2/pkg/protocol/alert" "github.com/pion/dtls/v2/pkg/protocol/handshake" "github.com/pion/dtls/v2/pkg/protocol/recordlayer" ) func flight6Parse(ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) (flightVal, *alert.Alert, error) { _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence-1, state.cipherSuite, handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false}, ) if !ok { // No valid message received. Keep reading return 0, nil, nil } if _, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil } // Other party may re-transmit the last flight. Keep state to be flight6. return flight6, nil, nil } func flight6Generate(c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig) ([]*packet, *alert.Alert, error) { var pkts []*packet pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &protocol.ChangeCipherSpec{}, }, }) if len(state.localVerifyData) == 0 { plainText := cache.pullAndMerge( handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeCertificateVerify, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false}, ) var err error state.localVerifyData, err = prf.VerifyDataServer(state.masterSecret, plainText, state.cipherSuite.HashFunc()) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } } pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, Epoch: 1, }, Content: &handshake.Handshake{ Message: &handshake.MessageFinished{ VerifyData: state.localVerifyData, }, }, }, shouldEncrypt: true, resetLocalSequenceNumber: true, }, ) return pkts, nil, nil } dtls-2.2.6/flighthandler.go000066400000000000000000000031411437412644000156350ustar00rootroot00000000000000package dtls import ( "context" "github.com/pion/dtls/v2/pkg/protocol/alert" ) // Parse received handshakes and return next flightVal type flightParser func(context.Context, flightConn, *State, *handshakeCache, *handshakeConfig) (flightVal, *alert.Alert, error) // Generate flights type flightGenerator func(flightConn, *State, *handshakeCache, *handshakeConfig) ([]*packet, *alert.Alert, error) func (f flightVal) getFlightParser() (flightParser, error) { switch f { case flight0: return flight0Parse, nil case flight1: return flight1Parse, nil case flight2: return flight2Parse, nil case flight3: return flight3Parse, nil case flight4: return flight4Parse, nil case flight4b: return flight4bParse, nil case flight5: return flight5Parse, nil case flight5b: return flight5bParse, nil case flight6: return flight6Parse, nil default: return nil, errInvalidFlight } } func (f flightVal) getFlightGenerator() (gen flightGenerator, retransmit bool, err error) { switch f { case flight0: return flight0Generate, true, nil case flight1: return flight1Generate, true, nil case flight2: // https://tools.ietf.org/html/rfc6347#section-3.2.1 // HelloVerifyRequests must not be retransmitted. return flight2Generate, false, nil case flight3: return flight3Generate, true, nil case flight4: return flight4Generate, true, nil case flight4b: return flight4bGenerate, true, nil case flight5: return flight5Generate, true, nil case flight5b: return flight5bGenerate, true, nil case flight6: return flight6Generate, true, nil default: return nil, false, errInvalidFlight } } dtls-2.2.6/fragment_buffer.go000066400000000000000000000066641437412644000161730ustar00rootroot00000000000000package dtls import ( "github.com/pion/dtls/v2/pkg/protocol" "github.com/pion/dtls/v2/pkg/protocol/handshake" "github.com/pion/dtls/v2/pkg/protocol/recordlayer" ) // 2 megabytes const fragmentBufferMaxSize = 2000000 type fragment struct { recordLayerHeader recordlayer.Header handshakeHeader handshake.Header data []byte } type fragmentBuffer struct { // map of MessageSequenceNumbers that hold slices of fragments cache map[uint16][]*fragment currentMessageSequenceNumber uint16 } func newFragmentBuffer() *fragmentBuffer { return &fragmentBuffer{cache: map[uint16][]*fragment{}} } // current total size of buffer func (f *fragmentBuffer) size() int { size := 0 for i := range f.cache { for j := range f.cache[i] { size += len(f.cache[i][j].data) } } return size } // Attempts to push a DTLS packet to the fragmentBuffer // when it returns true it means the fragmentBuffer has inserted and the buffer shouldn't be handled // when an error returns it is fatal, and the DTLS connection should be stopped func (f *fragmentBuffer) push(buf []byte) (bool, error) { if f.size()+len(buf) >= fragmentBufferMaxSize { return false, errFragmentBufferOverflow } frag := new(fragment) if err := frag.recordLayerHeader.Unmarshal(buf); err != nil { return false, err } // fragment isn't a handshake, we don't need to handle it if frag.recordLayerHeader.ContentType != protocol.ContentTypeHandshake { return false, nil } for buf = buf[recordlayer.HeaderSize:]; len(buf) != 0; frag = new(fragment) { if err := frag.handshakeHeader.Unmarshal(buf); err != nil { return false, err } if _, ok := f.cache[frag.handshakeHeader.MessageSequence]; !ok { f.cache[frag.handshakeHeader.MessageSequence] = []*fragment{} } // end index should be the length of handshake header but if the handshake // was fragmented, we should keep them all end := int(handshake.HeaderLength + frag.handshakeHeader.Length) if size := len(buf); end > size { end = size } // Discard all headers, when rebuilding the packet we will re-build frag.data = append([]byte{}, buf[handshake.HeaderLength:end]...) f.cache[frag.handshakeHeader.MessageSequence] = append(f.cache[frag.handshakeHeader.MessageSequence], frag) buf = buf[end:] } return true, nil } func (f *fragmentBuffer) pop() (content []byte, epoch uint16) { frags, ok := f.cache[f.currentMessageSequenceNumber] if !ok { return nil, 0 } // Go doesn't support recursive lambdas var appendMessage func(targetOffset uint32) bool rawMessage := []byte{} appendMessage = func(targetOffset uint32) bool { for _, f := range frags { if f.handshakeHeader.FragmentOffset == targetOffset { fragmentEnd := (f.handshakeHeader.FragmentOffset + f.handshakeHeader.FragmentLength) if fragmentEnd != f.handshakeHeader.Length && f.handshakeHeader.FragmentLength != 0 { if !appendMessage(fragmentEnd) { return false } } rawMessage = append(f.data, rawMessage...) return true } } return false } // Recursively collect up if !appendMessage(0) { return nil, 0 } firstHeader := frags[0].handshakeHeader firstHeader.FragmentOffset = 0 firstHeader.FragmentLength = firstHeader.Length rawHeader, err := firstHeader.Marshal() if err != nil { return nil, 0 } messageEpoch := frags[0].recordLayerHeader.Epoch delete(f.cache, f.currentMessageSequenceNumber) f.currentMessageSequenceNumber++ return append(rawHeader, rawMessage...), messageEpoch } dtls-2.2.6/fragment_buffer_test.go000066400000000000000000000125121437412644000172170ustar00rootroot00000000000000package dtls import ( "errors" "reflect" "testing" ) func TestFragmentBuffer(t *testing.T) { for _, test := range []struct { Name string In [][]byte Expected [][]byte Epoch uint16 }{ { Name: "Single Fragment", In: [][]byte{ {0x16, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0F, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xfe, 0xff, 0x00}, }, Expected: [][]byte{ {0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xfe, 0xff, 0x00}, }, Epoch: 0, }, { Name: "Single Fragment Epoch 3", In: [][]byte{ {0x16, 0xfe, 0xff, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0F, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xfe, 0xff, 0x00}, }, Expected: [][]byte{ {0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xfe, 0xff, 0x00}, }, Epoch: 3, }, { Name: "Multiple Fragments", In: [][]byte{ {0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x81, 0x0b, 0x00, 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x01, 0x02, 0x03, 0x04}, {0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x81, 0x0b, 0x00, 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x05, 0x05, 0x06, 0x07, 0x08, 0x09}, {0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x81, 0x0b, 0x00, 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x0A, 0x00, 0x00, 0x05, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E}, }, Expected: [][]byte{ {0x0b, 0x00, 0x00, 0x0f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e}, }, Epoch: 0, }, { Name: "Multiple Unordered Fragments", In: [][]byte{ {0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x81, 0x0b, 0x00, 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x01, 0x02, 0x03, 0x04}, {0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x81, 0x0b, 0x00, 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x0A, 0x00, 0x00, 0x05, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E}, {0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x81, 0x0b, 0x00, 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x05, 0x05, 0x06, 0x07, 0x08, 0x09}, }, Expected: [][]byte{ {0x0b, 0x00, 0x00, 0x0f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e}, }, Epoch: 0, }, { Name: "Multiple Handshakes in Single Fragment", In: [][]byte{ { 0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x30, /* record header */ 0x03, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0xfe, 0xff, 0x01, 0x01, /*handshake msg 1*/ 0x03, 0x00, 0x00, 0x04, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0xfe, 0xff, 0x01, 0x01, /*handshake msg 2*/ 0x03, 0x00, 0x00, 0x04, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0xfe, 0xff, 0x01, 0x01, /*handshake msg 3*/ }, }, Expected: [][]byte{ {0x03, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0xfe, 0xff, 0x01, 0x01}, {0x03, 0x00, 0x00, 0x04, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0xfe, 0xff, 0x01, 0x01}, {0x03, 0x00, 0x00, 0x04, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0xfe, 0xff, 0x01, 0x01}, }, Epoch: 0, }, // Assert that a zero length fragment doesn't cause the fragmentBuffer to enter an infinite loop { Name: "Zero Length Fragment", In: [][]byte{ { 0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, }, }, Expected: [][]byte{ {0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00}, }, Epoch: 0, }, } { fragmentBuffer := newFragmentBuffer() for _, frag := range test.In { status, err := fragmentBuffer.push(frag) if err != nil { t.Error(err) } else if !status { t.Errorf("fragmentBuffer didn't accept fragments for '%s'", test.Name) } } for _, expected := range test.Expected { out, epoch := fragmentBuffer.pop() if !reflect.DeepEqual(out, expected) { t.Errorf("fragmentBuffer '%s' push/pop: got % 02x, want % 02x", test.Name, out, expected) } if epoch != test.Epoch { t.Errorf("fragmentBuffer returned wrong epoch: got %d, want %d", epoch, test.Epoch) } } if frag, _ := fragmentBuffer.pop(); frag != nil { t.Errorf("fragmentBuffer popped single buffer multiple times for '%s'", test.Name) } } } func TestFragmentBuffer_Overflow(t *testing.T) { fragmentBuffer := newFragmentBuffer() // Push a buffer that doesn't exceed size limits if _, err := fragmentBuffer.push([]byte{0x16, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0F, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xfe, 0xff, 0x00}); err != nil { t.Fatal(err) } // Allocate a buffer that exceeds cache size largeBuffer := make([]byte, fragmentBufferMaxSize) if _, err := fragmentBuffer.push(largeBuffer); !errors.Is(err, errFragmentBufferOverflow) { t.Fatalf("Pushing a large buffer returned (%s) expected(%s)", err, errFragmentBufferOverflow) } } dtls-2.2.6/go.mod000066400000000000000000000003161437412644000136020ustar00rootroot00000000000000module github.com/pion/dtls/v2 require ( github.com/pion/logging v0.2.2 github.com/pion/transport/v2 v2.0.2 github.com/pion/udp/v2 v2.0.1 golang.org/x/crypto v0.5.0 golang.org/x/net v0.7.0 ) go 1.13 dtls-2.2.6/go.sum000066400000000000000000000121761437412644000136360ustar00rootroot00000000000000github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY= github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= github.com/pion/transport/v2 v2.0.2 h1:St+8o+1PEzPT51O9bv+tH/KYYLMNR5Vwm5Z3Qkjsywg= github.com/pion/transport/v2 v2.0.2/go.mod h1:vrz6bUbFr/cjdwbnxq8OdDDzHf7JJfGsIRkxfpZoTA0= github.com/pion/udp/v2 v2.0.1 h1:xP0z6WNux1zWEjhC7onRA3EwwSliXqu1ElUZAQhUP54= github.com/pion/udp/v2 v2.0.1/go.mod h1:B7uvTMP00lzWdyMr/1PVZXtV3wpPIxBRd4Wl6AksXn8= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.5.0 h1:U/0M97KRkSFvyD/3FSmdP5W5swImpNgle/EHFhOsQPE= golang.org/x/crypto v0.5.0/go.mod h1:NK/OQwhpMQP3MwtdjgLlYHnH9ebylxKWv3e0fK+mkQU= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.5.0/go.mod h1:DivGGAXEgPSlEBzxGzZI+ZLohi+xUj054jfeKui00ws= golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g= golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.4.0/go.mod h1:9P2UbLfCdcvo3p/nzKvsmas4TnlujnuoV9hGgYzW1lQ= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= dtls-2.2.6/handshake_cache.go000066400000000000000000000110351437412644000160740ustar00rootroot00000000000000package dtls import ( "sync" "github.com/pion/dtls/v2/pkg/crypto/prf" "github.com/pion/dtls/v2/pkg/protocol/handshake" ) type handshakeCacheItem struct { typ handshake.Type isClient bool epoch uint16 messageSequence uint16 data []byte } type handshakeCachePullRule struct { typ handshake.Type epoch uint16 isClient bool optional bool } type handshakeCache struct { cache []*handshakeCacheItem mu sync.Mutex } func newHandshakeCache() *handshakeCache { return &handshakeCache{} } func (h *handshakeCache) push(data []byte, epoch, messageSequence uint16, typ handshake.Type, isClient bool) { h.mu.Lock() defer h.mu.Unlock() h.cache = append(h.cache, &handshakeCacheItem{ data: append([]byte{}, data...), epoch: epoch, messageSequence: messageSequence, typ: typ, isClient: isClient, }) } // returns a list handshakes that match the requested rules // the list will contain null entries for rules that can't be satisfied // multiple entries may match a rule, but only the last match is returned (ie ClientHello with cookies) func (h *handshakeCache) pull(rules ...handshakeCachePullRule) []*handshakeCacheItem { h.mu.Lock() defer h.mu.Unlock() out := make([]*handshakeCacheItem, len(rules)) for i, r := range rules { for _, c := range h.cache { if c.typ == r.typ && c.isClient == r.isClient && c.epoch == r.epoch { switch { case out[i] == nil: out[i] = c case out[i].messageSequence < c.messageSequence: out[i] = c } } } } return out } // fullPullMap pulls all handshakes between rules[0] to rules[len(rules)-1] as map. func (h *handshakeCache) fullPullMap(startSeq int, cipherSuite CipherSuite, rules ...handshakeCachePullRule) (int, map[handshake.Type]handshake.Message, bool) { h.mu.Lock() defer h.mu.Unlock() ci := make(map[handshake.Type]*handshakeCacheItem) for _, r := range rules { var item *handshakeCacheItem for _, c := range h.cache { if c.typ == r.typ && c.isClient == r.isClient && c.epoch == r.epoch { switch { case item == nil: item = c case item.messageSequence < c.messageSequence: item = c } } } if !r.optional && item == nil { // Missing mandatory message. return startSeq, nil, false } ci[r.typ] = item } out := make(map[handshake.Type]handshake.Message) seq := startSeq for _, r := range rules { t := r.typ i := ci[t] if i == nil { continue } var keyExchangeAlgorithm CipherSuiteKeyExchangeAlgorithm if cipherSuite != nil { keyExchangeAlgorithm = cipherSuite.KeyExchangeAlgorithm() } rawHandshake := &handshake.Handshake{ KeyExchangeAlgorithm: keyExchangeAlgorithm, } if err := rawHandshake.Unmarshal(i.data); err != nil { return startSeq, nil, false } if uint16(seq) != rawHandshake.Header.MessageSequence { // There is a gap. Some messages are not arrived. return startSeq, nil, false } seq++ out[t] = rawHandshake.Message } return seq, out, true } // pullAndMerge calls pull and then merges the results, ignoring any null entries func (h *handshakeCache) pullAndMerge(rules ...handshakeCachePullRule) []byte { merged := []byte{} for _, p := range h.pull(rules...) { if p != nil { merged = append(merged, p.data...) } } return merged } // sessionHash returns the session hash for Extended Master Secret support // https://tools.ietf.org/html/draft-ietf-tls-session-hash-06#section-4 func (h *handshakeCache) sessionHash(hf prf.HashFunc, epoch uint16, additional ...[]byte) ([]byte, error) { merged := []byte{} // Order defined by https://tools.ietf.org/html/rfc5246#section-7.3 handshakeBuffer := h.pull( handshakeCachePullRule{handshake.TypeClientHello, epoch, true, false}, handshakeCachePullRule{handshake.TypeServerHello, epoch, false, false}, handshakeCachePullRule{handshake.TypeCertificate, epoch, false, false}, handshakeCachePullRule{handshake.TypeServerKeyExchange, epoch, false, false}, handshakeCachePullRule{handshake.TypeCertificateRequest, epoch, false, false}, handshakeCachePullRule{handshake.TypeServerHelloDone, epoch, false, false}, handshakeCachePullRule{handshake.TypeCertificate, epoch, true, false}, handshakeCachePullRule{handshake.TypeClientKeyExchange, epoch, true, false}, ) for _, p := range handshakeBuffer { if p == nil { continue } merged = append(merged, p.data...) } for _, a := range additional { merged = append(merged, a...) } hash := hf() if _, err := hash.Write(merged); err != nil { return []byte{}, err } return hash.Sum(nil), nil } dtls-2.2.6/handshake_cache_test.go000066400000000000000000000160261437412644000171400ustar00rootroot00000000000000package dtls import ( "bytes" "testing" "github.com/pion/dtls/v2/internal/ciphersuite" "github.com/pion/dtls/v2/pkg/protocol/handshake" ) func TestHandshakeCacheSinglePush(t *testing.T) { for _, test := range []struct { Name string Rule []handshakeCachePullRule Input []handshakeCacheItem Expected []byte }{ { Name: "Single Push", Input: []handshakeCacheItem{ {0, true, 0, 0, []byte{0x00}}, }, Rule: []handshakeCachePullRule{ {0, 0, true, false}, }, Expected: []byte{0x00}, }, { Name: "Multi Push", Input: []handshakeCacheItem{ {0, true, 0, 0, []byte{0x00}}, {1, true, 0, 1, []byte{0x01}}, {2, true, 0, 2, []byte{0x02}}, }, Rule: []handshakeCachePullRule{ {0, 0, true, false}, {1, 0, true, false}, {2, 0, true, false}, }, Expected: []byte{0x00, 0x01, 0x02}, }, { Name: "Multi Push, Rules set order", Input: []handshakeCacheItem{ {2, true, 0, 2, []byte{0x02}}, {0, true, 0, 0, []byte{0x00}}, {1, true, 0, 1, []byte{0x01}}, }, Rule: []handshakeCachePullRule{ {0, 0, true, false}, {1, 0, true, false}, {2, 0, true, false}, }, Expected: []byte{0x00, 0x01, 0x02}, }, { Name: "Multi Push, Dupe Seqnum", Input: []handshakeCacheItem{ {0, true, 0, 0, []byte{0x00}}, {1, true, 0, 1, []byte{0x01}}, {1, true, 0, 1, []byte{0x01}}, }, Rule: []handshakeCachePullRule{ {0, 0, true, false}, {1, 0, true, false}, }, Expected: []byte{0x00, 0x01}, }, { Name: "Multi Push, Dupe Seqnum Client/Server", Input: []handshakeCacheItem{ {0, true, 0, 0, []byte{0x00}}, {1, true, 0, 1, []byte{0x01}}, {1, false, 0, 1, []byte{0x02}}, }, Rule: []handshakeCachePullRule{ {0, 0, true, false}, {1, 0, true, false}, {1, 0, false, false}, }, Expected: []byte{0x00, 0x01, 0x02}, }, { Name: "Multi Push, Dupe Seqnum with Unique HandshakeType", Input: []handshakeCacheItem{ {1, true, 0, 0, []byte{0x00}}, {2, true, 0, 1, []byte{0x01}}, {3, false, 0, 0, []byte{0x02}}, }, Rule: []handshakeCachePullRule{ {1, 0, true, false}, {2, 0, true, false}, {3, 0, false, false}, }, Expected: []byte{0x00, 0x01, 0x02}, }, { Name: "Multi Push, Wrong epoch", Input: []handshakeCacheItem{ {1, true, 0, 0, []byte{0x00}}, {2, true, 1, 1, []byte{0x01}}, {2, true, 0, 2, []byte{0x11}}, {3, false, 0, 0, []byte{0x02}}, {3, false, 1, 0, []byte{0x12}}, {3, false, 2, 0, []byte{0x12}}, }, Rule: []handshakeCachePullRule{ {1, 0, true, false}, {2, 1, true, false}, {3, 0, false, false}, }, Expected: []byte{0x00, 0x01, 0x02}, }, } { h := newHandshakeCache() for _, i := range test.Input { h.push(i.data, i.epoch, i.messageSequence, i.typ, i.isClient) } verifyData := h.pullAndMerge(test.Rule...) if !bytes.Equal(verifyData, test.Expected) { t.Errorf("handshakeCache '%s' exp: % 02x actual % 02x", test.Name, test.Expected, verifyData) } } } func TestHandshakeCacheSessionHash(t *testing.T) { for _, test := range []struct { Name string Rule []handshakeCachePullRule Input []handshakeCacheItem Expected []byte }{ { Name: "Standard Handshake", Input: []handshakeCacheItem{ {handshake.TypeClientHello, true, 0, 0, []byte{0x00}}, {handshake.TypeServerHello, false, 0, 1, []byte{0x01}}, {handshake.TypeCertificate, false, 0, 2, []byte{0x02}}, {handshake.TypeServerKeyExchange, false, 0, 3, []byte{0x03}}, {handshake.TypeServerHelloDone, false, 0, 4, []byte{0x04}}, {handshake.TypeClientKeyExchange, true, 0, 5, []byte{0x05}}, }, Expected: []byte{0x17, 0xe8, 0x8d, 0xb1, 0x87, 0xaf, 0xd6, 0x2c, 0x16, 0xe5, 0xde, 0xbf, 0x3e, 0x65, 0x27, 0xcd, 0x00, 0x6b, 0xc0, 0x12, 0xbc, 0x90, 0xb5, 0x1a, 0x81, 0x0c, 0xd8, 0x0c, 0x2d, 0x51, 0x1f, 0x43}, }, { Name: "Handshake With Client Cert Request", Input: []handshakeCacheItem{ {handshake.TypeClientHello, true, 0, 0, []byte{0x00}}, {handshake.TypeServerHello, false, 0, 1, []byte{0x01}}, {handshake.TypeCertificate, false, 0, 2, []byte{0x02}}, {handshake.TypeServerKeyExchange, false, 0, 3, []byte{0x03}}, {handshake.TypeCertificateRequest, false, 0, 4, []byte{0x04}}, {handshake.TypeServerHelloDone, false, 0, 5, []byte{0x05}}, {handshake.TypeClientKeyExchange, true, 0, 6, []byte{0x06}}, }, Expected: []byte{0x57, 0x35, 0x5a, 0xc3, 0x30, 0x3c, 0x14, 0x8f, 0x11, 0xae, 0xf7, 0xcb, 0x17, 0x94, 0x56, 0xb9, 0x23, 0x2c, 0xde, 0x33, 0xa8, 0x18, 0xdf, 0xda, 0x2c, 0x2f, 0xcb, 0x93, 0x25, 0x74, 0x9a, 0x6b}, }, { Name: "Handshake Ignores after ClientKeyExchange", Input: []handshakeCacheItem{ {handshake.TypeClientHello, true, 0, 0, []byte{0x00}}, {handshake.TypeServerHello, false, 0, 1, []byte{0x01}}, {handshake.TypeCertificate, false, 0, 2, []byte{0x02}}, {handshake.TypeServerKeyExchange, false, 0, 3, []byte{0x03}}, {handshake.TypeCertificateRequest, false, 0, 4, []byte{0x04}}, {handshake.TypeServerHelloDone, false, 0, 5, []byte{0x05}}, {handshake.TypeClientKeyExchange, true, 0, 6, []byte{0x06}}, {handshake.TypeCertificateVerify, true, 0, 7, []byte{0x07}}, {handshake.TypeFinished, true, 1, 7, []byte{0x08}}, {handshake.TypeFinished, false, 1, 7, []byte{0x09}}, }, Expected: []byte{0x57, 0x35, 0x5a, 0xc3, 0x30, 0x3c, 0x14, 0x8f, 0x11, 0xae, 0xf7, 0xcb, 0x17, 0x94, 0x56, 0xb9, 0x23, 0x2c, 0xde, 0x33, 0xa8, 0x18, 0xdf, 0xda, 0x2c, 0x2f, 0xcb, 0x93, 0x25, 0x74, 0x9a, 0x6b}, }, { Name: "Handshake Ignores wrong epoch", Input: []handshakeCacheItem{ {handshake.TypeClientHello, true, 0, 0, []byte{0x00}}, {handshake.TypeServerHello, false, 0, 1, []byte{0x01}}, {handshake.TypeCertificate, false, 0, 2, []byte{0x02}}, {handshake.TypeServerKeyExchange, false, 0, 3, []byte{0x03}}, {handshake.TypeCertificateRequest, false, 0, 4, []byte{0x04}}, {handshake.TypeServerHelloDone, false, 0, 5, []byte{0x05}}, {handshake.TypeClientKeyExchange, true, 0, 6, []byte{0x06}}, {handshake.TypeCertificateVerify, true, 0, 7, []byte{0x07}}, {handshake.TypeFinished, true, 0, 7, []byte{0xf0}}, {handshake.TypeFinished, false, 0, 7, []byte{0xf1}}, {handshake.TypeFinished, true, 1, 7, []byte{0x08}}, {handshake.TypeFinished, false, 1, 7, []byte{0x09}}, {handshake.TypeFinished, true, 0, 7, []byte{0xf0}}, {handshake.TypeFinished, false, 0, 7, []byte{0xf1}}, }, Expected: []byte{0x57, 0x35, 0x5a, 0xc3, 0x30, 0x3c, 0x14, 0x8f, 0x11, 0xae, 0xf7, 0xcb, 0x17, 0x94, 0x56, 0xb9, 0x23, 0x2c, 0xde, 0x33, 0xa8, 0x18, 0xdf, 0xda, 0x2c, 0x2f, 0xcb, 0x93, 0x25, 0x74, 0x9a, 0x6b}, }, } { h := newHandshakeCache() for _, i := range test.Input { h.push(i.data, i.epoch, i.messageSequence, i.typ, i.isClient) } cipherSuite := ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{} verifyData, err := h.sessionHash(cipherSuite.HashFunc(), 0) if err != nil { t.Error(err) } if !bytes.Equal(verifyData, test.Expected) { t.Errorf("handshakeCacheSesssionHassh '%s' exp: % 02x actual % 02x", test.Name, test.Expected, verifyData) } } } dtls-2.2.6/handshake_test.go000066400000000000000000000034341437412644000160140ustar00rootroot00000000000000package dtls import ( "reflect" "testing" "time" "github.com/pion/dtls/v2/pkg/protocol" "github.com/pion/dtls/v2/pkg/protocol/extension" "github.com/pion/dtls/v2/pkg/protocol/handshake" ) func TestHandshakeMessage(t *testing.T) { rawHandshakeMessage := []byte{ 0x01, 0x00, 0x00, 0x29, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x29, 0xfe, 0xfd, 0xb6, 0x2f, 0xce, 0x5c, 0x42, 0x54, 0xff, 0x86, 0xe1, 0x24, 0x41, 0x91, 0x42, 0x62, 0x15, 0xad, 0x16, 0xc9, 0x15, 0x8d, 0x95, 0x71, 0x8a, 0xbb, 0x22, 0xd7, 0x47, 0xec, 0xd8, 0x3d, 0xdc, 0x4b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, } parsedHandshake := &handshake.Handshake{ Header: handshake.Header{ Length: 0x29, FragmentLength: 0x29, Type: handshake.TypeClientHello, }, Message: &handshake.MessageClientHello{ Version: protocol.Version{Major: 0xFE, Minor: 0xFD}, Random: handshake.Random{ GMTUnixTime: time.Unix(3056586332, 0), RandomBytes: [28]byte{0x42, 0x54, 0xff, 0x86, 0xe1, 0x24, 0x41, 0x91, 0x42, 0x62, 0x15, 0xad, 0x16, 0xc9, 0x15, 0x8d, 0x95, 0x71, 0x8a, 0xbb, 0x22, 0xd7, 0x47, 0xec, 0xd8, 0x3d, 0xdc, 0x4b}, }, SessionID: []byte{}, Cookie: []byte{}, CipherSuiteIDs: []uint16{}, CompressionMethods: []*protocol.CompressionMethod{}, Extensions: []extension.Extension{}, }, } h := &handshake.Handshake{} if err := h.Unmarshal(rawHandshakeMessage); err != nil { t.Error(err) } else if !reflect.DeepEqual(h, parsedHandshake) { t.Errorf("handshakeMessageClientHello unmarshal: got %#v, want %#v", h, parsedHandshake) } raw, err := h.Marshal() if err != nil { t.Error(err) } else if !reflect.DeepEqual(raw, rawHandshakeMessage) { t.Errorf("handshakeMessageClientHello marshal: got %#v, want %#v", raw, rawHandshakeMessage) } } dtls-2.2.6/handshaker.go000066400000000000000000000237261437412644000151450ustar00rootroot00000000000000package dtls import ( "context" "crypto/tls" "crypto/x509" "fmt" "io" "sync" "time" "github.com/pion/dtls/v2/pkg/crypto/elliptic" "github.com/pion/dtls/v2/pkg/crypto/signaturehash" "github.com/pion/dtls/v2/pkg/protocol/alert" "github.com/pion/dtls/v2/pkg/protocol/handshake" "github.com/pion/logging" ) // [RFC6347 Section-4.2.4] // +-----------+ // +---> | PREPARING | <--------------------+ // | +-----------+ | // | | | // | | Buffer next flight | // | | | // | \|/ | // | +-----------+ | // | | SENDING |<------------------+ | Send // | +-----------+ | | HelloRequest // Receive | | | | // next | | Send flight | | or // flight | +--------+ | | // | | | Set retransmit timer | | Receive // | | \|/ | | HelloRequest // | | +-----------+ | | Send // +--)--| WAITING |-------------------+ | ClientHello // | | +-----------+ Timer expires | | // | | | | | // | | +------------------------+ | // Receive | | Send Read retransmit | // last | | last | // flight | | flight | // | | | // \|/\|/ | // +-----------+ | // | FINISHED | -------------------------------+ // +-----------+ // | /|\ // | | // +---+ // Read retransmit // Retransmit last flight type handshakeState uint8 const ( handshakeErrored handshakeState = iota handshakePreparing handshakeSending handshakeWaiting handshakeFinished ) func (s handshakeState) String() string { switch s { case handshakeErrored: return "Errored" case handshakePreparing: return "Preparing" case handshakeSending: return "Sending" case handshakeWaiting: return "Waiting" case handshakeFinished: return "Finished" default: return "Unknown" } } type handshakeFSM struct { currentFlight flightVal flights []*packet retransmit bool state *State cache *handshakeCache cfg *handshakeConfig closed chan struct{} } type handshakeConfig struct { localPSKCallback PSKCallback localPSKIdentityHint []byte localCipherSuites []CipherSuite // Available CipherSuites localSignatureSchemes []signaturehash.Algorithm // Available signature schemes extendedMasterSecret ExtendedMasterSecretType // Policy for the Extended Master Support extension localSRTPProtectionProfiles []SRTPProtectionProfile // Available SRTPProtectionProfiles, if empty no SRTP support serverName string supportedProtocols []string clientAuth ClientAuthType // If we are a client should we request a client certificate localCertificates []tls.Certificate nameToCertificate map[string]*tls.Certificate insecureSkipVerify bool verifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error verifyConnection func(*State) error sessionStore SessionStore rootCAs *x509.CertPool clientCAs *x509.CertPool retransmitInterval time.Duration customCipherSuites func() []CipherSuite ellipticCurves []elliptic.Curve insecureSkipHelloVerify bool onFlightState func(flightVal, handshakeState) log logging.LeveledLogger keyLogWriter io.Writer localGetCertificate func(*ClientHelloInfo) (*tls.Certificate, error) localGetClientCertificate func(*CertificateRequestInfo) (*tls.Certificate, error) initialEpoch uint16 mu sync.Mutex } type flightConn interface { notify(ctx context.Context, level alert.Level, desc alert.Description) error writePackets(context.Context, []*packet) error recvHandshake() <-chan chan struct{} setLocalEpoch(epoch uint16) handleQueuedPackets(context.Context) error sessionKey() []byte } func (c *handshakeConfig) writeKeyLog(label string, clientRandom, secret []byte) { if c.keyLogWriter == nil { return } c.mu.Lock() defer c.mu.Unlock() _, err := c.keyLogWriter.Write([]byte(fmt.Sprintf("%s %x %x\n", label, clientRandom, secret))) if err != nil { c.log.Debugf("failed to write key log file: %s", err) } } func srvCliStr(isClient bool) string { if isClient { return "client" } return "server" } func newHandshakeFSM( s *State, cache *handshakeCache, cfg *handshakeConfig, initialFlight flightVal, ) *handshakeFSM { return &handshakeFSM{ currentFlight: initialFlight, state: s, cache: cache, cfg: cfg, closed: make(chan struct{}), } } func (s *handshakeFSM) Run(ctx context.Context, c flightConn, initialState handshakeState) error { state := initialState defer func() { close(s.closed) }() for { s.cfg.log.Tracef("[handshake:%s] %s: %s", srvCliStr(s.state.isClient), s.currentFlight.String(), state.String()) if s.cfg.onFlightState != nil { s.cfg.onFlightState(s.currentFlight, state) } var err error switch state { case handshakePreparing: state, err = s.prepare(ctx, c) case handshakeSending: state, err = s.send(ctx, c) case handshakeWaiting: state, err = s.wait(ctx, c) case handshakeFinished: state, err = s.finish(ctx, c) default: return errInvalidFSMTransition } if err != nil { return err } } } func (s *handshakeFSM) Done() <-chan struct{} { return s.closed } func (s *handshakeFSM) prepare(ctx context.Context, c flightConn) (handshakeState, error) { s.flights = nil // Prepare flights var ( a *alert.Alert err error pkts []*packet ) gen, retransmit, errFlight := s.currentFlight.getFlightGenerator() if errFlight != nil { err = errFlight a = &alert.Alert{Level: alert.Fatal, Description: alert.InternalError} } else { pkts, a, err = gen(c, s.state, s.cache, s.cfg) s.retransmit = retransmit } if a != nil { if alertErr := c.notify(ctx, a.Level, a.Description); alertErr != nil { if err != nil { err = alertErr } } } if err != nil { return handshakeErrored, err } s.flights = pkts epoch := s.cfg.initialEpoch nextEpoch := epoch for _, p := range s.flights { p.record.Header.Epoch += epoch if p.record.Header.Epoch > nextEpoch { nextEpoch = p.record.Header.Epoch } if h, ok := p.record.Content.(*handshake.Handshake); ok { h.Header.MessageSequence = uint16(s.state.handshakeSendSequence) s.state.handshakeSendSequence++ } } if epoch != nextEpoch { s.cfg.log.Tracef("[handshake:%s] -> changeCipherSpec (epoch: %d)", srvCliStr(s.state.isClient), nextEpoch) c.setLocalEpoch(nextEpoch) } return handshakeSending, nil } func (s *handshakeFSM) send(ctx context.Context, c flightConn) (handshakeState, error) { // Send flights if err := c.writePackets(ctx, s.flights); err != nil { return handshakeErrored, err } if s.currentFlight.isLastSendFlight() { return handshakeFinished, nil } return handshakeWaiting, nil } func (s *handshakeFSM) wait(ctx context.Context, c flightConn) (handshakeState, error) { //nolint:gocognit parse, errFlight := s.currentFlight.getFlightParser() if errFlight != nil { if alertErr := c.notify(ctx, alert.Fatal, alert.InternalError); alertErr != nil { if errFlight != nil { return handshakeErrored, alertErr } } return handshakeErrored, errFlight } retransmitTimer := time.NewTimer(s.cfg.retransmitInterval) for { select { case done := <-c.recvHandshake(): nextFlight, alert, err := parse(ctx, c, s.state, s.cache, s.cfg) close(done) if alert != nil { if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil { if err != nil { err = alertErr } } } if err != nil { return handshakeErrored, err } if nextFlight == 0 { break } s.cfg.log.Tracef("[handshake:%s] %s -> %s", srvCliStr(s.state.isClient), s.currentFlight.String(), nextFlight.String()) if nextFlight.isLastRecvFlight() && s.currentFlight == nextFlight { return handshakeFinished, nil } s.currentFlight = nextFlight return handshakePreparing, nil case <-retransmitTimer.C: if !s.retransmit { return handshakeWaiting, nil } return handshakeSending, nil case <-ctx.Done(): return handshakeErrored, ctx.Err() } } } func (s *handshakeFSM) finish(ctx context.Context, c flightConn) (handshakeState, error) { parse, errFlight := s.currentFlight.getFlightParser() if errFlight != nil { if alertErr := c.notify(ctx, alert.Fatal, alert.InternalError); alertErr != nil { if errFlight != nil { return handshakeErrored, alertErr } } return handshakeErrored, errFlight } retransmitTimer := time.NewTimer(s.cfg.retransmitInterval) select { case done := <-c.recvHandshake(): nextFlight, alert, err := parse(ctx, c, s.state, s.cache, s.cfg) close(done) if alert != nil { if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil { if err != nil { err = alertErr } } } if err != nil { return handshakeErrored, err } if nextFlight == 0 { break } if nextFlight.isLastRecvFlight() && s.currentFlight == nextFlight { return handshakeFinished, nil } <-retransmitTimer.C // Retransmit last flight return handshakeSending, nil case <-ctx.Done(): return handshakeErrored, ctx.Err() } return handshakeFinished, nil } dtls-2.2.6/handshaker_test.go000066400000000000000000000300051437412644000161700ustar00rootroot00000000000000//nolint:dupl package dtls import ( "bytes" "context" "crypto/tls" "errors" "sync" "testing" "time" "github.com/pion/dtls/v2/pkg/crypto/selfsign" "github.com/pion/dtls/v2/pkg/crypto/signaturehash" "github.com/pion/dtls/v2/pkg/protocol/alert" "github.com/pion/dtls/v2/pkg/protocol/handshake" "github.com/pion/dtls/v2/pkg/protocol/recordlayer" "github.com/pion/logging" "github.com/pion/transport/v2/test" ) const nonZeroRetransmitInterval = 100 * time.Millisecond // Test that writes to the key log are in the correct format and only applies // when a key log writer is given. func TestWriteKeyLog(t *testing.T) { var buf bytes.Buffer cfg := handshakeConfig{ keyLogWriter: &buf, } cfg.writeKeyLog("LABEL", []byte{0xAA, 0xBB, 0xCC}, []byte{0xDD, 0xEE, 0xFF}) // Secrets follow the format