pax_global_header00006660000000000000000000000064151433026730014516gustar00rootroot0000000000000052 comment=a621789e8dba850944500deda8eaa1c0dc4d92f0 dtls-3.1.2/000077500000000000000000000000001514330267300124675ustar00rootroot00000000000000dtls-3.1.2/.editorconfig000066400000000000000000000006321514330267300151450ustar00rootroot00000000000000# http://editorconfig.org/ # SPDX-FileCopyrightText: 2026 The Pion community # SPDX-License-Identifier: MIT root = true [*] charset = utf-8 insert_final_newline = true trim_trailing_whitespace = true end_of_line = lf [*.go] indent_style = tab indent_size = 4 [{*.yml,*.yaml}] indent_style = space indent_size = 2 # Makefiles always use tabs for indentation [Makefile] indent_style = tab dtls-3.1.2/.github/000077500000000000000000000000001514330267300140275ustar00rootroot00000000000000dtls-3.1.2/.github/.gitignore000066400000000000000000000001561514330267300160210ustar00rootroot00000000000000# SPDX-FileCopyrightText: 2026 The Pion community # SPDX-License-Identifier: MIT .goassets dtls-3.1.2/.github/fetch-scripts.sh000077500000000000000000000016001514330267300171410ustar00rootroot00000000000000#!/bin/sh # # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # # SPDX-FileCopyrightText: 2026 The Pion community # SPDX-License-Identifier: MIT set -eu SCRIPT_PATH="$(realpath "$(dirname "$0")")" GOASSETS_PATH="${SCRIPT_PATH}/.goassets" GOASSETS_REF=${GOASSETS_REF:-master} if [ -d "${GOASSETS_PATH}" ]; then if ! git -C "${GOASSETS_PATH}" diff --exit-code; then echo "${GOASSETS_PATH} has uncommitted changes" >&2 exit 1 fi git -C "${GOASSETS_PATH}" fetch origin git -C "${GOASSETS_PATH}" checkout ${GOASSETS_REF} git -C "${GOASSETS_PATH}" reset --hard origin/${GOASSETS_REF} else git clone -b ${GOASSETS_REF} https://github.com/pion/.goassets.git "${GOASSETS_PATH}" fi dtls-3.1.2/.github/install-hooks.sh000077500000000000000000000011221514330267300171510ustar00rootroot00000000000000#!/bin/sh # # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # # SPDX-FileCopyrightText: 2026 The Pion community # SPDX-License-Identifier: MIT SCRIPT_PATH="$(realpath "$(dirname "$0")")" . ${SCRIPT_PATH}/fetch-scripts.sh cp "${GOASSETS_PATH}/hooks/commit-msg.sh" "${SCRIPT_PATH}/../.git/hooks/commit-msg" cp "${GOASSETS_PATH}/hooks/pre-commit.sh" "${SCRIPT_PATH}/../.git/hooks/pre-commit" dtls-3.1.2/.github/workflows/000077500000000000000000000000001514330267300160645ustar00rootroot00000000000000dtls-3.1.2/.github/workflows/api.yaml000066400000000000000000000011141514330267300175160ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # If this repository should have package specific CI config, # remove the repository name from .goassets/.github/workflows/assets-sync.yml. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # # SPDX-FileCopyrightText: 2026 The Pion community # SPDX-License-Identifier: MIT name: API on: pull_request: jobs: check: uses: pion/.goassets/.github/workflows/api.reusable.yml@master dtls-3.1.2/.github/workflows/codeql-analysis.yml000066400000000000000000000013201514330267300216730ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # If this repository should have package specific CI config, # remove the repository name from .goassets/.github/workflows/assets-sync.yml. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # # SPDX-FileCopyrightText: 2026 The Pion community # SPDX-License-Identifier: MIT name: CodeQL on: workflow_dispatch: schedule: - cron: '23 5 * * 0' pull_request: branches: - master paths: - '**.go' jobs: analyze: uses: pion/.goassets/.github/workflows/codeql-analysis.reusable.yml@master dtls-3.1.2/.github/workflows/e2e.yaml000066400000000000000000000007461514330267300174320ustar00rootroot00000000000000# SPDX-FileCopyrightText: 2026 The Pion community # SPDX-License-Identifier: MIT name: E2E on: pull_request: branches: - master push: branches: - master jobs: e2e-test: name: Test runs-on: ubuntu-latest timeout-minutes: 10 steps: - name: checkout uses: actions/checkout@v6 - name: test run: | docker build -t pion-dtls-e2e -f e2e/Dockerfile . docker run -i --rm pion-dtls-e2e dtls-3.1.2/.github/workflows/fuzz.yaml000066400000000000000000000013421514330267300177460ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # If this repository should have package specific CI config, # remove the repository name from .goassets/.github/workflows/assets-sync.yml. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # # SPDX-FileCopyrightText: 2026 The Pion community # SPDX-License-Identifier: MIT name: Fuzz on: push: branches: - master schedule: - cron: "0 */8 * * *" jobs: fuzz: uses: pion/.goassets/.github/workflows/fuzz.reusable.yml@master with: go-version: "1.25" # auto-update/latest-go-version fuzz-time: "60s" dtls-3.1.2/.github/workflows/lint.yaml000066400000000000000000000011151514330267300177140ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # If this repository should have package specific CI config, # remove the repository name from .goassets/.github/workflows/assets-sync.yml. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # # SPDX-FileCopyrightText: 2026 The Pion community # SPDX-License-Identifier: MIT name: Lint on: pull_request: jobs: lint: uses: pion/.goassets/.github/workflows/lint.reusable.yml@master dtls-3.1.2/.github/workflows/release.yml000066400000000000000000000012501514330267300202250ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # If this repository should have package specific CI config, # remove the repository name from .goassets/.github/workflows/assets-sync.yml. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # # SPDX-FileCopyrightText: 2026 The Pion community # SPDX-License-Identifier: MIT name: Release on: push: tags: - 'v*' jobs: release: uses: pion/.goassets/.github/workflows/release.reusable.yml@master with: go-version: "1.25" # auto-update/latest-go-version dtls-3.1.2/.github/workflows/renovate-go-sum-fix.yaml000066400000000000000000000012671514330267300225720ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # If this repository should have package specific CI config, # remove the repository name from .goassets/.github/workflows/assets-sync.yml. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # # SPDX-FileCopyrightText: 2026 The Pion community # SPDX-License-Identifier: MIT name: Fix go.sum on: push: branches: - renovate/* jobs: fix: uses: pion/.goassets/.github/workflows/renovate-go-sum-fix.reusable.yml@master secrets: token: ${{ secrets.PIONBOT_PRIVATE_KEY }} dtls-3.1.2/.github/workflows/reuse.yml000066400000000000000000000011511514330267300177300ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # If this repository should have package specific CI config, # remove the repository name from .goassets/.github/workflows/assets-sync.yml. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # # SPDX-FileCopyrightText: 2026 The Pion community # SPDX-License-Identifier: MIT name: REUSE Compliance Check on: push: pull_request: jobs: lint: uses: pion/.goassets/.github/workflows/reuse.reusable.yml@master dtls-3.1.2/.github/workflows/test.yaml000066400000000000000000000033271514330267300177340ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # If this repository should have package specific CI config, # remove the repository name from .goassets/.github/workflows/assets-sync.yml. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # # SPDX-FileCopyrightText: 2026 The Pion community # SPDX-License-Identifier: MIT name: Test on: push: branches: - master pull_request: jobs: test: uses: pion/.goassets/.github/workflows/test.reusable.yml@master strategy: matrix: go: ["1.25", "1.24"] # auto-update/supported-go-version-list fail-fast: false with: go-version: ${{ matrix.go }} secrets: inherit test-i386: uses: pion/.goassets/.github/workflows/test-i386.reusable.yml@master strategy: matrix: go: ["1.25", "1.24"] # auto-update/supported-go-version-list fail-fast: false with: go-version: ${{ matrix.go }} test-windows: uses: pion/.goassets/.github/workflows/test-windows.reusable.yml@master strategy: matrix: go: ["1.25", "1.24"] # auto-update/supported-go-version-list fail-fast: false with: go-version: ${{ matrix.go }} test-macos: uses: pion/.goassets/.github/workflows/test-macos.reusable.yml@master strategy: matrix: go: ["1.25", "1.24"] # auto-update/supported-go-version-list fail-fast: false with: go-version: ${{ matrix.go }} test-wasm: uses: pion/.goassets/.github/workflows/test-wasm.reusable.yml@master with: go-version: "1.25" # auto-update/latest-go-version secrets: inherit dtls-3.1.2/.github/workflows/tidy-check.yaml000066400000000000000000000013021514330267300207700ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # If this repository should have package specific CI config, # remove the repository name from .goassets/.github/workflows/assets-sync.yml. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # # SPDX-FileCopyrightText: 2026 The Pion community # SPDX-License-Identifier: MIT name: Go mod tidy on: pull_request: push: branches: - master jobs: tidy: uses: pion/.goassets/.github/workflows/tidy-check.reusable.yml@master with: go-version: "1.25" # auto-update/latest-go-version dtls-3.1.2/.gitignore000066400000000000000000000006321514330267300144600ustar00rootroot00000000000000# SPDX-FileCopyrightText: 2026 The Pion community # SPDX-License-Identifier: MIT ### JetBrains IDE ### ##################### .idea/ ### Emacs Temporary Files ### ############################# *~ ### Folders ### ############### bin/ vendor/ node_modules/ ### Files ### ############# *.ivf *.ogg tags cover.out *.sw[poe] *.wasm examples/sfu-ws/cert.pem examples/sfu-ws/key.pem wasm_exec.js dtls-3.1.2/.golangci.yml000066400000000000000000000202661514330267300150610ustar00rootroot00000000000000# SPDX-FileCopyrightText: 2026 The Pion community # SPDX-License-Identifier: MIT version: "2" linters: enable: - asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers - bidichk # Checks for dangerous unicode character sequences - bodyclose # checks whether HTTP response body is closed successfully - containedctx # containedctx is a linter that detects struct contained context.Context field - contextcheck # check the function whether use a non-inherited context - cyclop # checks function and package cyclomatic complexity - decorder # check declaration order and count of types, constants, variables and functions - dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) - dupl # Tool for code clone detection - durationcheck # check for two durations multiplied together - err113 # Golang linter to check the errors handling expressions - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases - errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted. - errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`. - errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13. - exhaustive # check exhaustiveness of enum switch statements - forbidigo # Forbids identifiers - forcetypeassert # finds forced type assertions - gochecknoglobals # Checks that no globals are present in Go code - gocognit # Computes and checks the cognitive complexity of functions - goconst # Finds repeated strings that could be replaced by a constant - gocritic # The most opinionated Go source code linter - gocyclo # Computes and checks the cyclomatic complexity of functions - godot # Check if comments end in a period - godox # Tool for detection of FIXME, TODO and other comment keywords - goheader # Checks is file header matches to pattern - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. - goprintffuncname # Checks that printf-like functions are named with `f` at the end - gosec # Inspects source code for security problems - govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string - grouper # An analyzer to analyze expression groups. - importas # Enforces consistent import aliases - ineffassign # Detects when assignments to existing variables are not used - lll # Reports long lines - maintidx # maintidx measures the maintainability index of each function. - makezero # Finds slice declarations with non-zero initial length - misspell # Finds commonly misspelled English words in comments - nakedret # Finds naked returns in functions greater than a specified function length - nestif # Reports deeply nested if statements - nilerr # Finds the code that returns nil even if it checks that the error is not nil. - nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value. - nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity - noctx # noctx finds sending http request without context.Context - predeclared # find code that shadows one of Go's predeclared identifiers - revive # golint replacement, finds style mistakes - staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks - tagliatelle # Checks the struct tags. - thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers - unconvert # Remove unnecessary type conversions - unparam # Reports unused function parameters - unused # Checks Go code for unused constants, variables, functions and types - varnamelen # checks that the length of a variable's name matches its scope - wastedassign # wastedassign finds wasted assignment statements - whitespace # Tool for detection of leading and trailing whitespace disable: - depguard # Go linter that checks if package imports are in a list of acceptable packages - funlen # Tool for detection of long functions - gochecknoinits # Checks that no init functions are present in Go code - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. - interfacebloat # A linter that checks length of interface. - ireturn # Accept Interfaces, Return Concrete Types - mnd # An analyzer to detect magic numbers - nolintlint # Reports ill-formed or insufficient nolint directives - paralleltest # paralleltest detects missing usage of t.Parallel() method in your Go test - prealloc # Finds slice declarations that could potentially be preallocated - promlinter # Check Prometheus metrics naming via promlint - rowserrcheck # checks whether Err of rows is checked successfully - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. - testpackage # linter that makes you use a separate _test package - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes - wrapcheck # Checks that errors returned from external packages are wrapped - wsl # Whitespace Linter - Forces you to use empty lines! settings: staticcheck: checks: - all - -QF1008 # "could remove embedded field", to keep it explicit! - -QF1003 # "could use tagged switch on enum", Cases conflicts with exhaustive! exhaustive: default-signifies-exhaustive: true forbidigo: forbid: - pattern: ^fmt.Print(f|ln)?$ - pattern: ^log.(Panic|Fatal|Print)(f|ln)?$ - pattern: ^os.Exit$ - pattern: ^panic$ - pattern: ^print(ln)?$ - pattern: ^testing.T.(Error|Errorf|Fatal|Fatalf|Fail|FailNow)$ pkg: ^testing$ msg: use testify/assert instead analyze-types: true gomodguard: blocked: modules: - github.com/pkg/errors: recommendations: - errors govet: enable: - shadow revive: rules: # Prefer 'any' type alias over 'interface{}' for Go 1.18+ compatibility - name: use-any severity: warning disabled: false misspell: locale: US varnamelen: max-distance: 12 min-name-length: 2 ignore-type-assert-ok: true ignore-map-index-ok: true ignore-chan-recv-ok: true ignore-decls: - i int - n int - w io.Writer - r io.Reader - b []byte exclusions: generated: lax rules: - linters: - forbidigo - gocognit path: (examples|main\.go) - linters: - gocognit path: _test\.go - linters: - forbidigo path: cmd formatters: enable: - gci # Gci control golang package import order and make it always deterministic. - gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification - gofumpt # Gofumpt checks whether code was gofumpt-ed. - goimports # Goimports does everything that gofmt does. Additionally it checks unused imports exclusions: generated: lax dtls-3.1.2/.goreleaser.yml000066400000000000000000000001711514330267300154170ustar00rootroot00000000000000# SPDX-FileCopyrightText: 2026 The Pion community # SPDX-License-Identifier: MIT builds: - skip: true dtls-3.1.2/.reuse/000077500000000000000000000000001514330267300136705ustar00rootroot00000000000000dtls-3.1.2/.reuse/dep5000066400000000000000000000011141514330267300144450ustar00rootroot00000000000000Format: https://www.debian.org/doc/packaging-manuals/copyright-format/1.0/ Upstream-Name: Pion Source: https://github.com/pion/ Files: README.md DESIGN.md **/README.md AUTHORS.txt renovate.json go.mod go.sum **/go.mod **/go.sum .eslintrc.json package.json examples.json sfu-ws/flutter/.gitignore sfu-ws/flutter/pubspec.yaml c-data-channels/webrtc.h examples/examples.json yarn.lock Copyright: 2026 The Pion community License: MIT Files: testdata/seed/* testdata/fuzz/* **/testdata/fuzz/* api/*.txt Copyright: 2026 The Pion community License: CC0-1.0 dtls-3.1.2/LICENSE000066400000000000000000000021051514330267300134720ustar00rootroot00000000000000MIT License Copyright (c) 2026 The Pion community Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. dtls-3.1.2/LICENSES/000077500000000000000000000000001514330267300136745ustar00rootroot00000000000000dtls-3.1.2/LICENSES/CC0-1.0.txt000066400000000000000000000156101514330267300153010ustar00rootroot00000000000000Creative Commons Legal Code CC0 1.0 Universal CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED HEREUNDER. Statement of Purpose The laws of most jurisdictions throughout the world automatically confer exclusive Copyright and Related Rights (defined below) upon the creator and subsequent owner(s) (each and all, an "owner") of an original work of authorship and/or a database (each, a "Work"). Certain owners wish to permanently relinquish those rights to a Work for the purpose of contributing to a commons of creative, cultural and scientific works ("Commons") that the public can reliably and without fear of later claims of infringement build upon, modify, incorporate in other works, reuse and redistribute as freely as possible in any form whatsoever and for any purposes, including without limitation commercial purposes. These owners may contribute to the Commons to promote the ideal of a free culture and the further production of creative, cultural and scientific works, or to gain reputation or greater distribution for their Work in part through the use and efforts of others. For these and/or other purposes and motivations, and without any expectation of additional consideration or compensation, the person associating CC0 with a Work (the "Affirmer"), to the extent that he or she is an owner of Copyright and Related Rights in the Work, voluntarily elects to apply CC0 to the Work and publicly distribute the Work under its terms, with knowledge of his or her Copyright and Related Rights in the Work and the meaning and intended legal effect of CC0 on those rights. 1. Copyright and Related Rights. A Work made available under CC0 may be protected by copyright and related or neighboring rights ("Copyright and Related Rights"). Copyright and Related Rights include, but are not limited to, the following: i. the right to reproduce, adapt, distribute, perform, display, communicate, and translate a Work; ii. moral rights retained by the original author(s) and/or performer(s); iii. publicity and privacy rights pertaining to a person's image or likeness depicted in a Work; iv. rights protecting against unfair competition in regards to a Work, subject to the limitations in paragraph 4(a), below; v. rights protecting the extraction, dissemination, use and reuse of data in a Work; vi. database rights (such as those arising under Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, and under any national implementation thereof, including any amended or successor version of such directive); and vii. other similar, equivalent or corresponding rights throughout the world based on applicable law or treaty, and any national implementations thereof. 2. Waiver. To the greatest extent permitted by, but not in contravention of, applicable law, Affirmer hereby overtly, fully, permanently, irrevocably and unconditionally waives, abandons, and surrenders all of Affirmer's Copyright and Related Rights and associated claims and causes of action, whether now known or unknown (including existing as well as future claims and causes of action), in the Work (i) in all territories worldwide, (ii) for the maximum duration provided by applicable law or treaty (including future time extensions), (iii) in any current or future medium and for any number of copies, and (iv) for any purpose whatsoever, including without limitation commercial, advertising or promotional purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each member of the public at large and to the detriment of Affirmer's heirs and successors, fully intending that such Waiver shall not be subject to revocation, rescission, cancellation, termination, or any other legal or equitable action to disrupt the quiet enjoyment of the Work by the public as contemplated by Affirmer's express Statement of Purpose. 3. Public License Fallback. Should any part of the Waiver for any reason be judged legally invalid or ineffective under applicable law, then the Waiver shall be preserved to the maximum extent permitted taking into account Affirmer's express Statement of Purpose. In addition, to the extent the Waiver is so judged Affirmer hereby grants to each affected person a royalty-free, non transferable, non sublicensable, non exclusive, irrevocable and unconditional license to exercise Affirmer's Copyright and Related Rights in the Work (i) in all territories worldwide, (ii) for the maximum duration provided by applicable law or treaty (including future time extensions), (iii) in any current or future medium and for any number of copies, and (iv) for any purpose whatsoever, including without limitation commercial, advertising or promotional purposes (the "License"). The License shall be deemed effective as of the date CC0 was applied by Affirmer to the Work. Should any part of the License for any reason be judged legally invalid or ineffective under applicable law, such partial invalidity or ineffectiveness shall not invalidate the remainder of the License, and in such case Affirmer hereby affirms that he or she will not (i) exercise any of his or her remaining Copyright and Related Rights in the Work or (ii) assert any associated claims and causes of action with respect to the Work, in either case contrary to Affirmer's express Statement of Purpose. 4. Limitations and Disclaimers. a. No trademark or patent rights held by Affirmer are waived, abandoned, surrendered, licensed or otherwise affected by this document. b. Affirmer offers the Work as-is and makes no representations or warranties of any kind concerning the Work, express, implied, statutory or otherwise, including without limitation warranties of title, merchantability, fitness for a particular purpose, non infringement, or the absence of latent or other defects, accuracy, or the present or absence of errors, whether or not discoverable, all to the greatest extent permissible under applicable law. c. Affirmer disclaims responsibility for clearing rights of other persons that may apply to the Work or any use thereof, including without limitation any person's Copyright and Related Rights in the Work. Further, Affirmer disclaims responsibility for obtaining any necessary consents, permissions or other rights required for any use of the Work. d. Affirmer understands and acknowledges that Creative Commons is not a party to this document and has no duty or obligation with respect to this CC0 or use of the Work. dtls-3.1.2/LICENSES/MIT.txt000066400000000000000000000020661514330267300150720ustar00rootroot00000000000000MIT License Copyright (c) Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. dtls-3.1.2/README.md000066400000000000000000000161151514330267300137520ustar00rootroot00000000000000


Pion DTLS

A Go implementation of DTLS

Pion DTLS Sourcegraph Widget join us on Discord Follow us on Bluesky
GitHub Workflow Status Go Reference Coverage Status Go Report Card License: MIT


Native [DTLS 1.2][rfc6347] implementation in the Go programming language. A long term goal is a professional security review, and maybe an inclusion in stdlib. ### RFCs #### Implemented - **RFC 6347**: [Datagram Transport Layer Security Version 1.2][rfc6347] - **RFC 5705**: [Keying Material Exporters for Transport Layer Security (TLS)][rfc5705] - **RFC 7627**: [Transport Layer Security (TLS) - Session Hash and Extended Master Secret Extension][rfc7627] - **RFC 7301**: [Transport Layer Security (TLS) - Application-Layer Protocol Negotiation Extension][rfc7301] [rfc5289]: https://tools.ietf.org/html/rfc5289 [rfc5487]: https://tools.ietf.org/html/rfc5487 [rfc5489]: https://tools.ietf.org/html/rfc5489 [rfc5705]: https://tools.ietf.org/html/rfc5705 [rfc6347]: https://tools.ietf.org/html/rfc6347 [rfc6655]: https://tools.ietf.org/html/rfc6655 [rfc7301]: https://tools.ietf.org/html/rfc7301 [rfc7627]: https://tools.ietf.org/html/rfc7627 [rfc8422]: https://tools.ietf.org/html/rfc8422 [rfc9147]: https://tools.ietf.org/html/rfc9147 ### Goals/Progress This will only be targeting DTLS 1.2, and the most modern/common cipher suites. We would love contributions that fall under the 'Planned Features' and any bug fixes! #### Current features * DTLS 1.2 Client/Server * Key Exchange via ECDHE(curve25519, nistp256, nistp384) and PSK * Packet loss and re-ordering is handled during handshaking * Key export ([RFC 5705][rfc5705]) * Serialization and Resumption of sessions * Extended Master Secret extension ([RFC 7627][rfc7627]) * ALPN extension ([RFC 7301][rfc7301]) #### Supported ciphers ##### ECDHE * TLS_ECDHE_ECDSA_WITH_AES_128_CCM ([RFC 6655][rfc6655]) * TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 ([RFC 6655][rfc6655]) * TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 ([RFC 5289][rfc5289]) * TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 ([RFC 5289][rfc5289]) * TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 ([RFC 5289][rfc5289]) * TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 ([RFC 5289][rfc5289]) * TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA ([RFC 8422][rfc8422]) * TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA ([RFC 8422][rfc8422]) ##### PSK * TLS_PSK_WITH_AES_128_CCM ([RFC 6655][rfc6655]) * TLS_PSK_WITH_AES_128_CCM_8 ([RFC 6655][rfc6655]) * TLS_PSK_WITH_AES_256_CCM_8 ([RFC 6655][rfc6655]) * TLS_PSK_WITH_AES_128_GCM_SHA256 ([RFC 5487][rfc5487]) * TLS_PSK_WITH_AES_128_CBC_SHA256 ([RFC 5487][rfc5487]) ##### ECDHE & PSK * TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 ([RFC 5489][rfc5489]) #### Planned Features * DTLS 1.3 ([RFC 9147][rfc9147]) * Chacha20Poly1305 #### Excluded Features * DTLS 1.0 * Renegotiation * Compression ### Using This library needs at least Go 1.21, and you should have [Go modules enabled](https://github.com/golang/go/wiki/Modules). #### Pion DTLS For a DTLS 1.2 Server that listens on 127.0.0.1:4444 ```sh go run examples/listen/selfsign/main.go ``` For a DTLS 1.2 Client that connects to 127.0.0.1:4444 ```sh go run examples/dial/selfsign/main.go ``` #### OpenSSL Pion DTLS can connect to itself and OpenSSL. ``` // Generate a certificate openssl ecparam -out key.pem -name prime256v1 -genkey openssl req -new -sha256 -key key.pem -out server.csr openssl x509 -req -sha256 -days 365 -in server.csr -signkey key.pem -out cert.pem // Use with examples/dial/selfsign/main.go openssl s_server -dtls1_2 -cert cert.pem -key key.pem -accept 4444 // Use with examples/listen/selfsign/main.go openssl s_client -dtls1_2 -connect 127.0.0.1:4444 -debug -cert cert.pem -key key.pem ``` ### Using with PSK Pion DTLS also comes with examples that do key exchange via PSK #### Pion DTLS ```sh go run examples/listen/psk/main.go ``` ```sh go run examples/dial/psk/main.go ``` #### OpenSSL ``` // Use with examples/dial/psk/main.go openssl s_server -dtls1_2 -accept 4444 -nocert -psk abc123 -cipher PSK-AES128-CCM8 // Use with examples/listen/psk/main.go openssl s_client -dtls1_2 -connect 127.0.0.1:4444 -psk abc123 -cipher PSK-AES128-CCM8 ``` ### Community Pion has an active community on the [Discord](https://discord.gg/PngbdqpFbt). Follow the [Pion Bluesky](https://bsky.app/profile/pion.ly) or [Pion Twitter](https://twitter.com/_pion) for project updates and important WebRTC news. We are always looking to support **your projects**. Please reach out if you have something to build! If you need commercial support or don't want to use public methods you can contact us at [team@pion.ly](mailto:team@pion.ly) ### Contributing Check out the [contributing wiki](https://github.com/pion/webrtc/wiki/Contributing) to join the group of amazing people making this project possible ### Funding NLnet foundation logo NLnet foundation logo The DTLS 1.3 implementation in this project is funded through the [NGI0 Commons Fund](https://nlnet.nl/commonsfund), a fund established by [NLnet](https://nlnet.nl/) with financial support from the European Commission's [Next Generation Internet](https://ngi.eu/) programme, under the aegis of [DG Communications Networks, Content and Technology](https://commission.europa.eu/about-european-commission/departments-and-executive-agencies/communications-networks-content-and-technology_en) under grant agreement No [101135429](https://cordis.europa.eu/project/id/101135429). Additional funding is made available by the [Swiss State Secretariat for Education, Research and Innovation](https://www.sbfi.admin.ch/sbfi/en/home.html) (SERI). Learn more on the [NLnet project page](https://nlnet.nl/project/PION-DTLS1.3/). ### License MIT License - see [LICENSE](LICENSE) for full text dtls-3.1.2/bench_test.go000066400000000000000000000052231514330267300151360ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "context" "crypto/tls" "fmt" "testing" "time" "github.com/pion/dtls/v3/pkg/crypto/selfsign" dtlsnet "github.com/pion/dtls/v3/pkg/net" "github.com/pion/logging" "github.com/pion/transport/v4/dpipe" "github.com/pion/transport/v4/test" "github.com/stretchr/testify/assert" ) func TestSimpleReadWrite(t *testing.T) { report := test.CheckRoutines(t) defer report() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() ca, cb := dpipe.Pipe() certificate, err := selfsign.GenerateSelfSigned() assert.NoError(t, err) gotHello := make(chan struct{}) go func() { server, sErr := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ Certificates: []tls.Certificate{certificate}, LoggerFactory: logging.NewDefaultLoggerFactory(), }, false) assert.NoError(t, sErr) buf := make([]byte, 1024) _, sErr = server.Read(buf) //nolint:contextcheck assert.NoError(t, sErr) gotHello <- struct{}{} assert.NoError(t, server.Close()) //nolint:contextcheck }() client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ LoggerFactory: logging.NewDefaultLoggerFactory(), InsecureSkipVerify: true, }, false) assert.NoError(t, err) _, err = client.Write([]byte("hello")) assert.NoError(t, err) select { case <-gotHello: // OK case <-time.After(time.Second * 5): assert.Fail(t, "timeout") } assert.NoError(t, client.Close()) } func benchmarkConn(b *testing.B, payloadSize int64) { b.Helper() b.Run(fmt.Sprintf("%d", payloadSize), func(b *testing.B) { ctx := context.Background() ca, cb := dpipe.Pipe() certificate, err := selfsign.GenerateSelfSigned() assert.NoError(b, err) server := make(chan *Conn) go func() { s, sErr := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ Certificates: []tls.Certificate{certificate}, }, false) assert.NoError(b, sErr) server <- s }() hw := make([]byte, payloadSize) b.ReportAllocs() b.SetBytes(int64(len(hw))) go func() { client, cErr := testClient( ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{InsecureSkipVerify: true}, false, ) assert.NoError(b, cErr) for { _, cErr = client.Write(hw) //nolint:contextcheck assert.NoError(b, cErr) } }() s := <-server buf := make([]byte, 2048) for i := 0; i < b.N; i++ { _, err = s.Read(buf) assert.NoError(b, err) } }) } func BenchmarkConnReadWrite(b *testing.B) { for _, n := range []int64{16, 128, 512, 1024, 2048} { benchmarkConn(b, n) } } dtls-3.1.2/certificate.go000066400000000000000000000114471514330267300153070ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "bytes" "crypto/tls" "crypto/x509" "fmt" "strings" "github.com/pion/dtls/v3/pkg/protocol/handshake" ) // ClientHelloInfo contains information from a ClientHello message in order to // guide application logic in the GetCertificate. type ClientHelloInfo struct { // ServerName indicates the name of the server requested by the client // in order to support virtual hosting. ServerName is only set if the // client is using SNI (see RFC 4366, Section 3.1). ServerName string // CipherSuites lists the CipherSuites supported by the client (e.g. // TLS_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256). CipherSuites []CipherSuiteID // RandomBytes stores the client hello random bytes RandomBytes [handshake.RandomBytesLength]byte } // CertificateRequestInfo contains information from a server's // CertificateRequest message, which is used to demand a certificate and proof // of control from a client. type CertificateRequestInfo struct { // AcceptableCAs contains zero or more, DER-encoded, X.501 // Distinguished Names. These are the names of root or intermediate CAs // that the server wishes the returned certificate to be signed by. An // empty slice indicates that the server has no preference. AcceptableCAs [][]byte } // SupportsCertificate returns nil if the provided certificate is supported by // the server that sent the CertificateRequest. Otherwise, it returns an error // describing the reason for the incompatibility. // NOTE: original src: // https://github.com/golang/go/blob/29b9a328d268d53833d2cc063d1d8b4bf6852675/src/crypto/tls/common.go#L1273 func (cri *CertificateRequestInfo) SupportsCertificate(c *tls.Certificate) error { if len(cri.AcceptableCAs) == 0 { return nil } for j, cert := range c.Certificate { x509Cert := c.Leaf // Parse the certificate if this isn't the leaf node, or if // chain.Leaf was nil. if j != 0 || x509Cert == nil { var err error if x509Cert, err = x509.ParseCertificate(cert); err != nil { return fmt.Errorf("failed to parse certificate #%d in the chain: %w", j, err) } } for _, ca := range cri.AcceptableCAs { if bytes.Equal(x509Cert.RawIssuer, ca) { return nil } } } return errNotAcceptableCertificateChain } func (c *handshakeConfig) setNameToCertificateLocked() { nameToCertificate := make(map[string]*tls.Certificate) for i := range c.localCertificates { cert := &c.localCertificates[i] x509Cert := cert.Leaf if x509Cert == nil { var parseErr error x509Cert, parseErr = x509.ParseCertificate(cert.Certificate[0]) if parseErr != nil { continue } } if len(x509Cert.Subject.CommonName) > 0 { nameToCertificate[strings.ToLower(x509Cert.Subject.CommonName)] = cert } for _, san := range x509Cert.DNSNames { nameToCertificate[strings.ToLower(san)] = cert } } c.nameToCertificate = nameToCertificate } //nolint:cyclop func (c *handshakeConfig) getCertificate(clientHelloInfo *ClientHelloInfo) (*tls.Certificate, error) { c.mu.Lock() defer c.mu.Unlock() if c.localGetCertificate != nil && (len(c.localCertificates) == 0 || len(clientHelloInfo.ServerName) > 0) { cert, err := c.localGetCertificate(clientHelloInfo) if cert != nil || err != nil { return cert, err } } if c.nameToCertificate == nil { c.setNameToCertificateLocked() } if len(c.localCertificates) == 0 { return nil, errNoCertificates } if len(c.localCertificates) == 1 { // There's only one choice, so no point doing any work. return &c.localCertificates[0], nil } if len(clientHelloInfo.ServerName) == 0 { return &c.localCertificates[0], nil } name := strings.TrimRight(strings.ToLower(clientHelloInfo.ServerName), ".") if cert, ok := c.nameToCertificate[name]; ok { return cert, nil } // try replacing labels in the name with wildcards until we get a // match. labels := strings.Split(name, ".") for i := range labels { labels[i] = "*" candidate := strings.Join(labels, ".") if cert, ok := c.nameToCertificate[candidate]; ok { return cert, nil } } // If nothing matches, return the first certificate. return &c.localCertificates[0], nil } // NOTE: original src: // https://github.com/golang/go/blob/29b9a328d268d53833d2cc063d1d8b4bf6852675/src/crypto/tls/handshake_client.go#L974 func (c *handshakeConfig) getClientCertificate(cri *CertificateRequestInfo) (*tls.Certificate, error) { c.mu.Lock() defer c.mu.Unlock() if c.localGetClientCertificate != nil { return c.localGetClientCertificate(cri) } for i := range c.localCertificates { chain := c.localCertificates[i] if err := cri.SupportsCertificate(&chain); err != nil { continue } return &chain, nil } // No acceptable certificate found. Don't send a certificate. return new(tls.Certificate), nil } dtls-3.1.2/certificate_test.go000066400000000000000000000045771514330267300163540ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "crypto/tls" "testing" "github.com/pion/dtls/v3/pkg/crypto/selfsign" "github.com/stretchr/testify/assert" ) func TestGetCertificate(t *testing.T) { certificateWildcard, err := selfsign.GenerateSelfSignedWithDNS("*.test.test") assert.NoError(t, err) certificateTest, err := selfsign.GenerateSelfSignedWithDNS("test.test", "www.test.test", "pop.test.test") assert.NoError(t, err) certificateRandom, err := selfsign.GenerateSelfSigned() assert.NoError(t, err) testCases := []struct { localCertificates []tls.Certificate desc string serverName string expectedCertificate tls.Certificate getCertificate func(info *ClientHelloInfo) (*tls.Certificate, error) }{ { desc: "Simple match in CN", localCertificates: []tls.Certificate{ certificateRandom, certificateTest, certificateWildcard, }, serverName: "test.test", expectedCertificate: certificateTest, }, { desc: "Simple match in SANs", localCertificates: []tls.Certificate{ certificateRandom, certificateTest, certificateWildcard, }, serverName: "www.test.test", expectedCertificate: certificateTest, }, { desc: "Wildcard match", localCertificates: []tls.Certificate{ certificateRandom, certificateTest, certificateWildcard, }, serverName: "foo.test.test", expectedCertificate: certificateWildcard, }, { desc: "No match return first", localCertificates: []tls.Certificate{ certificateRandom, certificateTest, certificateWildcard, }, serverName: "foo.bar", expectedCertificate: certificateRandom, }, { desc: "Get certificate from callback", getCertificate: func(*ClientHelloInfo) (*tls.Certificate, error) { return &certificateTest, nil }, expectedCertificate: certificateTest, }, } for _, test := range testCases { test := test t.Run(test.desc, func(t *testing.T) { cfg := &handshakeConfig{ localCertificates: test.localCertificates, localGetCertificate: test.getCertificate, } cert, err := cfg.getCertificate(&ClientHelloInfo{ServerName: test.serverName}) assert.NoError(t, err) assert.Equal(t, test.expectedCertificate.Leaf, cert.Leaf, "Certificate Leaf should match expected") }) } } dtls-3.1.2/cipher_suite.go000066400000000000000000000250521514330267300155050ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "crypto" "crypto/ecdsa" "crypto/ed25519" "crypto/rsa" "crypto/tls" "fmt" "hash" "github.com/pion/dtls/v3/internal/ciphersuite" "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) // CipherSuiteID is an ID for our supported CipherSuites. type CipherSuiteID = ciphersuite.ID // Supported Cipher Suites. const ( // nolint: godot // AES-128-CCM TLS_ECDHE_ECDSA_WITH_AES_128_CCM CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_128_CCM // nolint: revive,staticcheck,lll TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 // nolint: revive,staticcheck,lll // nolint: godot // AES-128-GCM-SHA256 TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 // nolint: revive,staticcheck,lll TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 CipherSuiteID = ciphersuite.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 // nolint: revive,staticcheck,lll TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 // nolint: revive,staticcheck,lll TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 CipherSuiteID = ciphersuite.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 // nolint: revive,staticcheck,lll // nolint: godot // AES-256-CBC-SHA TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA CipherSuiteID = ciphersuite.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA // nolint: revive,staticcheck,lll TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA CipherSuiteID = ciphersuite.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA // nolint: revive,staticcheck,lll TLS_PSK_WITH_AES_128_CCM CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_CCM // nolint: revive,staticcheck,lll TLS_PSK_WITH_AES_128_CCM_8 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_CCM_8 // nolint: revive,staticcheck,lll TLS_PSK_WITH_AES_256_CCM_8 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_256_CCM_8 // nolint: revive,staticcheck,lll TLS_PSK_WITH_AES_128_GCM_SHA256 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_GCM_SHA256 // nolint: revive,staticcheck,lll TLS_PSK_WITH_AES_128_CBC_SHA256 CipherSuiteID = ciphersuite.TLS_PSK_WITH_AES_128_CBC_SHA256 // nolint: revive,staticcheck,lll TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 CipherSuiteID = ciphersuite.TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 // nolint: revive,staticcheck,lll ) // CipherSuiteAuthenticationType controls what authentication method is using during the handshake for a CipherSuite. type CipherSuiteAuthenticationType = ciphersuite.AuthenticationType // AuthenticationType Enums. const ( CipherSuiteAuthenticationTypeCertificate CipherSuiteAuthenticationType = ciphersuite.AuthenticationTypeCertificate CipherSuiteAuthenticationTypePreSharedKey CipherSuiteAuthenticationType = ciphersuite.AuthenticationTypePreSharedKey CipherSuiteAuthenticationTypeAnonymous CipherSuiteAuthenticationType = ciphersuite.AuthenticationTypeAnonymous ) // CipherSuiteKeyExchangeAlgorithm controls what exchange algorithm is using during the handshake for a CipherSuite. type CipherSuiteKeyExchangeAlgorithm = ciphersuite.KeyExchangeAlgorithm // CipherSuiteKeyExchangeAlgorithm Bitmask. const ( CipherSuiteKeyExchangeAlgorithmNone CipherSuiteKeyExchangeAlgorithm = ciphersuite.KeyExchangeAlgorithmNone CipherSuiteKeyExchangeAlgorithmPsk CipherSuiteKeyExchangeAlgorithm = ciphersuite.KeyExchangeAlgorithmPsk CipherSuiteKeyExchangeAlgorithmEcdhe CipherSuiteKeyExchangeAlgorithm = ciphersuite.KeyExchangeAlgorithmEcdhe ) var _ = allCipherSuites() // Necessary until this function isn't only used by Go 1.14 // CipherSuite is an interface that all DTLS CipherSuites must satisfy. type CipherSuite interface { // String of CipherSuite, only used for logging String() string // ID of CipherSuite. ID() CipherSuiteID // What type of Certificate does this CipherSuite use CertificateType() clientcertificate.Type // What Hash function is used during verification HashFunc() func() hash.Hash // AuthenticationType controls what authentication method is using during the handshake AuthenticationType() CipherSuiteAuthenticationType // KeyExchangeAlgorithm controls what exchange algorithm is using during the handshake KeyExchangeAlgorithm() CipherSuiteKeyExchangeAlgorithm // ECC (Elliptic Curve Cryptography) determines whether ECC extesions will be send during handshake. // https://datatracker.ietf.org/doc/html/rfc4492#page-10 ECC() bool // Called when keying material has been generated, should initialize the internal cipher Init(masterSecret, clientRandom, serverRandom []byte, isClient bool) error IsInitialized() bool Encrypt(pkt *recordlayer.RecordLayer, raw []byte) ([]byte, error) Decrypt(h recordlayer.Header, in []byte) ([]byte, error) } // CipherSuiteName provides the same functionality as tls.CipherSuiteName // that appeared first in Go 1.14. // // Our implementation differs slightly in that it takes in a CiperSuiteID, // like the rest of our library, instead of a uint16 like crypto/tls. func CipherSuiteName(id CipherSuiteID) string { suite := cipherSuiteForID(id, nil) if suite != nil { return suite.String() } return fmt.Sprintf("0x%04X", uint16(id)) } // Taken from https://www.iana.org/assignments/tls-parameters/tls-parameters.xml // A cipherSuite is a specific combination of key agreement, cipher and MAC // function. func cipherSuiteForID(id CipherSuiteID, customCiphers func() []CipherSuite) CipherSuite { //nolint:cyclop switch id { //nolint:exhaustive case TLS_ECDHE_ECDSA_WITH_AES_128_CCM: return ciphersuite.NewTLSEcdheEcdsaWithAes128Ccm() case TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8: return ciphersuite.NewTLSEcdheEcdsaWithAes128Ccm8() case TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: return &ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{} case TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: return &ciphersuite.TLSEcdheRsaWithAes128GcmSha256{} case TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA: return &ciphersuite.TLSEcdheEcdsaWithAes256CbcSha{} case TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA: return &ciphersuite.TLSEcdheRsaWithAes256CbcSha{} case TLS_PSK_WITH_AES_128_CCM: return ciphersuite.NewTLSPskWithAes128Ccm() case TLS_PSK_WITH_AES_128_CCM_8: return ciphersuite.NewTLSPskWithAes128Ccm8() case TLS_PSK_WITH_AES_256_CCM_8: return ciphersuite.NewTLSPskWithAes256Ccm8() case TLS_PSK_WITH_AES_128_GCM_SHA256: return &ciphersuite.TLSPskWithAes128GcmSha256{} case TLS_PSK_WITH_AES_128_CBC_SHA256: return &ciphersuite.TLSPskWithAes128CbcSha256{} case TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: return &ciphersuite.TLSEcdheEcdsaWithAes256GcmSha384{} case TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384: return &ciphersuite.TLSEcdheRsaWithAes256GcmSha384{} case TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256: return ciphersuite.NewTLSEcdhePskWithAes128CbcSha256() } if customCiphers != nil { for _, c := range customCiphers() { if c.ID() == id { return c } } } return nil } // CipherSuites we support in order of preference. func defaultCipherSuites() []CipherSuite { return []CipherSuite{ &ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{}, &ciphersuite.TLSEcdheRsaWithAes128GcmSha256{}, &ciphersuite.TLSEcdheEcdsaWithAes256CbcSha{}, &ciphersuite.TLSEcdheRsaWithAes256CbcSha{}, &ciphersuite.TLSEcdheEcdsaWithAes256GcmSha384{}, &ciphersuite.TLSEcdheRsaWithAes256GcmSha384{}, } } func allCipherSuites() []CipherSuite { return []CipherSuite{ ciphersuite.NewTLSEcdheEcdsaWithAes128Ccm(), ciphersuite.NewTLSEcdheEcdsaWithAes128Ccm8(), &ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{}, &ciphersuite.TLSEcdheRsaWithAes128GcmSha256{}, &ciphersuite.TLSEcdheEcdsaWithAes256CbcSha{}, &ciphersuite.TLSEcdheRsaWithAes256CbcSha{}, ciphersuite.NewTLSPskWithAes128Ccm(), ciphersuite.NewTLSPskWithAes128Ccm8(), ciphersuite.NewTLSPskWithAes256Ccm8(), &ciphersuite.TLSPskWithAes128GcmSha256{}, &ciphersuite.TLSEcdheEcdsaWithAes256GcmSha384{}, &ciphersuite.TLSEcdheRsaWithAes256GcmSha384{}, } } func cipherSuiteIDs(cipherSuites []CipherSuite) []uint16 { rtrn := []uint16{} for _, c := range cipherSuites { rtrn = append(rtrn, uint16(c.ID())) } return rtrn } //nolint:cyclop func parseCipherSuites( userSelectedSuites []CipherSuiteID, customCipherSuites func() []CipherSuite, includeCertificateSuites, includePSKSuites bool, ) ([]CipherSuite, error) { cipherSuitesForIDs := func(ids []CipherSuiteID) ([]CipherSuite, error) { cipherSuites := []CipherSuite{} for _, id := range ids { c := cipherSuiteForID(id, nil) if c == nil { return nil, &invalidCipherSuiteError{id} } cipherSuites = append(cipherSuites, c) } return cipherSuites, nil } var ( cipherSuites []CipherSuite err error i int ) if userSelectedSuites != nil { cipherSuites, err = cipherSuitesForIDs(userSelectedSuites) if err != nil { return nil, err } } else { cipherSuites = defaultCipherSuites() } // Put CustomCipherSuites before ID selected suites if customCipherSuites != nil { cipherSuites = append(customCipherSuites(), cipherSuites...) } var foundCertificateSuite, foundPSKSuite, foundAnonymousSuite bool for _, c := range cipherSuites { switch { case includeCertificateSuites && c.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate: foundCertificateSuite = true case includePSKSuites && c.AuthenticationType() == CipherSuiteAuthenticationTypePreSharedKey: foundPSKSuite = true case c.AuthenticationType() == CipherSuiteAuthenticationTypeAnonymous: foundAnonymousSuite = true default: continue } cipherSuites[i] = c i++ } switch { case includeCertificateSuites && !foundCertificateSuite && !foundAnonymousSuite: return nil, errNoAvailableCertificateCipherSuite case includePSKSuites && !foundPSKSuite: return nil, errNoAvailablePSKCipherSuite case i == 0: return nil, errNoAvailableCipherSuites } return cipherSuites[:i], nil } func filterCipherSuitesForCertificate(cert *tls.Certificate, cipherSuites []CipherSuite) []CipherSuite { if cert == nil || cert.PrivateKey == nil { return cipherSuites } signer, ok := cert.PrivateKey.(crypto.Signer) if !ok { return cipherSuites } var certType clientcertificate.Type switch signer.Public().(type) { case ed25519.PublicKey, *ecdsa.PublicKey: certType = clientcertificate.ECDSASign case *rsa.PublicKey: certType = clientcertificate.RSASign } filtered := []CipherSuite{} for _, c := range cipherSuites { if c.AuthenticationType() != CipherSuiteAuthenticationTypeCertificate || certType == c.CertificateType() { filtered = append(filtered, c) } } return filtered } dtls-3.1.2/cipher_suite_go114.go000066400000000000000000000022501514330267300164130ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT //go:build go1.14 // +build go1.14 package dtls import ( "crypto/tls" ) // VersionDTLS12 is the DTLS version in the same style as // VersionTLSXX from crypto/tls. const VersionDTLS12 = 0xfefd // Convert from our cipherSuite interface to a tls.CipherSuite struct. func toTLSCipherSuite(c CipherSuite) *tls.CipherSuite { return &tls.CipherSuite{ ID: uint16(c.ID()), Name: c.String(), SupportedVersions: []uint16{VersionDTLS12}, Insecure: false, } } // CipherSuites returns a list of cipher suites currently implemented by this // package, excluding those with security issues, which are returned by // InsecureCipherSuites. func CipherSuites() []*tls.CipherSuite { suites := allCipherSuites() res := make([]*tls.CipherSuite, len(suites)) for i, c := range suites { res[i] = toTLSCipherSuite(c) } return res } // InsecureCipherSuites returns a list of cipher suites currently implemented by // this package and which have security issues. func InsecureCipherSuites() []*tls.CipherSuite { var res []*tls.CipherSuite return res } dtls-3.1.2/cipher_suite_go114_test.go000066400000000000000000000016521514330267300174570ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT //go:build go1.14 // +build go1.14 package dtls import ( "testing" "github.com/stretchr/testify/assert" ) func TestInsecureCipherSuites(t *testing.T) { assert.Empty(t, InsecureCipherSuites(), "Expected no insecure ciphersuites") } func TestCipherSuites(t *testing.T) { ours := allCipherSuites() theirs := CipherSuites() assert.Equal(t, len(ours), len(theirs)) for i, s := range ours { i := i s := s t.Run(s.String(), func(t *testing.T) { cipher := theirs[i] assert.Equal(t, cipher.ID, uint16(s.ID())) assert.Equal(t, cipher.Name, s.String()) assert.Equal(t, 1, len(cipher.SupportedVersions), "Expected SupportedVersion to be 1") assert.Equal(t, uint16(VersionDTLS12), cipher.SupportedVersions[0], "Expected SupportedVersion to match") assert.False(t, cipher.Insecure, "Expected Insecure") }) } } dtls-3.1.2/cipher_suite_test.go000066400000000000000000000050741514330267300165460ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "context" "testing" "time" "github.com/pion/dtls/v3/internal/ciphersuite" dtlsnet "github.com/pion/dtls/v3/pkg/net" "github.com/pion/transport/v4/dpipe" "github.com/pion/transport/v4/test" "github.com/stretchr/testify/assert" ) func TestCipherSuiteName(t *testing.T) { testCases := []struct { suite CipherSuiteID expected string }{ {TLS_ECDHE_ECDSA_WITH_AES_128_CCM, "TLS_ECDHE_ECDSA_WITH_AES_128_CCM"}, {CipherSuiteID(0x0000), "0x0000"}, } for _, testCase := range testCases { assert.Equal(t, testCase.expected, CipherSuiteName(testCase.suite)) } } func TestAllCipherSuites(t *testing.T) { assert.NotEmpty(t, allCipherSuites()) } // CustomCipher that is just used to assert Custom IDs work. type testCustomCipherSuite struct { ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256 authenticationType CipherSuiteAuthenticationType } func (t *testCustomCipherSuite) ID() CipherSuiteID { return 0xFFFF } func (t *testCustomCipherSuite) AuthenticationType() CipherSuiteAuthenticationType { return t.authenticationType } // Assert that two connections that pass in a CipherSuite with a CustomID works. func TestCustomCipherSuite(t *testing.T) { type result struct { c *Conn err error } // Check for leaking routines report := test.CheckRoutines(t) defer report() runTest := func(cipherFactory func() []CipherSuite) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() ca, cb := dpipe.Pipe() resultCh := make(chan result) go func() { client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ CipherSuites: []CipherSuiteID{}, CustomCipherSuites: cipherFactory, }, true) resultCh <- result{client, err} }() server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ CipherSuites: []CipherSuiteID{}, CustomCipherSuites: cipherFactory, }, true) clientResult := <-resultCh assert.NoError(t, err) assert.NoError(t, server.Close()) assert.Nil(t, clientResult.err) assert.NoError(t, clientResult.c.Close()) } t.Run("Custom ID", func(*testing.T) { runTest(func() []CipherSuite { return []CipherSuite{&testCustomCipherSuite{authenticationType: CipherSuiteAuthenticationTypeCertificate}} }) }) t.Run("Anonymous Cipher", func(*testing.T) { runTest(func() []CipherSuite { return []CipherSuite{&testCustomCipherSuite{authenticationType: CipherSuiteAuthenticationTypeAnonymous}} }) }) } dtls-3.1.2/codecov.yml000066400000000000000000000007151514330267300146370ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # # SPDX-FileCopyrightText: 2026 The Pion community # SPDX-License-Identifier: MIT coverage: status: project: default: # Allow decreasing 2% of total coverage to avoid noise. threshold: 2% patch: default: target: 70% only_pulls: true ignore: - "examples/*" - "examples/**/*" dtls-3.1.2/compression_method.go000066400000000000000000000004261514330267300167210ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT package dtls import "github.com/pion/dtls/v3/pkg/protocol" func defaultCompressionMethods() []*protocol.CompressionMethod { return []*protocol.CompressionMethod{ {}, } } dtls-3.1.2/config.go000066400000000000000000000303721514330267300142700ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "crypto" "crypto/ecdsa" "crypto/ed25519" "crypto/rsa" "crypto/tls" "crypto/x509" "io" "net" "time" "github.com/pion/dtls/v3/pkg/crypto/elliptic" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/logging" ) const keyLogLabelTLS12 = "CLIENT_RANDOM" // Config is used to configure a DTLS client or server. // After a Config is passed to a DTLS function it must not be modified. // // Deprecated: prefer the options-based APIs (`*WithOptions`) to construct immutable configs, // This will be removed in the next major version. type Config struct { //nolint:dupl // Certificates contains certificate chain to present to the other side of the connection. // Server MUST set this if PSK is non-nil // client SHOULD sets this so CertificateRequests can be handled if PSK is non-nil Certificates []tls.Certificate // CipherSuites is a list of supported cipher suites. // If CipherSuites is nil, a default list is used CipherSuites []CipherSuiteID // CustomCipherSuites is a list of CipherSuites that can be // provided by the user. This allow users to user Ciphers that are reserved // for private usage. CustomCipherSuites func() []CipherSuite // SignatureSchemes contains the signature and hash schemes that the peer requests to verify. SignatureSchemes []tls.SignatureScheme // CertificateSignatureSchemes contains the signature and hash schemes that may be used // in digital signatures for X.509 certificates. If not set, the signature_algorithms_cert // extension is not sent, and SignatureSchemes is used for both handshake signatures and // certificate chain validation, as specified in RFC 8446 Section 4.2.3. CertificateSignatureSchemes []tls.SignatureScheme // SRTPProtectionProfiles are the supported protection profiles // Clients will send this via use_srtp and assert that the server properly responds // Servers will assert that clients send one of these profiles and will respond as needed SRTPProtectionProfiles []SRTPProtectionProfile // SRTPMasterKeyIdentifier value (if any) is sent via the use_srtp // extension for Clients and Servers SRTPMasterKeyIdentifier []byte // ClientAuth determines the server's policy for // TLS Client Authentication. The default is NoClientCert. ClientAuth ClientAuthType // RequireExtendedMasterSecret determines if the "Extended Master Secret" extension // should be disabled, requested, or required (default requested). ExtendedMasterSecret ExtendedMasterSecretType // FlightInterval controls how often we send outbound handshake messages // defaults to time.Second FlightInterval time.Duration // DisableRetransmitBackoff can be used to the disable the backoff feature // when sending outbound messages as specified in RFC 4347 4.2.4.1 DisableRetransmitBackoff bool // PSK sets the pre-shared key used by this DTLS connection // If PSK is non-nil only PSK CipherSuites will be used PSK PSKCallback PSKIdentityHint []byte // InsecureSkipVerify controls whether a client verifies the // server's certificate chain and host name. // If InsecureSkipVerify is true, TLS accepts any certificate // presented by the server and any host name in that certificate. // In this mode, TLS is susceptible to man-in-the-middle attacks. // This should be used only for testing. InsecureSkipVerify bool // InsecureHashes allows the use of hashing algorithms that are known // to be vulnerable. InsecureHashes bool // VerifyPeerCertificate, if not nil, is called after normal // certificate verification by either a client or server. It // receives the certificate provided by the peer and also a flag // that tells if normal verification has succeedded. If it returns a // non-nil error, the handshake is aborted and that error results. // // If normal verification fails then the handshake will abort before // considering this callback. If normal verification is disabled by // setting InsecureSkipVerify, or (for a server) when ClientAuth is // RequestClientCert or RequireAnyClientCert, then this callback will // be considered but the verifiedChains will always be nil. VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error // VerifyConnection, if not nil, is called after normal certificate // verification/PSK and after VerifyPeerCertificate by either a TLS client // or server. If it returns a non-nil error, the handshake is aborted // and that error results. // // If normal verification fails then the handshake will abort before // considering this callback. This callback will run for all connections // regardless of InsecureSkipVerify or ClientAuth settings. VerifyConnection func(*State) error // RootCAs defines the set of root certificate authorities // that one peer uses when verifying the other peer's certificates. // If RootCAs is nil, TLS uses the host's root CA set. RootCAs *x509.CertPool // ClientCAs defines the set of root certificate authorities // that servers use if required to verify a client certificate // by the policy in ClientAuth. ClientCAs *x509.CertPool // ServerName is used to verify the hostname on the returned // certificates unless InsecureSkipVerify is given. ServerName string LoggerFactory logging.LoggerFactory // MTU is the length at which handshake messages will be fragmented to // fit within the maximum transmission unit (default is 1200 bytes) MTU int // ReplayProtectionWindow is the size of the replay attack protection window. // Duplication of the sequence number is checked in this window size. // Packet with sequence number older than this value compared to the latest // accepted packet will be discarded. (default is 64) ReplayProtectionWindow int // KeyLogWriter optionally specifies a destination for TLS master secrets // in NSS key log format that can be used to allow external programs // such as Wireshark to decrypt TLS connections. // See https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/Key_Log_Format. // Use of KeyLogWriter compromises security and should only be // used for debugging. KeyLogWriter io.Writer // SessionStore is the container to store session for resumption. SessionStore SessionStore // List of application protocols the peer supports, for ALPN SupportedProtocols []string // List of Elliptic Curves to use // // If an ECC ciphersuite is configured and EllipticCurves is empty // it will default to X25519, P-256, P-384 in this specific order. EllipticCurves []elliptic.Curve // GetCertificate returns a Certificate based on the given // ClientHelloInfo. It will only be called if the client supplies SNI // information or if Certificates is empty. // // If GetCertificate is nil or returns nil, then the certificate is // retrieved from NameToCertificate. If NameToCertificate is nil, the // best element of Certificates will be used. GetCertificate func(*ClientHelloInfo) (*tls.Certificate, error) // GetClientCertificate, if not nil, is called when a server requests a // certificate from a client. If set, the contents of Certificates will // be ignored. // // If GetClientCertificate returns an error, the handshake will be // aborted and that error will be returned. Otherwise // GetClientCertificate must return a non-nil Certificate. If // Certificate.Certificate is empty then no certificate will be sent to // the server. If this is unacceptable to the server then it may abort // the handshake. GetClientCertificate func(*CertificateRequestInfo) (*tls.Certificate, error) // InsecureSkipVerifyHello, if true and when acting as server, allow client to // skip hello verify phase and receive ServerHello after initial ClientHello. // This have implication on DoS attack resistance. InsecureSkipVerifyHello bool // ConnectionIDGenerator generates connection identifiers that should be // sent by the remote party if it supports the DTLS Connection Identifier // extension, as determined during the handshake. Generated connection // identifiers must always have the same length. Returning a zero-length // connection identifier indicates that the local party supports sending // connection identifiers but does not require the remote party to send // them. A nil ConnectionIDGenerator indicates that connection identifiers // are not supported. // https://datatracker.ietf.org/doc/html/rfc9146 ConnectionIDGenerator func() []byte // PaddingLengthGenerator generates the number of padding bytes used to // inflate ciphertext size in order to obscure content size from observers. // The length of the content is passed to the generator such that both // deterministic and random padding schemes can be applied while not // exceeding maximum record size. // If no PaddingLengthGenerator is specified, padding will not be applied. // https://datatracker.ietf.org/doc/html/rfc9146#section-4 PaddingLengthGenerator func(uint) uint // HelloRandomBytesGenerator generates custom client hello random bytes. HelloRandomBytesGenerator func() [handshake.RandomBytesLength]byte // Handshake hooks: hooks can be used for testing invalid messages, // mimicking other implementations or randomizing fields, which is valuable // for applications that need censorship-resistance by making // fingerprinting more difficult. // ClientHelloMessageHook, if not nil, is called when a Client Hello message is sent // from a client. The returned handshake message replaces the original message. ClientHelloMessageHook func(handshake.MessageClientHello) handshake.Message // ServerHelloMessageHook, if not nil, is called when a Server Hello message is sent // from a server. The returned handshake message replaces the original message. ServerHelloMessageHook func(handshake.MessageServerHello) handshake.Message // CertificateRequestMessageHook, if not nil, is called when a Certificate Request // message is sent from a server. The returned handshake message replaces the original message. CertificateRequestMessageHook func(handshake.MessageCertificateRequest) handshake.Message // OnConnectionAttempt is fired Whenever a connection attempt is made, // the server or application can call this callback function. // The callback function can then implement logic to handle the connection attempt, such as logging the attempt, // checking against a list of blocked IPs, or counting the attempts to prevent brute force attacks. // If the callback function returns an error, the connection attempt will be aborted. OnConnectionAttempt func(net.Addr) error } func (c *Config) includeCertificateSuites() bool { return c.PSK == nil || len(c.Certificates) > 0 || c.GetCertificate != nil || c.GetClientCertificate != nil } const defaultMTU = 1200 // bytes var defaultCurves = []elliptic.Curve{elliptic.X25519, elliptic.P256, elliptic.P384} //nolint:gochecknoglobals // PSKCallback is called once we have the remote's PSKIdentityHint. // If the remote provided none it will be nil. type PSKCallback func([]byte) ([]byte, error) // ClientAuthType declares the policy the server will follow for // TLS Client Authentication. type ClientAuthType int // ClientAuthType enums. const ( NoClientCert ClientAuthType = iota RequestClientCert RequireAnyClientCert VerifyClientCertIfGiven RequireAndVerifyClientCert ) // ExtendedMasterSecretType declares the policy the client and server // will follow for the Extended Master Secret extension. type ExtendedMasterSecretType int // ExtendedMasterSecretType enums. const ( RequestExtendedMasterSecret ExtendedMasterSecretType = iota RequireExtendedMasterSecret DisableExtendedMasterSecret ) func validateConfig(config *Config) error { //nolint:cyclop switch { case config == nil: return errNoConfigProvided case config.PSKIdentityHint != nil && config.PSK == nil: return errIdentityNoPSK } for _, cert := range config.Certificates { if cert.Certificate == nil { return errInvalidCertificate } if cert.PrivateKey != nil { signer, ok := cert.PrivateKey.(crypto.Signer) if !ok { return errInvalidPrivateKey } switch signer.Public().(type) { case ed25519.PublicKey: case *ecdsa.PublicKey: case *rsa.PublicKey: default: return errInvalidPrivateKey } } } _, err := parseCipherSuites( config.CipherSuites, config.CustomCipherSuites, config.includeCertificateSuites(), config.PSK != nil, ) return err } dtls-3.1.2/config_test.go000066400000000000000000000104661514330267300153310ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "crypto/dsa" //nolint:staticcheck "crypto/rand" "crypto/rsa" "crypto/tls" "errors" "testing" "github.com/pion/dtls/v3/pkg/crypto/selfsign" "github.com/stretchr/testify/assert" ) func TestValidateConfig(t *testing.T) { cert, err := selfsign.GenerateSelfSigned() if err != nil { assert.NoError(t, err, "TestValidateConfig: Config validation error, self signed certificate not generated") return } dsaPrivateKey := &dsa.PrivateKey{} err = dsa.GenerateParameters(&dsaPrivateKey.Parameters, rand.Reader, dsa.L1024N160) if err != nil { assert.NoError(t, err, "TestValidateConfig: Config validation error, DSA parameters not generated") return } err = dsa.GenerateKey(dsaPrivateKey, rand.Reader) if err != nil { assert.NoError(t, err, "TestValidateConfig: Config validation error, DSA private key not generated") return } rsaPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { assert.NoError(t, err, "TestValidateConfig: Config validation error, RSA private key not generated") return } cases := map[string]struct { config *Config wantAnyErr bool expErr error }{ "Empty config": { expErr: errNoConfigProvided, }, "PSK and Certificate, valid cipher suites": { config: &Config{ CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, PSK: func([]byte) ([]byte, error) { return nil, nil }, Certificates: []tls.Certificate{cert}, }, }, "PSK and Certificate, no PSK cipher suite": { config: &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, PSK: func([]byte) ([]byte, error) { return nil, nil }, Certificates: []tls.Certificate{cert}, }, expErr: errNoAvailablePSKCipherSuite, }, "PSK and Certificate, no non-PSK cipher suite": { config: &Config{ CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, PSK: func([]byte) ([]byte, error) { return nil, nil }, Certificates: []tls.Certificate{cert}, }, expErr: errNoAvailableCertificateCipherSuite, }, "PSK identity hint with not PSK": { config: &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, PSK: nil, PSKIdentityHint: []byte{}, }, expErr: errIdentityNoPSK, }, "Invalid private key": { config: &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, Certificates: []tls.Certificate{{Certificate: cert.Certificate, PrivateKey: dsaPrivateKey}}, }, expErr: errInvalidPrivateKey, }, "PrivateKey without Certificate": { config: &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, Certificates: []tls.Certificate{{PrivateKey: cert.PrivateKey}}, }, expErr: errInvalidCertificate, }, "Invalid cipher suites": { config: &Config{CipherSuites: []CipherSuiteID{0x0000}}, wantAnyErr: true, }, "Valid config": { config: &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, Certificates: []tls.Certificate{cert, {Certificate: cert.Certificate, PrivateKey: rsaPrivateKey}}, }, }, "Valid config with get certificate": { config: &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, GetCertificate: func(*ClientHelloInfo) (*tls.Certificate, error) { return &tls.Certificate{Certificate: cert.Certificate, PrivateKey: rsaPrivateKey}, nil }, }, }, "Valid config with get client certificate": { config: &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, GetClientCertificate: func(*CertificateRequestInfo) (*tls.Certificate, error) { return &tls.Certificate{Certificate: cert.Certificate, PrivateKey: rsaPrivateKey}, nil }, }, }, } for name, testCase := range cases { testCase := testCase t.Run(name, func(t *testing.T) { err := validateConfig(testCase.config) if testCase.expErr != nil || testCase.wantAnyErr { if testCase.expErr != nil && !errors.Is(err, testCase.expErr) { assert.ErrorIs(t, err, testCase.expErr, "TestValidateConfig") } assert.Error(t, err, "TestValidateConfig: Config validation expected an error") } }) } } dtls-3.1.2/conn.go000066400000000000000000001123411514330267300137550ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "bytes" "context" "errors" "fmt" "io" "net" "sync" "sync/atomic" "time" "github.com/pion/dtls/v3/internal/closer" "github.com/pion/dtls/v3/pkg/crypto/elliptic" "github.com/pion/dtls/v3/pkg/crypto/signaturehash" "github.com/pion/dtls/v3/pkg/protocol" "github.com/pion/dtls/v3/pkg/protocol/alert" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/dtls/v3/pkg/protocol/recordlayer" "github.com/pion/logging" "github.com/pion/transport/v4/deadline" "github.com/pion/transport/v4/netctx" "github.com/pion/transport/v4/replaydetector" ) const ( initialTickerInterval = time.Second cookieLength = 20 sessionLength = 32 defaultNamedCurve = elliptic.X25519 inboundBufferSize = 8192 // Default replay protection window is specified by RFC 6347 Section 4.1.2.6. defaultReplayProtectionWindow = 64 // maxAppDataPacketQueueSize is the maximum number of app data packets we will. // enqueue before the handshake is completed. maxAppDataPacketQueueSize = 100 ) func invalidKeyingLabels() map[string]bool { return map[string]bool{ "client finished": true, "server finished": true, "master secret": true, "key expansion": true, } } type addrPkt struct { rAddr net.Addr data []byte } type recvHandshakeState struct { done chan struct{} isRetransmit bool } // Conn represents a DTLS connection. type Conn struct { lock sync.RWMutex // Internal lock (must not be public) nextConn netctx.PacketConn // Embedded Conn, typically a udpconn we read/write from fragmentBuffer *fragmentBuffer // out-of-order and missing fragment handling handshakeCache *handshakeCache // caching of handshake messages for verifyData generation decrypted chan any // Decrypted Application Data or error, pull by calling `Read` rAddr net.Addr state State // Internal state maximumTransmissionUnit int paddingLengthGenerator func(uint) uint handshakeCompletedSuccessfully atomic.Bool handshakeMutex sync.Mutex handshakeDone chan struct{} encryptedPackets []addrPkt connectionClosedByUser bool closeLock sync.Mutex closed *closer.Closer readDeadline *deadline.Deadline writeDeadline *deadline.Deadline log logging.LeveledLogger reading chan struct{} handshakeRecv chan recvHandshakeState cancelHandshaker func() cancelHandshakeReader func() fsm *handshakeFSM replayProtectionWindow uint handshakeConfig *handshakeConfig } // createConn creates a new DTLS connection. // Caller is responsible for validating the config before calling this function. // //nolint:cyclop func createConn( nextConn net.PacketConn, rAddr net.Addr, config *Config, isClient bool, resumeState *State, ) (*Conn, error) { if nextConn == nil { return nil, errNilNextConn } loggerFactory := config.LoggerFactory if loggerFactory == nil { loggerFactory = logging.NewDefaultLoggerFactory() } logger := loggerFactory.NewLogger("dtls") mtu := config.MTU if mtu <= 0 { mtu = defaultMTU } replayProtectionWindow := config.ReplayProtectionWindow if replayProtectionWindow <= 0 { replayProtectionWindow = defaultReplayProtectionWindow } paddingLengthGenerator := config.PaddingLengthGenerator if paddingLengthGenerator == nil { paddingLengthGenerator = func(uint) uint { return 0 } } cipherSuites, err := parseCipherSuites( config.CipherSuites, config.CustomCipherSuites, config.includeCertificateSuites(), config.PSK != nil, ) if err != nil { return nil, err } signatureSchemes, err := signaturehash.ParseSignatureSchemes(config.SignatureSchemes, config.InsecureHashes) if err != nil { return nil, err } // Parse certificate signature schemes only if explicitly configured var certSignatureSchemes []signaturehash.Algorithm if len(config.CertificateSignatureSchemes) > 0 { certSignatureSchemes, err = signaturehash.ParseSignatureSchemes( config.CertificateSignatureSchemes, config.InsecureHashes, ) if err != nil { return nil, err } } workerInterval := initialTickerInterval if config.FlightInterval > 0 { workerInterval = config.FlightInterval } serverName := config.ServerName // Do not allow the use of an IP address literal as an SNI value. // See RFC 6066, Section 3. if net.ParseIP(serverName) != nil { serverName = "" } curves := config.EllipticCurves if len(curves) == 0 { curves = defaultCurves } handshakeConfig := &handshakeConfig{ localPSKCallback: config.PSK, localPSKIdentityHint: config.PSKIdentityHint, localCipherSuites: cipherSuites, localSignatureSchemes: signatureSchemes, localCertSignatureSchemes: certSignatureSchemes, extendedMasterSecret: config.ExtendedMasterSecret, localSRTPProtectionProfiles: config.SRTPProtectionProfiles, localSRTPMasterKeyIdentifier: config.SRTPMasterKeyIdentifier, serverName: serverName, supportedProtocols: config.SupportedProtocols, clientAuth: config.ClientAuth, localCertificates: config.Certificates, insecureSkipVerify: config.InsecureSkipVerify, verifyPeerCertificate: config.VerifyPeerCertificate, verifyConnection: config.VerifyConnection, rootCAs: config.RootCAs, clientCAs: config.ClientCAs, customCipherSuites: config.CustomCipherSuites, initialRetransmitInterval: workerInterval, disableRetransmitBackoff: config.DisableRetransmitBackoff, log: logger, initialEpoch: 0, keyLogWriter: config.KeyLogWriter, sessionStore: config.SessionStore, ellipticCurves: curves, localGetCertificate: config.GetCertificate, localGetClientCertificate: config.GetClientCertificate, insecureSkipHelloVerify: config.InsecureSkipVerifyHello, connectionIDGenerator: config.ConnectionIDGenerator, helloRandomBytesGenerator: config.HelloRandomBytesGenerator, clientHelloMessageHook: config.ClientHelloMessageHook, serverHelloMessageHook: config.ServerHelloMessageHook, certificateRequestMessageHook: config.CertificateRequestMessageHook, resumeState: resumeState, } conn := &Conn{ rAddr: rAddr, nextConn: netctx.NewPacketConn(nextConn), handshakeConfig: handshakeConfig, fragmentBuffer: newFragmentBuffer(), handshakeCache: newHandshakeCache(), maximumTransmissionUnit: mtu, paddingLengthGenerator: paddingLengthGenerator, decrypted: make(chan any, 1), log: logger, readDeadline: deadline.New(), writeDeadline: deadline.New(), reading: make(chan struct{}, 1), handshakeRecv: make(chan recvHandshakeState), closed: closer.NewCloser(), cancelHandshaker: func() {}, cancelHandshakeReader: func() {}, replayProtectionWindow: uint(replayProtectionWindow), //nolint:gosec // G115 state: State{ isClient: isClient, }, } conn.setRemoteEpoch(0) conn.setLocalEpoch(0) return conn, nil } // Handshake runs the client or server DTLS handshake // protocol if it has not yet been run. // // Most uses of this package need not call Handshake explicitly: the // first [Conn.Read] or [Conn.Write] will call it automatically. // // For control over canceling or setting a timeout on a handshake, use // [Conn.HandshakeContext]. func (c *Conn) Handshake() error { return c.HandshakeContext(context.Background()) } // HandshakeContext runs the client or server DTLS handshake // protocol if it has not yet been run. // // The provided Context must be non-nil. If the context is canceled before // the handshake is complete, the handshake is interrupted and an error is returned. // Once the handshake has completed, cancellation of the context will not affect the // connection. // // Most uses of this package need not call HandshakeContext explicitly: the // first [Conn.Read] or [Conn.Write] will call it automatically. func (c *Conn) HandshakeContext(ctx context.Context) error { c.handshakeMutex.Lock() defer c.handshakeMutex.Unlock() if c.isHandshakeCompletedSuccessfully() { return nil } handshakeDone := make(chan struct{}) defer close(handshakeDone) c.closeLock.Lock() c.handshakeDone = handshakeDone c.closeLock.Unlock() // rfc5246#section-7.4.3 // In addition, the hash and signature algorithms MUST be compatible // with the key in the server's end-entity certificate. if !c.state.isClient { cert, err := c.handshakeConfig.getCertificate(&ClientHelloInfo{}) if err != nil && !errors.Is(err, errNoCertificates) { return err } c.handshakeConfig.localCipherSuites = filterCipherSuitesForCertificate(cert, c.handshakeConfig.localCipherSuites) } var initialFlight flightVal var initialFSMState handshakeState if c.handshakeConfig.resumeState != nil { //nolint:nestif if c.state.isClient { initialFlight = flight5 } else { initialFlight = flight6 } initialFSMState = handshakeFinished c.state = *c.handshakeConfig.resumeState } else { if c.state.isClient { initialFlight = flight1 } else { initialFlight = flight0 } initialFSMState = handshakePreparing } // Do handshake if err := c.handshake(ctx, c.handshakeConfig, initialFlight, initialFSMState); err != nil { return err } c.log.Trace("Handshake Completed") return nil } // Dial connects to the given network address and establishes a DTLS connection on top. // // Deprecated: Use DialWithOptions instead. func Dial(network string, rAddr *net.UDPAddr, config *Config) (*Conn, error) { // net.ListenUDP is used rather than net.DialUDP as the latter prevents the // use of net.PacketConn.WriteTo. // https://github.com/golang/go/blob/ce5e37ec21442c6eb13a43e68ca20129102ebac0/src/net/udpsock_posix.go#L115 pConn, err := net.ListenUDP(network, nil) if err != nil { return nil, err } return Client(pConn, rAddr, config) } // DialWithOptions connects to the given network address and establishes a DTLS connection on top. func DialWithOptions(network string, rAddr *net.UDPAddr, opts ...ClientOption) (*Conn, error) { config, err := buildClientConfig(opts...) if err != nil { return nil, err } return Dial(network, rAddr, config) } // Client establishes a DTLS connection over an existing connection. // // Deprecated: Use ClientWithOptions instead. func Client(conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) { switch { case config == nil: return nil, errNoConfigProvided case config.PSK != nil && config.PSKIdentityHint == nil: return nil, errPSKAndIdentityMustBeSetForClient } if err := validateConfig(config); err != nil { return nil, err } return createConn(conn, rAddr, config, true, nil) } // ClientWithOptions establishes a DTLS connection over an existing connection. func ClientWithOptions(conn net.PacketConn, rAddr net.Addr, opts ...ClientOption) (*Conn, error) { config, err := buildClientConfig(opts...) if err != nil { return nil, err } return Client(conn, rAddr, config) } // serverWithConfig is an internal helper that accepts a *Config. func serverWithConfig(conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) { if config == nil { return nil, errNoConfigProvided } if config.OnConnectionAttempt != nil { if err := config.OnConnectionAttempt(rAddr); err != nil { return nil, err } } return createConn(conn, rAddr, config, false, nil) } // Server listens for incoming DTLS connections. // // Deprecated: Use ServerWithOptions instead. func Server(conn net.PacketConn, rAddr net.Addr, config *Config) (*Conn, error) { if config == nil { return nil, errNoConfigProvided } if err := validateConfig(config); err != nil { return nil, err } return serverWithConfig(conn, rAddr, config) } // ServerWithOptions listens for incoming DTLS connections. func ServerWithOptions(conn net.PacketConn, rAddr net.Addr, opts ...ServerOption) (*Conn, error) { config, err := buildServerConfig(opts...) if err != nil { return nil, err } return Server(conn, rAddr, config) } // Read reads data from the connection. func (c *Conn) Read(buff []byte) (n int, err error) { //nolint:cyclop if err := c.Handshake(); err != nil { return 0, err } select { case <-c.readDeadline.Done(): return 0, errDeadlineExceeded default: } for { select { case <-c.readDeadline.Done(): return 0, errDeadlineExceeded case out, ok := <-c.decrypted: if !ok { return 0, io.EOF } switch val := out.(type) { case ([]byte): if len(buff) < len(val) { return 0, errBufferTooSmall } copy(buff, val) return len(val), nil case (error): return 0, val } } } } // Write writes len(payload) bytes from payload to the DTLS connection. func (c *Conn) Write(payload []byte) (int, error) { if c.isConnectionClosed() { return 0, ErrConnClosed } select { case <-c.writeDeadline.Done(): return 0, errDeadlineExceeded default: } if err := c.Handshake(); err != nil { return 0, err } return len(payload), c.writePackets(c.writeDeadline, []*packet{ { record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Epoch: c.state.getLocalEpoch(), Version: protocol.Version1_2, }, Content: &protocol.ApplicationData{ Data: payload, }, }, shouldWrapCID: len(c.state.remoteConnectionID) > 0, shouldEncrypt: true, }, }) } // Close closes the connection. func (c *Conn) Close() error { err := c.close(true) //nolint:contextcheck c.closeLock.Lock() handshakeDone := c.handshakeDone c.closeLock.Unlock() if handshakeDone != nil { <-handshakeDone } return err } // ConnectionState returns basic DTLS details about the connection. // Note that this replaced the `Export` function of v1. func (c *Conn) ConnectionState() (State, bool) { c.lock.RLock() defer c.lock.RUnlock() stateClone, err := c.state.clone() if err != nil { return State{}, false } return *stateClone, true } // SelectedSRTPProtectionProfile returns the selected SRTPProtectionProfile. func (c *Conn) SelectedSRTPProtectionProfile() (SRTPProtectionProfile, bool) { profile := c.state.getSRTPProtectionProfile() if profile == 0 { return 0, false } return profile, true } // RemoteSRTPMasterKeyIdentifier returns the MasterKeyIdentifier value from the use_srtp. func (c *Conn) RemoteSRTPMasterKeyIdentifier() ([]byte, bool) { if profile := c.state.getSRTPProtectionProfile(); profile == 0 { return nil, false } return c.state.remoteSRTPMasterKeyIdentifier, true } func (c *Conn) writePackets(ctx context.Context, pkts []*packet) error { c.lock.Lock() defer c.lock.Unlock() var rawPackets [][]byte for _, pkt := range pkts { if dtlsHandshake, ok := pkt.record.Content.(*handshake.Handshake); ok { handshakeRaw, err := pkt.record.Marshal() if err != nil { return err } c.log.Tracef("[handshake:%v] -> %s (epoch: %d, seq: %d)", srvCliStr(c.state.isClient), dtlsHandshake.Header.Type.String(), pkt.record.Header.Epoch, dtlsHandshake.Header.MessageSequence) c.handshakeCache.push( handshakeRaw[recordlayer.FixedHeaderSize:], pkt.record.Header.Epoch, dtlsHandshake.Header.MessageSequence, dtlsHandshake.Header.Type, c.state.isClient, ) rawHandshakePackets, err := c.processHandshakePacket(pkt, dtlsHandshake) if err != nil { return err } rawPackets = append(rawPackets, rawHandshakePackets...) } else { rawPacket, err := c.processPacket(pkt) if err != nil { return err } rawPackets = append(rawPackets, rawPacket) } } if len(rawPackets) == 0 { return nil } compactedRawPackets := c.compactRawPackets(rawPackets) for _, compactedRawPackets := range compactedRawPackets { if _, err := c.nextConn.WriteToContext(ctx, compactedRawPackets, c.rAddr); err != nil { return netError(err) } } return nil } func (c *Conn) compactRawPackets(rawPackets [][]byte) [][]byte { // avoid a useless copy in the common case if len(rawPackets) == 1 { return rawPackets } combinedRawPackets := make([][]byte, 0) currentCombinedRawPacket := make([]byte, 0) for _, rawPacket := range rawPackets { if len(currentCombinedRawPacket) > 0 && len(currentCombinedRawPacket)+len(rawPacket) >= c.maximumTransmissionUnit { combinedRawPackets = append(combinedRawPackets, currentCombinedRawPacket) currentCombinedRawPacket = []byte{} } currentCombinedRawPacket = append(currentCombinedRawPacket, rawPacket...) } combinedRawPackets = append(combinedRawPackets, currentCombinedRawPacket) return combinedRawPackets } func (c *Conn) processPacket(pkt *packet) ([]byte, error) { //nolint:cyclop epoch := pkt.record.Header.Epoch for len(c.state.localSequenceNumber) <= int(epoch) { c.state.localSequenceNumber = append(c.state.localSequenceNumber, uint64(0)) } seq := atomic.AddUint64(&c.state.localSequenceNumber[epoch], 1) - 1 if seq > recordlayer.MaxSequenceNumber { // RFC 6347 Section 4.1.0 // The implementation must either abandon an association or rehandshake // prior to allowing the sequence number to wrap. return nil, errSequenceNumberOverflow } pkt.record.Header.SequenceNumber = seq var rawPacket []byte if pkt.shouldWrapCID { //nolint:nestif // Record must be marshaled to populate fields used in inner plaintext. if _, err := pkt.record.Marshal(); err != nil { return nil, err } content, err := pkt.record.Content.Marshal() if err != nil { return nil, err } inner := &recordlayer.InnerPlaintext{ Content: content, RealType: pkt.record.Header.ContentType, } rawInner, err := inner.Marshal() //nolint:govet if err != nil { return nil, err } cidHeader := &recordlayer.Header{ Version: pkt.record.Header.Version, ContentType: protocol.ContentTypeConnectionID, Epoch: pkt.record.Header.Epoch, ContentLen: uint16(len(rawInner)), //nolint:gosec //G115 ConnectionID: c.state.remoteConnectionID, SequenceNumber: pkt.record.Header.SequenceNumber, } rawPacket, err = cidHeader.Marshal() if err != nil { return nil, err } pkt.record.Header = *cidHeader rawPacket = append(rawPacket, rawInner...) } else { var err error rawPacket, err = pkt.record.Marshal() if err != nil { return nil, err } } if pkt.shouldEncrypt { var err error rawPacket, err = c.state.cipherSuite.Encrypt(pkt.record, rawPacket) if err != nil { return nil, err } } return rawPacket, nil } //nolint:cyclop func (c *Conn) processHandshakePacket(pkt *packet, dtlsHandshake *handshake.Handshake) ([][]byte, error) { rawPackets := make([][]byte, 0) handshakeFragments, err := c.fragmentHandshake(dtlsHandshake) if err != nil { return nil, err } epoch := pkt.record.Header.Epoch for len(c.state.localSequenceNumber) <= int(epoch) { c.state.localSequenceNumber = append(c.state.localSequenceNumber, uint64(0)) } for _, handshakeFragment := range handshakeFragments { seq := atomic.AddUint64(&c.state.localSequenceNumber[epoch], 1) - 1 if seq > recordlayer.MaxSequenceNumber { return nil, errSequenceNumberOverflow } var rawPacket []byte if pkt.shouldWrapCID { inner := &recordlayer.InnerPlaintext{ Content: handshakeFragment, RealType: protocol.ContentTypeHandshake, Zeros: c.paddingLengthGenerator(uint(len(handshakeFragment))), } rawInner, err := inner.Marshal() //nolint:govet if err != nil { return nil, err } cidHeader := &recordlayer.Header{ Version: pkt.record.Header.Version, ContentType: protocol.ContentTypeConnectionID, Epoch: pkt.record.Header.Epoch, ContentLen: uint16(len(rawInner)), //nolint:gosec //G115 ConnectionID: c.state.remoteConnectionID, SequenceNumber: pkt.record.Header.SequenceNumber, } rawPacket, err = cidHeader.Marshal() if err != nil { return nil, err } pkt.record.Header = *cidHeader rawPacket = append(rawPacket, rawInner...) } else { recordlayerHeader := &recordlayer.Header{ Version: pkt.record.Header.Version, ContentType: pkt.record.Header.ContentType, ContentLen: uint16(len(handshakeFragment)), //nolint:gosec // G115 Epoch: pkt.record.Header.Epoch, SequenceNumber: seq, } rawPacket, err = recordlayerHeader.Marshal() if err != nil { return nil, err } pkt.record.Header = *recordlayerHeader rawPacket = append(rawPacket, handshakeFragment...) } if pkt.shouldEncrypt { var err error rawPacket, err = c.state.cipherSuite.Encrypt(pkt.record, rawPacket) if err != nil { return nil, err } } rawPackets = append(rawPackets, rawPacket) } return rawPackets, nil } func (c *Conn) fragmentHandshake(dtlsHandshake *handshake.Handshake) ([][]byte, error) { content, err := dtlsHandshake.Message.Marshal() if err != nil { return nil, err } fragmentedHandshakes := make([][]byte, 0) contentFragments := splitBytes(content, c.maximumTransmissionUnit) if len(contentFragments) == 0 { contentFragments = [][]byte{ {}, } } offset := 0 for _, contentFragment := range contentFragments { contentFragmentLen := len(contentFragment) headerFragment := &handshake.Header{ Type: dtlsHandshake.Header.Type, Length: dtlsHandshake.Header.Length, MessageSequence: dtlsHandshake.Header.MessageSequence, FragmentOffset: uint32(offset), //nolint:gosec // G115 FragmentLength: uint32(contentFragmentLen), //nolint:gosec // G115 } offset += contentFragmentLen fragmentedHandshake, err := headerFragment.Marshal() if err != nil { return nil, err } fragmentedHandshake = append(fragmentedHandshake, contentFragment...) fragmentedHandshakes = append(fragmentedHandshakes, fragmentedHandshake) } return fragmentedHandshakes, nil } var poolReadBuffer = sync.Pool{ //nolint:gochecknoglobals New: func() any { b := make([]byte, inboundBufferSize) return &b }, } func (c *Conn) readAndBuffer(ctx context.Context) error { //nolint:cyclop bufptr, ok := poolReadBuffer.Get().(*[]byte) if !ok { return errFailedToAccessPoolReadBuffer } defer poolReadBuffer.Put(bufptr) b := *bufptr i, rAddr, err := c.nextConn.ReadFromContext(ctx, b) if err != nil { return netError(err) } pkts, err := recordlayer.ContentAwareUnpackDatagram(b[:i], len(c.state.getLocalConnectionID())) if err != nil { return err } var hasHandshake, isRetransmit bool for _, p := range pkts { hs, rtx, alert, err := c.handleIncomingPacket(ctx, p, rAddr, true) if alert != nil { if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil { if err == nil { err = alertErr } } } var e *alertError if errors.As(err, &e) && e.IsFatalOrCloseNotify() { return e } if err != nil { return err } if hs { hasHandshake = true } if rtx { isRetransmit = true } } if hasHandshake { s := recvHandshakeState{ done: make(chan struct{}), isRetransmit: isRetransmit, } select { case c.handshakeRecv <- s: // If the other party may retransmit the flight, // we should respond even if it not a new message. <-s.done case <-c.fsm.Done(): } } return nil } func (c *Conn) handleQueuedPackets(ctx context.Context) error { c.lock.Lock() pkts := c.encryptedPackets c.encryptedPackets = nil c.lock.Unlock() for _, p := range pkts { _, _, alert, err := c.handleIncomingPacket(ctx, p.data, p.rAddr, false) // don't re-enqueue if alert != nil { if alertErr := c.notify(ctx, alert.Level, alert.Description); alertErr != nil { if err == nil { err = alertErr } } } var e *alertError if errors.As(err, &e) && e.IsFatalOrCloseNotify() { return e } if err != nil { return err } } return nil } func (c *Conn) enqueueEncryptedPackets(packet addrPkt) bool { c.lock.Lock() defer c.lock.Unlock() if len(c.encryptedPackets) < maxAppDataPacketQueueSize { c.encryptedPackets = append(c.encryptedPackets, packet) return true } return false } //nolint:gocognit,gocyclo,cyclop,maintidx func (c *Conn) handleIncomingPacket( ctx context.Context, buf []byte, rAddr net.Addr, enqueue bool, ) (bool, bool, *alert.Alert, error) { header := &recordlayer.Header{} // Set connection ID size so that records of content type tls12_cid will // be parsed correctly. if len(c.state.getLocalConnectionID()) > 0 { header.ConnectionID = make([]byte, len(c.state.getLocalConnectionID())) } if err := header.Unmarshal(buf); err != nil { // Decode error must be silently discarded // [RFC6347 Section-4.1.2.7] c.log.Debugf("discarded broken packet: %v", err) return false, false, nil, nil } // Validate epoch remoteEpoch := c.state.getRemoteEpoch() if header.Epoch > remoteEpoch { if header.Epoch > remoteEpoch+1 { c.log.Debugf("discarded future packet (epoch: %d, seq: %d)", header.Epoch, header.SequenceNumber, ) return false, false, nil, nil } if enqueue { if ok := c.enqueueEncryptedPackets(addrPkt{rAddr, buf}); ok { c.log.Debug("received packet of next epoch, queuing packet") } } return false, false, nil, nil } // Anti-replay protection for len(c.state.replayDetector) <= int(header.Epoch) { c.state.replayDetector = append(c.state.replayDetector, replaydetector.New(c.replayProtectionWindow, recordlayer.MaxSequenceNumber), ) } markPacketAsValid, ok := c.state.replayDetector[int(header.Epoch)].Check(header.SequenceNumber) if !ok { c.log.Debugf("discarded duplicated packet (epoch: %d, seq: %d)", header.Epoch, header.SequenceNumber, ) return false, false, nil, nil } // originalCID indicates whether the original record had content type // Connection ID. originalCID := false // Decrypt if header.Epoch != 0 { //nolint:nestif if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() { if enqueue { if ok := c.enqueueEncryptedPackets(addrPkt{rAddr, buf}); ok { c.log.Debug("handshake not finished, queuing packet") } } return false, false, nil, nil } // If a connection identifier had been negotiated and encryption is // enabled, the connection identifier MUST be sent. if len(c.state.getLocalConnectionID()) > 0 && header.ContentType != protocol.ContentTypeConnectionID { c.log.Debug("discarded packet missing connection ID after value negotiated") return false, false, nil, nil } var err error var hdr recordlayer.Header if header.ContentType == protocol.ContentTypeConnectionID { hdr.ConnectionID = make([]byte, len(c.state.getLocalConnectionID())) } buf, err = c.state.cipherSuite.Decrypt(hdr, buf) if err != nil { c.log.Debugf("%s: decrypt failed: %s", srvCliStr(c.state.isClient), err) return false, false, nil, nil } // If this is a connection ID record, make it look like a normal record for // further processing. if header.ContentType == protocol.ContentTypeConnectionID { originalCID = true ip := &recordlayer.InnerPlaintext{} if err := ip.Unmarshal(buf[header.Size():]); err != nil { //nolint:govet c.log.Debugf("unpacking inner plaintext failed: %s", err) return false, false, nil, nil } unpacked := &recordlayer.Header{ ContentType: ip.RealType, ContentLen: uint16(len(ip.Content)), //nolint:gosec // G115 Version: header.Version, Epoch: header.Epoch, SequenceNumber: header.SequenceNumber, } buf, err = unpacked.Marshal() if err != nil { c.log.Debugf("converting CID record to inner plaintext failed: %s", err) return false, false, nil, nil } buf = append(buf, ip.Content...) } // If connection ID does not match discard the packet. if !bytes.Equal(c.state.getLocalConnectionID(), header.ConnectionID) { c.log.Debug("unexpected connection ID") return false, false, nil, nil } } isHandshake, isRetransmit, err := c.fragmentBuffer.push(append([]byte{}, buf...)) if err != nil { // Decode error must be silently discarded // [RFC6347 Section-4.1.2.7] c.log.Debugf("defragment failed: %s", err) return false, false, nil, nil } else if isHandshake { markPacketAsValid() for out, epoch := c.fragmentBuffer.pop(); out != nil; out, epoch = c.fragmentBuffer.pop() { header := &handshake.Header{} if err := header.Unmarshal(out); err != nil { c.log.Debugf("%s: handshake parse failed: %s", srvCliStr(c.state.isClient), err) continue } c.handshakeCache.push(out, epoch, header.MessageSequence, header.Type, !c.state.isClient) } return true, isRetransmit, nil, nil } r := &recordlayer.RecordLayer{} if err := r.Unmarshal(buf); err != nil { return false, false, &alert.Alert{Level: alert.Fatal, Description: alert.DecodeError}, err } isLatestSeqNum := false switch content := r.Content.(type) { case *alert.Alert: c.log.Tracef("%s: <- %s", srvCliStr(c.state.isClient), content.String()) var a *alert.Alert if content.Description == alert.CloseNotify { // Respond with a close_notify [RFC5246 Section 7.2.1] a = &alert.Alert{Level: alert.Warning, Description: alert.CloseNotify} } _ = markPacketAsValid() return false, false, a, &alertError{content} case *protocol.ChangeCipherSpec: if c.state.cipherSuite == nil || !c.state.cipherSuite.IsInitialized() { if enqueue { if ok := c.enqueueEncryptedPackets(addrPkt{rAddr, buf}); ok { c.log.Debugf("CipherSuite not initialized, queuing packet") } } return false, false, nil, nil } newRemoteEpoch := header.Epoch + 1 c.log.Tracef("%s: <- ChangeCipherSpec (epoch: %d)", srvCliStr(c.state.isClient), newRemoteEpoch) if c.state.getRemoteEpoch()+1 == newRemoteEpoch { c.setRemoteEpoch(newRemoteEpoch) isLatestSeqNum = markPacketAsValid() } case *protocol.ApplicationData: if header.Epoch == 0 { return false, false, &alert.Alert{ Level: alert.Fatal, Description: alert.UnexpectedMessage, }, errApplicationDataEpochZero } isLatestSeqNum = markPacketAsValid() select { case c.decrypted <- content.Data: case <-c.closed.Done(): case <-ctx.Done(): } default: return false, false, &alert.Alert{ Level: alert.Fatal, Description: alert.UnexpectedMessage, }, fmt.Errorf("%w: %d", errUnhandledContextType, content.ContentType()) } // Any valid connection ID record is a candidate for updating the remote // address if it is the latest record received. // https://datatracker.ietf.org/doc/html/rfc9146#peer-address-update if originalCID && isLatestSeqNum { if rAddr != c.RemoteAddr() { c.lock.Lock() c.rAddr = rAddr c.lock.Unlock() } } return false, false, nil, nil } func (c *Conn) recvHandshake() <-chan recvHandshakeState { return c.handshakeRecv } func (c *Conn) notify(ctx context.Context, level alert.Level, desc alert.Description) error { if level == alert.Fatal && len(c.state.SessionID) > 0 { // According to the RFC, we need to delete the stored session. // https://datatracker.ietf.org/doc/html/rfc5246#section-7.2 if ss := c.fsm.cfg.sessionStore; ss != nil { c.log.Tracef("clean invalid session: %s", c.state.SessionID) if err := ss.Del(c.sessionKey()); err != nil { return err } } } return c.writePackets(ctx, []*packet{ { record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Epoch: c.state.getLocalEpoch(), Version: protocol.Version1_2, }, Content: &alert.Alert{ Level: level, Description: desc, }, }, shouldWrapCID: len(c.state.remoteConnectionID) > 0, shouldEncrypt: c.isHandshakeCompletedSuccessfully(), }, }) } func (c *Conn) setHandshakeCompletedSuccessfully() bool { return c.handshakeCompletedSuccessfully.CompareAndSwap(false, true) } func (c *Conn) isHandshakeCompletedSuccessfully() bool { return c.handshakeCompletedSuccessfully.Load() } //nolint:cyclop,gocognit,contextcheck func (c *Conn) handshake( ctx context.Context, cfg *handshakeConfig, initialFlight flightVal, initialState handshakeState, ) error { c.fsm = newHandshakeFSM(&c.state, c.handshakeCache, cfg, initialFlight) done := make(chan struct{}) ctxRead, cancelRead := context.WithCancel(context.Background()) cfg.onFlightState = func(_ flightVal, s handshakeState) { if s == handshakeFinished && c.setHandshakeCompletedSuccessfully() { close(done) } } ctxHs, cancel := context.WithCancel(context.Background()) c.closeLock.Lock() c.cancelHandshaker = cancel c.cancelHandshakeReader = cancelRead c.closeLock.Unlock() firstErr := make(chan error, 1) var handshakeLoopsFinished sync.WaitGroup handshakeLoopsFinished.Add(2) // Handshake routine should be live until close. // The other party may request retransmission of the last flight to cope with packet drop. go func() { defer handshakeLoopsFinished.Done() err := c.fsm.Run(ctxHs, c, initialState) if !errors.Is(err, context.Canceled) { select { case firstErr <- err: default: } } }() go func() { defer func() { if c.isHandshakeCompletedSuccessfully() { // Escaping read loop. // It's safe to close decrypted channnel now. close(c.decrypted) } // Force stop handshaker when the underlying connection is closed. cancel() }() defer handshakeLoopsFinished.Done() for { if err := c.readAndBuffer(ctxRead); err != nil { //nolint:nestif var alertErr *alertError if errors.As(err, &alertErr) { if !alertErr.IsFatalOrCloseNotify() { if c.isHandshakeCompletedSuccessfully() { // Pass the error to Read() select { case c.decrypted <- err: case <-c.closed.Done(): case <-ctxRead.Done(): } } continue // non-fatal alert must not stop read loop } } else { switch { case errors.Is(err, context.DeadlineExceeded), errors.Is(err, context.Canceled), errors.Is(err, io.EOF), errors.Is(err, net.ErrClosed): case errors.Is(err, recordlayer.ErrInvalidPacketLength): // Decode error must be silently discarded // [RFC6347 Section-4.1.2.7] continue default: if c.isHandshakeCompletedSuccessfully() { // Keep read loop and pass the read error to Read() select { case c.decrypted <- err: case <-c.closed.Done(): case <-ctxRead.Done(): } continue // non-fatal alert must not stop read loop } } } select { case firstErr <- err: default: } if alertErr != nil { if alertErr.IsFatalOrCloseNotify() { _ = c.close(false) //nolint:contextcheck } } if !c.isConnectionClosed() && errors.Is(err, context.Canceled) { c.log.Trace("handshake timeouts - closing underline connection") _ = c.close(false) //nolint:contextcheck } return } } }() select { case err := <-firstErr: cancelRead() cancel() handshakeLoopsFinished.Wait() return c.translateHandshakeCtxError(err) case <-ctx.Done(): cancelRead() cancel() handshakeLoopsFinished.Wait() return c.translateHandshakeCtxError(ctx.Err()) case <-done: return nil } } func (c *Conn) translateHandshakeCtxError(err error) error { if err == nil { return nil } if errors.Is(err, context.Canceled) && c.isHandshakeCompletedSuccessfully() { return nil } return &HandshakeError{Err: err} } func (c *Conn) close(byUser bool) error { c.closeLock.Lock() cancelHandshaker := c.cancelHandshaker cancelHandshakeReader := c.cancelHandshakeReader c.closeLock.Unlock() cancelHandshaker() cancelHandshakeReader() if c.isHandshakeCompletedSuccessfully() && byUser { // Discard error from notify() to return non-error on the first user call of Close() // even if the underlying connection is already closed. _ = c.notify(context.Background(), alert.Warning, alert.CloseNotify) } c.closeLock.Lock() // Don't return ErrConnClosed at the first time of the call from user. closedByUser := c.connectionClosedByUser if byUser { c.connectionClosedByUser = true } isClosed := c.isConnectionClosed() c.closed.Close() c.closeLock.Unlock() if closedByUser { return ErrConnClosed } if isClosed { return nil } return c.nextConn.Close() } func (c *Conn) isConnectionClosed() bool { select { case <-c.closed.Done(): return true default: return false } } func (c *Conn) setLocalEpoch(epoch uint16) { c.state.localEpoch.Store(epoch) } func (c *Conn) setRemoteEpoch(epoch uint16) { c.state.remoteEpoch.Store(epoch) } // LocalAddr implements net.Conn.LocalAddr. func (c *Conn) LocalAddr() net.Addr { return c.nextConn.LocalAddr() } // RemoteAddr implements net.Conn.RemoteAddr. func (c *Conn) RemoteAddr() net.Addr { c.lock.RLock() defer c.lock.RUnlock() return c.rAddr } func (c *Conn) sessionKey() []byte { if c.state.isClient { // As ServerName can be like 0.example.com, it's better to add // delimiter character which is not allowed to be in // neither address or domain name. return []byte(c.rAddr.String() + "_" + c.fsm.cfg.serverName) } return c.state.SessionID } // SetDeadline implements net.Conn.SetDeadline. func (c *Conn) SetDeadline(t time.Time) error { c.readDeadline.Set(t) return c.SetWriteDeadline(t) } // SetReadDeadline implements net.Conn.SetReadDeadline. func (c *Conn) SetReadDeadline(t time.Time) error { c.readDeadline.Set(t) // Read deadline is fully managed by this layer. // Don't set read deadline to underlying connection. return nil } // SetWriteDeadline implements net.Conn.SetWriteDeadline. func (c *Conn) SetWriteDeadline(t time.Time) error { c.writeDeadline.Set(t) // Write deadline is also fully managed by this layer. return nil } dtls-3.1.2/conn_go_test.go000066400000000000000000000067711514330267300155120ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT //go:build !js // +build !js package dtls import ( "context" "errors" "net" "testing" "time" "github.com/pion/dtls/v3/pkg/crypto/selfsign" dtlsnet "github.com/pion/dtls/v3/pkg/net" "github.com/pion/transport/v4/dpipe" "github.com/pion/transport/v4/test" "github.com/stretchr/testify/assert" ) func TestContextConfig(t *testing.T) { //nolint:cyclop // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() report := test.CheckRoutines(t) defer report() addrListen, err := net.ResolveUDPAddr("udp", "localhost:0") assert.NoError(t, err) // Dummy listener listen, err := net.ListenUDP("udp", addrListen) assert.NoError(t, err) defer func() { _ = listen.Close() }() addr, ok := listen.LocalAddr().(*net.UDPAddr) assert.True(t, ok) cert, err := selfsign.GenerateSelfSigned() assert.NoError(t, err) clientOpts := []ClientOption{ WithCertificates(cert), } serverOpts := []ServerOption{ WithCertificates(cert), } dials := map[string]struct { f func() (func() (net.Conn, error), func()) order []byte }{ "Dial": { f: func() (func() (net.Conn, error), func()) { ctx, cancel := context.WithTimeout(context.Background(), 40*time.Millisecond) return func() (net.Conn, error) { conn, err := DialWithOptions("udp", addr, clientOpts...) if err != nil { return nil, err } return conn, conn.HandshakeContext(ctx) }, func() { cancel() } }, order: []byte{0, 1, 2}, }, "Client": { f: func() (func() (net.Conn, error), func()) { ca, _ := dpipe.Pipe() ctx, cancel := context.WithTimeout(context.Background(), 40*time.Millisecond) return func() (net.Conn, error) { conn, err := ClientWithOptions(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), clientOpts...) if err != nil { return nil, err } return conn, conn.HandshakeContext(ctx) }, func() { _ = ca.Close() cancel() } }, order: []byte{0, 1, 2}, }, "Server": { f: func() (func() (net.Conn, error), func()) { ca, _ := dpipe.Pipe() ctx, cancel := context.WithTimeout(context.Background(), 40*time.Millisecond) return func() (net.Conn, error) { conn, err := ServerWithOptions(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), serverOpts...) if err != nil { return nil, err } return conn, conn.HandshakeContext(ctx) }, func() { _ = ca.Close() cancel() } }, order: []byte{0, 1, 2}, }, } for name, dial := range dials { dial := dial t.Run(name, func(t *testing.T) { done := make(chan struct{}) go func() { d, cancel := dial.f() conn, err := d() defer cancel() var netError net.Error if !errors.As(err, &netError) || !netError.Temporary() { //nolint:staticcheck assert.Fail(t, "Dial failed with unexpected error", "err: %v", err) close(done) return } done <- struct{}{} if err == nil { _ = conn.Close() } }() var order []byte early := time.After(20 * time.Millisecond) late := time.After(60 * time.Millisecond) func() { for len(order) < 3 { select { case <-early: order = append(order, 0) case _, ok := <-done: if !ok { return } order = append(order, 1) case <-late: order = append(order, 2) } } }() assert.Equal(t, dial.order, order, "Invalid cancel timing") }) } } dtls-3.1.2/conn_test.go000066400000000000000000003030071514330267300150150ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "bytes" "context" "crypto" "crypto/ecdsa" cryptoElliptic "crypto/elliptic" "crypto/rand" "crypto/rsa" "crypto/tls" "crypto/x509" "encoding/hex" "errors" "fmt" "io" "net" "strings" "sync" "sync/atomic" "testing" "time" "github.com/pion/dtls/v3/internal/ciphersuite" "github.com/pion/dtls/v3/pkg/crypto/elliptic" "github.com/pion/dtls/v3/pkg/crypto/hash" "github.com/pion/dtls/v3/pkg/crypto/selfsign" "github.com/pion/dtls/v3/pkg/crypto/signature" "github.com/pion/dtls/v3/pkg/crypto/signaturehash" dtlsnet "github.com/pion/dtls/v3/pkg/net" "github.com/pion/dtls/v3/pkg/protocol" "github.com/pion/dtls/v3/pkg/protocol/alert" "github.com/pion/dtls/v3/pkg/protocol/extension" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/dtls/v3/pkg/protocol/recordlayer" "github.com/pion/logging" "github.com/pion/transport/v4/dpipe" "github.com/pion/transport/v4/test" "github.com/stretchr/testify/assert" ) var ( errTestPSKInvalidIdentity = errors.New("TestPSK: Server got invalid identity") errPSKRejected = errors.New("PSK Rejected") errNotExpectedChain = errors.New("not expected chain") errExpecedChain = errors.New("expected chain") errWrongCert = errors.New("wrong cert") ) func TestStressDuplex(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() // Run the test stressDuplex(t) } func stressDuplex(t *testing.T) { t.Helper() ca, cb, err := pipeMemory() assert.NoError(t, err) defer func() { assert.NoError(t, ca.Close()) assert.NoError(t, cb.Close()) }() opt := test.Options{ MsgSize: 2048, MsgCount: 100, } assert.NoError(t, test.StressDuplex(ca, cb, opt)) } func TestRoutineLeakOnClose(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(5 * time.Second) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() ca, cb, err := pipeMemory() assert.NoError(t, err) _, err = ca.Write(make([]byte, 100)) assert.NoError(t, err) assert.NoError(t, cb.Close()) assert.NoError(t, ca.Close()) // Packet is sent, but not read. // inboundLoop routine should not be leaked. } func TestReadWriteDeadline(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(5 * time.Second) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() var netErr net.Error ca, cb, err := pipeMemory() assert.NoError(t, err) assert.NoError(t, ca.SetDeadline(time.Unix(0, 1))) _, werr := ca.Write(make([]byte, 100)) assert.ErrorAs(t, werr, &netErr, "Write must return net.Error") assert.True(t, netErr.Timeout(), "Deadline exceeded Write must return Timeout") assert.True(t, netErr.Temporary(), "Deadline exceeded Write must return Temporary") //nolint:staticcheck _, rerr := ca.Read(make([]byte, 100)) assert.ErrorAs(t, rerr, &netErr, "Read must return net.Error") assert.True(t, netErr.Timeout(), "Deadline exceeded Read must return Timeout") assert.True(t, netErr.Temporary(), "Deadline exceeded Read must return Temporary") //nolint:staticcheck assert.NoError(t, ca.SetDeadline(time.Time{})) assert.NoError(t, ca.Close()) assert.NoError(t, cb.Close()) _, err = ca.Write(make([]byte, 100)) assert.ErrorIs(t, err, ErrConnClosed) _, err = ca.Read(make([]byte, 100)) assert.ErrorIs(t, err, io.EOF) } func TestSequenceNumberOverflow(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(5 * time.Second) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() t.Run("ApplicationData", func(t *testing.T) { ca, cb, err := pipeMemory() assert.NoError(t, err) atomic.StoreUint64(&ca.state.localSequenceNumber[1], recordlayer.MaxSequenceNumber) _, werr := ca.Write(make([]byte, 100)) assert.NoError(t, werr, "Write must send message with maximum sequence number") _, werr = ca.Write(make([]byte, 100)) assert.ErrorIs(t, werr, errSequenceNumberOverflow, "Write must abandonsend message with maximum sequence number") assert.NoError(t, ca.Close()) assert.NoError(t, cb.Close()) }) t.Run("Handshake", func(t *testing.T) { ca, cb, err := pipeMemory() assert.NoError(t, err) ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() atomic.StoreUint64(&ca.state.localSequenceNumber[0], recordlayer.MaxSequenceNumber+1) // Try to send handshake packet. werr := ca.writePackets(ctx, []*packet{ { record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &handshake.Handshake{ Message: &handshake.MessageClientHello{ Version: protocol.Version1_2, Cookie: make([]byte, 64), CipherSuiteIDs: cipherSuiteIDs(defaultCipherSuites()), CompressionMethods: defaultCompressionMethods(), }, }, }, }, }) assert.ErrorIs(t, werr, errSequenceNumberOverflow, "Connection must fail when handshake packet reaches maximum sequence num") assert.NoError(t, ca.Close()) assert.NoError(t, cb.Close()) }) } func pipeMemory() (*Conn, *Conn, error) { // In memory pipe ca, cb := dpipe.Pipe() return pipeConn(ca, cb) } func pipeConn(ca, cb net.Conn) (*Conn, *Conn, error) { type result struct { c *Conn err error } resultCh := make(chan result, 1) // Buffered to prevent goroutine leak ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() // Setup client go func() { client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ SRTPProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, }, true) resultCh <- result{client, err} }() // Setup server server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ SRTPProtectionProfiles: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, }, true) if err != nil { // Read from resultCh to prevent goroutine leak if res := <-resultCh; res.c != nil { _ = res.c.Close() } return nil, nil, err } // Receive client res := <-resultCh if res.err != nil { _ = server.Close() return nil, nil, res.err } return res.c, server, nil } func testClient( ctx context.Context, pktConn net.PacketConn, rAddr net.Addr, cfg *Config, generateCertificate bool, ) (*Conn, error) { if generateCertificate { clientCert, err := selfsign.GenerateSelfSigned() if err != nil { return nil, err } cfg.Certificates = []tls.Certificate{clientCert} } cfg.InsecureSkipVerify = true conn, err := Client(pktConn, rAddr, cfg) if err != nil { return nil, err } return conn, conn.HandshakeContext(ctx) } func testServer( ctx context.Context, c net.PacketConn, rAddr net.Addr, cfg *Config, generateCertificate bool, ) (*Conn, error) { if generateCertificate { serverCert, err := selfsign.GenerateSelfSigned() if err != nil { return nil, err } cfg.Certificates = []tls.Certificate{serverCert} } conn, err := Server(c, rAddr, cfg) if err != nil { return nil, err } return conn, conn.HandshakeContext(ctx) } func sendClientHello( cookie []byte, ca net.Conn, sequenceNumber uint64, extensions []extension.Extension, cipherSuiteIDsOverride ...uint16, ) error { cipherSuites := cipherSuiteIDsOverride if len(cipherSuites) == 0 { cipherSuites = cipherSuiteIDs(defaultCipherSuites()) } packet, err := (&recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, SequenceNumber: sequenceNumber, }, Content: &handshake.Handshake{ Header: handshake.Header{ MessageSequence: uint16(sequenceNumber), //nolint:gosec // G115 }, Message: &handshake.MessageClientHello{ Version: protocol.Version1_2, Cookie: cookie, CipherSuiteIDs: cipherSuites, CompressionMethods: defaultCompressionMethods(), Extensions: extensions, }, }, }).Marshal() if err != nil { return err } if _, err = ca.Write(packet); err != nil { return err } return nil } func TestHandshakeWithAlert(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() cases := map[string]struct { configServer, configClient *Config errServer, errClient error }{ "CipherSuiteNoIntersection": { configServer: &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, }, configClient: &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}, }, errServer: errCipherSuiteNoIntersection, errClient: &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}}, }, "SignatureSchemesNoIntersection": { configServer: &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, SignatureSchemes: []tls.SignatureScheme{tls.ECDSAWithP256AndSHA256}, }, configClient: &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, SignatureSchemes: []tls.SignatureScheme{tls.ECDSAWithP521AndSHA512}, }, errServer: &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}}, errClient: errNoAvailableSignatureSchemes, }, } for name, testCase := range cases { testCase := testCase t.Run(name, func(t *testing.T) { clientErr := make(chan error, 1) ca, cb := dpipe.Pipe() go func() { _, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), testCase.configClient, true) clientErr <- err }() _, errServer := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), testCase.configServer, true) assert.ErrorIs(t, errServer, testCase.errServer) assert.ErrorIs(t, <-clientErr, testCase.errClient) }) } } func TestHandshakeWithInvalidRecord(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() type result struct { c *Conn err error } clientErr := make(chan result, 1) ca, cb := dpipe.Pipe() caWithInvalidRecord := &connWithCallback{Conn: ca} var msgSeq atomic.Int32 // Send invalid record after first message caWithInvalidRecord.onWrite = func([]byte) { if msgSeq.Add(1) == 2 { _, err := ca.Write([]byte{0x01, 0x02}) assert.NoError(t, err) } } go func() { client, err := testClient( ctx, dtlsnet.PacketConnFromConn(caWithInvalidRecord), caWithInvalidRecord.RemoteAddr(), &Config{CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}}, true, ) clientErr <- result{client, err} }() server, errServer := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, }, true) errClient := <-clientErr defer func() { if server != nil { assert.NoError(t, server.Close()) } if errClient.c != nil { assert.NoError(t, errClient.c.Close()) } }() assert.NoError(t, errServer) assert.NoError(t, errClient.err) } func TestExportKeyingMaterial(t *testing.T) { // Check for leaking routines report := test.CheckRoutines(t) defer report() var rand [28]byte exportLabel := "EXTRACTOR-dtls_srtp" expectedServerKey := []byte{0x61, 0x09, 0x9d, 0x7d, 0xcb, 0x08, 0x52, 0x2c, 0xe7, 0x7b} expectedClientKey := []byte{0x87, 0xf0, 0x40, 0x02, 0xf6, 0x1c, 0xf1, 0xfe, 0x8c, 0x77} conn := &Conn{ state: State{ localRandom: handshake.Random{GMTUnixTime: time.Unix(500, 0), RandomBytes: rand}, remoteRandom: handshake.Random{GMTUnixTime: time.Unix(1000, 0), RandomBytes: rand}, localSequenceNumber: []uint64{0, 0}, cipherSuite: &ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{}, }, } conn.setLocalEpoch(0) conn.setRemoteEpoch(0) state, ok := conn.ConnectionState() assert.True(t, ok) _, err := state.ExportKeyingMaterial(exportLabel, nil, 0) assert.ErrorIs(t, err, errHandshakeInProgress, "ExportKeyingMaterial when epoch == 0 error mismatch") conn.setLocalEpoch(1) state, ok = conn.ConnectionState() assert.True(t, ok) _, err = state.ExportKeyingMaterial(exportLabel, []byte{0x00}, 0) assert.ErrorIs(t, err, errContextUnsupported, "ExportKeyingMaterial with context mismatch") for k := range invalidKeyingLabels() { state, ok = conn.ConnectionState() assert.True(t, ok) _, err = state.ExportKeyingMaterial(k, nil, 0) assert.ErrorIs(t, err, errReservedExportKeyingMaterial, "ExportKeyingMaterial reserved label mismatch") } state, ok = conn.ConnectionState() assert.True(t, ok) keyingMaterial, err := state.ExportKeyingMaterial(exportLabel, nil, 10) assert.NoError(t, err, "ExportingKeyingMaterial as server error") assert.Equal(t, expectedServerKey, keyingMaterial, "ExportKeyingMaterial client export mismatch") conn.state.isClient = true state, ok = conn.ConnectionState() assert.True(t, ok) keyingMaterial, err = state.ExportKeyingMaterial(exportLabel, nil, 10) assert.NoError(t, err) assert.Equal(t, expectedClientKey, keyingMaterial, "ExportKeyingMaterial client report mismatch") } func TestPSK(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() for _, test := range []struct { Name string ClientIdentity []byte ServerIdentity []byte CipherSuites []CipherSuiteID ClientVerifyConnection func(*State) error ServerVerifyConnection func(*State) error WantFail bool ExpectedServerErr string ExpectedClientErr string }{ { Name: "Server identity specified", ServerIdentity: []byte("Test Identity"), ClientIdentity: []byte("Client Identity"), CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, }, { Name: "Server identity specified - Server verify connection fails", ServerIdentity: []byte("Test Identity"), ClientIdentity: []byte("Client Identity"), CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, ServerVerifyConnection: func(*State) error { return errExample }, WantFail: true, ExpectedServerErr: errExample.Error(), ExpectedClientErr: alert.BadCertificate.String(), }, { Name: "Server identity specified - Client verify connection fails", ServerIdentity: []byte("Test Identity"), ClientIdentity: []byte("Client Identity"), CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, ClientVerifyConnection: func(*State) error { return errExample }, WantFail: true, ExpectedServerErr: alert.BadCertificate.String(), ExpectedClientErr: errExample.Error(), }, { Name: "Server identity nil", ServerIdentity: nil, ClientIdentity: []byte("Client Identity"), CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, }, { Name: "TLS_PSK_WITH_AES_128_CBC_SHA256", ServerIdentity: nil, ClientIdentity: []byte("Client Identity"), CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CBC_SHA256}, }, { Name: "TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256", ServerIdentity: nil, ClientIdentity: []byte("Client Identity"), CipherSuites: []CipherSuiteID{TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256}, }, { Name: "Client identity empty", ServerIdentity: nil, ClientIdentity: []byte{}, CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, }, } { test := test t.Run(test.Name, func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() type result struct { c *Conn err error } clientRes := make(chan result, 1) ca, cb := dpipe.Pipe() go func() { conf := &Config{ PSK: func(hint []byte) ([]byte, error) { if !bytes.Equal(test.ServerIdentity, hint) { return nil, fmt.Errorf( //nolint:err113 "TestPSK: Client got invalid identity expected(% 02x) actual(% 02x)", test.ServerIdentity, hint, ) } return []byte{0xAB, 0xC1, 0x23}, nil }, PSKIdentityHint: test.ClientIdentity, CipherSuites: test.CipherSuites, VerifyConnection: test.ClientVerifyConnection, } c, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), conf, false) clientRes <- result{c, err} }() config := &Config{ PSK: func(hint []byte) ([]byte, error) { t.Log(hint) if !bytes.Equal(test.ClientIdentity, hint) { return nil, fmt.Errorf("%w: expected(% 02x) actual(% 02x)", errTestPSKInvalidIdentity, test.ClientIdentity, hint) } return []byte{0xAB, 0xC1, 0x23}, nil }, PSKIdentityHint: test.ServerIdentity, CipherSuites: test.CipherSuites, VerifyConnection: test.ServerVerifyConnection, } server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, false) if test.WantFail { res := <-clientRes assert.Error(t, err) assert.True(t, strings.Contains(err.Error(), test.ExpectedServerErr), "TestPSK: Server expected error mismatch") assert.Error(t, res.err, "TestPSK: Client expected error mismatch") assert.True(t, strings.Contains(res.err.Error(), test.ExpectedClientErr), "TestPSK: Client expeected error mismatch") return } assert.NoError(t, err) state, ok := server.ConnectionState() assert.True(t, ok, "TestPSK: Server ConnectionState failed") actualPSKIdentityHint := state.IdentityHint assert.Equal(t, test.ClientIdentity, actualPSKIdentityHint, "TestPSK: Server ClientPSKIdentity Mismatch") defer func() { _ = server.Close() }() res := <-clientRes assert.NoError(t, res.err) assert.NoError(t, res.c.Close()) }) } } func TestPSKHintFail(t *testing.T) { // Check for leaking routines report := test.CheckRoutines(t) defer report() serverAlertError := &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InternalError}} pskRejected := errPSKRejected // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() clientErr := make(chan error, 1) ca, cb := dpipe.Pipe() go func() { conf := &Config{ PSK: func([]byte) ([]byte, error) { return nil, pskRejected }, PSKIdentityHint: []byte{}, CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, } _, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), conf, false) clientErr <- err }() config := &Config{ PSK: func([]byte) ([]byte, error) { return nil, pskRejected }, PSKIdentityHint: []byte{}, CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, } _, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, false) assert.ErrorIs(t, err, serverAlertError, "TestPSK: Server should fail with alert error") assert.ErrorIs(t, <-clientErr, pskRejected, "TestPSK: Client should fail with pskRejected error") } func TestPSKMismatchNoRetransmitLoop(t *testing.T) { lim := test.TimeOut(time.Second * 20) defer lim.Stop() report := test.CheckRoutines(t) defer report() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() var serverWrites atomic.Int32 var clientWrites atomic.Int32 ca, cb := dpipe.Pipe() defer func() { _ = ca.Close() }() defer func() { _ = cb.Close() }() caCount := &connWithCallback{Conn: ca} caCount.onWrite = func([]byte) { clientWrites.Add(1) } cbCount := &connWithCallback{Conn: cb} cbCount.onWrite = func([]byte) { serverWrites.Add(1) } clientErr := make(chan error, 1) serverErr := make(chan error, 1) go func() { conf := &Config{ PSK: func([]byte) ([]byte, error) { return []byte("client-psk"), nil }, PSKIdentityHint: []byte("Client Identity"), CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, } c, err := testClient(ctx, dtlsnet.PacketConnFromConn(caCount), caCount.RemoteAddr(), conf, false) if c != nil { _ = c.Close() //nolint:contextcheck } clientErr <- err }() go func() { conf := &Config{ PSK: func([]byte) ([]byte, error) { return []byte("server-psk"), nil }, CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, } s, err := testServer(ctx, dtlsnet.PacketConnFromConn(cbCount), cbCount.RemoteAddr(), conf, false) if s != nil { _ = s.Close() //nolint:contextcheck } serverErr <- err }() serverErrRes := <-serverErr clientErrRes := <-clientErr var serverHandshakeErr *HandshakeError var clientHandshakeErr *HandshakeError assert.ErrorAs(t, serverErrRes, &serverHandshakeErr) assert.ErrorAs(t, clientErrRes, &clientHandshakeErr) serverCount := serverWrites.Load() clientCount := clientWrites.Load() time.Sleep(2 * time.Second) assert.Equal(t, serverCount, serverWrites.Load(), "Server should not retransmit after handshake failure") assert.Equal(t, clientCount, clientWrites.Load(), "Client should not retransmit after handshake failure") assert.LessOrEqual(t, serverCount, int32(20), "Server retransmit count too high for backoff") assert.LessOrEqual(t, clientCount, int32(20), "Client retransmit count too high for backoff") } // Assert that ServerKeyExchange is only sent if Identity is set on server side. func TestPSKServerKeyExchange(t *testing.T) { //nolint:cyclop // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() for _, test := range []struct { Name string SetIdentity bool }{ { Name: "Server Identity Set", SetIdentity: true, }, { Name: "Server Not Identity Set", SetIdentity: false, }, } { testCase := test t.Run(testCase.Name, func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() var gotServerKeyExchange atomic.Bool expectedServerKeyExchange := testCase.SetIdentity clientErr := make(chan error, 1) ca, cb := dpipe.Pipe() cbAnalyzer := &connWithCallback{Conn: cb} cbAnalyzer.onWrite = func(in []byte) { messages, err := recordlayer.UnpackDatagram(in) assert.NoError(t, err) for i := range messages { var header recordlayer.Header if err := header.Unmarshal(messages[i]); err != nil { continue } if header.ContentType != protocol.ContentTypeHandshake || header.Epoch != 0 { continue } payload := messages[i][recordlayer.FixedHeaderSize:] for len(payload) >= handshake.HeaderLength { var h handshake.Header if err := h.Unmarshal(payload); err != nil { break } if h.Type == handshake.TypeServerKeyExchange { gotServerKeyExchange.Store(true) break } fragLen := int(h.FragmentLength) if fragLen <= 0 || handshake.HeaderLength+fragLen > len(payload) { break } payload = payload[handshake.HeaderLength+fragLen:] } } } go func() { conf := &Config{ PSK: func([]byte) ([]byte, error) { return []byte{0xAB, 0xC1, 0x23}, nil }, PSKIdentityHint: []byte{0xAB, 0xC1, 0x23}, CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, } if client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), conf, false); err != nil { clientErr <- err } else { clientErr <- client.Close() //nolint } }() config := &Config{ PSK: func([]byte) ([]byte, error) { return []byte{0xAB, 0xC1, 0x23}, nil }, CipherSuites: []CipherSuiteID{TLS_PSK_WITH_AES_128_CCM_8}, } if testCase.SetIdentity { config.PSKIdentityHint = []byte{0xAB, 0xC1, 0x23} } server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cbAnalyzer), cbAnalyzer.RemoteAddr(), config, false) assert.NoError(t, err) // Read the value immediately after handshake completes, before closing receivedServerKeyExchange := gotServerKeyExchange.Load() assert.NoError(t, server.Close()) assert.NoError(t, <-clientErr, "TestPSK: Client erro") assert.Equal(t, expectedServerKeyExchange, receivedServerKeyExchange) }) } } func TestClientTimeout(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() clientErr := make(chan error, 1) ca, _ := dpipe.Pipe() go func() { conf := &Config{} c, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), conf, true) if err == nil { _ = c.Close() //nolint:contextcheck } clientErr <- err }() // no server! err := <-clientErr var netErr net.Error assert.ErrorAs(t, err, &netErr, "Client error exp(Temporary network error) failed") assert.True(t, netErr.Timeout(), "Client error exp(Timeout) failed") } func TestSRTPConfiguration(t *testing.T) { // Check for leaking routines report := test.CheckRoutines(t) defer report() for _, test := range []struct { Name string ClientSRTP []SRTPProtectionProfile ServerSRTP []SRTPProtectionProfile ClientSRTPMasterKeyIdentifier []byte ServerSRTPMasterKeyIdentifier []byte ExpectedProfile SRTPProtectionProfile WantClientError error WantServerError error }{ { Name: "No SRTP in use", ClientSRTP: nil, ServerSRTP: nil, ExpectedProfile: 0, WantClientError: nil, WantServerError: nil, }, { Name: "SRTP both ends", ClientSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, ServerSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, ExpectedProfile: SRTP_AES128_CM_HMAC_SHA1_80, ClientSRTPMasterKeyIdentifier: []byte("ClientSRTPMKI"), ServerSRTPMasterKeyIdentifier: []byte("ServerSRTPMKI"), WantClientError: nil, WantServerError: nil, }, { Name: "SRTP client only", ClientSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, ServerSRTP: nil, ExpectedProfile: 0, WantClientError: &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}}, WantServerError: errServerNoMatchingSRTPProfile, }, { Name: "SRTP server only", ClientSRTP: nil, ServerSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80}, ExpectedProfile: 0, WantClientError: nil, WantServerError: nil, }, { Name: "Multiple Suites", ClientSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80, SRTP_AES128_CM_HMAC_SHA1_32}, ServerSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80, SRTP_AES128_CM_HMAC_SHA1_32}, ExpectedProfile: SRTP_AES128_CM_HMAC_SHA1_80, WantClientError: nil, WantServerError: nil, }, { Name: "Multiple Suites, Server Chooses", ClientSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_80, SRTP_AES128_CM_HMAC_SHA1_32}, ServerSRTP: []SRTPProtectionProfile{SRTP_AES128_CM_HMAC_SHA1_32, SRTP_AES128_CM_HMAC_SHA1_80}, ExpectedProfile: SRTP_AES128_CM_HMAC_SHA1_32, WantClientError: nil, WantServerError: nil, }, } { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() ca, cb := dpipe.Pipe() type result struct { c *Conn err error } resultCh := make(chan result) go func() { client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ SRTPProtectionProfiles: test.ClientSRTP, SRTPMasterKeyIdentifier: test.ServerSRTPMasterKeyIdentifier, }, true) resultCh <- result{client, err} }() server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ SRTPProtectionProfiles: test.ServerSRTP, SRTPMasterKeyIdentifier: test.ClientSRTPMasterKeyIdentifier, }, true) assert.ErrorIs(t, err, test.WantServerError, "TestSRTPConfiguration: Server Error Mismatch") if err == nil { defer func() { _ = server.Close() }() } res := <-resultCh if res.err == nil { defer func() { _ = res.c.Close() }() } assert.ErrorIsf(t, res.err, test.WantClientError, "TestSRTPConfiguration: Client Error Mismatch '%s'", test.Name) if res.c == nil { return } actualClientSRTP, _ := res.c.SelectedSRTPProtectionProfile() assert.Equalf(t, test.ExpectedProfile, actualClientSRTP, "TestSRTPConfiguration: Client SRTPProtectionProfile Mismatch '%s'", test.Name) actualServerSRTP, _ := server.SelectedSRTPProtectionProfile() assert.Equalf(t, test.ExpectedProfile, actualServerSRTP, "TestSRTPConfiguration: Server SRTPProtectionProfile Mismatch '%s'", test.Name) actualServerMKI, _ := server.RemoteSRTPMasterKeyIdentifier() assert.Truef(t, bytes.Equal(test.ServerSRTPMasterKeyIdentifier, actualServerMKI), "TestSRTPConfiguration: Server SRTPMKI Mismatch '%s'", test.Name) actualClientMKI, _ := res.c.RemoteSRTPMasterKeyIdentifier() assert.Truef(t, bytes.Equal(test.ClientSRTPMasterKeyIdentifier, actualClientMKI), "TestSRTPConfiguration: Client SRTPMKI Mismatch '%s'", test.Name) } } func TestClientCertificate(t *testing.T) { //nolint:gocyclo,cyclop,maintidx // Check for leaking routines report := test.CheckRoutines(t) defer report() srvCert, err := selfsign.GenerateSelfSigned() assert.NoError(t, err) srvCAPool := x509.NewCertPool() srvCertificate, err := x509.ParseCertificate(srvCert.Certificate[0]) assert.NoError(t, err) srvCAPool.AddCert(srvCertificate) cert, err := selfsign.GenerateSelfSigned() assert.NoError(t, err) certificate, err := x509.ParseCertificate(cert.Certificate[0]) assert.NoError(t, err) caPool := x509.NewCertPool() caPool.AddCert(certificate) t.Run("parallel", func(t *testing.T) { // sync routines to check routine leak tests := map[string]struct { clientCfg *Config serverCfg *Config wantErr bool }{ "NoClientCert": { clientCfg: &Config{RootCAs: srvCAPool}, serverCfg: &Config{ Certificates: []tls.Certificate{srvCert}, ClientAuth: NoClientCert, ClientCAs: caPool, }, }, "NoClientCert_ServerVerifyConnectionFails": { clientCfg: &Config{RootCAs: srvCAPool}, serverCfg: &Config{ Certificates: []tls.Certificate{srvCert}, ClientAuth: NoClientCert, ClientCAs: caPool, VerifyConnection: func(*State) error { return errExample }, }, wantErr: true, }, "NoClientCert_ClientVerifyConnectionFails": { clientCfg: &Config{RootCAs: srvCAPool, VerifyConnection: func(*State) error { return errExample }}, serverCfg: &Config{ Certificates: []tls.Certificate{srvCert}, ClientAuth: NoClientCert, ClientCAs: caPool, }, wantErr: true, }, "NoClientCert_cert": { clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}}, serverCfg: &Config{ Certificates: []tls.Certificate{srvCert}, ClientAuth: RequireAnyClientCert, }, }, "RequestClientCert_cert_sigscheme": { // specify signature algorithm clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}}, serverCfg: &Config{ SignatureSchemes: []tls.SignatureScheme{tls.ECDSAWithP521AndSHA512}, Certificates: []tls.Certificate{srvCert}, ClientAuth: RequestClientCert, }, }, "RequestClientCert_cert": { clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}}, serverCfg: &Config{ Certificates: []tls.Certificate{srvCert}, ClientAuth: RequestClientCert, }, }, "RequestClientCert_no_cert": { clientCfg: &Config{RootCAs: srvCAPool}, serverCfg: &Config{ Certificates: []tls.Certificate{srvCert}, ClientAuth: RequestClientCert, ClientCAs: caPool, }, }, "RequireAnyClientCert": { clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}}, serverCfg: &Config{ Certificates: []tls.Certificate{srvCert}, ClientAuth: RequireAnyClientCert, }, }, "RequireAnyClientCert_error": { clientCfg: &Config{RootCAs: srvCAPool}, serverCfg: &Config{ Certificates: []tls.Certificate{srvCert}, ClientAuth: RequireAnyClientCert, }, wantErr: true, }, "VerifyClientCertIfGiven_no_cert": { clientCfg: &Config{RootCAs: srvCAPool}, serverCfg: &Config{ Certificates: []tls.Certificate{srvCert}, ClientAuth: VerifyClientCertIfGiven, ClientCAs: caPool, }, }, "VerifyClientCertIfGiven_cert": { clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}}, serverCfg: &Config{ Certificates: []tls.Certificate{srvCert}, ClientAuth: VerifyClientCertIfGiven, ClientCAs: caPool, }, }, "VerifyClientCertIfGiven_error": { clientCfg: &Config{RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}}, serverCfg: &Config{ Certificates: []tls.Certificate{srvCert}, ClientAuth: VerifyClientCertIfGiven, }, wantErr: true, }, "RequireAndVerifyClientCert": { clientCfg: &Config{ RootCAs: srvCAPool, Certificates: []tls.Certificate{cert}, VerifyConnection: func(s *State) error { if ok := bytes.Equal(s.PeerCertificates[0], srvCertificate.Raw); !ok { return errExample } return nil }, }, serverCfg: &Config{ Certificates: []tls.Certificate{srvCert}, ClientAuth: RequireAndVerifyClientCert, ClientCAs: caPool, VerifyConnection: func(s *State) error { if ok := bytes.Equal(s.PeerCertificates[0], certificate.Raw); !ok { return errExample } return nil }, }, }, "RequireAndVerifyClientCert_callbacks": { clientCfg: &Config{ RootCAs: srvCAPool, // Certificates: []tls.Certificate{cert}, GetClientCertificate: func(*CertificateRequestInfo) (*tls.Certificate, error) { return &cert, nil }, }, serverCfg: &Config{ GetCertificate: func(*ClientHelloInfo) (*tls.Certificate, error) { return &srvCert, nil }, // Certificates: []tls.Certificate{srvCert}, ClientAuth: RequireAndVerifyClientCert, ClientCAs: caPool, }, }, } for name, tt := range tests { tt := tt t.Run(name, func(t *testing.T) { ca, cb := dpipe.Pipe() type result struct { c *Conn err, hserr error } c := make(chan result) go func() { client, err := Client(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), tt.clientCfg) c <- result{client, err, client.Handshake()} }() server, err := Server(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), tt.serverCfg) hserr := server.Handshake() res := <-c defer func() { if err == nil { _ = server.Close() } if res.err == nil { _ = res.c.Close() } }() if tt.wantErr { assert.True(t, err != nil || hserr != nil, "Error expected") return // Error expected, test succeeded } assert.NoError(t, err) assert.NoError(t, res.err) state, ok := server.ConnectionState() assert.True(t, ok, "Server connection state not available") actualClientCert := state.PeerCertificates //nolint:nestif if tt.serverCfg.ClientAuth == RequireAnyClientCert || tt.serverCfg.ClientAuth == RequireAndVerifyClientCert { assert.NotNil(t, actualClientCert, "Client did not provide a certificate") var cfgCert [][]byte if len(tt.clientCfg.Certificates) > 0 { cfgCert = tt.clientCfg.Certificates[0].Certificate } if tt.clientCfg.GetClientCertificate != nil { crt, err := tt.clientCfg.GetClientCertificate(&CertificateRequestInfo{}) assert.NoError(t, err, "Server configuration did not provide a certificate") cfgCert = crt.Certificate } assert.NotEmpty(t, cfgCert, "Client certificate was not communicated correctly") assert.Equal(t, actualClientCert[0], cfgCert[0], "Client certificate was not communicated correctly") } if tt.serverCfg.ClientAuth == NoClientCert { assert.Nil(t, actualClientCert, "Client certificate wasn't expected") } clientState, ok := res.c.ConnectionState() assert.True(t, ok, "Client connection state not available") actualServerCert := clientState.PeerCertificates assert.NotNil(t, actualServerCert, "server did not provide a certificate") var cfgCert [][]byte if len(tt.serverCfg.Certificates) > 0 { cfgCert = tt.serverCfg.Certificates[0].Certificate } if tt.serverCfg.GetCertificate != nil { crt, err := tt.serverCfg.GetCertificate(&ClientHelloInfo{}) assert.NoError(t, err, "Server configuration did not provide a certificate") cfgCert = crt.Certificate } assert.NotEmpty(t, cfgCert, "Server certificate was not communicated correctly") assert.Equal(t, actualServerCert[0], cfgCert[0], "Server certificate was not communicated correctly") }) } }) } func TestConnectionID(t *testing.T) { // Check for leaking routines report := test.CheckRoutines(t) defer report() clientCID := []byte{5, 77, 33, 24, 93, 27, 45, 81} serverCID := []byte{64, 24, 73, 2, 17, 96, 38, 59} cidEcho := func(echo []byte) func() []byte { return func() []byte { return echo } } tests := map[string]struct { clientCfg *Config serverCfg *Config clientConnectionID []byte serverConnectionID []byte }{ "BidirectionalConnectionIDs": { clientCfg: &Config{ ConnectionIDGenerator: cidEcho(clientCID), }, serverCfg: &Config{ ConnectionIDGenerator: cidEcho(serverCID), }, clientConnectionID: clientCID, serverConnectionID: serverCID, }, "BothSupportOnlyClientSends": { clientCfg: &Config{ ConnectionIDGenerator: cidEcho(nil), }, serverCfg: &Config{ ConnectionIDGenerator: cidEcho(serverCID), }, serverConnectionID: serverCID, }, "BothSupportOnlyServerSends": { clientCfg: &Config{ ConnectionIDGenerator: cidEcho(clientCID), }, serverCfg: &Config{ ConnectionIDGenerator: cidEcho(nil), }, clientConnectionID: clientCID, }, "ClientDoesNotSupport": { clientCfg: &Config{}, serverCfg: &Config{ ConnectionIDGenerator: cidEcho(serverCID), }, }, "ServerDoesNotSupport": { clientCfg: &Config{ ConnectionIDGenerator: cidEcho(clientCID), }, serverCfg: &Config{}, }, "NeitherSupport": { clientCfg: &Config{}, serverCfg: &Config{}, }, } for name, tt := range tests { tt := tt t.Run(name, func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() ca, cb := dpipe.Pipe() type result struct { c *Conn err error } c := make(chan result) go func() { client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), tt.clientCfg, true) c <- result{client, err} }() server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), tt.serverCfg, true) assert.NoError(t, err) res := <-c assert.NoError(t, res.err) defer func() { if err == nil { _ = server.Close() } if res.err == nil { _ = res.c.Close() } }() assert.True(t, bytes.Equal(tt.clientConnectionID, res.c.state.getLocalConnectionID()), "Unexpected client local connection ID") assert.True(t, bytes.Equal(tt.serverConnectionID, res.c.state.remoteConnectionID), "Unexpected client remote connection ID") assert.True(t, bytes.Equal(tt.serverConnectionID, server.state.getLocalConnectionID()), "Unexpected server local connection ID") assert.True(t, bytes.Equal(tt.clientConnectionID, server.state.remoteConnectionID), "Unexpected server remote connection ID") }) } } func TestExtendedMasterSecret(t *testing.T) { // Check for leaking routines report := test.CheckRoutines(t) defer report() tests := map[string]struct { clientCfg *Config serverCfg *Config expectedClientErr error expectedServerErr error }{ "Request_Request_ExtendedMasterSecret": { clientCfg: &Config{ ExtendedMasterSecret: RequestExtendedMasterSecret, }, serverCfg: &Config{ ExtendedMasterSecret: RequestExtendedMasterSecret, }, expectedClientErr: nil, expectedServerErr: nil, }, "Request_Require_ExtendedMasterSecret": { clientCfg: &Config{ ExtendedMasterSecret: RequestExtendedMasterSecret, }, serverCfg: &Config{ ExtendedMasterSecret: RequireExtendedMasterSecret, }, expectedClientErr: nil, expectedServerErr: nil, }, "Request_Disable_ExtendedMasterSecret": { clientCfg: &Config{ ExtendedMasterSecret: RequestExtendedMasterSecret, }, serverCfg: &Config{ ExtendedMasterSecret: DisableExtendedMasterSecret, }, expectedClientErr: nil, expectedServerErr: nil, }, "Require_Request_ExtendedMasterSecret": { clientCfg: &Config{ ExtendedMasterSecret: RequireExtendedMasterSecret, }, serverCfg: &Config{ ExtendedMasterSecret: RequestExtendedMasterSecret, }, expectedClientErr: nil, expectedServerErr: nil, }, "Require_Require_ExtendedMasterSecret": { clientCfg: &Config{ ExtendedMasterSecret: RequireExtendedMasterSecret, }, serverCfg: &Config{ ExtendedMasterSecret: RequireExtendedMasterSecret, }, expectedClientErr: nil, expectedServerErr: nil, }, "Require_Disable_ExtendedMasterSecret": { clientCfg: &Config{ ExtendedMasterSecret: RequireExtendedMasterSecret, }, serverCfg: &Config{ ExtendedMasterSecret: DisableExtendedMasterSecret, }, expectedClientErr: errClientRequiredButNoServerEMS, expectedServerErr: &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}}, }, "Disable_Request_ExtendedMasterSecret": { clientCfg: &Config{ ExtendedMasterSecret: DisableExtendedMasterSecret, }, serverCfg: &Config{ ExtendedMasterSecret: RequestExtendedMasterSecret, }, expectedClientErr: nil, expectedServerErr: nil, }, "Disable_Require_ExtendedMasterSecret": { clientCfg: &Config{ ExtendedMasterSecret: DisableExtendedMasterSecret, }, serverCfg: &Config{ ExtendedMasterSecret: RequireExtendedMasterSecret, }, expectedClientErr: &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}}, expectedServerErr: errServerRequiredButNoClientEMS, }, "Disable_Disable_ExtendedMasterSecret": { clientCfg: &Config{ ExtendedMasterSecret: DisableExtendedMasterSecret, }, serverCfg: &Config{ ExtendedMasterSecret: DisableExtendedMasterSecret, }, expectedClientErr: nil, expectedServerErr: nil, }, } for name, tt := range tests { tt := tt t.Run(name, func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() ca, cb := dpipe.Pipe() type result struct { c *Conn err error } c := make(chan result) go func() { client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), tt.clientCfg, true) c <- result{client, err} }() server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), tt.serverCfg, true) res := <-c defer func() { if err == nil { _ = server.Close() } if res.err == nil { _ = res.c.Close() } }() assert.ErrorIs(t, res.err, tt.expectedClientErr) assert.ErrorIs(t, err, tt.expectedServerErr) }) } } func TestServerCertificate(t *testing.T) { //nolint:cyclop // Check for leaking routines report := test.CheckRoutines(t) defer report() cert, err := selfsign.GenerateSelfSigned() assert.NoError(t, err) certificate, err := x509.ParseCertificate(cert.Certificate[0]) assert.NoError(t, err) caPool := x509.NewCertPool() caPool.AddCert(certificate) t.Run("parallel", func(t *testing.T) { // sync routines to check routine leak tests := map[string]struct { clientCfg *Config serverCfg *Config wantErr bool }{ "no_ca": { clientCfg: &Config{}, serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert}, wantErr: true, }, "good_ca": { clientCfg: &Config{RootCAs: caPool}, serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert}, }, "no_ca_skip_verify": { clientCfg: &Config{InsecureSkipVerify: true}, serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert}, }, "good_ca_skip_verify_custom_verify_peer": { clientCfg: &Config{RootCAs: caPool, Certificates: []tls.Certificate{cert}}, serverCfg: &Config{ Certificates: []tls.Certificate{cert}, ClientAuth: RequireAnyClientCert, VerifyPeerCertificate: func(_ [][]byte, chain [][]*x509.Certificate) error { if len(chain) != 0 { return errNotExpectedChain } return nil }, }, }, "good_ca_verify_custom_verify_peer": { clientCfg: &Config{RootCAs: caPool, Certificates: []tls.Certificate{cert}}, serverCfg: &Config{ ClientCAs: caPool, Certificates: []tls.Certificate{cert}, ClientAuth: RequireAndVerifyClientCert, VerifyPeerCertificate: func(_ [][]byte, chain [][]*x509.Certificate) error { if len(chain) == 0 { return errExpecedChain } return nil }, }, }, "good_ca_custom_verify_peer": { clientCfg: &Config{ RootCAs: caPool, VerifyPeerCertificate: func([][]byte, [][]*x509.Certificate) error { return errWrongCert }, }, serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert}, wantErr: true, }, "server_name": { clientCfg: &Config{RootCAs: caPool, ServerName: certificate.Subject.CommonName}, serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert}, }, "server_name_error": { clientCfg: &Config{RootCAs: caPool, ServerName: "barfoo"}, serverCfg: &Config{Certificates: []tls.Certificate{cert}, ClientAuth: NoClientCert}, wantErr: true, }, } for name, tt := range tests { tt := tt t.Run(name, func(t *testing.T) { ca, cb := dpipe.Pipe() type result struct { c *Conn err, hserr error } srvCh := make(chan result) go func() { s, err := Server(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), tt.serverCfg) srvCh <- result{s, err, s.Handshake()} }() cli, err := Client(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), tt.clientCfg) hserr := cli.Handshake() if err == nil { _ = cli.Close() } if tt.wantErr { assert.True(t, err != nil || hserr != nil, "Expected error") } else { assert.NoError(t, err, "Client connection failed") assert.NoError(t, hserr, "Client handshake failed") } srv := <-srvCh if srv.err == nil { _ = srv.c.Close() } }) } }) } func TestCipherSuiteConfiguration(t *testing.T) { // Check for leaking routines report := test.CheckRoutines(t) defer report() for _, test := range []struct { Name string ClientCipherSuites []CipherSuiteID ServerCipherSuites []CipherSuiteID WantClientError error WantServerError error WantSelectedCipherSuite CipherSuiteID }{ { Name: "No CipherSuites specified", ClientCipherSuites: nil, ServerCipherSuites: nil, WantClientError: nil, WantServerError: nil, }, { Name: "Invalid CipherSuite", ClientCipherSuites: []CipherSuiteID{0x00}, ServerCipherSuites: []CipherSuiteID{0x00}, WantClientError: &invalidCipherSuiteError{0x00}, WantServerError: &invalidCipherSuiteError{0x00}, }, { Name: "Valid CipherSuites specified", ClientCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, ServerCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, WantClientError: nil, WantServerError: nil, WantSelectedCipherSuite: TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, }, { Name: "CipherSuites mismatch", ClientCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, ServerCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA}, WantClientError: &alertError{&alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}}, WantServerError: errCipherSuiteNoIntersection, }, { Name: "Valid CipherSuites CCM specified", ClientCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_CCM}, ServerCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_CCM}, WantClientError: nil, WantServerError: nil, WantSelectedCipherSuite: TLS_ECDHE_ECDSA_WITH_AES_128_CCM, }, { Name: "Valid CipherSuites CCM-8 specified", ClientCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8}, ServerCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8}, WantClientError: nil, WantServerError: nil, WantSelectedCipherSuite: TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8, }, { Name: "Server supports subset of client suites", ClientCipherSuites: []CipherSuiteID{ TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, }, ServerCipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA}, WantClientError: nil, WantServerError: nil, WantSelectedCipherSuite: TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, }, } { test := test t.Run(test.Name, func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() ca, cb := dpipe.Pipe() type result struct { c *Conn err error } resultCh := make(chan result) go func() { client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ CipherSuites: test.ClientCipherSuites, }, true) resultCh <- result{client, err} }() server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ CipherSuites: test.ServerCipherSuites, }, true) if err == nil { defer func() { _ = server.Close() }() } assert.ErrorIsf(t, err, test.WantServerError, "TestCipherSuiteConfiguration: Server Error Mismatch '%s'", test.Name) res := <-resultCh if err == nil { assert.NoError(t, server.Close()) assert.NoError(t, res.c.Close()) } assert.ErrorIsf(t, res.err, test.WantClientError, "TestCipherSuiteConfiguration: Client Error Mismatch '%s'") if test.WantSelectedCipherSuite != 0x00 { assert.Equal(t, test.WantSelectedCipherSuite, res.c.state.cipherSuite.ID(), "TestCipherSuiteConfiguration: Server Selected Bad Cipher Suite '%s'", test.Name) } }) } } func TestCertificateAndPSKServer(t *testing.T) { // Check for leaking routines report := test.CheckRoutines(t) defer report() for _, test := range []struct { Name string ClientPSK bool }{ { Name: "Client uses PKI", ClientPSK: false, }, { Name: "Client uses PSK", ClientPSK: true, }, } { test := test t.Run(test.Name, func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() ca, cb := dpipe.Pipe() type result struct { c *Conn err error } resultCh := make(chan result) go func() { config := &Config{CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}} if test.ClientPSK { config.PSK = func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil } config.PSKIdentityHint = []byte{0x00} config.CipherSuites = []CipherSuiteID{TLS_PSK_WITH_AES_128_GCM_SHA256} } client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), config, false) resultCh <- result{client, err} }() config := &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_PSK_WITH_AES_128_GCM_SHA256}, PSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil }, } server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, true) assert.NoErrorf(t, err, "TestCertificateAndPSKServer: Server Error Mismatch '%s'", test.Name) if err != nil { defer func() { assert.NoError(t, server.Close()) }() } res := <-resultCh assert.NoErrorf(t, res.err, "TestCertificateAndPSKServer: Server Error Mismatch '%s'", test.Name) assert.NoError(t, server.Close()) assert.NoError(t, res.c.Close()) }) } } func TestPSKConfiguration(t *testing.T) { //nolint:cyclop // Check for leaking routines report := test.CheckRoutines(t) defer report() for _, test := range []struct { Name string ClientHasCertificate bool ServerHasCertificate bool ClientPSK PSKCallback ServerPSK PSKCallback ClientPSKIdentity []byte ServerPSKIdentity []byte WantClientError error WantServerError error }{ { Name: "PSK and no certificate specified", ClientHasCertificate: false, ServerHasCertificate: false, ClientPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil }, ServerPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil }, ClientPSKIdentity: []byte{0x00}, ServerPSKIdentity: []byte{0x00}, WantClientError: errNoAvailablePSKCipherSuite, WantServerError: errNoAvailablePSKCipherSuite, }, { Name: "PSK and certificate specified", ClientHasCertificate: true, ServerHasCertificate: true, ClientPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil }, ServerPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil }, ClientPSKIdentity: []byte{0x00}, ServerPSKIdentity: []byte{0x00}, WantClientError: errNoAvailablePSKCipherSuite, WantServerError: errNoAvailablePSKCipherSuite, }, { Name: "PSK and no identity specified", ClientHasCertificate: false, ServerHasCertificate: false, ClientPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil }, ServerPSK: func([]byte) ([]byte, error) { return []byte{0x00, 0x01, 0x02}, nil }, ClientPSKIdentity: nil, ServerPSKIdentity: nil, WantClientError: errPSKAndIdentityMustBeSetForClient, WantServerError: errNoAvailablePSKCipherSuite, }, { Name: "No PSK and identity specified", ClientHasCertificate: false, ServerHasCertificate: false, ClientPSK: nil, ServerPSK: nil, ClientPSKIdentity: []byte{0x00}, ServerPSKIdentity: []byte{0x00}, WantClientError: errIdentityNoPSK, WantServerError: errIdentityNoPSK, }, } { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() ca, cb := dpipe.Pipe() type result struct { c *Conn err error } resultCh := make(chan result) go func() { client, err := testClient( ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{PSK: test.ClientPSK, PSKIdentityHint: test.ClientPSKIdentity}, test.ClientHasCertificate, ) resultCh <- result{client, err} }() _, err := testServer( ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{PSK: test.ServerPSK, PSKIdentityHint: test.ServerPSKIdentity}, test.ServerHasCertificate, ) if err != nil || test.WantServerError != nil { if err == nil || test.WantServerError == nil || err.Error() != test.WantServerError.Error() { assert.Failf(t, "TestPSKConfiguration", "Server Error Mismatch '%s'", test.Name) } } res := <-resultCh if res.err != nil || test.WantClientError != nil { if res.err == nil || test.WantClientError == nil && res.err.Error() != test.WantClientError.Error() { assert.Failf(t, "TestPSKConfiguration", "Client Error Mismatch '%s'", test.Name) } } } } func TestServerTimeout(t *testing.T) { //nolint:cyclop // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() cookie := make([]byte, 20) _, err := rand.Read(cookie) assert.NoError(t, err) var rand [28]byte random := handshake.Random{GMTUnixTime: time.Unix(500, 0), RandomBytes: rand} cipherSuites := []CipherSuite{ &ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{}, &ciphersuite.TLSEcdheRsaWithAes128GcmSha256{}, } extensions := []extension.Extension{ &extension.SupportedSignatureAlgorithms{ SignatureHashAlgorithms: []signaturehash.Algorithm{ {Hash: hash.SHA256, Signature: signature.ECDSA}, {Hash: hash.SHA384, Signature: signature.ECDSA}, {Hash: hash.SHA512, Signature: signature.ECDSA}, {Hash: hash.SHA256, Signature: signature.RSA}, {Hash: hash.SHA384, Signature: signature.RSA}, {Hash: hash.SHA512, Signature: signature.RSA}, }, }, &extension.SupportedEllipticCurves{ EllipticCurves: []elliptic.Curve{elliptic.X25519, elliptic.P256, elliptic.P384}, }, &extension.SupportedPointFormats{ PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed}, }, } record := &recordlayer.RecordLayer{ Header: recordlayer.Header{ SequenceNumber: 0, Version: protocol.Version1_2, }, Content: &handshake.Handshake{ // sequenceNumber and messageSequence line up, may need to be re-evaluated Header: handshake.Header{ MessageSequence: 0, }, Message: &handshake.MessageClientHello{ Version: protocol.Version1_2, Cookie: cookie, Random: random, CipherSuiteIDs: cipherSuiteIDs(cipherSuites), CompressionMethods: defaultCompressionMethods(), Extensions: extensions, }, }, } packet, err := record.Marshal() assert.NoError(t, err) ca, cb := dpipe.Pipe() defer func() { assert.NoError(t, ca.Close()) }() // Client reader caReadChan := make(chan []byte, 1000) go func() { for { data := make([]byte, 8192) n, err := ca.Read(data) if err != nil { return } caReadChan <- data[:n] } }() // Start sending ClientHello packets until server responds with first packet go func() { for { select { case <-time.After(10 * time.Millisecond): _, err := ca.Write(packet) if err != nil { return } case <-caReadChan: // Once we receive the first reply from the server, stop return } } }() ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) defer cancel() config := &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, FlightInterval: 100 * time.Millisecond, } _, serverErr := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, true) var netErr net.Error assert.ErrorAsf(t, serverErr, &netErr, "Client error exp(Temporary network error) failed(%v)", serverErr) assert.Truef(t, netErr.Timeout(), "Client error exp(Temporary network error) failed(%v)", serverErr) // Wait a little longer to ensure no additional messages have been sent by the server time.Sleep(300 * time.Millisecond) select { case msg := <-caReadChan: assert.Fail(t, "Expected no additional messages from server", "got: %+v", msg) default: } } func TestProtocolVersionValidation(t *testing.T) { //nolint:maintidx // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() cookie := make([]byte, 20) _, err := rand.Read(cookie) assert.NoError(t, err) var rand [28]byte random := handshake.Random{GMTUnixTime: time.Unix(500, 0), RandomBytes: rand} config := &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, FlightInterval: 100 * time.Millisecond, } t.Run("Server", func(t *testing.T) { serverCases := map[string]struct { records []*recordlayer.RecordLayer }{ "ClientHelloVersion": { records: []*recordlayer.RecordLayer{ { Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &handshake.Handshake{ Message: &handshake.MessageClientHello{ Version: protocol.Version{Major: 0xfe, Minor: 0xff}, // try to downgrade Cookie: cookie, Random: random, CipherSuiteIDs: []uint16{uint16((&ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{}).ID())}, CompressionMethods: defaultCompressionMethods(), }, }, }, }, }, "SecondsClientHelloVersion": { records: []*recordlayer.RecordLayer{ { Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &handshake.Handshake{ Message: &handshake.MessageClientHello{ Version: protocol.Version1_2, Cookie: cookie, Random: random, CipherSuiteIDs: []uint16{uint16((&ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{}).ID())}, CompressionMethods: defaultCompressionMethods(), }, }, }, { Header: recordlayer.Header{ Version: protocol.Version1_2, SequenceNumber: 1, }, Content: &handshake.Handshake{ Header: handshake.Header{ MessageSequence: 1, }, Message: &handshake.MessageClientHello{ Version: protocol.Version{Major: 0xfe, Minor: 0xff}, // try to downgrade Cookie: cookie, Random: random, CipherSuiteIDs: []uint16{uint16((&ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{}).ID())}, CompressionMethods: defaultCompressionMethods(), }, }, }, }, }, } for name, serverCase := range serverCases { serverCase := serverCase t.Run(name, func(t *testing.T) { ca, cb := dpipe.Pipe() defer func() { assert.NoError(t, ca.Close()) }() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() var wg sync.WaitGroup wg.Add(1) defer wg.Wait() go func() { defer wg.Done() _, err := testServer( ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, true, ) assert.ErrorIs(t, err, errUnsupportedProtocolVersion) }() time.Sleep(50 * time.Millisecond) resp := make([]byte, 1024) for _, record := range serverCase.records { packet, err := record.Marshal() assert.NoError(t, err) _, werr := ca.Write(packet) assert.NoError(t, werr) n, rerr := ca.Read(resp[:cap(resp)]) assert.NoError(t, rerr) resp = resp[:n] } h := &recordlayer.Header{} assert.NoError(t, h.Unmarshal(resp)) assert.Equal(t, protocol.ContentTypeAlert, h.ContentType, "Peer must return alert to unsupported protocol version") }) } }) t.Run("Client", func(t *testing.T) { clientCases := map[string]struct { records []*recordlayer.RecordLayer }{ "ServerHelloVersion": { records: []*recordlayer.RecordLayer{ { Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &handshake.Handshake{ Message: &handshake.MessageHelloVerifyRequest{ Version: protocol.Version1_2, Cookie: cookie, }, }, }, { Header: recordlayer.Header{ Version: protocol.Version1_2, SequenceNumber: 1, }, Content: &handshake.Handshake{ Header: handshake.Header{ MessageSequence: 1, }, Message: &handshake.MessageServerHello{ Version: protocol.Version{Major: 0xfe, Minor: 0xff}, // try to downgrade Random: random, CipherSuiteID: func() *uint16 { id := uint16(TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256) return &id }(), CompressionMethod: defaultCompressionMethods()[0], }, }, }, }, }, } for name, clientCase := range clientCases { clientCase := clientCase t.Run(name, func(t *testing.T) { ca, cb := dpipe.Pipe() defer func() { assert.NoError(t, ca.Close()) }() ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() var wg sync.WaitGroup wg.Add(1) defer wg.Wait() go func() { defer wg.Done() _, err := testClient(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, true) assert.ErrorIs(t, err, errUnsupportedProtocolVersion) }() time.Sleep(50 * time.Millisecond) for _, record := range clientCase.records { _, err := ca.Read(make([]byte, 1024)) assert.NoError(t, err) packet, err := record.Marshal() assert.NoError(t, err) _, err = ca.Write(packet) assert.NoError(t, err) } resp := make([]byte, 1024) n, err := ca.Read(resp) assert.NoError(t, err) resp = resp[:n] h := &recordlayer.Header{} assert.NoError(t, h.Unmarshal(resp)) assert.Equal(t, protocol.ContentTypeAlert, h.ContentType, "Peer must return alert to unsupported protocol version") }) } }) } func TestMultipleHelloVerifyRequest(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() cookies := [][]byte{ // first clientHello contains an empty cookie {}, } var packets [][]byte for i := 0; i < 2; i++ { cookie := make([]byte, 20) _, err := rand.Read(cookie) assert.NoError(t, err) cookies = append(cookies, cookie) record := &recordlayer.RecordLayer{ Header: recordlayer.Header{ SequenceNumber: uint64(i), //nolint:gosec // G101 Version: protocol.Version1_2, }, Content: &handshake.Handshake{ Header: handshake.Header{ MessageSequence: uint16(i), //nolint:gosec // G115 }, Message: &handshake.MessageHelloVerifyRequest{ Version: protocol.Version1_2, Cookie: cookie, }, }, } packet, err := record.Marshal() assert.NoError(t, err) packets = append(packets, packet) } ca, cb := dpipe.Pipe() defer func() { assert.NoError(t, ca.Close()) }() ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() var wg sync.WaitGroup wg.Add(1) defer wg.Wait() go func() { defer wg.Done() _, _ = testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{}, false) }() for i, cookie := range cookies { // read client hello resp := make([]byte, 1024) n, err := cb.Read(resp) assert.NoError(t, err) record := &recordlayer.RecordLayer{} assert.NoError(t, record.Unmarshal(resp[:n])) clientHello, ok := record.Content.(*handshake.Handshake).Message.(*handshake.MessageClientHello) assert.True(t, ok) assert.Equal(t, cookie, clientHello.Cookie) if len(packets) <= i { break } // write hello verify request _, err = cb.Write(packets[i]) assert.NoError(t, err) } cancel() } // Assert that a DTLS Server only responds with RenegotiationInfo if a ClientHello contained that // extension according to RFC5746 section 3.6, RFC5246 section 7.4.1.4 and RFC5746 section 4.2. func TestRenegotationInfo(t *testing.T) { //nolint:cyclop // Limit runtime in case of deadlocks lim := test.TimeOut(10 * time.Second) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() resp := make([]byte, 1024) for _, testCase := range []struct { Name string ExpectRenegotiationInfo bool SendRenegotiationInfoExt bool IncludeRenegotiationSCSV bool }{ { Name: "Include RenegotiationInfo", ExpectRenegotiationInfo: true, SendRenegotiationInfoExt: true, }, { Name: "RenegotiationInfo SCSV", ExpectRenegotiationInfo: true, IncludeRenegotiationSCSV: true, }, { Name: "No RenegotiationInfo", ExpectRenegotiationInfo: false, }, } { test := testCase t.Run(test.Name, func(t *testing.T) { ca, cb := dpipe.Pipe() defer func() { assert.NoError(t, ca.Close()) }() ctx, cancel := context.WithCancel(context.Background()) defer cancel() go func() { _, err := testServer( ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{}, true, ) assert.ErrorIs(t, err, context.Canceled) }() time.Sleep(50 * time.Millisecond) extensions := []extension.Extension{} if test.SendRenegotiationInfoExt { extensions = append(extensions, &extension.RenegotiationInfo{ RenegotiatedConnection: 0, }) } cipherSuites := cipherSuiteIDs(defaultCipherSuites()) if test.IncludeRenegotiationSCSV { cipherSuites = append(cipherSuites, renegotiationInfoSCSV) } err := sendClientHello([]byte{}, ca, 0, extensions, cipherSuites...) assert.NoError(t, err) n, err := ca.Read(resp) assert.NoError(t, err) record := &recordlayer.RecordLayer{} assert.NoError(t, record.Unmarshal(resp[:n])) helloVerifyRequest, ok := record.Content.(*handshake.Handshake).Message.(*handshake.MessageHelloVerifyRequest) assert.True(t, ok) err = sendClientHello(helloVerifyRequest.Cookie, ca, 1, extensions, cipherSuites...) assert.NoError(t, err) n, err = ca.Read(resp) assert.NoError(t, err) messages, err := recordlayer.UnpackDatagram(resp[:n]) assert.NoError(t, err) assert.NoError(t, record.Unmarshal(messages[0])) serverHello, ok := record.Content.(*handshake.Handshake).Message.(*handshake.MessageServerHello) assert.True(t, ok) actualNegotationInfo := false for _, v := range serverHello.Extensions { if _, ok := v.(*extension.RenegotiationInfo); ok { actualNegotationInfo = true } } assert.True(t, test.ExpectRenegotiationInfo == actualNegotationInfo, "NegotationInfo state in ServerHello is incorrect: expected(%t) actual(%t)", test.ExpectRenegotiationInfo, actualNegotationInfo) }) } } func TestServerNameIndicationExtension(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() for _, test := range []struct { Name string ServerName string Expected []byte IncludeSNI bool }{ { Name: "Server name is a valid hostname", ServerName: "example.com", Expected: []byte("example.com"), IncludeSNI: true, }, { Name: "Server name is an IP literal", ServerName: "1.2.3.4", Expected: []byte(""), IncludeSNI: false, }, { Name: "Server name is empty", ServerName: "", Expected: []byte(""), IncludeSNI: false, }, } { test := test t.Run(test.Name, func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() ca, cb := dpipe.Pipe() go func() { conf := &Config{ ServerName: test.ServerName, } _, _ = testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), conf, false) }() // Receive ClientHello resp := make([]byte, 1024) n, err := cb.Read(resp) assert.NoError(t, err) r := &recordlayer.RecordLayer{} assert.NoError(t, r.Unmarshal(resp[:n])) clientHello, ok := r.Content.(*handshake.Handshake).Message.(*handshake.MessageClientHello) assert.True(t, ok) gotSNI := false var actualServerName string for _, v := range clientHello.Extensions { if _, ok := v.(*extension.ServerName); ok { gotSNI = true extensionServerName, ok := v.(*extension.ServerName) assert.True(t, ok) actualServerName = extensionServerName.ServerName } } assert.Equalf(t, test.IncludeSNI, gotSNI, "TestSNI: expected SNI inclusion '%s'", test.Name) assert.Equalf(t, test.Expected, []byte(actualServerName), "TestSNI: server name mismatch '%s'", test.Name) }) } } func TestALPNExtension(t *testing.T) { //nolint:maintidx // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() for _, test := range []struct { Name string ClientProtocolNameList []string ServerProtocolNameList []string ExpectedProtocol string ExpectAlertFromClient bool ExpectAlertFromServer bool Alert alert.Description }{ { Name: "Negotiate a protocol", ClientProtocolNameList: []string{"http/1.1", "spd/1"}, ServerProtocolNameList: []string{"spd/1"}, ExpectedProtocol: "spd/1", ExpectAlertFromClient: false, ExpectAlertFromServer: false, Alert: 0, }, { Name: "Server doesn't support any", ClientProtocolNameList: []string{"http/1.1", "spd/1"}, ServerProtocolNameList: []string{}, ExpectedProtocol: "", ExpectAlertFromClient: false, ExpectAlertFromServer: false, Alert: 0, }, { Name: "Negotiate with higher server precedence", ClientProtocolNameList: []string{"http/1.1", "spd/1", "http/3"}, ServerProtocolNameList: []string{"ssh/2", "http/3", "spd/1"}, ExpectedProtocol: "http/3", ExpectAlertFromClient: false, ExpectAlertFromServer: false, Alert: 0, }, { Name: "Empty intersection", ClientProtocolNameList: []string{"http/1.1", "http/3"}, ServerProtocolNameList: []string{"ssh/2", "spd/1"}, ExpectedProtocol: "", ExpectAlertFromClient: false, ExpectAlertFromServer: true, Alert: alert.NoApplicationProtocol, }, { Name: "Multiple protocols in ServerHello", ClientProtocolNameList: []string{"http/1.1"}, ServerProtocolNameList: []string{"http/1.1"}, ExpectedProtocol: "http/1.1", ExpectAlertFromClient: true, ExpectAlertFromServer: false, Alert: alert.InternalError, }, } { test := test t.Run(test.Name, func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() ca, cb := dpipe.Pipe() go func() { conf := &Config{ SupportedProtocols: test.ClientProtocolNameList, } _, _ = testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), conf, false) }() // Receive ClientHello resp := make([]byte, 1024) n, err := cb.Read(resp) assert.NoError(t, err) ctx2, cancel2 := context.WithTimeout(context.Background(), 10*time.Second) defer cancel2() ca2, cb2 := dpipe.Pipe() go func() { conf := &Config{ SupportedProtocols: test.ServerProtocolNameList, } _, err2 := testServer(ctx2, dtlsnet.PacketConnFromConn(cb2), cb2.RemoteAddr(), conf, true) if test.ExpectAlertFromServer { assert.NotErrorIs(t, err2, context.Canceled) } }() time.Sleep(50 * time.Millisecond) // Forward ClientHello _, err = ca2.Write(resp[:n]) assert.NoError(t, err) // Receive HelloVerify resp2 := make([]byte, 1024) n, err = ca2.Read(resp2) assert.NoError(t, err) // Forward HelloVerify _, err = cb.Write(resp2[:n]) assert.NoError(t, err) // Receive ClientHello resp3 := make([]byte, 1024) n, err = cb.Read(resp3) assert.NoError(t, err) // Forward ClientHello _, err = ca2.Write(resp3[:n]) assert.NoError(t, err) // Receive ServerHello resp4 := make([]byte, 1024) n, err = ca2.Read(resp4) assert.NoError(t, err) messages, err := recordlayer.UnpackDatagram(resp4[:n]) assert.NoError(t, err) record := &recordlayer.RecordLayer{} assert.NoError(t, record.Unmarshal(messages[0])) if test.ExpectAlertFromServer { //nolint:nestif a, ok := record.Content.(*alert.Alert) assert.True(t, ok) assert.Equalf(t, test.Alert, a.Description, "ALPN %v", test.Name) } else { serverHello, ok := record.Content.(*handshake.Handshake).Message.(*handshake.MessageServerHello) assert.True(t, ok) var negotiatedProtocol string for _, v := range serverHello.Extensions { if _, ok := v.(*extension.ALPN); ok { e, ok := v.(*extension.ALPN) assert.True(t, ok) negotiatedProtocol = e.ProtocolNameList[0] // Manipulate ServerHello if test.ExpectAlertFromClient { e.ProtocolNameList = append(e.ProtocolNameList, "oops") } } } assert.Equalf(t, test.ExpectedProtocol, negotiatedProtocol, "ALPN %v", test.Name) s, err := record.Marshal() assert.NoError(t, err) // Forward ServerHello _, err = cb.Write(s) assert.NoError(t, err) if test.ExpectAlertFromClient { resp5 := make([]byte, 1024) n, err = cb.Read(resp5) assert.NoError(t, err) r2 := &recordlayer.RecordLayer{} assert.NoError(t, r2.Unmarshal(resp5[:n])) a, ok := r2.Content.(*alert.Alert) assert.True(t, ok) assert.Equalf(t, test.Alert, a.Description, "ALPN %v", test.Name) } } time.Sleep(50 * time.Millisecond) // Give some time for returned errors }) } } // Make sure the supported_groups extension is not included in the ServerHello. func TestSupportedGroupsExtension(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() t.Run("ServerHello Supported Groups", func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() ca, cb := dpipe.Pipe() go func() { _, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{}, true) assert.ErrorIs(t, err, context.Canceled) }() extensions := []extension.Extension{ &extension.SupportedEllipticCurves{ EllipticCurves: []elliptic.Curve{elliptic.X25519, elliptic.P256, elliptic.P384}, }, &extension.SupportedPointFormats{ PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed}, }, } time.Sleep(50 * time.Millisecond) resp := make([]byte, 1024) err := sendClientHello([]byte{}, ca, 0, extensions) assert.NoError(t, err) // Receive ServerHello n, err := ca.Read(resp) assert.NoError(t, err) record := &recordlayer.RecordLayer{} assert.NoError(t, record.Unmarshal(resp[:n])) helloVerifyRequest, ok := record.Content.(*handshake.Handshake).Message.(*handshake.MessageHelloVerifyRequest) assert.True(t, ok, "Failed to cast MessageHelloVerifyRequest") err = sendClientHello(helloVerifyRequest.Cookie, ca, 1, extensions) assert.NoError(t, err) n, err = ca.Read(resp) assert.NoError(t, err) messages, err := recordlayer.UnpackDatagram(resp[:n]) assert.NoError(t, err) assert.NoError(t, record.Unmarshal(messages[0])) serverHello, ok := record.Content.(*handshake.Handshake).Message.(*handshake.MessageServerHello) assert.True(t, ok, "TestSupportedGroups: Failed to cast MessageServerHello") gotGroups := false for _, v := range serverHello.Extensions { if _, ok := v.(*extension.SupportedEllipticCurves); ok { gotGroups = true } } assert.False(t, gotGroups, "TestSupportedGroups: supported_groups extension was sent in ServerHello") }) } func TestSessionResume(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() t.Run("resumed", func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() type result struct { c *Conn err error } clientRes := make(chan result, 1) ss := &memSessStore{} id, _ := hex.DecodeString("9b9fc92255634d9fb109febed42166717bb8ded8c738ba71bc7f2a0d9dae0306") secret, _ := hex.DecodeString( "2e942a37aca5241deb2295b5fcedac221c7078d2503d2b62aeb48c880d7da73c001238b708559686b9da6e829c05ead7", ) s := Session{ID: id, Secret: secret} ca, cb := dpipe.Pipe() _ = ss.Set(id, s) _ = ss.Set([]byte(ca.RemoteAddr().String()+"_example.com"), s) go func() { config := &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, ServerName: "example.com", SessionStore: ss, MTU: 100, } c, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), config, false) clientRes <- result{c, err} }() config := &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, ServerName: "example.com", SessionStore: ss, MTU: 100, } server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, true) assert.NoError(t, err) state, ok := server.ConnectionState() assert.True(t, ok) actualSessionID := state.SessionID actualMasterSecret := state.masterSecret assert.Equal(t, actualSessionID, id, "TestSessionResumetion SessionID mismatch") assert.Equal(t, actualMasterSecret, secret, "TestSessionResumetion masterSecret mismatch") defer func() { assert.NoError(t, server.Close()) }() res := <-clientRes assert.NoError(t, res.err) assert.NoError(t, res.c.Close()) }) t.Run("new session", func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() type result struct { c *Conn err error } clientRes := make(chan result, 1) s1 := &memSessStore{} s2 := &memSessStore{} ca, cb := dpipe.Pipe() go func() { config := &Config{ ServerName: "example.com", SessionStore: s1, } c, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), config, false) clientRes <- result{c, err} }() config := &Config{ SessionStore: s2, } server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, true) assert.NoError(t, err) state, ok := server.ConnectionState() assert.True(t, ok) actualSessionID := state.SessionID actualMasterSecret := state.masterSecret ss, _ := s2.Get(actualSessionID) assert.Equal(t, actualMasterSecret, ss.Secret, "TestSessionResumetion masterSecret mismatch") defer func() { assert.NoError(t, server.Close()) }() res := <-clientRes assert.NoError(t, res.err) cs, _ := s1.Get([]byte(ca.RemoteAddr().String() + "_example.com")) assert.Equal(t, actualMasterSecret, cs.Secret, "TestSessionResumetion mismatch") assert.NoError(t, res.c.Close()) }) } type memSessStore struct { sync.Map } func (ms *memSessStore) Set(key []byte, s Session) error { k := hex.EncodeToString(key) ms.Store(k, s) return nil } func (ms *memSessStore) Get(key []byte) (Session, error) { k := hex.EncodeToString(key) v, ok := ms.Load(k) if !ok { return Session{}, nil } s, ok := v.(Session) if !ok { return Session{}, nil } return s, nil } func (ms *memSessStore) Del(key []byte) error { k := hex.EncodeToString(key) ms.Delete(k) return nil } // Assert that the server only uses CipherSuites with a hash+signature that matches // the certificate. As specified in rfc5246#section-7.4.3 // . func TestCipherSuiteMatchesCertificateType(t *testing.T) { //nolint:cyclop // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() for _, test := range []struct { Name string cipherList []CipherSuiteID expectedCipher CipherSuiteID generateRSA bool }{ { Name: "ECDSA Certificate with RSA CipherSuite first", cipherList: []CipherSuiteID{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, expectedCipher: TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, }, { Name: "RSA Certificate with ECDSA CipherSuite first", cipherList: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}, expectedCipher: TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, generateRSA: true, }, } { test := test t.Run(test.Name, func(t *testing.T) { clientErr := make(chan error, 1) client := make(chan *Conn, 1) ca, cb := dpipe.Pipe() go func() { c, err := testClient(context.TODO(), dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ CipherSuites: test.cipherList, }, false) clientErr <- err client <- c }() var ( signer crypto.Signer err error ) if test.generateRSA { signer, err = rsa.GenerateKey(rand.Reader, 2048) assert.NoError(t, err) } else { signer, err = ecdsa.GenerateKey(cryptoElliptic.P256(), rand.Reader) assert.NoError(t, err) } serverCert, err := selfsign.SelfSign(signer) assert.NoError(t, err) s, err := testServer(context.TODO(), dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ CipherSuites: test.cipherList, Certificates: []tls.Certificate{serverCert}, }, false) assert.NoError(t, err) assert.NoError(t, s.Close()) c := <-client assert.NoError(t, <-clientErr) assert.NoError(t, c.Close()) state, ok := c.ConnectionState() assert.True(t, ok) assert.Equal(t, test.expectedCipher, state.cipherSuite.ID()) }) } } // Test that we return the proper certificate if we are serving multiple ServerNames on a single Server. func TestMultipleServerCertificates(t *testing.T) { fooCert, err := selfsign.GenerateSelfSignedWithDNS("foo") assert.NoError(t, err) barCert, err := selfsign.GenerateSelfSignedWithDNS("bar") assert.NoError(t, err) caPool := x509.NewCertPool() for _, cert := range []tls.Certificate{fooCert, barCert} { certificate, err := x509.ParseCertificate(cert.Certificate[0]) assert.NoError(t, err) caPool.AddCert(certificate) } for _, test := range []struct { RequestServerName string ExpectedDNSName string }{ { "foo", "foo", }, { "bar", "bar", }, { "invalid", "foo", }, } { test := test t.Run(test.RequestServerName, func(t *testing.T) { clientErr := make(chan error, 2) client := make(chan *Conn, 1) ca, cb := dpipe.Pipe() go func() { clientConn, err := testClient(context.TODO(), dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ RootCAs: caPool, ServerName: test.RequestServerName, VerifyPeerCertificate: func(rawCerts [][]byte, _ [][]*x509.Certificate) error { certificate, err := x509.ParseCertificate(rawCerts[0]) if err != nil { return err } if certificate.DNSNames[0] != test.ExpectedDNSName { return errWrongCert } return nil }, }, false) clientErr <- err client <- clientConn }() s, err := testServer(context.TODO(), dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ Certificates: []tls.Certificate{fooCert, barCert}, }, false) assert.NoError(t, err) assert.NoError(t, s.Close()) assert.NoError(t, <-clientErr) assert.NoError(t, (<-client).Close()) }) } } func TestEllipticCurveConfiguration(t *testing.T) { // Check for leaking routines report := test.CheckRoutines(t) defer report() for _, test := range []struct { Name string ConfigCurves []elliptic.Curve HandshakeCurves []elliptic.Curve }{ { Name: "Curve defaulting", ConfigCurves: nil, HandshakeCurves: defaultCurves, }, { Name: "Single curve", ConfigCurves: []elliptic.Curve{elliptic.X25519}, HandshakeCurves: []elliptic.Curve{elliptic.X25519}, }, { Name: "Multiple curves", ConfigCurves: []elliptic.Curve{elliptic.P384, elliptic.X25519}, HandshakeCurves: []elliptic.Curve{elliptic.P384, elliptic.X25519}, }, } { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() ca, cb := dpipe.Pipe() type result struct { c *Conn err error } resultCh := make(chan result) go func() { client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, EllipticCurves: test.ConfigCurves, }, true) resultCh <- result{client, err} }() server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, EllipticCurves: test.ConfigCurves, }, true) assert.NoError(t, err) ok := len(test.ConfigCurves) == 0 || len(test.ConfigCurves) == len(test.HandshakeCurves) assert.True(t, ok, "Failed to default Elliptic curves") if len(test.ConfigCurves) != 0 { assert.Equal(t, len(test.HandshakeCurves), len(server.fsm.cfg.ellipticCurves), "Failed to configure Elliptic curves") for i, c := range test.ConfigCurves { assert.Equal(t, c, server.fsm.cfg.ellipticCurves[i], "Failed to maintain Elliptic curve order") } } res := <-resultCh assert.NoError(t, res.err, "Client error") defer func() { assert.NoError(t, server.Close()) assert.NoError(t, res.c.Close()) }() } } func TestSkipHelloVerify(t *testing.T) { report := test.CheckRoutines(t) defer report() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() ca, cb := dpipe.Pipe() certificate, err := selfsign.GenerateSelfSigned() assert.NoError(t, err) gotHello := make(chan struct{}) go func() { server, sErr := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ Certificates: []tls.Certificate{certificate}, LoggerFactory: logging.NewDefaultLoggerFactory(), InsecureSkipVerifyHello: true, }, false) assert.NoError(t, sErr) buf := make([]byte, 1024) _, sErr = server.Read(buf) //nolint:contextcheck assert.NoError(t, sErr) gotHello <- struct{}{} assert.NoError(t, server.Close()) //nolint:contextcheck }() client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ LoggerFactory: logging.NewDefaultLoggerFactory(), InsecureSkipVerify: true, }, false) assert.NoError(t, err) _, err = client.Write([]byte("hello")) assert.NoError(t, err) select { case <-gotHello: // OK case <-time.After(time.Second * 5): assert.Fail(t, "timeout") } assert.NoError(t, client.Close()) } type connWithCallback struct { net.Conn onWrite func([]byte) } func (c *connWithCallback) Write(b []byte) (int, error) { if c.onWrite != nil { c.onWrite(b) } return c.Conn.Write(b) } func TestApplicationDataQueueLimited(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() ca, cb := dpipe.Pipe() defer func() { assert.NoError(t, ca.Close()) }() defer func() { assert.NoError(t, cb.Close()) }() done := make(chan struct{}) go func() { serverCert, err := selfsign.GenerateSelfSigned() assert.NoError(t, err) cfg := &Config{} cfg.Certificates = []tls.Certificate{serverCert} dconn, err := createConn(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), cfg, false, nil) assert.NoError(t, err) go func() { for i := 0; i < 5; i++ { dconn.lock.RLock() qlen := len(dconn.encryptedPackets) dconn.lock.RUnlock() assert.GreaterOrEqual(t, maxAppDataPacketQueueSize, qlen, "too many encrypted packets enqueued") time.Sleep(1 * time.Second) } }() assert.Error(t, dconn.HandshakeContext(ctx)) close(done) }() extensions := []extension.Extension{} time.Sleep(50 * time.Millisecond) assert.NoError(t, sendClientHello([]byte{}, ca, 0, extensions)) time.Sleep(50 * time.Millisecond) for i := 0; i < 1000; i++ { // Send an application data packet packet, err := (&recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, SequenceNumber: uint64(3), Epoch: 1, // use an epoch greater than 0 }, Content: &protocol.ApplicationData{ Data: []byte{1, 2, 3, 4}, }, }).Marshal() assert.NoError(t, err) _, err = ca.Write(packet) assert.NoError(t, err) if i%100 == 0 { time.Sleep(10 * time.Millisecond) } } time.Sleep(1 * time.Second) assert.NoError(t, ca.Close()) <-done } func TestHelloRandom(t *testing.T) { report := test.CheckRoutines(t) defer report() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() ca, cb := dpipe.Pipe() certificate, err := selfsign.GenerateSelfSigned() assert.NoError(t, err) gotHello := make(chan struct{}) chRandom := [handshake.RandomBytesLength]byte{} _, err = rand.Read(chRandom[:]) assert.NoError(t, err) go func() { server, sErr := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ GetCertificate: func(chi *ClientHelloInfo) (*tls.Certificate, error) { if len(chi.CipherSuites) == 0 { return &certificate, nil } assert.Equal(t, chRandom[:], chi.RandomBytes[:]) return &certificate, nil }, LoggerFactory: logging.NewDefaultLoggerFactory(), }, false) assert.NoError(t, sErr) buf := make([]byte, 1024) _, sErr = server.Read(buf) //nolint:contextcheck assert.NoError(t, sErr) gotHello <- struct{}{} assert.NoError(t, server.Close()) //nolint:contextcheck }() client, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ LoggerFactory: logging.NewDefaultLoggerFactory(), HelloRandomBytesGenerator: func() [handshake.RandomBytesLength]byte { return chRandom }, InsecureSkipVerify: true, }, false) assert.NoError(t, err) _, err = client.Write([]byte("hello")) assert.NoError(t, err) select { case <-gotHello: // OK case <-time.After(time.Second * 5): assert.Fail(t, "timeout") } assert.NoError(t, client.Close()) } func TestOnConnectionAttempt(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*20) defer cancel() var clientOnConnectionAttempt, serverOnConnectionAttempt atomic.Int32 ca, cb := dpipe.Pipe() clientErr := make(chan error, 1) go func() { _, err := testClient(ctx, dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ OnConnectionAttempt: func(in net.Addr) error { clientOnConnectionAttempt.Store(1) assert.NotNil(t, in) return nil }, }, true) clientErr <- err }() expectedErr := &FatalError{} _, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ OnConnectionAttempt: func(in net.Addr) error { serverOnConnectionAttempt.Store(1) assert.NotNil(t, in) return expectedErr }, }, true) assert.ErrorIs(t, err, expectedErr) assert.Error(t, <-clientErr) assert.Equal(t, int32(1), serverOnConnectionAttempt.Load(), "OnConnectionAttempt did not fire for server") assert.Equal(t, int32(0), clientOnConnectionAttempt.Load(), "OnConnectionAttempt fired for client") } func TestFragmentBuffer_Retransmission(t *testing.T) { fragmentBuffer := newFragmentBuffer() frag := []byte{ 0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x30, 0x03, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0xfe, 0xff, 0x01, 0x01, } _, isRetransmission, err := fragmentBuffer.push(frag) assert.NoError(t, err) assert.False(t, isRetransmission) v, _ := fragmentBuffer.pop() assert.NotNil(t, v) _, isRetransmission, err = fragmentBuffer.push(frag) assert.NoError(t, err) assert.True(t, isRetransmission) } func TestConnectionState(t *testing.T) { ca, cb := dpipe.Pipe() // Setup client clientCfg := &Config{} clientCert, err := selfsign.GenerateSelfSigned() assert.NoError(t, err) clientCfg.Certificates = []tls.Certificate{clientCert} clientCfg.InsecureSkipVerify = true client, err := Client(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), clientCfg) assert.NoError(t, err) defer func() { _ = client.Close() }() _, ok := client.ConnectionState() assert.False(t, ok) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() errorChannel := make(chan error) go func() { errC := client.HandshakeContext(ctx) errorChannel <- errC }() // Setup server server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{}, true) assert.NoError(t, err) defer func() { _ = server.Close() }() err = <-errorChannel assert.NoError(t, err) _, ok = client.ConnectionState() assert.True(t, ok) } func TestMultiHandshake(t *testing.T) { defer test.CheckRoutines(t)() defer test.TimeOut(time.Second * 10).Stop() ca, cb := dpipe.Pipe() serverCert, err := selfsign.GenerateSelfSigned() assert.NoError(t, err) server, err := Server(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ Certificates: []tls.Certificate{serverCert}, }) assert.NoError(t, err) go func() { _ = server.Handshake() }() clientCert, err := selfsign.GenerateSelfSigned() assert.NoError(t, err) client, err := Client(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), &Config{ Certificates: []tls.Certificate{clientCert}, }) assert.NoError(t, err) assert.Error(t, client.Handshake()) assert.Error(t, client.Handshake()) assert.NoError(t, server.Close()) assert.NoError(t, client.Close()) } func TestCloseDuringHandshake(t *testing.T) { defer test.CheckRoutines(t)() defer test.TimeOut(time.Second * 10).Stop() serverCert, err := selfsign.GenerateSelfSigned() assert.NoError(t, err) for i := 0; i < 100; i++ { _, cb := dpipe.Pipe() server, err := Server(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ Certificates: []tls.Certificate{serverCert}, }) assert.NoError(t, err) waitChan := make(chan struct{}) go func() { close(waitChan) _ = server.Handshake() }() <-waitChan assert.NoError(t, server.Close()) } } func TestCloseWithoutHandshake(t *testing.T) { defer test.CheckRoutines(t)() defer test.TimeOut(time.Second * 10).Stop() serverCert, err := selfsign.GenerateSelfSigned() assert.NoError(t, err) _, cb := dpipe.Pipe() server, err := Server(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{ Certificates: []tls.Certificate{serverCert}, }) assert.NoError(t, err) assert.NoError(t, server.Close()) } dtls-3.1.2/connection_id.go000066400000000000000000000056561514330267300156450ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "crypto/rand" "github.com/pion/dtls/v3/pkg/protocol" "github.com/pion/dtls/v3/pkg/protocol/extension" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) // RandomCIDGenerator is a random Connection ID generator where CID is the // specified size. Specifying a size of 0 will indicate to peers that sending a // Connection ID is not necessary. func RandomCIDGenerator(size int) func() []byte { return func() []byte { cid := make([]byte, size) if _, err := rand.Read(cid); err != nil { panic(err) //nolint -- nonrecoverable } return cid } } // OnlySendCIDGenerator enables sending Connection IDs negotiated with a peer, // but indicates to the peer that sending Connection IDs in return is not // necessary. func OnlySendCIDGenerator() func() []byte { return func() []byte { return nil } } // cidDatagramRouter extracts connection IDs from incoming datagram payloads and // uses them to route to the proper connection. // NOTE: properly routing datagrams based on connection IDs requires using // constant size connection IDs. func cidDatagramRouter(size int) func([]byte) (string, bool) { return func(packet []byte) (string, bool) { pkts, err := recordlayer.ContentAwareUnpackDatagram(packet, size) if err != nil || len(pkts) < 1 { return "", false } for _, pkt := range pkts { h := &recordlayer.Header{ ConnectionID: make([]byte, size), } if err := h.Unmarshal(pkt); err != nil { continue } if h.ContentType != protocol.ContentTypeConnectionID { continue } return string(h.ConnectionID), true } return "", false } } // cidConnIdentifier extracts connection IDs from outgoing ServerHello records // and associates them with the associated connection. // NOTE: a ServerHello should always be the first record in a datagram if // multiple are present, so we avoid iterating through all packets if the first // is not a ServerHello. func cidConnIdentifier() func([]byte) (string, bool) { //nolint:cyclop return func(packet []byte) (string, bool) { pkts, err := recordlayer.UnpackDatagram(packet) if err != nil || len(pkts) < 1 { return "", false } var h recordlayer.Header if hErr := h.Unmarshal(pkts[0]); hErr != nil { return "", false } if h.ContentType != protocol.ContentTypeHandshake { return "", false } var hh handshake.Header var sh handshake.MessageServerHello for _, pkt := range pkts { if hhErr := hh.Unmarshal(pkt[recordlayer.FixedHeaderSize:]); hhErr != nil { continue } if err = sh.Unmarshal(pkt[recordlayer.FixedHeaderSize+handshake.HeaderLength:]); err == nil { break } } if err != nil { return "", false } for _, ext := range sh.Extensions { if e, ok := ext.(*extension.ConnectionID); ok { return string(e.CID), true } } return "", false } } dtls-3.1.2/connection_id_test.go000066400000000000000000000171141514330267300166740ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "testing" "time" "github.com/pion/dtls/v3/pkg/protocol" "github.com/pion/dtls/v3/pkg/protocol/extension" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/dtls/v3/pkg/protocol/recordlayer" "github.com/stretchr/testify/assert" ) func TestRandomConnectionIDGenerator(t *testing.T) { cases := map[string]struct { reason string size int }{ "LengthMatch": { reason: "Zero size should match length of generated CID.", size: 0, }, "LengthMatchSome": { reason: "Non-zero size should match length of generated CID with non-zero.", size: 8, }, } for name, tc := range cases { t.Run(name, func(t *testing.T) { assert.Equal(t, tc.size, len(RandomCIDGenerator(tc.size)()), "%s\nRandomCIDGenerator mismatch", tc.reason) }) } } func TestOnlySendCIDGenerator(t *testing.T) { cases := map[string]struct { reason string }{ "LengthMatch": { reason: "CID length should always be zero.", }, } for name, tc := range cases { t.Run(name, func(t *testing.T) { assert.Equalf(t, 0, len(OnlySendCIDGenerator()()), "%s\nOnlySendCIDGenerator mismatch", tc.reason) }) } } func TestCIDDatagramRouter(t *testing.T) { cid := []byte("abcd1234") cidLen := 8 appRecord, err := (&recordlayer.RecordLayer{ Header: recordlayer.Header{ Epoch: 1, Version: protocol.Version1_2, }, Content: &protocol.ApplicationData{ Data: []byte("application data"), }, }).Marshal() assert.NoError(t, err) appData, err := (&protocol.ApplicationData{ Data: []byte("some data"), }).Marshal() assert.NoError(t, err) inner, err := (&recordlayer.InnerPlaintext{ Content: appData, RealType: protocol.ContentTypeApplicationData, }).Marshal() assert.NoError(t, err) cidHeader, err := (&recordlayer.Header{ Epoch: 1, Version: protocol.Version1_2, ContentType: protocol.ContentTypeConnectionID, ContentLen: uint16(len(inner)), //nolint:gosec // G115 ConnectionID: cid, SequenceNumber: 1, }).Marshal() assert.NoError(t, err) cases := map[string]struct { reason string size int datagram []byte ok bool want string }{ "EmptyDatagram": { reason: "If datagram is empty, we cannot extract an identifier", size: cidLen, datagram: []byte{}, ok: false, want: "", }, "NotADTLSRecord": { reason: "If datagram is not a DTLS record, we cannot extract an identifier", size: cidLen, datagram: []byte("not a DTLS record"), ok: false, want: "", }, "NotAConnectionIDDatagram": { reason: "If datagram does not contain any Connection ID records, we cannot extract an identifier", size: cidLen, datagram: appRecord, ok: false, want: "", }, "OneRecordConnectionID": { reason: "If datagram contains one Connection ID record, we should be able to extract it.", size: cidLen, datagram: append(cidHeader, inner...), ok: true, want: string(cid), }, "OneRecordConnectionIDAltLength": { //nolint:lll reason: "If datagram contains one Connection ID record, but it has the wrong length we should not be able to extract it.", size: cidLen, datagram: func() []byte { altCIDHeader, err := (&recordlayer.Header{ Epoch: 1, Version: protocol.Version1_2, ContentType: protocol.ContentTypeConnectionID, ContentLen: uint16(len(inner)), //nolint:gosec // G115 ConnectionID: []byte("abcd"), SequenceNumber: 1, }).Marshal() assert.NoError(t, err) return append(altCIDHeader, inner...) }(), ok: false, want: "", }, "MultipleRecordOneConnectionID": { //nolint:lll reason: "If datagram contains multiple records and one is a Connection ID record, we should be able to extract it.", size: 8, datagram: append(append(appRecord, cidHeader...), inner...), ok: true, want: string(cid), }, "MultipleRecordMultipleConnectionID": { //nolint:lll reason: "If datagram contains multiple records and multiple are Connection ID records, we should extract the first one.", size: 8, datagram: append(append(append(appRecord, func() []byte { altCIDHeader, err := (&recordlayer.Header{ Epoch: 1, Version: protocol.Version1_2, ContentType: protocol.ContentTypeConnectionID, ContentLen: uint16(len(inner)), //nolint:gosec // G115 ConnectionID: []byte("1234abcd"), SequenceNumber: 1, }).Marshal() assert.NoError(t, err) return append(altCIDHeader, inner...) }()...), cidHeader...), inner...), ok: true, want: "1234abcd", }, } for name, tc := range cases { t.Run(name, func(t *testing.T) { cid, ok := cidDatagramRouter(tc.size)(tc.datagram) assert.Equal(t, tc.ok, ok, "%s\ncidDatagramRouter mismatch", tc.reason) assert.Equal(t, tc.want, cid, "%s\ncidDatagramRouter mismatch", tc.reason) }) } } func TestCIDConnIdentifier(t *testing.T) { cid := []byte("abcd1234") cs := uint16(TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256) sh, err := (&recordlayer.RecordLayer{ Header: recordlayer.Header{ Epoch: 0, Version: protocol.Version1_2, }, Content: &handshake.Handshake{ Message: &handshake.MessageServerHello{ Version: protocol.Version1_2, Random: handshake.Random{GMTUnixTime: time.Unix(500, 0), RandomBytes: [28]byte{}}, SessionID: []byte("hello"), CipherSuiteID: &cs, CompressionMethod: defaultCompressionMethods()[0], Extensions: []extension.Extension{ &extension.ConnectionID{ CID: cid, }, }, }, }, }).Marshal() assert.NoError(t, err) appRecord, err := (&recordlayer.RecordLayer{ Header: recordlayer.Header{ Epoch: 1, Version: protocol.Version1_2, }, Content: &protocol.ApplicationData{ Data: []byte("application data"), }, }).Marshal() assert.NoError(t, err) cases := map[string]struct { reason string datagram []byte ok bool want string }{ "EmptyDatagram": { reason: "If datagram is empty, we cannot extract an identifier", datagram: []byte{}, ok: false, want: "", }, "NotADTLSRecord": { reason: "If datagram is not a DTLS record, we cannot extract an identifier", datagram: []byte("not a DTLS record"), ok: false, want: "", }, "NotAServerhelloDatagram": { reason: "If datagram does not contain any ServerHello record, we cannot extract an identifier", datagram: appRecord, ok: false, want: "", }, "OneRecordServerHello": { reason: "If datagram contains one ServerHello record, we should be able to extract an identifier.", datagram: sh, ok: true, want: string(cid), }, "MultipleRecordFirstServerHello": { //nolint:lll reason: "If datagram contains multiple records and the first is a ServerHello record, we should be able to extract an identifier.", datagram: append(sh, appRecord...), ok: true, want: string(cid), }, "MultipleRecordNotFirstServerHello": { //nolint:lll reason: "If datagram contains multiple records and the first is not a ServerHello record, we should not be able to extract an identifier.", datagram: append(appRecord, sh...), ok: false, want: "", }, } for name, tc := range cases { t.Run(name, func(t *testing.T) { cid, ok := cidConnIdentifier()(tc.datagram) assert.Equalf(t, tc.ok, ok, "%s\ncidConnIdentifier mismatch", tc.reason) assert.Equalf(t, tc.want, cid, "%s\ncidConnIdentifier mismatch", tc.reason) }) } } dtls-3.1.2/crypto.go000066400000000000000000000312501514330267300143370ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "crypto" "crypto/ecdsa" "crypto/ed25519" "crypto/rand" "crypto/rsa" "crypto/x509" "crypto/x509/pkix" "encoding/asn1" "encoding/binary" "math/big" "time" "github.com/pion/dtls/v3/pkg/crypto/elliptic" "github.com/pion/dtls/v3/pkg/crypto/hash" "github.com/pion/dtls/v3/pkg/crypto/signature" "github.com/pion/dtls/v3/pkg/crypto/signaturehash" ) type ecdsaSignature struct { R, S *big.Int } func valueKeyMessage(clientRandom, serverRandom, publicKey []byte, namedCurve elliptic.Curve) []byte { serverECDHParams := make([]byte, 4) serverECDHParams[0] = 3 // named curve binary.BigEndian.PutUint16(serverECDHParams[1:], uint16(namedCurve)) serverECDHParams[3] = byte(len(publicKey)) plaintext := []byte{} plaintext = append(plaintext, clientRandom...) plaintext = append(plaintext, serverRandom...) plaintext = append(plaintext, serverECDHParams...) plaintext = append(plaintext, publicKey...) return plaintext } // validateSignatureAlgOID validates that the signature scheme matches the // certificate's public key algorithm OID. This is required by RFC 8446 Section 4.2.3: // - RSA_PSS_RSAE requires rsaEncryption OID // - RSA_PSS_PSS requires id-RSASSA-PSS OID // // Note: returns nil if the given signature.Algorithm is not PSS based. // // https://www.rfc-editor.org/rfc/rfc8446#section-4.2.3 func validateSignatureAlgOID(cert *x509.Certificate, sigAlg signature.Algorithm) error { if !sigAlg.IsPSS() { return nil } // Get the certificate's public key algorithm OID from the raw certificate // We need to parse the SubjectPublicKeyInfo to get the algorithm OID var spki struct { Algorithm pkix.AlgorithmIdentifier PublicKey asn1.BitString } if _, err := asn1.Unmarshal(cert.RawSubjectPublicKeyInfo, &spki); err != nil { return err } certOID := spki.Algorithm.Algorithm switch sigAlg { // Check RSAE variants (0x0804-0x0806) require rsaEncryption OID case signature.RSA_PSS_RSAE_SHA256, signature.RSA_PSS_RSAE_SHA384, signature.RSA_PSS_RSAE_SHA512: oidPublicKeyRSA := asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 1} // OID: rsaEncryption if !certOID.Equal(oidPublicKeyRSA) { return errInvalidCertificateOID } return nil // Check PSS variants (0x0809-0x080b) require id-RSASSA-PSS OID case signature.RSA_PSS_PSS_SHA256, signature.RSA_PSS_PSS_SHA384, signature.RSA_PSS_PSS_SHA512: oidPublicKeyRSAPSS := asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 1, 10} // OID: id-RSASSA-PSS if !certOID.Equal(oidPublicKeyRSAPSS) { return errInvalidCertificateOID } return nil default: return nil } } // If the client provided a "signature_algorithms" extension, then all // certificates provided by the server MUST be signed by a // hash/signature algorithm pair that appears in that extension // // https://tools.ietf.org/html/rfc5246#section-7.4.2 func generateKeySignature( clientRandom, serverRandom, publicKey []byte, namedCurve elliptic.Curve, signer crypto.Signer, hashAlgorithm hash.Algorithm, signatureAlgorithm signature.Algorithm, ) ([]byte, error) { msg := valueKeyMessage(clientRandom, serverRandom, publicKey, namedCurve) switch signer.Public().(type) { case ed25519.PublicKey: // https://crypto.stackexchange.com/a/55483 return signer.Sign(rand.Reader, msg, crypto.Hash(0)) case *ecdsa.PublicKey: hashed := hashAlgorithm.Digest(msg) return signer.Sign(rand.Reader, hashed, hashAlgorithm.CryptoHash()) case *rsa.PublicKey: hashed := hashAlgorithm.Digest(msg) // Use RSA-PSS if the signature algorithm is PSS if signatureAlgorithm.IsPSS() { pssOpts := &rsa.PSSOptions{ SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: hashAlgorithm.CryptoHash(), } return signer.Sign(rand.Reader, hashed, pssOpts) } // Otherwise use PKCS#1 v1.5 return signer.Sign(rand.Reader, hashed, hashAlgorithm.CryptoHash()) } return nil, errKeySignatureGenerateUnimplemented } //nolint:dupl,cyclop func verifyKeySignature( message, remoteKeySignature []byte, hashAlgorithm hash.Algorithm, signatureAlgorithm signature.Algorithm, rawCertificates [][]byte, ) error { if len(rawCertificates) == 0 { return errLengthMismatch } certificate, err := x509.ParseCertificate(rawCertificates[0]) if err != nil { return err } // Validate that the signature algorithm matches the certificate's OID if err := validateSignatureAlgOID(certificate, signatureAlgorithm); err != nil { return err } switch pubKey := certificate.PublicKey.(type) { case ed25519.PublicKey: if ok := ed25519.Verify(pubKey, message, remoteKeySignature); !ok { return errKeySignatureMismatch } return nil case *ecdsa.PublicKey: ecdsaSig := &ecdsaSignature{} if _, err := asn1.Unmarshal(remoteKeySignature, ecdsaSig); err != nil { return err } if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 { return errInvalidECDSASignature } hashed := hashAlgorithm.Digest(message) if !ecdsa.Verify(pubKey, hashed, ecdsaSig.R, ecdsaSig.S) { return errKeySignatureMismatch } return nil case *rsa.PublicKey: hashed := hashAlgorithm.Digest(message) // Use RSA-PSS verification if the signature algorithm is PSS if signatureAlgorithm.IsPSS() { pssOpts := &rsa.PSSOptions{ SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: hashAlgorithm.CryptoHash(), } if err := rsa.VerifyPSS(pubKey, hashAlgorithm.CryptoHash(), hashed, remoteKeySignature, pssOpts); err != nil { return errKeySignatureMismatch } return nil } // Otherwise use PKCS#1 v1.5 if rsa.VerifyPKCS1v15(pubKey, hashAlgorithm.CryptoHash(), hashed, remoteKeySignature) != nil { return errKeySignatureMismatch } return nil } return errKeySignatureVerifyUnimplemented } // If the server has sent a CertificateRequest message, the client MUST send the Certificate // message. The ClientKeyExchange message is now sent, and the content // of that message will depend on the public key algorithm selected // between the ClientHello and the ServerHello. If the client has sent // a certificate with signing ability, a digitally-signed // CertificateVerify message is sent to explicitly verify possession of // the private key in the certificate. // https://tools.ietf.org/html/rfc5246#section-7.3 func generateCertificateVerify( handshakeBodies []byte, signer crypto.Signer, hashAlgorithm hash.Algorithm, signatureAlgorithm signature.Algorithm, ) ([]byte, error) { if _, ok := signer.Public().(ed25519.PublicKey); ok { // https://pkg.go.dev/crypto/ed25519#PrivateKey.Sign // Sign signs the given message with priv. Ed25519 performs two passes over // messages to be signed and therefore cannot handle pre-hashed messages. return signer.Sign(rand.Reader, handshakeBodies, crypto.Hash(0)) } hashed := hashAlgorithm.Digest(handshakeBodies) switch signer.Public().(type) { case *ecdsa.PublicKey: return signer.Sign(rand.Reader, hashed, hashAlgorithm.CryptoHash()) case *rsa.PublicKey: // Use RSA-PSS if the signature algorithm is PSS if signatureAlgorithm.IsPSS() { pssOpts := &rsa.PSSOptions{ SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: hashAlgorithm.CryptoHash(), } return signer.Sign(rand.Reader, hashed, pssOpts) } // Otherwise use PKCS#1 v1.5 return signer.Sign(rand.Reader, hashed, hashAlgorithm.CryptoHash()) } return nil, errInvalidSignatureAlgorithm } //nolint:dupl,cyclop func verifyCertificateVerify( handshakeBodies []byte, hashAlgorithm hash.Algorithm, signatureAlgorithm signature.Algorithm, remoteKeySignature []byte, rawCertificates [][]byte, ) error { if len(rawCertificates) == 0 { return errLengthMismatch } certificate, err := x509.ParseCertificate(rawCertificates[0]) if err != nil { return err } // Validate that the signature algorithm matches the certificate's OID if err := validateSignatureAlgOID(certificate, signatureAlgorithm); err != nil { return err } switch pubKey := certificate.PublicKey.(type) { case ed25519.PublicKey: if ok := ed25519.Verify(pubKey, handshakeBodies, remoteKeySignature); !ok { return errKeySignatureMismatch } return nil case *ecdsa.PublicKey: ecdsaSig := &ecdsaSignature{} if _, err := asn1.Unmarshal(remoteKeySignature, ecdsaSig); err != nil { return err } if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 { return errInvalidECDSASignature } hash := hashAlgorithm.Digest(handshakeBodies) if !ecdsa.Verify(pubKey, hash, ecdsaSig.R, ecdsaSig.S) { return errKeySignatureMismatch } return nil case *rsa.PublicKey: hash := hashAlgorithm.Digest(handshakeBodies) // Use RSA-PSS verification if the signature algorithm is PSS if signatureAlgorithm.IsPSS() { pssOpts := &rsa.PSSOptions{ SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: hashAlgorithm.CryptoHash(), } if err := rsa.VerifyPSS(pubKey, hashAlgorithm.CryptoHash(), hash, remoteKeySignature, pssOpts); err != nil { return errKeySignatureMismatch } return nil } // Otherwise use PKCS#1 v1.5 if rsa.VerifyPKCS1v15(pubKey, hashAlgorithm.CryptoHash(), hash, remoteKeySignature) != nil { return errKeySignatureMismatch } return nil } return errKeySignatureVerifyUnimplemented } func loadCerts(rawCertificates [][]byte) ([]*x509.Certificate, error) { if len(rawCertificates) == 0 { return nil, errLengthMismatch } certs := make([]*x509.Certificate, 0, len(rawCertificates)) for _, rawCert := range rawCertificates { cert, err := x509.ParseCertificate(rawCert) if err != nil { return nil, err } certs = append(certs, cert) } return certs, nil } func verifyClientCert( rawCertificates [][]byte, roots *x509.CertPool, certSignatureSchemes []signaturehash.Algorithm, ) (chains [][]*x509.Certificate, err error) { certificate, err := loadCerts(rawCertificates) if err != nil { return nil, err } intermediateCAPool := x509.NewCertPool() for _, cert := range certificate[1:] { intermediateCAPool.AddCert(cert) } opts := x509.VerifyOptions{ Roots: roots, CurrentTime: time.Now(), Intermediates: intermediateCAPool, KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, } chains, err = certificate[0].Verify(opts) if err != nil { return nil, err } // Validate certificate signature algorithms if specified. // At least one chain must use only allowed signature algorithms. if len(certSignatureSchemes) > 0 && len(chains) > 0 { var validChainFound bool for _, chain := range chains { if err := validateCertificateSignatureAlgorithms(chain, certSignatureSchemes); err == nil { validChainFound = true break } } if !validChainFound { return nil, errInvalidCertificateSignatureAlgorithm } } return chains, nil } func verifyServerCert( rawCertificates [][]byte, roots *x509.CertPool, serverName string, certSignatureSchemes []signaturehash.Algorithm, ) (chains [][]*x509.Certificate, err error) { certificate, err := loadCerts(rawCertificates) if err != nil { return nil, err } intermediateCAPool := x509.NewCertPool() for _, cert := range certificate[1:] { intermediateCAPool.AddCert(cert) } opts := x509.VerifyOptions{ Roots: roots, CurrentTime: time.Now(), DNSName: serverName, Intermediates: intermediateCAPool, } chains, err = certificate[0].Verify(opts) if err != nil { return nil, err } // Validate certificate signature algorithms if specified. // At least one chain must use only allowed signature algorithms. if len(certSignatureSchemes) > 0 && len(chains) > 0 { var validChainFound bool for _, chain := range chains { if err := validateCertificateSignatureAlgorithms(chain, certSignatureSchemes); err == nil { validChainFound = true break } } if !validChainFound { return nil, errInvalidCertificateSignatureAlgorithm } } return chains, nil } // validateCertificateSignatureAlgorithms validates that all certificates in the chain // use signature algorithms that are in the allowed list. This implements the // signature_algorithms_cert extension validation per RFC 8446 Section 4.2.3. func validateCertificateSignatureAlgorithms( certs []*x509.Certificate, allowedAlgorithms []signaturehash.Algorithm, ) error { if len(allowedAlgorithms) == 0 { // No restrictions specified return nil } // Validate each certificate's signature algorithm (except the root, which we trust) for i := 0; i < len(certs)-1; i++ { cert := certs[i] certAlg, err := signaturehash.FromCertificate(cert) if err != nil { return err } // Check if this algorithm is in the allowed list found := false for _, allowed := range allowedAlgorithms { if certAlg.Hash == allowed.Hash && certAlg.Signature == allowed.Signature { found = true break } } if !found { return errInvalidCertificateSignatureAlgorithm } } return nil } dtls-3.1.2/crypto_test.go000066400000000000000000000517371514330267300154120ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "crypto/rand" "crypto/x509" "encoding/pem" "math/big" "testing" "github.com/pion/dtls/v3/pkg/crypto/elliptic" "github.com/pion/dtls/v3/pkg/crypto/hash" "github.com/pion/dtls/v3/pkg/crypto/signature" "github.com/pion/dtls/v3/pkg/crypto/signaturehash" "github.com/stretchr/testify/assert" ) // RSA-PSS certificate with id-RSASSA-PSS OID (1.2.840.113549.1.1.10) // Generated with: // // openssl genpkey -algorithm RSA-PSS -out rsa_pss_key.pem -pkeyopt rsa_keygen_bits:2048 // openssl req -new -x509 -key rsa_pss_key.pem -out rsa_pss_cert.pem -days 365 -subj "/CN=RSA-PSS-Test" // // Note: Go's x509.CreateCertificate does not support creating RSA-PSS certificates, // and x509.ParsePKCS8PrivateKey cannot parse RSA-PSS private keys (fails with // "PKCS#8 wrapping contained private key with unknown algorithm: 1.2.840.113549.1.1.10"). // Therefore we use this cert for OID validation testing only. // // nolint: gosec const rsaPSSCertificate = ` -----BEGIN CERTIFICATE----- MIIDdTCCAimgAwIBAgIUOvVXWgzlj9KVp4TQe+ZATB3PkvswQQYJKoZIhvcNAQEK MDSgDzANBglghkgBZQMEAgEFAKEcMBoGCSqGSIb3DQEBCDANBglghkgBZQMEAgEF AKIDAgEgMBcxFTATBgNVBAMMDFJTQS1QU1MtVGVzdDAeFw0yNjAxMjQwNDE1MzFa Fw0yNzAxMjQwNDE1MzFaMBcxFTATBgNVBAMMDFJTQS1QU1MtVGVzdDCCASAwCwYJ KoZIhvcNAQEKA4IBDwAwggEKAoIBAQCpwVkHm2eU336pNtW7VYuu7nWUkSZxr9Oz DAQrZbLsdcSeWj/sSe37/EPmtQrH8f8mK7OR7mY1DrodHyAqyGeeHIwTaAMdrrMX X0RiPbid7w6MU3QZ1q5Hp8IAf8sLrQofchFRLDw6XkMcI4hbWtVJ9GwZiOO2gpDk uS7SBLEiEzKHme+UzPMFUa2xCypYd/bpO0F+h9vtPDFTCRfK6EFf7mb/QAl1UwfO Xq5+hMMiKWyhK2OIKhYc98k7eV7nlC4rz5tMY2v1tUJA6/fAZEmAREVE740hxmkN qN5Enm5tF/ipROPbmQnyCkwtZxKTLi0tz8RTq7lZXRoQr9fo/6ufAgMBAAGjUzBR MB0GA1UdDgQWBBRpdc2ssJhWnWTm4DPJLW3aDy71WTAfBgNVHSMEGDAWgBRpdc2s sJhWnWTm4DPJLW3aDy71WTAPBgNVHRMBAf8EBTADAQH/MEEGCSqGSIb3DQEBCjA0 oA8wDQYJYIZIAWUDBAIBBQChHDAaBgkqhkiG9w0BAQgwDQYJYIZIAWUDBAIBBQCi AwIBIAOCAQEATkolVgnlASfTEvMElGmrLTRVPBovk7ZCpER+/H316xswuUDWKn9t BUhSCYinj5yywgwgx4sErnB5YkB+SR2kkE8WMAU0SNTh2kLUr4TrdqM1o0S5hGQT awGCPIWZjip3V0TeAqC4sWTgdy2EBYPEJ0AZGm50/yJlWiOzsdDbzceKjremCxLF Qgkrd/H9mRfIsybvQZ0SbhCWTbNiGpv+O3q4rJ8l3FiaNc9xt+9/FbzeRIipmVb3 ACeCkdjZt/3rjb/tZRHcURgXYi2109wQOaIE5tAQYFCvaKp3HNdWGU1K5+AO0SIY k2mwB2RsEXa29/Xzj1eMyG33CDgo55AtDw== -----END CERTIFICATE----- ` // nolint: gosec const rawPrivateKey = ` -----BEGIN RSA PRIVATE KEY----- MIIEowIBAAKCAQEAxIA2BrrnR2sIlATsp7aRBD/3krwZ7vt9dNeoDQAee0s6SuYP 6MBx/HPnAkwNvPS90R05a7pwRkoT6Ur4PfPhCVlUe8lV+0Eto3ZSEeHz3HdsqlM3 bso67L7Dqrc7MdVstlKcgJi8yeAoGOIL9/igOv0XBFCeznm9nznx6mnsR5cugw+1 ypXelaHmBCLV7r5SeVSh57+KhvZGbQ2fFpUaTPegRpJZXBNS8lSeWvtOv9d6N5UB ROTAJodMZT5AfX0jB0QB9IT/0I96H6BSENH08NXOeXApMuLKvnAf361rS7cRAfRL rWZqERMP4u6Cnk0Cnckc3WcW27kGGIbtwbqUIQIDAQABAoIBAGF7OVIdZp8Hejn0 N3L8HvT8xtUEe9kS6ioM0lGgvX5s035Uo4/T6LhUx0VcdXRH9eLHnLTUyN4V4cra ZkxVsE3zAvZl60G6E+oDyLMWZOP6Wu4kWlub9597A5atT7BpMIVCdmFVZFLB4SJ3 AXkC3nplFAYP+Lh1rJxRIrIn2g+pEeBboWbYA++oDNuMQffDZaokTkJ8Bn1JZYh0 xEXKY8Bi2Egd5NMeZa1UFO6y8tUbZfwgVs6Enq5uOgtfayq79vZwyjj1kd29MBUD 8g8byV053ZKxbUOiOuUts97eb+fN3DIDRTcT2c+lXt/4C54M1FclJAbtYRK/qwsl pYWKQAECgYEA4ZUbqQnTo1ICvj81ifGrz+H4LKQqe92Hbf/W51D/Umk2kP702W22 HP4CvrJRtALThJIG9m2TwUjl/WAuZIBrhSAbIvc3Fcoa2HjdRp+sO5U1ueDq7d/S Z+PxRI8cbLbRpEdIaoR46qr/2uWZ943PHMv9h4VHPYn1w8b94hwD6vkCgYEA3v87 mFLzyM9ercnEv9zHMRlMZFQhlcUGQZvfb8BuJYl/WogyT6vRrUuM0QXULNEPlrin mBQTqc1nCYbgkFFsD2VVt1qIyiAJsB9MD1LNV6YuvE7T2KOSadmsA4fa9PUqbr71 hf3lTTq+LeR09LebO7WgSGYY+5YKVOEGpYMR1GkCgYEAxPVQmk3HKHEhjgRYdaG5 lp9A9ZE8uruYVJWtiHgzBTxx9TV2iST+fd/We7PsHFTfY3+wbpcMDBXfIVRKDVwH BMwchXH9+Ztlxx34bYJaegd0SmA0Hw9ugWEHNgoSEmWpM1s9wir5/ELjc7dGsFtz uzvsl9fpdLSxDYgAAdzeGtkCgYBAzKIgrVox7DBzB8KojhtD5ToRnXD0+H/M6OKQ srZPKhlb0V/tTtxrIx0UUEFLlKSXA6mPw6XDHfDnD86JoV9pSeUSlrhRI+Ysy6tq eIE7CwthpPZiaYXORHZ7wCqcK/HcpJjsCs9rFbrV0yE5S3FMdIbTAvgXg44VBB7O UbwIoQKBgDuY8gSrA5/A747wjjmsdRWK4DMTMEV4eCW1BEP7Tg7Cxd5n3xPJiYhr nhLGN+mMnVIcv2zEMS0/eNZr1j/0BtEdx+3IC6Eq+ONY0anZ4Irt57/5QeKgKn/L JPhfPySIPG4UmwE4gW8t79vfOKxnUu2fDD1ZXUYopan6EckACNH/ -----END RSA PRIVATE KEY----- ` func TestGenerateKeySignature(t *testing.T) { block, _ := pem.Decode([]byte(rawPrivateKey)) key, err := x509.ParsePKCS1PrivateKey(block.Bytes) assert.NoError(t, err) clientRandom := []byte{ 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, } serverRandom := []byte{ 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f, } publicKey := []byte{ 0x20, 0x9f, 0xd7, 0xad, 0x6d, 0xcf, 0xf4, 0x29, 0x8d, 0xd3, 0xf9, 0x6d, 0x5b, 0x1b, 0x2a, 0xf9, 0x10, 0xa0, 0x53, 0x5b, 0x14, 0x88, 0xd7, 0xf8, 0xfa, 0xbb, 0x34, 0x9a, 0x98, 0x28, 0x80, 0xb6, 0x15, } expectedSignature := []byte{ 0x6f, 0x47, 0x97, 0x85, 0xcc, 0x76, 0x50, 0x93, 0xbd, 0xe2, 0x6a, 0x69, 0x0b, 0xc3, 0x03, 0xd1, 0xb7, 0xe4, 0xab, 0x88, 0x7b, 0xa6, 0x52, 0x80, 0xdf, 0xaa, 0x25, 0x7a, 0xdb, 0x29, 0x32, 0xe4, 0xd8, 0x28, 0x28, 0xb3, 0xe8, 0x04, 0x3c, 0x38, 0x16, 0xfc, 0x78, 0xe9, 0x15, 0x7b, 0xc5, 0xbd, 0x7d, 0xfc, 0xcd, 0x83, 0x00, 0x57, 0x4a, 0x3c, 0x23, 0x85, 0x75, 0x6b, 0x37, 0xd5, 0x89, 0x72, 0x73, 0xf0, 0x44, 0x8c, 0x00, 0x70, 0x1f, 0x6e, 0xa2, 0x81, 0xd0, 0x09, 0xc5, 0x20, 0x36, 0xab, 0x23, 0x09, 0x40, 0x1f, 0x4d, 0x45, 0x96, 0x62, 0xbb, 0x81, 0xb0, 0x30, 0x72, 0xad, 0x3a, 0x0a, 0xac, 0x31, 0x63, 0x40, 0x52, 0x0a, 0x27, 0xf3, 0x34, 0xde, 0x27, 0x7d, 0xb7, 0x54, 0xff, 0x0f, 0x9f, 0x5a, 0xfe, 0x07, 0x0f, 0x4e, 0x9f, 0x53, 0x04, 0x34, 0x62, 0xf4, 0x30, 0x74, 0x83, 0x35, 0xfc, 0xe4, 0x7e, 0xbf, 0x5a, 0xc4, 0x52, 0xd0, 0xea, 0xf9, 0x61, 0x4e, 0xf5, 0x1c, 0x0e, 0x58, 0x02, 0x71, 0xfb, 0x1f, 0x34, 0x55, 0xe8, 0x36, 0x70, 0x3c, 0xc1, 0xcb, 0xc9, 0xb7, 0xbb, 0xb5, 0x1c, 0x44, 0x9a, 0x6d, 0x88, 0x78, 0x98, 0xd4, 0x91, 0x2e, 0xeb, 0x98, 0x81, 0x23, 0x30, 0x73, 0x39, 0x43, 0xd5, 0xbb, 0x70, 0x39, 0xba, 0x1f, 0xdb, 0x70, 0x9f, 0x91, 0x83, 0x56, 0xc2, 0xde, 0xed, 0x17, 0x6d, 0x2c, 0x3e, 0x21, 0xea, 0x36, 0xb4, 0x91, 0xd8, 0x31, 0x05, 0x60, 0x90, 0xfd, 0xc6, 0x74, 0xa9, 0x7b, 0x18, 0xfc, 0x1c, 0x6a, 0x1c, 0x6e, 0xec, 0xd3, 0xc1, 0xc0, 0x0d, 0x11, 0x25, 0x48, 0x37, 0x3d, 0x45, 0x11, 0xa2, 0x31, 0x14, 0x0a, 0x66, 0x9f, 0xd8, 0xac, 0x74, 0xa2, 0xcd, 0xc8, 0x79, 0xb3, 0x9e, 0xc6, 0x66, 0x25, 0xcf, 0x2c, 0x87, 0x5e, 0x5c, 0x36, 0x75, 0x86, } signature, err := generateKeySignature(clientRandom, serverRandom, publicKey, elliptic.X25519, key, hash.SHA256, signature.RSA) assert.NoError(t, err) assert.Equal(t, expectedSignature, signature) } func TestRSAPSSSignatureGeneration(t *testing.T) { clientRandom := []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07} serverRandom := []byte{0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f} publicKey := []byte{0x10, 0x11, 0x12, 0x13} // Parse the private key block, _ := pem.Decode([]byte(rawPrivateKey)) key, err := x509.ParsePKCS1PrivateKey(block.Bytes) assert.NoError(t, err) // Generate PSS signature sig, err := generateKeySignature(clientRandom, serverRandom, publicKey, elliptic.X25519, key, hash.SHA256, signature.RSA_PSS_RSAE_SHA256) assert.NoError(t, err) assert.NotNil(t, sig) // Verify that PSS signature is different from PKCS#1 v1.5 (PSS is randomized) sig2, err := generateKeySignature(clientRandom, serverRandom, publicKey, elliptic.X25519, key, hash.SHA256, signature.RSA_PSS_RSAE_SHA256) assert.NoError(t, err) // PSS signatures should be different each time due to random salt assert.NotEqual(t, sig, sig2) } func TestRSAPSSSignatureVerification(t *testing.T) { clientRandom := []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07} serverRandom := []byte{0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f} publicKey := []byte{0x10, 0x11, 0x12, 0x13} // Parse the private key block, _ := pem.Decode([]byte(rawPrivateKey)) key, err := x509.ParsePKCS1PrivateKey(block.Bytes) assert.NoError(t, err) // Generate certificate with the public key cert := &x509.Certificate{ SerialNumber: big.NewInt(1), PublicKey: &key.PublicKey, } rawCert, err := x509.CreateCertificate(rand.Reader, cert, cert, &key.PublicKey, key) assert.NoError(t, err) // Generate PSS signature sig, err := generateKeySignature(clientRandom, serverRandom, publicKey, elliptic.X25519, key, hash.SHA256, signature.RSA_PSS_RSAE_SHA256) assert.NoError(t, err) // Verify PSS signature expectedMsg := valueKeyMessage(clientRandom, serverRandom, publicKey, elliptic.X25519) err = verifyKeySignature(expectedMsg, sig, hash.SHA256, signature.RSA_PSS_RSAE_SHA256, [][]byte{rawCert}) assert.NoError(t, err) // Verify that PKCS#1 v1.5 verification fails for PSS signature err = verifyKeySignature(expectedMsg, sig, hash.SHA256, signature.RSA, [][]byte{rawCert}) assert.Error(t, err) } func TestRSAPSSVsPKCS1v15(t *testing.T) { clientRandom := []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07} serverRandom := []byte{0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f} publicKey := []byte{0x10, 0x11, 0x12, 0x13} // Parse the private key block, _ := pem.Decode([]byte(rawPrivateKey)) key, err := x509.ParsePKCS1PrivateKey(block.Bytes) assert.NoError(t, err) // Generate certificate cert := &x509.Certificate{ SerialNumber: big.NewInt(1), PublicKey: &key.PublicKey, } rawCert, err := x509.CreateCertificate(rand.Reader, cert, cert, &key.PublicKey, key) assert.NoError(t, err) expectedMsg := valueKeyMessage(clientRandom, serverRandom, publicKey, elliptic.X25519) // Generate and verify PKCS#1 v1.5 signature pkcs1Sig, err := generateKeySignature(clientRandom, serverRandom, publicKey, elliptic.X25519, key, hash.SHA256, signature.RSA) assert.NoError(t, err) err = verifyKeySignature(expectedMsg, pkcs1Sig, hash.SHA256, signature.RSA, [][]byte{rawCert}) assert.NoError(t, err) // Generate and verify PSS signature pssSig, err := generateKeySignature(clientRandom, serverRandom, publicKey, elliptic.X25519, key, hash.SHA256, signature.RSA_PSS_RSAE_SHA256) assert.NoError(t, err) err = verifyKeySignature(expectedMsg, pssSig, hash.SHA256, signature.RSA_PSS_RSAE_SHA256, [][]byte{rawCert}) assert.NoError(t, err) // Verify cross-verification fails err = verifyKeySignature(expectedMsg, pkcs1Sig, hash.SHA256, signature.RSA_PSS_RSAE_SHA256, [][]byte{rawCert}) assert.Error(t, err, "PKCS#1 v1.5 signature should not verify as PSS") err = verifyKeySignature(expectedMsg, pssSig, hash.SHA256, signature.RSA, [][]byte{rawCert}) assert.Error(t, err, "PSS signature should not verify as PKCS#1 v1.5") } func TestRSAPSSRSAEVariants(t *testing.T) { clientRandom := []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07} serverRandom := []byte{0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f} publicKey := []byte{0x10, 0x11, 0x12, 0x13} // Parse the private key block, _ := pem.Decode([]byte(rawPrivateKey)) key, err := x509.ParsePKCS1PrivateKey(block.Bytes) assert.NoError(t, err) // Generate certificate with rsaEncryption OID (standard RSA cert) cert := &x509.Certificate{ SerialNumber: big.NewInt(1), PublicKey: &key.PublicKey, } rawCert, err := x509.CreateCertificate(rand.Reader, cert, cert, &key.PublicKey, key) assert.NoError(t, err) expectedMsg := valueKeyMessage(clientRandom, serverRandom, publicKey, elliptic.X25519) // Test RSA-PSS RSAE variants (work with standard RSA certs) // Note: We don't test RSA_PSS_PSS variants here because they require id-RSASSA-PSS OID certs, // which Go's x509.CreateCertificate doesn't support creating (and can't parse properly either). // OID validation is tested separately in TestCertificateOIDValidation. testCases := []struct { name string hashAlgo hash.Algorithm sigAlgo signature.Algorithm }{ {"RSA_PSS_RSAE_SHA256", hash.SHA256, signature.RSA_PSS_RSAE_SHA256}, {"RSA_PSS_RSAE_SHA384", hash.SHA384, signature.RSA_PSS_RSAE_SHA384}, {"RSA_PSS_RSAE_SHA512", hash.SHA512, signature.RSA_PSS_RSAE_SHA512}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { // Generate signature sig, err := generateKeySignature(clientRandom, serverRandom, publicKey, elliptic.X25519, key, tc.hashAlgo, tc.sigAlgo) assert.NoError(t, err) assert.NotNil(t, sig) assert.True(t, len(sig) > 0, "Signature should not be empty") // Verify signature err = verifyKeySignature(expectedMsg, sig, tc.hashAlgo, tc.sigAlgo, [][]byte{rawCert}) assert.NoError(t, err, "Signature verification should succeed") // Verify IsPSS() returns true assert.True(t, tc.sigAlgo.IsPSS(), "Should be identified as PSS algorithm") // Verify GetPSSHash() returns correct hash assert.Equal(t, tc.hashAlgo, tc.sigAlgo.GetPSSHash(), "Hash extraction should match") }) } } func TestCertificateOIDValidation(t *testing.T) { clientRandom := []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07} serverRandom := []byte{0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f} publicKey := []byte{0x10, 0x11, 0x12, 0x13} // Load standard RSA key and cert (has rsaEncryption OID) block, _ := pem.Decode([]byte(rawPrivateKey)) rsaKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) assert.NoError(t, err) rsaEncryptionCert := &x509.Certificate{ SerialNumber: big.NewInt(1), PublicKey: &rsaKey.PublicKey, } rsaEncryptionCertBytes, err := x509.CreateCertificate( rand.Reader, rsaEncryptionCert, rsaEncryptionCert, &rsaKey.PublicKey, rsaKey, ) assert.NoError(t, err) // Load RSA-PSS cert (has id-RSASSA-PSS OID) // We use a locally generated RSA-PSS cert since Go's x509.CreateCertificate doesn't support creating them. // We use the regular RSA key for signing because Go can't parse RSA-PSS private keys either. // For OID validation testing, only the cert's OID matters, not which key was used to sign. pssCertBlock, _ := pem.Decode([]byte(rsaPSSCertificate)) pssCertBytes := pssCertBlock.Bytes expectedMsg := valueKeyMessage(clientRandom, serverRandom, publicKey, elliptic.X25519) t.Run("RSAE_with_rsaEncryption_OID_succeeds", func(t *testing.T) { // Generate signature with RSAE algorithm using rsaEncryption cert sig, err := generateKeySignature(clientRandom, serverRandom, publicKey, elliptic.X25519, rsaKey, hash.SHA256, signature.RSA_PSS_RSAE_SHA256) assert.NoError(t, err) // Should succeed: RSAE + rsaEncryption OID is valid per RFC 8446 err = verifyKeySignature( expectedMsg, sig, hash.SHA256, signature.RSA_PSS_RSAE_SHA256, [][]byte{rsaEncryptionCertBytes}, ) assert.NoError(t, err) }) t.Run("PSS_with_idRSASSAPSS_OID_succeeds", func(t *testing.T) { t.Skip("Go's x509 library cannot extract public key from RSA-PSS certificates (OID 1.2.840.113549.1.1.10)") // This test would verify that PSS + id-RSASSA-PSS OID is valid per RFC 8446, // but Go's crypto/x509 doesn't fully support parsing RSA-PSS certs. // The important validation (that mismatches are rejected) is tested in other cases. }) t.Run("PSS_with_rsaEncryption_OID_fails", func(t *testing.T) { // Generate signature with PSS algorithm sig, err := generateKeySignature(clientRandom, serverRandom, publicKey, elliptic.X25519, rsaKey, hash.SHA256, signature.RSA_PSS_PSS_SHA256) assert.NoError(t, err) // Should fail: PSS algorithm requires id-RSASSA-PSS OID, not rsaEncryption err = verifyKeySignature( expectedMsg, sig, hash.SHA256, signature.RSA_PSS_PSS_SHA256, [][]byte{rsaEncryptionCertBytes}, ) assert.Error(t, err) assert.ErrorIs(t, err, errInvalidCertificateOID) }) t.Run("RSAE_with_idRSASSAPSS_OID_fails", func(t *testing.T) { // Generate signature with RSAE algorithm sig, err := generateKeySignature(clientRandom, serverRandom, publicKey, elliptic.X25519, rsaKey, hash.SHA256, signature.RSA_PSS_RSAE_SHA256) assert.NoError(t, err) // Should fail: RSAE algorithm requires rsaEncryption OID, not id-RSASSA-PSS err = verifyKeySignature(expectedMsg, sig, hash.SHA256, signature.RSA_PSS_RSAE_SHA256, [][]byte{pssCertBytes}) assert.Error(t, err) assert.ErrorIs(t, err, errInvalidCertificateOID) }) } func TestValidateCertificateSignatureAlgorithms(t *testing.T) { // Helper to create a test certificate with specific signature algorithm createTestCert := func(sigAlg x509.SignatureAlgorithm, isCA bool) *x509.Certificate { return &x509.Certificate{ SerialNumber: big.NewInt(1), SignatureAlgorithm: sigAlg, IsCA: isCA, } } t.Run("Empty allowed list passes", func(t *testing.T) { certs := []*x509.Certificate{ createTestCert(x509.SHA256WithRSA, false), } err := validateCertificateSignatureAlgorithms(certs, nil) assert.NoError(t, err) }) t.Run("Single cert with allowed algorithm passes", func(t *testing.T) { certs := []*x509.Certificate{ createTestCert(x509.SHA256WithRSA, false), createTestCert(x509.SHA256WithRSA, true), // Root } allowed := []signaturehash.Algorithm{ {Hash: hash.SHA256, Signature: signature.RSA}, } err := validateCertificateSignatureAlgorithms(certs, allowed) assert.NoError(t, err) }) t.Run("Single cert with disallowed algorithm fails", func(t *testing.T) { certs := []*x509.Certificate{ createTestCert(x509.SHA256WithRSA, false), createTestCert(x509.SHA256WithRSA, true), // Root } allowed := []signaturehash.Algorithm{ {Hash: hash.SHA384, Signature: signature.ECDSA}, // Different algorithm } err := validateCertificateSignatureAlgorithms(certs, allowed) assert.ErrorIs(t, err, errInvalidCertificateSignatureAlgorithm) }) t.Run("Root certificate is not validated", func(t *testing.T) { certs := []*x509.Certificate{ createTestCert(x509.SHA256WithRSA, false), // Leaf - validated createTestCert(x509.SHA384WithRSA, true), // Root - NOT validated } allowed := []signaturehash.Algorithm{ {Hash: hash.SHA256, Signature: signature.RSA}, // Only allows SHA256 } // Should pass because root (SHA384) is not validated err := validateCertificateSignatureAlgorithms(certs, allowed) assert.NoError(t, err) }) t.Run("Multi-cert chain with all allowed algorithms passes", func(t *testing.T) { certs := []*x509.Certificate{ createTestCert(x509.SHA256WithRSA, false), // Leaf createTestCert(x509.SHA384WithRSA, false), // Intermediate createTestCert(x509.SHA512WithRSA, true), // Root (not validated) } allowed := []signaturehash.Algorithm{ {Hash: hash.SHA256, Signature: signature.RSA}, {Hash: hash.SHA384, Signature: signature.RSA}, // SHA512 not needed since root is not validated } err := validateCertificateSignatureAlgorithms(certs, allowed) assert.NoError(t, err) }) t.Run("Multi-cert chain with one disallowed intermediate fails", func(t *testing.T) { certs := []*x509.Certificate{ createTestCert(x509.SHA256WithRSA, false), // Leaf - allowed createTestCert(x509.SHA384WithRSA, false), // Intermediate - NOT allowed createTestCert(x509.SHA512WithRSA, true), // Root } allowed := []signaturehash.Algorithm{ {Hash: hash.SHA256, Signature: signature.RSA}, // Only allows SHA256 } err := validateCertificateSignatureAlgorithms(certs, allowed) assert.ErrorIs(t, err, errInvalidCertificateSignatureAlgorithm) }) t.Run("ECDSA certificates", func(t *testing.T) { certs := []*x509.Certificate{ createTestCert(x509.ECDSAWithSHA256, false), createTestCert(x509.ECDSAWithSHA384, false), createTestCert(x509.ECDSAWithSHA512, true), // Root } allowed := []signaturehash.Algorithm{ {Hash: hash.SHA256, Signature: signature.ECDSA}, {Hash: hash.SHA384, Signature: signature.ECDSA}, } err := validateCertificateSignatureAlgorithms(certs, allowed) assert.NoError(t, err) }) t.Run("RSA-PSS certificates", func(t *testing.T) { certs := []*x509.Certificate{ createTestCert(x509.SHA256WithRSAPSS, false), createTestCert(x509.SHA384WithRSAPSS, true), // Root } allowed := []signaturehash.Algorithm{ {Hash: hash.SHA256, Signature: signature.RSA}, } err := validateCertificateSignatureAlgorithms(certs, allowed) assert.NoError(t, err) }) t.Run("Ed25519 certificates", func(t *testing.T) { certs := []*x509.Certificate{ createTestCert(x509.PureEd25519, false), createTestCert(x509.PureEd25519, true), // Root } allowed := []signaturehash.Algorithm{ {Hash: hash.None, Signature: signature.Ed25519}, } err := validateCertificateSignatureAlgorithms(certs, allowed) assert.NoError(t, err) }) t.Run("Unsupported certificate algorithm", func(t *testing.T) { certs := []*x509.Certificate{ createTestCert(x509.MD5WithRSA, false), // MD5 not supported createTestCert(x509.SHA256WithRSA, true), } allowed := []signaturehash.Algorithm{ {Hash: hash.SHA256, Signature: signature.RSA}, } err := validateCertificateSignatureAlgorithms(certs, allowed) assert.Error(t, err) // Should error from FromCertificate, not from algorithm mismatch }) t.Run("Single cert chain does not validate", func(t *testing.T) { // Single cert is treated as self-signed root, which is not validated certs := []*x509.Certificate{ createTestCert(x509.SHA256WithRSA, true), // Root } allowed := []signaturehash.Algorithm{ {Hash: hash.SHA384, Signature: signature.ECDSA}, // Different algorithm } // Should pass because single root cert is not validated err := validateCertificateSignatureAlgorithms(certs, allowed) assert.NoError(t, err) }) } dtls-3.1.2/dtls.go000066400000000000000000000002731514330267300137660ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT // Package dtls implements Datagram Transport Layer Security (DTLS) 1.2 package dtls dtls-3.1.2/e2e/000077500000000000000000000000001514330267300131425ustar00rootroot00000000000000dtls-3.1.2/e2e/Dockerfile000066400000000000000000000004161514330267300151350ustar00rootroot00000000000000# SPDX-FileCopyrightText: 2026 The Pion community # SPDX-License-Identifier: MIT FROM docker.io/library/golang:1.24-bullseye COPY . /go/src/github.com/pion/dtls WORKDIR /go/src/github.com/pion/dtls/e2e CMD ["go", "test", "-tags=openssl", "-v", "."] dtls-3.1.2/e2e/e2e.go000066400000000000000000000002511514330267300141420ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT // Package e2e contains end to end tests for pion/dtls package e2e dtls-3.1.2/e2e/e2e_lossy_test.go000066400000000000000000000137021514330267300164370ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT package e2e import ( "fmt" "math/rand" "testing" "time" "github.com/pion/dtls/v3" "github.com/pion/dtls/v3/pkg/crypto/selfsign" dtlsnet "github.com/pion/dtls/v3/pkg/net" transportTest "github.com/pion/transport/v4/test" "github.com/stretchr/testify/assert" ) const ( flightInterval = time.Millisecond * 100 lossyTestTimeout = 30 * time.Second ) // DTLS Client/Server over a lossy transport, just asserts it can handle at increasing increments func TestPionE2ELossy(t *testing.T) { //nolint:cyclop // Check for leaking routines report := transportTest.CheckRoutines(t) defer report() type runResult struct { dtlsConn *dtls.Conn err error } serverCert, err := selfsign.GenerateSelfSigned() assert.NoError(t, err) clientCert, err := selfsign.GenerateSelfSigned() assert.NoError(t, err) for _, test := range []struct { LossChanceRange int DoClientAuth bool CipherSuites []dtls.CipherSuiteID MTU int DisableServerFlightInterval bool }{ { LossChanceRange: 0, }, { LossChanceRange: 10, }, { LossChanceRange: 20, }, { LossChanceRange: 50, }, { LossChanceRange: 0, DoClientAuth: true, }, { LossChanceRange: 10, DoClientAuth: true, }, { LossChanceRange: 20, DoClientAuth: true, }, { LossChanceRange: 50, DoClientAuth: true, }, { LossChanceRange: 0, CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA}, }, { LossChanceRange: 10, CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA}, }, { LossChanceRange: 20, CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA}, }, { LossChanceRange: 50, CipherSuites: []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA}, }, { LossChanceRange: 10, MTU: 100, DoClientAuth: true, }, { LossChanceRange: 20, MTU: 100, DoClientAuth: true, }, { LossChanceRange: 50, MTU: 100, DoClientAuth: true, }, // Incoming retransmitted handshakes should cause us to retransmit. Disabling the FlightInterval on one side // means that a incoming re-transmissions causes the retransmission to be fired { LossChanceRange: 10, DisableServerFlightInterval: true, }, { LossChanceRange: 20, DisableServerFlightInterval: true, }, { LossChanceRange: 50, DisableServerFlightInterval: true, }, } { name := fmt.Sprintf("Loss%d_MTU%d", test.LossChanceRange, test.MTU) if test.DoClientAuth { name += "_WithCliAuth" } for _, ciph := range test.CipherSuites { name += "_With" + ciph.String() } if test.DisableServerFlightInterval { name += "_WithNoServerFlightInterval" } test := test t.Run(name, func(t *testing.T) { // Limit runtime in case of deadlocks lim := transportTest.TimeOut(lossyTestTimeout + time.Second) defer lim.Stop() chosenLoss := rand.Intn(9) + test.LossChanceRange //nolint:gosec serverDone := make(chan runResult) clientDone := make(chan runResult) br := transportTest.NewBridge() assert.NoError(t, br.SetLossChance(chosenLoss)) go func() { clientOpts := []dtls.ClientOption{ dtls.WithFlightInterval(flightInterval), dtls.WithInsecureSkipVerify(true), dtls.WithDisableRetransmitBackoff(true), } if len(test.CipherSuites) > 0 { clientOpts = append(clientOpts, dtls.WithCipherSuites(test.CipherSuites...)) } if test.MTU > 0 { clientOpts = append(clientOpts, dtls.WithMTU(test.MTU)) } if test.DoClientAuth { clientOpts = append(clientOpts, dtls.WithCertificates(clientCert)) } client, startupErr := dtls.ClientWithOptions( dtlsnet.PacketConnFromConn(br.GetConn0()), br.GetConn0().RemoteAddr(), clientOpts..., ) clientDone <- runResult{client, startupErr} }() go func() { serverOpts := []dtls.ServerOption{ dtls.WithCertificates(serverCert), dtls.WithFlightInterval(flightInterval), dtls.WithDisableRetransmitBackoff(true), } if test.MTU > 0 { serverOpts = append(serverOpts, dtls.WithMTU(test.MTU)) } if test.DoClientAuth { serverOpts = append(serverOpts, dtls.WithClientAuth(dtls.RequireAnyClientCert)) } if test.DisableServerFlightInterval { serverOpts = append(serverOpts, dtls.WithFlightInterval(time.Hour)) } server, startupErr := dtls.ServerWithOptions( dtlsnet.PacketConnFromConn(br.GetConn1()), br.GetConn1().RemoteAddr(), serverOpts..., ) serverDone <- runResult{server, startupErr} }() testTimer := time.NewTimer(lossyTestTimeout) var serverConn, clientConn *dtls.Conn defer func() { if serverConn != nil { assert.NoError(t, serverConn.Close()) } if clientConn != nil { assert.NoError(t, clientConn.Close()) } }() for serverConn == nil || clientConn == nil { br.Tick() select { case serverResult := <-serverDone: if serverResult.err != nil { assert.Failf(t, "Fail, serverError", "clientComplete(%t) serverComplete(%t) LossChance(%d) error(%v)", clientConn != nil, serverConn != nil, chosenLoss, serverResult.err) return } serverConn = serverResult.dtlsConn case clientResult := <-clientDone: if clientResult.err != nil { assert.Failf(t, "Fail, clientError", "clientComplete(%t) serverComplete(%t) LossChance(%d) error(%v)", clientConn != nil, serverConn != nil, chosenLoss, clientResult.err) return } clientConn = clientResult.dtlsConn case <-testTimer.C: assert.Failf(t, "Test expired", "clientComplete(%t) serverComplete(%t) LossChance(%d)", clientConn != nil, serverConn != nil, chosenLoss) return case <-time.After(10 * time.Millisecond): } } }) } } dtls-3.1.2/e2e/e2e_openssl_test.go000066400000000000000000000226041514330267300167520ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT //go:build openssl && !js // +build openssl,!js package e2e import ( "crypto/tls" "crypto/x509" "encoding/pem" "errors" "fmt" "io/ioutil" "net" "os" "os/exec" "regexp" "strings" "testing" "time" "github.com/pion/dtls/v3" ) func serverOpenSSL(c *comm) { go func() { c.serverMutex.Lock() defer c.serverMutex.Unlock() // Use information stored in comm struct cipherSuites := c.serverCipherSuites certs := c.serverCertificates psk := c.serverPSK pskHint := c.serverPSKIdentityHint // create openssl arguments args := []string{ "s_server", "-dtls1_2", "-quiet", "-verify_quiet", "-verify_return_error", fmt.Sprintf("-accept=%d", c.serverPort), } ciphers := ciphersFromSuites(cipherSuites) if ciphers != "" { args = append(args, fmt.Sprintf("-cipher=%s", ciphers)) } // psk arguments if psk != nil { pskBytes, err := psk(nil) if err != nil { c.errChan <- err return } args = append(args, fmt.Sprintf("-psk=%X", pskBytes)) if len(pskHint) > 0 { args = append(args, fmt.Sprintf("-psk_hint=%s", pskHint)) } } // certs arguments if len(certs) > 0 { // create temporary cert files certPEM, keyPEM, err := writeTempPEMFromCerts(certs) if err != nil { c.errChan <- err return } args = append(args, fmt.Sprintf("-cert=%s", certPEM), fmt.Sprintf("-key=%s", keyPEM)) defer func() { _ = os.Remove(certPEM) _ = os.Remove(keyPEM) }() } else { args = append(args, "-nocert") } // launch command // #nosec G204 cmd := exec.Command("openssl", args...) var inner net.Conn inner, c.serverConn = net.Pipe() cmd.Stdin = inner cmd.Stdout = inner cmd.Stderr = os.Stderr if err := cmd.Start(); err != nil { c.errChan <- err _ = inner.Close() return } // Ensure that server has started time.Sleep(500 * time.Millisecond) c.serverReady <- struct{}{} simpleReadWrite(c.errChan, c.serverChan, c.serverConn, c.messageRecvCount) c.serverDone <- cmd.Process.Kill() close(c.serverDone) }() } func clientOpenSSL(c *comm) { select { case <-c.serverReady: // OK case <-time.After(time.Second): c.errChan <- errors.New("waiting on serverReady err: timeout") } c.clientMutex.Lock() defer c.clientMutex.Unlock() // Use information stored in comm struct cipherSuites := c.clientCipherSuites certs := c.clientCertificates psk := c.clientPSK insecureSkipVerify := c.clientInsecureSkipVerify // create openssl arguments args := []string{ "s_client", "-dtls1_2", "-quiet", "-verify_quiet", "-servername=localhost", fmt.Sprintf("-connect=127.0.0.1:%d", c.serverPort), } ciphers := ciphersFromSuites(cipherSuites) if ciphers != "" { args = append(args, fmt.Sprintf("-cipher=%s", ciphers)) } // psk arguments if psk != nil { pskBytes, err := psk(nil) if err != nil { c.errChan <- err return } args = append(args, fmt.Sprintf("-psk=%X", pskBytes)) } // certificate arguments if len(certs) > 0 { // create temporary cert files certPEM, keyPEM, err := writeTempPEMFromCerts(certs) if err != nil { c.errChan <- err return } args = append(args, fmt.Sprintf("-CAfile=%s", certPEM), fmt.Sprintf("-cert=%s", certPEM), fmt.Sprintf("-key=%s", keyPEM)) defer func() { _ = os.Remove(certPEM) _ = os.Remove(keyPEM) }() } if !insecureSkipVerify { args = append(args, "-verify_return_error") } // launch command // #nosec G204 cmd := exec.Command("openssl", args...) var inner net.Conn inner, c.clientConn = net.Pipe() cmd.Stdin = inner cmd.Stdout = inner cmd.Stderr = os.Stderr if err := cmd.Start(); err != nil { c.errChan <- err _ = inner.Close() return } simpleReadWrite(c.errChan, c.clientChan, c.clientConn, c.messageRecvCount) c.clientDone <- cmd.Process.Kill() close(c.clientDone) } func ciphersFromSuites(cipherSuites []dtls.CipherSuiteID) string { // See https://tls.mbed.org/supported-ssl-ciphersuites translate := map[dtls.CipherSuiteID]string{ dtls.TLS_ECDHE_ECDSA_WITH_AES_128_CCM: "ECDHE-ECDSA-AES128-CCM", dtls.TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8: "ECDHE-ECDSA-AES128-CCM8", dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256: "ECDHE-ECDSA-AES128-GCM-SHA256", dtls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384: "ECDHE-ECDSA-AES256-GCM-SHA384", dtls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256: "ECDHE-RSA-AES128-GCM-SHA256", dtls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384: "ECDHE-RSA-AES256-GCM-SHA384", dtls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA: "ECDHE-ECDSA-AES256-SHA", dtls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA: "ECDHE-RSA-AES256-SHA", dtls.TLS_PSK_WITH_AES_128_CCM: "PSK-AES128-CCM", dtls.TLS_PSK_WITH_AES_128_CCM_8: "PSK-AES128-CCM8", dtls.TLS_PSK_WITH_AES_256_CCM_8: "PSK-AES256-CCM8", dtls.TLS_PSK_WITH_AES_128_GCM_SHA256: "PSK-AES128-GCM-SHA256", dtls.TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256: "ECDHE-PSK-AES128-CBC-SHA256", } var ciphers []string for _, c := range cipherSuites { if text, ok := translate[c]; ok { ciphers = append(ciphers, text) } } return strings.Join(ciphers, ";") } func writeTempPEMFromCerts(certs []tls.Certificate) (string, string, error) { if len(certs) == 0 { return "", "", fmt.Errorf("no certificates provided") } certOut, err := ioutil.TempFile("", "cert.pem") if err != nil { return "", "", fmt.Errorf("failed to create temporary file: %w", err) } keyOut, err := ioutil.TempFile("", "key.pem") if err != nil { return "", "", fmt.Errorf("failed to create temporary file: %w", err) } cert := certs[0] derBytes := cert.Certificate[0] if err = pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil { return "", "", fmt.Errorf("failed to write data to cert.pem: %w", err) } if err = certOut.Close(); err != nil { return "", "", fmt.Errorf("error closing cert.pem: %w", err) } priv := cert.PrivateKey var privBytes []byte privBytes, err = x509.MarshalPKCS8PrivateKey(priv) if err != nil { return "", "", fmt.Errorf("unable to marshal private key: %w", err) } if err = pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil { return "", "", fmt.Errorf("failed to write data to key.pem: %w", err) } if err = keyOut.Close(); err != nil { return "", "", fmt.Errorf("error closing key.pem: %w", err) } return certOut.Name(), keyOut.Name(), nil } func minimumOpenSSLVersion(t *testing.T) bool { t.Helper() cmd := exec.Command("openssl", "version") allOut, err := cmd.CombinedOutput() if err != nil { t.Log("Cannot determine OpenSSL version: ", err) return false } verMatch := regexp.MustCompile(`(?i)^OpenSSL\s(?P(\d+\.)?(\d+\.)?(\*|\d+)(\w)?).+$`) match := verMatch.FindStringSubmatch(strings.TrimSpace(string(allOut))) params := map[string]string{} for i, name := range verMatch.SubexpNames() { if i > 0 && i <= len(match) { params[name] = match[i] } } var ver string if val, ok := params["version"]; !ok { t.Log("Could not extract OpenSSL version") return false } else { ver = val } cmp := strings.Compare(ver, "3.0.0") if cmp == -1 { return false } return true } func TestPionOpenSSLE2ESimple(t *testing.T) { t.Run("OpenSSLServer", func(t *testing.T) { testPionE2ESimple(t, serverOpenSSL, clientPion) }) t.Run("OpenSSLClient", func(t *testing.T) { testPionE2ESimple(t, serverPion, clientOpenSSL) }) } func TestPionOpenSSLE2ESimplePSK(t *testing.T) { t.Run("OpenSSLServer", func(t *testing.T) { testPionE2ESimplePSK(t, serverOpenSSL, clientPion) }) t.Run("OpenSSLClient", func(t *testing.T) { testPionE2ESimplePSK(t, serverPion, clientOpenSSL) }) } func TestPionOpenSSLE2EMTUs(t *testing.T) { t.Run("OpenSSLServer", func(t *testing.T) { testPionE2EMTUs(t, serverOpenSSL, clientPion) }) t.Run("OpenSSLClient", func(t *testing.T) { testPionE2EMTUs(t, serverPion, clientOpenSSL) }) } func TestPionOpenSSLE2ESimpleED25519(t *testing.T) { t.Run("OpenSSLServer", func(t *testing.T) { if !minimumOpenSSLVersion(t) { t.Skip("Cannot use OpenSSL < 3.0 as a DTLS server with ED25519 keys") } testPionE2ESimpleED25519(t, serverOpenSSL, clientPion) }) t.Run("OpenSSLClient", func(t *testing.T) { testPionE2ESimpleED25519(t, serverPion, clientOpenSSL) }) } func TestPionOpenSSLE2ESimpleED25519ClientCert(t *testing.T) { t.Run("OpenSSLServer", func(t *testing.T) { if !minimumOpenSSLVersion(t) { t.Skip("Cannot use OpenSSL < 3.0 as a DTLS server with ED25519 keys") } testPionE2ESimpleED25519ClientCert(t, serverOpenSSL, clientPion) }) t.Run("OpenSSLClient", func(t *testing.T) { testPionE2ESimpleED25519ClientCert(t, serverPion, clientOpenSSL) }) } func TestPionOpenSSLE2ESimpleECDSAClientCert(t *testing.T) { t.Run("OpenSSLServer", func(t *testing.T) { testPionE2ESimpleECDSAClientCert(t, serverOpenSSL, clientPion) }) t.Run("OpenSSLClient", func(t *testing.T) { testPionE2ESimpleECDSAClientCert(t, serverPion, clientOpenSSL) }) } func TestPionOpenSSLE2ESimpleRSA(t *testing.T) { t.Run("OpenSSLServer", func(t *testing.T) { testPionE2ESimpleRSA(t, serverOpenSSL, clientPion) }) t.Run("OpenSSLClient", func(t *testing.T) { testPionE2ESimpleRSA(t, serverPion, clientOpenSSL) }) } func TestPionOpenSSLE2ESimpleRSAClientCert(t *testing.T) { t.Run("OpenSSLServer", func(t *testing.T) { testPionE2ESimpleRSAClientCert(t, serverOpenSSL, clientPion) }) t.Run("OpenSSLClient", func(t *testing.T) { testPionE2ESimpleRSAClientCert(t, serverPion, clientOpenSSL) }) } dtls-3.1.2/e2e/e2e_test.go000066400000000000000000001235541514330267300152150ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT //go:build !js // +build !js package e2e import ( "context" "crypto" "crypto/ecdsa" "crypto/ed25519" "crypto/elliptic" "crypto/rand" "crypto/rsa" "crypto/tls" "crypto/x509" "crypto/x509/pkix" "errors" "fmt" "io" "math/big" "net" "sync" "sync/atomic" "testing" "time" "github.com/pion/dtls/v3" "github.com/pion/dtls/v3/pkg/crypto/selfsign" "github.com/pion/dtls/v3/pkg/protocol/extension" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/transport/v4/test" "github.com/stretchr/testify/assert" ) const ( testMessage = "Hello World" testTimeLimit = 5 * time.Second messageRetry = 200 * time.Millisecond ) var ( errServerTimeout = errors.New("waiting on serverReady err: timeout") errHookCiphersFailed = errors.New("hook failed to modify cipherlist") errHookAPLNFailed = errors.New("hook failed to modify APLN extension") ) func randomPort(tb testing.TB) int { tb.Helper() conn, err := net.ListenPacket("udp4", "127.0.0.1:0") // nolint: noctx assert.NoError(tb, err, "failed to pick port") defer func() { _ = conn.Close() }() switch addr := conn.LocalAddr().(type) { case *net.UDPAddr: return addr.Port default: assert.Fail(tb, "failed to acquire port", "unknown addr type %T", addr) return 0 } } func simpleReadWrite(errChan chan error, outChan chan string, conn io.ReadWriter, messageRecvCount *uint64) { go func() { buffer := make([]byte, 8192) n, err := conn.Read(buffer) if err != nil { errChan <- err return } outChan <- string(buffer[:n]) atomic.AddUint64(messageRecvCount, 1) }() for { if atomic.LoadUint64(messageRecvCount) == 2 { break } else if _, err := conn.Write([]byte(testMessage)); err != nil { errChan <- err break } time.Sleep(messageRetry) } } type comm struct { ctx context.Context //nolint:containedctx clientOpts []dtls.ClientOption serverOpts []dtls.ServerOption // OpenSSL test helpers need this information clientCipherSuites []dtls.CipherSuiteID serverCipherSuites []dtls.CipherSuiteID clientCertificates []tls.Certificate serverCertificates []tls.Certificate clientPSK dtls.PSKCallback serverPSK dtls.PSKCallback clientPSKIdentityHint []byte serverPSKIdentityHint []byte clientInsecureSkipVerify bool serverPort int messageRecvCount *uint64 // Counter to make sure both sides got a message clientMutex *sync.Mutex clientConn net.Conn clientDone chan error serverMutex *sync.Mutex serverConn net.Conn serverListener net.Listener serverReady chan struct{} serverDone chan error errChan chan error clientChan chan string serverChan chan string client func(*comm) server func(*comm) } func newComm( ctx context.Context, clientOpts []dtls.ClientOption, serverOpts []dtls.ServerOption, serverPort int, server, client func(*comm), ) *comm { messageRecvCount := uint64(0) com := &comm{ ctx: ctx, clientOpts: clientOpts, serverOpts: serverOpts, serverPort: serverPort, messageRecvCount: &messageRecvCount, clientMutex: &sync.Mutex{}, serverMutex: &sync.Mutex{}, serverReady: make(chan struct{}), serverDone: make(chan error), clientDone: make(chan error), errChan: make(chan error), clientChan: make(chan string), serverChan: make(chan string), server: server, client: client, } return com } // setOpenSSLInfo sets OpenSSL-specific information in the comm struct. // This is called by test functions that have this information available. func (c *comm) setOpenSSLInfo( clientCipherSuites []dtls.CipherSuiteID, serverCipherSuites []dtls.CipherSuiteID, clientCertificates []tls.Certificate, serverCertificates []tls.Certificate, clientPSK dtls.PSKCallback, serverPSK dtls.PSKCallback, clientPSKIdentityHint []byte, serverPSKIdentityHint []byte, clientInsecureSkipVerify bool, ) { c.clientCipherSuites = clientCipherSuites c.serverCipherSuites = serverCipherSuites c.clientCertificates = clientCertificates c.serverCertificates = serverCertificates c.clientPSK = clientPSK c.serverPSK = serverPSK c.clientPSKIdentityHint = clientPSKIdentityHint c.serverPSKIdentityHint = serverPSKIdentityHint c.clientInsecureSkipVerify = clientInsecureSkipVerify } func (c *comm) assert(t *testing.T) { //nolint:cyclop t.Helper() // DTLS Client go c.client(c) // DTLS Server go c.server(c) defer func() { if c.clientConn != nil { assert.NoError(t, c.clientConn.Close()) } if c.serverConn != nil { assert.NoError(t, c.serverConn.Close()) } if c.serverListener != nil { assert.NoError(t, c.serverListener.Close()) } }() func() { seenClient, seenServer := false, false for { select { case err := <-c.errChan: assert.NoError(t, err) case <-time.After(testTimeLimit): assert.Failf(t, "Test timeout", "seenClient %t seenServer %t", seenClient, seenServer) case clientMsg := <-c.clientChan: assert.Equal(t, testMessage, clientMsg) seenClient = true if seenClient && seenServer { return } case serverMsg := <-c.serverChan: assert.Equal(t, testMessage, serverMsg) seenServer = true if seenClient && seenServer { return } } } }() } func (c *comm) cleanup(t *testing.T) { t.Helper() clientDone, serverDone := false, false for { select { case err := <-c.clientDone: assert.NoError(t, err) clientDone = true if clientDone && serverDone { return } case err := <-c.serverDone: assert.NoError(t, err) serverDone = true if clientDone && serverDone { return } case <-time.After(testTimeLimit): assert.Fail(t, "Test timeout waiting for server shutdown") } } } func clientPion(c *comm) { //nolint:varnamelen select { case <-c.serverReady: // OK case <-time.After(time.Second): c.errChan <- errServerTimeout } c.clientMutex.Lock() defer c.clientMutex.Unlock() conn, err := dtls.DialWithOptions("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: c.serverPort}, c.clientOpts..., ) if err != nil { c.errChan <- err return } if err := conn.HandshakeContext(c.ctx); err != nil { c.errChan <- err return } c.clientConn = conn simpleReadWrite(c.errChan, c.clientChan, c.clientConn, c.messageRecvCount) c.clientDone <- nil close(c.clientDone) } func serverPion(c *comm) { //nolint:varnamelen c.serverMutex.Lock() defer c.serverMutex.Unlock() var err error c.serverListener, err = dtls.ListenWithOptions("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: c.serverPort}, c.serverOpts..., ) if err != nil { c.errChan <- err return } c.serverReady <- struct{}{} c.serverConn, err = c.serverListener.Accept() if err != nil { c.errChan <- err return } dtlsConn, ok := c.serverConn.(*dtls.Conn) if ok { if err := dtlsConn.HandshakeContext(c.ctx); err != nil { c.errChan <- err return } } simpleReadWrite(c.errChan, c.serverChan, c.serverConn, c.messageRecvCount) c.serverDone <- nil close(c.serverDone) } type dtlsTestOpts struct { clientOpts []dtls.ClientOption serverOpts []dtls.ServerOption } func withConnectionIDGenerator(g func() []byte) dtlsTestOpts { return dtlsTestOpts{ clientOpts: []dtls.ClientOption{dtls.WithConnectionIDGenerator(g)}, serverOpts: []dtls.ServerOption{dtls.WithConnectionIDGenerator(g)}, } } // Simple DTLS Client/Server can communicate // - Assert that you can send messages both ways // - Assert that Close() on both ends work // - Assert that no Goroutines are leaked // //nolint:dupl func testPionE2ESimple(t *testing.T, server, client func(*comm), opts ...dtlsTestOpts) { t.Helper() lim := test.TimeOut(time.Second * 30) defer lim.Stop() report := test.CheckRoutines(t) defer report() for _, cipherSuite := range []dtls.CipherSuiteID{ dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, dtls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, dtls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, } { cipherSuite := cipherSuite t.Run(cipherSuite.String(), func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() cert, err := selfsign.GenerateSelfSignedWithDNS("localhost") assert.NoError(t, err) clientOpts := []dtls.ClientOption{ dtls.WithCertificates(cert), dtls.WithCipherSuites(cipherSuite), dtls.WithInsecureSkipVerify(true), } serverOpts := []dtls.ServerOption{ dtls.WithCertificates(cert), dtls.WithCipherSuites(cipherSuite), dtls.WithInsecureSkipVerify(true), } for _, o := range opts { clientOpts = append(clientOpts, o.clientOpts...) serverOpts = append(serverOpts, o.serverOpts...) } serverPort := randomPort(t) comm := newComm(ctx, clientOpts, serverOpts, serverPort, server, client) comm.setOpenSSLInfo( []dtls.CipherSuiteID{cipherSuite}, []dtls.CipherSuiteID{cipherSuite}, []tls.Certificate{cert}, []tls.Certificate{cert}, nil, nil, nil, nil, true) defer comm.cleanup(t) comm.assert(t) }) } } //nolint:dupl func testPionE2ESimpleRSA(t *testing.T, server, client func(*comm), opts ...dtlsTestOpts) { t.Helper() lim := test.TimeOut(time.Second * 30) defer lim.Stop() report := test.CheckRoutines(t) defer report() for _, cipherSuite := range []dtls.CipherSuiteID{ dtls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, dtls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, dtls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, } { cipherSuite := cipherSuite t.Run(cipherSuite.String(), func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() priv, err := rsa.GenerateKey(rand.Reader, 2048) assert.NoError(t, err) cert, err := selfsign.SelfSign(priv) assert.NoError(t, err) clientOpts := []dtls.ClientOption{ dtls.WithCertificates(cert), dtls.WithCipherSuites(cipherSuite), dtls.WithInsecureSkipVerify(true), } serverOpts := []dtls.ServerOption{ dtls.WithCertificates(cert), dtls.WithCipherSuites(cipherSuite), dtls.WithInsecureSkipVerify(true), } for _, o := range opts { clientOpts = append(clientOpts, o.clientOpts...) serverOpts = append(serverOpts, o.serverOpts...) } serverPort := randomPort(t) comm := newComm(ctx, clientOpts, serverOpts, serverPort, server, client) comm.setOpenSSLInfo( []dtls.CipherSuiteID{cipherSuite}, []dtls.CipherSuiteID{cipherSuite}, []tls.Certificate{cert}, []tls.Certificate{cert}, nil, nil, nil, nil, true) defer comm.cleanup(t) comm.assert(t) }) } } func testPionE2ESimplePSK(t *testing.T, server, client func(*comm), opts ...dtlsTestOpts) { t.Helper() lim := test.TimeOut(time.Second * 30) defer lim.Stop() report := test.CheckRoutines(t) defer report() for _, cipherSuite := range []dtls.CipherSuiteID{ dtls.TLS_PSK_WITH_AES_128_CCM, dtls.TLS_PSK_WITH_AES_128_CCM_8, dtls.TLS_PSK_WITH_AES_256_CCM_8, dtls.TLS_PSK_WITH_AES_128_GCM_SHA256, dtls.TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256, } { cipherSuite := cipherSuite t.Run(cipherSuite.String(), func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() clientOpts := []dtls.ClientOption{ dtls.WithPSK(func([]byte) ([]byte, error) { return []byte{0xAB, 0xC1, 0x23}, nil }), dtls.WithPSKIdentityHint([]byte{0x01, 0x02, 0x03, 0x04, 0x05}), dtls.WithCipherSuites(cipherSuite), } serverOpts := []dtls.ServerOption{ dtls.WithPSK(func([]byte) ([]byte, error) { return []byte{0xAB, 0xC1, 0x23}, nil }), dtls.WithPSKIdentityHint([]byte{0x01, 0x02, 0x03, 0x04, 0x05}), dtls.WithCipherSuites(cipherSuite), } for _, o := range opts { clientOpts = append(clientOpts, o.clientOpts...) serverOpts = append(serverOpts, o.serverOpts...) } serverPort := randomPort(t) comm := newComm(ctx, clientOpts, serverOpts, serverPort, server, client) pskFunc := func([]byte) ([]byte, error) { return []byte{0xAB, 0xC1, 0x23}, nil } pskHint := []byte{0x01, 0x02, 0x03, 0x04, 0x05} comm.setOpenSSLInfo( []dtls.CipherSuiteID{cipherSuite}, []dtls.CipherSuiteID{cipherSuite}, nil, nil, pskFunc, pskFunc, pskHint, pskHint, false) defer comm.cleanup(t) comm.assert(t) }) } } func testPionE2EMTUs(t *testing.T, server, client func(*comm), opts ...dtlsTestOpts) { t.Helper() lim := test.TimeOut(time.Second * 30) defer lim.Stop() report := test.CheckRoutines(t) defer report() cipherSuite := dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 for _, mtu := range []int{ 10000, 1000, 100, } { mtu := mtu t.Run(fmt.Sprintf("MTU%d", mtu), func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() cert, err := selfsign.GenerateSelfSignedWithDNS("localhost") assert.NoError(t, err) clientOpts := []dtls.ClientOption{ dtls.WithCertificates(cert), dtls.WithCipherSuites(cipherSuite), dtls.WithInsecureSkipVerify(true), dtls.WithMTU(mtu), } serverOpts := []dtls.ServerOption{ dtls.WithCertificates(cert), dtls.WithCipherSuites(cipherSuite), dtls.WithInsecureSkipVerify(true), dtls.WithMTU(mtu), } for _, o := range opts { clientOpts = append(clientOpts, o.clientOpts...) serverOpts = append(serverOpts, o.serverOpts...) } serverPort := randomPort(t) comm := newComm(ctx, clientOpts, serverOpts, serverPort, server, client) comm.setOpenSSLInfo( []dtls.CipherSuiteID{cipherSuite}, []dtls.CipherSuiteID{cipherSuite}, []tls.Certificate{cert}, []tls.Certificate{cert}, nil, nil, nil, nil, true) defer comm.cleanup(t) comm.assert(t) }) } } func testPionE2ESimpleED25519(t *testing.T, server, client func(*comm), opts ...dtlsTestOpts) { t.Helper() lim := test.TimeOut(time.Second * 30) defer lim.Stop() report := test.CheckRoutines(t) defer report() for _, cipherSuite := range []dtls.CipherSuiteID{ dtls.TLS_ECDHE_ECDSA_WITH_AES_128_CCM, dtls.TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8, dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, dtls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, dtls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, } { cipherSuite := cipherSuite t.Run(cipherSuite.String(), func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() _, key, err := ed25519.GenerateKey(rand.Reader) assert.NoError(t, err) cert, err := selfsign.SelfSign(key) assert.NoError(t, err) clientOpts := []dtls.ClientOption{ dtls.WithCertificates(cert), dtls.WithCipherSuites(cipherSuite), dtls.WithInsecureSkipVerify(true), } serverOpts := []dtls.ServerOption{ dtls.WithCertificates(cert), dtls.WithCipherSuites(cipherSuite), dtls.WithInsecureSkipVerify(true), } for _, o := range opts { clientOpts = append(clientOpts, o.clientOpts...) serverOpts = append(serverOpts, o.serverOpts...) } serverPort := randomPort(t) comm := newComm(ctx, clientOpts, serverOpts, serverPort, server, client) comm.setOpenSSLInfo( []dtls.CipherSuiteID{cipherSuite}, []dtls.CipherSuiteID{cipherSuite}, []tls.Certificate{cert}, []tls.Certificate{cert}, nil, nil, nil, nil, true) defer comm.cleanup(t) comm.assert(t) }) } } func testPionE2ESimpleED25519ClientCert(t *testing.T, server, client func(*comm), opts ...dtlsTestOpts) { t.Helper() lim := test.TimeOut(time.Second * 30) defer lim.Stop() report := test.CheckRoutines(t) defer report() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() _, skey, err := ed25519.GenerateKey(rand.Reader) assert.NoError(t, err) scert, err := selfsign.SelfSign(skey) assert.NoError(t, err) _, ckey, err := ed25519.GenerateKey(rand.Reader) assert.NoError(t, err) ccert, err := selfsign.SelfSign(ckey) assert.NoError(t, err) clientOpts := []dtls.ClientOption{ dtls.WithCertificates(ccert), dtls.WithCipherSuites(dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256), dtls.WithInsecureSkipVerify(true), } serverOpts := []dtls.ServerOption{ dtls.WithCertificates(scert), dtls.WithCipherSuites(dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256), dtls.WithClientAuth(dtls.RequireAnyClientCert), } for _, o := range opts { clientOpts = append(clientOpts, o.clientOpts...) serverOpts = append(serverOpts, o.serverOpts...) } serverPort := randomPort(t) comm := newComm(ctx, clientOpts, serverOpts, serverPort, server, client) comm.setOpenSSLInfo( []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, []tls.Certificate{ccert}, []tls.Certificate{scert}, nil, nil, nil, nil, true) defer comm.cleanup(t) comm.assert(t) } func testPionE2ESimpleECDSAClientCert(t *testing.T, server, client func(*comm), opts ...dtlsTestOpts) { t.Helper() lim := test.TimeOut(time.Second * 30) defer lim.Stop() report := test.CheckRoutines(t) defer report() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() scert, err := selfsign.GenerateSelfSigned() assert.NoError(t, err) ccert, err := selfsign.GenerateSelfSigned() assert.NoError(t, err) clientCAs := x509.NewCertPool() caCert, err := x509.ParseCertificate(ccert.Certificate[0]) assert.NoError(t, err) clientCAs.AddCert(caCert) clientOpts := []dtls.ClientOption{ dtls.WithCertificates(ccert), dtls.WithCipherSuites(dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256), dtls.WithInsecureSkipVerify(true), } serverOpts := []dtls.ServerOption{ dtls.WithClientCAs(clientCAs), dtls.WithCertificates(scert), dtls.WithCipherSuites(dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256), dtls.WithClientAuth(dtls.RequireAnyClientCert), } for _, o := range opts { clientOpts = append(clientOpts, o.clientOpts...) serverOpts = append(serverOpts, o.serverOpts...) } serverPort := randomPort(t) comm := newComm(ctx, clientOpts, serverOpts, serverPort, server, client) comm.setOpenSSLInfo( []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, []tls.Certificate{ccert}, []tls.Certificate{scert}, nil, nil, nil, nil, true) defer comm.cleanup(t) comm.assert(t) } func testPionE2ESimpleRSAClientCert(t *testing.T, server, client func(*comm), opts ...dtlsTestOpts) { t.Helper() lim := test.TimeOut(time.Second * 30) defer lim.Stop() report := test.CheckRoutines(t) defer report() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() spriv, err := rsa.GenerateKey(rand.Reader, 2048) assert.NoError(t, err) scert, err := selfsign.SelfSign(spriv) assert.NoError(t, err) cpriv, err := rsa.GenerateKey(rand.Reader, 2048) assert.NoError(t, err) ccert, err := selfsign.SelfSign(cpriv) assert.NoError(t, err) clientOpts := []dtls.ClientOption{ dtls.WithCertificates(ccert), dtls.WithCipherSuites(dtls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256), dtls.WithInsecureSkipVerify(true), } serverOpts := []dtls.ServerOption{ dtls.WithCertificates(scert), dtls.WithCipherSuites(dtls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256), dtls.WithClientAuth(dtls.RequireAnyClientCert), } for _, o := range opts { clientOpts = append(clientOpts, o.clientOpts...) serverOpts = append(serverOpts, o.serverOpts...) } serverPort := randomPort(t) comm := newComm(ctx, clientOpts, serverOpts, serverPort, server, client) comm.setOpenSSLInfo( []dtls.CipherSuiteID{dtls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}, []dtls.CipherSuiteID{dtls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}, []tls.Certificate{ccert}, []tls.Certificate{scert}, nil, nil, nil, nil, true) defer comm.cleanup(t) comm.assert(t) } func testPionE2ESimpleClientHelloHook(t *testing.T, server, client func(*comm), opts ...dtlsTestOpts) { t.Helper() lim := test.TimeOut(time.Second * 30) defer lim.Stop() report := test.CheckRoutines(t) defer report() t.Run("ClientHello hook", func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() cert, err := selfsign.GenerateSelfSignedWithDNS("localhost") assert.NoError(t, err) modifiedCipher := dtls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA supportedList := []dtls.CipherSuiteID{ dtls.TLS_ECDHE_ECDSA_WITH_AES_128_CCM, modifiedCipher, } clientOpts := []dtls.ClientOption{ dtls.WithCertificates(cert), dtls.WithVerifyConnection(func(s *dtls.State) error { if s.CipherSuiteID != modifiedCipher { return errHookCiphersFailed } return nil }), dtls.WithCipherSuites(supportedList...), dtls.WithClientHelloMessageHook(func(ch handshake.MessageClientHello) handshake.Message { ch.CipherSuiteIDs = []uint16{uint16(modifiedCipher)} return &ch }), dtls.WithInsecureSkipVerify(true), } serverOpts := []dtls.ServerOption{ dtls.WithCertificates(cert), dtls.WithCipherSuites(supportedList...), dtls.WithInsecureSkipVerify(true), } for _, o := range opts { clientOpts = append(clientOpts, o.clientOpts...) serverOpts = append(serverOpts, o.serverOpts...) } serverPort := randomPort(t) comm := newComm(ctx, clientOpts, serverOpts, serverPort, server, client) comm.setOpenSSLInfo( supportedList, supportedList, []tls.Certificate{cert}, []tls.Certificate{cert}, nil, nil, nil, nil, true) defer comm.cleanup(t) comm.assert(t) }) } func testPionE2ESimpleServerHelloHook(t *testing.T, server, client func(*comm), opts ...dtlsTestOpts) { t.Helper() lim := test.TimeOut(time.Second * 30) defer lim.Stop() report := test.CheckRoutines(t) defer report() t.Run("ServerHello hook", func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() cert, err := selfsign.GenerateSelfSignedWithDNS("localhost") assert.NoError(t, err) supportedList := []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_CCM} apln := "APLN" clientOpts := []dtls.ClientOption{ dtls.WithCertificates(cert), dtls.WithVerifyConnection(func(s *dtls.State) error { if s.NegotiatedProtocol != apln { return errHookAPLNFailed } return nil }), dtls.WithCipherSuites(supportedList...), dtls.WithInsecureSkipVerify(true), } serverOpts := []dtls.ServerOption{ dtls.WithCertificates(cert), dtls.WithCipherSuites(supportedList...), dtls.WithServerHelloMessageHook(func(sh handshake.MessageServerHello) handshake.Message { sh.Extensions = append(sh.Extensions, &extension.ALPN{ ProtocolNameList: []string{apln}, }) return &sh }), dtls.WithInsecureSkipVerify(true), } for _, o := range opts { clientOpts = append(clientOpts, o.clientOpts...) serverOpts = append(serverOpts, o.serverOpts...) } serverPort := randomPort(t) comm := newComm(ctx, clientOpts, serverOpts, serverPort, server, client) comm.setOpenSSLInfo( supportedList, supportedList, []tls.Certificate{cert}, []tls.Certificate{cert}, nil, nil, nil, nil, true) defer comm.cleanup(t) comm.assert(t) }) } func TestPionE2ESimple(t *testing.T) { testPionE2ESimple(t, serverPion, clientPion) } func TestPionE2ESimpleRSA(t *testing.T) { testPionE2ESimpleRSA(t, serverPion, clientPion) } func TestPionE2ESimplePSK(t *testing.T) { testPionE2ESimplePSK(t, serverPion, clientPion) } func TestPionE2EMTUs(t *testing.T) { testPionE2EMTUs(t, serverPion, clientPion) } func TestPionE2ESimpleED25519(t *testing.T) { testPionE2ESimpleED25519(t, serverPion, clientPion) } func TestPionE2ESimpleED25519ClientCert(t *testing.T) { testPionE2ESimpleED25519ClientCert(t, serverPion, clientPion) } func TestPionE2ESimpleECDSAClientCert(t *testing.T) { testPionE2ESimpleECDSAClientCert(t, serverPion, clientPion) } func TestPionE2ESimpleRSAClientCert(t *testing.T) { testPionE2ESimpleRSAClientCert(t, serverPion, clientPion) } func TestPionE2ESimpleCID(t *testing.T) { testPionE2ESimple(t, serverPion, clientPion, withConnectionIDGenerator(dtls.RandomCIDGenerator(8))) } func TestPionE2ESimplePSKCID(t *testing.T) { testPionE2ESimplePSK(t, serverPion, clientPion, withConnectionIDGenerator(dtls.RandomCIDGenerator(8))) } func TestPionE2EMTUsCID(t *testing.T) { testPionE2EMTUs(t, serverPion, clientPion, withConnectionIDGenerator(dtls.RandomCIDGenerator(8))) } func TestPionE2ESimpleED25519CID(t *testing.T) { testPionE2ESimpleED25519(t, serverPion, clientPion, withConnectionIDGenerator(dtls.RandomCIDGenerator(8))) } func TestPionE2ESimpleED25519ClientCertCID(t *testing.T) { testPionE2ESimpleED25519ClientCert(t, serverPion, clientPion, withConnectionIDGenerator(dtls.RandomCIDGenerator(8))) } func TestPionE2ESimpleECDSAClientCertCID(t *testing.T) { testPionE2ESimpleECDSAClientCert(t, serverPion, clientPion, withConnectionIDGenerator(dtls.RandomCIDGenerator(8))) } func TestPionE2ESimpleRSAClientCertCID(t *testing.T) { testPionE2ESimpleRSAClientCert(t, serverPion, clientPion, withConnectionIDGenerator(dtls.RandomCIDGenerator(8))) } func TestPionE2ESimpleClientHelloHook(t *testing.T) { testPionE2ESimpleClientHelloHook(t, serverPion, clientPion) } func TestPionE2ESimpleServerHelloHook(t *testing.T) { testPionE2ESimpleServerHelloHook(t, serverPion, clientPion) } // TestCertificateSignatureSchemesAllowed tests that connections succeed when // certificate chains use only allowed signature algorithms. func TestCertificateSignatureSchemesAllowed(t *testing.T) { lim := test.TimeOut(time.Second * 30) defer lim.Stop() report := test.CheckRoutines(t) defer report() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() // Generate ECDSA certificate (uses ECDSA-SHA256) cert, err := selfsign.GenerateSelfSigned() assert.NoError(t, err) // Client allows ECDSA signature schemes for certificates clientOpts := []dtls.ClientOption{ dtls.WithCertificates(cert), dtls.WithCipherSuites(dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256), dtls.WithInsecureSkipVerify(true), dtls.WithCertificateSignatureSchemes( tls.ECDSAWithP256AndSHA256, tls.ECDSAWithP384AndSHA384, tls.ECDSAWithP521AndSHA512, ), } // Server allows ECDSA signature schemes for certificates serverOpts := []dtls.ServerOption{ dtls.WithCertificates(cert), dtls.WithCipherSuites(dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256), dtls.WithInsecureSkipVerify(true), dtls.WithCertificateSignatureSchemes( tls.ECDSAWithP256AndSHA256, tls.ECDSAWithP384AndSHA384, tls.ECDSAWithP521AndSHA512, ), } serverPort := randomPort(t) comm := newComm(ctx, clientOpts, serverOpts, serverPort, serverPion, clientPion) comm.setOpenSSLInfo( []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, []tls.Certificate{cert}, []tls.Certificate{cert}, nil, nil, nil, nil, true) defer comm.cleanup(t) comm.assert(t) } // TestCertificateSignatureSchemesRSA tests RSA certificates with signature schemes. func TestCertificateSignatureSchemesRSA(t *testing.T) { lim := test.TimeOut(time.Second * 30) defer lim.Stop() report := test.CheckRoutines(t) defer report() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() // Generate RSA certificate priv, err := rsa.GenerateKey(rand.Reader, 2048) assert.NoError(t, err) cert, err := selfsign.SelfSign(priv) assert.NoError(t, err) // Allow RSA-PSS signature schemes for certificates clientOpts := []dtls.ClientOption{ dtls.WithCertificates(cert), dtls.WithCipherSuites(dtls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256), dtls.WithInsecureSkipVerify(true), dtls.WithCertificateSignatureSchemes( tls.PSSWithSHA256, tls.PSSWithSHA384, tls.PSSWithSHA512, tls.PKCS1WithSHA256, tls.PKCS1WithSHA384, tls.PKCS1WithSHA512, ), } serverOpts := []dtls.ServerOption{ dtls.WithCertificates(cert), dtls.WithCipherSuites(dtls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256), dtls.WithInsecureSkipVerify(true), dtls.WithCertificateSignatureSchemes( tls.PSSWithSHA256, tls.PSSWithSHA384, tls.PSSWithSHA512, tls.PKCS1WithSHA256, tls.PKCS1WithSHA384, tls.PKCS1WithSHA512, ), } serverPort := randomPort(t) comm := newComm(ctx, clientOpts, serverOpts, serverPort, serverPion, clientPion) comm.setOpenSSLInfo( []dtls.CipherSuiteID{dtls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}, []dtls.CipherSuiteID{dtls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}, []tls.Certificate{cert}, []tls.Certificate{cert}, nil, nil, nil, nil, true) defer comm.cleanup(t) comm.assert(t) } // TestCertificateSignatureSchemesClientCert tests certificate signature validation // with client certificates using different ECDSA curves. func TestCertificateSignatureSchemesClientCert(t *testing.T) { lim := test.TimeOut(time.Second * 30) defer lim.Stop() report := test.CheckRoutines(t) defer report() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() // Server uses P-256 ECDSA serverCert, err := selfsign.GenerateSelfSigned() assert.NoError(t, err) // Client uses P-384 ECDSA clientKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader) assert.NoError(t, err) clientCert, err := selfsign.SelfSign(clientKey) assert.NoError(t, err) clientCAs := x509.NewCertPool() caCert, err := x509.ParseCertificate(clientCert.Certificate[0]) assert.NoError(t, err) clientCAs.AddCert(caCert) // Both sides accept P-256 and P-384 ECDSA clientOpts := []dtls.ClientOption{ dtls.WithCertificates(clientCert), dtls.WithCipherSuites(dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256), dtls.WithInsecureSkipVerify(true), dtls.WithCertificateSignatureSchemes( tls.ECDSAWithP256AndSHA256, tls.ECDSAWithP384AndSHA384, ), } serverOpts := []dtls.ServerOption{ dtls.WithClientCAs(clientCAs), dtls.WithCertificates(serverCert), dtls.WithCipherSuites(dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256), dtls.WithClientAuth(dtls.RequireAnyClientCert), dtls.WithCertificateSignatureSchemes( tls.ECDSAWithP256AndSHA256, tls.ECDSAWithP384AndSHA384, ), } serverPort := randomPort(t) comm := newComm(ctx, clientOpts, serverOpts, serverPort, serverPion, clientPion) comm.setOpenSSLInfo( []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, []dtls.CipherSuiteID{dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, []tls.Certificate{clientCert}, []tls.Certificate{serverCert}, nil, nil, nil, nil, true) defer comm.cleanup(t) comm.assert(t) } // createCertChain creates a certificate chain with a root CA and a leaf certificate // signed by the CA. The CA uses the same key type as the leaf to ensure consistent // signature algorithms in the chain. This allows testing signature algorithm validation. func createCertChain(t *testing.T, leafKeyType string) (tls.Certificate, *x509.CertPool) { t.Helper() // Create root CA with matching key type var caKey crypto.Signer var caPubKey crypto.PublicKey var err error switch leafKeyType { case "ecdsa-p256": var k *ecdsa.PrivateKey k, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) assert.NoError(t, err) caKey = k caPubKey = &k.PublicKey case "ecdsa-p384": var k *ecdsa.PrivateKey k, err = ecdsa.GenerateKey(elliptic.P384(), rand.Reader) assert.NoError(t, err) caKey = k caPubKey = &k.PublicKey case "rsa": var k *rsa.PrivateKey k, err = rsa.GenerateKey(rand.Reader, 2048) assert.NoError(t, err) caKey = k caPubKey = &k.PublicKey default: assert.FailNowf(t, "unknown key type", "unknown key type: %s", leafKeyType) } caTemplate := &x509.Certificate{ SerialNumber: big.NewInt(1), Subject: pkix.Name{CommonName: "Test CA"}, NotBefore: time.Now(), NotAfter: time.Now().Add(24 * time.Hour), KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature, BasicConstraintsValid: true, IsCA: true, } var caCertDER []byte caCertDER, err = x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, caPubKey, caKey) assert.NoError(t, err) var caCert *x509.Certificate caCert, err = x509.ParseCertificate(caCertDER) assert.NoError(t, err) // Create leaf certificate with same key type, signed by CA var leafKey crypto.Signer var leafPubKey crypto.PublicKey switch leafKeyType { case "ecdsa-p256": var k *ecdsa.PrivateKey k, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) assert.NoError(t, err) leafKey = k leafPubKey = &k.PublicKey case "ecdsa-p384": var k *ecdsa.PrivateKey k, err = ecdsa.GenerateKey(elliptic.P384(), rand.Reader) assert.NoError(t, err) leafKey = k leafPubKey = &k.PublicKey case "rsa": var k *rsa.PrivateKey k, err = rsa.GenerateKey(rand.Reader, 2048) assert.NoError(t, err) leafKey = k leafPubKey = &k.PublicKey default: assert.FailNowf(t, "unknown key type", "unknown key type: %s", leafKeyType) } leafTemplate := &x509.Certificate{ SerialNumber: big.NewInt(2), Subject: pkix.Name{CommonName: "Test Leaf"}, NotBefore: time.Now(), NotAfter: time.Now().Add(24 * time.Hour), KeyUsage: x509.KeyUsageDigitalSignature, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, } var leafCertDER []byte leafCertDER, err = x509.CreateCertificate(rand.Reader, leafTemplate, caCert, leafPubKey, caKey) assert.NoError(t, err) // Create tls.Certificate with full chain tlsCert := tls.Certificate{ Certificate: [][]byte{leafCertDER, caCertDER}, PrivateKey: leafKey, } // Create root CA pool rootCAs := x509.NewCertPool() rootCAs.AddCert(caCert) return tlsCert, rootCAs } // TestCertificateSignatureSchemesServerCertRejected tests that client rejects // server certificate when it uses a disallowed signature algorithm. // //nolint:dupl func TestCertificateSignatureSchemesServerCertRejected(t *testing.T) { lim := test.TimeOut(time.Second * 30) defer lim.Stop() report := test.CheckRoutines(t) defer report() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() // Server uses P-256 ECDSA certificate chain serverCert, serverRootCAs := createCertChain(t, "ecdsa-p256") // Client uses self-signed cert (doesn't matter since server doesn't validate) clientCert, err := selfsign.GenerateSelfSigned() assert.NoError(t, err) // Client only allows P-384 and P-521, but server leaf cert is P-256 // This should cause the handshake to fail clientOpts := []dtls.ClientOption{ dtls.WithCertificates(clientCert), dtls.WithCipherSuites(dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256), dtls.WithRootCAs(serverRootCAs), dtls.WithCertificateSignatureSchemes( tls.ECDSAWithP384AndSHA384, tls.ECDSAWithP521AndSHA512, ), } serverOpts := []dtls.ServerOption{ dtls.WithCertificates(serverCert), dtls.WithCipherSuites(dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256), dtls.WithInsecureSkipVerify(true), } serverPort := randomPort(t) // Start server serverReady := make(chan struct{}) serverDone := make(chan error, 1) var serverConn net.Conn go func() { listener, listenerErr := dtls.ListenWithOptions("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: serverPort}, serverOpts..., ) if listenerErr != nil { serverDone <- listenerErr return } defer func() { _ = listener.Close() }() serverReady <- struct{}{} var acceptErr error serverConn, acceptErr = listener.Accept() if acceptErr != nil { serverDone <- acceptErr return } defer func() { _ = serverConn.Close() }() // Try to do handshake var handshakeErr error if dtlsConn, ok := serverConn.(*dtls.Conn); ok { handshakeErr = dtlsConn.HandshakeContext(ctx) } serverDone <- handshakeErr }() // Wait for server to be ready select { case <-serverReady: case <-time.After(time.Second): assert.FailNow(t, "server not ready in time") } // Client should fail to connect conn, err := dtls.DialWithOptions("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: serverPort}, clientOpts..., ) if err == nil && conn != nil { err = conn.HandshakeContext(ctx) _ = conn.Close() } // We expect the handshake to fail due to invalid certificate signature algorithm assert.Error(t, err, "expected handshake to fail with disallowed signature scheme") // Wait for server to complete select { case serverErr := <-serverDone: // Server should also see an error assert.Error(t, serverErr) case <-time.After(2 * time.Second): t.Log("server did not complete in time") } } // TestCertificateSignatureSchemesClientCertRejected tests that server rejects // client certificate when it uses a disallowed signature algorithm. func TestCertificateSignatureSchemesClientCertRejected(t *testing.T) { lim := test.TimeOut(time.Second * 30) defer lim.Stop() report := test.CheckRoutines(t) defer report() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() // Server uses self-signed cert (client won't validate) serverCert, err := selfsign.GenerateSelfSigned() assert.NoError(t, err) // Client uses P-256 ECDSA certificate chain clientCert, clientRootCAs := createCertChain(t, "ecdsa-p256") // Client allows P-256 for server cert (won't validate anyway) clientOpts := []dtls.ClientOption{ dtls.WithCertificates(clientCert), dtls.WithCipherSuites(dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256), dtls.WithInsecureSkipVerify(true), } // Server only allows P-384 and P-521 for client cert, but client uses P-256 // This should cause the handshake to fail serverOpts := []dtls.ServerOption{ dtls.WithClientCAs(clientRootCAs), dtls.WithCertificates(serverCert), dtls.WithCipherSuites(dtls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256), dtls.WithClientAuth(dtls.RequireAndVerifyClientCert), dtls.WithCertificateSignatureSchemes( tls.ECDSAWithP384AndSHA384, tls.ECDSAWithP521AndSHA512, ), } serverPort := randomPort(t) // Start server serverReady := make(chan struct{}) serverDone := make(chan error, 1) var serverConn net.Conn go func() { listener, listenerErr := dtls.ListenWithOptions("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: serverPort}, serverOpts..., ) if listenerErr != nil { serverDone <- listenerErr return } defer func() { _ = listener.Close() }() serverReady <- struct{}{} var acceptErr error serverConn, acceptErr = listener.Accept() if acceptErr != nil { serverDone <- acceptErr return } defer func() { _ = serverConn.Close() }() // Try to do handshake var handshakeErr error if dtlsConn, ok := serverConn.(*dtls.Conn); ok { handshakeErr = dtlsConn.HandshakeContext(ctx) } serverDone <- handshakeErr }() // Wait for server to be ready select { case <-serverReady: case <-time.After(time.Second): assert.FailNow(t, "server not ready in time") } // Client tries to connect conn, err := dtls.DialWithOptions("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: serverPort}, clientOpts..., ) if err == nil && conn != nil { err = conn.HandshakeContext(ctx) _ = conn.Close() } // We expect the handshake to fail due to invalid client certificate signature algorithm assert.Error(t, err, "expected handshake to fail with disallowed client cert signature scheme") // Wait for server to complete select { case serverErr := <-serverDone: // Server should also see an error (certificate validation failed) assert.Error(t, serverErr) case <-time.After(2 * time.Second): t.Log("server did not complete in time") } } // TestCertificateSignatureSchemesRSAMismatch tests that connections fail when // RSA certificate is presented but only ECDSA schemes are allowed. // //nolint:dupl func TestCertificateSignatureSchemesRSAMismatch(t *testing.T) { lim := test.TimeOut(time.Second * 30) defer lim.Stop() report := test.CheckRoutines(t) defer report() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() // Server uses RSA certificate chain serverCert, serverRootCAs := createCertChain(t, "rsa") // Client uses self-signed cert clientCert, err := selfsign.GenerateSelfSigned() assert.NoError(t, err) // Client only allows ECDSA, but server uses RSA // This should cause the handshake to fail clientOpts := []dtls.ClientOption{ dtls.WithCertificates(clientCert), dtls.WithCipherSuites(dtls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256), dtls.WithRootCAs(serverRootCAs), dtls.WithCertificateSignatureSchemes( tls.ECDSAWithP256AndSHA256, tls.ECDSAWithP384AndSHA384, ), } serverOpts := []dtls.ServerOption{ dtls.WithCertificates(serverCert), dtls.WithCipherSuites(dtls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256), dtls.WithInsecureSkipVerify(true), } serverPort := randomPort(t) // Start server serverReady := make(chan struct{}) serverDone := make(chan error, 1) var serverConn net.Conn go func() { listener, listenerErr := dtls.ListenWithOptions("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: serverPort}, serverOpts..., ) if listenerErr != nil { serverDone <- listenerErr return } defer func() { _ = listener.Close() }() serverReady <- struct{}{} var acceptErr error serverConn, acceptErr = listener.Accept() if acceptErr != nil { serverDone <- acceptErr return } defer func() { _ = serverConn.Close() }() // Try to do handshake var handshakeErr error if dtlsConn, ok := serverConn.(*dtls.Conn); ok { handshakeErr = dtlsConn.HandshakeContext(ctx) } serverDone <- handshakeErr }() // Wait for server to be ready select { case <-serverReady: case <-time.After(time.Second): assert.FailNow(t, "server not ready in time") } // Client should fail to connect conn, err := dtls.DialWithOptions("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: serverPort}, clientOpts..., ) if err == nil && conn != nil { err = conn.HandshakeContext(ctx) _ = conn.Close() } // We expect the handshake to fail due to RSA cert with ECDSA-only schemes assert.Error(t, err, "expected handshake to fail with RSA cert but ECDSA-only schemes") // Wait for server to complete select { case serverErr := <-serverDone: // Server should also see an error assert.Error(t, serverErr) case <-time.After(2 * time.Second): t.Log("server did not complete in time") } } dtls-3.1.2/errors.go000066400000000000000000000274761514330267300143520ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "context" "errors" "fmt" "io" "net" "os" "github.com/pion/dtls/v3/pkg/protocol" "github.com/pion/dtls/v3/pkg/protocol/alert" ) // Typed errors. var ( ErrConnClosed = &FatalError{Err: errors.New("conn is closed")} //nolint:err113 errDeadlineExceeded = &TimeoutError{Err: fmt.Errorf("read/write timeout: %w", context.DeadlineExceeded)} errInvalidContentType = &TemporaryError{Err: errors.New("invalid content type")} //nolint:err113 //nolint:err113 errBufferTooSmall = &TemporaryError{Err: errors.New("buffer is too small")} //nolint:err113 errContextUnsupported = &TemporaryError{Err: errors.New("context is not supported for ExportKeyingMaterial")} //nolint:err113 errHandshakeInProgress = &TemporaryError{Err: errors.New("handshake is in progress")} //nolint:err113 errReservedExportKeyingMaterial = &TemporaryError{ Err: errors.New("ExportKeyingMaterial can not be used with a reserved label"), } //nolint:err113 errApplicationDataEpochZero = &TemporaryError{Err: errors.New("ApplicationData with epoch of 0")} //nolint:err113 errUnhandledContextType = &TemporaryError{Err: errors.New("unhandled contentType")} //nolint:err113 errCertificateVerifyNoCertificate = &FatalError{ Err: errors.New("client sent certificate verify but we have no certificate to verify"), } //nolint:err113 errCipherSuiteNoIntersection = &FatalError{Err: errors.New("client+server do not support any shared cipher suites")} //nolint:err113 errClientCertificateNotVerified = &FatalError{Err: errors.New("client sent certificate but did not verify it")} //nolint:err113 errClientCertificateRequired = &FatalError{Err: errors.New("server required client verification, but got none")} //nolint:err113 errClientNoMatchingSRTPProfile = &FatalError{Err: errors.New("server responded with SRTP Profile we do not support")} //nolint:err113 errClientRequiredButNoServerEMS = &FatalError{ Err: errors.New("client required Extended Master Secret extension, but server does not support it"), } //nolint:err113 errCookieMismatch = &FatalError{Err: errors.New("client+server cookie does not match")} //nolint:err113 errIdentityNoPSK = &FatalError{Err: errors.New("PSK Identity Hint provided but PSK is nil")} //nolint:err113 errInvalidCertificate = &FatalError{Err: errors.New("no certificate provided")} //nolint:err113 errInvalidCipherSuite = &FatalError{Err: errors.New("invalid or unknown cipher suite")} //nolint:err113 errInvalidClientAuthType = &FatalError{Err: errors.New("invalid client auth type")} //nolint:err113 errInvalidECDSASignature = &FatalError{Err: errors.New("ECDSA signature contained zero or negative values")} //nolint:err113 errInvalidPrivateKey = &FatalError{Err: errors.New("invalid private key type")} //nolint:err113 errInvalidSignatureAlgorithm = &FatalError{Err: errors.New("invalid signature algorithm")} //nolint:err113 errInvalidExtendedMasterSecretType = &FatalError{Err: errors.New("invalid extended master secret type")} //nolint:err113 errInvalidCertificateSignatureAlgorithm = &FatalError{ Err: errors.New("certificate uses a signature algorithm that is not allowed"), } //nolint:err113 errKeySignatureMismatch = &FatalError{Err: errors.New("expected and actual key signature do not match")} //nolint:err113 errInvalidCertificateOID = &FatalError{Err: errors.New("certificate OID does not match signature algorithm")} //nolint:err113 errNilNextConn = &FatalError{Err: errors.New("Conn can not be created with a nil nextConn")} //nolint:err113 errNoAvailableCipherSuites = &FatalError{ Err: errors.New("connection can not be created, no CipherSuites satisfy this Config"), } //nolint:err113 errNoAvailablePSKCipherSuite = &FatalError{ Err: errors.New("connection can not be created, pre-shared key present but no compatible CipherSuite"), } //nolint:err113 errNoAvailableCertificateCipherSuite = &FatalError{ Err: errors.New("connection can not be created, certificate present but no compatible CipherSuite"), } //nolint:err113 errNoAvailableSignatureSchemes = &FatalError{ Err: errors.New("connection can not be created, no SignatureScheme satisfy this Config"), } //nolint:err113 errNoCertificates = &FatalError{Err: errors.New("no certificates configured")} //nolint:err113 errNoConfigProvided = &FatalError{Err: errors.New("no config provided")} //nolint:err113 errNoSupportedEllipticCurves = &FatalError{ Err: errors.New("client requested zero or more elliptic curves that are not supported by the server"), } //nolint:err113 errUnsupportedProtocolVersion = &FatalError{Err: errors.New("unsupported protocol version")} //nolint:err113 errPSKAndIdentityMustBeSetForClient = &FatalError{ Err: errors.New("PSK and PSK Identity Hint must both be set for client"), } //nolint:err113 errRequestedButNoSRTPExtension = &FatalError{ Err: errors.New("SRTP support was requested but server did not respond with use_srtp extension"), } //nolint:err113 errServerNoMatchingSRTPProfile = &FatalError{Err: errors.New("client requested SRTP but we have no matching profiles")} //nolint:err113 errServerRequiredButNoClientEMS = &FatalError{ Err: errors.New("server requires the Extended Master Secret extension, but the client does not support it"), } //nolint:err113 errVerifyDataMismatch = &FatalError{Err: errors.New("expected and actual verify data does not match")} //nolint:err113 errNotAcceptableCertificateChain = &FatalError{Err: errors.New("certificate chain is not signed by an acceptable CA")} //nolint:err113 errInvalidFlight = &InternalError{Err: errors.New("invalid flight number")} //nolint:err113 errKeySignatureGenerateUnimplemented = &InternalError{ Err: errors.New("unable to generate key signature, unimplemented"), } //nolint:err113 errKeySignatureVerifyUnimplemented = &InternalError{Err: errors.New("unable to verify key signature, unimplemented")} //nolint:err113 errLengthMismatch = &InternalError{Err: errors.New("data length and declared length do not match")} //nolint:err113 errSequenceNumberOverflow = &InternalError{Err: errors.New("sequence number overflow")} //nolint:err113 errInvalidFSMTransition = &InternalError{Err: errors.New("invalid state machine transition")} //nolint:err113 errFailedToAccessPoolReadBuffer = &InternalError{Err: errors.New("failed to access pool read buffer")} //nolint:err113 errFragmentBufferOverflow = &InternalError{Err: errors.New("fragment buffer overflow")} //nolint:err113 errEmptyCertificates = &FatalError{Err: errors.New("certificates option requires at least one certificate")} //nolint:err113 errEmptyCipherSuites = &FatalError{Err: errors.New("cipher suites option requires at least one cipher suite")} //nolint:err113 errNilCustomCipherSuites = &FatalError{Err: errors.New("custom cipher suites option requires a non-nil function")} //nolint:err113 errEmptySignatureSchemes = &FatalError{Err: errors.New("signature schemes option requires at least one scheme")} //nolint:err113 errEmptyCertificateSignatureSchemes = &FatalError{ Err: errors.New("certificate signature schemes option requires at least one scheme"), } //nolint:err113 errEmptySRTPProtectionProfiles = &FatalError{ Err: errors.New("SRTP protection profiles option requires at least one profile"), } //nolint:err113 errInvalidFlightInterval = &FatalError{Err: errors.New("flight interval must be positive")} //nolint:err113 errNilPSKCallback = &FatalError{Err: errors.New("PSK option requires a non-nil callback")} //nolint:err113 errNilVerifyPeerCertificate = &FatalError{ Err: errors.New("verify peer certificate option requires a non-nil callback"), } //nolint:err113 errNilVerifyConnection = &FatalError{Err: errors.New("verify connection option requires a non-nil callback")} //nolint:err113 errInvalidMTU = &FatalError{Err: errors.New("MTU must be positive")} //nolint:err113 errInvalidReplayProtectionWindow = &FatalError{Err: errors.New("replay protection window must be non-negative")} //nolint:err113 errEmptySupportedProtocols = &FatalError{ Err: errors.New("supported protocols option requires at least one protocol"), } //nolint:err113 errEmptyEllipticCurves = &FatalError{Err: errors.New("elliptic curves option requires at least one curve")} //nolint:err113 errNilGetClientCertificate = &FatalError{ Err: errors.New("get client certificate option requires a non-nil callback"), } //nolint:err113 errNilConnectionIDGenerator = &FatalError{ Err: errors.New("connection ID generator option requires a non-nil function"), } //nolint:err113 errNilPaddingLengthGenerator = &FatalError{ Err: errors.New("padding length generator option requires a non-nil function"), } //nolint:err113 errNilHelloRandomBytesGenerator = &FatalError{ Err: errors.New("hello random bytes generator option requires a non-nil function"), } //nolint:err113 errNilClientHelloMessageHook = &FatalError{ Err: errors.New("client hello message hook option requires a non-nil function"), } //nolint:err113 errNilGetCertificate = &FatalError{Err: errors.New("get certificate option requires a non-nil callback")} //nolint:err113 errNilServerHelloMessageHook = &FatalError{ Err: errors.New("server hello message hook option requires a non-nil function"), } //nolint:err113 errNilCertificateRequestMessageHook = &FatalError{ Err: errors.New("certificate request message hook option requires a non-nil function"), } //nolint:err113 errNilOnConnectionAttempt = &FatalError{ Err: errors.New("on connection attempt option requires a non-nil callback"), } ) // FatalError indicates that the DTLS connection is no longer available. // It is mainly caused by wrong configuration of server or client. type FatalError = protocol.FatalError // InternalError indicates and internal error caused by the implementation, // and the DTLS connection is no longer available. // It is mainly caused by bugs or tried to use unimplemented features. type InternalError = protocol.InternalError // TemporaryError indicates that the DTLS connection is still available, but the request was failed temporary. type TemporaryError = protocol.TemporaryError // TimeoutError indicates that the request was timed out. type TimeoutError = protocol.TimeoutError // HandshakeError indicates that the handshake failed. type HandshakeError = protocol.HandshakeError // errInvalidCipherSuite indicates an attempt at using an unsupported cipher suite. type invalidCipherSuiteError struct { id CipherSuiteID } func (e *invalidCipherSuiteError) Error() string { return fmt.Sprintf("CipherSuite with id(%d) is not valid", e.id) } func (e *invalidCipherSuiteError) Is(err error) bool { var other *invalidCipherSuiteError if errors.As(err, &other) { return e.id == other.id } return false } // errAlert wraps DTLS alert notification as an error. type alertError struct { *alert.Alert } func (e *alertError) Error() string { return fmt.Sprintf("alert: %s", e.Alert.String()) } func (e *alertError) IsFatalOrCloseNotify() bool { return e.Level == alert.Fatal || e.Description == alert.CloseNotify } func (e *alertError) Is(err error) bool { var other *alertError if errors.As(err, &other) { return e.Level == other.Level && e.Description == other.Description } return false } // netError translates an error from underlying Conn to corresponding net.Error. func netError(err error) error { switch { case errors.Is(err, io.EOF), errors.Is(err, context.Canceled), errors.Is(err, context.DeadlineExceeded): // Return io.EOF and context errors as is. return err } var ( ne net.Error opError *net.OpError se *os.SyscallError ) if errors.As(err, &opError) { //nolint:nestif if errors.As(opError, &se) { if se.Timeout() { return &TimeoutError{Err: err} } if isOpErrorTemporary(se) { return &TemporaryError{Err: err} } } } if errors.As(err, &ne) { return err } return &FatalError{Err: err} } dtls-3.1.2/errors_errno.go000066400000000000000000000012321514330267300155350ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT //go:build aix || darwin || dragonfly || freebsd || linux || nacl || nacljs || netbsd || openbsd || solaris || windows // +build aix darwin dragonfly freebsd linux nacl nacljs netbsd openbsd solaris windows // For systems having syscall.Errno. // Update build targets by following command: // $ grep -R ECONN $(go env GOROOT)/src/syscall/zerrors_*.go \ // | tr "." "_" | cut -d"_" -f"2" | sort | uniq package dtls import ( "errors" "os" "syscall" ) func isOpErrorTemporary(err *os.SyscallError) bool { return errors.Is(err.Err, syscall.ECONNREFUSED) } dtls-3.1.2/errors_errno_test.go000066400000000000000000000027601514330267300166030ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT //go:build aix || darwin || dragonfly || freebsd || linux || nacl || nacljs || netbsd || openbsd || solaris || windows // +build aix darwin dragonfly freebsd linux nacl nacljs netbsd openbsd solaris windows // For systems having syscall.Errno. // The build target must be same as errors_errno.go. package dtls import ( "net" "testing" "time" "github.com/stretchr/testify/assert" ) func TestErrorsTemporary(t *testing.T) { // Allocate a UDP port no one is listening on. addrListen, err := net.ResolveUDPAddr("udp", "localhost:0") assert.NoError(t, err) listener, err := net.ListenUDP("udp", addrListen) assert.NoError(t, err) raddr, ok := listener.LocalAddr().(*net.UDPAddr) assert.True(t, ok) assert.NoError(t, listener.Close()) // Server is not listening. conn, errDial := net.DialUDP("udp", nil, raddr) assert.NoError(t, errDial) _, _ = conn.Write([]byte{0x00}) // trigger // Avoid indefinite blocking on platforms that don't surface ICMP errors reliably - Windows :) _ = conn.SetReadDeadline(time.Now().Add(5 * time.Second)) _, err = conn.Read(make([]byte, 10)) _ = conn.Close() if err == nil { t.Skip("ECONNREFUSED is not set by system") } var ne net.Error assert.ErrorAs(t, netError(err), &ne) if ne.Timeout() { t.Skip("timed out waiting for ICMP error; skipping on this platform") } assert.False(t, ne.Timeout()) assert.True(t, ne.Temporary()) //nolint:staticcheck } dtls-3.1.2/errors_noerrno.go000066400000000000000000000010141514330267300160700ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT //go:build !aix && !darwin && !dragonfly && !freebsd && !linux && !nacl && !nacljs && !netbsd && !openbsd && !solaris && !windows // +build !aix,!darwin,!dragonfly,!freebsd,!linux,!nacl,!nacljs,!netbsd,!openbsd,!solaris,!windows // For systems without syscall.Errno. // Build targets must be inverse of errors_errno.go package dtls import ( "os" ) func isOpErrorTemporary(err *os.SyscallError) bool { return false } dtls-3.1.2/errors_test.go000066400000000000000000000036731514330267300154020ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "errors" "fmt" "net" "testing" "github.com/stretchr/testify/assert" ) var errExample = errors.New("an example error") func TestErrorUnwrap(t *testing.T) { cases := []struct { err error errUnwrapped []error }{ { &FatalError{Err: errExample}, []error{errExample}, }, { &TemporaryError{Err: errExample}, []error{errExample}, }, { &InternalError{Err: errExample}, []error{errExample}, }, { &TimeoutError{Err: errExample}, []error{errExample}, }, { &HandshakeError{Err: errExample}, []error{errExample}, }, } for _, c := range cases { c := c t.Run(fmt.Sprintf("%T", c.err), func(t *testing.T) { err := c.err for _, unwrapped := range c.errUnwrapped { assert.ErrorIs(t, errors.Unwrap(err), unwrapped) } }) } } func TestErrorNetError(t *testing.T) { cases := []struct { err error str string timeout, temporary bool }{ {&FatalError{Err: errExample}, "dtls fatal: an example error", false, false}, {&TemporaryError{Err: errExample}, "dtls temporary: an example error", false, true}, {&InternalError{Err: errExample}, "dtls internal: an example error", false, false}, {&TimeoutError{Err: errExample}, "dtls timeout: an example error", true, true}, {&HandshakeError{Err: errExample}, "handshake error: an example error", false, false}, {&HandshakeError{Err: &TimeoutError{Err: errExample}}, "handshake error: dtls timeout: an example error", true, true}, } for _, testCase := range cases { testCase := testCase t.Run(fmt.Sprintf("%T", testCase.err), func(t *testing.T) { var ne net.Error assert.ErrorAs(t, testCase.err, &ne) assert.Equal(t, testCase.timeout, ne.Timeout()) assert.Equal(t, testCase.temporary, ne.Temporary()) //nolint:staticcheck assert.Equal(t, testCase.str, ne.Error()) }) } } dtls-3.1.2/examples/000077500000000000000000000000001514330267300143055ustar00rootroot00000000000000dtls-3.1.2/examples/certificates/000077500000000000000000000000001514330267300167525ustar00rootroot00000000000000dtls-3.1.2/examples/certificates/README.md000066400000000000000000000024321514330267300202320ustar00rootroot00000000000000# Certificates The certificates in for the examples are generated using the commands shown below. Note that this was run on OpenSSL 1.1.1d, of which the arguments can be found in the [OpenSSL Manpages](https://www.openssl.org/docs/man1.1.1/man1), and is not guaranteed to work on different OpenSSL versions. ```shell # Extensions required for certificate validation. $ EXTFILE='extfile.conf' $ echo 'subjectAltName = IP:127.0.0.1\nbasicConstraints = critical,CA:true' > "${EXTFILE}" # Server. $ SERVER_NAME='server' $ openssl ecparam -name prime256v1 -genkey -noout -out "${SERVER_NAME}.pem" $ openssl req -key "${SERVER_NAME}.pem" -new -sha256 -subj '/C=NL' -out "${SERVER_NAME}.csr" $ openssl x509 -req -in "${SERVER_NAME}.csr" -extfile "${EXTFILE}" -days 365 -signkey "${SERVER_NAME}.pem" -sha256 -out "${SERVER_NAME}.pub.pem" # Client. $ CLIENT_NAME='client' $ openssl ecparam -name prime256v1 -genkey -noout -out "${CLIENT_NAME}.pem" $ openssl req -key "${CLIENT_NAME}.pem" -new -sha256 -subj '/C=NL' -out "${CLIENT_NAME}.csr" $ openssl x509 -req -in "${CLIENT_NAME}.csr" -extfile "${EXTFILE}" -days 365 -CA "${SERVER_NAME}.pub.pem" -CAkey "${SERVER_NAME}.pem" -set_serial '0xabcd' -sha256 -out "${CLIENT_NAME}.pub.pem" # Cleanup. $ rm "${EXTFILE}" "${SERVER_NAME}.csr" "${CLIENT_NAME}.csr" ``` dtls-3.1.2/examples/certificates/client.pem000066400000000000000000000005061514330267300207340ustar00rootroot00000000000000SPDX-FileCopyrightText: 2026 The Pion community SPDX-License-Identifier: CC0-1.0 -----BEGIN EC PRIVATE KEY----- MHcCAQEEIGOO78dEAcepxdUIeDzC28jMcFrJr2q7x+UdhgtJ/RS3oAoGCCqGSM49 AwEHoUQDQgAEGLSNxlkJ9mETKI2Hogq3Cyh06pJKA1YMgcKqYKS6yQQlvvk5rU88 +RojFPgXJukymhfIJmw4eGxxEMSjuEZY7w== -----END EC PRIVATE KEY----- dtls-3.1.2/examples/certificates/client.pub.pem000066400000000000000000000010701514330267300215160ustar00rootroot00000000000000SPDX-FileCopyrightText: 2026 The Pion community SPDX-License-Identifier: CC0-1.0 -----BEGIN CERTIFICATE----- MIIBLTCB1aADAgECAgMAq80wCgYIKoZIzj0EAwIwDTELMAkGA1UEBhMCTkwwHhcN MjAwMzIwMDk0NjQ0WhcNMjEwMzIwMDk0NjQ0WjANMQswCQYDVQQGEwJOTDBZMBMG ByqGSM49AgEGCCqGSM49AwEHA0IABBi0jcZZCfZhEyiNh6IKtwsodOqSSgNWDIHC qmCkuskEJb75Oa1PPPkaIxT4FybpMpoXyCZsOHhscRDEo7hGWO+jJDAiMA8GA1Ud EQQIMAaHBH8AAAEwDwYDVR0TAQH/BAUwAwEB/zAKBggqhkjOPQQDAgNHADBEAiBx sIkcADN9E60veZOFOeANaRWAiQaLWZfUxqkOmfHztQIgI2CfHMjDQwJZFh35HvFs NOPJj8wxFhqR5pqMF23cgOY= -----END CERTIFICATE----- dtls-3.1.2/examples/certificates/server.pem000066400000000000000000000005061514330267300207640ustar00rootroot00000000000000SPDX-FileCopyrightText: 2026 The Pion community SPDX-License-Identifier: CC0-1.0 -----BEGIN EC PRIVATE KEY----- MHcCAQEEIDT8Xyx5RpPP+98ulYZKsvKIVdBUJug/L9H2M8JThv+GoAoGCCqGSM49 AwEHoUQDQgAE6Wf0qQqIb5G7g51P83Dh1Yst52kyntGYz1Bt6S7crpmQFs9ZRZMy bJ6MGIwGcVBMgoL3pfxDKdZ3mnzmoibU0w== -----END EC PRIVATE KEY----- dtls-3.1.2/examples/certificates/server.pub.pem000066400000000000000000000011201514330267300215420ustar00rootroot00000000000000SPDX-FileCopyrightText: 2026 The Pion community SPDX-License-Identifier: CC0-1.0 -----BEGIN CERTIFICATE----- MIIBPzCB5qADAgECAhRtzyVTL+9D0KHfbcKYeKckpLVRmTAKBggqhkjOPQQDAjAN MQswCQYDVQQGEwJOTDAeFw0yMDAzMjAwOTQ2NDRaFw0yMTAzMjAwOTQ2NDRaMA0x CzAJBgNVBAYTAk5MMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE6Wf0qQqIb5G7 g51P83Dh1Yst52kyntGYz1Bt6S7crpmQFs9ZRZMybJ6MGIwGcVBMgoL3pfxDKdZ3 mnzmoibU06MkMCIwDwYDVR0RBAgwBocEfwAAATAPBgNVHRMBAf8EBTADAQH/MAoG CCqGSM49BAMCA0gAMEUCIQD000SU+klkNLGvHZcMYNVkCFsImnGKIqPMy3LELSiF 0gIgSGIFkNEIAyNxn44CXZJu3piyz1ouK2fLefDJMYfcXgM= -----END CERTIFICATE----- dtls-3.1.2/examples/dial/000077500000000000000000000000001514330267300152165ustar00rootroot00000000000000dtls-3.1.2/examples/dial/cid/000077500000000000000000000000001514330267300157555ustar00rootroot00000000000000dtls-3.1.2/examples/dial/cid/main.go000066400000000000000000000025311514330267300172310ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT // Package main implements an example DTLS client using a pre-shared key. package main import ( "context" "fmt" "net" "time" "github.com/pion/dtls/v3" "github.com/pion/dtls/v3/examples/util" ) func main() { // Prepare the IP to connect to addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444} // // Everything below is the pion-DTLS API! Thanks for using it ❤️. // // Connect to a DTLS server ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() dtlsConn, err := dtls.DialWithOptions("udp", addr, dtls.WithPSK(func(hint []byte) ([]byte, error) { fmt.Printf("Server's hint: %s \n", hint) return []byte{0xAB, 0xC1, 0x23}, nil }), dtls.WithPSKIdentityHint([]byte("Pion DTLS Client")), dtls.WithCipherSuites(dtls.TLS_PSK_WITH_AES_128_CCM_8), dtls.WithExtendedMasterSecret(dtls.RequireExtendedMasterSecret), dtls.WithConnectionIDGenerator(dtls.OnlySendCIDGenerator()), ) util.Check(err) defer func() { util.Check(dtlsConn.Close()) }() if err := dtlsConn.HandshakeContext(ctx); err != nil { fmt.Printf("Failed to handshake with server: %v\n", err) return } fmt.Println("Connected; type 'exit' to shutdown gracefully") // Simulate a chat session util.Chat(dtlsConn) } dtls-3.1.2/examples/dial/psk/000077500000000000000000000000001514330267300160135ustar00rootroot00000000000000dtls-3.1.2/examples/dial/psk/main.go000066400000000000000000000024101514330267300172630ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT // Package main implements an example DTLS client using a pre-shared key. package main import ( "context" "fmt" "net" "time" "github.com/pion/dtls/v3" "github.com/pion/dtls/v3/examples/util" ) func main() { // Prepare the IP to connect to addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444} // // Everything below is the pion-DTLS API! Thanks for using it ❤️. // // Connect to a DTLS server ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() dtlsConn, err := dtls.DialWithOptions("udp", addr, dtls.WithPSK(func(hint []byte) ([]byte, error) { fmt.Printf("Server's hint: %s \n", hint) return []byte{0xAB, 0xC1, 0x23}, nil }), dtls.WithPSKIdentityHint([]byte{}), dtls.WithCipherSuites(dtls.TLS_PSK_WITH_AES_128_CCM_8), dtls.WithExtendedMasterSecret(dtls.RequireExtendedMasterSecret), ) util.Check(err) defer func() { util.Check(dtlsConn.Close()) }() if err := dtlsConn.HandshakeContext(ctx); err != nil { fmt.Printf("Failed to handshake with server: %v\n", err) return } fmt.Println("Connected; type 'exit' to shutdown gracefully") // Simulate a chat session util.Chat(dtlsConn) } dtls-3.1.2/examples/dial/selfsign/000077500000000000000000000000001514330267300170305ustar00rootroot00000000000000dtls-3.1.2/examples/dial/selfsign/main.go000066400000000000000000000024421514330267300203050ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT // Package main implements a DTLS client using self-signed certificates. package main import ( "context" "fmt" "net" "time" "github.com/pion/dtls/v3" "github.com/pion/dtls/v3/examples/util" "github.com/pion/dtls/v3/pkg/crypto/selfsign" ) func main() { // Prepare the IP to connect to addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444} // Generate a certificate and private key to secure the connection certificate, genErr := selfsign.GenerateSelfSigned() util.Check(genErr) // // Everything below is the pion-DTLS API! Thanks for using it ❤️. // // Connect to a DTLS server ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() dtlsConn, err := dtls.DialWithOptions("udp", addr, dtls.WithCertificates(certificate), dtls.WithInsecureSkipVerify(true), dtls.WithExtendedMasterSecret(dtls.RequireExtendedMasterSecret), ) util.Check(err) defer func() { util.Check(dtlsConn.Close()) }() if err := dtlsConn.HandshakeContext(ctx); err != nil { fmt.Printf("Failed to handshake with server: %v\n", err) return } fmt.Println("Connected; type 'exit' to shutdown gracefully") // Simulate a chat session util.Chat(dtlsConn) } dtls-3.1.2/examples/dial/verify/000077500000000000000000000000001514330267300165225ustar00rootroot00000000000000dtls-3.1.2/examples/dial/verify/main.go000066400000000000000000000027551514330267300200060ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT // Package main implements a DTLS client using a client certificate. package main import ( "context" "crypto/x509" "fmt" "net" "time" "github.com/pion/dtls/v3" "github.com/pion/dtls/v3/examples/util" ) func main() { // Prepare the IP to connect to addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444} // // Everything below is the pion-DTLS API! Thanks for using it ❤️. // certificate, err := util.LoadKeyAndCertificate("examples/certificates/client.pem", "examples/certificates/client.pub.pem") util.Check(err) rootCertificate, err := util.LoadCertificate("examples/certificates/server.pub.pem") util.Check(err) certPool := x509.NewCertPool() cert, err := x509.ParseCertificate(rootCertificate.Certificate[0]) util.Check(err) certPool.AddCert(cert) // Connect to a DTLS server ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() dtlsConn, err := dtls.DialWithOptions("udp", addr, dtls.WithCertificates(certificate), dtls.WithExtendedMasterSecret(dtls.RequireExtendedMasterSecret), dtls.WithRootCAs(certPool), ) util.Check(err) defer func() { util.Check(dtlsConn.Close()) }() if err := dtlsConn.HandshakeContext(ctx); err != nil { fmt.Printf("Failed to handshake with server: %v\n", err) return } fmt.Println("Connected; type 'exit' to shutdown gracefully") // Simulate a chat session util.Chat(dtlsConn) } dtls-3.1.2/examples/listen/000077500000000000000000000000001514330267300156035ustar00rootroot00000000000000dtls-3.1.2/examples/listen/cid/000077500000000000000000000000001514330267300163425ustar00rootroot00000000000000dtls-3.1.2/examples/listen/cid/main.go000066400000000000000000000034001514330267300176120ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT // Package main implements a DTLS server using a pre-shared key. package main import ( "context" "fmt" "net" "time" "github.com/pion/dtls/v3" "github.com/pion/dtls/v3/examples/util" ) func main() { // Prepare the IP to connect to addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444} // // Everything below is the pion-DTLS API! Thanks for using it ❤️. // listener, err := dtls.ListenWithOptions("udp", addr, dtls.WithPSK(func(hint []byte) ([]byte, error) { fmt.Printf("Client's hint: %s \n", hint) return []byte{0xAB, 0xC1, 0x23}, nil }), dtls.WithPSKIdentityHint([]byte("Pion DTLS Server")), dtls.WithCipherSuites(dtls.TLS_PSK_WITH_AES_128_CCM_8), dtls.WithExtendedMasterSecret(dtls.RequireExtendedMasterSecret), dtls.WithConnectionIDGenerator(dtls.RandomCIDGenerator(8)), ) util.Check(err) defer func() { util.Check(listener.Close()) }() fmt.Println("Listening") // Simulate a chat session hub := util.NewHub() go func() { for { // Wait for a connection. conn, err := listener.Accept() util.Check(err) // defer conn.Close() // TODO: graceful shutdown // `conn` is of type `net.Conn` but may be casted to `dtls.Conn` // using `dtlsConn := conn.(*dtls.Conn)` in order to to expose // functions like `ConnectionState` etc. // Perform the handshake with a 30-second timeout ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) dtlsConn, ok := conn.(*dtls.Conn) if ok { util.Check(dtlsConn.HandshakeContext(ctx)) } cancel() // Register the connection with the chat hub if err == nil { hub.Register(conn) } } }() // Start chatting hub.Chat() } dtls-3.1.2/examples/listen/psk/000077500000000000000000000000001514330267300164005ustar00rootroot00000000000000dtls-3.1.2/examples/listen/psk/main.go000066400000000000000000000033021514330267300176510ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT // Package main implements a DTLS server using a pre-shared key. package main import ( "context" "fmt" "net" "time" "github.com/pion/dtls/v3" "github.com/pion/dtls/v3/examples/util" ) func main() { // Prepare the IP to connect to addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444} // // Everything below is the pion-DTLS API! Thanks for using it ❤️. // listener, err := dtls.ListenWithOptions("udp", addr, dtls.WithPSK(func(hint []byte) ([]byte, error) { fmt.Printf("Client's hint: %s \n", hint) return []byte{0xAB, 0xC1, 0x23}, nil }), dtls.WithPSKIdentityHint([]byte("Pion DTLS Server")), dtls.WithCipherSuites(dtls.TLS_PSK_WITH_AES_128_CCM_8), dtls.WithExtendedMasterSecret(dtls.RequireExtendedMasterSecret), ) util.Check(err) defer func() { util.Check(listener.Close()) }() fmt.Println("Listening") // Simulate a chat session hub := util.NewHub() go func() { for { // Wait for a connection. conn, err := listener.Accept() util.Check(err) // defer conn.Close() // TODO: graceful shutdown // `conn` is of type `net.Conn` but may be casted to `dtls.Conn` // using `dtlsConn := conn.(*dtls.Conn)` in order to to expose // functions like `ConnectionState` etc. // Perform the handshake with a 30-second timeout ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) dtlsConn, ok := conn.(*dtls.Conn) if ok { util.Check(dtlsConn.HandshakeContext(ctx)) } cancel() // Register the connection with the chat hub if err == nil { hub.Register(conn) } } }() // Start chatting hub.Chat() } dtls-3.1.2/examples/listen/selfsign/000077500000000000000000000000001514330267300174155ustar00rootroot00000000000000dtls-3.1.2/examples/listen/selfsign/main.go000066400000000000000000000032671514330267300207000ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT // Package main implements an example DTLS server using self-signed certificates. package main import ( "context" "fmt" "net" "time" "github.com/pion/dtls/v3" "github.com/pion/dtls/v3/examples/util" "github.com/pion/dtls/v3/pkg/crypto/selfsign" ) func main() { // Prepare the IP to connect to addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444} // Generate a certificate and private key to secure the connection certificate, genErr := selfsign.GenerateSelfSigned() util.Check(genErr) // // Everything below is the pion-DTLS API! Thanks for using it ❤️. // listener, err := dtls.ListenWithOptions("udp", addr, dtls.WithCertificates(certificate), dtls.WithExtendedMasterSecret(dtls.RequireExtendedMasterSecret), ) util.Check(err) defer func() { util.Check(listener.Close()) }() fmt.Println("Listening") // Simulate a chat session hub := util.NewHub() go func() { for { // Wait for a connection. conn, err := listener.Accept() util.Check(err) // defer conn.Close() // TODO: graceful shutdown // `conn` is of type `net.Conn` but may be casted to `dtls.Conn` // using `dtlsConn := conn.(*dtls.Conn)` in order to to expose // functions like `ConnectionState` etc. // Perform the handshake with a 30-second timeout ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) dtlsConn, ok := conn.(*dtls.Conn) if ok { util.Check(dtlsConn.HandshakeContext(ctx)) } cancel() // Register the connection with the chat hub if err == nil { hub.Register(conn) } } }() // Start chatting hub.Chat() } dtls-3.1.2/examples/listen/verify-brute-force-protection/000077500000000000000000000000001514330267300235065ustar00rootroot00000000000000dtls-3.1.2/examples/listen/verify-brute-force-protection/main.go000066400000000000000000000101571514330267300247650ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT // Package main implements an example DTLS server which verifies client certificates. // It also implements a basic Brute Force Attack protection. package main import ( "context" "crypto/x509" "fmt" "net" "sync" "time" "github.com/pion/dtls/v3" "github.com/pion/dtls/v3/examples/util" ) func main() { // Prepare the IP to connect to addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444} // // Everything below is the pion-DTLS API! Thanks for using it ❤️. // // ************ Variables used to implement a basic Brute Force Attack protection ************* var ( attempts = make(map[string]int) // Map of attempts for each IP address. attemptsMutex sync.Mutex // Mutex for the map of attempts. attemptsCleaner = time.Now() // Time to be able to clean the map of attempts every X minutes. ) certificate, err := util.LoadKeyAndCertificate("examples/certificates/server.pem", "examples/certificates/server.pub.pem") util.Check(err) rootCertificate, err := util.LoadCertificate("examples/certificates/server.pub.pem") util.Check(err) certPool := x509.NewCertPool() cert, err := x509.ParseCertificate(rootCertificate.Certificate[0]) util.Check(err) certPool.AddCert(cert) listener, err := dtls.ListenWithOptions("udp", addr, dtls.WithCertificates(certificate), dtls.WithExtendedMasterSecret(dtls.RequireExtendedMasterSecret), dtls.WithClientAuth(dtls.RequireAndVerifyClientCert), dtls.WithClientCAs(certPool), // This function will be called on each connection attempt. dtls.WithOnConnectionAttempt(func(addr net.Addr) error { // *************** Brute Force Attack protection *************** // Check if the IP address is in the map, and if the IP address has exceeded the limit attemptsMutex.Lock() defer attemptsMutex.Unlock() // Here I implement a time cleaner for the map of attempts, every 5 minutes I will // decrement by 1 the number of attempts for each IP address. if time.Now().After(attemptsCleaner.Add(time.Minute * 5)) { attemptsCleaner = time.Now() for k, v := range attempts { if v > 0 { attempts[k]-- } if attempts[k] == 0 { delete(attempts, k) } } } // Check if the IP address is in the map, and the IP address has exceeded the limit (Brute Force Attack protection) attemptIP := addr.(*net.UDPAddr).IP.String() //nolint if attempts[attemptIP] > 10 { return fmt.Errorf("too many attempts from this IP address") //nolint } // Here I increment the number of attempts for this IP address (Brute Force Attack protection) attempts[attemptIP]++ // *************** END Brute Force Attack protection END *************** return nil }), ) util.Check(err) defer func() { util.Check(listener.Close()) }() fmt.Println("Listening") // Simulate a chat session hub := util.NewHub() go func() { for { // Wait for a connection. conn, err := listener.Accept() util.Check(err) // defer conn.Close() // TODO: graceful shutdown // `conn` is of type `net.Conn` but may be casted to `dtls.Conn` // using `dtlsConn := conn.(*dtls.Conn)` in order to to expose // functions like `ConnectionState` etc. // *************** Brute Force Attack protection *************** // Here I decrease the number of attempts for this IP address attemptsMutex.Lock() attemptIP := conn.(*dtls.Conn).RemoteAddr().(*net.UDPAddr).IP.String() //nolint attempts[attemptIP]-- // If the number of attempts for this IP address is 0, I delete the IP address from the map if attempts[attemptIP] == 0 { delete(attempts, attemptIP) } attemptsMutex.Unlock() // *************** END Brute Force Attack protection END *************** // Perform the handshake with a 30-second timeout ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) dtlsConn, ok := conn.(*dtls.Conn) if ok { util.Check(dtlsConn.HandshakeContext(ctx)) } cancel() // Register the connection with the chat hub hub.Register(conn) } }() // Start chatting hub.Chat() } dtls-3.1.2/examples/listen/verify/000077500000000000000000000000001514330267300171075ustar00rootroot00000000000000dtls-3.1.2/examples/listen/verify/main.go000066400000000000000000000037201514330267300203640ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT // Package main implements an example DTLS server which verifies client certificates. package main import ( "context" "crypto/x509" "fmt" "net" "time" "github.com/pion/dtls/v3" "github.com/pion/dtls/v3/examples/util" ) func main() { // Prepare the IP to connect to addr := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 4444} // // Everything below is the pion-DTLS API! Thanks for using it ❤️. // certificate, err := util.LoadKeyAndCertificate("examples/certificates/server.pem", "examples/certificates/server.pub.pem") util.Check(err) rootCertificate, err := util.LoadCertificate("examples/certificates/server.pub.pem") util.Check(err) certPool := x509.NewCertPool() cert, err := x509.ParseCertificate(rootCertificate.Certificate[0]) util.Check(err) certPool.AddCert(cert) listener, err := dtls.ListenWithOptions("udp", addr, dtls.WithCertificates(certificate), dtls.WithExtendedMasterSecret(dtls.RequireExtendedMasterSecret), dtls.WithClientAuth(dtls.RequireAndVerifyClientCert), dtls.WithClientCAs(certPool), ) util.Check(err) defer func() { util.Check(listener.Close()) }() fmt.Println("Listening") // Simulate a chat session hub := util.NewHub() go func() { for { // Wait for a connection. conn, err := listener.Accept() util.Check(err) // defer conn.Close() // TODO: graceful shutdown // `conn` is of type `net.Conn` but may be casted to `dtls.Conn` // using `dtlsConn := conn.(*dtls.Conn)` in order to to expose // functions like `ConnectionState` etc. // Perform the handshake with a 30-second timeout ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) dtlsConn, ok := conn.(*dtls.Conn) if ok { util.Check(dtlsConn.HandshakeContext(ctx)) } cancel() // Register the connection with the chat hub hub.Register(conn) } }() // Start chatting hub.Chat() } dtls-3.1.2/examples/util/000077500000000000000000000000001514330267300152625ustar00rootroot00000000000000dtls-3.1.2/examples/util/hub.go000066400000000000000000000031551514330267300163730ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT package util import ( "bufio" "fmt" "net" "os" "strings" "sync" ) // Hub is a helper to handle one to many chat. type Hub struct { conns map[string]net.Conn lock sync.RWMutex } // NewHub builds a new hub. func NewHub() *Hub { return &Hub{conns: make(map[string]net.Conn)} } // Register adds a new conn to the Hub. func (h *Hub) Register(conn net.Conn) { fmt.Printf("Connected to %s\n", conn.RemoteAddr()) h.lock.Lock() defer h.lock.Unlock() h.conns[conn.RemoteAddr().String()] = conn go h.readLoop(conn) } func (h *Hub) readLoop(conn net.Conn) { b := make([]byte, bufSize) for { n, err := conn.Read(b) if err != nil { h.unregister(conn) return } fmt.Printf("Got message: %s\n", string(b[:n])) } } func (h *Hub) unregister(conn net.Conn) { h.lock.Lock() defer h.lock.Unlock() delete(h.conns, conn.RemoteAddr().String()) err := conn.Close() if err != nil { fmt.Println("Failed to disconnect", conn.RemoteAddr(), err) } else { fmt.Println("Disconnected ", conn.RemoteAddr()) } } func (h *Hub) broadcast(msg []byte) { h.lock.RLock() defer h.lock.RUnlock() for _, conn := range h.conns { _, err := conn.Write(msg) if err != nil { fmt.Printf("Failed to write message to %s: %v\n", conn.RemoteAddr(), err) } } } // Chat starts the stdin readloop to dispatch messages to the hub. func (h *Hub) Chat() { reader := bufio.NewReader(os.Stdin) for { msg, err := reader.ReadString('\n') Check(err) if strings.TrimSpace(msg) == "exit" { return } h.broadcast([]byte(msg)) } } dtls-3.1.2/examples/util/util.go000066400000000000000000000040551514330267300165720ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT // Package util provides auxiliary utilities used in examples package util //nolint: revive import ( "bufio" "crypto/tls" "encoding/pem" "errors" "fmt" "io" "net" "os" "path/filepath" "strings" ) const bufSize = 8192 var ( errBlockIsNotCertificate = errors.New("block is not a certificate, unable to load certificates") errNoCertificateFound = errors.New("no certificate found, unable to load certificates") ) // Chat simulates a simple text chat session over the connection. func Chat(conn io.ReadWriter) { go func() { b := make([]byte, bufSize) for { n, err := conn.Read(b) Check(err) fmt.Printf("Got message: %s\n", string(b[:n])) } }() reader := bufio.NewReader(os.Stdin) for { text, err := reader.ReadString('\n') Check(err) if strings.TrimSpace(text) == "exit" { return } _, err = conn.Write([]byte(text)) Check(err) } } // Check is a helper to throw errors in the examples. func Check(err error) { var netError net.Error if errors.As(err, &netError) && netError.Temporary() { //nolint:staticcheck fmt.Printf("Warning: %v\n", err) } else if err != nil { fmt.Printf("error: %v\n", err) panic(err) } } // LoadKeyAndCertificate reads certificates or key from file. func LoadKeyAndCertificate(keyPath string, certificatePath string) (tls.Certificate, error) { return tls.LoadX509KeyPair(certificatePath, keyPath) } // LoadCertificate Load/read certificate(s) from file. func LoadCertificate(path string) (*tls.Certificate, error) { rawData, err := os.ReadFile(filepath.Clean(path)) if err != nil { return nil, err } var certificate tls.Certificate for { block, rest := pem.Decode(rawData) if block == nil { break } if block.Type != "CERTIFICATE" { return nil, errBlockIsNotCertificate } certificate.Certificate = append(certificate.Certificate, block.Bytes) rawData = rest } if len(certificate.Certificate) == 0 { return nil, errNoCertificateFound } return &certificate, nil } dtls-3.1.2/flight.go000066400000000000000000000061541514330267300143010ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT package dtls /* DTLS messages are grouped into a series of message flights, according to the diagrams below. Although each flight of messages may consist of a number of messages, they should be viewed as monolithic for the purpose of timeout and retransmission. https://tools.ietf.org/html/rfc4347#section-4.2.4 Message flights for full handshake: Client Server ------ ------ Waiting Flight 0 ClientHello --------> Flight 1 <------- HelloVerifyRequest Flight 2 ClientHello --------> Flight 3 ServerHello \ Certificate* \ ServerKeyExchange* Flight 4 CertificateRequest* / <-------- ServerHelloDone / Certificate* \ ClientKeyExchange \ CertificateVerify* Flight 5 [ChangeCipherSpec] / Finished --------> / [ChangeCipherSpec] \ Flight 6 <-------- Finished / Message flights for session-resuming handshake (no cookie exchange): Client Server ------ ------ Waiting Flight 0 ClientHello --------> Flight 1 ServerHello \ [ChangeCipherSpec] Flight 4b <-------- Finished / [ChangeCipherSpec] \ Flight 5b Finished --------> / [ChangeCipherSpec] \ Flight 6 <-------- Finished / */ type flightVal uint8 const ( flight0 flightVal = iota + 1 flight1 flight2 flight3 flight4 flight4b flight5 flight5b flight6 ) func (f flightVal) String() string { //nolint:cyclop switch f { case flight0: return "Flight 0" case flight1: return "Flight 1" case flight2: return "Flight 2" case flight3: return "Flight 3" case flight4: return "Flight 4" case flight4b: return "Flight 4b" case flight5: return "Flight 5" case flight5b: return "Flight 5b" case flight6: return "Flight 6" default: return "Invalid Flight" } } func (f flightVal) isLastSendFlight() bool { return f == flight6 || f == flight5b } func (f flightVal) isLastRecvFlight() bool { return f == flight5 || f == flight4b } dtls-3.1.2/flight0handler.go000066400000000000000000000133501514330267300157130ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "context" "crypto/rand" "github.com/pion/dtls/v3/pkg/crypto/elliptic" "github.com/pion/dtls/v3/pkg/protocol" "github.com/pion/dtls/v3/pkg/protocol/alert" "github.com/pion/dtls/v3/pkg/protocol/extension" "github.com/pion/dtls/v3/pkg/protocol/handshake" ) // renegotiationInfoSCSV is TLS_EMPTY_RENEGOTIATION_INFO_SCSV defined in RFC 5746. // https://datatracker.ietf.org/doc/html/rfc5746#section-3.3. const renegotiationInfoSCSV uint16 = 0x00ff //nolint:cyclop,gocognit func flight0Parse( _ context.Context, _ flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig, ) (flightVal, *alert.Alert, error) { seq, msgs, ok := cache.fullPullMap(0, state.cipherSuite, handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, ) if !ok { // No valid message received. Keep reading return 0, nil, nil } // Connection Identifiers must be negotiated afresh on session resumption. // https://datatracker.ietf.org/doc/html/rfc9146#name-the-connection_id-extension state.setLocalConnectionID(nil) state.remoteConnectionID = nil state.handshakeRecvSequence = seq var clientHello *handshake.MessageClientHello // Validate type if clientHello, ok = msgs[handshake.TypeClientHello].(*handshake.MessageClientHello); !ok { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil } if !clientHello.Version.Equal(protocol.Version1_2) { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion } state.remoteRandom = clientHello.Random cipherSuites := []CipherSuite{} for _, id := range clientHello.CipherSuiteIDs { if id == renegotiationInfoSCSV { state.remoteSupportsRenegotiation = true continue } if c := cipherSuiteForID(CipherSuiteID(id), cfg.customCipherSuites); c != nil { cipherSuites = append(cipherSuites, c) } } if state.cipherSuite, ok = findMatchingCipherSuite(cipherSuites, cfg.localCipherSuites); !ok { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errCipherSuiteNoIntersection } for _, val := range clientHello.Extensions { switch ext := val.(type) { case *extension.SupportedEllipticCurves: if len(ext.EllipticCurves) == 0 { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errNoSupportedEllipticCurves } state.namedCurve = ext.EllipticCurves[0] case *extension.UseSRTP: profile, ok := findMatchingSRTPProfile(cfg.localSRTPProtectionProfiles, ext.ProtectionProfiles) if !ok { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errServerNoMatchingSRTPProfile } state.setSRTPProtectionProfile(profile) state.remoteSRTPMasterKeyIdentifier = ext.MasterKeyIdentifier case *extension.UseExtendedMasterSecret: if cfg.extendedMasterSecret != DisableExtendedMasterSecret { state.extendedMasterSecret = true } case *extension.ServerName: state.serverName = ext.ServerName // remote server name case *extension.RenegotiationInfo: state.remoteSupportsRenegotiation = true case *extension.ALPN: state.peerSupportedProtocols = ext.ProtocolNameList case *extension.ConnectionID: // Only set connection ID to be sent if server supports connection // IDs. if cfg.connectionIDGenerator != nil { state.remoteConnectionID = ext.CID } case *extension.SignatureAlgorithmsCert: // Store the client's certificate signature schemes for later validation state.remoteCertSignatureSchemes = ext.SignatureHashAlgorithms } } // If the client doesn't support connection IDs, the server should not // expect one to be sent. if state.remoteConnectionID == nil { state.setLocalConnectionID(nil) } if cfg.extendedMasterSecret == RequireExtendedMasterSecret && !state.extendedMasterSecret { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errServerRequiredButNoClientEMS } if state.localKeypair == nil { var err error state.localKeypair, err = elliptic.GenerateKeypair(state.namedCurve) if err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, err } } nextFlight := flight2 if cfg.insecureSkipHelloVerify { nextFlight = flight4 } return handleHelloResume(clientHello.SessionID, state, cfg, nextFlight) } func handleHelloResume( sessionID []byte, state *State, cfg *handshakeConfig, next flightVal, ) (flightVal, *alert.Alert, error) { if len(sessionID) > 0 && cfg.sessionStore != nil { if s, err := cfg.sessionStore.Get(sessionID); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } else if s.ID != nil { cfg.log.Tracef("[handshake] resume session: %x", sessionID) state.SessionID = sessionID state.masterSecret = s.Secret if err := state.initCipherSuite(); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } clientRandom := state.localRandom.MarshalFixed() cfg.writeKeyLog(keyLogLabelTLS12, clientRandom[:], state.masterSecret) return flight4b, nil, nil } } return next, nil, nil } func flight0Generate( _ flightConn, state *State, _ *handshakeCache, cfg *handshakeConfig, ) ([]*packet, *alert.Alert, error) { // Initialize if !cfg.insecureSkipHelloVerify { state.cookie = make([]byte, cookieLength) if _, err := rand.Read(state.cookie); err != nil { return nil, nil, err } } var zeroEpoch uint16 state.localEpoch.Store(zeroEpoch) state.remoteEpoch.Store(zeroEpoch) state.namedCurve = defaultNamedCurve if err := state.localRandom.Populate(); err != nil { return nil, nil, err } return nil, nil, nil } dtls-3.1.2/flight1handler.go000066400000000000000000000133221514330267300157130ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "context" "github.com/pion/dtls/v3/pkg/crypto/elliptic" "github.com/pion/dtls/v3/pkg/protocol" "github.com/pion/dtls/v3/pkg/protocol/alert" "github.com/pion/dtls/v3/pkg/protocol/extension" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) func flight1Parse( ctx context.Context, conn flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig, ) (flightVal, *alert.Alert, error) { // HelloVerifyRequest can be skipped by the server, // so allow ServerHello during flight1 also seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite, handshakeCachePullRule{handshake.TypeHelloVerifyRequest, cfg.initialEpoch, false, true}, handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, true}, ) if !ok { // No valid message received. Keep reading return 0, nil, nil } if _, ok := msgs[handshake.TypeServerHello]; ok { // Flight1 and flight2 were skipped. // Parse as flight3. return flight3Parse(ctx, conn, state, cache, cfg) } if h, ok := msgs[handshake.TypeHelloVerifyRequest].(*handshake.MessageHelloVerifyRequest); ok { // DTLS 1.2 clients must not assume that the server will use the protocol version // specified in HelloVerifyRequest message. RFC 6347 Section 4.2.1 if !h.Version.Equal(protocol.Version1_0) && !h.Version.Equal(protocol.Version1_2) { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion } state.cookie = append([]byte{}, h.Cookie...) state.handshakeRecvSequence = seq return flight3, nil, nil } return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil } //nolint:cyclop func flight1Generate( conn flightConn, state *State, _ *handshakeCache, cfg *handshakeConfig, ) ([]*packet, *alert.Alert, error) { var zeroEpoch uint16 state.localEpoch.Store(zeroEpoch) state.remoteEpoch.Store(zeroEpoch) state.namedCurve = defaultNamedCurve state.cookie = nil if err := state.localRandom.Populate(); err != nil { return nil, nil, err } if cfg.helloRandomBytesGenerator != nil { state.localRandom.RandomBytes = cfg.helloRandomBytesGenerator() } extensions := []extension.Extension{ &extension.SupportedSignatureAlgorithms{ SignatureHashAlgorithms: cfg.localSignatureSchemes, }, &extension.RenegotiationInfo{ RenegotiatedConnection: 0, }, } if len(cfg.localCertSignatureSchemes) > 0 { extensions = append(extensions, &extension.SignatureAlgorithmsCert{ SignatureHashAlgorithms: cfg.localCertSignatureSchemes, }) } var setEllipticCurveCryptographyClientHelloExtensions bool for _, c := range cfg.localCipherSuites { if c.ECC() { setEllipticCurveCryptographyClientHelloExtensions = true break } } if setEllipticCurveCryptographyClientHelloExtensions { extensions = append(extensions, []extension.Extension{ &extension.SupportedEllipticCurves{ EllipticCurves: cfg.ellipticCurves, }, &extension.SupportedPointFormats{ PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed}, }, }...) } if len(cfg.localSRTPProtectionProfiles) > 0 { extensions = append(extensions, &extension.UseSRTP{ ProtectionProfiles: cfg.localSRTPProtectionProfiles, MasterKeyIdentifier: cfg.localSRTPMasterKeyIdentifier, }) } if cfg.extendedMasterSecret == RequestExtendedMasterSecret || cfg.extendedMasterSecret == RequireExtendedMasterSecret { extensions = append(extensions, &extension.UseExtendedMasterSecret{ Supported: true, }) } if len(cfg.serverName) > 0 { extensions = append(extensions, &extension.ServerName{ServerName: cfg.serverName}) } if len(cfg.supportedProtocols) > 0 { extensions = append(extensions, &extension.ALPN{ProtocolNameList: cfg.supportedProtocols}) } if cfg.sessionStore != nil { cfg.log.Tracef("[handshake] try to resume session") if s, err := cfg.sessionStore.Get(conn.sessionKey()); err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } else if s.ID != nil { cfg.log.Tracef("[handshake] get saved session: %x", s.ID) state.SessionID = s.ID state.masterSecret = s.Secret } } // If we have a connection ID generator, use it. The CID may be zero length, // in which case we are just requesting that the server send us a CID to // use. if cfg.connectionIDGenerator != nil { state.setLocalConnectionID(cfg.connectionIDGenerator()) // The presence of a generator indicates support for connection IDs. We // use the presence of a non-nil local CID in flight 3 to determine // whether we send a CID in the second ClientHello, so we convert any // nil CID returned by a generator to []byte{}. if state.getLocalConnectionID() == nil { state.setLocalConnectionID([]byte{}) } extensions = append(extensions, &extension.ConnectionID{CID: state.getLocalConnectionID()}) } clientHello := &handshake.MessageClientHello{ Version: protocol.Version1_2, SessionID: state.SessionID, Cookie: state.cookie, Random: state.localRandom, CipherSuiteIDs: cipherSuiteIDs(cfg.localCipherSuites), CompressionMethods: defaultCompressionMethods(), Extensions: extensions, } var content handshake.Handshake if cfg.clientHelloMessageHook != nil { content = handshake.Handshake{Message: cfg.clientHelloMessageHook(*clientHello)} } else { content = handshake.Handshake{Message: clientHello} } return []*packet{ { record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &content, }, }, }, nil, nil } dtls-3.1.2/flight1handler_test.go000066400000000000000000000345421514330267300167610ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "context" "testing" "time" "github.com/pion/dtls/v3/internal/ciphersuite" "github.com/pion/dtls/v3/pkg/protocol/alert" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/logging" "github.com/pion/transport/v4/test" "github.com/stretchr/testify/assert" ) type flight1TestMockFlightConn struct{} func (f *flight1TestMockFlightConn) notify(context.Context, alert.Level, alert.Description) error { return nil } func (f *flight1TestMockFlightConn) writePackets(context.Context, []*packet) error { return nil } func (f *flight1TestMockFlightConn) recvHandshake() <-chan recvHandshakeState { return nil } func (f *flight1TestMockFlightConn) setLocalEpoch(uint16) {} func (f *flight1TestMockFlightConn) handleQueuedPackets(context.Context) error { return nil } func (f *flight1TestMockFlightConn) sessionKey() []byte { return nil } type flight1TestMockCipherSuite struct { ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256 t *testing.T } func (f *flight1TestMockCipherSuite) IsInitialized() bool { assert.Fail(f.t, "IsInitialized called with Certificate but not CertificateVerify") return true } // When "server hello" arrives later than "certificate", // "server key exchange", "certificate request", "server hello done", // is it normal for the flight1Parse method to handle it. func TestFlight1_Process_ServerHelloLateArrival(t *testing.T) { //nolint:maintidx // Limit runtime in case of deadlocks lim := test.TimeOut(5 * time.Second) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() mockConn := &flight1TestMockFlightConn{} state := &State{ cipherSuite: &flight1TestMockCipherSuite{t: t}, } cache := newHandshakeCache() cfg := &handshakeConfig{ localSRTPProtectionProfiles: []SRTPProtectionProfile{SRTP_AEAD_AES_128_GCM}, localCipherSuites: []CipherSuite{}, } cfg.localCipherSuites = []CipherSuite{cipherSuiteForID(TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, nil)} cfg.log = logging.NewDefaultLoggerFactory().NewLogger("dtls") serverHello := []byte{ 0x02, 0x00, 0x00, 0x62, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x62, 0xfe, 0xfd, 0x07, 0x46, 0xb7, 0xbf, 0xde, 0x78, 0xab, 0x38, 0x69, 0x36, 0x74, 0x10, 0xa6, 0x50, 0x67, 0x7b, 0x4b, 0x85, 0xdf, 0x71, 0x71, 0x62, 0x3a, 0xb1, 0xd7, 0xa4, 0x79, 0x6a, 0x38, 0x13, 0x5e, 0xa1, 0x20, 0xbd, 0x64, 0xaf, 0xb3, 0x36, 0x77, 0x73, 0x8a, 0x62, 0x75, 0xb2, 0x64, 0xbe, 0xf6, 0x2a, 0xb1, 0x6e, 0x7b, 0xf6, 0x00, 0xd6, 0x24, 0xd5, 0xb1, 0x1e, 0x54, 0xa3, 0x76, 0xb3, 0xac, 0x76, 0x8f, 0xc0, 0x2f, 0x00, 0x00, 0x1a, 0xff, 0x01, 0x00, 0x01, 0x00, 0x00, 0x0b, 0x00, 0x04, 0x03, 0x00, 0x01, 0x02, 0x00, 0x0e, 0x00, 0x05, 0x00, 0x02, 0x00, 0x07, 0x00, 0x00, 0x17, 0x00, 0x00, } certificate1 := []byte{ 0x0b, 0x00, 0x05, 0x5b, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x04, 0xe4, 0x00, 0x05, 0x58, 0x00, 0x05, 0x55, 0x30, 0x82, 0x05, 0x51, 0x30, 0x82, 0x04, 0x39, 0xa0, 0x03, 0x02, 0x01, 0x02, 0x02, 0x0c, 0x56, 0x8b, 0xb4, 0x68, 0xed, 0x70, 0xce, 0xb6, 0x8d, 0x44, 0x65, 0x4b, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x0b, 0x05, 0x00, 0x30, 0x66, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13, 0x02, 0x42, 0x45, 0x31, 0x19, 0x30, 0x17, 0x06, 0x03, 0x55, 0x04, 0x0a, 0x13, 0x10, 0x47, 0x6c, 0x6f, 0x62, 0x61, 0x6c, 0x53, 0x69, 0x67, 0x6e, 0x20, 0x6e, 0x76, 0x2d, 0x73, 0x61, 0x31, 0x3c, 0x30, 0x3a, 0x06, 0x03, 0x55, 0x04, 0x03, 0x13, 0x33, 0x47, 0x6c, 0x6f, 0x62, 0x61, 0x6c, 0x53, 0x69, 0x67, 0x6e, 0x20, 0x4f, 0x72, 0x67, 0x61, 0x6e, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x20, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x20, 0x43, 0x41, 0x20, 0x2d, 0x20, 0x53, 0x48, 0x41, 0x32, 0x35, 0x36, 0x20, 0x2d, 0x20, 0x47, 0x32, 0x30, 0x1e, 0x17, 0x0d, 0x31, 0x37, 0x30, 0x34, 0x32, 0x30, 0x31, 0x31, 0x31, 0x39, 0x35, 0x39, 0x5a, 0x17, 0x0d, 0x31, 0x38, 0x30, 0x34, 0x32, 0x31, 0x31, 0x31, 0x31, 0x39, 0x35, 0x39, 0x5a, 0x30, 0x81, 0x84, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13, 0x02, 0x43, 0x4e, 0x31, 0x12, 0x30, 0x10, 0x06, 0x03, 0x55, 0x04, 0x08, 0x13, 0x09, 0x67, 0x75, 0x61, 0x6e, 0x67, 0x64, 0x6f, 0x6e, 0x67, 0x31, 0x11, 0x30, 0x0f, 0x06, 0x03, 0x55, 0x04, 0x07, 0x13, 0x08, 0x73, 0x68, 0x65, 0x6e, 0x7a, 0x68, 0x65, 0x6e, 0x31, 0x36, 0x30, 0x34, 0x06, 0x03, 0x55, 0x04, 0x0a, 0x13, 0x2d, 0x54, 0x65, 0x6e, 0x63, 0x65, 0x6e, 0x74, 0x20, 0x54, 0x65, 0x63, 0x68, 0x6e, 0x6f, 0x6c, 0x6f, 0x67, 0x79, 0x20, 0x28, 0x53, 0x68, 0x65, 0x6e, 0x7a, 0x68, 0x65, 0x6e, 0x29, 0x20, 0x43, 0x6f, 0x6d, 0x70, 0x61, 0x6e, 0x79, 0x20, 0x4c, 0x69, 0x6d, 0x69, 0x74, 0x65, 0x64, 0x31, 0x16, 0x30, 0x14, 0x06, 0x03, 0x55, 0x04, 0x03, 0x13, 0x0d, 0x77, 0x65, 0x62, 0x72, 0x74, 0x63, 0x2e, 0x71, 0x71, 0x2e, 0x63, 0x6f, 0x6d, 0x30, 0x82, 0x01, 0x22, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x01, 0x05, 0x00, 0x03, 0x82, 0x01, 0x0f, 0x00, 0x30, 0x82, 0x01, 0x0a, 0x02, 0x82, 0x01, 0x01, 0x00, 0xb6, 0x00, 0xa7, 0x09, 0x0a, 0xc4, 0x96, 0x24, 0x72, 0xa0, 0x09, 0xda, 0xac, 0x63, 0xe4, 0x9a, 0xfe, 0x8b, 0x9b, 0x99, 0x8c, 0xe3, 0xab, 0x4b, 0x7c, 0xbd, 0x4f, 0x31, 0x1e, 0x2f, 0xff, 0x34, 0x54, 0xb5, 0xb0, 0x99, 0xcd, 0x00, 0x7c, 0x5b, 0x12, 0x96, 0xfa, 0x9b, 0x6b, 0x79, 0xc7, 0xfb, 0x00, 0x53, 0xaf, 0xb6, 0x00, 0x45, 0x46, 0x20, 0x7d, 0x95, 0xca, 0x86, 0xcc, 0x4b, 0xe8, 0x25, 0x52, 0x5b, 0x9c, 0xe7, 0x58, 0xcd, 0xd0, 0x8f, 0x4a, 0xd8, 0x77, 0x7d, 0x45, 0xa0, 0x70, 0xe8, 0x16, 0x45, 0x23, 0xfb, 0xbc, 0x43, 0x36, 0xdd, 0x5b, 0x8f, 0x01, 0xc3, 0xc0, 0xa2, 0xab, 0x80, 0xf1, 0x97, 0x72, 0x38, 0xab, 0x6f, 0xa1, 0x28, 0x09, 0xdd, 0x31, 0x7e, 0x50, 0xc8, 0x51, 0xde, 0x8d, 0x05, 0xbc, 0x72, 0x79, 0x94, 0x6e, 0xd4, 0xb7, 0xf0, 0x97, 0xd0, 0x76, 0x9c, 0x9d, 0xb4, 0x34, 0xf1, 0x8a, 0x82, 0x20, 0x9b, 0x24, 0x4b, 0x38, 0xc9, 0x63, 0xe6, 0x02, 0xf5, 0xb2, 0x9b, 0x70, 0xa4, 0x97, 0x9f, 0xaa, 0x1f, 0x36, 0x9c, 0xfd, 0x81, 0x93, 0x81, 0xd7, 0x4e, 0xca, 0xd2, 0xa7, 0x7c, 0x29, 0x9d, 0x28, 0xf2, 0x3e, 0x3b, 0xea, 0xe6, 0x22, 0x51, 0x8f, 0x0b, 0xe7, 0x65, 0xa1, 0x28, 0xdd, 0x55, 0x6a, 0x59, 0x53, 0x67, 0xb6, 0xb3, 0xd2, 0x4c, 0x90, 0x69, 0xd1, 0x1e, 0x62, 0xab, 0x33, 0x47, 0x29, 0x45, 0x18, 0x1f, 0xeb, 0x6d, 0x13, 0xb4, 0x61, 0xf5, 0x15, 0x03, 0xf7, 0x4f, 0x9c, 0x4c, 0x2c, 0xae, 0x5e, 0xde, 0xd2, 0x11, 0x32, 0xb5, 0x17, 0xb5, 0xe8, 0xa3, 0xb2, 0x1f, 0xc3, 0x9f, 0x78, 0xa1, 0xf5, 0x80, 0xb4, 0x96, 0x90, 0x6b, 0x77, 0x9e, 0xe9, 0x39, 0x61, 0x2c, 0x18, 0xf5, 0x7b, 0xab, 0x1e, 0x09, 0x88, 0x7d, 0xc3, 0x75, 0x5e, 0x4d, 0xcf, 0xf3, 0x02, 0x03, 0x01, 0x00, 0x01, 0xa3, 0x82, 0x01, 0xde, 0x30, 0x82, 0x01, 0xda, 0x30, 0x0e, 0x06, 0x03, 0x55, 0x1d, 0x0f, 0x01, 0x01, 0xff, 0x04, 0x04, 0x03, 0x02, 0x05, 0xa0, 0x30, 0x81, 0xa0, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x01, 0x01, 0x04, 0x81, 0x93, 0x30, 0x81, 0x90, 0x30, 0x4d, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x30, 0x02, 0x86, 0x41, 0x68, 0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x73, 0x65, 0x63, 0x75, 0x72, 0x65, 0x2e, 0x67, 0x6c, 0x6f, 0x62, 0x61, 0x6c, 0x73, 0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x63, 0x61, 0x63, 0x65, 0x72, 0x74, 0x2f, 0x67, 0x73, 0x6f, 0x72, 0x67, 0x61, 0x6e, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x76, 0x61, 0x6c, 0x73, 0x68, 0x61, 0x32, 0x67, 0x32, 0x72, 0x31, 0x2e, 0x63, 0x72, 0x74, 0x30, 0x3f, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x30, 0x01, 0x86, 0x33, 0x68, 0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x6f, 0x63, 0x73, 0x70, 0x32, 0x2e, 0x67, 0x6c, 0x6f, 0x62, 0x61, 0x6c, 0x73, 0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x67, 0x73, 0x6f, 0x72, 0x67, 0x61, 0x6e, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x76, 0x61, 0x6c, 0x73, 0x68, 0x61, 0x32, 0x67, 0x32, 0x30, 0x56, 0x06, 0x03, 0x55, 0x1d, 0x20, 0x04, 0x4f, 0x30, 0x4d, 0x30, 0x41, 0x06, 0x09, 0x2b, 0x06, 0x01, 0x04, 0x01, 0xa0, 0x32, 0x01, 0x14, 0x30, 0x34, 0x30, 0x32, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x02, 0x01, 0x16, 0x26, 0x68, 0x74, 0x74, 0x70, 0x73, 0x3a, 0x2f, 0x2f, 0x77, 0x77, 0x77, 0x2e, 0x67, 0x6c, 0x6f, 0x62, 0x61, 0x6c, 0x73, 0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x72, 0x65, 0x70, 0x6f, 0x73, 0x69, 0x74, 0x6f, 0x72, 0x79, 0x2f, 0x30, 0x08, 0x06, 0x06, 0x67, 0x81, 0x0c, 0x01, 0x02, 0x02, 0x30, 0x09, 0x06, 0x03, 0x55, 0x1d, 0x13, 0x04, 0x02, 0x30, 0x00, 0x30, 0x49, 0x06, 0x03, 0x55, 0x1d, 0x1f, 0x04, 0x42, 0x30, 0x40, 0x30, 0x3e, 0xa0, 0x3c, 0xa0, 0x3a, 0x86, 0x38, 0x68, 0x74, 0x74, 0x70, 0x3a, 0x2f, 0x2f, 0x63, 0x72, 0x6c, 0x2e, 0x67, 0x6c, 0x6f, 0x62, 0x61, 0x6c, 0x73, 0x69, 0x67, 0x6e, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x67, 0x73, 0x2f, 0x67, 0x73, 0x6f, 0x72, 0x67, 0x61, 0x6e, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x76, 0x61, 0x6c, 0x73, 0x68, 0x61, 0x32, 0x67, 0x32, 0x2e, 0x63, 0x72, 0x6c, 0x30, 0x18, 0x06, 0x03, 0x55, 0x1d, 0x11, 0x04, 0x11, 0x30, 0x0f, 0x82, 0x0d, 0x77, 0x65, 0x62, 0x72, 0x74, 0x63, 0x2e, 0x71, 0x71, 0x2e, 0x63, 0x6f, 0x6d, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x1d, 0x25, 0x04, 0x16, 0x30, 0x14, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x01, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x02, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x1d, 0x0e, 0x04, 0x16, 0x04, 0x14, 0x28, 0xff, 0xe2, 0x97, 0xf3, 0x6f, 0x2a, 0xef, 0x0f, 0xbc, 0x4c, 0x61, 0x9b, 0xd9, 0x23, 0x7b, 0x3a, 0xef, 0xc2, 0xe7, 0x30, 0x1f, 0x06, 0x03, 0x55, 0x1d, 0x23, 0x04, 0x18, 0x30, 0x16, 0x80, 0x14, 0x96, 0xde, 0x61, 0xf1, 0xbd, 0x1c, 0x16, 0x29, 0x53, 0x1c, 0xc0, 0xcc, 0x7d, 0x3b, 0x83, 0x00, 0x40, 0xe6, 0x1a, 0x7c, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x0b, 0x05, 0x00, 0x03, 0x82, 0x01, 0x01, 0x00, 0x30, 0xc1, 0xcc, 0xd6, 0x97, 0xf7, 0xf5, 0xa7, 0x93, 0xa5, 0x78, 0xc8, 0xcb, 0x81, 0x44, 0xd4, 0x1f, 0x2a, 0xa6, 0xc1, 0x48, 0xa8, 0x1a, 0xbd, 0x17, 0x10, 0x0e, 0xdf, 0x21, 0xea, 0x02, 0x3e, 0xb3, 0xbd, 0x45, 0x1e, 0x64, 0x85, 0x3f, 0x04, 0x9a, 0xc0, 0x78, 0xf4, 0x81, 0x2e, 0x38, 0x39, 0x3a, 0x04, 0x2d, 0x5f, 0xec, 0xc4, 0x10, 0x57, 0xfb, 0x1b, 0x32, 0xe0, 0x8e, 0xfc, 0xe3, 0x6d, 0x4b, 0xc6, 0xf0, 0x07, 0xb7, 0xc6, 0x19, 0xd7, 0x99, 0x93, 0xbd, 0x60, 0x58, 0xad, 0xbb, 0x94, 0xcf, 0xd8, 0x05, 0x5c, 0x14, 0x70, 0xec, 0x2e, 0xb7, 0x60, 0x52, 0x3c, 0xd3, 0x03, 0xf8, 0xcd, 0xe5, 0x4e, 0x84, 0xcf, 0xef, 0x2f, 0x12, 0xdd, 0x74, 0xfd, 0x95, 0x9d, 0x03, 0xa9, 0x81, 0x18, 0x3a, 0x6e, 0xe6, 0xc2, 0xdd, 0x07, 0x1e, 0xea, 0x8c, 0xe6, 0xd9, 0x31, 0x72, 0x63, 0x25, 0xcd, 0xf2, 0x19, 0xf2, 0x4e, 0x3c, 0x18, 0xfb, 0xb2, 0x74, } certificate2 := []byte{ 0x0b, 0x00, 0x05, 0x5b, 0x00, 0x01, 0x00, 0x04, 0xe4, 0x00, 0x00, 0x77, 0xc1, 0x6b, 0x67, 0xec, 0x34, 0x05, 0xe8, 0x63, 0xfc, 0x74, 0x4b, 0x11, 0x3f, 0x3a, 0xe4, 0x4e, 0x06, 0x89, 0x96, 0x24, 0x3c, 0x15, 0x83, 0xc5, 0x1d, 0xeb, 0xc0, 0x19, 0x71, 0x35, 0x6c, 0xfa, 0xf1, 0x51, 0x06, 0x0e, 0x8e, 0xfb, 0x9b, 0x4e, 0xaa, 0x50, 0x24, 0x77, 0xac, 0x86, 0x14, 0x50, 0x52, 0x35, 0x68, 0x15, 0x9b, 0xdd, 0x8b, 0xdb, 0x83, 0x1d, 0xed, 0x45, 0x05, 0x78, 0x53, 0xd6, 0xc4, 0x21, 0xaf, 0x68, 0x45, 0x91, 0xe7, 0x30, 0x36, 0x4c, 0xb1, 0xfb, 0xf1, 0x65, 0x9a, 0xe4, 0x49, 0x90, 0x1c, 0x0c, 0xa8, 0x63, 0xe9, 0x04, 0xe3, 0x17, 0x61, 0x8d, 0x20, 0x29, 0xca, 0x41, 0xa6, 0x8b, 0x32, 0x53, 0xa5, 0x84, 0x29, 0x5a, 0x62, 0xe7, 0x84, 0x38, 0x32, 0x56, 0xbb, 0x8b, 0xbc, 0x25, 0xc7, 0xa3, 0x28, 0x3b, 0x35, } serverKeyExchange := []byte{ 0x0c, 0x00, 0x01, 0x28, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x01, 0x28, 0x03, 0x00, 0x1d, 0x20, 0x59, 0xa2, 0x0f, 0xc4, 0x7b, 0xd8, 0x03, 0xf6, 0xb0, 0xcf, 0x5d, 0xf0, 0x45, 0x7f, 0x7e, 0xf2, 0x98, 0xab, 0xc0, 0x24, 0xf1, 0xdf, 0xba, 0x63, 0x3e, 0xfb, 0xe5, 0x02, 0x31, 0xcf, 0xd1, 0x05, 0x04, 0x01, 0x01, 0x00, 0x7b, 0x52, 0x9c, 0xe7, 0x54, 0x8b, 0xb0, 0xc9, 0xfd, 0xaf, 0xe2, 0x91, 0x19, 0x9d, 0x6c, 0xb8, 0xbe, 0xa5, 0xe1, 0x48, 0xa0, 0xfd, 0xc5, 0x76, 0x62, 0x47, 0xf2, 0xd1, 0x35, 0x76, 0x4e, 0x33, 0xf4, 0xa1, 0xf1, 0x58, 0xdc, 0xd5, 0x45, 0x3f, 0x76, 0x64, 0x40, 0xba, 0x32, 0xe3, 0x07, 0xb7, 0x4b, 0xbe, 0xe2, 0x77, 0x99, 0xad, 0x11, 0x73, 0x54, 0xe6, 0xbb, 0xfb, 0xd4, 0xb1, 0x83, 0x9f, 0xc6, 0x50, 0xc6, 0xd8, 0xbb, 0x92, 0x0d, 0x93, 0xf9, 0x63, 0x29, 0xf9, 0xc3, 0xce, 0x24, 0x40, 0x29, 0x95, 0x43, 0xf0, 0x32, 0x00, 0x21, 0xde, 0xdf, 0x64, 0xfe, 0xb6, 0x11, 0xa0, 0x11, 0x44, 0x12, 0x2a, 0x1c, 0x96, 0x44, 0x4b, 0x79, 0x31, 0x23, 0x46, 0x4e, 0xe8, 0x16, 0x5b, 0xf5, 0x9a, 0x5f, 0x51, 0x10, 0x5b, 0x11, 0xa3, 0xb8, 0x1f, 0xb7, 0xf1, 0x11, 0xad, 0x05, 0x82, 0x2b, 0xc3, 0x65, 0x8c, 0x41, 0xb4, 0x8e, 0x60, 0x42, 0x89, 0x92, 0xd1, 0x83, 0x73, 0xe7, 0x35, 0xb4, 0xc9, 0xd1, 0xbc, 0x5c, 0x84, 0x5b, 0xdb, 0x44, 0x34, 0xea, 0xd8, 0x06, 0xe4, 0xfb, 0xbd, 0x40, 0x35, 0x18, 0x60, 0x33, 0xb6, 0xed, 0xbc, 0x9b, 0x3a, 0xff, 0x2f, 0xa1, 0xe8, 0x5d, 0x5c, 0xbb, 0xe8, 0xe1, 0xa6, 0xbb, 0x84, 0x0f, 0x50, 0x51, 0x0d, 0xa5, 0x8f, 0x96, 0xb6, 0x35, 0x37, 0x7b, 0x58, 0xaf, 0x4f, 0x77, 0x9d, 0x5d, 0xb2, 0xff, 0x5f, 0xd6, 0xb8, 0x82, 0x64, 0x5f, 0x79, 0xd0, 0x06, 0x44, 0x6d, 0x3a, 0x82, 0x25, 0x21, 0xca, 0xbb, 0xa0, 0x79, 0xdd, 0x6e, 0x15, 0xb6, 0x57, 0x9b, 0x04, 0x84, 0x63, 0x88, 0x1d, 0x41, 0xff, 0xe1, 0x20, 0x61, 0xd5, 0x3f, 0xc7, 0xca, 0x0c, 0xd9, 0xe0, 0x74, 0x86, 0x78, 0xed, 0x60, 0x18, 0x2d, 0x9e, 0x69, 0x66, 0x77, 0xf7, 0xd0, 0xe9, 0x9c, } certificateRequest := []byte{ 0x0d, 0x00, 0x00, 0x26, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x26, 0x03, 0x01, 0x02, 0x40, 0x00, 0x1e, 0x06, 0x01, 0x06, 0x02, 0x06, 0x03, 0x05, 0x01, 0x05, 0x02, 0x05, 0x03, 0x04, 0x01, 0x04, 0x02, 0x04, 0x03, 0x03, 0x01, 0x03, 0x02, 0x03, 0x03, 0x02, 0x01, 0x02, 0x02, 0x02, 0x03, 0x00, 0x00, } serverHelloDone := []byte{ 0x0e, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, } cache.push(certificate2, 0, 2, handshake.TypeCertificate, false) cache.push(serverKeyExchange, 0, 3, handshake.TypeServerKeyExchange, false) cache.push(certificateRequest, 0, 4, handshake.TypeCertificateRequest, false) cache.push(serverHelloDone, 0, 5, handshake.TypeServerHelloDone, false) _, alt, err := flight1Parse(context.TODO(), mockConn, state, cache, cfg) assert.NoError(t, err) assert.Nil(t, alt) cache.push(serverHello, 0, 0, handshake.TypeServerHello, false) cache.push(certificate1, 0, 1, handshake.TypeCertificate, false) _, alt, err = flight1Parse(context.TODO(), mockConn, state, cache, cfg) assert.NoError(t, err) assert.Nil(t, alt) } dtls-3.1.2/flight2handler.go000066400000000000000000000037611514330267300157220ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "bytes" "context" "github.com/pion/dtls/v3/pkg/protocol" "github.com/pion/dtls/v3/pkg/protocol/alert" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) func flight2Parse( ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig, ) (flightVal, *alert.Alert, error) { seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite, handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, ) if !ok { // Client may retransmit the first ClientHello when HelloVerifyRequest is dropped. // Parse as flight 0 in this case. return flight0Parse(ctx, c, state, cache, cfg) } state.handshakeRecvSequence = seq var clientHello *handshake.MessageClientHello // Validate type if clientHello, ok = msgs[handshake.TypeClientHello].(*handshake.MessageClientHello); !ok { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil } if !clientHello.Version.Equal(protocol.Version1_2) { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion } if len(clientHello.Cookie) == 0 { return 0, nil, nil } if !bytes.Equal(state.cookie, clientHello.Cookie) { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.AccessDenied}, errCookieMismatch } return flight4, nil, nil } func flight2Generate( _ flightConn, state *State, _ *handshakeCache, _ *handshakeConfig, ) ([]*packet, *alert.Alert, error) { state.handshakeSendSequence = 0 return []*packet{ { record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &handshake.Handshake{ Message: &handshake.MessageHelloVerifyRequest{ Version: protocol.Version1_2, Cookie: state.cookie, }, }, }, }, }, nil, nil } dtls-3.1.2/flight3handler.go000066400000000000000000000314271514330267300157230ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "bytes" "context" "github.com/pion/dtls/v3/internal/ciphersuite/types" "github.com/pion/dtls/v3/pkg/crypto/elliptic" "github.com/pion/dtls/v3/pkg/crypto/prf" "github.com/pion/dtls/v3/pkg/protocol" "github.com/pion/dtls/v3/pkg/protocol/alert" "github.com/pion/dtls/v3/pkg/protocol/extension" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) //nolint:gocognit,gocyclo,maintidx,cyclop func flight3Parse( ctx context.Context, conn flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig, ) (flightVal, *alert.Alert, error) { // Clients may receive multiple HelloVerifyRequest messages with different cookies. // Clients SHOULD handle this by sending a new ClientHello with a cookie in response // to the new HelloVerifyRequest. RFC 6347 Section 4.2.1 seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite, handshakeCachePullRule{handshake.TypeHelloVerifyRequest, cfg.initialEpoch, false, true}, ) if ok { if h, msgOk := msgs[handshake.TypeHelloVerifyRequest].(*handshake.MessageHelloVerifyRequest); msgOk { // DTLS 1.2 clients must not assume that the server will use the protocol version // specified in HelloVerifyRequest message. RFC 6347 Section 4.2.1 if !h.Version.Equal(protocol.Version1_0) && !h.Version.Equal(protocol.Version1_2) { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion } state.cookie = append([]byte{}, h.Cookie...) state.handshakeRecvSequence = seq return flight3, nil, nil } } _, msgs, ok = cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite, handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, ) if !ok { // Don't have enough messages. Keep reading return 0, nil, nil } if serverHelloMsg, msgOk := msgs[handshake.TypeServerHello].(*handshake.MessageServerHello); msgOk { //nolint:nestif if !serverHelloMsg.Version.Equal(protocol.Version1_2) { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.ProtocolVersion}, errUnsupportedProtocolVersion } for _, v := range serverHelloMsg.Extensions { switch ext := v.(type) { case *extension.UseSRTP: profile, found := findMatchingSRTPProfile(ext.ProtectionProfiles, cfg.localSRTPProtectionProfiles) if !found { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, errClientNoMatchingSRTPProfile } state.setSRTPProtectionProfile(profile) state.remoteSRTPMasterKeyIdentifier = ext.MasterKeyIdentifier case *extension.UseExtendedMasterSecret: if cfg.extendedMasterSecret != DisableExtendedMasterSecret { state.extendedMasterSecret = true } case *extension.ALPN: if len(ext.ProtocolNameList) > 1 { // This should be exactly 1, the zero case is handle when unmarshalling return 0, &alert.Alert{ Level: alert.Fatal, Description: alert.InternalError, }, extension.ErrALPNInvalidFormat // Meh, internal error? } state.NegotiatedProtocol = ext.ProtocolNameList[0] case *extension.ConnectionID: // Only set connection ID to be sent if client supports connection // IDs. if cfg.connectionIDGenerator != nil { state.remoteConnectionID = ext.CID } } } // If the server doesn't support connection IDs, the client should not // expect one to be sent. if state.remoteConnectionID == nil { state.setLocalConnectionID(nil) } if cfg.extendedMasterSecret == RequireExtendedMasterSecret && !state.extendedMasterSecret { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errClientRequiredButNoServerEMS } if len(cfg.localSRTPProtectionProfiles) > 0 && state.getSRTPProtectionProfile() == 0 { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errRequestedButNoSRTPExtension } remoteCipherSuite := cipherSuiteForID(CipherSuiteID(*serverHelloMsg.CipherSuiteID), cfg.customCipherSuites) if remoteCipherSuite == nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errCipherSuiteNoIntersection } selectedCipherSuite, found := findMatchingCipherSuite([]CipherSuite{remoteCipherSuite}, cfg.localCipherSuites) if !found { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errInvalidCipherSuite } state.cipherSuite = selectedCipherSuite state.remoteRandom = serverHelloMsg.Random cfg.log.Tracef("[handshake] use cipher suite: %s", selectedCipherSuite.String()) if len(serverHelloMsg.SessionID) > 0 && bytes.Equal(state.SessionID, serverHelloMsg.SessionID) { return handleResumption(ctx, conn, state, cache, cfg) } if len(state.SessionID) > 0 { cfg.log.Tracef("[handshake] clean old session : %s", state.SessionID) if err := cfg.sessionStore.Del(state.SessionID); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } } if cfg.sessionStore == nil { state.SessionID = []byte{} } else { state.SessionID = serverHelloMsg.SessionID } state.masterSecret = []byte{} } if cfg.localPSKCallback != nil { seq, msgs, ok = cache.fullPullMap(state.handshakeRecvSequence+1, state.cipherSuite, handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, true}, handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false}, ) } else { seq, msgs, ok = cache.fullPullMap(state.handshakeRecvSequence+1, state.cipherSuite, handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, true}, handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, true}, handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false}, ) } if !ok { // Don't have enough messages. Keep reading return 0, nil, nil } state.handshakeRecvSequence = seq if h, ok := msgs[handshake.TypeCertificate].(*handshake.MessageCertificate); ok { state.PeerCertificates = h.Certificate } else if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.NoCertificate}, errInvalidCertificate } if h, ok := msgs[handshake.TypeServerKeyExchange].(*handshake.MessageServerKeyExchange); ok { alertPtr, err := handleServerKeyExchange(conn, state, cfg, h) if err != nil { return 0, alertPtr, err } } if creq, ok := msgs[handshake.TypeCertificateRequest].(*handshake.MessageCertificateRequest); ok { state.remoteCertRequestAlgs = creq.SignatureHashAlgorithms state.remoteRequestedCertificate = true } return flight5, nil, nil } func handleResumption( ctx context.Context, c flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig, ) (flightVal, *alert.Alert, error) { if err := state.initCipherSuite(); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } // Now, encrypted packets can be handled if err := c.handleQueuedPackets(ctx); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence+1, state.cipherSuite, handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, false, false}, ) if !ok { // No valid message received. Keep reading return 0, nil, nil } var finished *handshake.MessageFinished if finished, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil } plainText := cache.pullAndMerge( handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, ) expectedVerifyData, err := prf.VerifyDataServer(state.masterSecret, plainText, state.cipherSuite.HashFunc()) if err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } if !bytes.Equal(expectedVerifyData, finished.VerifyData) { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errVerifyDataMismatch } clientRandom := state.localRandom.MarshalFixed() cfg.writeKeyLog(keyLogLabelTLS12, clientRandom[:], state.masterSecret) return flight5b, nil, nil } //nolint:cyclop func handleServerKeyExchange( _ flightConn, state *State, cfg *handshakeConfig, keyExchangeMessage *handshake.MessageServerKeyExchange, ) (*alert.Alert, error) { var err error if state.cipherSuite == nil { return &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errInvalidCipherSuite } if cfg.localPSKCallback != nil { //nolint:nestif var psk []byte if psk, err = cfg.localPSKCallback(keyExchangeMessage.IdentityHint); err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } state.IdentityHint = keyExchangeMessage.IdentityHint switch state.cipherSuite.KeyExchangeAlgorithm() { case types.KeyExchangeAlgorithmPsk: state.preMasterSecret = prf.PSKPreMasterSecret(psk) case (types.KeyExchangeAlgorithmEcdhe | types.KeyExchangeAlgorithmPsk): if state.localKeypair, err = elliptic.GenerateKeypair(keyExchangeMessage.NamedCurve); err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } state.preMasterSecret, err = prf.EcdhePSKPreMasterSecret( psk, keyExchangeMessage.PublicKey, state.localKeypair.PrivateKey, state.localKeypair.Curve, ) if err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } default: return &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errInvalidCipherSuite } } else { if state.localKeypair, err = elliptic.GenerateKeypair(keyExchangeMessage.NamedCurve); err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } if state.preMasterSecret, err = prf.PreMasterSecret( keyExchangeMessage.PublicKey, state.localKeypair.PrivateKey, state.localKeypair.Curve, ); err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } } return nil, nil //nolint:nilnil } func flight3Generate( _ flightConn, state *State, _ *handshakeCache, cfg *handshakeConfig, ) ([]*packet, *alert.Alert, error) { extensions := []extension.Extension{ &extension.SupportedSignatureAlgorithms{ SignatureHashAlgorithms: cfg.localSignatureSchemes, }, &extension.RenegotiationInfo{ RenegotiatedConnection: 0, }, } if len(cfg.localCertSignatureSchemes) > 0 { extensions = append(extensions, &extension.SignatureAlgorithmsCert{ SignatureHashAlgorithms: cfg.localCertSignatureSchemes, }) } if state.namedCurve != 0 { extensions = append(extensions, []extension.Extension{ &extension.SupportedEllipticCurves{ EllipticCurves: cfg.ellipticCurves, }, &extension.SupportedPointFormats{ PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed}, }, }...) } if len(cfg.localSRTPProtectionProfiles) > 0 { extensions = append(extensions, &extension.UseSRTP{ ProtectionProfiles: cfg.localSRTPProtectionProfiles, }) } if cfg.extendedMasterSecret == RequestExtendedMasterSecret || cfg.extendedMasterSecret == RequireExtendedMasterSecret { extensions = append(extensions, &extension.UseExtendedMasterSecret{ Supported: true, }) } if len(cfg.serverName) > 0 { extensions = append(extensions, &extension.ServerName{ServerName: cfg.serverName}) } if len(cfg.supportedProtocols) > 0 { extensions = append(extensions, &extension.ALPN{ProtocolNameList: cfg.supportedProtocols}) } // If we sent a connection ID on the first ClientHello, send it on the // second. if state.getLocalConnectionID() != nil { extensions = append(extensions, &extension.ConnectionID{CID: state.getLocalConnectionID()}) } clientHello := &handshake.MessageClientHello{ Version: protocol.Version1_2, SessionID: state.SessionID, Cookie: state.cookie, Random: state.localRandom, CipherSuiteIDs: cipherSuiteIDs(cfg.localCipherSuites), CompressionMethods: defaultCompressionMethods(), Extensions: extensions, } var content handshake.Handshake if cfg.clientHelloMessageHook != nil { content = handshake.Handshake{Message: cfg.clientHelloMessageHook(*clientHello)} } else { content = handshake.Handshake{Message: clientHello} } return []*packet{ { record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &content, }, }, }, nil, nil } dtls-3.1.2/flight3handler_test.go000066400000000000000000000053431514330267300167600ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "context" "math/rand" "testing" "time" "github.com/pion/dtls/v3/pkg/crypto/elliptic" dtlsnet "github.com/pion/dtls/v3/pkg/net" "github.com/pion/dtls/v3/pkg/protocol/extension" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/dtls/v3/pkg/protocol/recordlayer" "github.com/pion/transport/v4/dpipe" "github.com/pion/transport/v4/test" "github.com/stretchr/testify/assert" ) // Assert that SupportedEllipticCurves is only sent when a ECC CipherSuite is available. func TestSupportedEllipticCurves(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() expectedCurves := defaultCurves var actualCurves []elliptic.Curve rand.Shuffle(len(expectedCurves), func(i, j int) { expectedCurves[i], expectedCurves[j] = expectedCurves[j], expectedCurves[i] }) clientErr := make(chan error, 1) ca, cb := dpipe.Pipe() caAnalyzer := &connWithCallback{Conn: ca} caAnalyzer.onWrite = func(in []byte) { messages, err := recordlayer.UnpackDatagram(in) assert.NoError(t, err) for i := range messages { h := &handshake.Handshake{} _ = h.Unmarshal(messages[i][recordlayer.FixedHeaderSize:]) if h.Header.Type == handshake.TypeClientHello { //nolint:nestif clientHello := &handshake.MessageClientHello{} msg, err := h.Message.Marshal() assert.NoError(t, err) assert.NoError(t, clientHello.Unmarshal(msg)) for _, e := range clientHello.Extensions { if e.TypeValue() == extension.SupportedEllipticCurvesTypeValue { if c, ok := e.(*extension.SupportedEllipticCurves); ok { actualCurves = c.EllipticCurves } } } } } } go func() { conf := &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, EllipticCurves: expectedCurves, } if client, err := testClient( ctx, dtlsnet.PacketConnFromConn(caAnalyzer), caAnalyzer.RemoteAddr(), conf, false, ); err != nil { clientErr <- err } else { clientErr <- client.Close() // nolint:errcheck,contextcheck } }() config := &Config{ CipherSuites: []CipherSuiteID{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, } server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), config, true) assert.NoError(t, err) assert.NoError(t, server.Close()) assert.NoError(t, <-clientErr) for i := range expectedCurves { assert.Equal(t, expectedCurves[i], actualCurves[i], "curves in SupportedEllipticCurves mismatch") } } dtls-3.1.2/flight4bhandler.go000066400000000000000000000116251514330267300160640ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "bytes" "context" "github.com/pion/dtls/v3/pkg/crypto/prf" "github.com/pion/dtls/v3/pkg/protocol" "github.com/pion/dtls/v3/pkg/protocol/alert" "github.com/pion/dtls/v3/pkg/protocol/extension" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) func flight4bParse( _ context.Context, _ flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig, ) (flightVal, *alert.Alert, error) { _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite, handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false}, ) if !ok { // No valid message received. Keep reading return 0, nil, nil } var finished *handshake.MessageFinished if finished, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil } plainText := cache.pullAndMerge( handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, false, false}, ) expectedVerifyData, err := prf.VerifyDataClient(state.masterSecret, plainText, state.cipherSuite.HashFunc()) if err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } if !bytes.Equal(expectedVerifyData, finished.VerifyData) { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errVerifyDataMismatch } // Other party may re-transmit the last flight. Keep state to be flight4b. return flight4b, nil, nil } //nolint:cyclop func flight4bGenerate( _ flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig, ) ([]*packet, *alert.Alert, error) { var pkts []*packet extensions := []extension.Extension{&extension.RenegotiationInfo{ RenegotiatedConnection: 0, }} if (cfg.extendedMasterSecret == RequestExtendedMasterSecret || cfg.extendedMasterSecret == RequireExtendedMasterSecret) && state.extendedMasterSecret { extensions = append(extensions, &extension.UseExtendedMasterSecret{ Supported: true, }) } if state.getSRTPProtectionProfile() != 0 { extensions = append(extensions, &extension.UseSRTP{ ProtectionProfiles: []SRTPProtectionProfile{state.getSRTPProtectionProfile()}, MasterKeyIdentifier: cfg.localSRTPMasterKeyIdentifier, }) } selectedProto, err := extension.ALPNProtocolSelection(cfg.supportedProtocols, state.peerSupportedProtocols) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.NoApplicationProtocol}, err } if selectedProto != "" { extensions = append(extensions, &extension.ALPN{ ProtocolNameList: []string{selectedProto}, }) state.NegotiatedProtocol = selectedProto } cipherSuiteID := uint16(state.cipherSuite.ID()) var serverHello handshake.Handshake serverHelloMessage := &handshake.MessageServerHello{ Version: protocol.Version1_2, Random: state.localRandom, SessionID: state.SessionID, CipherSuiteID: &cipherSuiteID, CompressionMethod: defaultCompressionMethods()[0], Extensions: extensions, } if cfg.serverHelloMessageHook != nil { serverHello = handshake.Handshake{Message: cfg.serverHelloMessageHook(*serverHelloMessage)} } else { serverHello = handshake.Handshake{Message: serverHelloMessage} } serverHello.Header.MessageSequence = uint16(state.handshakeSendSequence) //nolint:gosec // G115 if len(state.localVerifyData) == 0 { plainText := cache.pullAndMerge( handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, ) raw, err := serverHello.Marshal() if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } plainText = append(plainText, raw...) state.localVerifyData, err = prf.VerifyDataServer(state.masterSecret, plainText, state.cipherSuite.HashFunc()) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } } pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &serverHello, }, }, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &protocol.ChangeCipherSpec{}, }, }, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, Epoch: 1, }, Content: &handshake.Handshake{ Message: &handshake.MessageFinished{ VerifyData: state.localVerifyData, }, }, }, shouldEncrypt: true, resetLocalSequenceNumber: true, }, ) return pkts, nil, nil } dtls-3.1.2/flight4handler.go000066400000000000000000000423521514330267300157230ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "context" "crypto" "crypto/rand" "crypto/x509" "github.com/pion/dtls/v3/internal/ciphersuite" "github.com/pion/dtls/v3/pkg/crypto/clientcertificate" "github.com/pion/dtls/v3/pkg/crypto/elliptic" "github.com/pion/dtls/v3/pkg/crypto/prf" "github.com/pion/dtls/v3/pkg/crypto/signaturehash" "github.com/pion/dtls/v3/pkg/protocol" "github.com/pion/dtls/v3/pkg/protocol/alert" "github.com/pion/dtls/v3/pkg/protocol/extension" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) //nolint:gocognit,gocyclo,lll,cyclop,maintidx func flight4Parse( ctx context.Context, conn flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig, ) (flightVal, *alert.Alert, error) { seq, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite, handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, true}, handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeCertificateVerify, cfg.initialEpoch, true, true}, ) if !ok { // No valid message received. Keep reading return 0, nil, nil } // Validate type var clientKeyExchange *handshake.MessageClientKeyExchange if clientKeyExchange, ok = msgs[handshake.TypeClientKeyExchange].(*handshake.MessageClientKeyExchange); !ok { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil } if h, hasCert := msgs[handshake.TypeCertificate].(*handshake.MessageCertificate); hasCert { state.PeerCertificates = h.Certificate // If the client offer its certificate, just disable session resumption. // Otherwise, we have to store the certificate identitfication and expire time. // And we have to check whether this certificate expired, revoked or changed. // // https://curl.se/docs/CVE-2016-5419.html state.SessionID = nil } //nolint:nestif if verify, hasVerify := msgs[handshake.TypeCertificateVerify].(*handshake.MessageCertificateVerify); hasVerify { if state.PeerCertificates == nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.NoCertificate}, errCertificateVerifyNoCertificate } plainText := cache.pullAndMerge( handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false}, ) // Verify that the pair of hash algorithm and signiture is listed. var validSignatureScheme bool for _, ss := range cfg.localSignatureSchemes { if ss.Hash == verify.HashAlgorithm && ss.Signature == verify.SignatureAlgorithm { validSignatureScheme = true break } } if !validSignatureScheme { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errNoAvailableSignatureSchemes } if err := verifyCertificateVerify( plainText, verify.HashAlgorithm, verify.SignatureAlgorithm, verify.Signature, state.PeerCertificates, ); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err } var chains [][]*x509.Certificate var err error var verified bool if cfg.clientAuth >= VerifyClientCertIfGiven { // Use cert-specific algorithms if present, otherwise fall back to signature_algorithms per RFC 8446 certAlgs := cfg.localCertSignatureSchemes if len(certAlgs) == 0 { certAlgs = cfg.localSignatureSchemes } if chains, err = verifyClientCert(state.PeerCertificates, cfg.clientCAs, certAlgs); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err } verified = true } if cfg.verifyPeerCertificate != nil { if err := cfg.verifyPeerCertificate(state.PeerCertificates, chains); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err } } state.peerCertificatesVerified = verified } else if state.PeerCertificates != nil { // A certificate was received, but we haven't seen a CertificateVerify // keep reading until we receive one return 0, nil, nil } if !state.cipherSuite.IsInitialized() { //nolint:nestif serverRandom := state.localRandom.MarshalFixed() clientRandom := state.remoteRandom.MarshalFixed() var err error var preMasterSecret []byte if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypePreSharedKey { var psk []byte if psk, err = cfg.localPSKCallback(clientKeyExchange.IdentityHint); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } state.IdentityHint = clientKeyExchange.IdentityHint switch state.cipherSuite.KeyExchangeAlgorithm() { case CipherSuiteKeyExchangeAlgorithmPsk: preMasterSecret = prf.PSKPreMasterSecret(psk) case (CipherSuiteKeyExchangeAlgorithmPsk | CipherSuiteKeyExchangeAlgorithmEcdhe): if preMasterSecret, err = prf.EcdhePSKPreMasterSecret( psk, clientKeyExchange.PublicKey, state.localKeypair.PrivateKey, state.localKeypair.Curve, ); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } default: return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errInvalidCipherSuite } } else { preMasterSecret, err = prf.PreMasterSecret( clientKeyExchange.PublicKey, state.localKeypair.PrivateKey, state.localKeypair.Curve, ) if err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, err } } if state.extendedMasterSecret { var sessionHash []byte sessionHash, err = cache.sessionHash(state.cipherSuite.HashFunc(), cfg.initialEpoch) if err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } state.masterSecret, err = prf.ExtendedMasterSecret(preMasterSecret, sessionHash, state.cipherSuite.HashFunc()) if err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } } else { state.masterSecret, err = prf.MasterSecret( preMasterSecret, clientRandom[:], serverRandom[:], state.cipherSuite.HashFunc(), ) if err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } } if err := state.cipherSuite.Init(state.masterSecret, clientRandom[:], serverRandom[:], false); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } cfg.writeKeyLog(keyLogLabelTLS12, clientRandom[:], state.masterSecret) } if len(state.SessionID) > 0 { s := Session{ ID: state.SessionID, Secret: state.masterSecret, } cfg.log.Tracef("[handshake] save new session: %x", s.ID) if err := cfg.sessionStore.Set(state.SessionID, s); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } } // Now, encrypted packets can be handled if err := conn.handleQueuedPackets(ctx); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } seq, msgs, ok = cache.fullPullMap(seq, state.cipherSuite, handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false}, ) if !ok { // No valid message received. Keep reading return 0, nil, nil } state.handshakeRecvSequence = seq if _, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil } if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeAnonymous { //nolint:nestif if cfg.verifyConnection != nil { stateClone, err := state.clone() if err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } if err := cfg.verifyConnection(stateClone); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err } } return flight6, nil, nil } switch cfg.clientAuth { case RequireAnyClientCert: if state.PeerCertificates == nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.NoCertificate}, errClientCertificateRequired } case VerifyClientCertIfGiven: if state.PeerCertificates != nil && !state.peerCertificatesVerified { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, errClientCertificateNotVerified } case RequireAndVerifyClientCert: if state.PeerCertificates == nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.NoCertificate}, errClientCertificateRequired } if !state.peerCertificatesVerified { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, errClientCertificateNotVerified } case NoClientCert, RequestClientCert: // go to flight6 } if cfg.verifyConnection != nil { stateClone, err := state.clone() if err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } if err := cfg.verifyConnection(stateClone); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err } } return flight6, nil, nil } //nolint:gocognit,cyclop,maintidx func flight4Generate( _ flightConn, state *State, _ *handshakeCache, cfg *handshakeConfig, ) ([]*packet, *alert.Alert, error) { extensions := []extension.Extension{} if (cfg.extendedMasterSecret == RequestExtendedMasterSecret || cfg.extendedMasterSecret == RequireExtendedMasterSecret) && state.extendedMasterSecret { extensions = append(extensions, &extension.UseExtendedMasterSecret{ Supported: true, }) } if state.getSRTPProtectionProfile() != 0 { extensions = append(extensions, &extension.UseSRTP{ ProtectionProfiles: []SRTPProtectionProfile{state.getSRTPProtectionProfile()}, MasterKeyIdentifier: cfg.localSRTPMasterKeyIdentifier, }) } if state.remoteSupportsRenegotiation { extensions = append(extensions, &extension.RenegotiationInfo{ RenegotiatedConnection: 0, }) } if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate { extensions = append(extensions, &extension.SupportedPointFormats{ PointFormats: []elliptic.CurvePointFormat{elliptic.CurvePointFormatUncompressed}, }) } selectedProto, err := extension.ALPNProtocolSelection(cfg.supportedProtocols, state.peerSupportedProtocols) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.NoApplicationProtocol}, err } if selectedProto != "" { extensions = append(extensions, &extension.ALPN{ ProtocolNameList: []string{selectedProto}, }) state.NegotiatedProtocol = selectedProto } // If we have a connection ID generator, we are willing to use connection // IDs. We already know whether the client supports connection IDs from // parsing the ClientHello, so avoid setting local connection ID if the // client won't send it. if cfg.connectionIDGenerator != nil && state.remoteConnectionID != nil { state.setLocalConnectionID(cfg.connectionIDGenerator()) extensions = append(extensions, &extension.ConnectionID{CID: state.getLocalConnectionID()}) } var pkts []*packet cipherSuiteID := uint16(state.cipherSuite.ID()) if cfg.sessionStore != nil { state.SessionID = make([]byte, sessionLength) if _, err := rand.Read(state.SessionID); err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } } serverHello := &handshake.MessageServerHello{ Version: protocol.Version1_2, Random: state.localRandom, SessionID: state.SessionID, CipherSuiteID: &cipherSuiteID, CompressionMethod: defaultCompressionMethods()[0], Extensions: extensions, } var content handshake.Handshake if cfg.serverHelloMessageHook != nil { content = handshake.Handshake{Message: cfg.serverHelloMessageHook(*serverHello)} } else { content = handshake.Handshake{Message: serverHello} } pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &content, }, }) switch { case state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate: certificate, err := cfg.getCertificate(&ClientHelloInfo{ ServerName: state.serverName, CipherSuites: []ciphersuite.ID{state.cipherSuite.ID()}, RandomBytes: state.remoteRandom.RandomBytes, }) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, err } pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &handshake.Handshake{ Message: &handshake.MessageCertificate{ Certificate: certificate.Certificate, }, }, }, }) serverRandom := state.localRandom.MarshalFixed() clientRandom := state.remoteRandom.MarshalFixed() signer, ok := certificate.PrivateKey.(crypto.Signer) if !ok { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errInvalidPrivateKey } // Find compatible signature scheme signatureHashAlgo, err := signaturehash.SelectSignatureScheme(cfg.localSignatureSchemes, signer) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, err } signature, err := generateKeySignature( clientRandom[:], serverRandom[:], state.localKeypair.PublicKey, state.namedCurve, signer, signatureHashAlgo.Hash, signatureHashAlgo.Signature, ) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } state.localKeySignature = signature pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &handshake.Handshake{ Message: &handshake.MessageServerKeyExchange{ EllipticCurveType: elliptic.CurveTypeNamedCurve, NamedCurve: state.namedCurve, PublicKey: state.localKeypair.PublicKey, HashAlgorithm: signatureHashAlgo.Hash, SignatureAlgorithm: signatureHashAlgo.Signature, Signature: state.localKeySignature, }, }, }, }) if cfg.clientAuth > NoClientCert { // An empty list of certificateAuthorities signals to // the client that it may send any certificate in response // to our request. When we know the CAs we trust, then // we can send them down, so that the client can choose // an appropriate certificate to give to us. var certificateAuthorities [][]byte if cfg.clientCAs != nil { // nolint:staticcheck // ignoring tlsCert.RootCAs.Subjects is deprecated ERR // because cert does not come from SystemCertPool and it's ok if certificate // authorities is empty. certificateAuthorities = cfg.clientCAs.Subjects() } certReq := &handshake.MessageCertificateRequest{ CertificateTypes: []clientcertificate.Type{clientcertificate.RSASign, clientcertificate.ECDSASign}, SignatureHashAlgorithms: cfg.localSignatureSchemes, CertificateAuthoritiesNames: certificateAuthorities, } var content handshake.Handshake if cfg.certificateRequestMessageHook != nil { content = handshake.Handshake{Message: cfg.certificateRequestMessageHook(*certReq)} } else { content = handshake.Handshake{Message: certReq} } pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &content, }, }) } case cfg.localPSKIdentityHint != nil || state.cipherSuite.KeyExchangeAlgorithm().Has(CipherSuiteKeyExchangeAlgorithmEcdhe): // To help the client in selecting which identity to use, the server // can provide a "PSK identity hint" in the ServerKeyExchange message. // If no hint is provided and cipher suite doesn't use elliptic curve, // the ServerKeyExchange message is omitted. // // https://tools.ietf.org/html/rfc4279#section-2 srvExchange := &handshake.MessageServerKeyExchange{ IdentityHint: cfg.localPSKIdentityHint, } if state.cipherSuite.KeyExchangeAlgorithm().Has(CipherSuiteKeyExchangeAlgorithmEcdhe) { srvExchange.EllipticCurveType = elliptic.CurveTypeNamedCurve srvExchange.NamedCurve = state.namedCurve srvExchange.PublicKey = state.localKeypair.PublicKey } pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &handshake.Handshake{ Message: srvExchange, }, }, }) } pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &handshake.Handshake{ Message: &handshake.MessageServerHelloDone{}, }, }, }) return pkts, nil, nil } dtls-3.1.2/flight4handler_test.go000066400000000000000000000150611514330267300167570ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "context" "crypto/tls" "testing" "time" "github.com/pion/dtls/v3/internal/ciphersuite" "github.com/pion/dtls/v3/pkg/crypto/elliptic" "github.com/pion/dtls/v3/pkg/crypto/selfsign" "github.com/pion/dtls/v3/pkg/crypto/signaturehash" "github.com/pion/dtls/v3/pkg/protocol/alert" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/transport/v4/test" "github.com/stretchr/testify/assert" ) type flight4TestMockFlightConn struct{} func (f *flight4TestMockFlightConn) notify(context.Context, alert.Level, alert.Description) error { return nil } func (f *flight4TestMockFlightConn) writePackets(context.Context, []*packet) error { return nil } func (f *flight4TestMockFlightConn) recvHandshake() <-chan recvHandshakeState { return nil } func (f *flight4TestMockFlightConn) setLocalEpoch(uint16) {} func (f *flight4TestMockFlightConn) handleQueuedPackets(context.Context) error { return nil } func (f *flight4TestMockFlightConn) sessionKey() []byte { return nil } type flight4TestMockCipherSuite struct { ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256 t *testing.T } func (f *flight4TestMockCipherSuite) IsInitialized() bool { assert.Fail(f.t, "IsInitialized called with Certificate but not CertificateVerify") return true } // Assert that if a Client sends a certificate they // must also send a CertificateVerify message. // The flight4handler must not interact with the CipherSuite // if the CertificateVerify is missing. func TestFlight4_Process_CertificateVerify(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(5 * time.Second) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() mockConn := &flight4TestMockFlightConn{} state := &State{ cipherSuite: &flight4TestMockCipherSuite{t: t}, } cache := newHandshakeCache() cfg := &handshakeConfig{} rawCertificate := []byte{ 0x0b, 0x00, 0x01, 0x9b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x9b, 0x00, 0x01, 0x98, 0x00, 0x01, 0x95, 0x30, 0x82, 0x01, 0x91, 0x30, 0x82, 0x01, 0x38, 0xa0, 0x03, 0x02, 0x01, 0x02, 0x02, 0x11, 0x01, 0x65, 0x03, 0x3f, 0x4d, 0x0b, 0x9a, 0x62, 0x91, 0xdb, 0x4d, 0x28, 0x2c, 0x1f, 0xd6, 0x73, 0x32, 0x30, 0x0a, 0x06, 0x08, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x04, 0x03, 0x02, 0x30, 0x00, 0x30, 0x1e, 0x17, 0x0d, 0x32, 0x32, 0x30, 0x35, 0x31, 0x35, 0x31, 0x38, 0x34, 0x33, 0x35, 0x35, 0x5a, 0x17, 0x0d, 0x32, 0x32, 0x30, 0x36, 0x31, 0x35, 0x31, 0x38, 0x34, 0x33, 0x35, 0x35, 0x5a, 0x30, 0x00, 0x30, 0x59, 0x30, 0x13, 0x06, 0x07, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x02, 0x01, 0x06, 0x08, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x03, 0x01, 0x07, 0x03, 0x42, 0x00, 0x04, 0xc3, 0xb7, 0x13, 0x1a, 0x0a, 0xfc, 0xd0, 0x82, 0xf8, 0x94, 0x5e, 0xc0, 0x77, 0x07, 0x81, 0x28, 0xc9, 0xcb, 0x08, 0x84, 0x50, 0x6b, 0xf0, 0x22, 0xe8, 0x79, 0xb9, 0x15, 0x33, 0xc4, 0x56, 0xa1, 0xd3, 0x1b, 0x24, 0xe3, 0x61, 0xbd, 0x4d, 0x65, 0x80, 0x6b, 0x5d, 0x96, 0x48, 0xa2, 0x44, 0x9e, 0xce, 0xe8, 0x65, 0xd6, 0x3c, 0xe0, 0x9b, 0x6b, 0xa1, 0x36, 0x34, 0xb2, 0x39, 0xe2, 0x03, 0x00, 0xa3, 0x81, 0x92, 0x30, 0x81, 0x8f, 0x30, 0x0e, 0x06, 0x03, 0x55, 0x1d, 0x0f, 0x01, 0x01, 0xff, 0x04, 0x04, 0x03, 0x02, 0x02, 0xa4, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x1d, 0x25, 0x04, 0x16, 0x30, 0x14, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x02, 0x06, 0x08, 0x2b, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x01, 0x30, 0x0f, 0x06, 0x03, 0x55, 0x1d, 0x13, 0x01, 0x01, 0xff, 0x04, 0x05, 0x30, 0x03, 0x01, 0x01, 0xff, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x1d, 0x0e, 0x04, 0x16, 0x04, 0x14, 0xb1, 0x1a, 0xe3, 0xeb, 0x6f, 0x7c, 0xc3, 0x8f, 0xba, 0x6f, 0x1c, 0xe8, 0xf0, 0x23, 0x08, 0x50, 0x8d, 0x3c, 0xea, 0x31, 0x30, 0x2e, 0x06, 0x03, 0x55, 0x1d, 0x11, 0x01, 0x01, 0xff, 0x04, 0x24, 0x30, 0x22, 0x82, 0x20, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x0a, 0x06, 0x08, 0x2a, 0x86, 0x48, 0xce, 0x3d, 0x04, 0x03, 0x02, 0x03, 0x47, 0x00, 0x30, 0x44, 0x02, 0x20, 0x06, 0x31, 0x43, 0xac, 0x03, 0x45, 0x79, 0x3c, 0xd7, 0x5f, 0x6e, 0x6a, 0xf8, 0x0e, 0xfd, 0x35, 0x49, 0xee, 0x1b, 0xbc, 0x47, 0xce, 0xe3, 0x39, 0xec, 0xe4, 0x62, 0xe1, 0x30, 0x1a, 0xa1, 0x89, 0x02, 0x20, 0x35, 0xcd, 0x7a, 0x15, 0x68, 0x09, 0x50, 0x49, 0x9e, 0x3e, 0x05, 0xd7, 0xc2, 0x69, 0x3f, 0x9c, 0x0c, 0x98, 0x92, 0x65, 0xec, 0xae, 0x44, 0xfe, 0xe5, 0x68, 0xb8, 0x09, 0x78, 0x7f, 0x6b, 0x77, } rawClientKeyExchange := []byte{ 0x10, 0x00, 0x00, 0x21, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x21, 0x20, 0x96, 0xed, 0x0c, 0xee, 0xf3, 0x11, 0xb1, 0x9d, 0x8b, 0x1c, 0x02, 0x7f, 0x06, 0x7c, 0x57, 0x7a, 0x14, 0xa6, 0x41, 0xde, 0x63, 0x57, 0x9e, 0xcd, 0x34, 0x54, 0xba, 0x37, 0x4d, 0x34, 0x15, 0x18, } cache.push(rawCertificate, 0, 0, handshake.TypeCertificate, true) cache.push(rawClientKeyExchange, 0, 1, handshake.TypeClientKeyExchange, true) _, _, err := flight4Parse(context.TODO(), mockConn, state, cache, cfg) assert.NoError(t, err) } func TestFlight4_CertificateRequestHook(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(5 * time.Second) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() localKeypair, err := elliptic.GenerateKeypair(elliptic.P256) assert.NoError(t, err) mockConn := &flight4TestMockFlightConn{} state := &State{ cipherSuite: &flight4TestMockCipherSuite{t: t}, localKeypair: localKeypair, } cert, err := selfsign.GenerateSelfSignedWithDNS("localhost") assert.NoError(t, err) cfg := &handshakeConfig{ localCertificates: []tls.Certificate{cert}, localSignatureSchemes: signaturehash.Algorithms(), clientAuth: 1, certificateRequestMessageHook: func(mcr handshake.MessageCertificateRequest) handshake.Message { mcr.SignatureHashAlgorithms = []signaturehash.Algorithm{} return &mcr }, } pkts, _, err := flight4Generate(mockConn, state, nil, cfg) assert.NoError(t, err) for _, p := range pkts { if h, ok := p.record.Content.(*handshake.Handshake); ok { //nolint:nestif if h.Message.Type() == handshake.TypeCertificateRequest { mcr := &handshake.MessageCertificateRequest{} msg, err := h.Message.Marshal() assert.NoError(t, err) assert.NoError(t, mcr.Unmarshal(msg)) if len(mcr.SignatureHashAlgorithms) == 0 { return } } } } assert.Fail(t, "hook failed to modify SignatureHashAlgorithms") } dtls-3.1.2/flight5bhandler.go000066400000000000000000000045401514330267300160630ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "context" "github.com/pion/dtls/v3/pkg/crypto/prf" "github.com/pion/dtls/v3/pkg/protocol" "github.com/pion/dtls/v3/pkg/protocol/alert" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) func flight5bParse( _ context.Context, _ flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig, ) (flightVal, *alert.Alert, error) { _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence-1, state.cipherSuite, handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, false, false}, ) if !ok { // No valid message received. Keep reading return 0, nil, nil } if _, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil } // Other party may re-transmit the last flight. Keep state to be flight5b. return flight5b, nil, nil } func flight5bGenerate( _ flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig, ) ([]*packet, *alert.Alert, error) { //nolint:gocognit var pkts []*packet pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &protocol.ChangeCipherSpec{}, }, }) if len(state.localVerifyData) == 0 { plainText := cache.pullAndMerge( handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, false, false}, ) var err error state.localVerifyData, err = prf.VerifyDataClient(state.masterSecret, plainText, state.cipherSuite.HashFunc()) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } } pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, Epoch: 1, }, Content: &handshake.Handshake{ Message: &handshake.MessageFinished{ VerifyData: state.localVerifyData, }, }, }, shouldEncrypt: true, resetLocalSequenceNumber: true, }) return pkts, nil, nil } dtls-3.1.2/flight5handler.go000066400000000000000000000342361514330267300157260ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "bytes" "context" "crypto" "crypto/x509" "github.com/pion/dtls/v3/pkg/crypto/prf" "github.com/pion/dtls/v3/pkg/crypto/signaturehash" "github.com/pion/dtls/v3/pkg/protocol" "github.com/pion/dtls/v3/pkg/protocol/alert" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) func flight5Parse( _ context.Context, conn flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig, ) (flightVal, *alert.Alert, error) { _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence, state.cipherSuite, handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, false, false}, ) if !ok { // No valid message received. Keep reading return 0, nil, nil } var finished *handshake.MessageFinished if finished, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil } plainText := cache.pullAndMerge( handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeCertificateVerify, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false}, ) expectedVerifyData, err := prf.VerifyDataServer(state.masterSecret, plainText, state.cipherSuite.HashFunc()) if err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } if !bytes.Equal(expectedVerifyData, finished.VerifyData) { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errVerifyDataMismatch } if len(state.SessionID) > 0 { s := Session{ ID: state.SessionID, Secret: state.masterSecret, } cfg.log.Tracef("[handshake] save new session: %x", s.ID) if err := cfg.sessionStore.Set(conn.sessionKey(), s); err != nil { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } } return flight5, nil, nil } //nolint:gocognit,cyclop,maintidx func flight5Generate( conn flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig, ) ([]*packet, *alert.Alert, error) { var signer crypto.Signer var pkts []*packet if state.remoteRequestedCertificate { //nolint:nestif _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence-2, state.cipherSuite, handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false}) if !ok { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errClientCertificateRequired } reqInfo := CertificateRequestInfo{} if r, ok2 := msgs[handshake.TypeCertificateRequest].(*handshake.MessageCertificateRequest); ok2 { reqInfo.AcceptableCAs = r.CertificateAuthoritiesNames } else { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errClientCertificateRequired } certificate, err := cfg.getClientCertificate(&reqInfo) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, err } if certificate == nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errNotAcceptableCertificateChain } if certificate.Certificate != nil { signer, ok = certificate.PrivateKey.(crypto.Signer) if !ok { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.HandshakeFailure}, errInvalidPrivateKey } } pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &handshake.Handshake{ Message: &handshake.MessageCertificate{ Certificate: certificate.Certificate, }, }, }, }) } clientKeyExchange := &handshake.MessageClientKeyExchange{} if cfg.localPSKCallback == nil { clientKeyExchange.PublicKey = state.localKeypair.PublicKey } else { clientKeyExchange.IdentityHint = cfg.localPSKIdentityHint } if state != nil && state.localKeypair != nil && len(state.localKeypair.PublicKey) > 0 { clientKeyExchange.PublicKey = state.localKeypair.PublicKey } pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &handshake.Handshake{ Message: clientKeyExchange, }, }, }) serverKeyExchangeData := cache.pullAndMerge( handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false}, ) serverKeyExchange := &handshake.MessageServerKeyExchange{} // handshakeMessageServerKeyExchange is optional for PSK if len(serverKeyExchangeData) == 0 { alertPtr, err := handleServerKeyExchange(conn, state, cfg, &handshake.MessageServerKeyExchange{}) if err != nil { return nil, alertPtr, err } } else { rawHandshake := &handshake.Handshake{ KeyExchangeAlgorithm: state.cipherSuite.KeyExchangeAlgorithm(), } err := rawHandshake.Unmarshal(serverKeyExchangeData) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, err } switch h := rawHandshake.Message.(type) { case *handshake.MessageServerKeyExchange: serverKeyExchange = h default: return nil, &alert.Alert{Level: alert.Fatal, Description: alert.UnexpectedMessage}, errInvalidContentType } } // Append not-yet-sent packets merged := []byte{} seqPred := uint16(state.handshakeSendSequence) //nolint:gosec // G115 for _, p := range pkts { h, ok := p.record.Content.(*handshake.Handshake) if !ok { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errInvalidContentType } h.Header.MessageSequence = seqPred seqPred++ raw, err := h.Marshal() if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } merged = append(merged, raw...) } if alertPtr, err := initializeCipherSuite(state, cache, cfg, serverKeyExchange, merged); err != nil { return nil, alertPtr, err } // If the client has sent a certificate with signing ability, a digitally-signed // CertificateVerify message is sent to explicitly verify possession of the // private key in the certificate. if state.remoteRequestedCertificate && signer != nil { plainText := append(cache.pullAndMerge( handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false}, ), merged...) // Find compatible signature scheme signatureHashAlgo, err := signaturehash.SelectSignatureScheme(state.remoteCertRequestAlgs, signer) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, err } certVerify, err := generateCertificateVerify(plainText, signer, signatureHashAlgo.Hash, signatureHashAlgo.Signature) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } state.localCertificatesVerify = certVerify pkt := &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &handshake.Handshake{ Message: &handshake.MessageCertificateVerify{ HashAlgorithm: signatureHashAlgo.Hash, SignatureAlgorithm: signatureHashAlgo.Signature, Signature: state.localCertificatesVerify, }, }, }, } pkts = append(pkts, pkt) h, ok := pkt.record.Content.(*handshake.Handshake) if !ok { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errInvalidContentType } h.Header.MessageSequence = seqPred // seqPred++ // this is the last use of seqPred raw, err := h.Marshal() if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } merged = append(merged, raw...) } pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &protocol.ChangeCipherSpec{}, }, }) if len(state.localVerifyData) == 0 { plainText := cache.pullAndMerge( handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeCertificateVerify, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false}, ) var err error state.localVerifyData, err = prf.VerifyDataClient( state.masterSecret, append(plainText, merged...), state.cipherSuite.HashFunc(), ) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } } pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, Epoch: 1, }, Content: &handshake.Handshake{ Message: &handshake.MessageFinished{ VerifyData: state.localVerifyData, }, }, }, shouldWrapCID: len(state.remoteConnectionID) > 0, shouldEncrypt: true, resetLocalSequenceNumber: true, }) return pkts, nil, nil } //nolint:gocognit,cyclop func initializeCipherSuite( state *State, cache *handshakeCache, cfg *handshakeConfig, handshakeKeyExchange *handshake.MessageServerKeyExchange, sendingPlainText []byte, ) (*alert.Alert, error) { if state.cipherSuite.IsInitialized() { return nil, nil //nolint } clientRandom := state.localRandom.MarshalFixed() serverRandom := state.remoteRandom.MarshalFixed() var err error if state.extendedMasterSecret { var sessionHash []byte sessionHash, err = cache.sessionHash(state.cipherSuite.HashFunc(), cfg.initialEpoch, sendingPlainText) if err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } state.masterSecret, err = prf.ExtendedMasterSecret(state.preMasterSecret, sessionHash, state.cipherSuite.HashFunc()) if err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.IllegalParameter}, err } } else { state.masterSecret, err = prf.MasterSecret( state.preMasterSecret, clientRandom[:], serverRandom[:], state.cipherSuite.HashFunc(), ) if err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } } if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeCertificate { //nolint:nestif // Verify that the pair of hash algorithm and signiture is listed. var validSignatureScheme bool for _, ss := range cfg.localSignatureSchemes { if ss.Hash == handshakeKeyExchange.HashAlgorithm && ss.Signature == handshakeKeyExchange.SignatureAlgorithm { validSignatureScheme = true break } } if !validSignatureScheme { return &alert.Alert{Level: alert.Fatal, Description: alert.InsufficientSecurity}, errNoAvailableSignatureSchemes } expectedMsg := valueKeyMessage( clientRandom[:], serverRandom[:], handshakeKeyExchange.PublicKey, handshakeKeyExchange.NamedCurve, ) if err = verifyKeySignature( expectedMsg, handshakeKeyExchange.Signature, handshakeKeyExchange.HashAlgorithm, handshakeKeyExchange.SignatureAlgorithm, state.PeerCertificates, ); err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err } var chains [][]*x509.Certificate if !cfg.insecureSkipVerify { certAlgs := cfg.localCertSignatureSchemes if len(certAlgs) == 0 { certAlgs = cfg.localSignatureSchemes } if chains, err = verifyServerCert(state.PeerCertificates, cfg.rootCAs, cfg.serverName, certAlgs); err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err } } if cfg.verifyPeerCertificate != nil { if err = cfg.verifyPeerCertificate(state.PeerCertificates, chains); err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err } } } if cfg.verifyConnection != nil { stateClone, errC := state.clone() if errC != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errC } if errC = cfg.verifyConnection(stateClone); errC != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, errC } } if err = state.cipherSuite.Init(state.masterSecret, clientRandom[:], serverRandom[:], true); err != nil { return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } cfg.writeKeyLog(keyLogLabelTLS12, clientRandom[:], state.masterSecret) return nil, nil //nolint } dtls-3.1.2/flight6handler.go000066400000000000000000000057771514330267300157370ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "context" "github.com/pion/dtls/v3/pkg/crypto/prf" "github.com/pion/dtls/v3/pkg/protocol" "github.com/pion/dtls/v3/pkg/protocol/alert" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) func flight6Parse( _ context.Context, _ flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig, ) (flightVal, *alert.Alert, error) { _, msgs, ok := cache.fullPullMap(state.handshakeRecvSequence-1, state.cipherSuite, handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false}, ) if !ok { // No valid message received. Keep reading return 0, nil, nil } if _, ok = msgs[handshake.TypeFinished].(*handshake.MessageFinished); !ok { return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, nil } // Other party may re-transmit the last flight. Keep state to be flight6. return flight6, nil, nil } func flight6Generate( _ flightConn, state *State, cache *handshakeCache, cfg *handshakeConfig, ) ([]*packet, *alert.Alert, error) { var pkts []*packet pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, }, Content: &protocol.ChangeCipherSpec{}, }, }) if len(state.localVerifyData) == 0 { plainText := cache.pullAndMerge( handshakeCachePullRule{handshake.TypeClientHello, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeServerHello, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeServerKeyExchange, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificateRequest, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeServerHelloDone, cfg.initialEpoch, false, false}, handshakeCachePullRule{handshake.TypeCertificate, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeClientKeyExchange, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeCertificateVerify, cfg.initialEpoch, true, false}, handshakeCachePullRule{handshake.TypeFinished, cfg.initialEpoch + 1, true, false}, ) var err error state.localVerifyData, err = prf.VerifyDataServer(state.masterSecret, plainText, state.cipherSuite.HashFunc()) if err != nil { return nil, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err } } pkts = append(pkts, &packet{ record: &recordlayer.RecordLayer{ Header: recordlayer.Header{ Version: protocol.Version1_2, Epoch: 1, }, Content: &handshake.Handshake{ Message: &handshake.MessageFinished{ VerifyData: state.localVerifyData, }, }, }, shouldWrapCID: len(state.remoteConnectionID) > 0, shouldEncrypt: true, resetLocalSequenceNumber: true, }, ) return pkts, nil, nil } dtls-3.1.2/flighthandler.go000066400000000000000000000033611514330267300156340ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "context" "github.com/pion/dtls/v3/pkg/protocol/alert" ) // Parse received handshakes and return next flightVal. type flightParser func( context.Context, flightConn, *State, *handshakeCache, *handshakeConfig, ) (flightVal, *alert.Alert, error) // Generate flights. type flightGenerator func(flightConn, *State, *handshakeCache, *handshakeConfig) ([]*packet, *alert.Alert, error) func (f flightVal) getFlightParser() (flightParser, error) { //nolint:cyclop switch f { case flight0: return flight0Parse, nil case flight1: return flight1Parse, nil case flight2: return flight2Parse, nil case flight3: return flight3Parse, nil case flight4: return flight4Parse, nil case flight4b: return flight4bParse, nil case flight5: return flight5Parse, nil case flight5b: return flight5bParse, nil case flight6: return flight6Parse, nil default: return nil, errInvalidFlight } } func (f flightVal) getFlightGenerator() (gen flightGenerator, retransmit bool, err error) { //nolint:cyclop switch f { case flight0: return flight0Generate, true, nil case flight1: return flight1Generate, true, nil case flight2: // https://tools.ietf.org/html/rfc6347#section-3.2.1 // HelloVerifyRequests must not be retransmitted. return flight2Generate, false, nil case flight3: return flight3Generate, true, nil case flight4: return flight4Generate, true, nil case flight4b: return flight4bGenerate, true, nil case flight5: return flight5Generate, true, nil case flight5b: return flight5bGenerate, true, nil case flight6: return flight6Generate, true, nil default: return nil, false, errInvalidFlight } } dtls-3.1.2/fragment_buffer.go000066400000000000000000000106721514330267300161600ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "github.com/pion/dtls/v3/pkg/protocol" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/dtls/v3/pkg/protocol/recordlayer" ) const ( // 2 megabytes. fragmentBufferMaxSize = 2000000 fragmentBufferMaxCount = 1000 ) type fragment struct { recordLayerHeader recordlayer.Header handshakeHeader handshake.Header data []byte } type fragments struct { fragmentByOffset map[uint32]*fragment fragmentsLength uint32 handshakeLength uint32 } type fragmentBuffer struct { // map of MessageSequenceNumbers that hold slices of fragments cache map[uint16]*fragments currentMessageSequenceNumber uint16 totalBufferSize int totalFragmentCount int } func newFragmentBuffer() *fragmentBuffer { return &fragmentBuffer{cache: map[uint16]*fragments{}} } // current total size of buffer. func (f *fragmentBuffer) size() int { return f.totalBufferSize } // Attempts to push a DTLS packet to the fragmentBuffer // when it returns true it means the fragmentBuffer has inserted and the buffer shouldn't be handled // when an error returns it is fatal, and the DTLS connection should be stopped. func (f *fragmentBuffer) push(buf []byte) (isHandshake, isRetransmit bool, err error) { //nolint:cyclop if f.size()+len(buf) >= fragmentBufferMaxSize || f.totalFragmentCount >= fragmentBufferMaxCount { return false, false, errFragmentBufferOverflow } recordLayerHeader := recordlayer.Header{} if err := recordLayerHeader.Unmarshal(buf); err != nil { return false, false, err } // fragment isn't a handshake, we don't need to handle it if recordLayerHeader.ContentType != protocol.ContentTypeHandshake { return false, false, nil } frag := new(fragment) for buf = buf[recordlayer.FixedHeaderSize:]; len(buf) != 0; frag = new(fragment) { //nolint:gosec // G602 if err := frag.handshakeHeader.Unmarshal(buf); err != nil { return false, false, err } // Fragment is a retransmission. We have already assembled it before successfully isRetransmit = frag.handshakeHeader.FragmentOffset == 0 && frag.handshakeHeader.MessageSequence < f.currentMessageSequenceNumber end := int(handshake.HeaderLength + frag.handshakeHeader.FragmentLength) if end > len(buf) { return false, false, errBufferTooSmall } if frag.handshakeHeader.MessageSequence < f.currentMessageSequenceNumber { buf = buf[end:] continue } messageFragments, ok := f.cache[frag.handshakeHeader.MessageSequence] if !ok { messageFragments = &fragments{ fragmentByOffset: map[uint32]*fragment{}, handshakeLength: frag.handshakeHeader.Length, } f.cache[frag.handshakeHeader.MessageSequence] = messageFragments } // Discard all headers, when rebuilding the packet we will re-build frag.data = append([]byte{}, buf[handshake.HeaderLength:end]...) frag.recordLayerHeader = recordLayerHeader if _, ok = messageFragments.fragmentByOffset[frag.handshakeHeader.FragmentOffset]; !ok { messageFragments.fragmentByOffset[frag.handshakeHeader.FragmentOffset] = frag messageFragments.fragmentsLength += frag.handshakeHeader.FragmentLength f.totalBufferSize += int(frag.handshakeHeader.FragmentLength) f.totalFragmentCount++ } buf = buf[end:] } return true, isRetransmit, nil } func (f *fragmentBuffer) pop() (content []byte, epoch uint16) { frags, ok := f.cache[f.currentMessageSequenceNumber] if !ok { return nil, 0 } if frags.fragmentsLength != frags.handshakeLength { return nil, 0 } var rawMessage []byte targetOffset := uint32(0) for i := 0; i < len(frags.fragmentByOffset) && targetOffset < frags.handshakeLength; i++ { if frag, ok := frags.fragmentByOffset[targetOffset]; ok { rawMessage = append(rawMessage, frag.data...) targetOffset = frag.handshakeHeader.FragmentOffset + frag.handshakeHeader.FragmentLength } else { return nil, 0 } } if int(frags.handshakeLength) != len(rawMessage) { return nil, 0 } firstHeader := frags.fragmentByOffset[0].handshakeHeader firstHeader.FragmentOffset = 0 firstHeader.FragmentLength = firstHeader.Length rawHeader, _ := firstHeader.Marshal() messageEpoch := frags.fragmentByOffset[0].recordLayerHeader.Epoch f.totalBufferSize -= int(frags.fragmentsLength) f.totalFragmentCount -= len(frags.fragmentByOffset) delete(f.cache, f.currentMessageSequenceNumber) f.currentMessageSequenceNumber++ return append(rawHeader, rawMessage...), messageEpoch } dtls-3.1.2/fragment_buffer_test.go000066400000000000000000000162041514330267300172140ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "testing" "github.com/stretchr/testify/assert" ) func TestFragmentBuffer(t *testing.T) { for _, test := range []struct { Name string In [][]byte Expected [][]byte Epoch uint16 }{ { Name: "Single Fragment", In: [][]byte{ { 0x16, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0F, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xfe, 0xff, 0x00, }, }, Expected: [][]byte{ {0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xfe, 0xff, 0x00}, }, Epoch: 0, }, { Name: "Single Fragment Epoch 3", In: [][]byte{ { 0x16, 0xfe, 0xff, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0F, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xfe, 0xff, 0x00, }, }, Expected: [][]byte{ {0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xfe, 0xff, 0x00}, }, Epoch: 3, }, { Name: "Multiple Fragments", In: [][]byte{ { 0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x81, 0x0b, 0x00, 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x01, 0x02, 0x03, 0x04, }, { 0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x81, 0x0b, 0x00, 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x05, 0x05, 0x06, 0x07, 0x08, 0x09, }, { 0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x81, 0x0b, 0x00, 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x0A, 0x00, 0x00, 0x05, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, }, }, Expected: [][]byte{ { 0x0b, 0x00, 0x00, 0x0f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, }, }, Epoch: 0, }, { Name: "Multiple Unordered Fragments", In: [][]byte{ { 0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x81, 0x0b, 0x00, 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x01, 0x02, 0x03, 0x04, }, { 0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x81, 0x0b, 0x00, 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x0A, 0x00, 0x00, 0x05, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, }, { 0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x81, 0x0b, 0x00, 0x00, 0x0F, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x05, 0x05, 0x06, 0x07, 0x08, 0x09, }, }, Expected: [][]byte{ { 0x0b, 0x00, 0x00, 0x0f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, }, }, Epoch: 0, }, { Name: "Multiple Handshakes in Single Fragment", In: [][]byte{ { 0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x30, /* record header */ 0x03, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0xfe, 0xff, 0x01, 0x01, /*handshake msg 1*/ 0x03, 0x00, 0x00, 0x04, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0xfe, 0xff, 0x01, 0x01, /*handshake msg 2*/ 0x03, 0x00, 0x00, 0x04, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0xfe, 0xff, 0x01, 0x01, /*handshake msg 3*/ }, }, Expected: [][]byte{ {0x03, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0xfe, 0xff, 0x01, 0x01}, {0x03, 0x00, 0x00, 0x04, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0xfe, 0xff, 0x01, 0x01}, {0x03, 0x00, 0x00, 0x04, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0xfe, 0xff, 0x01, 0x01}, }, Epoch: 0, }, // Assert that a zero length fragment doesn't cause the fragmentBuffer to enter an infinite loop { Name: "Zero Length Fragment", In: [][]byte{ { 0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, }, }, Expected: [][]byte{ {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, }, Epoch: 0, }, // Not aligned fragments should not be reassembled { Name: "Not Aligned Fragments", In: [][]byte{ { 0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x81, 0x0b, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x01, 0x02, 0x03, 0x04, }, { 0x16, 0xfe, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x81, 0x0b, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x05, 0x05, 0x06, 0x07, 0x08, 0x09, }, }, Expected: [][]byte{nil}, Epoch: 0, }, } { fragmentBuffer := newFragmentBuffer() for _, frag := range test.In { status, _, err := fragmentBuffer.push(frag) assert.NoError(t, err) assert.Truef(t, status, "fragmentBuffer didn't accept fragments for '%s'", test.Name) } for _, expected := range test.Expected { out, epoch := fragmentBuffer.pop() assert.Equalf(t, expected, out, "fragmentBuffer '%s' pop should return expected output", test.Name) assert.Equalf(t, test.Epoch, epoch, "fragmentBuffer returend wrong epoch") } frag, _ := fragmentBuffer.pop() assert.Nilf(t, frag, "fragmentBuffer '%s' pop should return nil when no more fragments are available", test.Name) } } func TestFragmentBuffer_Overflow(t *testing.T) { fragmentBuffer := newFragmentBuffer() // Push a buffer that doesn't exceed size limits _, _, err := fragmentBuffer.push([]byte{ 0x16, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0F, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0xfe, 0xff, 0x00, }) assert.NoError(t, err) // Allocate a buffer that exceeds cache size largeBuffer := make([]byte, fragmentBufferMaxSize) _, _, err = fragmentBuffer.push(largeBuffer) assert.ErrorIs(t, err, errFragmentBufferOverflow, "Pushing a large buffer should return an overflow error") } func TestFragmentBuffer_TooSmall(t *testing.T) { fragmentBuffer := newFragmentBuffer() // Push a buffer that is smaller than fragment length _, _, err := fragmentBuffer.push([]byte{ 0x16, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0F, 0x03, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0xfe, 0xff, 0x00, }) assert.ErrorIs(t, err, errBufferTooSmall, "Pushing a buffer that is smaller than fragment length should return an error") } func TestFragmentBuffer_UnmarshalInvalid(t *testing.T) { fragmentBuffer := newFragmentBuffer() // Push a buffer with partial record layer header _, _, err := fragmentBuffer.push([]byte{ 0x16, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, }) assert.Error(t, err, "Pushing a buffer with partial record layer header should return an error") // Push a buffer with partial handshake header _, _, err = fragmentBuffer.push([]byte{ 0x16, 0xfe, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0F, 0x03, }) assert.Error(t, err, "Pushing a buffer with partial handshake header should return an error") } dtls-3.1.2/fuzz_test.go000066400000000000000000000010631514330267300150530ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "os" "testing" ) func FuzzUnmarshalBinary(f *testing.F) { TestResumeClient, err := os.ReadFile("testdata/seed/TestResumeClient.raw") if err != nil { return } f.Add(TestResumeClient) TestResumeServer, err := os.ReadFile("testdata/seed/TestResumeServer.raw") if err != nil { return } f.Add(TestResumeServer) f.Fuzz(func(_ *testing.T, data []byte) { deserialized := &State{} _ = deserialized.UnmarshalBinary(data) }) } dtls-3.1.2/go.mod000066400000000000000000000006751514330267300136050ustar00rootroot00000000000000module github.com/pion/dtls/v3 require ( github.com/pion/logging v0.2.4 github.com/pion/transport/v4 v4.0.1 github.com/stretchr/testify v1.11.1 golang.org/x/crypto v0.45.0 golang.org/x/net v0.47.0 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) go 1.24.0 // Retract version with broken RSA interop with OpenSSL DTLS 1.2. retract v3.1.0 dtls-3.1.2/go.sum000066400000000000000000000027771514330267300136370ustar00rootroot00000000000000github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8= github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so= github.com/pion/transport/v4 v4.0.1 h1:sdROELU6BZ63Ab7FrOLn13M6YdJLY20wldXW2Cu2k8o= github.com/pion/transport/v4 v4.0.1/go.mod h1:nEuEA4AD5lPdcIegQDpVLgNoDGreqM/YqmEx3ovP4jM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q= golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4= golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= dtls-3.1.2/handshake_cache.go000066400000000000000000000114161514330267300160720ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "sync" "github.com/pion/dtls/v3/pkg/crypto/prf" "github.com/pion/dtls/v3/pkg/protocol/handshake" ) type handshakeCacheItem struct { typ handshake.Type isClient bool epoch uint16 messageSequence uint16 data []byte } type handshakeCachePullRule struct { typ handshake.Type epoch uint16 isClient bool optional bool } type handshakeCache struct { cache []*handshakeCacheItem mu sync.Mutex } func newHandshakeCache() *handshakeCache { return &handshakeCache{} } func (h *handshakeCache) push(data []byte, epoch, messageSequence uint16, typ handshake.Type, isClient bool) { h.mu.Lock() defer h.mu.Unlock() h.cache = append(h.cache, &handshakeCacheItem{ data: append([]byte{}, data...), epoch: epoch, messageSequence: messageSequence, typ: typ, isClient: isClient, }) } // returns a list handshakes that match the requested rules // the list will contain null entries for rules that can't be satisfied // multiple entries may match a rule, but only the last match is returned (ie ClientHello with cookies). func (h *handshakeCache) pull(rules ...handshakeCachePullRule) []*handshakeCacheItem { h.mu.Lock() defer h.mu.Unlock() out := make([]*handshakeCacheItem, len(rules)) for i, r := range rules { for _, c := range h.cache { if c.typ == r.typ && c.isClient == r.isClient && c.epoch == r.epoch { switch { case out[i] == nil: out[i] = c case out[i].messageSequence < c.messageSequence: out[i] = c } } } } return out } // fullPullMap pulls all handshakes between rules[0] to rules[len(rules)-1] as map. // //nolint:cyclop func (h *handshakeCache) fullPullMap( startSeq int, cipherSuite CipherSuite, rules ...handshakeCachePullRule, ) (int, map[handshake.Type]handshake.Message, bool) { h.mu.Lock() defer h.mu.Unlock() ci := make(map[handshake.Type]*handshakeCacheItem) for _, rule := range rules { var item *handshakeCacheItem for _, c := range h.cache { if c.typ == rule.typ && c.isClient == rule.isClient && c.epoch == rule.epoch { switch { case item == nil: item = c case item.messageSequence < c.messageSequence: item = c } } } if !rule.optional && item == nil { // Missing mandatory message. return startSeq, nil, false } ci[rule.typ] = item } out := make(map[handshake.Type]handshake.Message) seq := startSeq ok := false for _, r := range rules { typ := r.typ i := ci[typ] if i == nil { continue } var keyExchangeAlgorithm CipherSuiteKeyExchangeAlgorithm if cipherSuite != nil { keyExchangeAlgorithm = cipherSuite.KeyExchangeAlgorithm() } rawHandshake := &handshake.Handshake{ KeyExchangeAlgorithm: keyExchangeAlgorithm, } if err := rawHandshake.Unmarshal(i.data); err != nil { return startSeq, nil, false } if uint16(seq) != rawHandshake.Header.MessageSequence { //nolint:gosec // G115 // There is a gap. Some messages are not arrived. return startSeq, nil, false } seq++ ok = true out[typ] = rawHandshake.Message } if !ok { return seq, nil, false } return seq, out, true } // pullAndMerge calls pull and then merges the results, ignoring any null entries. func (h *handshakeCache) pullAndMerge(rules ...handshakeCachePullRule) []byte { merged := []byte{} for _, p := range h.pull(rules...) { if p != nil { merged = append(merged, p.data...) } } return merged } // sessionHash returns the session hash for Extended Master Secret support // https://tools.ietf.org/html/draft-ietf-tls-session-hash-06#section-4 func (h *handshakeCache) sessionHash(hf prf.HashFunc, epoch uint16, additional ...[]byte) ([]byte, error) { merged := []byte{} // Order defined by https://tools.ietf.org/html/rfc5246#section-7.3 handshakeBuffer := h.pull( handshakeCachePullRule{handshake.TypeClientHello, epoch, true, false}, handshakeCachePullRule{handshake.TypeServerHello, epoch, false, false}, handshakeCachePullRule{handshake.TypeCertificate, epoch, false, false}, handshakeCachePullRule{handshake.TypeServerKeyExchange, epoch, false, false}, handshakeCachePullRule{handshake.TypeCertificateRequest, epoch, false, false}, handshakeCachePullRule{handshake.TypeServerHelloDone, epoch, false, false}, handshakeCachePullRule{handshake.TypeCertificate, epoch, true, false}, handshakeCachePullRule{handshake.TypeClientKeyExchange, epoch, true, false}, ) for _, p := range handshakeBuffer { if p == nil { continue } merged = append(merged, p.data...) } for _, a := range additional { merged = append(merged, a...) } hash := hf() if _, err := hash.Write(merged); err != nil { return []byte{}, err } return hash.Sum(nil), nil } dtls-3.1.2/handshake_cache_test.go000066400000000000000000000160061514330267300171310ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "testing" "github.com/pion/dtls/v3/internal/ciphersuite" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/stretchr/testify/assert" ) func TestHandshakeCacheSinglePush(t *testing.T) { for _, test := range []struct { Name string Rule []handshakeCachePullRule Input []handshakeCacheItem Expected []byte }{ { Name: "Single Push", Input: []handshakeCacheItem{ {0, true, 0, 0, []byte{0x00}}, }, Rule: []handshakeCachePullRule{ {0, 0, true, false}, }, Expected: []byte{0x00}, }, { Name: "Multi Push", Input: []handshakeCacheItem{ {0, true, 0, 0, []byte{0x00}}, {1, true, 0, 1, []byte{0x01}}, {2, true, 0, 2, []byte{0x02}}, }, Rule: []handshakeCachePullRule{ {0, 0, true, false}, {1, 0, true, false}, {2, 0, true, false}, }, Expected: []byte{0x00, 0x01, 0x02}, }, { Name: "Multi Push, Rules set order", Input: []handshakeCacheItem{ {2, true, 0, 2, []byte{0x02}}, {0, true, 0, 0, []byte{0x00}}, {1, true, 0, 1, []byte{0x01}}, }, Rule: []handshakeCachePullRule{ {0, 0, true, false}, {1, 0, true, false}, {2, 0, true, false}, }, Expected: []byte{0x00, 0x01, 0x02}, }, { Name: "Multi Push, Dupe Seqnum", Input: []handshakeCacheItem{ {0, true, 0, 0, []byte{0x00}}, {1, true, 0, 1, []byte{0x01}}, {1, true, 0, 1, []byte{0x01}}, }, Rule: []handshakeCachePullRule{ {0, 0, true, false}, {1, 0, true, false}, }, Expected: []byte{0x00, 0x01}, }, { Name: "Multi Push, Dupe Seqnum Client/Server", Input: []handshakeCacheItem{ {0, true, 0, 0, []byte{0x00}}, {1, true, 0, 1, []byte{0x01}}, {1, false, 0, 1, []byte{0x02}}, }, Rule: []handshakeCachePullRule{ {0, 0, true, false}, {1, 0, true, false}, {1, 0, false, false}, }, Expected: []byte{0x00, 0x01, 0x02}, }, { Name: "Multi Push, Dupe Seqnum with Unique HandshakeType", Input: []handshakeCacheItem{ {1, true, 0, 0, []byte{0x00}}, {2, true, 0, 1, []byte{0x01}}, {3, false, 0, 0, []byte{0x02}}, }, Rule: []handshakeCachePullRule{ {1, 0, true, false}, {2, 0, true, false}, {3, 0, false, false}, }, Expected: []byte{0x00, 0x01, 0x02}, }, { Name: "Multi Push, Wrong epoch", Input: []handshakeCacheItem{ {1, true, 0, 0, []byte{0x00}}, {2, true, 1, 1, []byte{0x01}}, {2, true, 0, 2, []byte{0x11}}, {3, false, 0, 0, []byte{0x02}}, {3, false, 1, 0, []byte{0x12}}, {3, false, 2, 0, []byte{0x12}}, }, Rule: []handshakeCachePullRule{ {1, 0, true, false}, {2, 1, true, false}, {3, 0, false, false}, }, Expected: []byte{0x00, 0x01, 0x02}, }, } { h := newHandshakeCache() for _, i := range test.Input { h.push(i.data, i.epoch, i.messageSequence, i.typ, i.isClient) } verifyData := h.pullAndMerge(test.Rule...) assert.Equal(t, test.Expected, verifyData) } } func TestHandshakeCacheSessionHash(t *testing.T) { for _, test := range []struct { Name string Rule []handshakeCachePullRule Input []handshakeCacheItem Expected []byte }{ { Name: "Standard Handshake", Input: []handshakeCacheItem{ {handshake.TypeClientHello, true, 0, 0, []byte{0x00}}, {handshake.TypeServerHello, false, 0, 1, []byte{0x01}}, {handshake.TypeCertificate, false, 0, 2, []byte{0x02}}, {handshake.TypeServerKeyExchange, false, 0, 3, []byte{0x03}}, {handshake.TypeServerHelloDone, false, 0, 4, []byte{0x04}}, {handshake.TypeClientKeyExchange, true, 0, 5, []byte{0x05}}, }, Expected: []byte{ 0x17, 0xe8, 0x8d, 0xb1, 0x87, 0xaf, 0xd6, 0x2c, 0x16, 0xe5, 0xde, 0xbf, 0x3e, 0x65, 0x27, 0xcd, 0x00, 0x6b, 0xc0, 0x12, 0xbc, 0x90, 0xb5, 0x1a, 0x81, 0x0c, 0xd8, 0x0c, 0x2d, 0x51, 0x1f, 0x43, }, }, { Name: "Handshake With Client Cert Request", Input: []handshakeCacheItem{ {handshake.TypeClientHello, true, 0, 0, []byte{0x00}}, {handshake.TypeServerHello, false, 0, 1, []byte{0x01}}, {handshake.TypeCertificate, false, 0, 2, []byte{0x02}}, {handshake.TypeServerKeyExchange, false, 0, 3, []byte{0x03}}, {handshake.TypeCertificateRequest, false, 0, 4, []byte{0x04}}, {handshake.TypeServerHelloDone, false, 0, 5, []byte{0x05}}, {handshake.TypeClientKeyExchange, true, 0, 6, []byte{0x06}}, }, Expected: []byte{ 0x57, 0x35, 0x5a, 0xc3, 0x30, 0x3c, 0x14, 0x8f, 0x11, 0xae, 0xf7, 0xcb, 0x17, 0x94, 0x56, 0xb9, 0x23, 0x2c, 0xde, 0x33, 0xa8, 0x18, 0xdf, 0xda, 0x2c, 0x2f, 0xcb, 0x93, 0x25, 0x74, 0x9a, 0x6b, }, }, { Name: "Handshake Ignores after ClientKeyExchange", Input: []handshakeCacheItem{ {handshake.TypeClientHello, true, 0, 0, []byte{0x00}}, {handshake.TypeServerHello, false, 0, 1, []byte{0x01}}, {handshake.TypeCertificate, false, 0, 2, []byte{0x02}}, {handshake.TypeServerKeyExchange, false, 0, 3, []byte{0x03}}, {handshake.TypeCertificateRequest, false, 0, 4, []byte{0x04}}, {handshake.TypeServerHelloDone, false, 0, 5, []byte{0x05}}, {handshake.TypeClientKeyExchange, true, 0, 6, []byte{0x06}}, {handshake.TypeCertificateVerify, true, 0, 7, []byte{0x07}}, {handshake.TypeFinished, true, 1, 7, []byte{0x08}}, {handshake.TypeFinished, false, 1, 7, []byte{0x09}}, }, Expected: []byte{ 0x57, 0x35, 0x5a, 0xc3, 0x30, 0x3c, 0x14, 0x8f, 0x11, 0xae, 0xf7, 0xcb, 0x17, 0x94, 0x56, 0xb9, 0x23, 0x2c, 0xde, 0x33, 0xa8, 0x18, 0xdf, 0xda, 0x2c, 0x2f, 0xcb, 0x93, 0x25, 0x74, 0x9a, 0x6b, }, }, { Name: "Handshake Ignores wrong epoch", Input: []handshakeCacheItem{ {handshake.TypeClientHello, true, 0, 0, []byte{0x00}}, {handshake.TypeServerHello, false, 0, 1, []byte{0x01}}, {handshake.TypeCertificate, false, 0, 2, []byte{0x02}}, {handshake.TypeServerKeyExchange, false, 0, 3, []byte{0x03}}, {handshake.TypeCertificateRequest, false, 0, 4, []byte{0x04}}, {handshake.TypeServerHelloDone, false, 0, 5, []byte{0x05}}, {handshake.TypeClientKeyExchange, true, 0, 6, []byte{0x06}}, {handshake.TypeCertificateVerify, true, 0, 7, []byte{0x07}}, {handshake.TypeFinished, true, 0, 7, []byte{0xf0}}, {handshake.TypeFinished, false, 0, 7, []byte{0xf1}}, {handshake.TypeFinished, true, 1, 7, []byte{0x08}}, {handshake.TypeFinished, false, 1, 7, []byte{0x09}}, {handshake.TypeFinished, true, 0, 7, []byte{0xf0}}, {handshake.TypeFinished, false, 0, 7, []byte{0xf1}}, }, Expected: []byte{ 0x57, 0x35, 0x5a, 0xc3, 0x30, 0x3c, 0x14, 0x8f, 0x11, 0xae, 0xf7, 0xcb, 0x17, 0x94, 0x56, 0xb9, 0x23, 0x2c, 0xde, 0x33, 0xa8, 0x18, 0xdf, 0xda, 0x2c, 0x2f, 0xcb, 0x93, 0x25, 0x74, 0x9a, 0x6b, }, }, } { h := newHandshakeCache() for _, i := range test.Input { h.push(i.data, i.epoch, i.messageSequence, i.typ, i.isClient) } cipherSuite := ciphersuite.TLSEcdheEcdsaWithAes128GcmSha256{} verifyData, err := h.sessionHash(cipherSuite.HashFunc(), 0) assert.NoError(t, err) assert.Equal(t, test.Expected, verifyData, "handshakeCacheSessionHash") } } dtls-3.1.2/handshake_test.go000066400000000000000000000034041514330267300160040ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "testing" "time" "github.com/pion/dtls/v3/pkg/protocol" "github.com/pion/dtls/v3/pkg/protocol/extension" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/stretchr/testify/assert" ) func TestHandshakeMessage(t *testing.T) { rawHandshakeMessage := []byte{ 0x01, 0x00, 0x00, 0x29, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x29, 0xfe, 0xfd, 0xb6, 0x2f, 0xce, 0x5c, 0x42, 0x54, 0xff, 0x86, 0xe1, 0x24, 0x41, 0x91, 0x42, 0x62, 0x15, 0xad, 0x16, 0xc9, 0x15, 0x8d, 0x95, 0x71, 0x8a, 0xbb, 0x22, 0xd7, 0x47, 0xec, 0xd8, 0x3d, 0xdc, 0x4b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, } parsedHandshake := &handshake.Handshake{ Header: handshake.Header{ Length: 0x29, FragmentLength: 0x29, Type: handshake.TypeClientHello, }, Message: &handshake.MessageClientHello{ Version: protocol.Version{Major: 0xFE, Minor: 0xFD}, Random: handshake.Random{ GMTUnixTime: time.Unix(3056586332, 0), RandomBytes: [28]byte{ 0x42, 0x54, 0xff, 0x86, 0xe1, 0x24, 0x41, 0x91, 0x42, 0x62, 0x15, 0xad, 0x16, 0xc9, 0x15, 0x8d, 0x95, 0x71, 0x8a, 0xbb, 0x22, 0xd7, 0x47, 0xec, 0xd8, 0x3d, 0xdc, 0x4b, }, }, SessionID: []byte{}, Cookie: []byte{}, CipherSuiteIDs: []uint16{}, CompressionMethods: []*protocol.CompressionMethod{}, Extensions: []extension.Extension{}, }, } h := &handshake.Handshake{} assert.NoError(t, h.Unmarshal(rawHandshakeMessage)) assert.Equal(t, parsedHandshake, h, "handshakeMessageClientHello unmarshal") raw, err := h.Marshal() assert.NoError(t, err) assert.Equal(t, rawHandshakeMessage, raw, "handshakeMessageClientHello marshal") } dtls-3.1.2/handshaker.go000066400000000000000000000255661514330267300151440ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "context" "crypto/tls" "crypto/x509" "fmt" "io" "sync" "time" "github.com/pion/dtls/v3/pkg/crypto/elliptic" "github.com/pion/dtls/v3/pkg/crypto/signaturehash" "github.com/pion/dtls/v3/pkg/protocol/alert" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/logging" ) // [RFC6347 Section-4.2.4] // +-----------+ // +---> | PREPARING | <--------------------+ // | +-----------+ | // | | | // | | Buffer next flight | // | | | // | \|/ | // | +-----------+ | // | | SENDING |<------------------+ | Send // | +-----------+ | | HelloRequest // Receive | | | | // next | | Send flight | | or // flight | +--------+ | | // | | | Set retransmit timer | | Receive // | | \|/ | | HelloRequest // | | +-----------+ | | Send // +--)--| WAITING |-------------------+ | ClientHello // | | +-----------+ Timer expires | | // | | | | | // | | +------------------------+ | // Receive | | Send Read retransmit | // last | | last | // flight | | flight | // | | | // \|/\|/ | // +-----------+ | // | FINISHED | -------------------------------+ // +-----------+ // | /|\ // | | // +---+ // Read retransmit // Retransmit last flight type handshakeState uint8 const ( handshakeErrored handshakeState = iota handshakePreparing handshakeSending handshakeWaiting handshakeFinished ) func (s handshakeState) String() string { switch s { case handshakeErrored: return "Errored" case handshakePreparing: return "Preparing" case handshakeSending: return "Sending" case handshakeWaiting: return "Waiting" case handshakeFinished: return "Finished" default: return "Unknown" } } type handshakeFSM struct { currentFlight flightVal flights []*packet retransmit bool retransmitInterval time.Duration state *State cache *handshakeCache cfg *handshakeConfig closed chan struct{} } type handshakeConfig struct { localPSKCallback PSKCallback localPSKIdentityHint []byte localCipherSuites []CipherSuite // Available CipherSuites localSignatureSchemes []signaturehash.Algorithm // Available signature schemes localCertSignatureSchemes []signaturehash.Algorithm // Available signature schemes for certificates extendedMasterSecret ExtendedMasterSecretType // Policy for the Extended Master Support extension localSRTPProtectionProfiles []SRTPProtectionProfile // Available SRTPProtectionProfiles, if empty no SRTP support localSRTPMasterKeyIdentifier []byte serverName string supportedProtocols []string clientAuth ClientAuthType // If we are a client should we request a client certificate localCertificates []tls.Certificate nameToCertificate map[string]*tls.Certificate insecureSkipVerify bool verifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error verifyConnection func(*State) error sessionStore SessionStore rootCAs *x509.CertPool clientCAs *x509.CertPool initialRetransmitInterval time.Duration disableRetransmitBackoff bool customCipherSuites func() []CipherSuite ellipticCurves []elliptic.Curve insecureSkipHelloVerify bool connectionIDGenerator func() []byte helloRandomBytesGenerator func() [handshake.RandomBytesLength]byte onFlightState func(flightVal, handshakeState) log logging.LeveledLogger keyLogWriter io.Writer localGetCertificate func(*ClientHelloInfo) (*tls.Certificate, error) localGetClientCertificate func(*CertificateRequestInfo) (*tls.Certificate, error) initialEpoch uint16 mu sync.Mutex clientHelloMessageHook func(handshake.MessageClientHello) handshake.Message serverHelloMessageHook func(handshake.MessageServerHello) handshake.Message certificateRequestMessageHook func(handshake.MessageCertificateRequest) handshake.Message resumeState *State } type flightConn interface { notify(ctx context.Context, level alert.Level, desc alert.Description) error writePackets(context.Context, []*packet) error recvHandshake() <-chan recvHandshakeState setLocalEpoch(epoch uint16) handleQueuedPackets(context.Context) error sessionKey() []byte } func (c *handshakeConfig) writeKeyLog(label string, clientRandom, secret []byte) { if c.keyLogWriter == nil { return } c.mu.Lock() defer c.mu.Unlock() _, err := fmt.Fprintf(c.keyLogWriter, "%s %x %x\n", label, clientRandom, secret) if err != nil { c.log.Debugf("failed to write key log file: %s", err) } } func srvCliStr(isClient bool) string { if isClient { return "client" } return "server" } func newHandshakeFSM( s *State, cache *handshakeCache, cfg *handshakeConfig, initialFlight flightVal, ) *handshakeFSM { return &handshakeFSM{ currentFlight: initialFlight, state: s, cache: cache, cfg: cfg, retransmitInterval: cfg.initialRetransmitInterval, closed: make(chan struct{}), } } func (s *handshakeFSM) Run(ctx context.Context, conn flightConn, initialState handshakeState) error { state := initialState defer func() { close(s.closed) }() for { s.cfg.log.Tracef("[handshake:%s] %s: %s", srvCliStr(s.state.isClient), s.currentFlight.String(), state.String()) if s.cfg.onFlightState != nil { s.cfg.onFlightState(s.currentFlight, state) } var err error switch state { case handshakePreparing: state, err = s.prepare(ctx, conn) case handshakeSending: state, err = s.send(ctx, conn) case handshakeWaiting: state, err = s.wait(ctx, conn) case handshakeFinished: state, err = s.finish(ctx, conn) default: return errInvalidFSMTransition } if err != nil { return err } } } func (s *handshakeFSM) Done() <-chan struct{} { return s.closed } func (s *handshakeFSM) prepare(ctx context.Context, conn flightConn) (handshakeState, error) { s.flights = nil // Prepare flights var ( dtlsAlert *alert.Alert err error pkts []*packet ) gen, retransmit, errFlight := s.currentFlight.getFlightGenerator() if errFlight != nil { err = errFlight dtlsAlert = &alert.Alert{Level: alert.Fatal, Description: alert.InternalError} } else { pkts, dtlsAlert, err = gen(conn, s.state, s.cache, s.cfg) s.retransmit = retransmit } if dtlsAlert != nil { if alertErr := conn.notify(ctx, dtlsAlert.Level, dtlsAlert.Description); alertErr != nil { if err != nil { err = alertErr } } } if err != nil { return handshakeErrored, err } s.flights = pkts epoch := s.cfg.initialEpoch nextEpoch := epoch for _, p := range s.flights { p.record.Header.Epoch += epoch if p.record.Header.Epoch > nextEpoch { nextEpoch = p.record.Header.Epoch } if h, ok := p.record.Content.(*handshake.Handshake); ok { h.Header.MessageSequence = uint16(s.state.handshakeSendSequence) //nolint:gosec // G115 s.state.handshakeSendSequence++ } } if epoch != nextEpoch { s.cfg.log.Tracef("[handshake:%s] -> changeCipherSpec (epoch: %d)", srvCliStr(s.state.isClient), nextEpoch) conn.setLocalEpoch(nextEpoch) } return handshakeSending, nil } func (s *handshakeFSM) send(ctx context.Context, c flightConn) (handshakeState, error) { // Send flights if err := c.writePackets(ctx, s.flights); err != nil { return handshakeErrored, err } if s.currentFlight.isLastSendFlight() { return handshakeFinished, nil } return handshakeWaiting, nil } func (s *handshakeFSM) wait(ctx context.Context, conn flightConn) (handshakeState, error) { //nolint:gocognit,cyclop parse, errFlight := s.currentFlight.getFlightParser() if errFlight != nil { if alertErr := conn.notify(ctx, alert.Fatal, alert.InternalError); alertErr != nil { return handshakeErrored, alertErr } return handshakeErrored, errFlight } retransmitTimer := time.NewTimer(s.retransmitInterval) for { select { case state := <-conn.recvHandshake(): if state.isRetransmit { close(state.done) // ignore incoming retransmit hints, only rely on the timer-driven path below // https://github.com/pion/dtls/issues/758 continue } nextFlight, alert, err := parse(ctx, conn, s.state, s.cache, s.cfg) s.retransmitInterval = s.cfg.initialRetransmitInterval close(state.done) if alert != nil { if alertErr := conn.notify(ctx, alert.Level, alert.Description); alertErr != nil { if err != nil { err = alertErr } } } if err != nil { return handshakeErrored, err } if nextFlight == 0 { break } s.cfg.log.Tracef( "[handshake:%s] %s -> %s", srvCliStr(s.state.isClient), s.currentFlight.String(), nextFlight.String(), ) if nextFlight.isLastRecvFlight() && s.currentFlight == nextFlight { return handshakeFinished, nil } s.currentFlight = nextFlight return handshakePreparing, nil case <-retransmitTimer.C: if !s.retransmit { return handshakeWaiting, nil } // RFC 4347 4.2.4.1: // Implementations SHOULD use an initial timer value of 1 second (the minimum defined in RFC 2988 [RFC2988]) // and double the value at each retransmission, up to no less than the RFC 2988 maximum of 60 seconds. if !s.cfg.disableRetransmitBackoff { s.retransmitInterval *= 2 } if s.retransmitInterval > time.Second*60 { s.retransmitInterval = time.Second * 60 } return handshakeSending, nil case <-ctx.Done(): s.retransmitInterval = s.cfg.initialRetransmitInterval return handshakeErrored, ctx.Err() } } } func (s *handshakeFSM) finish(ctx context.Context, c flightConn) (handshakeState, error) { select { case state := <-c.recvHandshake(): close(state.done) if s.state.isClient { return handshakeFinished, nil } else { return handshakeSending, nil } case <-ctx.Done(): return handshakeErrored, ctx.Err() } } dtls-3.1.2/handshaker_test.go000066400000000000000000000275361514330267300162020ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2026 The Pion community // SPDX-License-Identifier: MIT package dtls import ( "bytes" "context" "crypto/tls" "errors" "sync" "testing" "time" "github.com/pion/dtls/v3/pkg/crypto/selfsign" "github.com/pion/dtls/v3/pkg/crypto/signaturehash" "github.com/pion/dtls/v3/pkg/protocol/alert" "github.com/pion/dtls/v3/pkg/protocol/handshake" "github.com/pion/dtls/v3/pkg/protocol/recordlayer" "github.com/pion/logging" "github.com/pion/transport/v4/test" "github.com/stretchr/testify/assert" ) const nonZeroRetransmitInterval = 100 * time.Millisecond // Test that writes to the key log are in the correct format and only applies // when a key log writer is given. func TestWriteKeyLog(t *testing.T) { var buf bytes.Buffer cfg := handshakeConfig{ keyLogWriter: &buf, } cfg.writeKeyLog("LABEL", []byte{0xAA, 0xBB, 0xCC}, []byte{0xDD, 0xEE, 0xFF}) // Secrets follow the format