pax_global_header 0000666 0000000 0000000 00000000064 14765760516 0014533 g ustar 00root root 0000000 0000000 52 comment=4af2a2caa7f535d1f6240628c43c616a8c8121af
golang-github-lucas-clemente-quic-go-0.50.0/ 0000775 0000000 0000000 00000000000 14765760516 0020605 5 ustar 00root root 0000000 0000000 golang-github-lucas-clemente-quic-go-0.50.0/.clusterfuzzlite/ 0000775 0000000 0000000 00000000000 14765760516 0024141 5 ustar 00root root 0000000 0000000 golang-github-lucas-clemente-quic-go-0.50.0/.clusterfuzzlite/Dockerfile 0000664 0000000 0000000 00000001113 14765760516 0026127 0 ustar 00root root 0000000 0000000 FROM gcr.io/oss-fuzz-base/base-builder-go:v1
ARG TARGETPLATFORM
RUN echo "TARGETPLATFORM: ${TARGETPLATFORM}"
ENV GOVERSION=1.24.0
RUN platform=$(echo ${TARGETPLATFORM} | tr '/' '-') && \
filename="go${GOVERSION}.${platform}.tar.gz" && \
wget https://dl.google.com/go/${filename} && \
mkdir temp-go && \
rm -rf /root/.go/* && \
tar -C temp-go/ -xzf ${filename} && \
mv temp-go/go/* /root/.go/ && \
rm -r ${filename} temp-go
RUN apt-get update && apt-get install -y make autoconf automake libtool
COPY . $SRC/quic-go
WORKDIR quic-go
COPY .clusterfuzzlite/build.sh $SRC/
golang-github-lucas-clemente-quic-go-0.50.0/.clusterfuzzlite/build.sh 0000775 0000000 0000000 00000000755 14765760516 0025606 0 ustar 00root root 0000000 0000000 #!/bin/bash -eu
export CXX="${CXX} -lresolv" # required by Go 1.20
compile_go_fuzzer github.com/quic-go/quic-go/fuzzing/frames Fuzz frame_fuzzer
compile_go_fuzzer github.com/quic-go/quic-go/fuzzing/header Fuzz header_fuzzer
compile_go_fuzzer github.com/quic-go/quic-go/fuzzing/transportparameters Fuzz transportparameter_fuzzer
compile_go_fuzzer github.com/quic-go/quic-go/fuzzing/tokens Fuzz token_fuzzer
compile_go_fuzzer github.com/quic-go/quic-go/fuzzing/handshake Fuzz handshake_fuzzer
golang-github-lucas-clemente-quic-go-0.50.0/.clusterfuzzlite/project.yaml 0000664 0000000 0000000 00000000015 14765760516 0026467 0 ustar 00root root 0000000 0000000 language: go
golang-github-lucas-clemente-quic-go-0.50.0/.githooks/ 0000775 0000000 0000000 00000000000 14765760516 0022512 5 ustar 00root root 0000000 0000000 golang-github-lucas-clemente-quic-go-0.50.0/.githooks/README.md 0000664 0000000 0000000 00000000231 14765760516 0023765 0 ustar 00root root 0000000 0000000 # Git Hooks
This directory contains useful Git hooks for working with quic-go.
Install them by running
```bash
git config core.hooksPath .githooks
```
golang-github-lucas-clemente-quic-go-0.50.0/.githooks/pre-commit 0000775 0000000 0000000 00000001622 14765760516 0024515 0 ustar 00root root 0000000 0000000 #!/bin/bash
# Check that test files don't contain focussed test cases.
errored=false
for f in $(git diff --diff-filter=d --cached --name-only); do
if [[ $f != *_test.go ]]; then continue; fi
output=$(git show :"$f" | grep -n -e "FIt(" -e "FContext(" -e "FDescribe(")
if [ $? -eq 0 ]; then
echo "$f contains a focussed test:"
echo "$output"
echo ""
errored=true
fi
done
pushd ./integrationtests/gomodvendor > /dev/null
go mod tidy
if [[ -n $(git diff --diff-filter=d --name-only -- "go.mod" "go.sum") ]]; then
echo "go.mod / go.sum in integrationtests/gomodvendor not tidied"
errored=true
fi
popd > /dev/null
# Check that all Go files are properly gofumpt-ed.
output=$(gofumpt -d $(git diff --diff-filter=d --cached --name-only -- '*.go'))
if [ -n "$output" ]; then
echo "Found files that are not properly gofumpt-ed."
echo "$output"
errored=true
fi
if [ "$errored" = true ]; then
exit 1
fi
golang-github-lucas-clemente-quic-go-0.50.0/.github/ 0000775 0000000 0000000 00000000000 14765760516 0022145 5 ustar 00root root 0000000 0000000 golang-github-lucas-clemente-quic-go-0.50.0/.github/FUNDING.yml 0000664 0000000 0000000 00000001464 14765760516 0023767 0 ustar 00root root 0000000 0000000 # These are supported funding model platforms
github: [marten-seemann] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
patreon: # Replace with a single Patreon username
open_collective: # Replace with a single Open Collective username
ko_fi: # Replace with a single Ko-fi username
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
liberapay: # Replace with a single Liberapay username
issuehunt: # Replace with a single IssueHunt username
otechie: # Replace with a single Otechie username
lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
golang-github-lucas-clemente-quic-go-0.50.0/.github/dependabot.yml 0000664 0000000 0000000 00000000166 14765760516 0025000 0 ustar 00root root 0000000 0000000 version: 2
updates:
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "weekly"
golang-github-lucas-clemente-quic-go-0.50.0/.github/workflows/ 0000775 0000000 0000000 00000000000 14765760516 0024202 5 ustar 00root root 0000000 0000000 golang-github-lucas-clemente-quic-go-0.50.0/.github/workflows/build-interop-docker.yml 0000664 0000000 0000000 00000003077 14765760516 0030756 0 ustar 00root root 0000000 0000000 name: Build interop Docker image
on:
push:
branches:
- master
tags:
- 'v*'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
interop:
runs-on: ${{ fromJSON(vars['DOCKER_RUNNER_UBUNTU'] || '"ubuntu-latest"') }}
steps:
- uses: actions/checkout@v4
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
with:
platforms: linux/amd64,linux/arm64
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_PASSWORD }}
- name: set tag name
id: tag
# Tagged releases won't be picked up by the interop runner automatically,
# but they can be useful when debugging regressions.
run: |
if [[ $GITHUB_REF == refs/tags/* ]]; then
echo "tag=${GITHUB_REF#refs/tags/}" | tee -a $GITHUB_OUTPUT;
echo "gitref=${GITHUB_REF#refs/tags/}" | tee -a $GITHUB_OUTPUT;
else
echo 'tag=latest' | tee -a $GITHUB_OUTPUT;
echo 'gitref=${{ github.sha }}' | tee -a $GITHUB_OUTPUT;
fi
- uses: docker/build-push-action@v6
with:
context: "{{defaultContext}}:interop"
platforms: linux/amd64,linux/arm64
push: true
build-args: |
GITREF=${{ steps.tag.outputs.gitref }}
tags: martenseemann/quic-go-interop:${{ steps.tag.outputs.tag }}
golang-github-lucas-clemente-quic-go-0.50.0/.github/workflows/clusterfuzz-lite-pr.yml 0000664 0000000 0000000 00000003501 14765760516 0030676 0 ustar 00root root 0000000 0000000 name: ClusterFuzzLite PR fuzzing
on:
pull_request:
paths:
- '**'
permissions: read-all
jobs:
PR:
runs-on: ${{ fromJSON(vars['CLUSTERFUZZ_LITE_RUNNER_UBUNTU'] || '"ubuntu-latest"') }}
concurrency:
group: ${{ github.workflow }}-${{ matrix.sanitizer }}-${{ github.ref }}
cancel-in-progress: true
strategy:
fail-fast: false
matrix:
sanitizer:
- address
steps:
- name: Build Fuzzers (${{ matrix.sanitizer }})
id: build
uses: google/clusterfuzzlite/actions/build_fuzzers@v1
with:
language: go
github-token: ${{ secrets.GITHUB_TOKEN }}
sanitizer: ${{ matrix.sanitizer }}
# Optional but recommended: used to only run fuzzers that are affected
# by the PR.
# See later section on "Git repo for storage".
# storage-repo: https://${{ secrets.PERSONAL_ACCESS_TOKEN }}@github.com/OWNER/STORAGE-REPO-NAME.git
# storage-repo-branch: main # Optional. Defaults to "main"
# storage-repo-branch-coverage: gh-pages # Optional. Defaults to "gh-pages".
- name: Run Fuzzers (${{ matrix.sanitizer }})
id: run
uses: google/clusterfuzzlite/actions/run_fuzzers@v1
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
fuzz-seconds: 480
mode: 'code-change'
sanitizer: ${{ matrix.sanitizer }}
output-sarif: true
parallel-fuzzing: true
# Optional but recommended: used to download the corpus produced by
# batch fuzzing.
# See later section on "Git repo for storage".
# storage-repo: https://${{ secrets.PERSONAL_ACCESS_TOKEN }}@github.com/OWNER/STORAGE-REPO-NAME.git
# storage-repo-branch: main # Optional. Defaults to "main"
# storage-repo-branch-coverage: gh-pages # Optional. Defaults to "gh-pages".
golang-github-lucas-clemente-quic-go-0.50.0/.github/workflows/cross-compile.sh 0000775 0000000 0000000 00000001502 14765760516 0027316 0 ustar 00root root 0000000 0000000 #!/bin/bash
set -e
dist="$1"
goos=$(echo "$dist" | cut -d "/" -f1)
goarch=$(echo "$dist" | cut -d "/" -f2)
# cross-compiling for android is a pain...
if [[ "$goos" == "android" ]]; then exit; fi
# iOS builds require Cgo, see https://github.com/golang/go/issues/43343
# Cgo would then need a C cross compilation setup. Not worth the hassle.
if [[ "$goos" == "ios" ]]; then exit; fi
# Write all log output to a temporary file instead of to stdout.
# That allows running this script in parallel, while preserving the correct order of the output.
log_file=$(mktemp)
error_handler() {
cat "$log_file" >&2
rm "$log_file"
exit 1
}
trap 'error_handler' ERR
echo "$dist" >> "$log_file"
out="main-$goos-$goarch"
GOOS=$goos GOARCH=$goarch go build -o $out example/main.go >> "$log_file" 2>&1
rm $out
cat "$log_file"
rm "$log_file"
golang-github-lucas-clemente-quic-go-0.50.0/.github/workflows/cross-compile.yml 0000664 0000000 0000000 00000001477 14765760516 0027515 0 ustar 00root root 0000000 0000000 on: [push, pull_request]
jobs:
crosscompile:
strategy:
fail-fast: false
matrix:
go: [ "1.23.x", "1.24.x" ]
runs-on: ${{ fromJSON(vars['CROSS_COMPILE_RUNNER_UBUNTU'] || '"ubuntu-latest"') }}
name: "Cross Compilation (Go ${{matrix.go}})"
timeout-minutes: 30
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version: ${{ matrix.go }}
- name: Install build utils
run: |
sudo apt-get update
sudo apt-get install -y gcc-multilib
- name: Install dependencies
run: go build example/main.go
- name: Run cross compilation
# run in parallel on as many cores as are available on the machine
run: go tool dist list | xargs -I % -P "$(nproc)" .github/workflows/cross-compile.sh %
golang-github-lucas-clemente-quic-go-0.50.0/.github/workflows/go-generate.sh 0000775 0000000 0000000 00000001114 14765760516 0026733 0 ustar 00root root 0000000 0000000 #!/usr/bin/env bash
set -e
# delete all go-generated files (that adhere to the comment convention)
git ls-files -z | grep --include \*.go -lrIZ "^// Code generated .* DO NOT EDIT\.$" | tr '\0' '\n' | xargs rm -f
# First regenerate sys_conn_buffers_write.go.
# If it doesn't exist, the following mockgen calls will fail.
go generate -run "sys_conn_buffers_write.go"
# now generate everything
go generate ./...
# Check if any files were changed
git diff --exit-code || (
echo "Generated files are not up to date. Please run 'go generate ./...' and commit the changes."
exit 1
)
golang-github-lucas-clemente-quic-go-0.50.0/.github/workflows/integration.yml 0000664 0000000 0000000 00000006742 14765760516 0027261 0 ustar 00root root 0000000 0000000 on: [push, pull_request]
jobs:
integration:
strategy:
fail-fast: false
matrix:
os: [ "ubuntu" ]
go: [ "1.23.x", "1.24.x" ]
race: [ false ]
use32bit: [ false ]
include:
- os: "ubuntu"
go: "1.24.x"
race: true
- os: "ubuntu"
go: "1.24.x"
use32bit: true
- os: "windows"
go: "1.24.x"
race: false
- os: "macos"
go: "1.24.x"
race: false
runs-on: ${{ fromJSON(vars[format('INTEGRATION_RUNNER_{0}', matrix.os)] || format('"{0}-latest"', matrix.os)) }}
timeout-minutes: 30
defaults:
run:
shell: bash # by default Windows uses PowerShell, which uses a different syntax for setting environment variables
env:
DEBUG: false # set this to true to export qlogs and save them as artifacts
TIMESCALE_FACTOR: 3
name: "Integration (${{ matrix.os }}, Go ${{ matrix.go }}${{ matrix.race && ', race' || '' }}${{ matrix.use32bit && ', 32-bit' || '' }})"
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version: ${{ matrix.go }}
- name: Set qlogger
if: env.DEBUG == 'true'
run: echo "QLOGFLAG= -qlog" >> $GITHUB_ENV
- name: Enable race detector
if: ${{ matrix.race }}
run: echo "RACEFLAG= -race" >> $GITHUB_ENV
- name: Enable 32-bit build
if: ${{ matrix.use32bit }}
run: echo "GOARCH=386" >> $GITHUB_ENV
- run: go version
- name: Run tools tests
run: go test ${{ env.RACEFLAG }} -v -timeout 30s -shuffle=on ./integrationtests/tools
- name: Run version negotiation tests
run: go test ${{ env.RACEFLAG }} -v -timeout 30s -shuffle=on ./integrationtests/versionnegotiation ${{ env.QLOGFLAG }}
- name: Run self tests, using QUIC v1
if: success() || failure() # run this step even if the previous one failed
run: go test ${{ env.RACEFLAG }} -v -timeout 5m -shuffle=on ./integrationtests/self -version=1 ${{ env.QLOGFLAG }}
- name: Run self tests, using QUIC v2
if: ${{ !matrix.race && (success() || failure()) }} # run this step even if the previous one failed
run: go test ${{ env.RACEFLAG }} -v -timeout 5m -shuffle=on ./integrationtests/self -version=2 ${{ env.QLOGFLAG }}
- name: Run self tests, with GSO disabled
if: ${{ matrix.os == 'ubuntu' && (success() || failure()) }} # run this step even if the previous one failed
env:
QUIC_GO_DISABLE_GSO: true
run: go test ${{ env.RACEFLAG }} -v -timeout 5m -shuffle=on ./integrationtests/self -version=1 ${{ env.QLOGFLAG }}
- name: Run self tests, with ECN disabled
if: ${{ !matrix.race && matrix.os == 'ubuntu' && (success() || failure()) }} # run this step even if the previous one failed
env:
QUIC_GO_DISABLE_ECN: true
run: go test ${{ env.RACEFLAG }} -v -timeout 5m -shuffle=on ./integrationtests/self -version=1 ${{ env.QLOGFLAG }}
- name: Run benchmarks
if: ${{ !matrix.race }}
run: go test -v -run=^$ -timeout 5m -shuffle=on -bench=. ./integrationtests/self
- name: save qlogs
if: ${{ always() && env.DEBUG == 'true' }}
uses: actions/upload-artifact@v4
with:
name: qlogs-${{ matrix.os }}-go${{ matrix.go }}-race${{ matrix.race }}${{ matrix.use32bit && '-32bit' || '' }}
path: integrationtests/self/*.qlog
retention-days: 7
golang-github-lucas-clemente-quic-go-0.50.0/.github/workflows/lint.yml 0000664 0000000 0000000 00000005777 14765760516 0025713 0 ustar 00root root 0000000 0000000 on: [push, pull_request]
jobs:
check:
runs-on: ubuntu-latest
timeout-minutes: 15
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version: "1.24.x"
- name: Check that no non-test files import Ginkgo or Gomega
run: .github/workflows/no_ginkgo.sh
- name: Check for //go:build ignore in .go files
run: |
IGNORED_FILES=$(grep -rl '//go:build ignore' . --include='*.go') || true
if [ -n "$IGNORED_FILES" ]; then
echo "::error::Found ignored Go files: $IGNORED_FILES"
exit 1
fi
- name: Check that go.mod is tidied
if: success() || failure() # run this step even if the previous one failed
run: |
cp go.mod go.mod.orig
cp go.sum go.sum.orig
go mod tidy
diff go.mod go.mod.orig
diff go.sum go.sum.orig
- name: Run code generators
if: success() || failure() # run this step even if the previous one failed
run: .github/workflows/go-generate.sh
- name: Check that go mod vendor works
if: success() || failure() # run this step even if the previous one failed
run: |
cd integrationtests/gomodvendor
go mod vendor
golangci-lint:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
go: [ "1.23.x", "1.24.x" ]
env:
GOLANGCI_LINT_VERSION: v1.64.4
name: golangci-lint (Go ${{ matrix.go }})
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version: ${{ matrix.go }}
- name: golangci-lint (Linux)
uses: golangci/golangci-lint-action@v6
with:
args: --timeout=3m
version: ${{ env.GOLANGCI_LINT_VERSION }}
- name: golangci-lint (Windows)
if: success() || failure() # run this step even if the previous one failed
uses: golangci/golangci-lint-action@v6
env:
GOOS: "windows"
with:
args: --timeout=3m
version: ${{ env.GOLANGCI_LINT_VERSION }}
- name: golangci-lint (OSX)
if: success() || failure() # run this step even if the previous one failed
uses: golangci/golangci-lint-action@v6
env:
GOOS: "darwin"
with:
args: --timeout=3m
version: ${{ env.GOLANGCI_LINT_VERSION }}
- name: golangci-lint (FreeBSD)
if: success() || failure() # run this step even if the previous one failed
uses: golangci/golangci-lint-action@v6
env:
GOOS: "freebsd"
with:
args: --timeout=3m
version: ${{ env.GOLANGCI_LINT_VERSION }}
- name: golangci-lint (others)
if: success() || failure() # run this step even if the previous one failed
uses: golangci/golangci-lint-action@v6
env:
GOOS: "solaris" # some OS that we don't have any build tags for
with:
args: --timeout=3m
version: ${{ env.GOLANGCI_LINT_VERSION }}
golang-github-lucas-clemente-quic-go-0.50.0/.github/workflows/no_ginkgo.sh 0000775 0000000 0000000 00000000726 14765760516 0026520 0 ustar 00root root 0000000 0000000 #!/usr/bin/env bash
# Verify that no non-test files import Ginkgo or Gomega.
set -e
HAS_TESTING=false
cd ..
for f in $(find . -name "*.go" ! -name "*_test.go" ! -name "tools.go"); do
if grep -q "github.com/onsi/ginkgo" $f; then
echo "$f imports github.com/onsi/ginkgo/v2"
HAS_TESTING=true
fi
if grep -q "github.com/onsi/gomega" $f; then
echo "$f imports github.com/onsi/gomega"
HAS_TESTING=true
fi
done
if "$HAS_TESTING"; then
exit 1
fi
exit 0
golang-github-lucas-clemente-quic-go-0.50.0/.github/workflows/unit.yml 0000664 0000000 0000000 00000004372 14765760516 0025712 0 ustar 00root root 0000000 0000000 on: [push, pull_request]
jobs:
unit:
strategy:
fail-fast: false
matrix:
os: [ "ubuntu", "windows", "macos" ]
go: [ "1.23.x", "1.24.x" ]
runs-on: ${{ fromJSON(vars[format('UNIT_RUNNER_{0}', matrix.os)] || format('"{0}-latest"', matrix.os)) }}
name: Unit tests (${{ matrix.os}}, Go ${{ matrix.go }})
timeout-minutes: 30
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version: ${{ matrix.go }}
- run: go version
- name: Run tests
env:
TIMESCALE_FACTOR: 10
run: go run github.com/onsi/ginkgo/v2/ginkgo -r -v -cover -randomize-all -randomize-suites -trace -skip-package integrationtests
- name: Run tests as root
if: ${{ matrix.os == 'ubuntu' }}
env:
TIMESCALE_FACTOR: 10
FILE: sys_conn_helper_linux_test.go
run: |
test -f $FILE # make sure the file actually exists
go run github.com/onsi/ginkgo/v2/ginkgo build -cover -tags root .
sudo ./quic-go.test -ginkgo.v -ginkgo.trace -ginkgo.randomize-all -ginkgo.focus-file=$FILE -test.coverprofile coverage-root.txt
rm quic-go.test
- name: Run tests (32 bit)
if: ${{ matrix.os != 'macos' }} # can't run 32 bit tests on OSX.
env:
TIMESCALE_FACTOR: 10
GOARCH: 386
run: go run github.com/onsi/ginkgo/v2/ginkgo -r -v -cover -coverprofile coverage.txt -output-dir . -randomize-all -randomize-suites -trace -skip-package integrationtests
- name: Run tests with race detector
if: ${{ matrix.os == 'ubuntu' }} # speed things up. Windows and OSX VMs are slow
env:
TIMESCALE_FACTOR: 20
run: go run github.com/onsi/ginkgo/v2/ginkgo -r -v -race -randomize-all -randomize-suites -trace -skip-package integrationtests
- name: Run benchmark tests
run: go test -v -run=^$ -benchtime 0.5s -bench=. $(go list ./... | grep -v integrationtests/self)
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
env:
OS: ${{ matrix.os }}
GO: ${{ matrix.go }}
with:
files: coverage.txt,coverage-root.txt
env_vars: OS,GO
token: ${{ secrets.CODECOV_TOKEN }}
golang-github-lucas-clemente-quic-go-0.50.0/.gitignore 0000664 0000000 0000000 00000000332 14765760516 0022573 0 ustar 00root root 0000000 0000000 debug
debug.test
main
mockgen_tmp.go
*.qtr
*.qlog
*.sqlog
*.txt
race.[0-9]*
fuzzing/*/*.zip
fuzzing/*/coverprofile
fuzzing/*/crashers
fuzzing/*/sonarprofile
fuzzing/*/suppressions
fuzzing/*/corpus/
gomock_reflect_*/
golang-github-lucas-clemente-quic-go-0.50.0/.golangci.yml 0000664 0000000 0000000 00000001647 14765760516 0023201 0 ustar 00root root 0000000 0000000 linters-settings:
misspell:
ignore-words:
- ect
depguard:
rules:
quicvarint:
list-mode: strict
files:
- "**/github.com/quic-go/quic-go/quicvarint/*"
- "!$test"
allow:
- $gostd
linters:
disable-all: true
enable:
- asciicheck
- copyloopvar
- depguard
- exhaustive
- goimports
- gofmt # redundant, since gofmt *should* be a no-op after gofumpt
- gofumpt
- gosimple
- govet
- ineffassign
- misspell
- prealloc
- staticcheck
- stylecheck
- unconvert
- unparam
- unused
issues:
exclude-files:
- internal/handshake/cipher_suite.go
exclude-rules:
- path: internal/qtls
linters:
- depguard
- path: _test\.go
linters:
- exhaustive
- prealloc
- unparam
- path: _test\.go
text: "SA1029:"
linters:
- staticcheck
golang-github-lucas-clemente-quic-go-0.50.0/Changelog.md 0000664 0000000 0000000 00000010613 14765760516 0023017 0 ustar 00root root 0000000 0000000 # Changelog
## v0.22.0 (2021-07-25)
- Use `ReadBatch` to read multiple UDP packets from the socket with a single syscall
- Add a config option (`Config.DisableVersionNegotiationPackets`) to disable sending of Version Negotiation packets
- Drop support for QUIC draft versions 32 and 34
- Remove the `RetireBugBackwardsCompatibilityMode`, which was intended to mitigate a bug when retiring connection IDs in quic-go in v0.17.2 and ealier
## v0.21.2 (2021-07-15)
- Update qtls (for Go 1.15, 1.16 and 1.17rc1) to include the fix for the crypto/tls panic (see https://groups.google.com/g/golang-dev/c/5LJ2V7rd-Ag/m/YGLHVBZ6AAAJ for details)
## v0.21.0 (2021-06-01)
- quic-go now supports RFC 9000!
## v0.20.0 (2021-03-19)
- Remove the `quic.Config.HandshakeTimeout`. Introduce a `quic.Config.HandshakeIdleTimeout`.
## v0.17.1 (2020-06-20)
- Supports QUIC WG draft-29.
- Improve bundling of ACK frames (#2543).
## v0.16.0 (2020-05-31)
- Supports QUIC WG draft-28.
## v0.15.0 (2020-03-01)
- Supports QUIC WG draft-27.
- Add support for 0-RTT.
- Remove `Session.Close()`. Applications need to pass an application error code to the transport using `Session.CloseWithError()`.
- Make the TLS Cipher Suites configurable (via `tls.Config.CipherSuites`).
## v0.14.0 (2019-12-04)
- Supports QUIC WG draft-24.
## v0.13.0 (2019-11-05)
- Supports QUIC WG draft-23.
- Add an `EarlyListener` that allows sending of 0.5-RTT data.
- Add a `TokenStore` to store address validation tokens.
- Issue and use new connection IDs during a connection.
## v0.12.0 (2019-08-05)
- Implement HTTP/3.
- Rename `quic.Cookie` to `quic.Token` and `quic.Config.AcceptCookie` to `quic.Config.AcceptToken`.
- Distinguish between Retry tokens and tokens sent in NEW_TOKEN frames.
- Enforce application protocol negotiation (via `tls.Config.NextProtos`).
- Use a varint for error codes.
- Add support for [quic-trace](https://github.com/google/quic-trace).
- Add a context to `Listener.Accept`, `Session.Accept{Uni}Stream` and `Session.Open{Uni}StreamSync`.
- Implement TLS key updates.
## v0.11.0 (2019-04-05)
- Drop support for gQUIC. For qQUIC support, please switch to the *gquic* branch.
- Implement QUIC WG draft-19.
- Use [qtls](https://github.com/marten-seemann/qtls) for TLS 1.3.
- Return a `tls.ConnectionState` from `quic.Session.ConnectionState()`.
- Remove the error return values from `quic.Stream.CancelRead()` and `quic.Stream.CancelWrite()`
## v0.10.0 (2018-08-28)
- Add support for QUIC 44, drop support for QUIC 42.
## v0.9.0 (2018-08-15)
- Add a `quic.Config` option for the length of the connection ID (for IETF QUIC).
- Split Session.Close into one method for regular closing and one for closing with an error.
## v0.8.0 (2018-06-26)
- Add support for unidirectional streams (for IETF QUIC).
- Add a `quic.Config` option for the maximum number of incoming streams.
- Add support for QUIC 42 and 43.
- Add dial functions that use a context.
- Multiplex clients on a net.PacketConn, when using Dial(conn).
## v0.7.0 (2018-02-03)
- The lower boundary for packets included in ACKs is now derived, and the value sent in STOP_WAITING frames is ignored.
- Remove `DialNonFWSecure` and `DialAddrNonFWSecure`.
- Expose the `ConnectionState` in the `Session` (experimental API).
- Implement packet pacing.
## v0.6.0 (2017-12-12)
- Add support for QUIC 39, drop support for QUIC 35 - 37
- Added `quic.Config` options for maximal flow control windows
- Add a `quic.Config` option for QUIC versions
- Add a `quic.Config` option to request omission of the connection ID from a server
- Add a `quic.Config` option to configure the source address validation
- Add a `quic.Config` option to configure the handshake timeout
- Add a `quic.Config` option to configure the idle timeout
- Add a `quic.Config` option to configure keep-alive
- Rename the STK to Cookie
- Implement `net.Conn`-style deadlines for streams
- Remove the `tls.Config` from the `quic.Config`. The `tls.Config` must now be passed to the `Dial` and `Listen` functions as a separate parameter. See the [Godoc](https://godoc.org/github.com/quic-go/quic-go) for details.
- Changed the log level environment variable to only accept strings ("DEBUG", "INFO", "ERROR"), see [the wiki](https://github.com/quic-go/quic-go/wiki/Logging) for more details.
- Rename the `h2quic.QuicRoundTripper` to `h2quic.RoundTripper`
- Changed `h2quic.Server.Serve()` to accept a `net.PacketConn`
- Drop support for Go 1.7 and 1.8.
- Various bugfixes
golang-github-lucas-clemente-quic-go-0.50.0/LICENSE 0000664 0000000 0000000 00000002103 14765760516 0021606 0 ustar 00root root 0000000 0000000 MIT License
Copyright (c) 2016 the quic-go authors & Google, Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
golang-github-lucas-clemente-quic-go-0.50.0/README.md 0000664 0000000 0000000 00000020330 14765760516 0022062 0 ustar 00root root 0000000 0000000 # A QUIC implementation in pure Go
[](https://quic-go.net/docs/)
[](https://pkg.go.dev/github.com/quic-go/quic-go)
[](https://codecov.io/gh/quic-go/quic-go/)
[](https://bugs.chromium.org/p/oss-fuzz/issues/list?sort=-opened&can=1&q=proj:quic-go)
quic-go is an implementation of the QUIC protocol ([RFC 9000](https://datatracker.ietf.org/doc/html/rfc9000), [RFC 9001](https://datatracker.ietf.org/doc/html/rfc9001), [RFC 9002](https://datatracker.ietf.org/doc/html/rfc9002)) in Go. It has support for HTTP/3 ([RFC 9114](https://datatracker.ietf.org/doc/html/rfc9114)), including QPACK ([RFC 9204](https://datatracker.ietf.org/doc/html/rfc9204)) and HTTP Datagrams ([RFC 9297](https://datatracker.ietf.org/doc/html/rfc9297)).
In addition to these base RFCs, it also implements the following RFCs:
* Unreliable Datagram Extension ([RFC 9221](https://datatracker.ietf.org/doc/html/rfc9221))
* Datagram Packetization Layer Path MTU Discovery (DPLPMTUD, [RFC 8899](https://datatracker.ietf.org/doc/html/rfc8899))
* QUIC Version 2 ([RFC 9369](https://datatracker.ietf.org/doc/html/rfc9369))
* QUIC Event Logging using qlog ([draft-ietf-quic-qlog-main-schema](https://datatracker.ietf.org/doc/draft-ietf-quic-qlog-main-schema/) and [draft-ietf-quic-qlog-quic-events](https://datatracker.ietf.org/doc/draft-ietf-quic-qlog-quic-events/))
Support for WebTransport over HTTP/3 ([draft-ietf-webtrans-http3](https://datatracker.ietf.org/doc/draft-ietf-webtrans-http3/)) is implemented in [webtransport-go](https://github.com/quic-go/webtransport-go).
Detailed documentation can be found on [quic-go.net](https://quic-go.net/docs/).
## Projects using quic-go
| Project | Description | Stars |
| ---------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------- |
| [AdGuardHome](https://github.com/AdguardTeam/AdGuardHome) | Free and open source, powerful network-wide ads & trackers blocking DNS server. |  |
| [algernon](https://github.com/xyproto/algernon) | Small self-contained pure-Go web server with Lua, Markdown, HTTP/2, QUIC, Redis and PostgreSQL support |  |
| [caddy](https://github.com/caddyserver/caddy/) | Fast, multi-platform web server with automatic HTTPS |  |
| [cloudflared](https://github.com/cloudflare/cloudflared) | A tunneling daemon that proxies traffic from the Cloudflare network to your origins |  |
| [frp](https://github.com/fatedier/frp) | A fast reverse proxy to help you expose a local server behind a NAT or firewall to the internet |  |
| [go-libp2p](https://github.com/libp2p/go-libp2p) | libp2p implementation in Go, powering [Kubo](https://github.com/ipfs/kubo) (IPFS) and [Lotus](https://github.com/filecoin-project/lotus) (Filecoin), among others |  |
| [gost](https://github.com/go-gost/gost) | A simple security tunnel written in Go |  |
| [Hysteria](https://github.com/apernet/hysteria) | A powerful, lightning fast and censorship resistant proxy |  |
| [Mercure](https://github.com/dunglas/mercure) | An open, easy, fast, reliable and battery-efficient solution for real-time communications |  |
| [OONI Probe](https://github.com/ooni/probe-cli) | Next generation OONI Probe. Library and CLI tool. |  |
| [reverst](https://github.com/flipt-io/reverst) | Reverse Tunnels in Go over HTTP/3 and QUIC |  |
| [RoadRunner](https://github.com/roadrunner-server/roadrunner) | High-performance PHP application server, process manager written in Go and powered with plugins |  |
| [syncthing](https://github.com/syncthing/syncthing/) | Open Source Continuous File Synchronization |  |
| [traefik](https://github.com/traefik/traefik) | The Cloud Native Application Proxy |  |
| [v2ray-core](https://github.com/v2fly/v2ray-core) | A platform for building proxies to bypass network restrictions |  |
| [YoMo](https://github.com/yomorun/yomo) | Streaming Serverless Framework for Geo-distributed System |  |
If you'd like to see your project added to this list, please send us a PR.
## Release Policy
quic-go always aims to support the latest two Go releases.
## Contributing
We are always happy to welcome new contributors! We have a number of self-contained issues that are suitable for first-time contributors, they are tagged with [help wanted](https://github.com/quic-go/quic-go/issues?q=is%3Aissue+is%3Aopen+label%3A%22help+wanted%22). If you have any questions, please feel free to reach out by opening an issue or leaving a comment.
golang-github-lucas-clemente-quic-go-0.50.0/SECURITY.md 0000664 0000000 0000000 00000001507 14765760516 0022401 0 ustar 00root root 0000000 0000000 # Security Policy
quic-go still in development. This means that there may be problems in our protocols,
or there may be mistakes in our implementations.
We take security vulnerabilities very seriously. If you discover a security issue,
please bring it to our attention right away!
## Reporting a Vulnerability
If you find a vulnerability that may affect live deployments -- for example, by exposing
a remote execution exploit -- please [**report privately**](https://github.com/quic-go/quic-go/security/advisories/new).
Please **DO NOT file a public issue**.
If the issue is an implementation weakness that cannot be immediately exploited or
something not yet deployed, just discuss it openly.
## Reporting a non security bug
For non-security bugs, please simply file a GitHub [issue](https://github.com/quic-go/quic-go/issues/new).
golang-github-lucas-clemente-quic-go-0.50.0/buffer_pool.go 0000664 0000000 0000000 00000004324 14765760516 0023441 0 ustar 00root root 0000000 0000000 package quic
import (
"sync"
"github.com/quic-go/quic-go/internal/protocol"
)
type packetBuffer struct {
Data []byte
// refCount counts how many packets Data is used in.
// It doesn't support concurrent use.
// It is > 1 when used for coalesced packet.
refCount int
}
// Split increases the refCount.
// It must be called when a packet buffer is used for more than one packet,
// e.g. when splitting coalesced packets.
func (b *packetBuffer) Split() {
b.refCount++
}
// Decrement decrements the reference counter.
// It doesn't put the buffer back into the pool.
func (b *packetBuffer) Decrement() {
b.refCount--
if b.refCount < 0 {
panic("negative packetBuffer refCount")
}
}
// MaybeRelease puts the packet buffer back into the pool,
// if the reference counter already reached 0.
func (b *packetBuffer) MaybeRelease() {
// only put the packetBuffer back if it's not used any more
if b.refCount == 0 {
b.putBack()
}
}
// Release puts back the packet buffer into the pool.
// It should be called when processing is definitely finished.
func (b *packetBuffer) Release() {
b.Decrement()
if b.refCount != 0 {
panic("packetBuffer refCount not zero")
}
b.putBack()
}
// Len returns the length of Data
func (b *packetBuffer) Len() protocol.ByteCount { return protocol.ByteCount(len(b.Data)) }
func (b *packetBuffer) Cap() protocol.ByteCount { return protocol.ByteCount(cap(b.Data)) }
func (b *packetBuffer) putBack() {
if cap(b.Data) == protocol.MaxPacketBufferSize {
bufferPool.Put(b)
return
}
if cap(b.Data) == protocol.MaxLargePacketBufferSize {
largeBufferPool.Put(b)
return
}
panic("putPacketBuffer called with packet of wrong size!")
}
var bufferPool, largeBufferPool sync.Pool
func getPacketBuffer() *packetBuffer {
buf := bufferPool.Get().(*packetBuffer)
buf.refCount = 1
buf.Data = buf.Data[:0]
return buf
}
func getLargePacketBuffer() *packetBuffer {
buf := largeBufferPool.Get().(*packetBuffer)
buf.refCount = 1
buf.Data = buf.Data[:0]
return buf
}
func init() {
bufferPool.New = func() any {
return &packetBuffer{Data: make([]byte, 0, protocol.MaxPacketBufferSize)}
}
largeBufferPool.New = func() any {
return &packetBuffer{Data: make([]byte, 0, protocol.MaxLargePacketBufferSize)}
}
}
golang-github-lucas-clemente-quic-go-0.50.0/buffer_pool_test.go 0000664 0000000 0000000 00000002111 14765760516 0024470 0 ustar 00root root 0000000 0000000 package quic
import (
"testing"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/stretchr/testify/require"
)
func TestBufferPoolSizes(t *testing.T) {
buf1 := getPacketBuffer()
require.Equal(t, protocol.MaxPacketBufferSize, cap(buf1.Data))
require.Zero(t, buf1.Len())
buf1.Data = append(buf1.Data, []byte("foobar")...)
require.Equal(t, protocol.ByteCount(6), buf1.Len())
buf2 := getLargePacketBuffer()
require.Equal(t, protocol.MaxLargePacketBufferSize, cap(buf2.Data))
require.Zero(t, buf2.Len())
}
func TestBufferPoolRelease(t *testing.T) {
buf1 := getPacketBuffer()
buf1.Release()
// panics if released twice
require.Panics(t, func() { buf1.Release() })
// panics if wrong-sized buffers are passed
buf2 := getLargePacketBuffer()
buf2.Data = make([]byte, 10) // replace the underlying slice
require.Panics(t, func() { buf2.Release() })
}
func TestBufferPoolSplitting(t *testing.T) {
buf := getPacketBuffer()
buf.Split()
buf.Split()
// now we have 3 parts
buf.Decrement()
buf.Decrement()
buf.Decrement()
require.Panics(t, func() { buf.Decrement() })
}
golang-github-lucas-clemente-quic-go-0.50.0/client.go 0000664 0000000 0000000 00000006460 14765760516 0022420 0 ustar 00root root 0000000 0000000 package quic
import (
"context"
"crypto/tls"
"errors"
"net"
"github.com/quic-go/quic-go/internal/protocol"
)
// make it possible to mock connection ID for initial generation in the tests
var generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
// DialAddr establishes a new QUIC connection to a server.
// It resolves the address, and then creates a new UDP connection to dial the QUIC server.
// When the QUIC connection is closed, this UDP connection is closed.
// See Dial for more details.
func DialAddr(ctx context.Context, addr string, tlsConf *tls.Config, conf *Config) (Connection, error) {
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
return nil, err
}
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
tr, err := setupTransport(udpConn, tlsConf, true)
if err != nil {
return nil, err
}
return tr.dial(ctx, udpAddr, addr, tlsConf, conf, false)
}
// DialAddrEarly establishes a new 0-RTT QUIC connection to a server.
// See DialAddr for more details.
func DialAddrEarly(ctx context.Context, addr string, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) {
udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
return nil, err
}
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
tr, err := setupTransport(udpConn, tlsConf, true)
if err != nil {
return nil, err
}
conn, err := tr.dial(ctx, udpAddr, addr, tlsConf, conf, true)
if err != nil {
tr.Close()
return nil, err
}
return conn, nil
}
// DialEarly establishes a new 0-RTT QUIC connection to a server using a net.PacketConn.
// See Dial for more details.
func DialEarly(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tls.Config, conf *Config) (EarlyConnection, error) {
dl, err := setupTransport(c, tlsConf, false)
if err != nil {
return nil, err
}
conn, err := dl.DialEarly(ctx, addr, tlsConf, conf)
if err != nil {
dl.Close()
return nil, err
}
return conn, nil
}
// Dial establishes a new QUIC connection to a server using a net.PacketConn.
// If the PacketConn satisfies the OOBCapablePacketConn interface (as a net.UDPConn does),
// ECN and packet info support will be enabled. In this case, ReadMsgUDP and WriteMsgUDP
// will be used instead of ReadFrom and WriteTo to read/write packets.
// The tls.Config must define an application protocol (using NextProtos).
//
// This is a convenience function. More advanced use cases should instantiate a Transport,
// which offers configuration options for a more fine-grained control of the connection establishment,
// including reusing the underlying UDP socket for multiple QUIC connections.
func Dial(ctx context.Context, c net.PacketConn, addr net.Addr, tlsConf *tls.Config, conf *Config) (Connection, error) {
dl, err := setupTransport(c, tlsConf, false)
if err != nil {
return nil, err
}
conn, err := dl.Dial(ctx, addr, tlsConf, conf)
if err != nil {
dl.Close()
return nil, err
}
return conn, nil
}
func setupTransport(c net.PacketConn, tlsConf *tls.Config, createdPacketConn bool) (*Transport, error) {
if tlsConf == nil {
return nil, errors.New("quic: tls.Config not set")
}
return &Transport{
Conn: c,
createdConn: createdPacketConn,
isSingleUse: true,
}, nil
}
golang-github-lucas-clemente-quic-go-0.50.0/client_test.go 0000664 0000000 0000000 00000004153 14765760516 0023454 0 ustar 00root root 0000000 0000000 package quic
import (
"context"
"crypto/tls"
"net"
"runtime"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestDial(t *testing.T) {
t.Run("Dial", func(t *testing.T) {
testDial(t,
func(ctx context.Context, addr net.Addr) error {
conn := newUPDConnLocalhost(t)
_, err := Dial(ctx, conn, addr, &tls.Config{}, nil)
return err
},
false,
)
})
t.Run("DialEarly", func(t *testing.T) {
testDial(t,
func(ctx context.Context, addr net.Addr) error {
conn := newUPDConnLocalhost(t)
_, err := DialEarly(ctx, conn, addr, &tls.Config{}, nil)
return err
},
false,
)
})
t.Run("DialAddr", func(t *testing.T) {
testDial(t,
func(ctx context.Context, addr net.Addr) error {
_, err := DialAddr(ctx, addr.String(), &tls.Config{}, nil)
return err
},
true,
)
})
t.Run("DialAddrEarly", func(t *testing.T) {
testDial(t,
func(ctx context.Context, addr net.Addr) error {
_, err := DialAddrEarly(ctx, addr.String(), &tls.Config{}, nil)
return err
},
true,
)
})
}
func testDial(t *testing.T,
dialFn func(context.Context, net.Addr) error,
shouldCloseConn bool,
) {
server := newUPDConnLocalhost(t)
ctx, cancel := context.WithCancel(context.Background())
errChan := make(chan error, 1)
go func() { errChan <- dialFn(ctx, server.LocalAddr()) }()
_, addr, err := server.ReadFrom(make([]byte, 1500))
require.NoError(t, err)
require.True(t, areTransportsRunning())
cancel()
select {
case err := <-errChan:
require.ErrorIs(t, err, context.Canceled)
case <-time.After(time.Second):
t.Fatal("timeout")
}
// The socket that the client used for dialing should be closed now.
// Binding to the same address would error if the address was still in use.
conn, err := net.ListenUDP("udp", addr.(*net.UDPAddr))
if shouldCloseConn {
require.NoError(t, err)
defer conn.Close()
} else {
require.Error(t, err)
if runtime.GOOS == "windows" {
require.ErrorContains(t, err, "bind: Only one usage of each socket address")
} else {
require.ErrorContains(t, err, "address already in use")
}
}
require.False(t, areTransportsRunning())
}
golang-github-lucas-clemente-quic-go-0.50.0/closed_conn.go 0000664 0000000 0000000 00000003430 14765760516 0023422 0 ustar 00root root 0000000 0000000 package quic
import (
"math/bits"
"net"
"sync/atomic"
"github.com/quic-go/quic-go/internal/utils"
)
// A closedLocalConn is a connection that we closed locally.
// When receiving packets for such a connection, we need to retransmit the packet containing the CONNECTION_CLOSE frame,
// with an exponential backoff.
type closedLocalConn struct {
counter atomic.Uint32
logger utils.Logger
sendPacket func(net.Addr, packetInfo)
}
var _ packetHandler = &closedLocalConn{}
// newClosedLocalConn creates a new closedLocalConn and runs it.
func newClosedLocalConn(sendPacket func(net.Addr, packetInfo), logger utils.Logger) packetHandler {
return &closedLocalConn{
sendPacket: sendPacket,
logger: logger,
}
}
func (c *closedLocalConn) handlePacket(p receivedPacket) {
n := c.counter.Add(1)
// exponential backoff
// only send a CONNECTION_CLOSE for the 1st, 2nd, 4th, 8th, 16th, ... packet arriving
if bits.OnesCount32(n) != 1 {
return
}
c.logger.Debugf("Received %d packets after sending CONNECTION_CLOSE. Retransmitting.", n)
c.sendPacket(p.remoteAddr, p.info)
}
func (c *closedLocalConn) destroy(error) {}
func (c *closedLocalConn) closeWithTransportError(TransportErrorCode) {}
// A closedRemoteConn is a connection that was closed remotely.
// For such a connection, we might receive reordered packets that were sent before the CONNECTION_CLOSE.
// We can just ignore those packets.
type closedRemoteConn struct{}
var _ packetHandler = &closedRemoteConn{}
func newClosedRemoteConn() packetHandler {
return &closedRemoteConn{}
}
func (c *closedRemoteConn) handlePacket(receivedPacket) {}
func (c *closedRemoteConn) destroy(error) {}
func (c *closedRemoteConn) closeWithTransportError(TransportErrorCode) {}
golang-github-lucas-clemente-quic-go-0.50.0/closed_conn_test.go 0000664 0000000 0000000 00000001540 14765760516 0024461 0 ustar 00root root 0000000 0000000 package quic
import (
"net"
"testing"
"github.com/quic-go/quic-go/internal/utils"
"github.com/stretchr/testify/require"
)
func TestClosedLocalConnection(t *testing.T) {
written := make(chan net.Addr, 1)
conn := newClosedLocalConn(func(addr net.Addr, _ packetInfo) { written <- addr }, utils.DefaultLogger)
addr := &net.UDPAddr{IP: net.IPv4(127, 1, 2, 3), Port: 1337}
for i := 1; i <= 20; i++ {
conn.handlePacket(receivedPacket{remoteAddr: addr})
if i == 1 || i == 2 || i == 4 || i == 8 || i == 16 {
select {
case gotAddr := <-written:
require.Equal(t, addr, gotAddr) // receive the CONNECTION_CLOSE
default:
t.Fatal("expected to receive address")
}
} else {
select {
case gotAddr := <-written:
t.Fatalf("unexpected address received: %v", gotAddr)
default:
// Nothing received, which is expected
}
}
}
}
golang-github-lucas-clemente-quic-go-0.50.0/codecov.yml 0000664 0000000 0000000 00000000616 14765760516 0022755 0 ustar 00root root 0000000 0000000 coverage:
round: nearest
ignore:
- http3/gzip_reader.go
- interop/
- internal/handshake/cipher_suite.go
- internal/utils/linkedlist/linkedlist.go
- internal/testdata
- logging/connection_tracer_multiplexer.go
- logging/tracer_multiplexer.go
- testutils/
- fuzzing/
- metrics/
status:
project:
default:
threshold: 0.5
patch: false
golang-github-lucas-clemente-quic-go-0.50.0/config.go 0000664 0000000 0000000 00000010310 14765760516 0022374 0 ustar 00root root 0000000 0000000 package quic
import (
"fmt"
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/quicvarint"
)
// Clone clones a Config
func (c *Config) Clone() *Config {
copy := *c
return ©
}
func (c *Config) handshakeTimeout() time.Duration {
return 2 * c.HandshakeIdleTimeout
}
func (c *Config) maxRetryTokenAge() time.Duration {
return c.handshakeTimeout()
}
func validateConfig(config *Config) error {
if config == nil {
return nil
}
const maxStreams = 1 << 60
if config.MaxIncomingStreams > maxStreams {
config.MaxIncomingStreams = maxStreams
}
if config.MaxIncomingUniStreams > maxStreams {
config.MaxIncomingUniStreams = maxStreams
}
if config.MaxStreamReceiveWindow > quicvarint.Max {
config.MaxStreamReceiveWindow = quicvarint.Max
}
if config.MaxConnectionReceiveWindow > quicvarint.Max {
config.MaxConnectionReceiveWindow = quicvarint.Max
}
if config.InitialPacketSize > 0 && config.InitialPacketSize < protocol.MinInitialPacketSize {
config.InitialPacketSize = protocol.MinInitialPacketSize
}
if config.InitialPacketSize > protocol.MaxPacketBufferSize {
config.InitialPacketSize = protocol.MaxPacketBufferSize
}
// check that all QUIC versions are actually supported
for _, v := range config.Versions {
if !protocol.IsValidVersion(v) {
return fmt.Errorf("invalid QUIC version: %s", v)
}
}
return nil
}
// populateConfig populates fields in the quic.Config with their default values, if none are set
// it may be called with nil
func populateConfig(config *Config) *Config {
if config == nil {
config = &Config{}
}
versions := config.Versions
if len(versions) == 0 {
versions = protocol.SupportedVersions
}
handshakeIdleTimeout := protocol.DefaultHandshakeIdleTimeout
if config.HandshakeIdleTimeout != 0 {
handshakeIdleTimeout = config.HandshakeIdleTimeout
}
idleTimeout := protocol.DefaultIdleTimeout
if config.MaxIdleTimeout != 0 {
idleTimeout = config.MaxIdleTimeout
}
initialStreamReceiveWindow := config.InitialStreamReceiveWindow
if initialStreamReceiveWindow == 0 {
initialStreamReceiveWindow = protocol.DefaultInitialMaxStreamData
}
maxStreamReceiveWindow := config.MaxStreamReceiveWindow
if maxStreamReceiveWindow == 0 {
maxStreamReceiveWindow = protocol.DefaultMaxReceiveStreamFlowControlWindow
}
initialConnectionReceiveWindow := config.InitialConnectionReceiveWindow
if initialConnectionReceiveWindow == 0 {
initialConnectionReceiveWindow = protocol.DefaultInitialMaxData
}
maxConnectionReceiveWindow := config.MaxConnectionReceiveWindow
if maxConnectionReceiveWindow == 0 {
maxConnectionReceiveWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindow
}
maxIncomingStreams := config.MaxIncomingStreams
if maxIncomingStreams == 0 {
maxIncomingStreams = protocol.DefaultMaxIncomingStreams
} else if maxIncomingStreams < 0 {
maxIncomingStreams = 0
}
maxIncomingUniStreams := config.MaxIncomingUniStreams
if maxIncomingUniStreams == 0 {
maxIncomingUniStreams = protocol.DefaultMaxIncomingUniStreams
} else if maxIncomingUniStreams < 0 {
maxIncomingUniStreams = 0
}
initialPacketSize := config.InitialPacketSize
if initialPacketSize == 0 {
initialPacketSize = protocol.InitialPacketSize
}
return &Config{
GetConfigForClient: config.GetConfigForClient,
Versions: versions,
HandshakeIdleTimeout: handshakeIdleTimeout,
MaxIdleTimeout: idleTimeout,
KeepAlivePeriod: config.KeepAlivePeriod,
InitialStreamReceiveWindow: initialStreamReceiveWindow,
MaxStreamReceiveWindow: maxStreamReceiveWindow,
InitialConnectionReceiveWindow: initialConnectionReceiveWindow,
MaxConnectionReceiveWindow: maxConnectionReceiveWindow,
AllowConnectionWindowIncrease: config.AllowConnectionWindowIncrease,
MaxIncomingStreams: maxIncomingStreams,
MaxIncomingUniStreams: maxIncomingUniStreams,
TokenStore: config.TokenStore,
EnableDatagrams: config.EnableDatagrams,
InitialPacketSize: initialPacketSize,
DisablePathMTUDiscovery: config.DisablePathMTUDiscovery,
Allow0RTT: config.Allow0RTT,
Tracer: config.Tracer,
}
}
golang-github-lucas-clemente-quic-go-0.50.0/config_test.go 0000664 0000000 0000000 00000014720 14765760516 0023444 0 ustar 00root root 0000000 0000000 package quic
import (
"context"
"errors"
"reflect"
"testing"
"time"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/logging"
"github.com/quic-go/quic-go/quicvarint"
"github.com/stretchr/testify/require"
)
func TestConfigValidation(t *testing.T) {
t.Run("nil config", func(t *testing.T) {
require.NoError(t, validateConfig(nil))
})
t.Run("config with a few values set", func(t *testing.T) {
conf := populateConfig(&Config{
MaxIncomingStreams: 5,
MaxStreamReceiveWindow: 10,
})
require.NoError(t, validateConfig(conf))
require.Equal(t, int64(5), conf.MaxIncomingStreams)
require.Equal(t, uint64(10), conf.MaxStreamReceiveWindow)
})
t.Run("stream limits", func(t *testing.T) {
conf := &Config{
MaxIncomingStreams: 1<<60 + 1,
MaxIncomingUniStreams: 1<<60 + 2,
}
require.NoError(t, validateConfig(conf))
require.Equal(t, int64(1<<60), conf.MaxIncomingStreams)
require.Equal(t, int64(1<<60), conf.MaxIncomingUniStreams)
})
t.Run("flow control windows", func(t *testing.T) {
conf := &Config{
MaxStreamReceiveWindow: quicvarint.Max + 1,
MaxConnectionReceiveWindow: quicvarint.Max + 2,
}
require.NoError(t, validateConfig(conf))
require.Equal(t, uint64(quicvarint.Max), conf.MaxStreamReceiveWindow)
require.Equal(t, uint64(quicvarint.Max), conf.MaxConnectionReceiveWindow)
})
t.Run("initial packet size", func(t *testing.T) {
// not set
conf := &Config{InitialPacketSize: 0}
require.NoError(t, validateConfig(conf))
require.Zero(t, conf.InitialPacketSize)
// too small
conf = &Config{InitialPacketSize: 10}
require.NoError(t, validateConfig(conf))
require.Equal(t, uint16(1200), conf.InitialPacketSize)
// too large
conf = &Config{InitialPacketSize: protocol.MaxPacketBufferSize + 1}
require.NoError(t, validateConfig(conf))
require.Equal(t, uint16(protocol.MaxPacketBufferSize), conf.InitialPacketSize)
})
}
func TestConfigHandshakeIdleTimeout(t *testing.T) {
c := &Config{HandshakeIdleTimeout: time.Second * 11 / 2}
require.Equal(t, 11*time.Second, c.handshakeTimeout())
}
func configWithNonZeroNonFunctionFields(t *testing.T) *Config {
t.Helper()
c := &Config{}
v := reflect.ValueOf(c).Elem()
typ := v.Type()
for i := 0; i < typ.NumField(); i++ {
f := v.Field(i)
if !f.CanSet() {
// unexported field; not cloned.
continue
}
switch fn := typ.Field(i).Name; fn {
case "GetConfigForClient", "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease", "Tracer":
// Can't compare functions.
case "Versions":
f.Set(reflect.ValueOf([]Version{1, 2, 3}))
case "ConnectionIDLength":
f.Set(reflect.ValueOf(8))
case "ConnectionIDGenerator":
f.Set(reflect.ValueOf(&protocol.DefaultConnectionIDGenerator{ConnLen: protocol.DefaultConnectionIDLength}))
case "HandshakeIdleTimeout":
f.Set(reflect.ValueOf(time.Second))
case "MaxIdleTimeout":
f.Set(reflect.ValueOf(time.Hour))
case "TokenStore":
f.Set(reflect.ValueOf(NewLRUTokenStore(2, 3)))
case "InitialStreamReceiveWindow":
f.Set(reflect.ValueOf(uint64(1234)))
case "MaxStreamReceiveWindow":
f.Set(reflect.ValueOf(uint64(9)))
case "InitialConnectionReceiveWindow":
f.Set(reflect.ValueOf(uint64(4321)))
case "MaxConnectionReceiveWindow":
f.Set(reflect.ValueOf(uint64(10)))
case "MaxIncomingStreams":
f.Set(reflect.ValueOf(int64(11)))
case "MaxIncomingUniStreams":
f.Set(reflect.ValueOf(int64(12)))
case "StatelessResetKey":
f.Set(reflect.ValueOf(&StatelessResetKey{1, 2, 3, 4}))
case "KeepAlivePeriod":
f.Set(reflect.ValueOf(time.Second))
case "EnableDatagrams":
f.Set(reflect.ValueOf(true))
case "DisableVersionNegotiationPackets":
f.Set(reflect.ValueOf(true))
case "InitialPacketSize":
f.Set(reflect.ValueOf(uint16(1350)))
case "DisablePathMTUDiscovery":
f.Set(reflect.ValueOf(true))
case "Allow0RTT":
f.Set(reflect.ValueOf(true))
default:
t.Fatalf("all fields must be accounted for, but saw unknown field %q", fn)
}
}
return c
}
func TestConfigCloning(t *testing.T) {
t.Run("function fields", func(t *testing.T) {
var calledAllowConnectionWindowIncrease, calledTracer bool
c1 := &Config{
GetConfigForClient: func(info *ClientHelloInfo) (*Config, error) { return nil, errors.New("nope") },
AllowConnectionWindowIncrease: func(Connection, uint64) bool { calledAllowConnectionWindowIncrease = true; return true },
Tracer: func(context.Context, logging.Perspective, ConnectionID) *logging.ConnectionTracer {
calledTracer = true
return nil
},
}
c2 := c1.Clone()
c2.AllowConnectionWindowIncrease(nil, 1234)
require.True(t, calledAllowConnectionWindowIncrease)
_, err := c2.GetConfigForClient(&ClientHelloInfo{})
require.EqualError(t, err, "nope")
c2.Tracer(context.Background(), logging.PerspectiveClient, protocol.ConnectionID{})
require.True(t, calledTracer)
})
t.Run("clones non-function fields", func(t *testing.T) {
c := configWithNonZeroNonFunctionFields(t)
require.Equal(t, c, c.Clone())
})
t.Run("returns a copy", func(t *testing.T) {
c1 := &Config{MaxIncomingStreams: 100}
c2 := c1.Clone()
c2.MaxIncomingStreams = 200
require.EqualValues(t, 100, c1.MaxIncomingStreams)
})
}
func TestConfigDefaultValues(t *testing.T) {
// if set, the values should be copied
c := configWithNonZeroNonFunctionFields(t)
require.Equal(t, c, populateConfig(c))
// if not set, some fields use default values
c = populateConfig(&Config{})
require.Equal(t, protocol.SupportedVersions, c.Versions)
require.Equal(t, protocol.DefaultHandshakeIdleTimeout, c.HandshakeIdleTimeout)
require.Equal(t, protocol.DefaultIdleTimeout, c.MaxIdleTimeout)
require.EqualValues(t, protocol.DefaultInitialMaxStreamData, c.InitialStreamReceiveWindow)
require.EqualValues(t, protocol.DefaultMaxReceiveStreamFlowControlWindow, c.MaxStreamReceiveWindow)
require.EqualValues(t, protocol.DefaultInitialMaxData, c.InitialConnectionReceiveWindow)
require.EqualValues(t, protocol.DefaultMaxReceiveConnectionFlowControlWindow, c.MaxConnectionReceiveWindow)
require.EqualValues(t, protocol.DefaultMaxIncomingStreams, c.MaxIncomingStreams)
require.EqualValues(t, protocol.DefaultMaxIncomingUniStreams, c.MaxIncomingUniStreams)
require.False(t, c.DisablePathMTUDiscovery)
require.Nil(t, c.GetConfigForClient)
}
func TestConfigZeroLimits(t *testing.T) {
config := &Config{
MaxIncomingStreams: -1,
MaxIncomingUniStreams: -1,
}
c := populateConfig(config)
require.Zero(t, c.MaxIncomingStreams)
require.Zero(t, c.MaxIncomingUniStreams)
}
golang-github-lucas-clemente-quic-go-0.50.0/conn_id_generator.go 0000664 0000000 0000000 00000010417 14765760516 0024616 0 ustar 00root root 0000000 0000000 package quic
import (
"fmt"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/wire"
)
type connIDGenerator struct {
generator ConnectionIDGenerator
highestSeq uint64
activeSrcConnIDs map[uint64]protocol.ConnectionID
initialClientDestConnID *protocol.ConnectionID // nil for the client
addConnectionID func(protocol.ConnectionID)
statelessResetter *statelessResetter
removeConnectionID func(protocol.ConnectionID)
retireConnectionID func(protocol.ConnectionID)
replaceWithClosed func([]protocol.ConnectionID, []byte)
queueControlFrame func(wire.Frame)
}
func newConnIDGenerator(
initialConnectionID protocol.ConnectionID,
initialClientDestConnID *protocol.ConnectionID, // nil for the client
addConnectionID func(protocol.ConnectionID),
statelessResetter *statelessResetter,
removeConnectionID func(protocol.ConnectionID),
retireConnectionID func(protocol.ConnectionID),
replaceWithClosed func([]protocol.ConnectionID, []byte),
queueControlFrame func(wire.Frame),
generator ConnectionIDGenerator,
) *connIDGenerator {
m := &connIDGenerator{
generator: generator,
activeSrcConnIDs: make(map[uint64]protocol.ConnectionID),
addConnectionID: addConnectionID,
statelessResetter: statelessResetter,
removeConnectionID: removeConnectionID,
retireConnectionID: retireConnectionID,
replaceWithClosed: replaceWithClosed,
queueControlFrame: queueControlFrame,
}
m.activeSrcConnIDs[0] = initialConnectionID
m.initialClientDestConnID = initialClientDestConnID
return m
}
func (m *connIDGenerator) SetMaxActiveConnIDs(limit uint64) error {
if m.generator.ConnectionIDLen() == 0 {
return nil
}
// The active_connection_id_limit transport parameter is the number of
// connection IDs the peer will store. This limit includes the connection ID
// used during the handshake, and the one sent in the preferred_address
// transport parameter.
// We currently don't send the preferred_address transport parameter,
// so we can issue (limit - 1) connection IDs.
for i := uint64(len(m.activeSrcConnIDs)); i < min(limit, protocol.MaxIssuedConnectionIDs); i++ {
if err := m.issueNewConnID(); err != nil {
return err
}
}
return nil
}
func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.ConnectionID) error {
if seq > m.highestSeq {
return &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: fmt.Sprintf("retired connection ID %d (highest issued: %d)", seq, m.highestSeq),
}
}
connID, ok := m.activeSrcConnIDs[seq]
// We might already have deleted this connection ID, if this is a duplicate frame.
if !ok {
return nil
}
if connID == sentWithDestConnID {
return &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: fmt.Sprintf("retired connection ID %d (%s), which was used as the Destination Connection ID on this packet", seq, connID),
}
}
m.retireConnectionID(connID)
delete(m.activeSrcConnIDs, seq)
// Don't issue a replacement for the initial connection ID.
if seq == 0 {
return nil
}
return m.issueNewConnID()
}
func (m *connIDGenerator) issueNewConnID() error {
connID, err := m.generator.GenerateConnectionID()
if err != nil {
return err
}
m.activeSrcConnIDs[m.highestSeq+1] = connID
m.addConnectionID(connID)
m.queueControlFrame(&wire.NewConnectionIDFrame{
SequenceNumber: m.highestSeq + 1,
ConnectionID: connID,
StatelessResetToken: m.statelessResetter.GetStatelessResetToken(connID),
})
m.highestSeq++
return nil
}
func (m *connIDGenerator) SetHandshakeComplete() {
if m.initialClientDestConnID != nil {
m.retireConnectionID(*m.initialClientDestConnID)
m.initialClientDestConnID = nil
}
}
func (m *connIDGenerator) RemoveAll() {
if m.initialClientDestConnID != nil {
m.removeConnectionID(*m.initialClientDestConnID)
}
for _, connID := range m.activeSrcConnIDs {
m.removeConnectionID(connID)
}
}
func (m *connIDGenerator) ReplaceWithClosed(connClose []byte) {
connIDs := make([]protocol.ConnectionID, 0, len(m.activeSrcConnIDs)+1)
if m.initialClientDestConnID != nil {
connIDs = append(connIDs, *m.initialClientDestConnID)
}
for _, connID := range m.activeSrcConnIDs {
connIDs = append(connIDs, connID)
}
m.replaceWithClosed(connIDs, connClose)
}
golang-github-lucas-clemente-quic-go-0.50.0/conn_id_generator_test.go 0000664 0000000 0000000 00000015602 14765760516 0025656 0 ustar 00root root 0000000 0000000 package quic
import (
"testing"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/wire"
"github.com/stretchr/testify/require"
)
func TestConnIDGeneratorIssueAndRetire(t *testing.T) {
t.Run("with initial client destination connection ID", func(t *testing.T) {
testConnIDGeneratorIssueAndRetire(t, true)
})
t.Run("without initial client destination connection ID", func(t *testing.T) {
testConnIDGeneratorIssueAndRetire(t, false)
})
}
func testConnIDGeneratorIssueAndRetire(t *testing.T, hasInitialClientDestConnID bool) {
var (
added []protocol.ConnectionID
retired []protocol.ConnectionID
)
var queuedFrames []wire.Frame
sr := newStatelessResetter(&StatelessResetKey{1, 2, 3, 4})
var initialClientDestConnID *protocol.ConnectionID
if hasInitialClientDestConnID {
connID := protocol.ParseConnectionID([]byte{2, 2, 2, 2})
initialClientDestConnID = &connID
}
g := newConnIDGenerator(
protocol.ParseConnectionID([]byte{1, 1, 1, 1}),
initialClientDestConnID,
func(c protocol.ConnectionID) { added = append(added, c) },
sr,
func(c protocol.ConnectionID) { t.Fatal("didn't expect conn ID removals") },
func(c protocol.ConnectionID) { retired = append(retired, c) },
func([]protocol.ConnectionID, []byte) {},
func(f wire.Frame) { queuedFrames = append(queuedFrames, f) },
&protocol.DefaultConnectionIDGenerator{ConnLen: 5},
)
require.Empty(t, added)
require.NoError(t, g.SetMaxActiveConnIDs(4))
require.Len(t, added, 3)
require.Len(t, queuedFrames, 3)
require.Empty(t, retired)
connIDs := make(map[uint64]protocol.ConnectionID)
// connection IDs 1, 2 and 3 were issued
for i, f := range queuedFrames {
ncid := f.(*wire.NewConnectionIDFrame)
require.EqualValues(t, i+1, ncid.SequenceNumber)
require.Equal(t, ncid.ConnectionID, added[i])
require.Equal(t, ncid.StatelessResetToken, sr.GetStatelessResetToken(ncid.ConnectionID))
connIDs[ncid.SequenceNumber] = ncid.ConnectionID
}
// completing the handshake retires the initial client destination connection ID
added = added[:0]
queuedFrames = queuedFrames[:0]
g.SetHandshakeComplete()
require.Empty(t, added)
require.Empty(t, queuedFrames)
if hasInitialClientDestConnID {
require.Equal(t, []protocol.ConnectionID{*initialClientDestConnID}, retired)
retired = retired[:0]
} else {
require.Empty(t, retired)
}
// it's invalid to retire a connection ID that hasn't been issued yet
err := g.Retire(4, protocol.ParseConnectionID([]byte{3, 3, 3, 3}))
require.ErrorIs(t, &qerr.TransportError{ErrorCode: qerr.ProtocolViolation}, err)
require.ErrorContains(t, err, "retired connection ID 4 (highest issued: 3)")
// it's invalid to retire a connection ID in a packet that uses that connection ID
err = g.Retire(3, connIDs[3])
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.ProtocolViolation})
require.ErrorContains(t, err, "was used as the Destination Connection ID on this packet")
// retiring a connection ID makes us issue a new one
require.NoError(t, g.Retire(2, protocol.ParseConnectionID([]byte{3, 3, 3, 3})))
require.Equal(t, []protocol.ConnectionID{connIDs[2]}, retired)
require.Len(t, queuedFrames, 1)
require.EqualValues(t, 4, queuedFrames[0].(*wire.NewConnectionIDFrame).SequenceNumber)
queuedFrames = queuedFrames[:0]
retired = retired[:0]
// duplicate retirements don't do anything
require.NoError(t, g.Retire(2, protocol.ParseConnectionID([]byte{3, 3, 3, 3})))
require.Empty(t, queuedFrames)
require.Empty(t, retired)
}
func TestConnIDGeneratorRemoveAll(t *testing.T) {
t.Run("with initial client destination connection ID", func(t *testing.T) {
testConnIDGeneratorRemoveAll(t, true)
})
t.Run("without initial client destination connection ID", func(t *testing.T) {
testConnIDGeneratorRemoveAll(t, false)
})
}
func testConnIDGeneratorRemoveAll(t *testing.T, hasInitialClientDestConnID bool) {
var initialClientDestConnID *protocol.ConnectionID
if hasInitialClientDestConnID {
connID := protocol.ParseConnectionID([]byte{2, 2, 2, 2})
initialClientDestConnID = &connID
}
var (
added []protocol.ConnectionID
removed []protocol.ConnectionID
)
g := newConnIDGenerator(
protocol.ParseConnectionID([]byte{1, 1, 1, 1}),
initialClientDestConnID,
func(c protocol.ConnectionID) { added = append(added, c) },
newStatelessResetter(&StatelessResetKey{1, 2, 3, 4}),
func(c protocol.ConnectionID) { removed = append(removed, c) },
func(c protocol.ConnectionID) { t.Fatal("didn't expect conn ID retirements") },
func([]protocol.ConnectionID, []byte) {},
func(f wire.Frame) {},
&protocol.DefaultConnectionIDGenerator{ConnLen: 5},
)
require.NoError(t, g.SetMaxActiveConnIDs(1000))
require.Len(t, added, protocol.MaxIssuedConnectionIDs-1)
g.RemoveAll()
if hasInitialClientDestConnID {
require.Len(t, removed, protocol.MaxIssuedConnectionIDs+1)
require.Contains(t, removed, *initialClientDestConnID)
} else {
require.Len(t, removed, protocol.MaxIssuedConnectionIDs)
}
for _, id := range added {
require.Contains(t, removed, id)
}
require.Contains(t, removed, protocol.ParseConnectionID([]byte{1, 1, 1, 1}))
}
func TestConnIDGeneratorReplaceWithClosed(t *testing.T) {
t.Run("with initial client destination connection ID", func(t *testing.T) {
testConnIDGeneratorReplaceWithClosed(t, true)
})
t.Run("without initial client destination connection ID", func(t *testing.T) {
testConnIDGeneratorReplaceWithClosed(t, false)
})
}
func testConnIDGeneratorReplaceWithClosed(t *testing.T, hasInitialClientDestConnID bool) {
var initialClientDestConnID *protocol.ConnectionID
if hasInitialClientDestConnID {
connID := protocol.ParseConnectionID([]byte{2, 2, 2, 2})
initialClientDestConnID = &connID
}
var (
added []protocol.ConnectionID
replaced []protocol.ConnectionID
replacedWith []byte
)
g := newConnIDGenerator(
protocol.ParseConnectionID([]byte{1, 1, 1, 1}),
initialClientDestConnID,
func(c protocol.ConnectionID) { added = append(added, c) },
newStatelessResetter(&StatelessResetKey{1, 2, 3, 4}),
func(c protocol.ConnectionID) { t.Fatal("didn't expect conn ID removals") },
func(c protocol.ConnectionID) { t.Fatal("didn't expect conn ID retirements") },
func(connIDs []protocol.ConnectionID, b []byte) {
replaced = connIDs
replacedWith = b
},
func(f wire.Frame) {},
&protocol.DefaultConnectionIDGenerator{ConnLen: 5},
)
require.NoError(t, g.SetMaxActiveConnIDs(1000))
require.Len(t, added, protocol.MaxIssuedConnectionIDs-1)
g.ReplaceWithClosed([]byte("foobar"))
if hasInitialClientDestConnID {
require.Len(t, replaced, protocol.MaxIssuedConnectionIDs+1)
require.Contains(t, replaced, *initialClientDestConnID)
} else {
require.Len(t, replaced, protocol.MaxIssuedConnectionIDs)
}
for _, id := range added {
require.Contains(t, replaced, id)
}
require.Contains(t, replaced, protocol.ParseConnectionID([]byte{1, 1, 1, 1}))
require.Equal(t, []byte("foobar"), replacedWith)
}
golang-github-lucas-clemente-quic-go-0.50.0/conn_id_manager.go 0000664 0000000 0000000 00000022110 14765760516 0024233 0 ustar 00root root 0000000 0000000 package quic
import (
"fmt"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/utils"
list "github.com/quic-go/quic-go/internal/utils/linkedlist"
"github.com/quic-go/quic-go/internal/wire"
)
type newConnID struct {
SequenceNumber uint64
ConnectionID protocol.ConnectionID
StatelessResetToken protocol.StatelessResetToken
}
type connIDManager struct {
queue list.List[newConnID]
highestProbingID uint64
pathProbing map[pathID]newConnID // initialized lazily
handshakeComplete bool
activeSequenceNumber uint64
highestRetired uint64
activeConnectionID protocol.ConnectionID
activeStatelessResetToken *protocol.StatelessResetToken
// We change the connection ID after sending on average
// protocol.PacketsPerConnectionID packets. The actual value is randomized
// hide the packet loss rate from on-path observers.
rand utils.Rand
packetsSinceLastChange uint32
packetsPerConnectionID uint32
addStatelessResetToken func(protocol.StatelessResetToken)
removeStatelessResetToken func(protocol.StatelessResetToken)
queueControlFrame func(wire.Frame)
closed bool
}
func newConnIDManager(
initialDestConnID protocol.ConnectionID,
addStatelessResetToken func(protocol.StatelessResetToken),
removeStatelessResetToken func(protocol.StatelessResetToken),
queueControlFrame func(wire.Frame),
) *connIDManager {
return &connIDManager{
activeConnectionID: initialDestConnID,
addStatelessResetToken: addStatelessResetToken,
removeStatelessResetToken: removeStatelessResetToken,
queueControlFrame: queueControlFrame,
}
}
func (h *connIDManager) AddFromPreferredAddress(connID protocol.ConnectionID, resetToken protocol.StatelessResetToken) error {
return h.addConnectionID(1, connID, resetToken)
}
func (h *connIDManager) Add(f *wire.NewConnectionIDFrame) error {
if err := h.add(f); err != nil {
return err
}
if h.queue.Len() >= protocol.MaxActiveConnectionIDs {
return &qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError}
}
return nil
}
func (h *connIDManager) add(f *wire.NewConnectionIDFrame) error {
if h.activeConnectionID.Len() == 0 {
return &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: "received NEW_CONNECTION_ID frame but zero-length connection IDs are in use",
}
}
// If the NEW_CONNECTION_ID frame is reordered, such that its sequence number is smaller than the currently active
// connection ID or if it was already retired, send the RETIRE_CONNECTION_ID frame immediately.
if f.SequenceNumber < max(h.activeSequenceNumber, h.highestProbingID) || f.SequenceNumber < h.highestRetired {
h.queueControlFrame(&wire.RetireConnectionIDFrame{
SequenceNumber: f.SequenceNumber,
})
return nil
}
if f.RetirePriorTo != 0 && h.pathProbing != nil {
for id, entry := range h.pathProbing {
if entry.SequenceNumber < f.RetirePriorTo {
h.queueControlFrame(&wire.RetireConnectionIDFrame{
SequenceNumber: entry.SequenceNumber,
})
delete(h.pathProbing, id)
}
}
}
// Retire elements in the queue.
// Doesn't retire the active connection ID.
if f.RetirePriorTo > h.highestRetired {
var next *list.Element[newConnID]
for el := h.queue.Front(); el != nil; el = next {
if el.Value.SequenceNumber >= f.RetirePriorTo {
break
}
next = el.Next()
h.queueControlFrame(&wire.RetireConnectionIDFrame{
SequenceNumber: el.Value.SequenceNumber,
})
h.queue.Remove(el)
}
h.highestRetired = f.RetirePriorTo
}
if f.SequenceNumber == h.activeSequenceNumber {
return nil
}
if err := h.addConnectionID(f.SequenceNumber, f.ConnectionID, f.StatelessResetToken); err != nil {
return err
}
// Retire the active connection ID, if necessary.
if h.activeSequenceNumber < f.RetirePriorTo {
// The queue is guaranteed to have at least one element at this point.
h.updateConnectionID()
}
return nil
}
func (h *connIDManager) addConnectionID(seq uint64, connID protocol.ConnectionID, resetToken protocol.StatelessResetToken) error {
// insert a new element at the end
if h.queue.Len() == 0 || h.queue.Back().Value.SequenceNumber < seq {
h.queue.PushBack(newConnID{
SequenceNumber: seq,
ConnectionID: connID,
StatelessResetToken: resetToken,
})
return nil
}
// insert a new element somewhere in the middle
for el := h.queue.Front(); el != nil; el = el.Next() {
if el.Value.SequenceNumber == seq {
if el.Value.ConnectionID != connID {
return fmt.Errorf("received conflicting connection IDs for sequence number %d", seq)
}
if el.Value.StatelessResetToken != resetToken {
return fmt.Errorf("received conflicting stateless reset tokens for sequence number %d", seq)
}
break
}
if el.Value.SequenceNumber > seq {
h.queue.InsertBefore(newConnID{
SequenceNumber: seq,
ConnectionID: connID,
StatelessResetToken: resetToken,
}, el)
break
}
}
return nil
}
func (h *connIDManager) updateConnectionID() {
h.assertNotClosed()
h.queueControlFrame(&wire.RetireConnectionIDFrame{
SequenceNumber: h.activeSequenceNumber,
})
h.highestRetired = max(h.highestRetired, h.activeSequenceNumber)
if h.activeStatelessResetToken != nil {
h.removeStatelessResetToken(*h.activeStatelessResetToken)
}
front := h.queue.Remove(h.queue.Front())
h.activeSequenceNumber = front.SequenceNumber
h.activeConnectionID = front.ConnectionID
h.activeStatelessResetToken = &front.StatelessResetToken
h.packetsSinceLastChange = 0
h.packetsPerConnectionID = protocol.PacketsPerConnectionID/2 + uint32(h.rand.Int31n(protocol.PacketsPerConnectionID))
h.addStatelessResetToken(*h.activeStatelessResetToken)
}
func (h *connIDManager) Close() {
h.closed = true
if h.activeStatelessResetToken != nil {
h.removeStatelessResetToken(*h.activeStatelessResetToken)
}
}
// is called when the server performs a Retry
// and when the server changes the connection ID in the first Initial sent
func (h *connIDManager) ChangeInitialConnID(newConnID protocol.ConnectionID) {
if h.activeSequenceNumber != 0 {
panic("expected first connection ID to have sequence number 0")
}
h.activeConnectionID = newConnID
}
// is called when the server provides a stateless reset token in the transport parameters
func (h *connIDManager) SetStatelessResetToken(token protocol.StatelessResetToken) {
h.assertNotClosed()
if h.activeSequenceNumber != 0 {
panic("expected first connection ID to have sequence number 0")
}
h.activeStatelessResetToken = &token
h.addStatelessResetToken(token)
}
func (h *connIDManager) SentPacket() {
h.packetsSinceLastChange++
}
func (h *connIDManager) shouldUpdateConnID() bool {
if !h.handshakeComplete {
return false
}
// initiate the first change as early as possible (after handshake completion)
if h.queue.Len() > 0 && h.activeSequenceNumber == 0 {
return true
}
// For later changes, only change if
// 1. The queue of connection IDs is filled more than 50%.
// 2. We sent at least PacketsPerConnectionID packets
return 2*h.queue.Len() >= protocol.MaxActiveConnectionIDs &&
h.packetsSinceLastChange >= h.packetsPerConnectionID
}
func (h *connIDManager) Get() protocol.ConnectionID {
h.assertNotClosed()
if h.shouldUpdateConnID() {
h.updateConnectionID()
}
return h.activeConnectionID
}
func (h *connIDManager) SetHandshakeComplete() {
h.handshakeComplete = true
}
// GetConnIDForPath retrieves a connection ID for a new path (i.e. not the active one).
// Once a connection ID is allocated for a path, it cannot be used for a different path.
// When called with the same pathID, it will return the same connection ID,
// unless the peer requested that this connection ID be retired.
func (h *connIDManager) GetConnIDForPath(id pathID) (protocol.ConnectionID, bool) {
h.assertNotClosed()
// if we're using zero-length connection IDs, we don't need to change the connection ID
if h.activeConnectionID.Len() == 0 {
return protocol.ConnectionID{}, true
}
if h.pathProbing == nil {
h.pathProbing = make(map[pathID]newConnID)
}
entry, ok := h.pathProbing[id]
if ok {
return entry.ConnectionID, true
}
if h.queue.Len() == 0 {
return protocol.ConnectionID{}, false
}
front := h.queue.Remove(h.queue.Front())
h.pathProbing[id] = front
h.highestProbingID = front.SequenceNumber
return front.ConnectionID, true
}
func (h *connIDManager) RetireConnIDForPath(pathID pathID) {
h.assertNotClosed()
// if we're using zero-length connection IDs, we don't need to change the connection ID
if h.activeConnectionID.Len() == 0 {
return
}
entry, ok := h.pathProbing[pathID]
if !ok {
return
}
h.queueControlFrame(&wire.RetireConnectionIDFrame{
SequenceNumber: entry.SequenceNumber,
})
delete(h.pathProbing, pathID)
}
// Using the connIDManager after it has been closed can have disastrous effects:
// If the connection ID is rotated, a new entry would be inserted into the packet handler map,
// leading to a memory leak of the connection struct.
// See https://github.com/quic-go/quic-go/pull/4852 for more details.
func (h *connIDManager) assertNotClosed() {
if h.closed {
panic("connection ID manager is closed")
}
}
golang-github-lucas-clemente-quic-go-0.50.0/conn_id_manager_test.go 0000664 0000000 0000000 00000033100 14765760516 0025273 0 ustar 00root root 0000000 0000000 package quic
import (
"crypto/rand"
"testing"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/wire"
"github.com/stretchr/testify/require"
)
func TestConnIDManagerInitialConnID(t *testing.T) {
m := newConnIDManager(protocol.ParseConnectionID([]byte{1, 2, 3, 4}), nil, nil, nil)
require.Equal(t, protocol.ParseConnectionID([]byte{1, 2, 3, 4}), m.Get())
require.Equal(t, protocol.ParseConnectionID([]byte{1, 2, 3, 4}), m.Get())
m.ChangeInitialConnID(protocol.ParseConnectionID([]byte{5, 6, 7, 8}))
require.Equal(t, protocol.ParseConnectionID([]byte{5, 6, 7, 8}), m.Get())
}
func TestConnIDManagerAddConnIDs(t *testing.T) {
m := newConnIDManager(
protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
func(protocol.StatelessResetToken) {},
func(protocol.StatelessResetToken) {},
func(wire.Frame) {},
)
f1 := &wire.NewConnectionIDFrame{
SequenceNumber: 1,
ConnectionID: protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}),
StatelessResetToken: protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe},
}
f2 := &wire.NewConnectionIDFrame{
SequenceNumber: 2,
ConnectionID: protocol.ParseConnectionID([]byte{0xba, 0xad, 0xf0, 0x0d}),
StatelessResetToken: protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe},
}
require.NoError(t, m.Add(f2))
require.NoError(t, m.Add(f1)) // receiving reordered frames is fine
require.NoError(t, m.Add(f2)) // receiving a duplicate is fine
require.Equal(t, protocol.ParseConnectionID([]byte{1, 2, 3, 4}), m.Get())
m.updateConnectionID()
require.Equal(t, protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef}), m.Get())
m.updateConnectionID()
require.Equal(t, protocol.ParseConnectionID([]byte{0xba, 0xad, 0xf0, 0x0d}), m.Get())
require.NoError(t, m.Add(f2)) // receiving a duplicate for the current connection ID is fine as well
require.Equal(t, protocol.ParseConnectionID([]byte{0xba, 0xad, 0xf0, 0x0d}), m.Get())
// receiving mismatching connection IDs is not fine
require.NoError(t, m.Add(&wire.NewConnectionIDFrame{
SequenceNumber: 3,
ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}), // mismatching connection ID
StatelessResetToken: protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe},
}))
require.EqualError(t, m.Add(&wire.NewConnectionIDFrame{
SequenceNumber: 3,
ConnectionID: protocol.ParseConnectionID([]byte{2, 3, 4, 5}), // mismatching connection ID
StatelessResetToken: protocol.StatelessResetToken{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe},
}), "received conflicting connection IDs for sequence number 3")
// receiving mismatching stateless reset tokens is not fine either
require.EqualError(t, m.Add(&wire.NewConnectionIDFrame{
SequenceNumber: 3,
ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
StatelessResetToken: protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 0xa, 0xb, 0xc, 0xd, 0xe, 0},
}), "received conflicting stateless reset tokens for sequence number 3")
}
func TestConnIDManagerLimit(t *testing.T) {
m := newConnIDManager(
protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
func(protocol.StatelessResetToken) {},
func(protocol.StatelessResetToken) {},
func(f wire.Frame) {},
)
for i := uint8(1); i < protocol.MaxActiveConnectionIDs; i++ {
require.NoError(t, m.Add(&wire.NewConnectionIDFrame{
SequenceNumber: uint64(i),
ConnectionID: protocol.ParseConnectionID([]byte{i, i, i, i}),
StatelessResetToken: protocol.StatelessResetToken{i, i, i, i, i, i, i, i, i, i, i, i, i, i, i, i},
}))
}
require.Equal(t, &qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError}, m.Add(&wire.NewConnectionIDFrame{
SequenceNumber: uint64(9999),
ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
StatelessResetToken: protocol.StatelessResetToken{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
}))
}
func TestConnIDManagerRetiringConnectionIDs(t *testing.T) {
var frameQueue []wire.Frame
m := newConnIDManager(
protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
func(protocol.StatelessResetToken) {},
func(protocol.StatelessResetToken) {},
func(f wire.Frame) { frameQueue = append(frameQueue, f) },
)
require.NoError(t, m.Add(&wire.NewConnectionIDFrame{
SequenceNumber: 10,
ConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
}))
require.NoError(t, m.Add(&wire.NewConnectionIDFrame{
SequenceNumber: 13,
ConnectionID: protocol.ParseConnectionID([]byte{2, 3, 4, 5}),
}))
require.Empty(t, frameQueue)
require.NoError(t, m.Add(&wire.NewConnectionIDFrame{
RetirePriorTo: 14,
SequenceNumber: 17,
ConnectionID: protocol.ParseConnectionID([]byte{3, 4, 5, 6}),
}))
require.Equal(t, []wire.Frame{
&wire.RetireConnectionIDFrame{SequenceNumber: 10},
&wire.RetireConnectionIDFrame{SequenceNumber: 13},
&wire.RetireConnectionIDFrame{SequenceNumber: 0},
}, frameQueue)
require.Equal(t, protocol.ParseConnectionID([]byte{3, 4, 5, 6}), m.Get())
frameQueue = nil
// a reordered connection ID is immediately retired
require.NoError(t, m.Add(&wire.NewConnectionIDFrame{
SequenceNumber: 12,
ConnectionID: protocol.ParseConnectionID([]byte{5, 6, 7, 8}),
}))
require.Equal(t, []wire.Frame{&wire.RetireConnectionIDFrame{SequenceNumber: 12}}, frameQueue)
require.Equal(t, protocol.ParseConnectionID([]byte{3, 4, 5, 6}), m.Get())
}
func TestConnIDManagerHandshakeCompletion(t *testing.T) {
var frameQueue []wire.Frame
var addedTokens, removedTokens []protocol.StatelessResetToken
m := newConnIDManager(
protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
func(token protocol.StatelessResetToken) { addedTokens = append(addedTokens, token) },
func(token protocol.StatelessResetToken) { removedTokens = append(removedTokens, token) },
func(f wire.Frame) { frameQueue = append(frameQueue, f) },
)
m.SetStatelessResetToken(protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1})
require.Equal(t, []protocol.StatelessResetToken{{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}}, addedTokens)
require.Empty(t, removedTokens)
require.NoError(t, m.Add(&wire.NewConnectionIDFrame{
SequenceNumber: 1,
ConnectionID: protocol.ParseConnectionID([]byte{4, 3, 2, 1}),
StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1},
}))
require.Equal(t, protocol.ParseConnectionID([]byte{1, 2, 3, 4}), m.Get())
m.SetHandshakeComplete()
require.Equal(t, protocol.ParseConnectionID([]byte{4, 3, 2, 1}), m.Get())
require.Equal(t, []wire.Frame{&wire.RetireConnectionIDFrame{SequenceNumber: 0}}, frameQueue)
require.Equal(t, []protocol.StatelessResetToken{{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}}, removedTokens)
}
func TestConnIDManagerConnIDRotation(t *testing.T) {
var frameQueue []wire.Frame
m := newConnIDManager(
protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
func(protocol.StatelessResetToken) {},
func(protocol.StatelessResetToken) {},
func(f wire.Frame) { frameQueue = append(frameQueue, f) },
)
// the first connection ID is used as soon as the handshake is complete
m.SetHandshakeComplete()
require.NoError(t, m.Add(&wire.NewConnectionIDFrame{
SequenceNumber: 1,
ConnectionID: protocol.ParseConnectionID([]byte{4, 3, 2, 1}),
StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1},
}))
require.Equal(t, protocol.ParseConnectionID([]byte{4, 3, 2, 1}), m.Get())
frameQueue = nil
// Note that we're missing the connection ID with sequence number 2.
// It will be received later.
var queuedConnIDs []protocol.ConnectionID
for i := 0; i < protocol.MaxActiveConnectionIDs-1; i++ {
b := make([]byte, 4)
rand.Read(b)
connID := protocol.ParseConnectionID(b)
queuedConnIDs = append(queuedConnIDs, connID)
require.NoError(t, m.Add(&wire.NewConnectionIDFrame{
SequenceNumber: uint64(3 + i),
ConnectionID: connID,
StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1},
}))
}
var counter int
for {
require.Empty(t, frameQueue)
m.SentPacket()
counter++
if m.Get() != protocol.ParseConnectionID([]byte{4, 3, 2, 1}) {
require.Equal(t, queuedConnIDs[0], m.Get())
require.Equal(t, []wire.Frame{&wire.RetireConnectionIDFrame{SequenceNumber: 1}}, frameQueue)
break
}
}
require.GreaterOrEqual(t, counter, protocol.PacketsPerConnectionID/2)
require.LessOrEqual(t, counter, protocol.PacketsPerConnectionID*3/2)
frameQueue = nil
// now receive connection ID 2
require.NoError(t, m.Add(&wire.NewConnectionIDFrame{
SequenceNumber: 2,
ConnectionID: protocol.ParseConnectionID([]byte{2, 3, 4, 5}),
}))
require.Equal(t, []wire.Frame{&wire.RetireConnectionIDFrame{SequenceNumber: 2}}, frameQueue)
}
func TestConnIDManagerPathMigration(t *testing.T) {
var frameQueue []wire.Frame
m := newConnIDManager(
protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
func(protocol.StatelessResetToken) {},
func(protocol.StatelessResetToken) {},
func(f wire.Frame) { frameQueue = append(frameQueue, f) },
)
// no connection ID available yet
_, ok := m.GetConnIDForPath(1)
require.False(t, ok)
// add a connection ID
require.NoError(t, m.Add(&wire.NewConnectionIDFrame{
SequenceNumber: 1,
ConnectionID: protocol.ParseConnectionID([]byte{4, 3, 2, 1}),
StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1},
}))
require.NoError(t, m.Add(&wire.NewConnectionIDFrame{
SequenceNumber: 2,
ConnectionID: protocol.ParseConnectionID([]byte{5, 4, 3, 2}),
StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1},
}))
connID, ok := m.GetConnIDForPath(1)
require.True(t, ok)
require.Equal(t, protocol.ParseConnectionID([]byte{4, 3, 2, 1}), connID)
connID, ok = m.GetConnIDForPath(2)
require.True(t, ok)
require.Equal(t, protocol.ParseConnectionID([]byte{5, 4, 3, 2}), connID)
// asking for the connection for path 1 again returns the same connection ID
connID, ok = m.GetConnIDForPath(1)
require.True(t, ok)
require.Equal(t, protocol.ParseConnectionID([]byte{4, 3, 2, 1}), connID)
// if the connection ID is retired, the path will use another connection ID
require.NoError(t, m.Add(&wire.NewConnectionIDFrame{
SequenceNumber: 3,
RetirePriorTo: 2,
ConnectionID: protocol.ParseConnectionID([]byte{6, 5, 4, 3}),
StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1},
}))
require.Len(t, frameQueue, 2)
frameQueue = nil
require.Equal(t, protocol.ParseConnectionID([]byte{6, 5, 4, 3}), m.Get())
// the connection ID is not used for new paths
_, ok = m.GetConnIDForPath(3)
require.False(t, ok)
// Manually retiring the connection ID does nothing.
// Path 1 doesn't have a connection ID anymore.
m.RetireConnIDForPath(1)
require.Empty(t, frameQueue)
_, ok = m.GetConnIDForPath(1)
require.False(t, ok)
// only after a new connection ID is added, it will be used for path 1
require.NoError(t, m.Add(&wire.NewConnectionIDFrame{
SequenceNumber: 4,
ConnectionID: protocol.ParseConnectionID([]byte{7, 6, 5, 4}),
StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1},
}))
connID, ok = m.GetConnIDForPath(1)
require.True(t, ok)
require.Equal(t, protocol.ParseConnectionID([]byte{7, 6, 5, 4}), connID)
// a RETIRE_CONNECTION_ID frame for path 1 is queued when retiring the connection ID
m.RetireConnIDForPath(1)
require.Equal(t, []wire.Frame{&wire.RetireConnectionIDFrame{SequenceNumber: 4}}, frameQueue)
}
func TestConnIDManagerZeroLengthConnectionID(t *testing.T) {
m := newConnIDManager(
protocol.ConnectionID{},
func(protocol.StatelessResetToken) {},
func(protocol.StatelessResetToken) {},
func(f wire.Frame) {},
)
require.Equal(t, protocol.ConnectionID{}, m.Get())
for i := 0; i < 5*protocol.PacketsPerConnectionID; i++ {
m.SentPacket()
require.Equal(t, protocol.ConnectionID{}, m.Get())
}
// for path probing, we don't need to change the connection ID
for id := pathID(1); id < 10; id++ {
connID, ok := m.GetConnIDForPath(id)
require.True(t, ok)
require.Equal(t, protocol.ConnectionID{}, connID)
}
// retiring a connection ID for a path is also a no-op
for id := pathID(1); id < 20; id++ {
m.RetireConnIDForPath(id)
}
require.ErrorIs(t, m.Add(&wire.NewConnectionIDFrame{
SequenceNumber: 1,
ConnectionID: protocol.ConnectionID{},
StatelessResetToken: protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1},
}), &qerr.TransportError{ErrorCode: qerr.ProtocolViolation})
}
func TestConnIDManagerClose(t *testing.T) {
var addedTokens, removedTokens []protocol.StatelessResetToken
m := newConnIDManager(
protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
func(token protocol.StatelessResetToken) { addedTokens = append(addedTokens, token) },
func(token protocol.StatelessResetToken) { removedTokens = append(removedTokens, token) },
func(f wire.Frame) {},
)
m.SetStatelessResetToken(protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1})
require.Equal(t, []protocol.StatelessResetToken{{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}}, addedTokens)
require.Empty(t, removedTokens)
m.Close()
require.Equal(t, []protocol.StatelessResetToken{{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}}, removedTokens)
require.Panics(t, func() { m.Get() })
require.Panics(t, func() { m.SetStatelessResetToken(protocol.StatelessResetToken{}) })
}
golang-github-lucas-clemente-quic-go-0.50.0/connection.go 0000664 0000000 0000000 00000244151 14765760516 0023302 0 ustar 00root root 0000000 0000000 package quic
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"reflect"
"sync"
"sync/atomic"
"time"
"github.com/quic-go/quic-go/internal/ackhandler"
"github.com/quic-go/quic-go/internal/flowcontrol"
"github.com/quic-go/quic-go/internal/handshake"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/utils/ringbuffer"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/logging"
)
type unpacker interface {
UnpackLongHeader(hdr *wire.Header, data []byte) (*unpackedPacket, error)
UnpackShortHeader(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error)
}
type streamManager interface {
GetOrOpenSendStream(protocol.StreamID) (sendStreamI, error)
GetOrOpenReceiveStream(protocol.StreamID) (receiveStreamI, error)
OpenStream() (Stream, error)
OpenUniStream() (SendStream, error)
OpenStreamSync(context.Context) (Stream, error)
OpenUniStreamSync(context.Context) (SendStream, error)
AcceptStream(context.Context) (Stream, error)
AcceptUniStream(context.Context) (ReceiveStream, error)
DeleteStream(protocol.StreamID) error
UpdateLimits(*wire.TransportParameters)
HandleMaxStreamsFrame(*wire.MaxStreamsFrame)
CloseWithError(error)
ResetFor0RTT()
UseResetMaps()
}
type cryptoStreamHandler interface {
StartHandshake(context.Context) error
ChangeConnectionID(protocol.ConnectionID)
SetLargest1RTTAcked(protocol.PacketNumber) error
SetHandshakeConfirmed()
GetSessionTicket() ([]byte, error)
NextEvent() handshake.Event
DiscardInitialKeys()
HandleMessage([]byte, protocol.EncryptionLevel) error
io.Closer
ConnectionState() handshake.ConnectionState
}
type receivedPacket struct {
buffer *packetBuffer
remoteAddr net.Addr
rcvTime time.Time
data []byte
ecn protocol.ECN
info packetInfo // only valid if the contained IP address is valid
}
func (p *receivedPacket) Size() protocol.ByteCount { return protocol.ByteCount(len(p.data)) }
func (p *receivedPacket) Clone() *receivedPacket {
return &receivedPacket{
remoteAddr: p.remoteAddr,
rcvTime: p.rcvTime,
data: p.data,
buffer: p.buffer,
ecn: p.ecn,
info: p.info,
}
}
type connRunner interface {
Add(protocol.ConnectionID, packetHandler) bool
Retire(protocol.ConnectionID)
Remove(protocol.ConnectionID)
ReplaceWithClosed([]protocol.ConnectionID, []byte)
AddResetToken(protocol.StatelessResetToken, packetHandler)
RemoveResetToken(protocol.StatelessResetToken)
}
type closeError struct {
err error
immediate bool
}
type errCloseForRecreating struct {
nextPacketNumber protocol.PacketNumber
nextVersion protocol.Version
}
func (e *errCloseForRecreating) Error() string {
return "closing connection in order to recreate it"
}
var connTracingID atomic.Uint64 // to be accessed atomically
func nextConnTracingID() ConnectionTracingID { return ConnectionTracingID(connTracingID.Add(1)) }
// A Connection is a QUIC connection
type connection struct {
// Destination connection ID used during the handshake.
// Used to check source connection ID on incoming packets.
handshakeDestConnID protocol.ConnectionID
// Set for the client. Destination connection ID used on the first Initial sent.
origDestConnID protocol.ConnectionID
retrySrcConnID *protocol.ConnectionID // only set for the client (and if a Retry was performed)
srcConnIDLen int
perspective protocol.Perspective
version protocol.Version
config *Config
conn sendConn
sendQueue sender
// lazily initialzed: most connections never migrate
pathManager *pathManager
largestRcvdAppData protocol.PacketNumber
streamsMap streamManager
connIDManager *connIDManager
connIDGenerator *connIDGenerator
rttStats *utils.RTTStats
cryptoStreamManager *cryptoStreamManager
sentPacketHandler ackhandler.SentPacketHandler
receivedPacketHandler ackhandler.ReceivedPacketHandler
retransmissionQueue *retransmissionQueue
framer *framer
connFlowController flowcontrol.ConnectionFlowController
tokenStoreKey string // only set for the client
tokenGenerator *handshake.TokenGenerator // only set for the server
unpacker unpacker
frameParser wire.FrameParser
packer packer
mtuDiscoverer mtuDiscoverer // initialized when the transport parameters are received
currentMTUEstimate atomic.Uint32
initialStream *cryptoStream
handshakeStream *cryptoStream
oneRTTStream *cryptoStream // only set for the server
cryptoStreamHandler cryptoStreamHandler
notifyReceivedPacket chan struct{}
sendingScheduled chan struct{}
receivedPacketMx sync.Mutex
receivedPackets ringbuffer.RingBuffer[receivedPacket]
// closeChan is used to notify the run loop that it should terminate
closeChan chan struct{}
closeErr atomic.Pointer[closeError]
ctx context.Context
ctxCancel context.CancelCauseFunc
handshakeCompleteChan chan struct{}
undecryptablePackets []receivedPacket // undecryptable packets, waiting for a change in encryption level
undecryptablePacketsToProcess []receivedPacket
earlyConnReadyChan chan struct{}
sentFirstPacket bool
droppedInitialKeys bool
handshakeComplete bool
handshakeConfirmed bool
receivedRetry bool
versionNegotiated bool
receivedFirstPacket bool
// the minimum of the max_idle_timeout values advertised by both endpoints
idleTimeout time.Duration
creationTime time.Time
// The idle timeout is set based on the max of the time we received the last packet...
lastPacketReceivedTime time.Time
// ... and the time we sent a new ack-eliciting packet after receiving a packet.
firstAckElicitingPacketAfterIdleSentTime time.Time
// pacingDeadline is the time when the next packet should be sent
pacingDeadline time.Time
peerParams *wire.TransportParameters
timer connectionTimer
// keepAlivePingSent stores whether a keep alive PING is in flight.
// It is reset as soon as we receive a packet from the peer.
keepAlivePingSent bool
keepAliveInterval time.Duration
datagramQueue *datagramQueue
connStateMutex sync.Mutex
connState ConnectionState
logID string
tracer *logging.ConnectionTracer
logger utils.Logger
}
var (
_ Connection = &connection{}
_ EarlyConnection = &connection{}
_ streamSender = &connection{}
)
var newConnection = func(
ctx context.Context,
ctxCancel context.CancelCauseFunc,
conn sendConn,
runner connRunner,
origDestConnID protocol.ConnectionID,
retrySrcConnID *protocol.ConnectionID,
clientDestConnID protocol.ConnectionID,
destConnID protocol.ConnectionID,
srcConnID protocol.ConnectionID,
connIDGenerator ConnectionIDGenerator,
statelessResetter *statelessResetter,
conf *Config,
tlsConf *tls.Config,
tokenGenerator *handshake.TokenGenerator,
clientAddressValidated bool,
tracer *logging.ConnectionTracer,
logger utils.Logger,
v protocol.Version,
) quicConn {
s := &connection{
ctx: ctx,
ctxCancel: ctxCancel,
conn: conn,
config: conf,
handshakeDestConnID: destConnID,
srcConnIDLen: srcConnID.Len(),
tokenGenerator: tokenGenerator,
oneRTTStream: newCryptoStream(),
perspective: protocol.PerspectiveServer,
tracer: tracer,
logger: logger,
version: v,
}
if origDestConnID.Len() > 0 {
s.logID = origDestConnID.String()
} else {
s.logID = destConnID.String()
}
s.connIDManager = newConnIDManager(
destConnID,
func(token protocol.StatelessResetToken) { runner.AddResetToken(token, s) },
runner.RemoveResetToken,
s.queueControlFrame,
)
s.connIDGenerator = newConnIDGenerator(
srcConnID,
&clientDestConnID,
func(connID protocol.ConnectionID) { runner.Add(connID, s) },
statelessResetter,
runner.Remove,
runner.Retire,
runner.ReplaceWithClosed,
s.queueControlFrame,
connIDGenerator,
)
s.preSetup()
s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler(
0,
protocol.ByteCount(s.config.InitialPacketSize),
s.rttStats,
clientAddressValidated,
s.conn.capabilities().ECN,
s.perspective,
s.tracer,
s.logger,
)
s.currentMTUEstimate.Store(uint32(estimateMaxPayloadSize(protocol.ByteCount(s.config.InitialPacketSize))))
statelessResetToken := statelessResetter.GetStatelessResetToken(srcConnID)
params := &wire.TransportParameters{
InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
InitialMaxStreamDataUni: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
InitialMaxData: protocol.ByteCount(s.config.InitialConnectionReceiveWindow),
MaxIdleTimeout: s.config.MaxIdleTimeout,
MaxBidiStreamNum: protocol.StreamNum(s.config.MaxIncomingStreams),
MaxUniStreamNum: protocol.StreamNum(s.config.MaxIncomingUniStreams),
MaxAckDelay: protocol.MaxAckDelayInclGranularity,
AckDelayExponent: protocol.AckDelayExponent,
MaxUDPPayloadSize: protocol.MaxPacketBufferSize,
DisableActiveMigration: true,
StatelessResetToken: &statelessResetToken,
OriginalDestinationConnectionID: origDestConnID,
// For interoperability with quic-go versions before May 2023, this value must be set to a value
// different from protocol.DefaultActiveConnectionIDLimit.
// If set to the default value, it will be omitted from the transport parameters, which will make
// old quic-go versions interpret it as 0, instead of the default value of 2.
// See https://github.com/quic-go/quic-go/pull/3806.
ActiveConnectionIDLimit: protocol.MaxActiveConnectionIDs,
InitialSourceConnectionID: srcConnID,
RetrySourceConnectionID: retrySrcConnID,
}
if s.config.EnableDatagrams {
params.MaxDatagramFrameSize = wire.MaxDatagramSize
} else {
params.MaxDatagramFrameSize = protocol.InvalidByteCount
}
if s.tracer != nil && s.tracer.SentTransportParameters != nil {
s.tracer.SentTransportParameters(params)
}
cs := handshake.NewCryptoSetupServer(
clientDestConnID,
conn.LocalAddr(),
conn.RemoteAddr(),
params,
tlsConf,
conf.Allow0RTT,
s.rttStats,
tracer,
logger,
s.version,
)
s.cryptoStreamHandler = cs
s.packer = newPacketPacker(srcConnID, s.connIDManager.Get, s.initialStream, s.handshakeStream, s.sentPacketHandler, s.retransmissionQueue, cs, s.framer, s.receivedPacketHandler, s.datagramQueue, s.perspective)
s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen)
s.cryptoStreamManager = newCryptoStreamManager(s.initialStream, s.handshakeStream, s.oneRTTStream)
return s
}
// declare this as a variable, such that we can it mock it in the tests
var newClientConnection = func(
ctx context.Context,
conn sendConn,
runner connRunner,
destConnID protocol.ConnectionID,
srcConnID protocol.ConnectionID,
connIDGenerator ConnectionIDGenerator,
statelessResetter *statelessResetter,
conf *Config,
tlsConf *tls.Config,
initialPacketNumber protocol.PacketNumber,
enable0RTT bool,
hasNegotiatedVersion bool,
tracer *logging.ConnectionTracer,
logger utils.Logger,
v protocol.Version,
) quicConn {
s := &connection{
conn: conn,
config: conf,
origDestConnID: destConnID,
handshakeDestConnID: destConnID,
srcConnIDLen: srcConnID.Len(),
perspective: protocol.PerspectiveClient,
logID: destConnID.String(),
logger: logger,
tracer: tracer,
versionNegotiated: hasNegotiatedVersion,
version: v,
}
s.connIDManager = newConnIDManager(
destConnID,
func(token protocol.StatelessResetToken) { runner.AddResetToken(token, s) },
runner.RemoveResetToken,
s.queueControlFrame,
)
s.connIDGenerator = newConnIDGenerator(
srcConnID,
nil,
func(connID protocol.ConnectionID) { runner.Add(connID, s) },
statelessResetter,
runner.Remove,
runner.Retire,
runner.ReplaceWithClosed,
s.queueControlFrame,
connIDGenerator,
)
s.ctx, s.ctxCancel = context.WithCancelCause(ctx)
s.preSetup()
s.sentPacketHandler, s.receivedPacketHandler = ackhandler.NewAckHandler(
initialPacketNumber,
protocol.ByteCount(s.config.InitialPacketSize),
s.rttStats,
false, // has no effect
s.conn.capabilities().ECN,
s.perspective,
s.tracer,
s.logger,
)
s.currentMTUEstimate.Store(uint32(estimateMaxPayloadSize(protocol.ByteCount(s.config.InitialPacketSize))))
oneRTTStream := newCryptoStream()
params := &wire.TransportParameters{
InitialMaxStreamDataBidiRemote: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
InitialMaxStreamDataBidiLocal: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
InitialMaxStreamDataUni: protocol.ByteCount(s.config.InitialStreamReceiveWindow),
InitialMaxData: protocol.ByteCount(s.config.InitialConnectionReceiveWindow),
MaxIdleTimeout: s.config.MaxIdleTimeout,
MaxBidiStreamNum: protocol.StreamNum(s.config.MaxIncomingStreams),
MaxUniStreamNum: protocol.StreamNum(s.config.MaxIncomingUniStreams),
MaxAckDelay: protocol.MaxAckDelayInclGranularity,
MaxUDPPayloadSize: protocol.MaxPacketBufferSize,
AckDelayExponent: protocol.AckDelayExponent,
DisableActiveMigration: true,
// For interoperability with quic-go versions before May 2023, this value must be set to a value
// different from protocol.DefaultActiveConnectionIDLimit.
// If set to the default value, it will be omitted from the transport parameters, which will make
// old quic-go versions interpret it as 0, instead of the default value of 2.
// See https://github.com/quic-go/quic-go/pull/3806.
ActiveConnectionIDLimit: protocol.MaxActiveConnectionIDs,
InitialSourceConnectionID: srcConnID,
}
if s.config.EnableDatagrams {
params.MaxDatagramFrameSize = wire.MaxDatagramSize
} else {
params.MaxDatagramFrameSize = protocol.InvalidByteCount
}
if s.tracer != nil && s.tracer.SentTransportParameters != nil {
s.tracer.SentTransportParameters(params)
}
cs := handshake.NewCryptoSetupClient(
destConnID,
params,
tlsConf,
enable0RTT,
s.rttStats,
tracer,
logger,
s.version,
)
s.cryptoStreamHandler = cs
s.cryptoStreamManager = newCryptoStreamManager(s.initialStream, s.handshakeStream, oneRTTStream)
s.unpacker = newPacketUnpacker(cs, s.srcConnIDLen)
s.packer = newPacketPacker(srcConnID, s.connIDManager.Get, s.initialStream, s.handshakeStream, s.sentPacketHandler, s.retransmissionQueue, cs, s.framer, s.receivedPacketHandler, s.datagramQueue, s.perspective)
if len(tlsConf.ServerName) > 0 {
s.tokenStoreKey = tlsConf.ServerName
} else {
s.tokenStoreKey = conn.RemoteAddr().String()
}
if s.config.TokenStore != nil {
if token := s.config.TokenStore.Pop(s.tokenStoreKey); token != nil {
s.packer.SetToken(token.data)
}
}
return s
}
func (s *connection) preSetup() {
s.largestRcvdAppData = protocol.InvalidPacketNumber
s.initialStream = newCryptoStream()
s.handshakeStream = newCryptoStream()
s.sendQueue = newSendQueue(s.conn)
s.retransmissionQueue = newRetransmissionQueue()
s.frameParser = *wire.NewFrameParser(s.config.EnableDatagrams)
s.rttStats = &utils.RTTStats{}
s.connFlowController = flowcontrol.NewConnectionFlowController(
protocol.ByteCount(s.config.InitialConnectionReceiveWindow),
protocol.ByteCount(s.config.MaxConnectionReceiveWindow),
func(size protocol.ByteCount) bool {
if s.config.AllowConnectionWindowIncrease == nil {
return true
}
return s.config.AllowConnectionWindowIncrease(s, uint64(size))
},
s.rttStats,
s.logger,
)
s.earlyConnReadyChan = make(chan struct{})
s.streamsMap = newStreamsMap(
s.ctx,
s,
s.queueControlFrame,
s.newFlowController,
uint64(s.config.MaxIncomingStreams),
uint64(s.config.MaxIncomingUniStreams),
s.perspective,
)
s.framer = newFramer(s.connFlowController)
s.receivedPackets.Init(8)
s.notifyReceivedPacket = make(chan struct{}, 1)
s.closeChan = make(chan struct{}, 1)
s.sendingScheduled = make(chan struct{}, 1)
s.handshakeCompleteChan = make(chan struct{})
now := time.Now()
s.lastPacketReceivedTime = now
s.creationTime = now
s.datagramQueue = newDatagramQueue(s.scheduleSending, s.logger)
s.connState.Version = s.version
}
// run the connection main loop
func (s *connection) run() (err error) {
defer func() { s.ctxCancel(err) }()
defer func() {
// drain queued packets that will never be processed
s.receivedPacketMx.Lock()
defer s.receivedPacketMx.Unlock()
for !s.receivedPackets.Empty() {
p := s.receivedPackets.PopFront()
p.buffer.Decrement()
p.buffer.MaybeRelease()
}
}()
s.timer = *newTimer()
if err := s.cryptoStreamHandler.StartHandshake(s.ctx); err != nil {
return err
}
if err := s.handleHandshakeEvents(time.Now()); err != nil {
return err
}
go func() {
if err := s.sendQueue.Run(); err != nil {
s.destroyImpl(err)
}
}()
if s.perspective == protocol.PerspectiveClient {
s.scheduleSending() // so the ClientHello actually gets sent
}
var sendQueueAvailable <-chan struct{}
runLoop:
for {
if s.framer.QueuedTooManyControlFrames() {
s.setCloseError(&closeError{err: &qerr.TransportError{ErrorCode: InternalError}})
break runLoop
}
// Close immediately if requested
select {
case <-s.closeChan:
break runLoop
default:
}
// no need to set a timer if we can send packets immediately
if s.pacingDeadline != deadlineSendImmediately {
s.maybeResetTimer()
}
// 1st: handle undecryptable packets, if any.
// This can only occur before completion of the handshake.
if len(s.undecryptablePacketsToProcess) > 0 {
var processedUndecryptablePacket bool
queue := s.undecryptablePacketsToProcess
s.undecryptablePacketsToProcess = nil
for _, p := range queue {
processed, err := s.handleOnePacket(p)
if err != nil {
s.setCloseError(&closeError{err: err})
break runLoop
}
if processed {
processedUndecryptablePacket = true
}
}
if processedUndecryptablePacket {
// if we processed any undecryptable packets, jump to the resetting of the timers directly
continue
}
}
// 2nd: receive packets.
processed, err := s.handlePackets() // don't check receivedPackets.Len() in the run loop to avoid locking the mutex
if err != nil {
s.setCloseError(&closeError{err: err})
break runLoop
}
// We don't need to wait for new events if:
// * we processed packets: we probably need to send an ACK, and potentially more data
// * the pacer allows us to send more packets immediately
shouldProceedImmediately := sendQueueAvailable == nil && (processed || s.pacingDeadline == deadlineSendImmediately)
if !shouldProceedImmediately {
// 3rd: wait for something to happen:
// * closing of the connection
// * timer firing
// * sending scheduled
// * send queue available
// * received packets
select {
case <-s.closeChan:
break runLoop
case <-s.timer.Chan():
s.timer.SetRead()
case <-s.sendingScheduled:
case <-sendQueueAvailable:
case <-s.notifyReceivedPacket:
wasProcessed, err := s.handlePackets()
if err != nil {
s.setCloseError(&closeError{err: err})
break runLoop
}
// if we processed any undecryptable packets, jump to the resetting of the timers directly
if !wasProcessed {
continue
}
}
}
// Check for loss detection timeout.
// This could cause packets to be declared lost, and retransmissions to be enqueued.
now := time.Now()
if timeout := s.sentPacketHandler.GetLossDetectionTimeout(); !timeout.IsZero() && timeout.Before(now) {
if err := s.sentPacketHandler.OnLossDetectionTimeout(now); err != nil {
s.setCloseError(&closeError{err: err})
break runLoop
}
}
if keepAliveTime := s.nextKeepAliveTime(); !keepAliveTime.IsZero() && !now.Before(keepAliveTime) {
// send a PING frame since there is no activity in the connection
s.logger.Debugf("Sending a keep-alive PING to keep the connection alive.")
s.framer.QueueControlFrame(&wire.PingFrame{})
s.keepAlivePingSent = true
} else if !s.handshakeComplete && now.Sub(s.creationTime) >= s.config.handshakeTimeout() {
s.destroyImpl(qerr.ErrHandshakeTimeout)
break runLoop
} else {
idleTimeoutStartTime := s.idleTimeoutStartTime()
if (!s.handshakeComplete && now.Sub(idleTimeoutStartTime) >= s.config.HandshakeIdleTimeout) ||
(s.handshakeComplete && now.After(s.nextIdleTimeoutTime())) {
s.destroyImpl(qerr.ErrIdleTimeout)
break runLoop
}
}
if s.sendQueue.WouldBlock() {
// The send queue is still busy sending out packets. Wait until there's space to enqueue new packets.
sendQueueAvailable = s.sendQueue.Available()
// Cancel the pacing timer, as we can't send any more packets until the send queue is available again.
s.pacingDeadline = time.Time{}
continue
}
if s.closeErr.Load() != nil {
break runLoop
}
if err := s.triggerSending(now); err != nil {
s.setCloseError(&closeError{err: err})
break runLoop
}
if s.sendQueue.WouldBlock() {
// The send queue is still busy sending out packets. Wait until there's space to enqueue new packets.
sendQueueAvailable = s.sendQueue.Available()
// Cancel the pacing timer, as we can't send any more packets until the send queue is available again.
s.pacingDeadline = time.Time{}
} else {
sendQueueAvailable = nil
}
}
closeErr := s.closeErr.Load()
s.cryptoStreamHandler.Close()
s.sendQueue.Close() // close the send queue before sending the CONNECTION_CLOSE
s.handleCloseError(closeErr)
if s.tracer != nil && s.tracer.Close != nil {
if e := (&errCloseForRecreating{}); !errors.As(closeErr.err, &e) {
s.tracer.Close()
}
}
s.logger.Infof("Connection %s closed.", s.logID)
s.timer.Stop()
return closeErr.err
}
// blocks until the early connection can be used
func (s *connection) earlyConnReady() <-chan struct{} {
return s.earlyConnReadyChan
}
func (s *connection) HandshakeComplete() <-chan struct{} {
return s.handshakeCompleteChan
}
func (s *connection) Context() context.Context {
return s.ctx
}
func (s *connection) supportsDatagrams() bool {
return s.peerParams.MaxDatagramFrameSize > 0
}
func (s *connection) ConnectionState() ConnectionState {
s.connStateMutex.Lock()
defer s.connStateMutex.Unlock()
cs := s.cryptoStreamHandler.ConnectionState()
s.connState.TLS = cs.ConnectionState
s.connState.Used0RTT = cs.Used0RTT
s.connState.GSO = s.conn.capabilities().GSO
return s.connState
}
// Time when the connection should time out
func (s *connection) nextIdleTimeoutTime() time.Time {
idleTimeout := max(s.idleTimeout, s.rttStats.PTO(true)*3)
return s.idleTimeoutStartTime().Add(idleTimeout)
}
// Time when the next keep-alive packet should be sent.
// It returns a zero time if no keep-alive should be sent.
func (s *connection) nextKeepAliveTime() time.Time {
if s.config.KeepAlivePeriod == 0 || s.keepAlivePingSent {
return time.Time{}
}
keepAliveInterval := max(s.keepAliveInterval, s.rttStats.PTO(true)*3/2)
return s.lastPacketReceivedTime.Add(keepAliveInterval)
}
func (s *connection) maybeResetTimer() {
var deadline time.Time
if !s.handshakeComplete {
deadline = s.creationTime.Add(s.config.handshakeTimeout())
if t := s.idleTimeoutStartTime().Add(s.config.HandshakeIdleTimeout); t.Before(deadline) {
deadline = t
}
} else {
if keepAliveTime := s.nextKeepAliveTime(); !keepAliveTime.IsZero() {
deadline = keepAliveTime
} else {
deadline = s.nextIdleTimeoutTime()
}
}
s.timer.SetTimer(
deadline,
s.receivedPacketHandler.GetAlarmTimeout(),
s.sentPacketHandler.GetLossDetectionTimeout(),
s.pacingDeadline,
)
}
func (s *connection) idleTimeoutStartTime() time.Time {
startTime := s.lastPacketReceivedTime
if t := s.firstAckElicitingPacketAfterIdleSentTime; t.After(startTime) {
startTime = t
}
return startTime
}
func (s *connection) handleHandshakeComplete(now time.Time) error {
defer close(s.handshakeCompleteChan)
// Once the handshake completes, we have derived 1-RTT keys.
// There's no point in queueing undecryptable packets for later decryption anymore.
s.undecryptablePackets = nil
s.connIDManager.SetHandshakeComplete()
s.connIDGenerator.SetHandshakeComplete()
if s.tracer != nil && s.tracer.ChoseALPN != nil {
s.tracer.ChoseALPN(s.cryptoStreamHandler.ConnectionState().NegotiatedProtocol)
}
// The server applies transport parameters right away, but the client side has to wait for handshake completion.
// During a 0-RTT connection, the client is only allowed to use the new transport parameters for 1-RTT packets.
if s.perspective == protocol.PerspectiveClient {
s.applyTransportParameters()
return nil
}
// All these only apply to the server side.
if err := s.handleHandshakeConfirmed(now); err != nil {
return err
}
ticket, err := s.cryptoStreamHandler.GetSessionTicket()
if err != nil {
return err
}
if ticket != nil { // may be nil if session tickets are disabled via tls.Config.SessionTicketsDisabled
s.oneRTTStream.Write(ticket)
for s.oneRTTStream.HasData() {
s.queueControlFrame(s.oneRTTStream.PopCryptoFrame(protocol.MaxPostHandshakeCryptoFrameSize))
}
}
token, err := s.tokenGenerator.NewToken(s.conn.RemoteAddr())
if err != nil {
return err
}
s.queueControlFrame(&wire.NewTokenFrame{Token: token})
s.queueControlFrame(&wire.HandshakeDoneFrame{})
return nil
}
func (s *connection) handleHandshakeConfirmed(now time.Time) error {
if err := s.dropEncryptionLevel(protocol.EncryptionHandshake, now); err != nil {
return err
}
s.handshakeConfirmed = true
s.cryptoStreamHandler.SetHandshakeConfirmed()
if !s.config.DisablePathMTUDiscovery && s.conn.capabilities().DF {
s.mtuDiscoverer.Start(now)
}
return nil
}
func (s *connection) handlePackets() (wasProcessed bool, _ error) {
// Now process all packets in the receivedPackets channel.
// Limit the number of packets to the length of the receivedPackets channel,
// so we eventually get a chance to send out an ACK when receiving a lot of packets.
s.receivedPacketMx.Lock()
numPackets := s.receivedPackets.Len()
if numPackets == 0 {
s.receivedPacketMx.Unlock()
return false, nil
}
var hasMorePackets bool
for i := 0; i < numPackets; i++ {
if i > 0 {
s.receivedPacketMx.Lock()
}
p := s.receivedPackets.PopFront()
hasMorePackets = !s.receivedPackets.Empty()
s.receivedPacketMx.Unlock()
processed, err := s.handleOnePacket(p)
if err != nil {
return false, err
}
if processed {
wasProcessed = true
}
if !hasMorePackets {
break
}
// only process a single packet at a time before handshake completion
if !s.handshakeComplete {
break
}
}
if hasMorePackets {
select {
case s.notifyReceivedPacket <- struct{}{}:
default:
}
}
return wasProcessed, nil
}
func (s *connection) handleOnePacket(rp receivedPacket) (wasProcessed bool, _ error) {
s.sentPacketHandler.ReceivedBytes(rp.Size(), rp.rcvTime)
if wire.IsVersionNegotiationPacket(rp.data) {
s.handleVersionNegotiationPacket(rp)
return false, nil
}
var counter uint8
var lastConnID protocol.ConnectionID
data := rp.data
p := rp
for len(data) > 0 {
if counter > 0 {
p = *(p.Clone())
p.data = data
destConnID, err := wire.ParseConnectionID(p.data, s.srcConnIDLen)
if err != nil {
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, protocol.ByteCount(len(data)), logging.PacketDropHeaderParseError)
}
s.logger.Debugf("error parsing packet, couldn't parse connection ID: %s", err)
break
}
if destConnID != lastConnID {
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, protocol.ByteCount(len(data)), logging.PacketDropUnknownConnectionID)
}
s.logger.Debugf("coalesced packet has different destination connection ID: %s, expected %s", destConnID, lastConnID)
break
}
}
if wire.IsLongHeaderPacket(p.data[0]) {
hdr, packetData, rest, err := wire.ParsePacket(p.data)
if err != nil {
if s.tracer != nil && s.tracer.DroppedPacket != nil {
dropReason := logging.PacketDropHeaderParseError
if err == wire.ErrUnsupportedVersion {
dropReason = logging.PacketDropUnsupportedVersion
}
s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, protocol.ByteCount(len(data)), dropReason)
}
s.logger.Debugf("error parsing packet: %s", err)
break
}
lastConnID = hdr.DestConnectionID
if hdr.Version != s.version {
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), protocol.InvalidPacketNumber, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedVersion)
}
s.logger.Debugf("Dropping packet with version %x. Expected %x.", hdr.Version, s.version)
break
}
if counter > 0 {
p.buffer.Split()
}
counter++
// only log if this actually a coalesced packet
if s.logger.Debug() && (counter > 1 || len(rest) > 0) {
s.logger.Debugf("Parsed a coalesced packet. Part %d: %d bytes. Remaining: %d bytes.", counter, len(packetData), len(rest))
}
p.data = packetData
processed, err := s.handleLongHeaderPacket(p, hdr)
if err != nil {
return false, err
}
if processed {
wasProcessed = true
}
data = rest
} else {
if counter > 0 {
p.buffer.Split()
}
processed, err := s.handleShortHeaderPacket(p)
if err != nil {
return false, err
}
if processed {
wasProcessed = true
}
break
}
}
p.buffer.MaybeRelease()
return wasProcessed, nil
}
func (s *connection) handleShortHeaderPacket(p receivedPacket) (wasProcessed bool, _ error) {
var wasQueued bool
defer func() {
// Put back the packet buffer if the packet wasn't queued for later decryption.
if !wasQueued {
p.buffer.Decrement()
}
}()
destConnID, err := wire.ParseConnectionID(p.data, s.srcConnIDLen)
if err != nil {
s.tracer.DroppedPacket(logging.PacketType1RTT, protocol.InvalidPacketNumber, protocol.ByteCount(len(p.data)), logging.PacketDropHeaderParseError)
return false, nil
}
pn, pnLen, keyPhase, data, err := s.unpacker.UnpackShortHeader(p.rcvTime, p.data)
if err != nil {
wasQueued, err = s.handleUnpackError(err, p, logging.PacketType1RTT)
return false, err
}
s.largestRcvdAppData = max(s.largestRcvdAppData, pn)
if s.logger.Debug() {
s.logger.Debugf("<- Reading packet %d (%d bytes) for connection %s, 1-RTT", pn, p.Size(), destConnID)
wire.LogShortHeader(s.logger, destConnID, pn, pnLen, keyPhase)
}
if s.receivedPacketHandler.IsPotentiallyDuplicate(pn, protocol.Encryption1RTT) {
s.logger.Debugf("Dropping (potentially) duplicate packet.")
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(logging.PacketType1RTT, pn, p.Size(), logging.PacketDropDuplicate)
}
return false, nil
}
var log func([]logging.Frame)
if s.tracer != nil && s.tracer.ReceivedShortHeaderPacket != nil {
log = func(frames []logging.Frame) {
s.tracer.ReceivedShortHeaderPacket(
&logging.ShortHeader{
DestConnectionID: destConnID,
PacketNumber: pn,
PacketNumberLen: pnLen,
KeyPhase: keyPhase,
},
p.Size(),
p.ecn,
frames,
)
}
}
isNonProbing, err := s.handleUnpackedShortHeaderPacket(destConnID, pn, data, p.ecn, p.rcvTime, log)
if err != nil {
return false, err
}
// In RFC 9000, only the client can migrate between paths.
if s.perspective == protocol.PerspectiveClient {
return true, nil
}
var shouldSwitchPath bool
if pn == s.largestRcvdAppData && !addrsEqual(p.remoteAddr, s.RemoteAddr()) {
if s.pathManager == nil {
s.pathManager = newPathManager(
s.connIDManager.GetConnIDForPath,
s.connIDManager.RetireConnIDForPath,
s.logger,
)
}
var destConnID protocol.ConnectionID
var pathChallenge ackhandler.Frame
destConnID, pathChallenge, shouldSwitchPath = s.pathManager.HandlePacket(p, isNonProbing)
if pathChallenge.Frame != nil {
probe, buf, err := s.packer.PackPathProbePacket(destConnID, pathChallenge, s.version)
if err != nil {
return false, err
}
s.logger.Debugf("sending path probe packet to %s", p.remoteAddr)
s.logShortHeaderPacket(probe.DestConnID, probe.Ack, probe.Frames, probe.StreamFrames, probe.PacketNumber, probe.PacketNumberLen, probe.KeyPhase, protocol.ECNNon, buf.Len(), false)
s.registerPackedShortHeaderPacket(probe, protocol.ECNNon, p.rcvTime)
s.sendQueue.SendProbe(buf, p.remoteAddr)
}
}
if shouldSwitchPath {
s.pathManager.SwitchToPath(p.remoteAddr)
s.sentPacketHandler.MigratedPath(p.rcvTime, protocol.ByteCount(s.config.InitialPacketSize))
maxPacketSize := protocol.ByteCount(protocol.MaxPacketBufferSize)
if s.peerParams.MaxUDPPayloadSize > 0 && s.peerParams.MaxUDPPayloadSize < maxPacketSize {
maxPacketSize = s.peerParams.MaxUDPPayloadSize
}
s.mtuDiscoverer.Reset(
p.rcvTime,
protocol.ByteCount(s.config.InitialPacketSize),
maxPacketSize,
)
s.conn.ChangeRemoteAddr(p.remoteAddr, p.info)
}
return true, nil
}
func (s *connection) handleLongHeaderPacket(p receivedPacket, hdr *wire.Header) (wasProcessed bool, _ error) {
var wasQueued bool
defer func() {
// Put back the packet buffer if the packet wasn't queued for later decryption.
if !wasQueued {
p.buffer.Decrement()
}
}()
if hdr.Type == protocol.PacketTypeRetry {
return s.handleRetryPacket(hdr, p.data, p.rcvTime), nil
}
// The server can change the source connection ID with the first Handshake packet.
// After this, all packets with a different source connection have to be ignored.
if s.receivedFirstPacket && hdr.Type == protocol.PacketTypeInitial && hdr.SrcConnectionID != s.handshakeDestConnID {
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(logging.PacketTypeInitial, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnknownConnectionID)
}
s.logger.Debugf("Dropping Initial packet (%d bytes) with unexpected source connection ID: %s (expected %s)", p.Size(), hdr.SrcConnectionID, s.handshakeDestConnID)
return false, nil
}
// drop 0-RTT packets, if we are a client
if s.perspective == protocol.PerspectiveClient && hdr.Type == protocol.PacketType0RTT {
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(logging.PacketType0RTT, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnexpectedPacket)
}
return false, nil
}
packet, err := s.unpacker.UnpackLongHeader(hdr, p.data)
if err != nil {
wasQueued, err = s.handleUnpackError(err, p, logging.PacketTypeFromHeader(hdr))
return false, err
}
if s.logger.Debug() {
s.logger.Debugf("<- Reading packet %d (%d bytes) for connection %s, %s", packet.hdr.PacketNumber, p.Size(), hdr.DestConnectionID, packet.encryptionLevel)
packet.hdr.Log(s.logger)
}
if pn := packet.hdr.PacketNumber; s.receivedPacketHandler.IsPotentiallyDuplicate(pn, packet.encryptionLevel) {
s.logger.Debugf("Dropping (potentially) duplicate packet.")
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(logging.PacketTypeFromHeader(hdr), pn, p.Size(), logging.PacketDropDuplicate)
}
return false, nil
}
if err := s.handleUnpackedLongHeaderPacket(packet, p.ecn, p.rcvTime, p.Size()); err != nil {
return false, err
}
return true, nil
}
func (s *connection) handleUnpackError(err error, p receivedPacket, pt logging.PacketType) (wasQueued bool, _ error) {
switch err {
case handshake.ErrKeysDropped:
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(pt, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropKeyUnavailable)
}
s.logger.Debugf("Dropping %s packet (%d bytes) because we already dropped the keys.", pt, p.Size())
return false, nil
case handshake.ErrKeysNotYetAvailable:
// Sealer for this encryption level not yet available.
// Try again later.
s.tryQueueingUndecryptablePacket(p, pt)
return true, nil
case wire.ErrInvalidReservedBits:
return false, &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: err.Error(),
}
case handshake.ErrDecryptionFailed:
// This might be a packet injected by an attacker. Drop it.
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(pt, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropPayloadDecryptError)
}
s.logger.Debugf("Dropping %s packet (%d bytes) that could not be unpacked. Error: %s", pt, p.Size(), err)
return false, nil
default:
var headerErr *headerParseError
if errors.As(err, &headerErr) {
// This might be a packet injected by an attacker. Drop it.
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(pt, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropHeaderParseError)
}
s.logger.Debugf("Dropping %s packet (%d bytes) for which we couldn't unpack the header. Error: %s", pt, p.Size(), err)
return false, nil
}
// This is an error returned by the AEAD (other than ErrDecryptionFailed).
// For example, a PROTOCOL_VIOLATION due to key updates.
return false, err
}
}
func (s *connection) handleRetryPacket(hdr *wire.Header, data []byte, rcvTime time.Time) bool /* was this a valid Retry */ {
if s.perspective == protocol.PerspectiveServer {
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.InvalidPacketNumber, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket)
}
s.logger.Debugf("Ignoring Retry.")
return false
}
if s.receivedFirstPacket {
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.InvalidPacketNumber, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket)
}
s.logger.Debugf("Ignoring Retry, since we already received a packet.")
return false
}
destConnID := s.connIDManager.Get()
if hdr.SrcConnectionID == destConnID {
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.InvalidPacketNumber, protocol.ByteCount(len(data)), logging.PacketDropUnexpectedPacket)
}
s.logger.Debugf("Ignoring Retry, since the server didn't change the Source Connection ID.")
return false
}
// If a token is already set, this means that we already received a Retry from the server.
// Ignore this Retry packet.
if s.receivedRetry {
s.logger.Debugf("Ignoring Retry, since a Retry was already received.")
return false
}
tag := handshake.GetRetryIntegrityTag(data[:len(data)-16], destConnID, hdr.Version)
if !bytes.Equal(data[len(data)-16:], tag[:]) {
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(logging.PacketTypeRetry, protocol.InvalidPacketNumber, protocol.ByteCount(len(data)), logging.PacketDropPayloadDecryptError)
}
s.logger.Debugf("Ignoring spoofed Retry. Integrity Tag doesn't match.")
return false
}
newDestConnID := hdr.SrcConnectionID
s.receivedRetry = true
s.sentPacketHandler.ResetForRetry(rcvTime)
s.handshakeDestConnID = newDestConnID
s.retrySrcConnID = &newDestConnID
s.cryptoStreamHandler.ChangeConnectionID(newDestConnID)
s.packer.SetToken(hdr.Token)
s.connIDManager.ChangeInitialConnID(newDestConnID)
if s.logger.Debug() {
s.logger.Debugf("<- Received Retry:")
(&wire.ExtendedHeader{Header: *hdr}).Log(s.logger)
s.logger.Debugf("Switching destination connection ID to: %s", hdr.SrcConnectionID)
}
if s.tracer != nil && s.tracer.ReceivedRetry != nil {
s.tracer.ReceivedRetry(hdr)
}
s.scheduleSending()
return true
}
func (s *connection) handleVersionNegotiationPacket(p receivedPacket) {
if s.perspective == protocol.PerspectiveServer || // servers never receive version negotiation packets
s.receivedFirstPacket || s.versionNegotiated { // ignore delayed / duplicated version negotiation packets
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnexpectedPacket)
}
return
}
src, dest, supportedVersions, err := wire.ParseVersionNegotiationPacket(p.data)
if err != nil {
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropHeaderParseError)
}
s.logger.Debugf("Error parsing Version Negotiation packet: %s", err)
return
}
for _, v := range supportedVersions {
if v == s.version {
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(logging.PacketTypeVersionNegotiation, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnexpectedVersion)
}
// The Version Negotiation packet contains the version that we offered.
// This might be a packet sent by an attacker, or it was corrupted.
return
}
}
s.logger.Infof("Received a Version Negotiation packet. Supported Versions: %s", supportedVersions)
if s.tracer != nil && s.tracer.ReceivedVersionNegotiationPacket != nil {
s.tracer.ReceivedVersionNegotiationPacket(dest, src, supportedVersions)
}
newVersion, ok := protocol.ChooseSupportedVersion(s.config.Versions, supportedVersions)
if !ok {
s.destroyImpl(&VersionNegotiationError{
Ours: s.config.Versions,
Theirs: supportedVersions,
})
s.logger.Infof("No compatible QUIC version found.")
return
}
if s.tracer != nil && s.tracer.NegotiatedVersion != nil {
s.tracer.NegotiatedVersion(newVersion, s.config.Versions, supportedVersions)
}
s.logger.Infof("Switching to QUIC version %s.", newVersion)
nextPN, _ := s.sentPacketHandler.PeekPacketNumber(protocol.EncryptionInitial)
s.destroyImpl(&errCloseForRecreating{
nextPacketNumber: nextPN,
nextVersion: newVersion,
})
}
func (s *connection) handleUnpackedLongHeaderPacket(
packet *unpackedPacket,
ecn protocol.ECN,
rcvTime time.Time,
packetSize protocol.ByteCount, // only for logging
) error {
if !s.receivedFirstPacket {
s.receivedFirstPacket = true
if !s.versionNegotiated && s.tracer != nil && s.tracer.NegotiatedVersion != nil {
var clientVersions, serverVersions []protocol.Version
switch s.perspective {
case protocol.PerspectiveClient:
clientVersions = s.config.Versions
case protocol.PerspectiveServer:
serverVersions = s.config.Versions
}
s.tracer.NegotiatedVersion(s.version, clientVersions, serverVersions)
}
// The server can change the source connection ID with the first Handshake packet.
if s.perspective == protocol.PerspectiveClient && packet.hdr.SrcConnectionID != s.handshakeDestConnID {
cid := packet.hdr.SrcConnectionID
s.logger.Debugf("Received first packet. Switching destination connection ID to: %s", cid)
s.handshakeDestConnID = cid
s.connIDManager.ChangeInitialConnID(cid)
}
// We create the connection as soon as we receive the first packet from the client.
// We do that before authenticating the packet.
// That means that if the source connection ID was corrupted,
// we might have created a connection with an incorrect source connection ID.
// Once we authenticate the first packet, we need to update it.
if s.perspective == protocol.PerspectiveServer {
if packet.hdr.SrcConnectionID != s.handshakeDestConnID {
s.handshakeDestConnID = packet.hdr.SrcConnectionID
s.connIDManager.ChangeInitialConnID(packet.hdr.SrcConnectionID)
}
if s.tracer != nil && s.tracer.StartedConnection != nil {
s.tracer.StartedConnection(
s.conn.LocalAddr(),
s.conn.RemoteAddr(),
packet.hdr.SrcConnectionID,
packet.hdr.DestConnectionID,
)
}
}
}
if s.perspective == protocol.PerspectiveServer && packet.encryptionLevel == protocol.EncryptionHandshake &&
!s.droppedInitialKeys {
// On the server side, Initial keys are dropped as soon as the first Handshake packet is received.
// See Section 4.9.1 of RFC 9001.
if err := s.dropEncryptionLevel(protocol.EncryptionInitial, rcvTime); err != nil {
return err
}
}
s.lastPacketReceivedTime = rcvTime
s.firstAckElicitingPacketAfterIdleSentTime = time.Time{}
s.keepAlivePingSent = false
if packet.hdr.Type == protocol.PacketType0RTT {
s.largestRcvdAppData = max(s.largestRcvdAppData, packet.hdr.PacketNumber)
}
var log func([]logging.Frame)
if s.tracer != nil && s.tracer.ReceivedLongHeaderPacket != nil {
log = func(frames []logging.Frame) {
s.tracer.ReceivedLongHeaderPacket(packet.hdr, packetSize, ecn, frames)
}
}
isAckEliciting, _, err := s.handleFrames(packet.data, packet.hdr.DestConnectionID, packet.encryptionLevel, log, rcvTime)
if err != nil {
return err
}
return s.receivedPacketHandler.ReceivedPacket(packet.hdr.PacketNumber, ecn, packet.encryptionLevel, rcvTime, isAckEliciting)
}
func (s *connection) handleUnpackedShortHeaderPacket(
destConnID protocol.ConnectionID,
pn protocol.PacketNumber,
data []byte,
ecn protocol.ECN,
rcvTime time.Time,
log func([]logging.Frame),
) (isNonProbing bool, _ error) {
s.lastPacketReceivedTime = rcvTime
s.firstAckElicitingPacketAfterIdleSentTime = time.Time{}
s.keepAlivePingSent = false
isAckEliciting, isNonProbing, err := s.handleFrames(data, destConnID, protocol.Encryption1RTT, log, rcvTime)
if err != nil {
return false, err
}
if err := s.receivedPacketHandler.ReceivedPacket(pn, ecn, protocol.Encryption1RTT, rcvTime, isAckEliciting); err != nil {
return false, err
}
return isNonProbing, nil
}
func (s *connection) handleFrames(
data []byte,
destConnID protocol.ConnectionID,
encLevel protocol.EncryptionLevel,
log func([]logging.Frame),
rcvTime time.Time,
) (isAckEliciting, isNonProbing bool, _ error) {
// Only used for tracing.
// If we're not tracing, this slice will always remain empty.
var frames []logging.Frame
if log != nil {
frames = make([]logging.Frame, 0, 4)
}
handshakeWasComplete := s.handshakeComplete
var handleErr error
for len(data) > 0 {
l, frame, err := s.frameParser.ParseNext(data, encLevel, s.version)
if err != nil {
return false, false, err
}
data = data[l:]
if frame == nil {
break
}
if ackhandler.IsFrameAckEliciting(frame) {
isAckEliciting = true
}
if !wire.IsProbingFrame(frame) {
isNonProbing = true
}
if log != nil {
frames = append(frames, toLoggingFrame(frame))
}
// An error occurred handling a previous frame.
// Don't handle the current frame.
if handleErr != nil {
continue
}
if err := s.handleFrame(frame, encLevel, destConnID, rcvTime); err != nil {
if log == nil {
return false, false, err
}
// If we're logging, we need to keep parsing (but not handling) all frames.
handleErr = err
}
}
if log != nil {
log(frames)
if handleErr != nil {
return false, false, handleErr
}
}
// Handle completion of the handshake after processing all the frames.
// This ensures that we correctly handle the following case on the server side:
// We receive a Handshake packet that contains the CRYPTO frame that allows us to complete the handshake,
// and an ACK serialized after that CRYPTO frame. In this case, we still want to process the ACK frame.
if !handshakeWasComplete && s.handshakeComplete {
if err := s.handleHandshakeComplete(rcvTime); err != nil {
return false, false, err
}
}
return
}
func (s *connection) handleFrame(
f wire.Frame,
encLevel protocol.EncryptionLevel,
destConnID protocol.ConnectionID,
rcvTime time.Time,
) error {
var err error
wire.LogFrame(s.logger, f, false)
switch frame := f.(type) {
case *wire.CryptoFrame:
err = s.handleCryptoFrame(frame, encLevel, rcvTime)
case *wire.StreamFrame:
err = s.handleStreamFrame(frame, rcvTime)
case *wire.AckFrame:
err = s.handleAckFrame(frame, encLevel, rcvTime)
case *wire.ConnectionCloseFrame:
err = s.handleConnectionCloseFrame(frame)
case *wire.ResetStreamFrame:
err = s.handleResetStreamFrame(frame, rcvTime)
case *wire.MaxDataFrame:
s.handleMaxDataFrame(frame)
case *wire.MaxStreamDataFrame:
err = s.handleMaxStreamDataFrame(frame)
case *wire.MaxStreamsFrame:
s.handleMaxStreamsFrame(frame)
case *wire.DataBlockedFrame:
case *wire.StreamDataBlockedFrame:
err = s.handleStreamDataBlockedFrame(frame)
case *wire.StreamsBlockedFrame:
case *wire.StopSendingFrame:
err = s.handleStopSendingFrame(frame)
case *wire.PingFrame:
case *wire.PathChallengeFrame:
s.handlePathChallengeFrame(frame)
case *wire.PathResponseFrame:
err = s.handlePathResponseFrame(frame)
case *wire.NewTokenFrame:
err = s.handleNewTokenFrame(frame)
case *wire.NewConnectionIDFrame:
err = s.handleNewConnectionIDFrame(frame)
case *wire.RetireConnectionIDFrame:
err = s.handleRetireConnectionIDFrame(frame, destConnID)
case *wire.HandshakeDoneFrame:
err = s.handleHandshakeDoneFrame(rcvTime)
case *wire.DatagramFrame:
err = s.handleDatagramFrame(frame)
default:
err = fmt.Errorf("unexpected frame type: %s", reflect.ValueOf(&frame).Elem().Type().Name())
}
return err
}
// handlePacket is called by the server with a new packet
func (s *connection) handlePacket(p receivedPacket) {
s.receivedPacketMx.Lock()
// Discard packets once the amount of queued packets is larger than
// the channel size, protocol.MaxConnUnprocessedPackets
if s.receivedPackets.Len() >= protocol.MaxConnUnprocessedPackets {
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropDOSPrevention)
}
s.receivedPacketMx.Unlock()
return
}
s.receivedPackets.PushBack(p)
s.receivedPacketMx.Unlock()
select {
case s.notifyReceivedPacket <- struct{}{}:
default:
}
}
func (s *connection) handleConnectionCloseFrame(frame *wire.ConnectionCloseFrame) error {
if frame.IsApplicationError {
return &qerr.ApplicationError{
Remote: true,
ErrorCode: qerr.ApplicationErrorCode(frame.ErrorCode),
ErrorMessage: frame.ReasonPhrase,
}
}
return &qerr.TransportError{
Remote: true,
ErrorCode: qerr.TransportErrorCode(frame.ErrorCode),
FrameType: frame.FrameType,
ErrorMessage: frame.ReasonPhrase,
}
}
func (s *connection) handleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel, rcvTime time.Time) error {
if err := s.cryptoStreamManager.HandleCryptoFrame(frame, encLevel); err != nil {
return err
}
for {
data := s.cryptoStreamManager.GetCryptoData(encLevel)
if data == nil {
break
}
if err := s.cryptoStreamHandler.HandleMessage(data, encLevel); err != nil {
return err
}
}
return s.handleHandshakeEvents(rcvTime)
}
func (s *connection) handleHandshakeEvents(now time.Time) error {
for {
ev := s.cryptoStreamHandler.NextEvent()
var err error
switch ev.Kind {
case handshake.EventNoEvent:
return nil
case handshake.EventHandshakeComplete:
// Don't call handleHandshakeComplete yet.
// It's advantageous to process ACK frames that might be serialized after the CRYPTO frame first.
s.handshakeComplete = true
case handshake.EventReceivedTransportParameters:
err = s.handleTransportParameters(ev.TransportParameters)
case handshake.EventRestoredTransportParameters:
s.restoreTransportParameters(ev.TransportParameters)
close(s.earlyConnReadyChan)
case handshake.EventReceivedReadKeys:
// queue all previously undecryptable packets
s.undecryptablePacketsToProcess = append(s.undecryptablePacketsToProcess, s.undecryptablePackets...)
s.undecryptablePackets = nil
case handshake.EventDiscard0RTTKeys:
err = s.dropEncryptionLevel(protocol.Encryption0RTT, now)
case handshake.EventWriteInitialData:
_, err = s.initialStream.Write(ev.Data)
case handshake.EventWriteHandshakeData:
_, err = s.handshakeStream.Write(ev.Data)
}
if err != nil {
return err
}
}
}
func (s *connection) handleStreamFrame(frame *wire.StreamFrame, rcvTime time.Time) error {
str, err := s.streamsMap.GetOrOpenReceiveStream(frame.StreamID)
if err != nil {
return err
}
if str == nil { // stream was already closed and garbage collected
return nil
}
return str.handleStreamFrame(frame, rcvTime)
}
func (s *connection) handleMaxDataFrame(frame *wire.MaxDataFrame) {
s.connFlowController.UpdateSendWindow(frame.MaximumData)
}
func (s *connection) handleMaxStreamDataFrame(frame *wire.MaxStreamDataFrame) error {
str, err := s.streamsMap.GetOrOpenSendStream(frame.StreamID)
if err != nil {
return err
}
if str == nil {
// stream is closed and already garbage collected
return nil
}
str.updateSendWindow(frame.MaximumStreamData)
return nil
}
func (s *connection) handleStreamDataBlockedFrame(frame *wire.StreamDataBlockedFrame) error {
// We don't need to do anything in response to a STREAM_DATA_BLOCKED frame,
// but we need to make sure that the stream ID is valid.
_, err := s.streamsMap.GetOrOpenReceiveStream(frame.StreamID)
return err
}
func (s *connection) handleMaxStreamsFrame(frame *wire.MaxStreamsFrame) {
s.streamsMap.HandleMaxStreamsFrame(frame)
}
func (s *connection) handleResetStreamFrame(frame *wire.ResetStreamFrame, rcvTime time.Time) error {
str, err := s.streamsMap.GetOrOpenReceiveStream(frame.StreamID)
if err != nil {
return err
}
if str == nil {
// stream is closed and already garbage collected
return nil
}
return str.handleResetStreamFrame(frame, rcvTime)
}
func (s *connection) handleStopSendingFrame(frame *wire.StopSendingFrame) error {
str, err := s.streamsMap.GetOrOpenSendStream(frame.StreamID)
if err != nil {
return err
}
if str == nil {
// stream is closed and already garbage collected
return nil
}
str.handleStopSendingFrame(frame)
return nil
}
func (s *connection) handlePathChallengeFrame(f *wire.PathChallengeFrame) {
s.queueControlFrame(&wire.PathResponseFrame{Data: f.Data})
}
func (s *connection) handlePathResponseFrame(f *wire.PathResponseFrame) error {
s.logger.Debugf("received PATH_RESPONSE frame: %v", f.Data)
if s.pathManager == nil {
// since we didn't send PATH_CHALLENGEs yet, we don't expect PATH_RESPONSEs
return &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: "unexpected PATH_RESPONSE frame",
}
}
s.pathManager.HandlePathResponseFrame(f)
return nil
}
func (s *connection) handleNewTokenFrame(frame *wire.NewTokenFrame) error {
if s.perspective == protocol.PerspectiveServer {
return &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: "received NEW_TOKEN frame from the client",
}
}
if s.config.TokenStore != nil {
s.config.TokenStore.Put(s.tokenStoreKey, &ClientToken{data: frame.Token})
}
return nil
}
func (s *connection) handleNewConnectionIDFrame(f *wire.NewConnectionIDFrame) error {
return s.connIDManager.Add(f)
}
func (s *connection) handleRetireConnectionIDFrame(f *wire.RetireConnectionIDFrame, destConnID protocol.ConnectionID) error {
return s.connIDGenerator.Retire(f.SequenceNumber, destConnID)
}
func (s *connection) handleHandshakeDoneFrame(rcvTime time.Time) error {
if s.perspective == protocol.PerspectiveServer {
return &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: "received a HANDSHAKE_DONE frame",
}
}
if !s.handshakeConfirmed {
return s.handleHandshakeConfirmed(rcvTime)
}
return nil
}
func (s *connection) handleAckFrame(frame *wire.AckFrame, encLevel protocol.EncryptionLevel, rcvTime time.Time) error {
acked1RTTPacket, err := s.sentPacketHandler.ReceivedAck(frame, encLevel, s.lastPacketReceivedTime)
if err != nil {
return err
}
if !acked1RTTPacket {
return nil
}
// On the client side: If the packet acknowledged a 1-RTT packet, this confirms the handshake.
// This is only possible if the ACK was sent in a 1-RTT packet.
// This is an optimization over simply waiting for a HANDSHAKE_DONE frame, see section 4.1.2 of RFC 9001.
if s.perspective == protocol.PerspectiveClient && !s.handshakeConfirmed {
if err := s.handleHandshakeConfirmed(rcvTime); err != nil {
return err
}
}
// If one of the acknowledged packets was a Path MTU probe packet, this might have increased the Path MTU estimate.
if s.mtuDiscoverer != nil {
if mtu := s.mtuDiscoverer.CurrentSize(); mtu > protocol.ByteCount(s.currentMTUEstimate.Load()) {
s.currentMTUEstimate.Store(uint32(mtu))
s.sentPacketHandler.SetMaxDatagramSize(mtu)
}
}
return s.cryptoStreamHandler.SetLargest1RTTAcked(frame.LargestAcked())
}
func (s *connection) handleDatagramFrame(f *wire.DatagramFrame) error {
if f.Length(s.version) > wire.MaxDatagramSize {
return &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: "DATAGRAM frame too large",
}
}
s.datagramQueue.HandleDatagramFrame(f)
return nil
}
func (s *connection) setCloseError(e *closeError) {
s.closeErr.CompareAndSwap(nil, e)
select {
case s.closeChan <- struct{}{}:
default:
}
}
// closeLocal closes the connection and send a CONNECTION_CLOSE containing the error
func (s *connection) closeLocal(e error) {
s.setCloseError(&closeError{err: e, immediate: false})
}
// destroy closes the connection without sending the error on the wire
func (s *connection) destroy(e error) {
s.destroyImpl(e)
<-s.ctx.Done()
}
func (s *connection) destroyImpl(e error) {
s.setCloseError(&closeError{err: e, immediate: true})
}
func (s *connection) CloseWithError(code ApplicationErrorCode, desc string) error {
s.closeLocal(&qerr.ApplicationError{
ErrorCode: code,
ErrorMessage: desc,
})
<-s.ctx.Done()
return nil
}
func (s *connection) closeWithTransportError(code TransportErrorCode) {
s.closeLocal(&qerr.TransportError{ErrorCode: code})
<-s.ctx.Done()
}
func (s *connection) handleCloseError(closeErr *closeError) {
if closeErr.immediate {
if nerr, ok := closeErr.err.(net.Error); ok && nerr.Timeout() {
s.logger.Errorf("Destroying connection: %s", closeErr.err)
} else {
s.logger.Errorf("Destroying connection with error: %s", closeErr.err)
}
} else {
if closeErr.err == nil {
s.logger.Infof("Closing connection.")
} else {
s.logger.Errorf("Closing connection with error: %s", closeErr.err)
}
}
e := closeErr.err
if e == nil {
e = &qerr.ApplicationError{}
} else {
defer func() { closeErr.err = e }()
}
var (
statelessResetErr *StatelessResetError
versionNegotiationErr *VersionNegotiationError
recreateErr *errCloseForRecreating
applicationErr *ApplicationError
transportErr *TransportError
)
var isRemoteClose bool
switch {
case errors.Is(e, qerr.ErrIdleTimeout),
errors.Is(e, qerr.ErrHandshakeTimeout),
errors.As(e, &statelessResetErr),
errors.As(e, &versionNegotiationErr),
errors.As(e, &recreateErr):
case errors.As(e, &applicationErr):
isRemoteClose = applicationErr.Remote
case errors.As(e, &transportErr):
isRemoteClose = transportErr.Remote
case closeErr.immediate:
e = closeErr.err
default:
e = &qerr.TransportError{
ErrorCode: qerr.InternalError,
ErrorMessage: e.Error(),
}
}
s.streamsMap.CloseWithError(e)
if s.datagramQueue != nil {
s.datagramQueue.CloseWithError(e)
}
// In rare instances, the connection ID manager might switch to a new connection ID
// when sending the CONNECTION_CLOSE frame.
// The connection ID manager removes the active stateless reset token from the packet
// handler map when it is closed, so we need to make sure that this happens last.
defer s.connIDManager.Close()
if s.tracer != nil && s.tracer.ClosedConnection != nil && !errors.As(e, &recreateErr) {
s.tracer.ClosedConnection(e)
}
// If this is a remote close we're done here
if isRemoteClose {
s.connIDGenerator.ReplaceWithClosed(nil)
return
}
if closeErr.immediate {
s.connIDGenerator.RemoveAll()
return
}
// Don't send out any CONNECTION_CLOSE if this is an error that occurred
// before we even sent out the first packet.
if s.perspective == protocol.PerspectiveClient && !s.sentFirstPacket {
s.connIDGenerator.RemoveAll()
return
}
connClosePacket, err := s.sendConnectionClose(e)
if err != nil {
s.logger.Debugf("Error sending CONNECTION_CLOSE: %s", err)
}
s.connIDGenerator.ReplaceWithClosed(connClosePacket)
}
func (s *connection) dropEncryptionLevel(encLevel protocol.EncryptionLevel, now time.Time) error {
if s.tracer != nil && s.tracer.DroppedEncryptionLevel != nil {
s.tracer.DroppedEncryptionLevel(encLevel)
}
s.sentPacketHandler.DropPackets(encLevel, now)
s.receivedPacketHandler.DropPackets(encLevel)
//nolint:exhaustive // only Initial and 0-RTT need special treatment
switch encLevel {
case protocol.EncryptionInitial:
s.droppedInitialKeys = true
s.cryptoStreamHandler.DiscardInitialKeys()
case protocol.Encryption0RTT:
s.streamsMap.ResetFor0RTT()
s.framer.Handle0RTTRejection()
return s.connFlowController.Reset()
}
return s.cryptoStreamManager.Drop(encLevel)
}
// is called for the client, when restoring transport parameters saved for 0-RTT
func (s *connection) restoreTransportParameters(params *wire.TransportParameters) {
if s.logger.Debug() {
s.logger.Debugf("Restoring Transport Parameters: %s", params)
}
s.peerParams = params
s.connIDGenerator.SetMaxActiveConnIDs(params.ActiveConnectionIDLimit)
s.connFlowController.UpdateSendWindow(params.InitialMaxData)
s.streamsMap.UpdateLimits(params)
s.connStateMutex.Lock()
s.connState.SupportsDatagrams = s.supportsDatagrams()
s.connStateMutex.Unlock()
}
func (s *connection) handleTransportParameters(params *wire.TransportParameters) error {
if s.tracer != nil && s.tracer.ReceivedTransportParameters != nil {
s.tracer.ReceivedTransportParameters(params)
}
if err := s.checkTransportParameters(params); err != nil {
return &qerr.TransportError{
ErrorCode: qerr.TransportParameterError,
ErrorMessage: err.Error(),
}
}
if s.perspective == protocol.PerspectiveClient && s.peerParams != nil && s.ConnectionState().Used0RTT && !params.ValidForUpdate(s.peerParams) {
return &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: "server sent reduced limits after accepting 0-RTT data",
}
}
s.peerParams = params
// On the client side we have to wait for handshake completion.
// During a 0-RTT connection, we are only allowed to use the new transport parameters for 1-RTT packets.
if s.perspective == protocol.PerspectiveServer {
s.applyTransportParameters()
// On the server side, the early connection is ready as soon as we processed
// the client's transport parameters.
close(s.earlyConnReadyChan)
}
s.connStateMutex.Lock()
s.connState.SupportsDatagrams = s.supportsDatagrams()
s.connStateMutex.Unlock()
return nil
}
func (s *connection) checkTransportParameters(params *wire.TransportParameters) error {
if s.logger.Debug() {
s.logger.Debugf("Processed Transport Parameters: %s", params)
}
// check the initial_source_connection_id
if params.InitialSourceConnectionID != s.handshakeDestConnID {
return fmt.Errorf("expected initial_source_connection_id to equal %s, is %s", s.handshakeDestConnID, params.InitialSourceConnectionID)
}
if s.perspective == protocol.PerspectiveServer {
return nil
}
// check the original_destination_connection_id
if params.OriginalDestinationConnectionID != s.origDestConnID {
return fmt.Errorf("expected original_destination_connection_id to equal %s, is %s", s.origDestConnID, params.OriginalDestinationConnectionID)
}
if s.retrySrcConnID != nil { // a Retry was performed
if params.RetrySourceConnectionID == nil {
return errors.New("missing retry_source_connection_id")
}
if *params.RetrySourceConnectionID != *s.retrySrcConnID {
return fmt.Errorf("expected retry_source_connection_id to equal %s, is %s", s.retrySrcConnID, *params.RetrySourceConnectionID)
}
} else if params.RetrySourceConnectionID != nil {
return errors.New("received retry_source_connection_id, although no Retry was performed")
}
return nil
}
func (s *connection) applyTransportParameters() {
params := s.peerParams
// Our local idle timeout will always be > 0.
s.idleTimeout = s.config.MaxIdleTimeout
// If the peer advertised an idle timeout, take the minimum of the values.
if params.MaxIdleTimeout > 0 {
s.idleTimeout = min(s.idleTimeout, params.MaxIdleTimeout)
}
s.keepAliveInterval = min(s.config.KeepAlivePeriod, s.idleTimeout/2)
s.streamsMap.UpdateLimits(params)
s.frameParser.SetAckDelayExponent(params.AckDelayExponent)
s.connFlowController.UpdateSendWindow(params.InitialMaxData)
s.rttStats.SetMaxAckDelay(params.MaxAckDelay)
s.connIDGenerator.SetMaxActiveConnIDs(params.ActiveConnectionIDLimit)
if params.StatelessResetToken != nil {
s.connIDManager.SetStatelessResetToken(*params.StatelessResetToken)
}
// We don't support connection migration yet, so we don't have any use for the preferred_address.
if params.PreferredAddress != nil {
// Retire the connection ID.
s.connIDManager.AddFromPreferredAddress(params.PreferredAddress.ConnectionID, params.PreferredAddress.StatelessResetToken)
}
maxPacketSize := protocol.ByteCount(protocol.MaxPacketBufferSize)
if params.MaxUDPPayloadSize > 0 && params.MaxUDPPayloadSize < maxPacketSize {
maxPacketSize = params.MaxUDPPayloadSize
}
s.mtuDiscoverer = newMTUDiscoverer(
s.rttStats,
protocol.ByteCount(s.config.InitialPacketSize),
maxPacketSize,
s.tracer,
)
}
func (s *connection) triggerSending(now time.Time) error {
s.pacingDeadline = time.Time{}
sendMode := s.sentPacketHandler.SendMode(now)
//nolint:exhaustive // No need to handle pacing limited here.
switch sendMode {
case ackhandler.SendAny:
return s.sendPackets(now)
case ackhandler.SendNone:
return nil
case ackhandler.SendPacingLimited:
deadline := s.sentPacketHandler.TimeUntilSend()
if deadline.IsZero() {
deadline = deadlineSendImmediately
}
s.pacingDeadline = deadline
// Allow sending of an ACK if we're pacing limit.
// This makes sure that a peer that is mostly receiving data (and thus has an inaccurate cwnd estimate)
// sends enough ACKs to allow its peer to utilize the bandwidth.
fallthrough
case ackhandler.SendAck:
// We can at most send a single ACK only packet.
// There will only be a new ACK after receiving new packets.
// SendAck is only returned when we're congestion limited, so we don't need to set the pacing timer.
return s.maybeSendAckOnlyPacket(now)
case ackhandler.SendPTOInitial, ackhandler.SendPTOHandshake, ackhandler.SendPTOAppData:
if err := s.sendProbePacket(sendMode, now); err != nil {
return err
}
if s.sendQueue.WouldBlock() {
s.scheduleSending()
return nil
}
return s.triggerSending(now)
default:
return fmt.Errorf("BUG: invalid send mode %d", sendMode)
}
}
func (s *connection) sendPackets(now time.Time) error {
// Path MTU Discovery
// Can't use GSO, since we need to send a single packet that's larger than our current maximum size.
// Performance-wise, this doesn't matter, since we only send a very small (<10) number of
// MTU probe packets per connection.
if s.handshakeConfirmed && s.mtuDiscoverer != nil && s.mtuDiscoverer.ShouldSendProbe(now) {
ping, size := s.mtuDiscoverer.GetPing(now)
p, buf, err := s.packer.PackMTUProbePacket(ping, size, s.version)
if err != nil {
return err
}
ecn := s.sentPacketHandler.ECNMode(true)
s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, ecn, buf.Len(), false)
s.registerPackedShortHeaderPacket(p, ecn, now)
s.sendQueue.Send(buf, 0, ecn)
// There's (likely) more data to send. Loop around again.
s.scheduleSending()
return nil
}
if offset := s.connFlowController.GetWindowUpdate(now); offset > 0 {
s.framer.QueueControlFrame(&wire.MaxDataFrame{MaximumData: offset})
}
if cf := s.cryptoStreamManager.GetPostHandshakeData(protocol.MaxPostHandshakeCryptoFrameSize); cf != nil {
s.queueControlFrame(cf)
}
if !s.handshakeConfirmed {
packet, err := s.packer.PackCoalescedPacket(false, s.maxPacketSize(), now, s.version)
if err != nil || packet == nil {
return err
}
s.sentFirstPacket = true
if err := s.sendPackedCoalescedPacket(packet, s.sentPacketHandler.ECNMode(packet.IsOnlyShortHeaderPacket()), now); err != nil {
return err
}
sendMode := s.sentPacketHandler.SendMode(now)
if sendMode == ackhandler.SendPacingLimited {
s.resetPacingDeadline()
} else if sendMode == ackhandler.SendAny {
s.pacingDeadline = deadlineSendImmediately
}
return nil
}
if s.conn.capabilities().GSO {
return s.sendPacketsWithGSO(now)
}
return s.sendPacketsWithoutGSO(now)
}
func (s *connection) sendPacketsWithoutGSO(now time.Time) error {
for {
buf := getPacketBuffer()
ecn := s.sentPacketHandler.ECNMode(true)
if _, err := s.appendOneShortHeaderPacket(buf, s.maxPacketSize(), ecn, now); err != nil {
if err == errNothingToPack {
buf.Release()
return nil
}
return err
}
s.sendQueue.Send(buf, 0, ecn)
if s.sendQueue.WouldBlock() {
return nil
}
sendMode := s.sentPacketHandler.SendMode(now)
if sendMode == ackhandler.SendPacingLimited {
s.resetPacingDeadline()
return nil
}
if sendMode != ackhandler.SendAny {
return nil
}
// Prioritize receiving of packets over sending out more packets.
s.receivedPacketMx.Lock()
hasPackets := !s.receivedPackets.Empty()
s.receivedPacketMx.Unlock()
if hasPackets {
s.pacingDeadline = deadlineSendImmediately
return nil
}
}
}
func (s *connection) sendPacketsWithGSO(now time.Time) error {
buf := getLargePacketBuffer()
maxSize := s.maxPacketSize()
ecn := s.sentPacketHandler.ECNMode(true)
for {
var dontSendMore bool
size, err := s.appendOneShortHeaderPacket(buf, maxSize, ecn, now)
if err != nil {
if err != errNothingToPack {
return err
}
if buf.Len() == 0 {
buf.Release()
return nil
}
dontSendMore = true
}
if !dontSendMore {
sendMode := s.sentPacketHandler.SendMode(now)
if sendMode == ackhandler.SendPacingLimited {
s.resetPacingDeadline()
}
if sendMode != ackhandler.SendAny {
dontSendMore = true
}
}
// Don't send more packets in this batch if they require a different ECN marking than the previous ones.
nextECN := s.sentPacketHandler.ECNMode(true)
// Append another packet if
// 1. The congestion controller and pacer allow sending more
// 2. The last packet appended was a full-size packet
// 3. The next packet will have the same ECN marking
// 4. We still have enough space for another full-size packet in the buffer
if !dontSendMore && size == maxSize && nextECN == ecn && buf.Len()+maxSize <= buf.Cap() {
continue
}
s.sendQueue.Send(buf, uint16(maxSize), ecn)
if dontSendMore {
return nil
}
if s.sendQueue.WouldBlock() {
return nil
}
// Prioritize receiving of packets over sending out more packets.
s.receivedPacketMx.Lock()
hasPackets := !s.receivedPackets.Empty()
s.receivedPacketMx.Unlock()
if hasPackets {
s.pacingDeadline = deadlineSendImmediately
return nil
}
ecn = nextECN
buf = getLargePacketBuffer()
}
}
func (s *connection) resetPacingDeadline() {
deadline := s.sentPacketHandler.TimeUntilSend()
if deadline.IsZero() {
deadline = deadlineSendImmediately
}
s.pacingDeadline = deadline
}
func (s *connection) maybeSendAckOnlyPacket(now time.Time) error {
if !s.handshakeConfirmed {
ecn := s.sentPacketHandler.ECNMode(false)
packet, err := s.packer.PackCoalescedPacket(true, s.maxPacketSize(), now, s.version)
if err != nil {
return err
}
if packet == nil {
return nil
}
return s.sendPackedCoalescedPacket(packet, ecn, now)
}
ecn := s.sentPacketHandler.ECNMode(true)
p, buf, err := s.packer.PackAckOnlyPacket(s.maxPacketSize(), now, s.version)
if err != nil {
if err == errNothingToPack {
return nil
}
return err
}
s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, ecn, buf.Len(), false)
s.registerPackedShortHeaderPacket(p, ecn, now)
s.sendQueue.Send(buf, 0, ecn)
return nil
}
func (s *connection) sendProbePacket(sendMode ackhandler.SendMode, now time.Time) error {
var encLevel protocol.EncryptionLevel
//nolint:exhaustive // We only need to handle the PTO send modes here.
switch sendMode {
case ackhandler.SendPTOInitial:
encLevel = protocol.EncryptionInitial
case ackhandler.SendPTOHandshake:
encLevel = protocol.EncryptionHandshake
case ackhandler.SendPTOAppData:
encLevel = protocol.Encryption1RTT
default:
return fmt.Errorf("connection BUG: unexpected send mode: %d", sendMode)
}
// Queue probe packets until we actually send out a packet,
// or until there are no more packets to queue.
var packet *coalescedPacket
for {
if wasQueued := s.sentPacketHandler.QueueProbePacket(encLevel); !wasQueued {
break
}
var err error
packet, err = s.packer.MaybePackPTOProbePacket(encLevel, s.maxPacketSize(), now, s.version)
if err != nil {
return err
}
if packet != nil {
break
}
}
if packet == nil {
s.retransmissionQueue.AddPing(encLevel)
var err error
packet, err = s.packer.MaybePackPTOProbePacket(encLevel, s.maxPacketSize(), now, s.version)
if err != nil {
return err
}
}
if packet == nil || (len(packet.longHdrPackets) == 0 && packet.shortHdrPacket == nil) {
return fmt.Errorf("connection BUG: couldn't pack %s probe packet", encLevel)
}
return s.sendPackedCoalescedPacket(packet, s.sentPacketHandler.ECNMode(packet.IsOnlyShortHeaderPacket()), now)
}
// appendOneShortHeaderPacket appends a new packet to the given packetBuffer.
// If there was nothing to pack, the returned size is 0.
func (s *connection) appendOneShortHeaderPacket(buf *packetBuffer, maxSize protocol.ByteCount, ecn protocol.ECN, now time.Time) (protocol.ByteCount, error) {
startLen := buf.Len()
p, err := s.packer.AppendPacket(buf, maxSize, now, s.version)
if err != nil {
return 0, err
}
size := buf.Len() - startLen
s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, ecn, size, false)
s.registerPackedShortHeaderPacket(p, ecn, now)
return size, nil
}
func (s *connection) registerPackedShortHeaderPacket(p shortHeaderPacket, ecn protocol.ECN, now time.Time) {
if p.IsPathProbePacket {
s.sentPacketHandler.SentPacket(
now,
p.PacketNumber,
protocol.InvalidPacketNumber,
p.StreamFrames,
p.Frames,
protocol.Encryption1RTT,
ecn,
p.Length,
p.IsPathMTUProbePacket,
true,
)
return
}
if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && (len(p.StreamFrames) > 0 || ackhandler.HasAckElicitingFrames(p.Frames)) {
s.firstAckElicitingPacketAfterIdleSentTime = now
}
largestAcked := protocol.InvalidPacketNumber
if p.Ack != nil {
largestAcked = p.Ack.LargestAcked()
}
s.sentPacketHandler.SentPacket(
now,
p.PacketNumber,
largestAcked,
p.StreamFrames,
p.Frames,
protocol.Encryption1RTT,
ecn,
p.Length,
p.IsPathMTUProbePacket,
false,
)
s.connIDManager.SentPacket()
}
func (s *connection) sendPackedCoalescedPacket(packet *coalescedPacket, ecn protocol.ECN, now time.Time) error {
s.logCoalescedPacket(packet, ecn)
for _, p := range packet.longHdrPackets {
if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && p.IsAckEliciting() {
s.firstAckElicitingPacketAfterIdleSentTime = now
}
largestAcked := protocol.InvalidPacketNumber
if p.ack != nil {
largestAcked = p.ack.LargestAcked()
}
s.sentPacketHandler.SentPacket(
now,
p.header.PacketNumber,
largestAcked,
p.streamFrames,
p.frames,
p.EncryptionLevel(),
ecn,
p.length,
false,
false,
)
if s.perspective == protocol.PerspectiveClient && p.EncryptionLevel() == protocol.EncryptionHandshake &&
!s.droppedInitialKeys {
// On the client side, Initial keys are dropped as soon as the first Handshake packet is sent.
// See Section 4.9.1 of RFC 9001.
if err := s.dropEncryptionLevel(protocol.EncryptionInitial, now); err != nil {
return err
}
}
}
if p := packet.shortHdrPacket; p != nil {
if s.firstAckElicitingPacketAfterIdleSentTime.IsZero() && p.IsAckEliciting() {
s.firstAckElicitingPacketAfterIdleSentTime = now
}
largestAcked := protocol.InvalidPacketNumber
if p.Ack != nil {
largestAcked = p.Ack.LargestAcked()
}
s.sentPacketHandler.SentPacket(
now,
p.PacketNumber,
largestAcked,
p.StreamFrames,
p.Frames,
protocol.Encryption1RTT,
ecn,
p.Length,
p.IsPathMTUProbePacket,
false,
)
}
s.connIDManager.SentPacket()
s.sendQueue.Send(packet.buffer, 0, ecn)
return nil
}
func (s *connection) sendConnectionClose(e error) ([]byte, error) {
var packet *coalescedPacket
var err error
var transportErr *qerr.TransportError
var applicationErr *qerr.ApplicationError
if errors.As(e, &transportErr) {
packet, err = s.packer.PackConnectionClose(transportErr, s.maxPacketSize(), s.version)
} else if errors.As(e, &applicationErr) {
packet, err = s.packer.PackApplicationClose(applicationErr, s.maxPacketSize(), s.version)
} else {
packet, err = s.packer.PackConnectionClose(&qerr.TransportError{
ErrorCode: qerr.InternalError,
ErrorMessage: fmt.Sprintf("connection BUG: unspecified error type (msg: %s)", e.Error()),
}, s.maxPacketSize(), s.version)
}
if err != nil {
return nil, err
}
ecn := s.sentPacketHandler.ECNMode(packet.IsOnlyShortHeaderPacket())
s.logCoalescedPacket(packet, ecn)
return packet.buffer.Data, s.conn.Write(packet.buffer.Data, 0, ecn)
}
func (s *connection) maxPacketSize() protocol.ByteCount {
if s.mtuDiscoverer == nil {
// Use the configured packet size on the client side.
// If the server sends a max_udp_payload_size that's smaller than this size, we can ignore this:
// Apparently the server still processed the (fully padded) Initial packet anyway.
if s.perspective == protocol.PerspectiveClient {
return protocol.ByteCount(s.config.InitialPacketSize)
}
// On the server side, there's no downside to using 1200 bytes until we received the client's transport
// parameters:
// * If the first packet didn't contain the entire ClientHello, all we can do is ACK that packet. We don't
// need a lot of bytes for that.
// * If it did, we will have processed the transport parameters and initialized the MTU discoverer.
return protocol.MinInitialPacketSize
}
return s.mtuDiscoverer.CurrentSize()
}
// AcceptStream returns the next stream openend by the peer
func (s *connection) AcceptStream(ctx context.Context) (Stream, error) {
return s.streamsMap.AcceptStream(ctx)
}
func (s *connection) AcceptUniStream(ctx context.Context) (ReceiveStream, error) {
return s.streamsMap.AcceptUniStream(ctx)
}
// OpenStream opens a stream
func (s *connection) OpenStream() (Stream, error) {
return s.streamsMap.OpenStream()
}
func (s *connection) OpenStreamSync(ctx context.Context) (Stream, error) {
return s.streamsMap.OpenStreamSync(ctx)
}
func (s *connection) OpenUniStream() (SendStream, error) {
return s.streamsMap.OpenUniStream()
}
func (s *connection) OpenUniStreamSync(ctx context.Context) (SendStream, error) {
return s.streamsMap.OpenUniStreamSync(ctx)
}
func (s *connection) newFlowController(id protocol.StreamID) flowcontrol.StreamFlowController {
initialSendWindow := s.peerParams.InitialMaxStreamDataUni
if id.Type() == protocol.StreamTypeBidi {
if id.InitiatedBy() == s.perspective {
initialSendWindow = s.peerParams.InitialMaxStreamDataBidiRemote
} else {
initialSendWindow = s.peerParams.InitialMaxStreamDataBidiLocal
}
}
return flowcontrol.NewStreamFlowController(
id,
s.connFlowController,
protocol.ByteCount(s.config.InitialStreamReceiveWindow),
protocol.ByteCount(s.config.MaxStreamReceiveWindow),
initialSendWindow,
s.rttStats,
s.logger,
)
}
// scheduleSending signals that we have data for sending
func (s *connection) scheduleSending() {
select {
case s.sendingScheduled <- struct{}{}:
default:
}
}
// tryQueueingUndecryptablePacket queues a packet for which we're missing the decryption keys.
// The logging.PacketType is only used for logging purposes.
func (s *connection) tryQueueingUndecryptablePacket(p receivedPacket, pt logging.PacketType) {
if s.handshakeComplete {
panic("shouldn't queue undecryptable packets after handshake completion")
}
if len(s.undecryptablePackets)+1 > protocol.MaxUndecryptablePackets {
if s.tracer != nil && s.tracer.DroppedPacket != nil {
s.tracer.DroppedPacket(pt, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropDOSPrevention)
}
s.logger.Infof("Dropping undecryptable packet (%d bytes). Undecryptable packet queue full.", p.Size())
return
}
s.logger.Infof("Queueing packet (%d bytes) for later decryption", p.Size())
if s.tracer != nil && s.tracer.BufferedPacket != nil {
s.tracer.BufferedPacket(pt, p.Size())
}
s.undecryptablePackets = append(s.undecryptablePackets, p)
}
func (s *connection) queueControlFrame(f wire.Frame) {
s.framer.QueueControlFrame(f)
s.scheduleSending()
}
func (s *connection) onHasConnectionData() { s.scheduleSending() }
func (s *connection) onHasStreamData(id protocol.StreamID, str sendStreamI) {
s.framer.AddActiveStream(id, str)
s.scheduleSending()
}
func (s *connection) onHasStreamControlFrame(id protocol.StreamID, str streamControlFrameGetter) {
s.framer.AddStreamWithControlFrames(id, str)
s.scheduleSending()
}
func (s *connection) onStreamCompleted(id protocol.StreamID) {
if err := s.streamsMap.DeleteStream(id); err != nil {
s.closeLocal(err)
}
s.framer.RemoveActiveStream(id)
}
func (s *connection) SendDatagram(p []byte) error {
if !s.supportsDatagrams() {
return errors.New("datagram support disabled")
}
f := &wire.DatagramFrame{DataLenPresent: true}
// The payload size estimate is conservative.
// Under many circumstances we could send a few more bytes.
maxDataLen := min(
f.MaxDataLen(s.peerParams.MaxDatagramFrameSize, s.version),
protocol.ByteCount(s.currentMTUEstimate.Load()),
)
if protocol.ByteCount(len(p)) > maxDataLen {
return &DatagramTooLargeError{MaxDatagramPayloadSize: int64(maxDataLen)}
}
f.Data = make([]byte, len(p))
copy(f.Data, p)
return s.datagramQueue.Add(f)
}
func (s *connection) ReceiveDatagram(ctx context.Context) ([]byte, error) {
if !s.config.EnableDatagrams {
return nil, errors.New("datagram support disabled")
}
return s.datagramQueue.Receive(ctx)
}
func (s *connection) LocalAddr() net.Addr { return s.conn.LocalAddr() }
func (s *connection) RemoteAddr() net.Addr { return s.conn.RemoteAddr() }
func (s *connection) NextConnection(ctx context.Context) (Connection, error) {
// The handshake might fail after the server rejected 0-RTT.
// This could happen if the Finished message is malformed or never received.
select {
case <-ctx.Done():
return nil, context.Cause(ctx)
case <-s.Context().Done():
case <-s.HandshakeComplete():
s.streamsMap.UseResetMaps()
}
return s, nil
}
// estimateMaxPayloadSize estimates the maximum payload size for short header packets.
// It is not very sophisticated: it just subtracts the size of header (assuming the maximum
// connection ID length), and the size of the encryption tag.
func estimateMaxPayloadSize(mtu protocol.ByteCount) protocol.ByteCount {
return mtu - 1 /* type byte */ - 20 /* maximum connection ID length */ - 16 /* tag size */
}
golang-github-lucas-clemente-quic-go-0.50.0/connection_logging.go 0000664 0000000 0000000 00000011627 14765760516 0025010 0 ustar 00root root 0000000 0000000 package quic
import (
"slices"
"github.com/quic-go/quic-go/internal/ackhandler"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/logging"
)
// ConvertFrame converts a wire.Frame into a logging.Frame.
// This makes it possible for external packages to access the frames.
// Furthermore, it removes the data slices from CRYPTO and STREAM frames.
func toLoggingFrame(frame wire.Frame) logging.Frame {
switch f := frame.(type) {
case *wire.AckFrame:
// We use a pool for ACK frames.
// Implementations of the tracer interface may hold on to frames, so we need to make a copy here.
return toLoggingAckFrame(f)
case *wire.CryptoFrame:
return &logging.CryptoFrame{
Offset: f.Offset,
Length: protocol.ByteCount(len(f.Data)),
}
case *wire.StreamFrame:
return &logging.StreamFrame{
StreamID: f.StreamID,
Offset: f.Offset,
Length: f.DataLen(),
Fin: f.Fin,
}
case *wire.DatagramFrame:
return &logging.DatagramFrame{
Length: logging.ByteCount(len(f.Data)),
}
default:
return logging.Frame(frame)
}
}
func toLoggingAckFrame(f *wire.AckFrame) *logging.AckFrame {
ack := &logging.AckFrame{
AckRanges: slices.Clone(f.AckRanges),
DelayTime: f.DelayTime,
ECNCE: f.ECNCE,
ECT0: f.ECT0,
ECT1: f.ECT1,
}
return ack
}
func (s *connection) logLongHeaderPacket(p *longHeaderPacket, ecn protocol.ECN) {
// quic-go logging
if s.logger.Debug() {
p.header.Log(s.logger)
if p.ack != nil {
wire.LogFrame(s.logger, p.ack, true)
}
for _, frame := range p.frames {
wire.LogFrame(s.logger, frame.Frame, true)
}
for _, frame := range p.streamFrames {
wire.LogFrame(s.logger, frame.Frame, true)
}
}
// tracing
if s.tracer != nil && s.tracer.SentLongHeaderPacket != nil {
frames := make([]logging.Frame, 0, len(p.frames))
for _, f := range p.frames {
frames = append(frames, toLoggingFrame(f.Frame))
}
for _, f := range p.streamFrames {
frames = append(frames, toLoggingFrame(f.Frame))
}
var ack *logging.AckFrame
if p.ack != nil {
ack = toLoggingAckFrame(p.ack)
}
s.tracer.SentLongHeaderPacket(p.header, p.length, ecn, ack, frames)
}
}
func (s *connection) logShortHeaderPacket(
destConnID protocol.ConnectionID,
ackFrame *wire.AckFrame,
frames []ackhandler.Frame,
streamFrames []ackhandler.StreamFrame,
pn protocol.PacketNumber,
pnLen protocol.PacketNumberLen,
kp protocol.KeyPhaseBit,
ecn protocol.ECN,
size protocol.ByteCount,
isCoalesced bool,
) {
if s.logger.Debug() && !isCoalesced {
s.logger.Debugf("-> Sending packet %d (%d bytes) for connection %s, 1-RTT (ECN: %s)", pn, size, s.logID, ecn)
}
// quic-go logging
if s.logger.Debug() {
wire.LogShortHeader(s.logger, destConnID, pn, pnLen, kp)
if ackFrame != nil {
wire.LogFrame(s.logger, ackFrame, true)
}
for _, f := range frames {
wire.LogFrame(s.logger, f.Frame, true)
}
for _, f := range streamFrames {
wire.LogFrame(s.logger, f.Frame, true)
}
}
// tracing
if s.tracer != nil && s.tracer.SentShortHeaderPacket != nil {
fs := make([]logging.Frame, 0, len(frames)+len(streamFrames))
for _, f := range frames {
fs = append(fs, toLoggingFrame(f.Frame))
}
for _, f := range streamFrames {
fs = append(fs, toLoggingFrame(f.Frame))
}
var ack *logging.AckFrame
if ackFrame != nil {
ack = toLoggingAckFrame(ackFrame)
}
s.tracer.SentShortHeaderPacket(
&logging.ShortHeader{DestConnectionID: destConnID, PacketNumber: pn, PacketNumberLen: pnLen, KeyPhase: kp},
size,
ecn,
ack,
fs,
)
}
}
func (s *connection) logCoalescedPacket(packet *coalescedPacket, ecn protocol.ECN) {
if s.logger.Debug() {
// There's a short period between dropping both Initial and Handshake keys and completion of the handshake,
// during which we might call PackCoalescedPacket but just pack a short header packet.
if len(packet.longHdrPackets) == 0 && packet.shortHdrPacket != nil {
s.logShortHeaderPacket(
packet.shortHdrPacket.DestConnID,
packet.shortHdrPacket.Ack,
packet.shortHdrPacket.Frames,
packet.shortHdrPacket.StreamFrames,
packet.shortHdrPacket.PacketNumber,
packet.shortHdrPacket.PacketNumberLen,
packet.shortHdrPacket.KeyPhase,
ecn,
packet.shortHdrPacket.Length,
false,
)
return
}
if len(packet.longHdrPackets) > 1 {
s.logger.Debugf("-> Sending coalesced packet (%d parts, %d bytes) for connection %s", len(packet.longHdrPackets), packet.buffer.Len(), s.logID)
} else {
s.logger.Debugf("-> Sending packet %d (%d bytes) for connection %s, %s", packet.longHdrPackets[0].header.PacketNumber, packet.buffer.Len(), s.logID, packet.longHdrPackets[0].EncryptionLevel())
}
}
for _, p := range packet.longHdrPackets {
s.logLongHeaderPacket(p, ecn)
}
if p := packet.shortHdrPacket; p != nil {
s.logShortHeaderPacket(p.DestConnID, p.Ack, p.Frames, p.StreamFrames, p.PacketNumber, p.PacketNumberLen, p.KeyPhase, ecn, p.Length, true)
}
}
golang-github-lucas-clemente-quic-go-0.50.0/connection_logging_test.go 0000664 0000000 0000000 00000003157 14765760516 0026046 0 ustar 00root root 0000000 0000000 package quic
import (
"testing"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/logging"
"github.com/stretchr/testify/require"
)
func TestConnectionLoggingCryptoFrame(t *testing.T) {
f := toLoggingFrame(&wire.CryptoFrame{
Offset: 1234,
Data: []byte("foobar"),
})
require.Equal(t, &logging.CryptoFrame{
Offset: 1234,
Length: 6,
}, f)
}
func TestConnectionLoggingStreamFrame(t *testing.T) {
f := toLoggingFrame(&wire.StreamFrame{
StreamID: 42,
Offset: 1234,
Data: []byte("foo"),
Fin: true,
})
require.Equal(t, &logging.StreamFrame{
StreamID: 42,
Offset: 1234,
Length: 3,
Fin: true,
}, f)
}
func TestConnectionLoggingAckFrame(t *testing.T) {
ack := &wire.AckFrame{
AckRanges: []wire.AckRange{
{Smallest: 1, Largest: 3},
{Smallest: 6, Largest: 7},
},
DelayTime: 42,
ECNCE: 123,
ECT0: 456,
ECT1: 789,
}
f := toLoggingFrame(ack)
// now modify the ACK range in the original frame
ack.AckRanges[0].Smallest = 2
require.Equal(t, &logging.AckFrame{
AckRanges: []wire.AckRange{
{Smallest: 1, Largest: 3}, // unchanged, since the ACK ranges were cloned
{Smallest: 6, Largest: 7},
},
DelayTime: 42,
ECNCE: 123,
ECT0: 456,
ECT1: 789,
}, f)
}
func TestConnectionLoggingDatagramFrame(t *testing.T) {
f := toLoggingFrame(&wire.DatagramFrame{Data: []byte("foobar")})
require.Equal(t, &logging.DatagramFrame{Length: 6}, f)
}
func TestConnectionLoggingOtherFrames(t *testing.T) {
f := toLoggingFrame(&wire.MaxDataFrame{MaximumData: 1234})
require.Equal(t, &logging.MaxDataFrame{MaximumData: 1234}, f)
}
golang-github-lucas-clemente-quic-go-0.50.0/connection_test.go 0000664 0000000 0000000 00000325130 14765760516 0024336 0 ustar 00root root 0000000 0000000 package quic
import (
"bytes"
"context"
"crypto/rand"
"crypto/tls"
"errors"
"net"
"net/netip"
"strconv"
"testing"
"time"
"github.com/quic-go/quic-go/internal/ackhandler"
"github.com/quic-go/quic-go/internal/handshake"
"github.com/quic-go/quic-go/internal/mocks"
mockackhandler "github.com/quic-go/quic-go/internal/mocks/ackhandler"
mocklogging "github.com/quic-go/quic-go/internal/mocks/logging"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire"
"github.com/quic-go/quic-go/logging"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
)
type testConnectionOpt func(*connection)
func connectionOptCryptoSetup(cs *mocks.MockCryptoSetup) testConnectionOpt {
return func(conn *connection) { conn.cryptoStreamHandler = cs }
}
func connectionOptStreamManager(sm *MockStreamManager) testConnectionOpt {
return func(conn *connection) { conn.streamsMap = sm }
}
func connectionOptConnFlowController(cfc *mocks.MockConnectionFlowController) testConnectionOpt {
return func(conn *connection) { conn.connFlowController = cfc }
}
func connectionOptTracer(tr *logging.ConnectionTracer) testConnectionOpt {
return func(conn *connection) { conn.tracer = tr }
}
func connectionOptSentPacketHandler(sph ackhandler.SentPacketHandler) testConnectionOpt {
return func(conn *connection) { conn.sentPacketHandler = sph }
}
func connectionOptReceivedPacketHandler(rph ackhandler.ReceivedPacketHandler) testConnectionOpt {
return func(conn *connection) { conn.receivedPacketHandler = rph }
}
func connectionOptUnpacker(u unpacker) testConnectionOpt {
return func(conn *connection) { conn.unpacker = u }
}
func connectionOptSender(s sender) testConnectionOpt {
return func(conn *connection) { conn.sendQueue = s }
}
func connectionOptHandshakeConfirmed() testConnectionOpt {
return func(conn *connection) {
conn.handshakeComplete = true
conn.handshakeConfirmed = true
}
}
func connectionOptRTT(rtt time.Duration) testConnectionOpt {
var rttStats utils.RTTStats
rttStats.UpdateRTT(rtt, 0)
return func(conn *connection) { conn.rttStats = &rttStats }
}
func connectionOptRetrySrcConnID(rcid protocol.ConnectionID) testConnectionOpt {
return func(conn *connection) { conn.retrySrcConnID = &rcid }
}
type testConnection struct {
conn *connection
connRunner *MockConnRunner
sendConn *MockSendConn
packer *MockPacker
destConnID protocol.ConnectionID
srcConnID protocol.ConnectionID
remoteAddr *net.UDPAddr
}
func newServerTestConnection(
t *testing.T,
mockCtrl *gomock.Controller,
config *Config,
gso bool,
opts ...testConnectionOpt,
) *testConnection {
if mockCtrl == nil {
mockCtrl = gomock.NewController(t)
}
remoteAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321}
localAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234}
connRunner := NewMockConnRunner(mockCtrl)
sendConn := NewMockSendConn(mockCtrl)
sendConn.EXPECT().capabilities().Return(connCapabilities{GSO: gso}).AnyTimes()
sendConn.EXPECT().RemoteAddr().Return(remoteAddr).AnyTimes()
sendConn.EXPECT().LocalAddr().Return(localAddr).AnyTimes()
packer := NewMockPacker(mockCtrl)
b := make([]byte, 12)
rand.Read(b)
origDestConnID := protocol.ParseConnectionID(b[:6])
srcConnID := protocol.ParseConnectionID(b[6:12])
ctx, cancel := context.WithCancelCause(context.Background())
if config == nil {
config = &Config{DisablePathMTUDiscovery: true}
}
conn := newConnection(
ctx,
cancel,
sendConn,
connRunner,
origDestConnID,
nil,
protocol.ConnectionID{},
protocol.ConnectionID{},
srcConnID,
&protocol.DefaultConnectionIDGenerator{},
newStatelessResetter(nil),
populateConfig(config),
&tls.Config{},
handshake.NewTokenGenerator(handshake.TokenProtectorKey{}),
false,
nil,
utils.DefaultLogger,
protocol.Version1,
).(*connection)
conn.packer = packer
for _, opt := range opts {
opt(conn)
}
return &testConnection{
conn: conn,
connRunner: connRunner,
sendConn: sendConn,
packer: packer,
destConnID: origDestConnID,
srcConnID: srcConnID,
remoteAddr: remoteAddr,
}
}
func newClientTestConnection(
t *testing.T,
mockCtrl *gomock.Controller,
config *Config,
enable0RTT bool,
opts ...testConnectionOpt,
) *testConnection {
if mockCtrl == nil {
mockCtrl = gomock.NewController(t)
}
remoteAddr := &net.UDPAddr{IP: net.IPv4(1, 2, 3, 4), Port: 4321}
localAddr := &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 1234}
connRunner := NewMockConnRunner(mockCtrl)
sendConn := NewMockSendConn(mockCtrl)
sendConn.EXPECT().capabilities().Return(connCapabilities{}).AnyTimes()
sendConn.EXPECT().RemoteAddr().Return(remoteAddr).AnyTimes()
sendConn.EXPECT().LocalAddr().Return(localAddr).AnyTimes()
packer := NewMockPacker(mockCtrl)
b := make([]byte, 12)
rand.Read(b)
destConnID := protocol.ParseConnectionID(b[:6])
srcConnID := protocol.ParseConnectionID(b[6:12])
if config == nil {
config = &Config{DisablePathMTUDiscovery: true}
}
conn := newClientConnection(
context.Background(),
sendConn,
connRunner,
destConnID,
srcConnID,
&protocol.DefaultConnectionIDGenerator{},
newStatelessResetter(nil),
populateConfig(config),
&tls.Config{ServerName: "quic-go.net"},
0,
enable0RTT,
false,
nil,
utils.DefaultLogger,
protocol.Version1,
).(*connection)
conn.packer = packer
for _, opt := range opts {
opt(conn)
}
return &testConnection{
conn: conn,
connRunner: connRunner,
sendConn: sendConn,
packer: packer,
destConnID: destConnID,
srcConnID: srcConnID,
}
}
func TestConnectionHandleReceiveStreamFrames(t *testing.T) {
const streamID protocol.StreamID = 5
now := time.Now()
connID := protocol.ConnectionID{}
f := &wire.StreamFrame{StreamID: streamID, Data: []byte("foobar")}
rsf := &wire.ResetStreamFrame{StreamID: streamID, ErrorCode: 42, FinalSize: 1337}
sdbf := &wire.StreamDataBlockedFrame{StreamID: streamID, MaximumStreamData: 1337}
t.Run("for existing and new streams", func(t *testing.T) {
mockCtrl := gomock.NewController(t)
streamsMap := NewMockStreamManager(mockCtrl)
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
str := NewMockReceiveStreamI(mockCtrl)
// STREAM frame
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(str, nil)
str.EXPECT().handleStreamFrame(f, now)
require.NoError(t, tc.conn.handleFrame(f, protocol.Encryption1RTT, connID, now))
// RESET_STREAM frame
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(str, nil)
str.EXPECT().handleResetStreamFrame(rsf, now)
require.NoError(t, tc.conn.handleFrame(rsf, protocol.Encryption1RTT, connID, now))
// STREAM_DATA_BLOCKED frames are not passed to the stream
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(str, nil)
require.NoError(t, tc.conn.handleFrame(sdbf, protocol.Encryption1RTT, connID, now))
})
t.Run("for closed streams", func(t *testing.T) {
mockCtrl := gomock.NewController(t)
streamsMap := NewMockStreamManager(mockCtrl)
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
// STREAM frame
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, nil)
require.NoError(t, tc.conn.handleFrame(f, protocol.Encryption1RTT, connID, now))
// RESET_STREAM frame
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, nil)
require.NoError(t, tc.conn.handleFrame(rsf, protocol.Encryption1RTT, connID, now))
// STREAM_DATA_BLOCKED frames are not passed to the stream
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, nil)
require.NoError(t, tc.conn.handleFrame(sdbf, protocol.Encryption1RTT, connID, now))
})
t.Run("for invalid streams", func(t *testing.T) {
mockCtrl := gomock.NewController(t)
streamsMap := NewMockStreamManager(mockCtrl)
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
testErr := errors.New("test err")
// STREAM frame
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, testErr)
require.ErrorIs(t, tc.conn.handleFrame(f, protocol.Encryption1RTT, connID, now), testErr)
// RESET_STREAM frame
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, testErr)
require.ErrorIs(t, tc.conn.handleFrame(rsf, protocol.Encryption1RTT, connID, now), testErr)
// STREAM_DATA_BLOCKED frames are not passed to the stream
streamsMap.EXPECT().GetOrOpenReceiveStream(streamID).Return(nil, testErr)
require.ErrorIs(t, tc.conn.handleFrame(sdbf, protocol.Encryption1RTT, connID, now), testErr)
})
}
func TestConnectionHandleSendStreamFrames(t *testing.T) {
const streamID protocol.StreamID = 3
now := time.Now()
connID := protocol.ConnectionID{}
ss := &wire.StopSendingFrame{StreamID: streamID, ErrorCode: 42}
msd := &wire.MaxStreamDataFrame{StreamID: streamID, MaximumStreamData: 1337}
t.Run("for existing and new streams", func(t *testing.T) {
mockCtrl := gomock.NewController(t)
streamsMap := NewMockStreamManager(mockCtrl)
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
str := NewMockSendStreamI(mockCtrl)
// STOP_SENDING frame
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(str, nil)
str.EXPECT().handleStopSendingFrame(ss)
require.NoError(t, tc.conn.handleFrame(ss, protocol.Encryption1RTT, connID, now))
// MAX_STREAM_DATA frame
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(str, nil)
str.EXPECT().updateSendWindow(msd.MaximumStreamData)
require.NoError(t, tc.conn.handleFrame(msd, protocol.Encryption1RTT, connID, now))
})
t.Run("for closed streams", func(t *testing.T) {
mockCtrl := gomock.NewController(t)
streamsMap := NewMockStreamManager(mockCtrl)
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
// STOP_SENDING frame
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(nil, nil)
require.NoError(t, tc.conn.handleFrame(ss, protocol.Encryption1RTT, connID, now))
// MAX_STREAM_DATA frame
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(nil, nil)
require.NoError(t, tc.conn.handleFrame(msd, protocol.Encryption1RTT, connID, now))
})
t.Run("for invalid streams", func(t *testing.T) {
mockCtrl := gomock.NewController(t)
streamsMap := NewMockStreamManager(mockCtrl)
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
testErr := errors.New("test err")
// STOP_SENDING frame
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(nil, testErr)
require.ErrorIs(t, tc.conn.handleFrame(ss, protocol.Encryption1RTT, connID, now), testErr)
// MAX_STREAM_DATA frame
streamsMap.EXPECT().GetOrOpenSendStream(streamID).Return(nil, testErr)
require.ErrorIs(t, tc.conn.handleFrame(msd, protocol.Encryption1RTT, connID, now), testErr)
})
}
func TestConnectionHandleStreamNumFrames(t *testing.T) {
mockCtrl := gomock.NewController(t)
streamsMap := NewMockStreamManager(mockCtrl)
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
now := time.Now()
connID := protocol.ConnectionID{}
// MAX_STREAMS frame
msf := &wire.MaxStreamsFrame{Type: protocol.StreamTypeBidi, MaxStreamNum: 10}
streamsMap.EXPECT().HandleMaxStreamsFrame(msf)
require.NoError(t, tc.conn.handleFrame(msf, protocol.Encryption1RTT, connID, now))
// STREAMS_BLOCKED frame
tc.conn.handleFrame(&wire.StreamsBlockedFrame{Type: protocol.StreamTypeBidi, StreamLimit: 1}, protocol.Encryption1RTT, connID, now)
}
func TestConnectionHandleConnectionFlowControlFrames(t *testing.T) {
mockCtrl := gomock.NewController(t)
connFC := mocks.NewMockConnectionFlowController(mockCtrl)
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptConnFlowController(connFC))
now := time.Now()
connID := protocol.ConnectionID{}
// MAX_DATA frame
connFC.EXPECT().UpdateSendWindow(protocol.ByteCount(1337))
require.NoError(t, tc.conn.handleFrame(&wire.MaxDataFrame{MaximumData: 1337}, protocol.Encryption1RTT, connID, now))
// DATA_BLOCKED frame
require.NoError(t, tc.conn.handleFrame(&wire.DataBlockedFrame{MaximumData: 1337}, protocol.Encryption1RTT, connID, now))
}
func TestConnectionOpenStreams(t *testing.T) {
mockCtrl := gomock.NewController(t)
streamsMap := NewMockStreamManager(mockCtrl)
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
// using OpenStream
mstr := NewMockStreamI(mockCtrl)
streamsMap.EXPECT().OpenStream().Return(mstr, nil)
str, err := tc.conn.OpenStream()
require.NoError(t, err)
require.Equal(t, mstr, str)
// using OpenStreamSync
streamsMap.EXPECT().OpenStreamSync(context.Background()).Return(mstr, nil)
str, err = tc.conn.OpenStreamSync(context.Background())
require.NoError(t, err)
require.Equal(t, mstr, str)
// using OpenUniStream
streamsMap.EXPECT().OpenUniStream().Return(mstr, nil)
ustr, err := tc.conn.OpenUniStream()
require.NoError(t, err)
require.Equal(t, mstr, ustr)
// using OpenUniStreamSync
streamsMap.EXPECT().OpenUniStreamSync(context.Background()).Return(mstr, nil)
ustr, err = tc.conn.OpenUniStreamSync(context.Background())
require.NoError(t, err)
require.Equal(t, mstr, ustr)
}
func TestConnectionAcceptStreams(t *testing.T) {
mockCtrl := gomock.NewController(t)
streamsMap := NewMockStreamManager(mockCtrl)
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptStreamManager(streamsMap))
// bidirectional streams
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
mstr := NewMockStreamI(mockCtrl)
streamsMap.EXPECT().AcceptStream(ctx).Return(mstr, nil)
str, err := tc.conn.AcceptStream(ctx)
require.NoError(t, err)
require.Equal(t, mstr, str)
// unidirectional streams
streamsMap.EXPECT().AcceptUniStream(ctx).Return(mstr, nil)
ustr, err := tc.conn.AcceptUniStream(ctx)
require.NoError(t, err)
require.Equal(t, mstr, ustr)
}
func TestConnectionServerInvalidFrames(t *testing.T) {
mockCtrl := gomock.NewController(t)
tc := newServerTestConnection(t, mockCtrl, nil, false)
for _, test := range []struct {
Name string
Frame wire.Frame
}{
{Name: "NEW_TOKEN", Frame: &wire.NewTokenFrame{Token: []byte("foobar")}},
{Name: "HANDSHAKE_DONE", Frame: &wire.HandshakeDoneFrame{}},
{Name: "PATH_RESPONSE", Frame: &wire.PathResponseFrame{Data: [8]byte{1, 2, 3, 4, 5, 6, 7, 8}}},
} {
t.Run(test.Name, func(t *testing.T) {
require.ErrorIs(t,
tc.conn.handleFrame(test.Frame, protocol.Encryption1RTT, protocol.ConnectionID{}, time.Now()),
&qerr.TransportError{ErrorCode: qerr.ProtocolViolation},
)
})
}
}
func TestConnectionTransportError(t *testing.T) {
mockCtrl := gomock.NewController(t)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptTracer(tr))
errChan := make(chan error, 1)
expectedErr := &qerr.TransportError{
ErrorCode: 1337,
FrameType: 42,
ErrorMessage: "test error",
}
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
b := getPacketBuffer()
b.Data = append(b.Data, []byte("connection close")...)
tc.packer.EXPECT().PackConnectionClose(expectedErr, gomock.Any(), protocol.Version1).Return(&coalescedPacket{buffer: b}, nil)
tc.sendConn.EXPECT().Write([]byte("connection close"), gomock.Any(), gomock.Any())
tc.connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).AnyTimes()
gomock.InOrder(
tracer.EXPECT().ClosedConnection(expectedErr),
tracer.EXPECT().Close(),
)
go func() { errChan <- tc.conn.run() }()
tc.conn.closeLocal(expectedErr)
select {
case err := <-errChan:
require.ErrorIs(t, err, expectedErr)
case <-time.After(time.Second):
t.Fatal("timeout")
}
// further calls to CloseWithError don't do anything
tc.conn.CloseWithError(42, "another error")
}
func TestConnectionApplicationClose(t *testing.T) {
mockCtrl := gomock.NewController(t)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptTracer(tr))
errChan := make(chan error, 1)
expectedErr := &qerr.ApplicationError{
ErrorCode: 1337,
ErrorMessage: "test error",
}
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
b := getPacketBuffer()
b.Data = append(b.Data, []byte("connection close")...)
tc.packer.EXPECT().PackApplicationClose(expectedErr, gomock.Any(), protocol.Version1).Return(&coalescedPacket{buffer: b}, nil)
tc.sendConn.EXPECT().Write([]byte("connection close"), gomock.Any(), gomock.Any())
tc.connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any()).AnyTimes()
gomock.InOrder(
tracer.EXPECT().ClosedConnection(expectedErr),
tracer.EXPECT().Close(),
)
go func() { errChan <- tc.conn.run() }()
tc.conn.CloseWithError(1337, "test error")
select {
case err := <-errChan:
require.ErrorIs(t, err, expectedErr)
case <-time.After(time.Second):
t.Fatal("timeout")
}
// further calls to CloseWithError don't do anything
tc.conn.CloseWithError(42, "another error")
}
func TestConnectionStatelessReset(t *testing.T) {
mockCtrl := gomock.NewController(t)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptTracer(tr))
errChan := make(chan error, 1)
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
gomock.InOrder(
tracer.EXPECT().ClosedConnection(&StatelessResetError{}),
tracer.EXPECT().Close(),
)
go func() { errChan <- tc.conn.run() }()
tc.conn.destroy(&StatelessResetError{})
select {
case err := <-errChan:
require.ErrorIs(t, err, &StatelessResetError{})
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func getLongHeaderPacket(t *testing.T, remoteAddr net.Addr, extHdr *wire.ExtendedHeader, data []byte) receivedPacket {
t.Helper()
b, err := extHdr.Append(nil, protocol.Version1)
require.NoError(t, err)
return receivedPacket{
remoteAddr: remoteAddr,
data: append(b, data...),
buffer: getPacketBuffer(),
rcvTime: time.Now(),
}
}
func getShortHeaderPacket(t *testing.T, remoteAddr net.Addr, connID protocol.ConnectionID, pn protocol.PacketNumber, data []byte) receivedPacket {
t.Helper()
b, err := wire.AppendShortHeader(nil, connID, pn, protocol.PacketNumberLen2, protocol.KeyPhaseOne)
require.NoError(t, err)
return receivedPacket{
remoteAddr: remoteAddr,
data: append(b, data...),
buffer: getPacketBuffer(),
rcvTime: time.Now(),
}
}
func TestConnectionServerInvalidPackets(t *testing.T) {
t.Run("Retry", func(t *testing.T) {
mockCtrl := gomock.NewController(t)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptTracer(tr))
p := getLongHeaderPacket(t,
tc.remoteAddr,
&wire.ExtendedHeader{Header: wire.Header{
Type: protocol.PacketTypeRetry,
DestConnectionID: tc.conn.origDestConnID,
SrcConnectionID: tc.srcConnID,
Version: tc.conn.version,
Token: []byte("foobar"),
}},
make([]byte, 16), /* Retry integrity tag */
)
tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnexpectedPacket)
wasProcessed, err := tc.conn.handleOnePacket(p)
require.NoError(t, err)
require.False(t, wasProcessed)
})
t.Run("version negotiation", func(t *testing.T) {
mockCtrl := gomock.NewController(t)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptTracer(tr))
b := wire.ComposeVersionNegotiation(
protocol.ArbitraryLenConnectionID(tc.srcConnID.Bytes()),
protocol.ArbitraryLenConnectionID(tc.conn.origDestConnID.Bytes()),
[]Version{Version1},
)
tracer.EXPECT().DroppedPacket(logging.PacketTypeVersionNegotiation, protocol.InvalidPacketNumber, protocol.ByteCount(len(b)), logging.PacketDropUnexpectedPacket)
wasProcessed, err := tc.conn.handleOnePacket(receivedPacket{data: b, buffer: getPacketBuffer()})
require.NoError(t, err)
require.False(t, wasProcessed)
})
t.Run("unsupported version", func(t *testing.T) {
mockCtrl := gomock.NewController(t)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptTracer(tr))
p := getLongHeaderPacket(t,
tc.remoteAddr,
&wire.ExtendedHeader{
Header: wire.Header{Type: protocol.PacketTypeHandshake, Version: 1234},
PacketNumberLen: protocol.PacketNumberLen2,
},
nil,
)
tracer.EXPECT().DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnsupportedVersion)
wasProcessed, err := tc.conn.handleOnePacket(p)
require.NoError(t, err)
require.False(t, wasProcessed)
})
t.Run("invalid header", func(t *testing.T) {
mockCtrl := gomock.NewController(t)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptTracer(tr))
p := getLongHeaderPacket(t,
tc.remoteAddr,
&wire.ExtendedHeader{
Header: wire.Header{Type: protocol.PacketTypeHandshake, Version: Version1},
PacketNumberLen: protocol.PacketNumberLen2,
},
nil,
)
p.data[0] ^= 0x40 // unset the QUIC bit
tracer.EXPECT().DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropHeaderParseError)
wasProcessed, err := tc.conn.handleOnePacket(p)
require.NoError(t, err)
require.False(t, wasProcessed)
})
}
func TestConnectionClientDrop0RTT(t *testing.T) {
mockCtrl := gomock.NewController(t)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
tc := newClientTestConnection(t, mockCtrl, nil, false, connectionOptTracer(tr))
p := getLongHeaderPacket(t,
tc.remoteAddr,
&wire.ExtendedHeader{
Header: wire.Header{Type: protocol.PacketType0RTT, Length: 2, Version: protocol.Version1},
PacketNumberLen: protocol.PacketNumberLen2,
},
nil,
)
tracer.EXPECT().DroppedPacket(logging.PacketType0RTT, protocol.InvalidPacketNumber, p.Size(), logging.PacketDropUnexpectedPacket)
wasProcessed, err := tc.conn.handleOnePacket(p)
require.NoError(t, err)
require.False(t, wasProcessed)
}
func TestConnectionUnpacking(t *testing.T) {
mockCtrl := gomock.NewController(t)
rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl)
unpacker := NewMockUnpacker(mockCtrl)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
tc := newServerTestConnection(t,
mockCtrl,
nil,
false,
connectionOptReceivedPacketHandler(rph),
connectionOptUnpacker(unpacker),
connectionOptTracer(tr),
)
// receive a long header packet
hdr := &wire.ExtendedHeader{
Header: wire.Header{
Type: protocol.PacketTypeInitial,
DestConnectionID: tc.srcConnID,
Version: protocol.Version1,
Length: 1,
},
PacketNumber: 0x37,
PacketNumberLen: protocol.PacketNumberLen1,
}
unpackedHdr := *hdr
unpackedHdr.PacketNumber = 0x1337
packet := getLongHeaderPacket(t, tc.remoteAddr, hdr, nil)
packet.ecn = protocol.ECNCE
rcvTime := time.Now().Add(-10 * time.Second)
packet.rcvTime = rcvTime
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(&unpackedPacket{
encryptionLevel: protocol.EncryptionInitial,
hdr: &unpackedHdr,
data: []byte{0}, // one PADDING frame
}, nil)
gomock.InOrder(
rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(0x1337), protocol.EncryptionInitial),
rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), protocol.ECNCE, protocol.EncryptionInitial, rcvTime, false),
)
tracer.EXPECT().NegotiatedVersion(gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), logging.ECNCE, []logging.Frame{})
wasProcessed, err := tc.conn.handleOnePacket(packet)
require.NoError(t, err)
require.True(t, wasProcessed)
require.True(t, mockCtrl.Satisfied())
// receive a duplicate of this packet
packet = getLongHeaderPacket(t, tc.remoteAddr, hdr, nil)
rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(0x1337), protocol.EncryptionInitial).Return(true)
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(&unpackedPacket{
encryptionLevel: protocol.EncryptionInitial,
hdr: &unpackedHdr,
data: []byte{0}, // one PADDING frame
}, nil)
tracer.EXPECT().DroppedPacket(logging.PacketTypeInitial, protocol.PacketNumber(0x1337), protocol.ByteCount(len(packet.data)), logging.PacketDropDuplicate)
wasProcessed, err = tc.conn.handleOnePacket(packet)
require.NoError(t, err)
require.False(t, wasProcessed)
require.True(t, mockCtrl.Satisfied())
// receive a short header packet
packet = getShortHeaderPacket(t, tc.remoteAddr, tc.srcConnID, 0x37, nil)
packet.ecn = protocol.ECT1
packet.rcvTime = rcvTime
gomock.InOrder(
rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(0x1337), protocol.Encryption1RTT),
rph.EXPECT().ReceivedPacket(protocol.PacketNumber(0x1337), protocol.ECT1, protocol.Encryption1RTT, rcvTime, false),
)
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(
protocol.PacketNumber(0x1337), protocol.PacketNumberLen2, protocol.KeyPhaseZero, []byte{0} /* PADDING */, nil,
)
tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), logging.ECT1, []logging.Frame{})
wasProcessed, err = tc.conn.handleOnePacket(packet)
require.NoError(t, err)
require.True(t, wasProcessed)
}
func TestConnectionUnpackCoalescedPacket(t *testing.T) {
mockCtrl := gomock.NewController(t)
rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl)
unpacker := NewMockUnpacker(mockCtrl)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
tc := newServerTestConnection(t,
mockCtrl,
nil,
false,
connectionOptReceivedPacketHandler(rph),
connectionOptUnpacker(unpacker),
connectionOptTracer(tr),
)
hdr1 := &wire.ExtendedHeader{
Header: wire.Header{
Type: protocol.PacketTypeInitial,
DestConnectionID: tc.srcConnID,
Version: protocol.Version1,
Length: 1,
},
PacketNumber: 37,
PacketNumberLen: protocol.PacketNumberLen1,
}
hdr2 := &wire.ExtendedHeader{
Header: wire.Header{
Type: protocol.PacketTypeHandshake,
DestConnectionID: tc.srcConnID,
Version: protocol.Version1,
Length: 1,
},
PacketNumber: 38,
PacketNumberLen: protocol.PacketNumberLen1,
}
// add a packet with a different source connection ID
incorrectSrcConnID := protocol.ParseConnectionID([]byte{0xa, 0xb, 0xc})
hdr3 := &wire.ExtendedHeader{
Header: wire.Header{
Type: protocol.PacketTypeHandshake,
DestConnectionID: incorrectSrcConnID,
Version: protocol.Version1,
Length: 1,
},
PacketNumber: 0x42,
PacketNumberLen: protocol.PacketNumberLen1,
}
unpackedHdr1 := *hdr1
unpackedHdr1.PacketNumber = 1337
unpackedHdr2 := *hdr2
unpackedHdr2.PacketNumber = 1338
packet := getLongHeaderPacket(t, tc.remoteAddr, hdr1, nil)
packet2 := getLongHeaderPacket(t, tc.remoteAddr, hdr2, nil)
packet3 := getLongHeaderPacket(t, tc.remoteAddr, hdr3, nil)
packet.data = append(packet.data, packet2.data...)
packet.data = append(packet.data, packet3.data...)
packet.ecn = protocol.ECT1
rcvTime := time.Now()
packet.rcvTime = rcvTime
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(&unpackedPacket{
encryptionLevel: protocol.EncryptionInitial,
hdr: &unpackedHdr1,
data: []byte{0}, // one PADDING frame
}, nil)
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(&unpackedPacket{
encryptionLevel: protocol.EncryptionHandshake,
hdr: &unpackedHdr2,
data: []byte{1}, // one PING frame
}, nil)
gomock.InOrder(
rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(1337), protocol.EncryptionInitial),
rph.EXPECT().ReceivedPacket(protocol.PacketNumber(1337), protocol.ECT1, protocol.EncryptionInitial, rcvTime, false),
rph.EXPECT().IsPotentiallyDuplicate(protocol.PacketNumber(1338), protocol.EncryptionHandshake),
rph.EXPECT().ReceivedPacket(protocol.PacketNumber(1338), protocol.ECT1, protocol.EncryptionHandshake, rcvTime, true),
)
tracer.EXPECT().NegotiatedVersion(gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().DroppedEncryptionLevel(protocol.EncryptionInitial)
rph.EXPECT().DropPackets(protocol.EncryptionInitial)
gomock.InOrder(
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), logging.ECT1, []logging.Frame{}),
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), logging.ECT1, []logging.Frame{&wire.PingFrame{}}),
tracer.EXPECT().DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, protocol.ByteCount(len(packet3.data)), logging.PacketDropUnknownConnectionID),
)
wasProcessed, err := tc.conn.handleOnePacket(packet)
require.NoError(t, err)
require.True(t, wasProcessed)
}
func TestConnectionUnpackFailuresFatal(t *testing.T) {
t.Run("other errors", func(t *testing.T) {
require.ErrorIs(t,
testConnectionUnpackFailureFatal(t, &qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError}),
&qerr.TransportError{ErrorCode: qerr.ConnectionIDLimitError},
)
})
t.Run("invalid reserved bits", func(t *testing.T) {
require.ErrorIs(t,
testConnectionUnpackFailureFatal(t, wire.ErrInvalidReservedBits),
&qerr.TransportError{ErrorCode: qerr.ProtocolViolation},
)
})
}
func testConnectionUnpackFailureFatal(t *testing.T, unpackErr error) error {
mockCtrl := gomock.NewController(t)
unpacker := NewMockUnpacker(mockCtrl)
tc := newServerTestConnection(t,
mockCtrl,
nil,
false,
connectionOptUnpacker(unpacker),
)
tc.connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any())
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(0), protocol.PacketNumberLen(0), protocol.KeyPhaseBit(0), nil, unpackErr)
tc.packer.EXPECT().PackConnectionClose(gomock.Any(), gomock.Any(), protocol.Version1).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
tc.conn.handlePacket(getShortHeaderPacket(t, tc.remoteAddr, tc.srcConnID, 0x42, nil))
select {
case err := <-errChan:
require.Error(t, err)
return err
case <-time.After(time.Second):
t.Fatal("timeout")
}
return nil
}
func TestConnectionUnpackFailureDropped(t *testing.T) {
t.Run("keys dropped", func(t *testing.T) {
testConnectionUnpackFailureDropped(t, handshake.ErrKeysDropped, logging.PacketDropKeyUnavailable)
})
t.Run("decryption failed", func(t *testing.T) {
testConnectionUnpackFailureDropped(t, handshake.ErrDecryptionFailed, logging.PacketDropPayloadDecryptError)
})
t.Run("header parse error", func(t *testing.T) {
testErr := errors.New("foo")
testConnectionUnpackFailureDropped(t, &headerParseError{err: testErr}, logging.PacketDropHeaderParseError)
})
}
func testConnectionUnpackFailureDropped(t *testing.T, unpackErr error, packetDropReason logging.PacketDropReason) {
mockCtrl := gomock.NewController(t)
unpacker := NewMockUnpacker(mockCtrl)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
tc := newServerTestConnection(t,
mockCtrl,
nil,
false,
connectionOptUnpacker(unpacker),
connectionOptTracer(tr),
)
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(0), protocol.PacketNumberLen(0), protocol.KeyPhaseBit(0), nil, unpackErr)
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
done := make(chan struct{})
tracer.EXPECT().DroppedPacket(gomock.Any(), protocol.InvalidPacketNumber, gomock.Any(), packetDropReason).Do(
func(logging.PacketType, protocol.PacketNumber, protocol.ByteCount, logging.PacketDropReason) {
close(done)
},
)
tc.conn.handlePacket(getShortHeaderPacket(t, tc.remoteAddr, tc.srcConnID, 0x42, nil))
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("timeout")
}
// test teardown
tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close()
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
tc.conn.destroy(nil)
select {
case <-errChan:
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestConnectionMaxUnprocessedPackets(t *testing.T) {
mockCtrl := gomock.NewController(t)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
tc := newServerTestConnection(t, mockCtrl, nil, false, connectionOptTracer(tr))
done := make(chan struct{})
for i := protocol.PacketNumber(0); i < protocol.MaxConnUnprocessedPackets; i++ {
// nothing here should block
tc.conn.handlePacket(receivedPacket{data: []byte("foobar")})
}
tracer.EXPECT().DroppedPacket(logging.PacketTypeNotDetermined, protocol.InvalidPacketNumber, logging.ByteCount(6), logging.PacketDropDOSPrevention).Do(func(logging.PacketType, logging.PacketNumber, logging.ByteCount, logging.PacketDropReason) {
close(done)
})
tc.conn.handlePacket(receivedPacket{data: []byte("foobar")})
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestConnectionRemoteClose(t *testing.T) {
mockCtrl := gomock.NewController(t)
mockStreamManager := NewMockStreamManager(mockCtrl)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
unpacker := NewMockUnpacker(mockCtrl)
tc := newServerTestConnection(t,
mockCtrl,
nil,
false,
connectionOptStreamManager(mockStreamManager),
connectionOptTracer(tr),
connectionOptUnpacker(unpacker),
)
ccf, err := (&wire.ConnectionCloseFrame{
ErrorCode: uint64(qerr.StreamLimitError),
ReasonPhrase: "foobar",
}).Append(nil, protocol.Version1)
require.NoError(t, err)
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(protocol.PacketNumber(1), protocol.PacketNumberLen2, protocol.KeyPhaseBit(0), ccf, nil)
tracer.EXPECT().ReceivedShortHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
expectedErr := &qerr.TransportError{ErrorCode: qerr.StreamLimitError, Remote: true}
tc.connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any())
streamErrChan := make(chan error, 1)
mockStreamManager.EXPECT().CloseWithError(gomock.Any()).Do(func(e error) { streamErrChan <- e })
tracerErrChan := make(chan error, 1)
tracer.EXPECT().ClosedConnection(gomock.Any()).Do(func(e error) { tracerErrChan <- e })
tracer.EXPECT().Close()
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
p := getShortHeaderPacket(t, tc.remoteAddr, tc.srcConnID, 1, []byte("encrypted"))
tc.conn.handlePacket(receivedPacket{data: p.data, buffer: p.buffer, rcvTime: time.Now()})
select {
case err := <-errChan:
require.ErrorIs(t, err, expectedErr)
case <-time.After(time.Second):
t.Fatal("timeout")
}
select {
case err := <-tracerErrChan:
require.ErrorIs(t, err, expectedErr)
case <-time.After(time.Second):
t.Fatal("timeout")
}
select {
case err := <-streamErrChan:
require.ErrorIs(t, err, expectedErr)
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestConnectionIdleTimeoutDuringHandshake(t *testing.T) {
mockCtrl := gomock.NewController(t)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
tc := newServerTestConnection(t,
mockCtrl,
&Config{HandshakeIdleTimeout: scaleDuration(25 * time.Millisecond)},
false,
connectionOptTracer(tr),
)
tc.packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), gomock.Any(), protocol.Version1).AnyTimes()
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
gomock.InOrder(
tracer.EXPECT().ClosedConnection(&IdleTimeoutError{}),
tracer.EXPECT().Close(),
)
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
select {
case err := <-errChan:
require.ErrorIs(t, err, &IdleTimeoutError{})
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestConnectionHandshakeIdleTimeout(t *testing.T) {
mockCtrl := gomock.NewController(t)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
tc := newServerTestConnection(t,
mockCtrl,
&Config{HandshakeIdleTimeout: scaleDuration(25 * time.Millisecond)},
false,
connectionOptTracer(tr),
func(c *connection) { c.creationTime = time.Now().Add(-10 * time.Second) },
)
tc.packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), gomock.Any(), protocol.Version1).AnyTimes()
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
gomock.InOrder(
tracer.EXPECT().ClosedConnection(&HandshakeTimeoutError{}),
tracer.EXPECT().Close(),
)
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
select {
case err := <-errChan:
require.ErrorIs(t, err, &HandshakeTimeoutError{})
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestConnectionTransportParameters(t *testing.T) {
mockCtrl := gomock.NewController(t)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
streamManager := NewMockStreamManager(mockCtrl)
connFC := mocks.NewMockConnectionFlowController(mockCtrl)
tc := newServerTestConnection(t,
mockCtrl,
nil,
false,
connectionOptTracer(tr),
connectionOptStreamManager(streamManager),
connectionOptConnFlowController(connFC),
)
tracer.EXPECT().ReceivedTransportParameters(gomock.Any())
params := &wire.TransportParameters{
MaxIdleTimeout: 90 * time.Second,
InitialMaxStreamDataBidiLocal: 0x5000,
InitialMaxData: 0x5000,
ActiveConnectionIDLimit: 3,
// marshaling always sets it to this value
MaxUDPPayloadSize: protocol.MaxPacketBufferSize,
OriginalDestinationConnectionID: tc.destConnID,
}
streamManager.EXPECT().UpdateLimits(params)
connFC.EXPECT().UpdateSendWindow(params.InitialMaxData)
require.NoError(t, tc.conn.handleTransportParameters(params))
}
func TestConnectionTransportParameterValidationFailureServer(t *testing.T) {
tc := newServerTestConnection(t, nil, nil, false)
err := tc.conn.handleTransportParameters(&wire.TransportParameters{
InitialSourceConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
})
assert.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.TransportParameterError})
assert.ErrorContains(t, err, "expected initial_source_connection_id to equal")
}
func TestConnectionTransportParameterValidationFailureClient(t *testing.T) {
t.Run("initial_source_connection_id", func(t *testing.T) {
tc := newClientTestConnection(t, nil, nil, false)
err := tc.conn.handleTransportParameters(&wire.TransportParameters{
InitialSourceConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
})
assert.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.TransportParameterError})
assert.ErrorContains(t, err, "expected initial_source_connection_id to equal")
})
t.Run("original_destination_connection_id", func(t *testing.T) {
tc := newClientTestConnection(t, nil, nil, false)
err := tc.conn.handleTransportParameters(&wire.TransportParameters{
InitialSourceConnectionID: tc.destConnID,
OriginalDestinationConnectionID: protocol.ParseConnectionID([]byte{1, 2, 3, 4}),
})
assert.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.TransportParameterError})
assert.ErrorContains(t, err, "expected original_destination_connection_id to equal")
})
t.Run("retry_source_connection_id if no retry", func(t *testing.T) {
tc := newClientTestConnection(t, nil, nil, false)
rcid := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
params := &wire.TransportParameters{
InitialSourceConnectionID: tc.destConnID,
OriginalDestinationConnectionID: tc.destConnID,
RetrySourceConnectionID: &rcid,
}
err := tc.conn.handleTransportParameters(params)
assert.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.TransportParameterError})
assert.ErrorContains(t, err, "received retry_source_connection_id, although no Retry was performed")
})
t.Run("retry_source_connection_id missing", func(t *testing.T) {
tc := newClientTestConnection(t,
nil,
nil,
false,
connectionOptRetrySrcConnID(protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef})),
)
params := &wire.TransportParameters{
InitialSourceConnectionID: tc.destConnID,
OriginalDestinationConnectionID: tc.destConnID,
}
err := tc.conn.handleTransportParameters(params)
assert.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.TransportParameterError})
assert.ErrorContains(t, err, "missing retry_source_connection_id")
})
t.Run("retry_source_connection_id incorrect", func(t *testing.T) {
tc := newClientTestConnection(t,
nil,
nil,
false,
connectionOptRetrySrcConnID(protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef})),
)
wrongCID := protocol.ParseConnectionID([]byte{1, 2, 3, 4})
params := &wire.TransportParameters{
InitialSourceConnectionID: tc.destConnID,
OriginalDestinationConnectionID: tc.destConnID,
RetrySourceConnectionID: &wrongCID,
}
err := tc.conn.handleTransportParameters(params)
assert.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.TransportParameterError})
assert.ErrorContains(t, err, "expected retry_source_connection_id to equal")
})
}
func TestConnectionHandshakeServer(t *testing.T) {
mockCtrl := gomock.NewController(t)
cs := mocks.NewMockCryptoSetup(mockCtrl)
unpacker := NewMockUnpacker(mockCtrl)
tc := newServerTestConnection(
t,
mockCtrl,
nil,
false,
connectionOptCryptoSetup(cs),
connectionOptUnpacker(unpacker),
)
// the state transition is driven by processing of a CRYPTO frame
hdr := &wire.ExtendedHeader{
Header: wire.Header{Type: protocol.PacketTypeHandshake, Version: protocol.Version1},
PacketNumberLen: protocol.PacketNumberLen2,
}
data, err := (&wire.CryptoFrame{Data: []byte("foobar")}).Append(nil, protocol.Version1)
require.NoError(t, err)
cs.EXPECT().DiscardInitialKeys()
tc.connRunner.EXPECT().Retire(gomock.Any())
gomock.InOrder(
cs.EXPECT().StartHandshake(gomock.Any()),
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}),
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(
&unpackedPacket{hdr: hdr, encryptionLevel: protocol.EncryptionHandshake, data: data}, nil,
),
cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake),
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventHandshakeComplete}),
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}),
cs.EXPECT().SetHandshakeConfirmed(),
cs.EXPECT().GetSessionTicket().Return([]byte("session ticket"), nil),
)
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack).AnyTimes()
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
p := getLongHeaderPacket(t, tc.remoteAddr, hdr, nil)
tc.conn.handlePacket(receivedPacket{data: p.data, buffer: p.buffer, rcvTime: time.Now()})
select {
case <-tc.conn.HandshakeComplete():
case <-tc.conn.Context().Done():
t.Fatal("connection context done")
case <-time.After(time.Second):
t.Fatal("timeout")
}
var foundSessionTicket, foundHandshakeDone, foundNewToken bool
frames, _, _ := tc.conn.framer.Append(nil, nil, protocol.MaxByteCount, time.Now(), protocol.Version1)
for _, frame := range frames {
switch f := frame.Frame.(type) {
case *wire.CryptoFrame:
assert.Equal(t, []byte("session ticket"), f.Data)
foundSessionTicket = true
case *wire.HandshakeDoneFrame:
foundHandshakeDone = true
case *wire.NewTokenFrame:
assert.NotEmpty(t, f.Token)
foundNewToken = true
}
}
assert.True(t, foundSessionTicket)
assert.True(t, foundHandshakeDone)
assert.True(t, foundNewToken)
// test teardown
cs.EXPECT().Close()
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
tc.conn.destroy(nil)
select {
case err := <-errChan:
require.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestConnectionHandshakeClient(t *testing.T) {
t.Run("without preferred address", func(t *testing.T) {
testConnectionHandshakeClient(t, false)
})
t.Run("with preferred address", func(t *testing.T) {
testConnectionHandshakeClient(t, true)
})
}
func testConnectionHandshakeClient(t *testing.T, usePreferredAddress bool) {
mockCtrl := gomock.NewController(t)
cs := mocks.NewMockCryptoSetup(mockCtrl)
unpacker := NewMockUnpacker(mockCtrl)
tc := newClientTestConnection(t, mockCtrl, nil, false, connectionOptCryptoSetup(cs), connectionOptUnpacker(unpacker))
tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
// the state transition is driven by processing of a CRYPTO frame
hdr := &wire.ExtendedHeader{
Header: wire.Header{Type: protocol.PacketTypeHandshake, Version: protocol.Version1},
PacketNumberLen: protocol.PacketNumberLen2,
}
data, err := (&wire.CryptoFrame{Data: []byte("foobar")}).Append(nil, protocol.Version1)
require.NoError(t, err)
tp := &wire.TransportParameters{
OriginalDestinationConnectionID: tc.destConnID,
MaxIdleTimeout: time.Hour,
}
preferredAddressConnID := protocol.ParseConnectionID([]byte{10, 8, 6, 4})
preferredAddressResetToken := protocol.StatelessResetToken{16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1}
if usePreferredAddress {
tp.PreferredAddress = &wire.PreferredAddress{
IPv4: netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), 42),
IPv6: netip.AddrPortFrom(netip.AddrFrom16([16]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}), 13),
ConnectionID: preferredAddressConnID,
StatelessResetToken: preferredAddressResetToken,
}
}
packedFirstPacket := make(chan struct{})
gomock.InOrder(
cs.EXPECT().StartHandshake(gomock.Any()),
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}),
tc.packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), gomock.Any(), protocol.Version1).DoAndReturn(
func(b bool, bc protocol.ByteCount, t time.Time, v protocol.Version) (*coalescedPacket, error) {
close(packedFirstPacket)
return &coalescedPacket{buffer: getPacketBuffer(), longHdrPackets: []*longHeaderPacket{{header: hdr}}}, nil
},
),
// initial keys are dropped when the first handshake packet is sent
cs.EXPECT().DiscardInitialKeys(),
// no more data to send
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(
&unpackedPacket{hdr: hdr, encryptionLevel: protocol.EncryptionHandshake, data: data}, nil,
),
cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake),
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventReceivedTransportParameters, TransportParameters: tp}),
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventHandshakeComplete}),
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}),
)
tc.packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), gomock.Any(), protocol.Version1).Return(nil, nil).AnyTimes()
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
select {
case <-packedFirstPacket:
case <-time.After(time.Second):
t.Fatal("timeout")
}
p := getLongHeaderPacket(t, tc.remoteAddr, hdr, nil)
tc.conn.handlePacket(receivedPacket{data: p.data, buffer: p.buffer, rcvTime: time.Now()})
select {
case <-tc.conn.HandshakeComplete():
case <-tc.conn.Context().Done():
t.Fatal("connection context done")
case <-time.After(time.Second):
t.Fatal("timeout")
}
require.True(t, mockCtrl.Satisfied())
// the handshake isn't confirmed until we receive a HANDSHAKE_DONE frame from the server
data, err = (&wire.HandshakeDoneFrame{}).Append(nil, protocol.Version1)
require.NoError(t, err)
done := make(chan struct{})
tc.packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), gomock.Any(), protocol.Version1).Return(nil, nil).AnyTimes()
gomock.InOrder(
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(
&unpackedPacket{hdr: hdr, encryptionLevel: protocol.Encryption1RTT, data: data}, nil,
),
cs.EXPECT().SetHandshakeConfirmed(),
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(buf *packetBuffer, _ protocol.ByteCount, _ time.Time, _ protocol.Version) (shortHeaderPacket, error) {
close(done)
return shortHeaderPacket{}, errNothingToPack
},
),
)
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack).AnyTimes()
p = getLongHeaderPacket(t, tc.remoteAddr, hdr, nil)
tc.conn.handlePacket(receivedPacket{data: p.data, buffer: p.buffer, rcvTime: time.Now()})
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("timeout")
}
if usePreferredAddress {
tc.connRunner.EXPECT().AddResetToken(preferredAddressResetToken, gomock.Any())
}
nextConnID := tc.conn.connIDManager.Get()
if usePreferredAddress {
require.Equal(t, preferredAddressConnID, nextConnID)
}
// test teardown
cs.EXPECT().Close()
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
if usePreferredAddress {
tc.connRunner.EXPECT().RemoveResetToken(preferredAddressResetToken)
}
tc.conn.destroy(nil)
select {
case err := <-errChan:
require.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestConnection0RTTTransportParameters(t *testing.T) {
mockCtrl := gomock.NewController(t)
cs := mocks.NewMockCryptoSetup(mockCtrl)
unpacker := NewMockUnpacker(mockCtrl)
tc := newClientTestConnection(t, mockCtrl, nil, false, connectionOptCryptoSetup(cs), connectionOptUnpacker(unpacker))
tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
// the state transition is driven by processing of a CRYPTO frame
hdr := &wire.ExtendedHeader{
Header: wire.Header{Type: protocol.PacketTypeHandshake, Version: protocol.Version1},
PacketNumberLen: protocol.PacketNumberLen2,
}
data, err := (&wire.CryptoFrame{Data: []byte("foobar")}).Append(nil, protocol.Version1)
require.NoError(t, err)
restored := &wire.TransportParameters{
ActiveConnectionIDLimit: 3,
InitialMaxData: 0x5000,
InitialMaxStreamDataBidiLocal: 0x5000,
InitialMaxStreamDataBidiRemote: 1000,
InitialMaxStreamDataUni: 1000,
MaxBidiStreamNum: 500,
MaxUniStreamNum: 500,
}
new := *restored
new.MaxBidiStreamNum-- // the server is not allowed to reduce the limit
new.OriginalDestinationConnectionID = tc.destConnID
packedFirstPacket := make(chan struct{})
gomock.InOrder(
cs.EXPECT().StartHandshake(gomock.Any()),
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventRestoredTransportParameters, TransportParameters: restored}),
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}),
tc.packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), gomock.Any(), protocol.Version1).DoAndReturn(
func(b bool, bc protocol.ByteCount, t time.Time, v protocol.Version) (*coalescedPacket, error) {
close(packedFirstPacket)
return &coalescedPacket{buffer: getPacketBuffer(), longHdrPackets: []*longHeaderPacket{{header: hdr}}}, nil
},
),
// initial keys are dropped when the first handshake packet is sent
cs.EXPECT().DiscardInitialKeys(),
// no more data to send
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(
&unpackedPacket{hdr: hdr, encryptionLevel: protocol.EncryptionHandshake, data: data}, nil,
),
cs.EXPECT().HandleMessage([]byte("foobar"), protocol.EncryptionHandshake),
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventReceivedTransportParameters, TransportParameters: &new}),
cs.EXPECT().ConnectionState().Return(handshake.ConnectionState{Used0RTT: true}),
// cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}),
cs.EXPECT().Close(),
)
tc.packer.EXPECT().PackCoalescedPacket(false, gomock.Any(), gomock.Any(), protocol.Version1).Return(nil, nil).AnyTimes()
tc.packer.EXPECT().PackConnectionClose(gomock.Any(), gomock.Any(), protocol.Version1).Return(&coalescedPacket{buffer: getPacketBuffer()}, nil)
tc.connRunner.EXPECT().ReplaceWithClosed(gomock.Any(), gomock.Any())
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
select {
case <-packedFirstPacket:
case <-time.After(time.Second):
t.Fatal("timeout")
}
p := getLongHeaderPacket(t, tc.remoteAddr, hdr, nil)
tc.conn.handlePacket(receivedPacket{data: p.data, buffer: p.buffer, rcvTime: time.Now()})
select {
case err := <-errChan:
require.ErrorIs(t, err, &qerr.TransportError{ErrorCode: qerr.ProtocolViolation})
require.ErrorContains(t, err, "server sent reduced limits after accepting 0-RTT data")
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestConnectionReceivePrioritization(t *testing.T) {
t.Run("handshake complete", func(t *testing.T) {
events := testConnectionReceivePrioritization(t, true, 5)
require.Equal(t, []string{"unpack", "unpack", "unpack", "unpack", "unpack", "pack"}, events)
})
// before handshake completion, we trigger packing of a new packet every time we receive a packet
t.Run("handshake not complete", func(t *testing.T) {
events := testConnectionReceivePrioritization(t, false, 5)
require.Equal(t, []string{
"unpack", "pack",
"unpack", "pack",
"unpack", "pack",
"unpack", "pack",
"unpack", "pack",
}, events)
})
}
func testConnectionReceivePrioritization(t *testing.T, handshakeComplete bool, numPackets int) []string {
mockCtrl := gomock.NewController(t)
unpacker := NewMockUnpacker(mockCtrl)
opts := []testConnectionOpt{connectionOptUnpacker(unpacker)}
if handshakeComplete {
opts = append(opts, connectionOptHandshakeConfirmed())
}
tc := newServerTestConnection(t, mockCtrl, nil, false, opts...)
var events []string
var counter int
var testDone bool
done := make(chan struct{})
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn(
func(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) {
counter++
if counter == numPackets {
testDone = true
}
events = append(events, "unpack")
return protocol.PacketNumber(counter), protocol.PacketNumberLen2, protocol.KeyPhaseZero, []byte{0, 1} /* PADDING, PING */, nil
},
).Times(numPackets)
switch handshakeComplete {
case false:
tc.packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(b bool, bc protocol.ByteCount, t time.Time, v protocol.Version) (*coalescedPacket, error) {
events = append(events, "pack")
if testDone {
close(done)
}
return nil, nil
},
).AnyTimes()
case true:
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(b *packetBuffer, bc protocol.ByteCount, t time.Time, v protocol.Version) (shortHeaderPacket, error) {
events = append(events, "pack")
if testDone {
close(done)
}
return shortHeaderPacket{}, errNothingToPack
},
).AnyTimes()
}
for i := range numPackets {
tc.conn.handlePacket(getShortHeaderPacket(t, tc.remoteAddr, tc.srcConnID, protocol.PacketNumber(i), []byte("foobar")))
}
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("timeout")
}
// test teardown
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
tc.conn.destroy(nil)
select {
case err := <-errChan:
require.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("timeout")
}
return events
}
func TestConnectionPacketBuffering(t *testing.T) {
mockCtrl := gomock.NewController(t)
unpacker := NewMockUnpacker(mockCtrl)
cs := mocks.NewMockCryptoSetup(mockCtrl)
tracer, tr := mocklogging.NewMockConnectionTracer(mockCtrl)
tc := newServerTestConnection(t,
mockCtrl,
nil,
false,
connectionOptUnpacker(unpacker),
connectionOptCryptoSetup(cs),
connectionOptTracer(tracer),
)
tr.EXPECT().NegotiatedVersion(gomock.Any(), gomock.Any(), gomock.Any())
tr.EXPECT().StartedConnection(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
tr.EXPECT().DroppedEncryptionLevel(gomock.Any())
cs.EXPECT().DiscardInitialKeys()
hdr1 := wire.ExtendedHeader{
Header: wire.Header{
Type: protocol.PacketTypeHandshake,
DestConnectionID: tc.srcConnID,
SrcConnectionID: tc.destConnID,
Length: 8,
Version: protocol.Version1,
},
PacketNumberLen: protocol.PacketNumberLen1,
PacketNumber: 1,
}
hdr2 := hdr1
hdr2.PacketNumber = 2
cs.EXPECT().StartHandshake(gomock.Any())
buffered := make(chan struct{})
gomock.InOrder(
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent}),
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(nil, handshake.ErrKeysNotYetAvailable),
tr.EXPECT().BufferedPacket(logging.PacketTypeHandshake, gomock.Any()),
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(nil, handshake.ErrKeysNotYetAvailable),
tr.EXPECT().BufferedPacket(logging.PacketTypeHandshake, gomock.Any()).Do(
func(logging.PacketType, logging.ByteCount) { close(buffered) },
),
)
tc.conn.handlePacket(getLongHeaderPacket(t, tc.remoteAddr, &hdr1, []byte("packet1")))
tc.conn.handlePacket(getLongHeaderPacket(t, tc.remoteAddr, &hdr2, []byte("packet2")))
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
select {
case <-buffered:
case <-time.After(time.Second):
t.Fatal("timeout")
}
// Now send another packet.
// In reality, this packet would contain a CRYPTO frame that advances the TLS handshake
// such that new keys become available.
var packets []string
hdr3 := hdr1
hdr3.PacketNumber = 3
tc.packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
unpacked := make(chan struct{})
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventReceivedReadKeys})
cs.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
gomock.InOrder(
// packet 3 contains a CRYPTO frame and triggers the keys to become available
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).DoAndReturn(
func(hdr *wire.Header, data []byte) (*unpackedPacket, error) {
packets = append(packets, string(data[len(data)-7:]))
cf := &wire.CryptoFrame{Data: []byte("foobar")}
b, _ := cf.Append(nil, protocol.Version1)
return &unpackedPacket{hdr: &hdr3, encryptionLevel: protocol.EncryptionHandshake, data: b}, nil
},
),
cs.EXPECT().HandleMessage(gomock.Any(), gomock.Any()),
tr.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()),
// packet 1 dequeued from the buffer
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).DoAndReturn(
func(hdr *wire.Header, data []byte) (*unpackedPacket, error) {
packets = append(packets, string(data[len(data)-7:]))
return &unpackedPacket{hdr: &hdr1, encryptionLevel: protocol.EncryptionHandshake, data: []byte{0} /* PADDING */}, nil
},
),
tr.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()),
// packet 2 dequeued from the buffer
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).DoAndReturn(
func(hdr *wire.Header, data []byte) (*unpackedPacket, error) {
packets = append(packets, string(data[len(data)-7:]))
close(unpacked)
return &unpackedPacket{hdr: &hdr2, encryptionLevel: protocol.EncryptionHandshake, data: []byte{0} /* PADDING */}, nil
},
),
tr.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()),
)
tc.conn.handlePacket(getLongHeaderPacket(t, tc.remoteAddr, &hdr3, []byte("packet3")))
select {
case <-unpacked:
case <-time.After(time.Second):
t.Fatal("timeout")
}
// packet3 triggered the keys to become available
// packet1 and packet2 are processed from the buffer in order
require.Equal(t, []string{"packet3", "packet1", "packet2"}, packets)
// test teardown
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
cs.EXPECT().Close()
tr.EXPECT().ClosedConnection(gomock.Any())
tr.EXPECT().Close()
tc.conn.destroy(nil)
select {
case err := <-errChan:
require.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestConnectionPacketPacing(t *testing.T) {
mockCtrl := gomock.NewController(t)
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sender := NewMockSender(mockCtrl)
tc := newServerTestConnection(t,
mockCtrl,
nil,
false,
connectionOptSentPacketHandler(sph),
connectionOptSender(sender),
connectionOptHandshakeConfirmed(),
// set a fixed RTT, so that the idle timeout doesn't interfere with this test
connectionOptRTT(10*time.Second),
)
sender.EXPECT().Run()
step := scaleDuration(50 * time.Millisecond)
sph.EXPECT().GetLossDetectionTimeout().Return(time.Now().Add(time.Hour)).AnyTimes()
gomock.InOrder(
// 1. allow 2 packets to be sent
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny),
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()),
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny),
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()),
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited),
// 2. become pacing limited for 25ms
sph.EXPECT().TimeUntilSend().DoAndReturn(func() time.Time { return time.Now().Add(step) }),
// 3. send another packet
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny),
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()),
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited),
// 4. become pacing limited for 25ms...
sph.EXPECT().TimeUntilSend().DoAndReturn(func() time.Time { return time.Now().Add(step) }),
// ... but this time we're still pacing limited when waking up.
// In this case, we can only send an ACK.
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited),
// 5. stop the test by becoming pacing limited forever
sph.EXPECT().TimeUntilSend().Return(time.Now().Add(time.Hour)),
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()),
)
sph.EXPECT().ECNMode(gomock.Any()).AnyTimes()
for i := 0; i < 3; i++ {
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), Version1).DoAndReturn(
func(buf *packetBuffer, _ protocol.ByteCount, _ time.Time, _ protocol.Version) (shortHeaderPacket, error) {
buf.Data = append(buf.Data, []byte("packet"+strconv.Itoa(i+1))...)
return shortHeaderPacket{PacketNumber: protocol.PacketNumber(i + 1)}, nil
},
)
}
tc.packer.EXPECT().PackAckOnlyPacket(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(_ protocol.ByteCount, _ time.Time, _ protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
buf := getPacketBuffer()
buf.Data = []byte("ack")
return shortHeaderPacket{PacketNumber: 1}, buf, nil
},
)
sender.EXPECT().WouldBlock().AnyTimes()
type sentPacket struct {
time time.Time
data []byte
}
sendChan := make(chan sentPacket, 10)
sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(b *packetBuffer, _ uint16, _ protocol.ECN) {
sendChan <- sentPacket{time: time.Now(), data: b.Data}
}).Times(4)
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
tc.conn.scheduleSending()
var times []time.Time
for i := 0; i < 3; i++ {
select {
case b := <-sendChan:
require.Equal(t, []byte("packet"+strconv.Itoa(i+1)), b.data)
times = append(times, b.time)
case <-time.After(scaleDuration(time.Second)):
t.Fatal("timeout")
}
}
select {
case b := <-sendChan:
require.Equal(t, []byte("ack"), b.data)
times = append(times, b.time)
case <-time.After(scaleDuration(time.Second)):
t.Fatal("timeout")
}
require.InDelta(t, times[0].Sub(times[1]).Seconds(), 0, scaleDuration(10*time.Millisecond).Seconds())
require.InDelta(t, times[2].Sub(times[1]).Seconds(), step.Seconds(), scaleDuration(20*time.Millisecond).Seconds())
require.InDelta(t, times[3].Sub(times[2]).Seconds(), step.Seconds(), scaleDuration(20*time.Millisecond).Seconds())
time.Sleep(scaleDuration(step)) // make sure that no more packets are sent
require.True(t, mockCtrl.Satisfied())
// test teardown
sender.EXPECT().Close()
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
tc.conn.destroy(nil)
select {
case <-sendChan:
t.Fatal("should not have sent any more packets")
case err := <-errChan:
require.NoError(t, err)
case <-time.After(3 * time.Second):
t.Fatal("timeout")
}
}
// When the send queue blocks, we need to reset the pacing timer, otherwise the run loop might busy-loop.
// See https://github.com/quic-go/quic-go/pull/4943 for more details.
func TestConnectionPacingAndSendQueue(t *testing.T) {
mockCtrl := gomock.NewController(t)
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sender := NewMockSender(mockCtrl)
tc := newServerTestConnection(t,
mockCtrl,
nil,
false,
connectionOptSentPacketHandler(sph),
connectionOptSender(sender),
connectionOptHandshakeConfirmed(),
// set a fixed RTT, so that the idle timeout doesn't interfere with this test
connectionOptRTT(10*time.Second),
)
sender.EXPECT().Run()
sendQueueAvailable := make(chan struct{})
pacingDeadline := time.Now().Add(-time.Millisecond)
var counter int
// allow exactly one packet to be sent, then become blocked
sender.EXPECT().WouldBlock().Return(false)
sender.EXPECT().WouldBlock().DoAndReturn(func() bool { counter++; return true }).AnyTimes()
sender.EXPECT().Available().Return(sendQueueAvailable).AnyTimes()
sph.EXPECT().GetLossDetectionTimeout().Return(time.Now().Add(time.Hour)).AnyTimes()
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendPacingLimited).AnyTimes()
sph.EXPECT().TimeUntilSend().Return(pacingDeadline).AnyTimes()
sph.EXPECT().ECNMode(gomock.Any()).Return(protocol.ECNNon).AnyTimes()
tc.packer.EXPECT().PackAckOnlyPacket(gomock.Any(), gomock.Any(), gomock.Any()).Return(
shortHeaderPacket{}, nil, errNothingToPack,
)
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
tc.conn.scheduleSending()
time.Sleep(scaleDuration(10 * time.Millisecond))
// test teardown
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
sender.EXPECT().Close()
tc.conn.destroy(nil)
select {
case err := <-errChan:
require.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("timeout")
}
// make sure the run loop didn't do too many iterations
require.Less(t, counter, 3)
}
func TestConnectionIdleTimeout(t *testing.T) {
mockCtrl := gomock.NewController(t)
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
tc := newServerTestConnection(t,
mockCtrl,
&Config{MaxIdleTimeout: time.Second},
false,
connectionOptHandshakeConfirmed(),
connectionOptSentPacketHandler(sph),
connectionOptRTT(time.Millisecond),
)
// the idle timeout is set when the transport parameters are received
idleTimeout := scaleDuration(50 * time.Millisecond)
require.NoError(t, tc.conn.handleTransportParameters(&wire.TransportParameters{
MaxIdleTimeout: idleTimeout,
}))
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
sph.EXPECT().ECNMode(gomock.Any()).AnyTimes()
var lastSendTime time.Time
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(buf *packetBuffer, _ protocol.ByteCount, _ time.Time, _ protocol.Version) (shortHeaderPacket, error) {
buf.Data = append(buf.Data, []byte("foobar")...)
lastSendTime = time.Now()
return shortHeaderPacket{Frames: []ackhandler.Frame{{Frame: &wire.PingFrame{}}}, Length: 6}, nil
},
)
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack)
tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
tc.conn.scheduleSending()
select {
case err := <-errChan:
require.ErrorIs(t, err, &IdleTimeoutError{})
require.NotZero(t, lastSendTime)
require.InDelta(t,
time.Since(lastSendTime).Seconds(),
idleTimeout.Seconds(),
scaleDuration(10*time.Millisecond).Seconds(),
)
case <-time.After(3 * time.Second):
t.Fatal("timeout")
}
}
func TestConnectionKeepAlive(t *testing.T) {
t.Run("enabled", func(t *testing.T) {
testConnectionKeepAlive(t, true, true)
})
t.Run("disabled", func(t *testing.T) {
testConnectionKeepAlive(t, false, false)
})
}
func testConnectionKeepAlive(t *testing.T, enable, expectKeepAlive bool) {
var keepAlivePeriod time.Duration
if enable {
keepAlivePeriod = time.Second
}
mockCtrl := gomock.NewController(t)
unpacker := NewMockUnpacker(mockCtrl)
tc := newServerTestConnection(t,
mockCtrl,
&Config{MaxIdleTimeout: time.Second, KeepAlivePeriod: keepAlivePeriod},
false,
connectionOptUnpacker(unpacker),
connectionOptHandshakeConfirmed(),
connectionOptRTT(time.Millisecond),
)
// the idle timeout is set when the transport parameters are received
idleTimeout := scaleDuration(50 * time.Millisecond)
require.NoError(t, tc.conn.handleTransportParameters(&wire.TransportParameters{
MaxIdleTimeout: idleTimeout,
}))
// Receive a packet. This starts the keep-alive timer.
buf := getPacketBuffer()
var err error
buf.Data, err = wire.AppendShortHeader(buf.Data, tc.srcConnID, 1, protocol.PacketNumberLen1, protocol.KeyPhaseZero)
require.NoError(t, err)
buf.Data = append(buf.Data, []byte("packet")...)
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
var unpackTime, packTime time.Time
done := make(chan struct{})
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).DoAndReturn(
func(t time.Time, bytes []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) {
unpackTime = time.Now()
return protocol.PacketNumber(1), protocol.PacketNumberLen1, protocol.KeyPhaseZero, []byte{0} /* PADDING */, nil
},
)
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack)
switch expectKeepAlive {
case true:
// record the time of the keep-alive is sent
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(buffer *packetBuffer, count protocol.ByteCount, t time.Time, version protocol.Version) (shortHeaderPacket, error) {
packTime = time.Now()
close(done)
return shortHeaderPacket{}, errNothingToPack
},
)
tc.conn.handlePacket(receivedPacket{data: buf.Data, buffer: buf, rcvTime: time.Now(), remoteAddr: tc.remoteAddr})
select {
case <-done:
// the keep-alive packet should be sent after half the idle timeout
diff := packTime.Sub(unpackTime)
require.InDelta(t, diff.Seconds(), idleTimeout.Seconds()/2, scaleDuration(10*time.Millisecond).Seconds())
case <-time.After(idleTimeout):
t.Fatal("timeout")
}
case false: // if keep-alives are disabled, the connection will run into an idle timeout
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
tc.conn.handlePacket(receivedPacket{data: buf.Data, buffer: buf, rcvTime: time.Now(), remoteAddr: tc.remoteAddr})
select {
case <-time.After(3 * time.Second):
t.Fatal("timeout")
case <-time.After(idleTimeout):
}
}
// test teardown
if expectKeepAlive {
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
tc.conn.destroy(nil)
}
select {
case err := <-errChan:
if expectKeepAlive {
require.NoError(t, err)
} else {
require.ErrorIs(t, err, &IdleTimeoutError{})
}
case <-time.After(3 * time.Second):
t.Fatal("timeout")
}
}
func TestConnectionACKTimer(t *testing.T) {
mockCtrl := gomock.NewController(t)
rph := mockackhandler.NewMockReceivedPacketHandler(mockCtrl)
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
tc := newServerTestConnection(t,
mockCtrl,
&Config{MaxIdleTimeout: time.Second},
false,
connectionOptHandshakeConfirmed(),
connectionOptReceivedPacketHandler(rph),
connectionOptSentPacketHandler(sph),
connectionOptRTT(10*time.Second),
)
alarmTimeout := scaleDuration(50 * time.Millisecond)
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
sph.EXPECT().ECNMode(gomock.Any()).AnyTimes()
rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(time.Hour))
tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
var times []time.Time
done := make(chan struct{}, 5)
var calls []any
for i := 0; i < 2; i++ {
calls = append(calls, tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(buf *packetBuffer, _ protocol.ByteCount, _ time.Time, _ protocol.Version) (shortHeaderPacket, error) {
buf.Data = append(buf.Data, []byte("foobar")...)
times = append(times, time.Now())
return shortHeaderPacket{Frames: []ackhandler.Frame{{Frame: &wire.PingFrame{}}}, Length: 6}, nil
},
))
calls = append(calls, tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(buf *packetBuffer, _ protocol.ByteCount, _ time.Time, _ protocol.Version) (shortHeaderPacket, error) {
done <- struct{}{}
return shortHeaderPacket{}, errNothingToPack
},
))
if i == 0 {
calls = append(calls, rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(alarmTimeout)))
} else {
calls = append(calls, rph.EXPECT().GetAlarmTimeout().Return(time.Now().Add(time.Hour)).MaxTimes(1))
}
}
gomock.InOrder(calls...)
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
tc.conn.scheduleSending()
for i := 0; i < 2; i++ {
select {
case <-done:
case <-time.After(3 * time.Second):
t.Fatal("timeout")
}
}
assert.Len(t, times, 2)
require.InDelta(t, times[1].Sub(times[0]).Seconds(), alarmTimeout.Seconds(), scaleDuration(10*time.Millisecond).Seconds())
// test teardown
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
tc.conn.destroy(nil)
select {
case err := <-errChan:
require.NoError(t, err)
case <-time.After(3 * time.Second):
t.Fatal("timeout")
}
}
// Send a GSO batch, until we have no more data to send.
func TestConnectionGSOBatch(t *testing.T) {
mockCtrl := gomock.NewController(t)
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
tc := newServerTestConnection(t,
mockCtrl,
nil,
true,
connectionOptHandshakeConfirmed(),
connectionOptSentPacketHandler(sph),
)
// allow packets to be sent
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
sph.EXPECT().TimeUntilSend().Return(time.Time{}).AnyTimes()
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
sph.EXPECT().GetLossDetectionTimeout().Return(time.Time{}).AnyTimes()
sph.EXPECT().ECNMode(gomock.Any()).Return(protocol.ECT1).AnyTimes()
maxPacketSize := tc.conn.maxPacketSize()
var expectedData []byte
for i := 0; i < 4; i++ {
data := bytes.Repeat([]byte{byte(i)}, int(maxPacketSize))
expectedData = append(expectedData, data...)
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(buffer *packetBuffer, count protocol.ByteCount, t time.Time, version protocol.Version) (shortHeaderPacket, error) {
buffer.Data = append(buffer.Data, data...)
return shortHeaderPacket{PacketNumber: protocol.PacketNumber(i)}, nil
},
)
}
done := make(chan struct{})
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack)
tc.sendConn.EXPECT().Write(expectedData, uint16(maxPacketSize), protocol.ECT1).DoAndReturn(
func([]byte, uint16, protocol.ECN) error { close(done); return nil },
)
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
tc.conn.scheduleSending()
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("timeout")
}
// test teardown
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
tc.conn.destroy(nil)
select {
case err := <-errChan:
require.NoError(t, err)
case <-time.After(3 * time.Second):
t.Fatal("timeout")
}
}
// Send a GSO batch, until a packet smaller than the maximum size is packed
func TestConnectionGSOBatchPacketSize(t *testing.T) {
mockCtrl := gomock.NewController(t)
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
tc := newServerTestConnection(t,
mockCtrl,
nil,
true,
connectionOptHandshakeConfirmed(),
connectionOptSentPacketHandler(sph),
)
// allow packets to be sent
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
sph.EXPECT().TimeUntilSend().Return(time.Time{}).AnyTimes()
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
sph.EXPECT().GetLossDetectionTimeout().Return(time.Time{}).AnyTimes()
sph.EXPECT().ECNMode(gomock.Any()).Return(protocol.ECT1).AnyTimes()
maxPacketSize := tc.conn.maxPacketSize()
var expectedData []byte
var calls []any
for i := 0; i < 4; i++ {
var data []byte
if i == 3 {
data = bytes.Repeat([]byte{byte(i)}, int(maxPacketSize-1))
} else {
data = bytes.Repeat([]byte{byte(i)}, int(maxPacketSize))
}
expectedData = append(expectedData, data...)
calls = append(calls, tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(buffer *packetBuffer, count protocol.ByteCount, t time.Time, version protocol.Version) (shortHeaderPacket, error) {
buffer.Data = append(buffer.Data, data...)
return shortHeaderPacket{PacketNumber: protocol.PacketNumber(10 + i)}, nil
},
))
}
// The smaller (fourth) packet concluded this GSO batch, but the send loop will immediately start composing the next batch.
// We therefore send a "foobar", so we can check that we're actually generating two GSO batches.
calls = append(calls,
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(buffer *packetBuffer, count protocol.ByteCount, t time.Time, version protocol.Version) (shortHeaderPacket, error) {
buffer.Data = append(buffer.Data, []byte("foobar")...)
return shortHeaderPacket{PacketNumber: protocol.PacketNumber(14)}, nil
},
),
)
calls = append(calls,
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack),
)
gomock.InOrder(calls...)
done := make(chan struct{})
gomock.InOrder(
tc.sendConn.EXPECT().Write(expectedData, uint16(maxPacketSize), protocol.ECT1),
tc.sendConn.EXPECT().Write([]byte("foobar"), uint16(maxPacketSize), protocol.ECT1).DoAndReturn(
func([]byte, uint16, protocol.ECN) error { close(done); return nil },
),
)
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
tc.conn.scheduleSending()
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("timeout")
}
// test teardown
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
tc.conn.destroy(nil)
select {
case err := <-errChan:
require.NoError(t, err)
case <-time.After(3 * time.Second):
t.Fatal("timeout")
}
}
func TestConnectionGSOBatchECN(t *testing.T) {
mockCtrl := gomock.NewController(t)
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
tc := newServerTestConnection(t,
mockCtrl,
nil,
true,
connectionOptHandshakeConfirmed(),
connectionOptSentPacketHandler(sph),
)
// allow packets to be sent
ecnMode := protocol.ECT1
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
sph.EXPECT().TimeUntilSend().Return(time.Time{}).AnyTimes()
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
sph.EXPECT().GetLossDetectionTimeout().Return(time.Time{}).AnyTimes()
sph.EXPECT().ECNMode(gomock.Any()).DoAndReturn(func(bool) protocol.ECN { return ecnMode }).AnyTimes()
// 3. Send a GSO batch, until the ECN marking changes.
var expectedData []byte
var calls []any
maxPacketSize := tc.conn.maxPacketSize()
for i := 0; i < 3; i++ {
data := bytes.Repeat([]byte{byte(i)}, int(maxPacketSize))
expectedData = append(expectedData, data...)
calls = append(calls, tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(buffer *packetBuffer, count protocol.ByteCount, t time.Time, version protocol.Version) (shortHeaderPacket, error) {
buffer.Data = append(buffer.Data, data...)
if i == 2 {
ecnMode = protocol.ECNCE
}
return shortHeaderPacket{PacketNumber: protocol.PacketNumber(20 + i)}, nil
},
))
}
// The smaller (fourth) packet concluded this GSO batch, but the send loop will immediately start composing the next batch.
// We therefore send a "foobar", so we can check that we're actually generating two GSO batches.
calls = append(calls,
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(buffer *packetBuffer, count protocol.ByteCount, t time.Time, version protocol.Version) (shortHeaderPacket, error) {
buffer.Data = append(buffer.Data, []byte("foobar")...)
return shortHeaderPacket{PacketNumber: protocol.PacketNumber(24)}, nil
},
),
)
calls = append(calls,
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(shortHeaderPacket{}, errNothingToPack),
)
gomock.InOrder(calls...)
done3 := make(chan struct{})
tc.sendConn.EXPECT().Write(expectedData, uint16(maxPacketSize), protocol.ECT1)
tc.sendConn.EXPECT().Write([]byte("foobar"), uint16(maxPacketSize), protocol.ECNCE).DoAndReturn(
func([]byte, uint16, protocol.ECN) error { close(done3); return nil },
)
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
tc.conn.scheduleSending()
select {
case <-done3:
case <-time.After(time.Second):
t.Fatal("timeout")
}
// test teardown
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
tc.conn.destroy(nil)
select {
case err := <-errChan:
require.NoError(t, err)
case <-time.After(3 * time.Second):
t.Fatal("timeout")
}
}
func TestConnectionPTOProbePackets(t *testing.T) {
t.Run("Initial", func(t *testing.T) {
testConnectionPTOProbePackets(t, protocol.EncryptionInitial)
})
t.Run("Handshake", func(t *testing.T) {
testConnectionPTOProbePackets(t, protocol.EncryptionHandshake)
})
t.Run("1-RTT", func(t *testing.T) {
testConnectionPTOProbePackets(t, protocol.Encryption1RTT)
})
}
func testConnectionPTOProbePackets(t *testing.T, encLevel protocol.EncryptionLevel) {
mockCtrl := gomock.NewController(t)
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
tc := newServerTestConnection(t,
mockCtrl,
nil,
false,
connectionOptSentPacketHandler(sph),
)
var sendMode ackhandler.SendMode
switch encLevel {
case protocol.EncryptionInitial:
sendMode = ackhandler.SendPTOInitial
case protocol.EncryptionHandshake:
sendMode = ackhandler.SendPTOHandshake
case protocol.Encryption1RTT:
sendMode = ackhandler.SendPTOAppData
}
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().TimeUntilSend().AnyTimes()
sph.EXPECT().SendMode(gomock.Any()).Return(sendMode)
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone)
sph.EXPECT().ECNMode(gomock.Any())
sph.EXPECT().QueueProbePacket(encLevel).Return(false)
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
tc.packer.EXPECT().MaybePackPTOProbePacket(encLevel, gomock.Any(), gomock.Any(), protocol.Version1).DoAndReturn(
func(encLevel protocol.EncryptionLevel, maxSize protocol.ByteCount, t time.Time, version protocol.Version) (*coalescedPacket, error) {
return &coalescedPacket{
buffer: getPacketBuffer(),
shortHdrPacket: &shortHeaderPacket{PacketNumber: 1},
}, nil
},
)
done := make(chan struct{})
tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).Do(
func([]byte, uint16, protocol.ECN) error { close(done); return nil },
)
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
tc.conn.scheduleSending()
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("timeout")
}
// test teardown
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
tc.conn.destroy(nil)
select {
case err := <-errChan:
require.NoError(t, err)
case <-time.After(3 * time.Second):
t.Fatal("timeout")
}
}
func TestConnectionCongestionControl(t *testing.T) {
mockCtrl := gomock.NewController(t)
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
tc := newServerTestConnection(t,
mockCtrl,
nil,
false,
connectionOptHandshakeConfirmed(),
connectionOptSentPacketHandler(sph),
connectionOptRTT(10*time.Second),
)
sph.EXPECT().TimeUntilSend().AnyTimes()
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().ECNMode(true).AnyTimes()
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).Times(2)
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAck).MaxTimes(1)
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2)
// Since we're already sending out packets, we don't expect any calls to PackAckOnlyPacket
for i := 0; i < 2; i++ {
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(buffer *packetBuffer, count protocol.ByteCount, t time.Time, version protocol.Version) (shortHeaderPacket, error) {
buffer.Data = append(buffer.Data, []byte("foobar")...)
return shortHeaderPacket{PacketNumber: protocol.PacketNumber(i)}, nil
},
)
}
tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any())
done1 := make(chan struct{})
tc.sendConn.EXPECT().Write(gomock.Any(), gomock.Any(), gomock.Any()).Do(
func([]byte, uint16, protocol.ECN) error { close(done1); return nil },
)
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
tc.conn.scheduleSending()
select {
case <-done1:
case <-time.After(time.Second):
t.Fatal("timeout")
}
require.True(t, mockCtrl.Satisfied())
// Now that we're congestion limited, we can only send an ack-only packet
done2 := make(chan struct{})
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAck)
tc.packer.EXPECT().PackAckOnlyPacket(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(protocol.ByteCount, time.Time, protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
close(done2)
return shortHeaderPacket{}, nil, errNothingToPack
},
)
tc.conn.scheduleSending()
select {
case <-done2:
case <-time.After(time.Second):
t.Fatal("timeout")
}
require.True(t, mockCtrl.Satisfied())
// If the send mode is "none", we can't even send an ack-only packet
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendNone)
tc.conn.scheduleSending()
time.Sleep(scaleDuration(10 * time.Millisecond)) // make sure there are no calls to the packer
// test teardown
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
tc.conn.destroy(nil)
select {
case err := <-errChan:
require.NoError(t, err)
case <-time.After(3 * time.Second):
t.Fatal("timeout")
}
}
func TestConnectionSendQueue(t *testing.T) {
t.Run("with GSO", func(t *testing.T) {
testConnectionSendQueue(t, true)
})
t.Run("without GSO", func(t *testing.T) {
testConnectionSendQueue(t, false)
})
}
func testConnectionSendQueue(t *testing.T, enableGSO bool) {
mockCtrl := gomock.NewController(t)
sph := mockackhandler.NewMockSentPacketHandler(mockCtrl)
sender := NewMockSender(mockCtrl)
tc := newServerTestConnection(t,
mockCtrl,
nil,
enableGSO,
connectionOptSender(sender),
connectionOptHandshakeConfirmed(),
connectionOptSentPacketHandler(sph),
)
sender.EXPECT().Run().MaxTimes(1)
sender.EXPECT().WouldBlock()
sender.EXPECT().WouldBlock().Return(true).Times(2)
available := make(chan struct{})
blocked := make(chan struct{})
sender.EXPECT().Available().DoAndReturn(
func() <-chan struct{} {
close(blocked)
return available
},
)
sph.EXPECT().GetLossDetectionTimeout().AnyTimes()
sph.EXPECT().SentPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
sph.EXPECT().SendMode(gomock.Any()).Return(ackhandler.SendAny).AnyTimes()
sph.EXPECT().ECNMode(gomock.Any()).AnyTimes()
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(
shortHeaderPacket{PacketNumber: protocol.PacketNumber(1)}, nil,
)
sender.EXPECT().Send(gomock.Any(), gomock.Any(), gomock.Any())
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
tc.conn.scheduleSending()
select {
case <-blocked:
case <-time.After(time.Second):
t.Fatal("timeout")
}
require.True(t, mockCtrl.Satisfied())
// now make room in the send queue
sender.EXPECT().WouldBlock().AnyTimes()
unblocked := make(chan struct{})
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(*packetBuffer, protocol.ByteCount, time.Time, protocol.Version) (shortHeaderPacket, error) {
close(unblocked)
return shortHeaderPacket{}, errNothingToPack
},
)
available <- struct{}{}
select {
case <-unblocked:
case <-time.After(time.Second):
t.Fatal("timeout")
}
// test teardown
sender.EXPECT().Close()
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
tc.conn.destroy(nil)
select {
case err := <-errChan:
require.NoError(t, err)
case <-time.After(3 * time.Second):
t.Fatal("timeout")
}
}
func getVersionNegotiationPacket(src, dest protocol.ConnectionID, versions []protocol.Version) receivedPacket {
b := wire.ComposeVersionNegotiation(
protocol.ArbitraryLenConnectionID(src.Bytes()),
protocol.ArbitraryLenConnectionID(dest.Bytes()),
versions,
)
return receivedPacket{
rcvTime: time.Now(),
data: b,
buffer: getPacketBuffer(),
}
}
func TestConnectionVersionNegotiation(t *testing.T) {
mockCtrl := gomock.NewController(t)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
tc := newClientTestConnection(t,
mockCtrl,
nil,
false,
connectionOptTracer(tr),
)
tc.packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
var tracerVersions []logging.Version
gomock.InOrder(
tracer.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(func(_, _ protocol.ArbitraryLenConnectionID, versions []logging.Version) {
tracerVersions = versions
}),
tracer.EXPECT().NegotiatedVersion(protocol.Version2, gomock.Any(), gomock.Any()),
tc.connRunner.EXPECT().Remove(gomock.Any()),
)
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
tc.conn.handlePacket(getVersionNegotiationPacket(
tc.destConnID,
tc.srcConnID,
[]protocol.Version{1234, protocol.Version2},
))
select {
case err := <-errChan:
var rerr *errCloseForRecreating
require.ErrorAs(t, err, &rerr)
require.Equal(t, rerr.nextVersion, protocol.Version2)
case <-time.After(time.Second):
t.Fatal("timeout")
}
require.Contains(t, tracerVersions, protocol.Version(1234))
require.Contains(t, tracerVersions, protocol.Version2)
}
func TestConnectionVersionNegotiationNoMatch(t *testing.T) {
mockCtrl := gomock.NewController(t)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
tc := newClientTestConnection(t,
mockCtrl,
&Config{Versions: []protocol.Version{protocol.Version1}},
false,
connectionOptTracer(tr),
)
tc.packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
var tracerVersions []logging.Version
tracer.EXPECT().ReceivedVersionNegotiationPacket(gomock.Any(), gomock.Any(), gomock.Any()).Do(
func(_, _ protocol.ArbitraryLenConnectionID, versions []logging.Version) { tracerVersions = versions },
)
tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close()
tc.connRunner.EXPECT().Remove(gomock.Any())
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
tc.conn.handlePacket(getVersionNegotiationPacket(
tc.destConnID,
tc.srcConnID,
[]protocol.Version{protocol.Version2},
))
select {
case err := <-errChan:
var verr *VersionNegotiationError
require.ErrorAs(t, err, &verr)
require.Contains(t, verr.Theirs, protocol.Version2)
case <-time.After(time.Second):
t.Fatal("timeout")
}
require.Contains(t, tracerVersions, protocol.Version2)
}
func TestConnectionVersionNegotiationInvalidPackets(t *testing.T) {
mockCtrl := gomock.NewController(t)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
tc := newClientTestConnection(t,
mockCtrl,
nil,
false,
connectionOptTracer(tr),
)
// offers the current version
tracer.EXPECT().DroppedPacket(logging.PacketTypeVersionNegotiation, gomock.Any(), gomock.Any(), logging.PacketDropUnexpectedVersion)
vnp := getVersionNegotiationPacket(
tc.destConnID,
tc.srcConnID,
[]protocol.Version{1234, protocol.Version1},
)
wasProcessed, err := tc.conn.handleOnePacket(vnp)
require.NoError(t, err)
require.False(t, wasProcessed)
require.True(t, mockCtrl.Satisfied())
// unparseable, since it's missing 2 bytes
tracer.EXPECT().DroppedPacket(logging.PacketTypeVersionNegotiation, gomock.Any(), gomock.Any(), logging.PacketDropHeaderParseError)
vnp.data = vnp.data[:len(vnp.data)-2]
wasProcessed, err = tc.conn.handleOnePacket(vnp)
require.NoError(t, err)
require.False(t, wasProcessed)
}
func getRetryPacket(t *testing.T, src, dest, origDest protocol.ConnectionID, token []byte) receivedPacket {
hdr := wire.Header{
Type: protocol.PacketTypeRetry,
SrcConnectionID: src,
DestConnectionID: dest,
Token: token,
Version: protocol.Version1,
}
b, err := (&wire.ExtendedHeader{Header: hdr}).Append(nil, protocol.Version1)
require.NoError(t, err)
tag := handshake.GetRetryIntegrityTag(b, origDest, protocol.Version1)
b = append(b, tag[:]...)
return receivedPacket{
rcvTime: time.Now(),
data: b,
buffer: getPacketBuffer(),
}
}
func TestConnectionRetryDrops(t *testing.T) {
mockCtrl := gomock.NewController(t)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
unpacker := NewMockUnpacker(mockCtrl)
tc := newClientTestConnection(t,
mockCtrl,
nil,
false,
connectionOptTracer(tr),
connectionOptUnpacker(unpacker),
)
newConnID := protocol.ParseConnectionID([]byte{0xde, 0xad, 0xbe, 0xef})
// invalid integrity tag
tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, gomock.Any(), gomock.Any(), logging.PacketDropPayloadDecryptError)
retry := getRetryPacket(t, newConnID, tc.srcConnID, tc.destConnID, []byte("foobar"))
retry.data[len(retry.data)-1]++
wasProcessed, err := tc.conn.handleOnePacket(retry)
require.NoError(t, err)
require.False(t, wasProcessed)
require.True(t, mockCtrl.Satisfied())
// receive a retry that doesn't change the connection ID
tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, gomock.Any(), gomock.Any(), logging.PacketDropUnexpectedPacket)
retry = getRetryPacket(t, tc.destConnID, tc.srcConnID, tc.destConnID, []byte("foobar"))
wasProcessed, err = tc.conn.handleOnePacket(retry)
require.NoError(t, err)
require.False(t, wasProcessed)
}
func TestConnectionRetryAfterReceivedPacket(t *testing.T) {
mockCtrl := gomock.NewController(t)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
unpacker := NewMockUnpacker(mockCtrl)
tc := newClientTestConnection(t,
mockCtrl,
nil,
false,
connectionOptTracer(tr),
connectionOptUnpacker(unpacker),
)
// receive a regular packet
tracer.EXPECT().NegotiatedVersion(gomock.Any(), gomock.Any(), gomock.Any())
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
regular := getPacketWithPacketType(t, tc.srcConnID, protocol.PacketTypeInitial, 200)
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(
&unpackedPacket{
hdr: &wire.ExtendedHeader{Header: wire.Header{Type: protocol.PacketTypeInitial}},
encryptionLevel: protocol.EncryptionInitial,
}, nil,
)
wasProcessed, err := tc.conn.handleOnePacket(receivedPacket{
data: regular,
buffer: getPacketBuffer(),
rcvTime: time.Now(),
remoteAddr: tc.remoteAddr,
})
require.NoError(t, err)
require.True(t, wasProcessed)
// receive a retry
retry := getRetryPacket(t, tc.destConnID, tc.srcConnID, tc.destConnID, []byte("foobar"))
tracer.EXPECT().DroppedPacket(logging.PacketTypeRetry, gomock.Any(), gomock.Any(), logging.PacketDropUnexpectedPacket)
wasProcessed, err = tc.conn.handleOnePacket(retry)
require.NoError(t, err)
require.False(t, wasProcessed)
}
func TestConnectionConnectionIDChanges(t *testing.T) {
t.Run("with retry", func(t *testing.T) {
testConnectionConnectionIDChanges(t, true)
})
t.Run("without retry", func(t *testing.T) {
testConnectionConnectionIDChanges(t, false)
})
}
func testConnectionConnectionIDChanges(t *testing.T, sendRetry bool) {
makeInitialPacket := func(t *testing.T, hdr *wire.ExtendedHeader) []byte {
t.Helper()
data, err := hdr.Append(nil, protocol.Version1)
require.NoError(t, err)
data = append(data, make([]byte, hdr.Length-protocol.ByteCount(hdr.PacketNumberLen))...)
return data
}
mockCtrl := gomock.NewController(t)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
unpacker := NewMockUnpacker(mockCtrl)
tc := newClientTestConnection(t,
mockCtrl,
nil,
false,
connectionOptTracer(tr),
connectionOptUnpacker(unpacker),
)
dstConnID := tc.destConnID
b := make([]byte, 3*10)
rand.Read(b)
newConnID := protocol.ParseConnectionID(b[:11])
newConnID2 := protocol.ParseConnectionID(b[11:20])
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
tracer.EXPECT().NegotiatedVersion(gomock.Any(), gomock.Any(), gomock.Any())
tc.packer.EXPECT().PackCoalescedPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes()
require.Equal(t, dstConnID, tc.conn.connIDManager.Get())
var retryConnID protocol.ConnectionID
if sendRetry {
retryConnID = protocol.ParseConnectionID(b[20:30])
hdrChan := make(chan *wire.Header)
tracer.EXPECT().ReceivedRetry(gomock.Any()).Do(func(hdr *wire.Header) { hdrChan <- hdr })
tc.packer.EXPECT().SetToken([]byte("foobar"))
tc.conn.handlePacket(getRetryPacket(t, retryConnID, tc.srcConnID, tc.destConnID, []byte("foobar")))
select {
case hdr := <-hdrChan:
assert.Equal(t, retryConnID, hdr.SrcConnectionID)
assert.Equal(t, []byte("foobar"), hdr.Token)
require.Equal(t, retryConnID, tc.conn.connIDManager.Get())
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
// Send the first packet. The server changes the connection ID to newConnID.
hdr1 := wire.ExtendedHeader{
Header: wire.Header{
SrcConnectionID: newConnID,
DestConnectionID: tc.srcConnID,
Type: protocol.PacketTypeInitial,
Length: 200,
Version: protocol.Version1,
},
PacketNumber: 1,
PacketNumberLen: protocol.PacketNumberLen2,
}
hdr2 := hdr1
hdr2.SrcConnectionID = newConnID2
receivedFirst := make(chan struct{})
gomock.InOrder(
unpacker.EXPECT().UnpackLongHeader(gomock.Any(), gomock.Any()).Return(
&unpackedPacket{
hdr: &hdr1,
encryptionLevel: protocol.EncryptionInitial,
}, nil,
),
tracer.EXPECT().ReceivedLongHeaderPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Do(
func(*wire.ExtendedHeader, protocol.ByteCount, protocol.ECN, []logging.Frame) { close(receivedFirst) },
),
)
tc.conn.handlePacket(receivedPacket{data: makeInitialPacket(t, &hdr1), buffer: getPacketBuffer(), rcvTime: time.Now(), remoteAddr: tc.remoteAddr})
select {
case <-receivedFirst:
require.Equal(t, newConnID, tc.conn.connIDManager.Get())
case <-time.After(time.Second):
t.Fatal("timeout")
}
// Send the second packet. We refuse to accept it, because the connection ID is changed again.
dropped := make(chan struct{})
tracer.EXPECT().DroppedPacket(logging.PacketTypeInitial, gomock.Any(), gomock.Any(), logging.PacketDropUnknownConnectionID).Do(
func(logging.PacketType, protocol.PacketNumber, protocol.ByteCount, logging.PacketDropReason) {
close(dropped)
},
)
tc.conn.handlePacket(receivedPacket{data: makeInitialPacket(t, &hdr2), buffer: getPacketBuffer(), rcvTime: time.Now(), remoteAddr: tc.remoteAddr})
select {
case <-dropped:
// the connection ID should not have changed
require.Equal(t, newConnID, tc.conn.connIDManager.Get())
case <-time.After(time.Second):
t.Fatal("timeout")
}
// test teardown
tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close()
tc.connRunner.EXPECT().Remove(gomock.Any())
tc.conn.destroy(nil)
select {
case err := <-errChan:
require.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
// When the connection is closed before sending the first packet,
// we don't send a CONNECTION_CLOSE.
// This can happen if there's something wrong the tls.Config, and
// crypto/tls refuses to start the handshake.
func TestConnectionEarlyClose(t *testing.T) {
mockCtrl := gomock.NewController(t)
tr, tracer := mocklogging.NewMockConnectionTracer(mockCtrl)
cryptoSetup := mocks.NewMockCryptoSetup(mockCtrl)
tc := newClientTestConnection(t,
mockCtrl,
nil,
false,
connectionOptTracer(tr),
connectionOptCryptoSetup(cryptoSetup),
)
tc.conn.sentFirstPacket = false
tracer.EXPECT().ClosedConnection(gomock.Any())
tracer.EXPECT().Close()
cryptoSetup.EXPECT().StartHandshake(gomock.Any()).Do(func(context.Context) error {
tc.conn.closeLocal(errors.New("early error"))
return nil
})
cryptoSetup.EXPECT().NextEvent().Return(handshake.Event{Kind: handshake.EventNoEvent})
cryptoSetup.EXPECT().Close()
tc.connRunner.EXPECT().Remove(gomock.Any())
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
select {
case err := <-errChan:
require.Error(t, err)
require.ErrorContains(t, err, "early error")
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestConnectionPathValidation(t *testing.T) {
t.Run("NAT rebinding", func(t *testing.T) {
testConnectionPathValidation(t, true)
})
t.Run("intentional migration", func(t *testing.T) {
testConnectionPathValidation(t, false)
})
}
func testConnectionPathValidation(t *testing.T, isNATRebinding bool) {
mockCtrl := gomock.NewController(t)
unpacker := NewMockUnpacker(mockCtrl)
tc := newServerTestConnection(
t,
mockCtrl,
nil,
false,
connectionOptUnpacker(unpacker),
connectionOptHandshakeConfirmed(),
connectionOptRTT(time.Second),
)
require.NoError(t, tc.conn.handleTransportParameters(&wire.TransportParameters{MaxUDPPayloadSize: 1456}))
newRemoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 1, 1), Port: 1234}
require.NotEqual(t, tc.remoteAddr, newRemoteAddr)
errChan := make(chan error, 1)
go func() { errChan <- tc.conn.run() }()
probeSent := make(chan struct{})
var pathChallenge *wire.PathChallengeFrame
payload := []byte{0} // PADDING frame
if isNATRebinding {
payload = []byte{1} // PING frame
}
gomock.InOrder(
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(
protocol.PacketNumber(10), protocol.PacketNumberLen2, protocol.KeyPhaseZero, payload, nil,
),
tc.packer.EXPECT().PackPathProbePacket(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(_ protocol.ConnectionID, f ackhandler.Frame, _ protocol.Version) (shortHeaderPacket, *packetBuffer, error) {
pathChallenge = f.Frame.(*wire.PathChallengeFrame)
return shortHeaderPacket{IsPathProbePacket: true}, getPacketBuffer(), nil
},
),
tc.sendConn.EXPECT().WriteTo(gomock.Any(), newRemoteAddr).DoAndReturn(
func([]byte, net.Addr) error { close(probeSent); return nil },
),
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(
shortHeaderPacket{}, errNothingToPack,
),
)
tc.conn.handlePacket(receivedPacket{
data: make([]byte, 10),
buffer: getPacketBuffer(),
remoteAddr: newRemoteAddr,
rcvTime: time.Now(),
})
select {
case <-probeSent:
case <-time.After(time.Second):
t.Fatal("timeout")
}
// Receive a packed containing a PATH_RESPONSE frame.
// Only if the first packet received on the path was a probing packet
// (i.e. we're dealing with a NAT rebinding), this makes us switch to the new path.
migrated := make(chan struct{})
data, err := (&wire.PathResponseFrame{Data: pathChallenge.Data}).Append(nil, protocol.Version1)
require.NoError(t, err)
calls := []any{
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(
protocol.PacketNumber(11), protocol.PacketNumberLen2, protocol.KeyPhaseZero, data, nil,
),
}
if isNATRebinding {
calls = append(calls,
tc.sendConn.EXPECT().ChangeRemoteAddr(newRemoteAddr, gomock.Any()).Do(
func(net.Addr, packetInfo) { close(migrated) },
),
)
}
calls = append(calls,
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(
shortHeaderPacket{}, errNothingToPack,
),
)
gomock.InOrder(calls...)
require.Equal(t, tc.remoteAddr, tc.conn.RemoteAddr())
// the PATH_RESPONSE can be sent on the old path, if the client is just probing the new path
addr := tc.remoteAddr
if isNATRebinding {
addr = newRemoteAddr
}
tc.conn.handlePacket(receivedPacket{
data: make([]byte, 100),
buffer: getPacketBuffer(),
remoteAddr: addr,
rcvTime: time.Now(),
})
if !isNATRebinding {
// If the first packet was a probing packet, we only switch to the new path when we
// receive a non-probing packet on that path.
select {
case <-migrated:
t.Fatal("didn't expect a migration yet")
case <-time.After(scaleDuration(10 * time.Millisecond)):
}
payload := []byte{1} // PING frame
gomock.InOrder(
unpacker.EXPECT().UnpackShortHeader(gomock.Any(), gomock.Any()).Return(
protocol.PacketNumber(12), protocol.PacketNumberLen2, protocol.KeyPhaseZero, payload, nil,
),
tc.sendConn.EXPECT().ChangeRemoteAddr(newRemoteAddr, gomock.Any()).Do(
func(net.Addr, packetInfo) { close(migrated) },
),
tc.packer.EXPECT().AppendPacket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(
shortHeaderPacket{}, errNothingToPack,
),
)
tc.conn.handlePacket(receivedPacket{
data: make([]byte, 100),
buffer: getPacketBuffer(),
remoteAddr: newRemoteAddr,
rcvTime: time.Now(),
})
}
select {
case <-migrated:
case <-time.After(time.Second):
t.Fatal("timeout")
}
// test teardown
tc.connRunner.EXPECT().Remove(gomock.Any()).AnyTimes()
tc.conn.destroy(nil)
select {
case <-errChan:
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
golang-github-lucas-clemente-quic-go-0.50.0/connection_timer.go 0000664 0000000 0000000 00000002565 14765760516 0024503 0 ustar 00root root 0000000 0000000 package quic
import (
"time"
"github.com/quic-go/quic-go/internal/utils"
)
var deadlineSendImmediately = time.Time{}.Add(42 * time.Millisecond) // any value > time.Time{} and before time.Now() is fine
type connectionTimer struct {
timer *utils.Timer
last time.Time
}
func newTimer() *connectionTimer {
return &connectionTimer{timer: utils.NewTimer()}
}
func (t *connectionTimer) SetRead() {
if deadline := t.timer.Deadline(); deadline != deadlineSendImmediately {
t.last = deadline
}
t.timer.SetRead()
}
func (t *connectionTimer) Chan() <-chan time.Time {
return t.timer.Chan()
}
// SetTimer resets the timer.
// It makes sure that the deadline is strictly increasing.
// This prevents busy-looping in cases where the timer fires, but we can't actually send out a packet.
// This doesn't apply to the pacing deadline, which can be set multiple times to deadlineSendImmediately.
func (t *connectionTimer) SetTimer(idleTimeoutOrKeepAlive, ackAlarm, lossTime, pacing time.Time) {
deadline := idleTimeoutOrKeepAlive
if !ackAlarm.IsZero() && ackAlarm.Before(deadline) && ackAlarm.After(t.last) {
deadline = ackAlarm
}
if !lossTime.IsZero() && lossTime.Before(deadline) && lossTime.After(t.last) {
deadline = lossTime
}
if !pacing.IsZero() && pacing.Before(deadline) {
deadline = pacing
}
t.timer.Reset(deadline)
}
func (t *connectionTimer) Stop() {
t.timer.Stop()
}
golang-github-lucas-clemente-quic-go-0.50.0/connection_timer_test.go 0000664 0000000 0000000 00000002734 14765760516 0025540 0 ustar 00root root 0000000 0000000 package quic
import (
"testing"
"time"
"github.com/stretchr/testify/require"
)
func (t *connectionTimer) Deadline() time.Time { return t.timer.Deadline() }
func TestConnectionTimerModes(t *testing.T) {
now := time.Now()
t.Run("idle timeout", func(t *testing.T) {
timer := newTimer()
timer.SetTimer(now.Add(time.Hour), time.Time{}, time.Time{}, time.Time{})
require.Equal(t, now.Add(time.Hour), timer.Deadline())
})
t.Run("ACK timer", func(t *testing.T) {
timer := newTimer()
timer.SetTimer(now.Add(time.Hour), now.Add(time.Minute), time.Time{}, time.Time{})
require.Equal(t, now.Add(time.Minute), timer.Deadline())
})
t.Run("loss timer", func(t *testing.T) {
timer := newTimer()
timer.SetTimer(now.Add(time.Hour), now.Add(time.Minute), now.Add(time.Second), time.Time{})
require.Equal(t, now.Add(time.Second), timer.Deadline())
})
t.Run("pacing timer", func(t *testing.T) {
timer := newTimer()
timer.SetTimer(now.Add(time.Hour), now.Add(time.Minute), now.Add(time.Second), now.Add(time.Millisecond))
require.Equal(t, now.Add(time.Millisecond), timer.Deadline())
})
}
func TestConnectionTimerReset(t *testing.T) {
now := time.Now()
timer := newTimer()
timer.SetTimer(now.Add(time.Hour), now.Add(time.Minute), time.Time{}, time.Time{})
require.Equal(t, now.Add(time.Minute), timer.Deadline())
timer.SetRead()
timer.SetTimer(now.Add(time.Hour), now.Add(time.Minute), time.Time{}, time.Time{})
require.Equal(t, now.Add(time.Hour), timer.Deadline())
}
golang-github-lucas-clemente-quic-go-0.50.0/crypto_stream.go 0000664 0000000 0000000 00000004415 14765760516 0024033 0 ustar 00root root 0000000 0000000 package quic
import (
"fmt"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/wire"
)
type cryptoStream struct {
queue frameSorter
highestOffset protocol.ByteCount
finished bool
writeOffset protocol.ByteCount
writeBuf []byte
}
func newCryptoStream() *cryptoStream {
return &cryptoStream{queue: *newFrameSorter()}
}
func (s *cryptoStream) HandleCryptoFrame(f *wire.CryptoFrame) error {
highestOffset := f.Offset + protocol.ByteCount(len(f.Data))
if maxOffset := highestOffset; maxOffset > protocol.MaxCryptoStreamOffset {
return &qerr.TransportError{
ErrorCode: qerr.CryptoBufferExceeded,
ErrorMessage: fmt.Sprintf("received invalid offset %d on crypto stream, maximum allowed %d", maxOffset, protocol.MaxCryptoStreamOffset),
}
}
if s.finished {
if highestOffset > s.highestOffset {
// reject crypto data received after this stream was already finished
return &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: "received crypto data after change of encryption level",
}
}
// ignore data with a smaller offset than the highest received
// could e.g. be a retransmission
return nil
}
s.highestOffset = max(s.highestOffset, highestOffset)
return s.queue.Push(f.Data, f.Offset, nil)
}
// GetCryptoData retrieves data that was received in CRYPTO frames
func (s *cryptoStream) GetCryptoData() []byte {
_, data, _ := s.queue.Pop()
return data
}
func (s *cryptoStream) Finish() error {
if s.queue.HasMoreData() {
return &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: "encryption level changed, but crypto stream has more data to read",
}
}
s.finished = true
return nil
}
// Writes writes data that should be sent out in CRYPTO frames
func (s *cryptoStream) Write(p []byte) (int, error) {
s.writeBuf = append(s.writeBuf, p...)
return len(p), nil
}
func (s *cryptoStream) HasData() bool {
return len(s.writeBuf) > 0
}
func (s *cryptoStream) PopCryptoFrame(maxLen protocol.ByteCount) *wire.CryptoFrame {
f := &wire.CryptoFrame{Offset: s.writeOffset}
n := min(f.MaxDataLen(maxLen), protocol.ByteCount(len(s.writeBuf)))
f.Data = s.writeBuf[:n]
s.writeBuf = s.writeBuf[n:]
s.writeOffset += n
return f
}
golang-github-lucas-clemente-quic-go-0.50.0/crypto_stream_manager.go 0000664 0000000 0000000 00000004210 14765760516 0025516 0 ustar 00root root 0000000 0000000 package quic
import (
"fmt"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/wire"
)
type cryptoStreamManager struct {
initialStream *cryptoStream
handshakeStream *cryptoStream
oneRTTStream *cryptoStream
}
func newCryptoStreamManager(
initialStream *cryptoStream,
handshakeStream *cryptoStream,
oneRTTStream *cryptoStream,
) *cryptoStreamManager {
return &cryptoStreamManager{
initialStream: initialStream,
handshakeStream: handshakeStream,
oneRTTStream: oneRTTStream,
}
}
func (m *cryptoStreamManager) HandleCryptoFrame(frame *wire.CryptoFrame, encLevel protocol.EncryptionLevel) error {
var str *cryptoStream
//nolint:exhaustive // CRYPTO frames cannot be sent in 0-RTT packets.
switch encLevel {
case protocol.EncryptionInitial:
str = m.initialStream
case protocol.EncryptionHandshake:
str = m.handshakeStream
case protocol.Encryption1RTT:
str = m.oneRTTStream
default:
return fmt.Errorf("received CRYPTO frame with unexpected encryption level: %s", encLevel)
}
return str.HandleCryptoFrame(frame)
}
func (m *cryptoStreamManager) GetCryptoData(encLevel protocol.EncryptionLevel) []byte {
var str *cryptoStream
//nolint:exhaustive // CRYPTO frames cannot be sent in 0-RTT packets.
switch encLevel {
case protocol.EncryptionInitial:
str = m.initialStream
case protocol.EncryptionHandshake:
str = m.handshakeStream
case protocol.Encryption1RTT:
str = m.oneRTTStream
default:
panic(fmt.Sprintf("received CRYPTO frame with unexpected encryption level: %s", encLevel))
}
return str.GetCryptoData()
}
func (m *cryptoStreamManager) GetPostHandshakeData(maxSize protocol.ByteCount) *wire.CryptoFrame {
if !m.oneRTTStream.HasData() {
return nil
}
return m.oneRTTStream.PopCryptoFrame(maxSize)
}
func (m *cryptoStreamManager) Drop(encLevel protocol.EncryptionLevel) error {
//nolint:exhaustive // 1-RTT keys should never get dropped.
switch encLevel {
case protocol.EncryptionInitial:
return m.initialStream.Finish()
case protocol.EncryptionHandshake:
return m.handshakeStream.Finish()
default:
panic(fmt.Sprintf("dropped unexpected encryption level: %s", encLevel))
}
}
golang-github-lucas-clemente-quic-go-0.50.0/crypto_stream_manager_test.go 0000664 0000000 0000000 00000005425 14765760516 0026566 0 ustar 00root root 0000000 0000000 package quic
import (
"testing"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/wire"
"github.com/stretchr/testify/require"
)
func TestCryptoStreamManager(t *testing.T) {
t.Run("Initial", func(t *testing.T) {
testCryptoStreamManager(t, protocol.EncryptionInitial)
})
t.Run("Handshake", func(t *testing.T) {
testCryptoStreamManager(t, protocol.EncryptionHandshake)
})
t.Run("1-RTT", func(t *testing.T) {
testCryptoStreamManager(t, protocol.Encryption1RTT)
})
}
func testCryptoStreamManager(t *testing.T, encLevel protocol.EncryptionLevel) {
initialStream := newCryptoStream()
handshakeStream := newCryptoStream()
oneRTTStream := newCryptoStream()
csm := newCryptoStreamManager(initialStream, handshakeStream, oneRTTStream)
require.NoError(t, csm.HandleCryptoFrame(&wire.CryptoFrame{Data: []byte("foo")}, encLevel))
require.NoError(t, csm.HandleCryptoFrame(&wire.CryptoFrame{Data: []byte("bar"), Offset: 3}, encLevel))
var data []byte
for {
b := csm.GetCryptoData(encLevel)
if len(b) == 0 {
break
}
data = append(data, b...)
}
require.Equal(t, []byte("foobar"), data)
}
func TestCryptoStreamManagerInvalidEncryptionLevel(t *testing.T) {
csm := newCryptoStreamManager(nil, nil, nil)
require.ErrorContains(t,
csm.HandleCryptoFrame(&wire.CryptoFrame{}, protocol.Encryption0RTT),
"received CRYPTO frame with unexpected encryption level",
)
}
func TestCryptoStreamManagerDropEncryptionLevel(t *testing.T) {
t.Run("Initial", func(t *testing.T) {
testCryptoStreamManagerDropEncryptionLevel(t, protocol.EncryptionInitial)
})
t.Run("Handshake", func(t *testing.T) {
testCryptoStreamManagerDropEncryptionLevel(t, protocol.EncryptionHandshake)
})
}
func testCryptoStreamManagerDropEncryptionLevel(t *testing.T, encLevel protocol.EncryptionLevel) {
initialStream := newCryptoStream()
handshakeStream := newCryptoStream()
oneRTTStream := newCryptoStream()
csm := newCryptoStreamManager(initialStream, handshakeStream, oneRTTStream)
require.NoError(t, csm.HandleCryptoFrame(&wire.CryptoFrame{Data: []byte("foo")}, encLevel))
require.ErrorContains(t, csm.Drop(encLevel), "encryption level changed, but crypto stream has more data to read")
require.Equal(t, []byte("foo"), csm.GetCryptoData(encLevel))
require.NoError(t, csm.Drop(encLevel))
}
func TestCryptoStreamManagerPostHandshake(t *testing.T) {
initialStream := newCryptoStream()
handshakeStream := newCryptoStream()
oneRTTStream := newCryptoStream()
csm := newCryptoStreamManager(initialStream, handshakeStream, oneRTTStream)
_, err := oneRTTStream.Write([]byte("foo"))
require.NoError(t, err)
_, err = oneRTTStream.Write([]byte("bar"))
require.NoError(t, err)
require.Equal(t,
&wire.CryptoFrame{Data: []byte("foobar")},
csm.GetPostHandshakeData(protocol.ByteCount(10)),
)
}
golang-github-lucas-clemente-quic-go-0.50.0/crypto_stream_test.go 0000664 0000000 0000000 00000006500 14765760516 0025067 0 ustar 00root root 0000000 0000000 package quic
import (
"testing"
"github.com/quic-go/quic-go/internal/protocol"
"github.com/quic-go/quic-go/internal/qerr"
"github.com/quic-go/quic-go/internal/wire"
"github.com/stretchr/testify/require"
)
func TestCryptoStreamDataAssembly(t *testing.T) {
str := newCryptoStream()
require.NoError(t, str.HandleCryptoFrame(&wire.CryptoFrame{Data: []byte("bar"), Offset: 3}))
require.NoError(t, str.HandleCryptoFrame(&wire.CryptoFrame{Data: []byte("foo")}))
// receive a retransmission
require.NoError(t, str.HandleCryptoFrame(&wire.CryptoFrame{Data: []byte("bar"), Offset: 3}))
var data []byte
for {
b := str.GetCryptoData()
if b == nil {
break
}
data = append(data, b...)
}
require.Equal(t, []byte("foobar"), data)
}
func TestCryptoStreamMaxOffset(t *testing.T) {
str := newCryptoStream()
require.NoError(t, str.HandleCryptoFrame(&wire.CryptoFrame{
Offset: protocol.MaxCryptoStreamOffset - 5,
Data: []byte("foo"),
}))
require.ErrorIs(t,
str.HandleCryptoFrame(&wire.CryptoFrame{
Offset: protocol.MaxCryptoStreamOffset - 2,
Data: []byte("bar"),
}),
&qerr.TransportError{ErrorCode: qerr.CryptoBufferExceeded},
)
}
func TestCryptoStreamFinishWithQueuedData(t *testing.T) {
t.Run("with data at current offset", func(t *testing.T) {
str := newCryptoStream()
require.NoError(t, str.HandleCryptoFrame(&wire.CryptoFrame{Data: []byte("foo")}))
require.Equal(t, []byte("foo"), str.GetCryptoData())
require.NoError(t, str.HandleCryptoFrame(&wire.CryptoFrame{Data: []byte("bar"), Offset: 3}))
require.ErrorIs(t, str.Finish(), &qerr.TransportError{ErrorCode: qerr.ProtocolViolation})
})
t.Run("with data at a higher offset", func(t *testing.T) {
str := newCryptoStream()
require.NoError(t, str.HandleCryptoFrame(&wire.CryptoFrame{Data: []byte("foobar"), Offset: 20}))
require.ErrorIs(t, str.Finish(), &qerr.TransportError{ErrorCode: qerr.ProtocolViolation})
})
}
func TestCryptoStreamReceiveDataAfterFinish(t *testing.T) {
str := newCryptoStream()
require.NoError(t, str.HandleCryptoFrame(&wire.CryptoFrame{Data: []byte("foobar")}))
require.Equal(t, []byte("foobar"), str.GetCryptoData())
require.NoError(t, str.Finish())
// receiving a retransmission is ok
require.NoError(t, str.HandleCryptoFrame(&wire.CryptoFrame{Data: []byte("bar"), Offset: 3}))
// but receiving new data is not
require.ErrorIs(t,
str.HandleCryptoFrame(&wire.CryptoFrame{Data: []byte("baz"), Offset: 4}),
&qerr.TransportError{ErrorCode: qerr.ProtocolViolation},
)
}
func TestCryptoStreamWrite(t *testing.T) {
expectedCryptoFrameLen := func(offset protocol.ByteCount) protocol.ByteCount {
f := &wire.CryptoFrame{Offset: offset}
return f.Length(protocol.Version1)
}
str := newCryptoStream()
require.False(t, str.HasData())
_, err := str.Write([]byte("foo"))
require.NoError(t, err)
require.True(t, str.HasData())
_, err = str.Write([]byte("bar"))
require.NoError(t, err)
_, err = str.Write([]byte("baz"))
require.NoError(t, err)
require.True(t, str.HasData())
f := str.PopCryptoFrame(expectedCryptoFrameLen(0) + 3)
require.Equal(t, &wire.CryptoFrame{Data: []byte("foo")}, f)
require.True(t, str.HasData())
f = str.PopCryptoFrame(protocol.MaxByteCount)
// the two write calls were coalesced into a single frame
require.Equal(t, &wire.CryptoFrame{Offset: 3, Data: []byte("barbaz")}, f)
require.False(t, str.HasData())
}
golang-github-lucas-clemente-quic-go-0.50.0/datagram_queue.go 0000664 0000000 0000000 00000005640 14765760516 0024125 0 ustar 00root root 0000000 0000000 package quic
import (
"context"
"sync"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/utils/ringbuffer"
"github.com/quic-go/quic-go/internal/wire"
)
const (
maxDatagramSendQueueLen = 32
maxDatagramRcvQueueLen = 128
)
type datagramQueue struct {
sendMx sync.Mutex
sendQueue ringbuffer.RingBuffer[*wire.DatagramFrame]
sent chan struct{} // used to notify Add that a datagram was dequeued
rcvMx sync.Mutex
rcvQueue [][]byte
rcvd chan struct{} // used to notify Receive that a new datagram was received
closeErr error
closed chan struct{}
hasData func()
logger utils.Logger
}
func newDatagramQueue(hasData func(), logger utils.Logger) *datagramQueue {
return &datagramQueue{
hasData: hasData,
rcvd: make(chan struct{}, 1),
sent: make(chan struct{}, 1),
closed: make(chan struct{}),
logger: logger,
}
}
// Add queues a new DATAGRAM frame for sending.
// Up to 32 DATAGRAM frames will be queued.
// Once that limit is reached, Add blocks until the queue size has reduced.
func (h *datagramQueue) Add(f *wire.DatagramFrame) error {
h.sendMx.Lock()
for {
if h.sendQueue.Len() < maxDatagramSendQueueLen {
h.sendQueue.PushBack(f)
h.sendMx.Unlock()
h.hasData()
return nil
}
select {
case <-h.sent: // drain the queue so we don't loop immediately
default:
}
h.sendMx.Unlock()
select {
case <-h.closed:
return h.closeErr
case <-h.sent:
}
h.sendMx.Lock()
}
}
// Peek gets the next DATAGRAM frame for sending.
// If actually sent out, Pop needs to be called before the next call to Peek.
func (h *datagramQueue) Peek() *wire.DatagramFrame {
h.sendMx.Lock()
defer h.sendMx.Unlock()
if h.sendQueue.Empty() {
return nil
}
return h.sendQueue.PeekFront()
}
func (h *datagramQueue) Pop() {
h.sendMx.Lock()
defer h.sendMx.Unlock()
_ = h.sendQueue.PopFront()
select {
case h.sent <- struct{}{}:
default:
}
}
// HandleDatagramFrame handles a received DATAGRAM frame.
func (h *datagramQueue) HandleDatagramFrame(f *wire.DatagramFrame) {
data := make([]byte, len(f.Data))
copy(data, f.Data)
var queued bool
h.rcvMx.Lock()
if len(h.rcvQueue) < maxDatagramRcvQueueLen {
h.rcvQueue = append(h.rcvQueue, data)
queued = true
select {
case h.rcvd <- struct{}{}:
default:
}
}
h.rcvMx.Unlock()
if !queued && h.logger.Debug() {
h.logger.Debugf("Discarding received DATAGRAM frame (%d bytes payload)", len(f.Data))
}
}
// Receive gets a received DATAGRAM frame.
func (h *datagramQueue) Receive(ctx context.Context) ([]byte, error) {
for {
h.rcvMx.Lock()
if len(h.rcvQueue) > 0 {
data := h.rcvQueue[0]
h.rcvQueue = h.rcvQueue[1:]
h.rcvMx.Unlock()
return data, nil
}
h.rcvMx.Unlock()
select {
case <-h.rcvd:
continue
case <-h.closed:
return nil, h.closeErr
case <-ctx.Done():
return nil, ctx.Err()
}
}
}
func (h *datagramQueue) CloseWithError(e error) {
h.closeErr = e
close(h.closed)
}
golang-github-lucas-clemente-quic-go-0.50.0/datagram_queue_test.go 0000664 0000000 0000000 00000010371 14765760516 0025161 0 ustar 00root root 0000000 0000000 package quic
import (
"context"
"errors"
"testing"
"time"
"github.com/quic-go/quic-go/internal/utils"
"github.com/quic-go/quic-go/internal/wire"
"github.com/stretchr/testify/require"
)
func TestDatagramQueuePeekAndPop(t *testing.T) {
var queued []struct{}
queue := newDatagramQueue(func() { queued = append(queued, struct{}{}) }, utils.DefaultLogger)
require.Nil(t, queue.Peek())
require.Empty(t, queued)
require.NoError(t, queue.Add(&wire.DatagramFrame{Data: []byte("foo")}))
require.Len(t, queued, 1)
require.Equal(t, &wire.DatagramFrame{Data: []byte("foo")}, queue.Peek())
// calling peek again returns the same datagram
require.Equal(t, &wire.DatagramFrame{Data: []byte("foo")}, queue.Peek())
queue.Pop()
require.Nil(t, queue.Peek())
}
func TestDatagramQueueSendQueueLength(t *testing.T) {
queue := newDatagramQueue(func() {}, utils.DefaultLogger)
for i := 0; i < maxDatagramSendQueueLen; i++ {
require.NoError(t, queue.Add(&wire.DatagramFrame{Data: []byte{0}}))
}
errChan := make(chan error, 1)
go func() { errChan <- queue.Add(&wire.DatagramFrame{Data: []byte("foobar")}) }()
select {
case <-errChan:
t.Fatal("expected to not receive error")
case <-time.After(scaleDuration(10 * time.Millisecond)):
}
// peeking doesn't remove the datagram from the queue...
require.NotNil(t, queue.Peek())
select {
case <-errChan:
t.Fatal("expected to not receive error")
case <-time.After(scaleDuration(10 * time.Millisecond)):
}
// ...but popping does
queue.Pop()
select {
case err := <-errChan:
require.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("timeout")
}
// pop all the remaining datagrams
for i := 1; i < maxDatagramSendQueueLen; i++ {
queue.Pop()
}
f := queue.Peek()
require.NotNil(t, f)
require.Equal(t, &wire.DatagramFrame{Data: []byte("foobar")}, f)
}
func TestDatagramQueueReceive(t *testing.T) {
queue := newDatagramQueue(func() {}, utils.DefaultLogger)
// receive frames that were received earlier
queue.HandleDatagramFrame(&wire.DatagramFrame{Data: []byte("foo")})
queue.HandleDatagramFrame(&wire.DatagramFrame{Data: []byte("bar")})
data, err := queue.Receive(context.Background())
require.NoError(t, err)
require.Equal(t, []byte("foo"), data)
data, err = queue.Receive(context.Background())
require.NoError(t, err)
require.Equal(t, []byte("bar"), data)
}
func TestDatagramQueueReceiveBlocking(t *testing.T) {
queue := newDatagramQueue(func() {}, utils.DefaultLogger)
// block until a new frame is received
type result struct {
data []byte
err error
}
resultChan := make(chan result, 1)
go func() {
data, err := queue.Receive(context.Background())
resultChan <- result{data, err}
}()
select {
case <-resultChan:
t.Fatal("expected to not receive result")
case <-time.After(scaleDuration(10 * time.Millisecond)):
}
queue.HandleDatagramFrame(&wire.DatagramFrame{Data: []byte("foobar")})
select {
case result := <-resultChan:
require.NoError(t, result.err)
require.Equal(t, []byte("foobar"), result.data)
case <-time.After(time.Second):
t.Fatal("timeout")
}
// unblock when the context is canceled
ctx, cancel := context.WithCancel(context.Background())
errChan := make(chan error, 1)
go func() {
_, err := queue.Receive(ctx)
errChan <- err
}()
select {
case <-errChan:
t.Fatal("expected to not receive error")
case <-time.After(scaleDuration(10 * time.Millisecond)):
}
cancel()
select {
case err := <-errChan:
require.ErrorIs(t, err, context.Canceled)
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
func TestDatagramQueueClose(t *testing.T) {
queue := newDatagramQueue(func() {}, utils.DefaultLogger)
for i := 0; i < maxDatagramSendQueueLen; i++ {
require.NoError(t, queue.Add(&wire.DatagramFrame{Data: []byte{0}}))
}
errChan1 := make(chan error, 1)
go func() { errChan1 <- queue.Add(&wire.DatagramFrame{Data: []byte("foobar")}) }()
errChan2 := make(chan error, 1)
go func() {
_, err := queue.Receive(context.Background())
errChan2 <- err
}()
queue.CloseWithError(errors.New("test error"))
select {
case err := <-errChan1:
require.EqualError(t, err, "test error")
case <-time.After(time.Second):
t.Fatal("timeout")
}
select {
case err := <-errChan2:
require.EqualError(t, err, "test error")
case <-time.After(time.Second):
t.Fatal("timeout")
}
}
golang-github-lucas-clemente-quic-go-0.50.0/docs/ 0000775 0000000 0000000 00000000000 14765760516 0021535 5 ustar 00root root 0000000 0000000 golang-github-lucas-clemente-quic-go-0.50.0/docs/quic.png 0000664 0000000 0000000 00000041623 14765760516 0023212 0 ustar 00root root 0000000 0000000 ‰PNG
IHDR ^ ø =ûè\ CZIDATxìÖÌØPFáÙ¶mÛ¶mÛ¶mÛ¶Íh¶mÛ6¾ÙûmïIž¨nïmëE)¥”RJ)¥”RJ9.¿È…šè‚q‚æ(…¸pfJ)¥”RÊJc>Ãpí?§| ºa"Ö`?Öc&ú¡‚â?J)¥”R¥pæ¯ÐþñßGi0·aŽðkQ>áÌT@DG2dAaTB´GO´A=”GdD"Do¸eJ)¥TBl†ý(xìx·DYKY¿™eéÚÏÒ·éb‰«Ô²¨Ùs›þ~ZP‘ADÄE*¤C|„G ¸uá0æ§PHyG\”B,Ã9¼…¹À+Æ,´A>D€k¤”RJåÃ#Øg~C„´ÔMÚXù5›Þ±+¿UkßiË9p”…O•îÇo×[4„kéQ}°ûq7ñÌ ßÕÛ8ÝXŽî(ð
çÖ ` 0\„?«Yº°ãgZ©EkðÔ–£ßpKÕ¸•…IšÂ¼xõú«sÝ„(ø”J†®Ø‰g0wt«QááÔ”RJ©zx
ƒyóéË’Õj`5vŸ0~¬œ$gÿ‘æ7xˆ¿UmáÔ¼"!c2¶â6Ì<ÄfŒ@uDpä9„@ÌüE¬ø¼•ŽºU6î³ô;›Ÿ ïØ;è(’î‹7– $DÐww‚»/îîn«ÖqY÷Å}qwwwww[ÇÞ¿ÞœjþõõŽôÔtÍDÞïœûI¦3ЗªW÷ç>Ÿ¶L$cªÉoÆ
&ˆ%zÍÍßp>œê
‚ ‚è%¾K²FB›5ÛÑH«ë¾Ó±T9ã;j˜æš¦®|vçÄ2äïØüv½aj18w>h