pax_global_header00006660000000000000000000000064143602160630014513gustar00rootroot0000000000000052 comment=c18816d341d933215c41476257bc2d6fe086cede sctp-1.8.6/000077500000000000000000000000001436021606300125005ustar00rootroot00000000000000sctp-1.8.6/.github/000077500000000000000000000000001436021606300140405ustar00rootroot00000000000000sctp-1.8.6/.github/generate-authors.sh000077500000000000000000000033471436021606300176630ustar00rootroot00000000000000#!/usr/bin/env bash # # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # set -e SCRIPT_PATH=$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P ) if [ -z "${AUTHORS_PATH}" ]; then AUTHORS_PATH="$GITHUB_WORKSPACE/AUTHORS.txt" fi if [ -f ${SCRIPT_PATH}/.ci.conf ]; then . ${SCRIPT_PATH}/.ci.conf fi # # DO NOT EDIT THIS # EXCLUDED_CONTRIBUTORS+=('John R. Bradley' 'renovate[bot]' 'Renovate Bot' 'Pion Bot' 'pionbot') # If you want to exclude a name from all repositories, send a PR to # https://github.com/pion/.goassets instead of this repository. # If you want to exclude a name only from this repository, # add EXCLUDED_CONTRIBUTORS=('name') to .github/.ci.conf CONTRIBUTORS=() shouldBeIncluded () { for i in "${EXCLUDED_CONTRIBUTORS[@]}"; do if [[ $1 =~ "$i" ]]; then return 1 fi done return 0 } IFS=$'\n' #Only split on newline for CONTRIBUTOR in $( ( git log --format='%aN <%aE>' git log --format='%(trailers:key=Co-authored-by)' | sed -n 's/^[^:]*:\s*//p' ) | LC_ALL=C.UTF-8 sort -uf ); do if shouldBeIncluded ${CONTRIBUTOR}; then CONTRIBUTORS+=("${CONTRIBUTOR}") fi done unset IFS if [ ${#CONTRIBUTORS[@]} -ne 0 ]; then cat >${AUTHORS_PATH} <<-'EOH' # Thank you to everyone that made Pion possible. If you are interested in contributing # we would love to have you https://github.com/pion/webrtc/wiki/Contributing # # This file is auto generated, using git to list all individuals contributors. # see `.github/generate-authors.sh` for the scripting EOH for i in "${CONTRIBUTORS[@]}"; do echo "$i" >> ${AUTHORS_PATH} done exit 0 fi sctp-1.8.6/.github/hooks/000077500000000000000000000000001436021606300151635ustar00rootroot00000000000000sctp-1.8.6/.github/hooks/commit-msg.sh000077500000000000000000000002671436021606300176030ustar00rootroot00000000000000#!/usr/bin/env bash # # DO NOT EDIT THIS FILE DIRECTLY # # It is automatically copied from https://github.com/pion/.goassets repository. # set -e .github/lint-commit-message.sh $1 sctp-1.8.6/.github/hooks/pre-commit.sh000077500000000000000000000004171436021606300176000ustar00rootroot00000000000000#!/bin/sh # # DO NOT EDIT THIS FILE DIRECTLY # # It is automatically copied from https://github.com/pion/.goassets repository. # # Redirect output to stderr. exec 1>&2 .github/lint-disallowed-functions-in-library.sh .github/lint-no-trailing-newline-in-log-messages.sh sctp-1.8.6/.github/hooks/pre-push.sh000077500000000000000000000002571436021606300172710ustar00rootroot00000000000000#!/bin/sh # # DO NOT EDIT THIS FILE DIRECTLY # # It is automatically copied from https://github.com/pion/.goassets repository. # set -e .github/generate-authors.sh exit 0 sctp-1.8.6/.github/install-hooks.sh000077500000000000000000000010521436021606300171640ustar00rootroot00000000000000#!/bin/bash # # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # SCRIPT_PATH=$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P ) cp "${SCRIPT_PATH}/hooks/commit-msg.sh" "${SCRIPT_PATH}/../.git/hooks/commit-msg" cp "${SCRIPT_PATH}/hooks/pre-commit.sh" "${SCRIPT_PATH}/../.git/hooks/pre-commit" cp "${SCRIPT_PATH}/hooks/pre-push.sh" "${SCRIPT_PATH}/../.git/hooks/pre-push" sctp-1.8.6/.github/lint-commit-message.sh000077500000000000000000000035661436021606300202670ustar00rootroot00000000000000#!/usr/bin/env bash # # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # set -e display_commit_message_error() { cat << EndOfMessage $1 ------------------------------------------------- The preceding commit message is invalid it failed '$2' of the following checks * Separate subject from body with a blank line * Limit the subject line to 50 characters * Capitalize the subject line * Do not end the subject line with a period * Wrap the body at 72 characters EndOfMessage exit 1 } lint_commit_message() { if [[ "$(echo "$1" | awk 'NR == 2 {print $1;}' | wc -c)" -ne 1 ]]; then display_commit_message_error "$1" 'Separate subject from body with a blank line' fi if [[ "$(echo "$1" | head -n1 | awk '{print length}')" -gt 50 ]]; then display_commit_message_error "$1" 'Limit the subject line to 50 characters' fi if [[ ! $1 =~ ^[A-Z] ]]; then display_commit_message_error "$1" 'Capitalize the subject line' fi if [[ "$(echo "$1" | awk 'NR == 1 {print substr($0,length($0),1)}')" == "." ]]; then display_commit_message_error "$1" 'Do not end the subject line with a period' fi if [[ "$(echo "$1" | awk '{print length}' | sort -nr | head -1)" -gt 72 ]]; then display_commit_message_error "$1" 'Wrap the body at 72 characters' fi } if [ "$#" -eq 1 ]; then if [ ! -f "$1" ]; then echo "$0 was passed one argument, but was not a valid file" exit 1 fi lint_commit_message "$(sed -n '/# Please enter the commit message for your changes. Lines starting/q;p' "$1")" else for COMMIT in $(git rev-list --no-merges origin/master..); do lint_commit_message "$(git log --format="%B" -n 1 ${COMMIT})" done fi sctp-1.8.6/.github/lint-disallowed-functions-in-library.sh000077500000000000000000000024251436021606300235510ustar00rootroot00000000000000#!/usr/bin/env bash # # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # set -e # Disallow usages of functions that cause the program to exit in the library code SCRIPT_PATH=$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P ) if [ -f ${SCRIPT_PATH}/.ci.conf ]; then . ${SCRIPT_PATH}/.ci.conf fi EXCLUDE_DIRECTORIES=${DISALLOWED_FUNCTIONS_EXCLUDED_DIRECTORIES:-"examples"} DISALLOWED_FUNCTIONS=('os.Exit(' 'panic(' 'Fatal(' 'Fatalf(' 'Fatalln(' 'fmt.Println(' 'fmt.Printf(' 'log.Print(' 'log.Println(' 'log.Printf(' 'print(' 'println(') FILES=$( find "${SCRIPT_PATH}/.." -name "*.go" \ | grep -v -e '^.*_test.go$' \ | while read FILE; do EXCLUDED=false for EXCLUDE_DIRECTORY in ${EXCLUDE_DIRECTORIES}; do if [[ ${FILE} == */${EXCLUDE_DIRECTORY}/* ]]; then EXCLUDED=true break fi done ${EXCLUDED} || echo "${FILE}" done ) for DISALLOWED_FUNCTION in "${DISALLOWED_FUNCTIONS[@]}"; do if grep -e "\s${DISALLOWED_FUNCTION}" ${FILES} | grep -v -e 'nolint'; then echo "${DISALLOWED_FUNCTION} may only be used in example code" exit 1 fi done sctp-1.8.6/.github/lint-filename.sh000077500000000000000000000012161436021606300171230ustar00rootroot00000000000000#!/usr/bin/env bash # # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # set -e SCRIPT_PATH=$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P ) GO_REGEX="^[a-zA-Z][a-zA-Z0-9_]*\.go$" find "${SCRIPT_PATH}/.." -name "*.go" | while read FULLPATH; do FILENAME=$(basename -- "${FULLPATH}") if ! [[ ${FILENAME} =~ ${GO_REGEX} ]]; then echo "${FILENAME} is not a valid filename for Go code, only alpha, numbers and underscores are supported" exit 1 fi done sctp-1.8.6/.github/lint-no-trailing-newline-in-log-messages.sh000077500000000000000000000017121436021606300242160ustar00rootroot00000000000000#!/usr/bin/env bash # # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # set -e # Disallow usages of functions that cause the program to exit in the library code SCRIPT_PATH=$( cd "$(dirname "${BASH_SOURCE[0]}")" ; pwd -P ) if [ -f ${SCRIPT_PATH}/.ci.conf ]; then . ${SCRIPT_PATH}/.ci.conf fi FILES=$( find "${SCRIPT_PATH}/.." -name "*.go" \ | while read FILE; do EXCLUDED=false for EXCLUDE_DIRECTORY in ${EXCLUDE_DIRECTORIES}; do if [[ $file == */${EXCLUDE_DIRECTORY}/* ]]; then EXCLUDED=true break fi done ${EXCLUDED} || echo "${FILE}" done ) if grep -E '\.(Trace|Debug|Info|Warn|Error)f?\("[^"]*\\n"\)?' ${FILES} | grep -v -e 'nolint'; then echo "Log format strings should have trailing new-line" exit 1 fisctp-1.8.6/.github/workflows/000077500000000000000000000000001436021606300160755ustar00rootroot00000000000000sctp-1.8.6/.github/workflows/codeql-analysis.yml000066400000000000000000000015411436021606300217110ustar00rootroot00000000000000name: "CodeQL" on: workflow_dispatch: schedule: - cron: '23 5 * * 0' pull_request: branches: - master paths: - '**.go' jobs: analyze: name: Analyze runs-on: ubuntu-latest permissions: actions: read contents: read security-events: write steps: - name: Checkout repo uses: actions/checkout@v3 # The code in examples/ might intentionally do things like log credentials # in order to show how the library is used, aid in debugging etc. We # should ignore those for CodeQL scanning, and only focus on the package # itself. - name: Remove example code run: | rm -rf examples/ - name: Initialize CodeQL uses: github/codeql-action/init@v2 with: languages: 'go' - name: CodeQL Analysis uses: github/codeql-action/analyze@v2 sctp-1.8.6/.github/workflows/generate-authors.yml000066400000000000000000000047721436021606300221070ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # If this repository should have package specific CI config, # remove the repository name from .goassets/.github/workflows/assets-sync.yml. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # name: generate-authors on: pull_request: jobs: checksecret: permissions: contents: none runs-on: ubuntu-latest outputs: is_PIONBOT_PRIVATE_KEY_set: ${{ steps.checksecret_job.outputs.is_PIONBOT_PRIVATE_KEY_set }} steps: - id: checksecret_job env: PIONBOT_PRIVATE_KEY: ${{ secrets.PIONBOT_PRIVATE_KEY }} run: | echo "is_PIONBOT_PRIVATE_KEY_set: ${{ env.PIONBOT_PRIVATE_KEY != '' }}" echo "::set-output name=is_PIONBOT_PRIVATE_KEY_set::${{ env.PIONBOT_PRIVATE_KEY != '' }}" generate-authors: permissions: contents: write needs: [checksecret] if: needs.checksecret.outputs.is_PIONBOT_PRIVATE_KEY_set == 'true' runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 with: ref: ${{ github.head_ref }} fetch-depth: 0 token: ${{ secrets.PIONBOT_PRIVATE_KEY }} - name: Generate the authors file run: .github/generate-authors.sh - name: Add the authors file to git run: git add AUTHORS.txt - name: Get last commit message id: last-commit-message run: | COMMIT_MSG=$(git log -1 --pretty=%B) COMMIT_MSG="${COMMIT_MSG//'%'/'%25'}" COMMIT_MSG="${COMMIT_MSG//$'\n'/'%0A'}" COMMIT_MSG="${COMMIT_MSG//$'\r'/'%0D'}" echo "::set-output name=msg::$COMMIT_MSG" - name: Get last commit author id: last-commit-author run: | echo "::set-output name=msg::$(git log -1 --pretty='%aN <%ae>')" - name: Check if AUTHORS.txt file has changed id: git-status-output run: | echo "::set-output name=msg::$(git status -s | wc -l)" - name: Commit and push if: ${{ steps.git-status-output.outputs.msg != '0' }} run: | git config user.email $(echo "${{ steps.last-commit-author.outputs.msg }}" | sed 's/\(.\+\) <\(\S\+\)>/\2/') git config user.name $(echo "${{ steps.last-commit-author.outputs.msg }}" | sed 's/\(.\+\) <\(\S\+\)>/\1/') git add AUTHORS.txt git commit --amend --no-edit git push --force https://github.com/${GITHUB_REPOSITORY} $(git symbolic-ref -q --short HEAD) sctp-1.8.6/.github/workflows/lint.yaml000066400000000000000000000026251436021606300177340ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # If this repository should have package specific CI config, # remove the repository name from .goassets/.github/workflows/assets-sync.yml. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # name: Lint on: pull_request: types: - opened - edited - synchronize permissions: contents: read jobs: lint-commit-message: name: Metadata runs-on: ubuntu-latest strategy: fail-fast: false steps: - uses: actions/checkout@v3 with: fetch-depth: 0 - name: Commit Message run: .github/lint-commit-message.sh - name: File names run: .github/lint-filename.sh - name: Functions run: .github/lint-disallowed-functions-in-library.sh - name: Logging messages should not have trailing newlines run: .github/lint-no-trailing-newline-in-log-messages.sh lint-go: name: Go permissions: contents: read pull-requests: read runs-on: ubuntu-latest strategy: fail-fast: false steps: - uses: actions/checkout@v3 - name: golangci-lint uses: golangci/golangci-lint-action@v3 with: version: v1.45.2 args: $GOLANGCI_LINT_EXRA_ARGS sctp-1.8.6/.github/workflows/release.yml000066400000000000000000000010071436021606300202360ustar00rootroot00000000000000name: release on: push: tags: - 'v*' jobs: release: permissions: contents: write runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 with: fetch-depth: 0 - uses: actions/setup-go@v3 with: go-version: '1.18' # auto-update/latest-go-version - name: Build and release uses: goreleaser/goreleaser-action@v4 with: version: latest args: release --rm-dist env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} sctp-1.8.6/.github/workflows/renovate-go-mod-fix.yaml000066400000000000000000000016071436021606300225540ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # If this repository should have package specific CI config, # remove the repository name from .goassets/.github/workflows/assets-sync.yml. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # name: go-mod-fix on: push: branches: - renovate/* permissions: contents: write jobs: go-mod-fix: runs-on: ubuntu-latest steps: - name: checkout uses: actions/checkout@v3 with: fetch-depth: 2 - name: fix uses: at-wat/go-sum-fix-action@v0 with: git_user: Pion Bot git_email: 59523206+pionbot@users.noreply.github.com github_token: ${{ secrets.PIONBOT_PRIVATE_KEY }} commit_style: squash push: force sctp-1.8.6/.github/workflows/test.yaml000066400000000000000000000112631436021606300177430ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # If this repository should have package specific CI config, # remove the repository name from .goassets/.github/workflows/assets-sync.yml. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # name: Test on: push: branches: - master pull_request: branches: - master permissions: contents: read jobs: test: runs-on: ubuntu-latest strategy: matrix: go: ['1.19', '1.18'] # auto-update/supported-go-version-list fail-fast: false name: Go ${{ matrix.go }} steps: - uses: actions/checkout@v3 - uses: actions/cache@v3 with: path: | ~/go/pkg/mod ~/go/bin ~/.cache key: ${{ runner.os }}-amd64-go-${{ hashFiles('**/go.sum') }} restore-keys: | ${{ runner.os }}-amd64-go- - name: Setup Go uses: actions/setup-go@v3 with: go-version: ${{ matrix.go }} - name: Setup go-acc run: go install github.com/ory/go-acc@latest - name: Set up gotestfmt uses: haveyoudebuggedit/gotestfmt-action@v2 with: token: ${{ secrets.GITHUB_TOKEN }} # Avoid getting rate limited - name: Run test run: | TEST_BENCH_OPTION="-bench=." if [ -f .github/.ci.conf ]; then . .github/.ci.conf; fi set -euo pipefail go-acc -o cover.out ./... -- \ ${TEST_BENCH_OPTION} \ -json \ -v -race 2>&1 | grep -v '^go: downloading' | tee /tmp/gotest.log | gotestfmt - name: Upload test log uses: actions/upload-artifact@v3 if: always() with: name: test-log-${{ matrix.go }} path: /tmp/gotest.log if-no-files-found: error - name: Run TEST_HOOK run: | if [ -f .github/.ci.conf ]; then . .github/.ci.conf; fi if [ -n "${TEST_HOOK}" ]; then ${TEST_HOOK}; fi - uses: codecov/codecov-action@v3 with: name: codecov-umbrella fail_ci_if_error: true flags: go test-i386: runs-on: ubuntu-latest strategy: matrix: go: ['1.19', '1.18'] # auto-update/supported-go-version-list fail-fast: false name: Go i386 ${{ matrix.go }} steps: - uses: actions/checkout@v3 - uses: actions/cache@v3 with: path: | ~/go/pkg/mod ~/.cache key: ${{ runner.os }}-i386-go-${{ hashFiles('**/go.sum') }} restore-keys: | ${{ runner.os }}-i386-go- - name: Run test run: | mkdir -p $HOME/go/pkg/mod $HOME/.cache docker run \ -u $(id -u):$(id -g) \ -e "GO111MODULE=on" \ -e "CGO_ENABLED=0" \ -v $GITHUB_WORKSPACE:/go/src/github.com/pion/$(basename $GITHUB_WORKSPACE) \ -v $HOME/go/pkg/mod:/go/pkg/mod \ -v $HOME/.cache:/.cache \ -w /go/src/github.com/pion/$(basename $GITHUB_WORKSPACE) \ i386/golang:${{matrix.go}}-alpine \ /usr/local/go/bin/go test \ ${TEST_EXTRA_ARGS:-} \ -v ./... test-wasm: runs-on: ubuntu-latest strategy: fail-fast: false name: WASM steps: - uses: actions/checkout@v3 - name: Use Node.js uses: actions/setup-node@v3 with: node-version: '16.x' - uses: actions/cache@v3 with: path: | ~/go/pkg/mod ~/.cache key: ${{ runner.os }}-wasm-go-${{ hashFiles('**/go.sum') }} restore-keys: | ${{ runner.os }}-wasm-go- - name: Download Go run: curl -sSfL https://dl.google.com/go/go${GO_VERSION}.linux-amd64.tar.gz | tar -C ~ -xzf - env: GO_VERSION: '1.19' # auto-update/latest-go-version - name: Set Go Root run: echo "GOROOT=${HOME}/go" >> $GITHUB_ENV - name: Set Go Path run: echo "GOPATH=${HOME}/go" >> $GITHUB_ENV - name: Set Go Path run: echo "GO_JS_WASM_EXEC=${GOROOT}/misc/wasm/go_js_wasm_exec" >> $GITHUB_ENV - name: Insall NPM modules run: yarn install - name: Run Tests run: | if [ -f .github/.ci.conf ]; then . .github/.ci.conf; fi GOOS=js GOARCH=wasm $GOPATH/bin/go test \ -coverprofile=cover.out -covermode=atomic \ -exec="${GO_JS_WASM_EXEC}" \ -v ./... - uses: codecov/codecov-action@v3 with: name: codecov-umbrella fail_ci_if_error: true flags: wasm sctp-1.8.6/.github/workflows/tidy-check.yaml000066400000000000000000000016751436021606300210160ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # If this repository should have package specific CI config, # remove the repository name from .goassets/.github/workflows/assets-sync.yml. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # name: Go mod tidy on: pull_request: branches: - master push: branches: - master permissions: contents: read jobs: Check: runs-on: ubuntu-latest steps: - name: checkout uses: actions/checkout@v3 - name: Setup Go uses: actions/setup-go@v3 with: go-version: '1.19' # auto-update/latest-go-version - name: check run: | go mod download go mod tidy if ! git diff --exit-code then echo "Not go mod tidied" exit 1 fi sctp-1.8.6/.gitignore000066400000000000000000000004661436021606300144760ustar00rootroot00000000000000### JetBrains IDE ### ##################### .idea/ ### Emacs Temporary Files ### ############################# *~ ### Folders ### ############### bin/ vendor/ node_modules/ ### Files ### ############# *.ivf *.ogg tags cover.out *.sw[poe] *.wasm examples/sfu-ws/cert.pem examples/sfu-ws/key.pem wasm_exec.js sctp-1.8.6/.golangci.yml000066400000000000000000000175411436021606300150740ustar00rootroot00000000000000linters-settings: govet: check-shadowing: true misspell: locale: US exhaustive: default-signifies-exhaustive: true gomodguard: blocked: modules: - github.com/pkg/errors: recommendations: - errors linters: enable: - asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers - bidichk # Checks for dangerous unicode character sequences - bodyclose # checks whether HTTP response body is closed successfully - contextcheck # check the function whether use a non-inherited context - deadcode # Finds unused code - decorder # check declaration order and count of types, constants, variables and functions - depguard # Go linter that checks if package imports are in a list of acceptable packages - dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) - dupl # Tool for code clone detection - durationcheck # check for two durations multiplied together - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases - errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted. - errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`. - errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13. - exhaustive # check exhaustiveness of enum switch statements - exportloopref # checks for pointers to enclosing loop variables - forcetypeassert # finds forced type assertions - gci # Gci control golang package import order and make it always deterministic. - gochecknoglobals # Checks that no globals are present in Go code - gochecknoinits # Checks that no init functions are present in Go code - gocognit # Computes and checks the cognitive complexity of functions - goconst # Finds repeated strings that could be replaced by a constant - gocritic # The most opinionated Go source code linter - godox # Tool for detection of FIXME, TODO and other comment keywords - goerr113 # Golang linter to check the errors handling expressions - gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification - gofumpt # Gofumpt checks whether code was gofumpt-ed. - goheader # Checks is file header matches to pattern - goimports # Goimports does everything that gofmt does. Additionally it checks unused imports - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. - goprintffuncname # Checks that printf-like functions are named with `f` at the end - gosec # Inspects source code for security problems - gosimple # Linter for Go source code that specializes in simplifying a code - govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string - grouper # An analyzer to analyze expression groups. - importas # Enforces consistent import aliases - ineffassign # Detects when assignments to existing variables are not used - misspell # Finds commonly misspelled English words in comments - nakedret # Finds naked returns in functions greater than a specified function length - nilerr # Finds the code that returns nil even if it checks that the error is not nil. - nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value. - noctx # noctx finds sending http request without context.Context - predeclared # find code that shadows one of Go's predeclared identifiers - revive # golint replacement, finds style mistakes - staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks - structcheck # Finds unused struct fields - stylecheck # Stylecheck is a replacement for golint - tagliatelle # Checks the struct tags. - tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17 - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes - typecheck # Like the front-end of a Go compiler, parses and type-checks Go code - unconvert # Remove unnecessary type conversions - unparam # Reports unused function parameters - unused # Checks Go code for unused constants, variables, functions and types - varcheck # Finds unused global variables and constants - wastedassign # wastedassign finds wasted assignment statements - whitespace # Tool for detection of leading and trailing whitespace disable: - containedctx # containedctx is a linter that detects struct contained context.Context field - cyclop # checks function and package cyclomatic complexity - exhaustivestruct # Checks if all struct's fields are initialized - forbidigo # Forbids identifiers - funlen # Tool for detection of long functions - gocyclo # Computes and checks the cyclomatic complexity of functions - godot # Check if comments end in a period - gomnd # An analyzer to detect magic numbers. - ifshort # Checks that your code uses short syntax for if-statements whenever possible - ireturn # Accept Interfaces, Return Concrete Types - lll # Reports long lines - maintidx # maintidx measures the maintainability index of each function. - makezero # Finds slice declarations with non-zero initial length - maligned # Tool to detect Go structs that would take less memory if their fields were sorted - nestif # Reports deeply nested if statements - nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity - nolintlint # Reports ill-formed or insufficient nolint directives - paralleltest # paralleltest detects missing usage of t.Parallel() method in your Go test - prealloc # Finds slice declarations that could potentially be preallocated - promlinter # Check Prometheus metrics naming via promlint - rowserrcheck # checks whether Err of rows is checked successfully - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. - testpackage # linter that makes you use a separate _test package - thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers - varnamelen # checks that the length of a variable's name matches its scope - wrapcheck # Checks that errors returned from external packages are wrapped - wsl # Whitespace Linter - Forces you to use empty lines! issues: exclude-use-default: false exclude-rules: # Allow complex tests, better to be self contained - path: _test\.go linters: - gocognit # Allow complex main function in examples - path: examples text: "of func `main` is high" linters: - gocognit run: skip-dirs-use-default: false sctp-1.8.6/.goreleaser.yml000066400000000000000000000000251436021606300154260ustar00rootroot00000000000000builds: - skip: true sctp-1.8.6/AUTHORS.txt000066400000000000000000000021261436021606300143670ustar00rootroot00000000000000# Thank you to everyone that made Pion possible. If you are interested in contributing # we would love to have you https://github.com/pion/webrtc/wiki/Contributing # # This file is auto generated, using git to list all individuals contributors. # see `.github/generate-authors.sh` for the scripting Aaron France Adrian Cable Atsushi Watanabe backkem Cecylia Bocovich chenkaiC4 Eric Daniels Hugo Arregui Hugo Arregui Jerko Steiner Jerry Tao John Bradley Konstantin Itskov Lukas Herman Luke Curley Michael MacDonald ronan Sam Lancia Sean DuBois Sean DuBois Teddy Yutaka Takeda ZHENK sctp-1.8.6/DESIGN.md000066400000000000000000000014451436021606300137770ustar00rootroot00000000000000

Design

### Portable Pion SCTP is written in Go and extremely portable. Anywhere Golang runs, Pion SCTP should work as well! Instead of dealing with complicated cross-compiling of multiple libraries, you now can run anywhere with one `go build` ### Simple API The API is based on an io.ReadWriteCloser. ### Readable If code comes from an RFC we try to make sure everything is commented with a link to the spec. This makes learning and debugging easier, this library was written to also serve as a guide for others. ### Tested Every commit is tested via travis-ci Go provides fantastic facilities for testing, and more will be added as time goes on. ### Shared libraries Every pion product is built using shared libraries, allowing others to review and reuse our libraries. sctp-1.8.6/LICENSE000066400000000000000000000020411436021606300135020ustar00rootroot00000000000000MIT License Copyright (c) 2018 Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. sctp-1.8.6/README.md000066400000000000000000000046051436021606300137640ustar00rootroot00000000000000


Pion SCTP

A Go implementation of SCTP

Pion SCTP Slack Widget
Build Status GoDoc Coverage Status Go Report Card License: MIT


See [DESIGN.md](DESIGN.md) for an overview of features and future goals. ### Roadmap The library is used as a part of our WebRTC implementation. Please refer to that [roadmap](https://github.com/pion/webrtc/issues/9) to track our major milestones. ### Community Pion has an active community on the [Golang Slack](https://invite.slack.golangbridge.org/). Sign up and join the **#pion** channel for discussions and support. You can also use [Pion mailing list](https://groups.google.com/forum/#!forum/pion). We are always looking to support **your projects**. Please reach out if you have something to build! If you need commercial support or don't want to use public methods you can contact us at [team@pion.ly](mailto:team@pion.ly) ### Contributing Check out the **[contributing wiki](https://github.com/pion/webrtc/wiki/Contributing)** to join the group of amazing people making this project possible: ### License MIT License - see [LICENSE](LICENSE) for full text sctp-1.8.6/ack_timer.go000066400000000000000000000035601436021606300147710ustar00rootroot00000000000000package sctp import ( "sync" "time" ) const ( ackInterval time.Duration = 200 * time.Millisecond ) // ackTimerObserver is the inteface to an ack timer observer. type ackTimerObserver interface { onAckTimeout() } // ackTimer provides the retnransmission timer conforms with RFC 4960 Sec 6.3.1 type ackTimer struct { observer ackTimerObserver interval time.Duration stopFunc stopAckTimerLoop closed bool mutex sync.RWMutex } type stopAckTimerLoop func() // newAckTimer creates a new acknowledgement timer used to enable delayed ack. func newAckTimer(observer ackTimerObserver) *ackTimer { return &ackTimer{ observer: observer, interval: ackInterval, } } // start starts the timer. func (t *ackTimer) start() bool { t.mutex.Lock() defer t.mutex.Unlock() // this timer is already closed if t.closed { return false } // this is a noop if the timer is already running if t.stopFunc != nil { return false } cancelCh := make(chan struct{}) go func() { timer := time.NewTimer(t.interval) select { case <-timer.C: t.stop() t.observer.onAckTimeout() case <-cancelCh: timer.Stop() } }() t.stopFunc = func() { close(cancelCh) } return true } // stops the timer. this is similar to stop() but subsequent start() call // will fail (the timer is no longer usable) func (t *ackTimer) stop() { t.mutex.Lock() defer t.mutex.Unlock() if t.stopFunc != nil { t.stopFunc() t.stopFunc = nil } } // closes the timer. this is similar to stop() but subsequent start() call // will fail (the timer is no longer usable) func (t *ackTimer) close() { t.mutex.Lock() defer t.mutex.Unlock() if t.stopFunc != nil { t.stopFunc() t.stopFunc = nil } t.closed = true } // isRunning tests if the timer is running. // Debug purpose only func (t *ackTimer) isRunning() bool { t.mutex.RLock() defer t.mutex.RUnlock() return (t.stopFunc != nil) } sctp-1.8.6/ack_timer_test.go000066400000000000000000000044061436021606300160300ustar00rootroot00000000000000package sctp import ( "sync/atomic" "testing" "time" "github.com/stretchr/testify/assert" ) type onAckTO func() type testAckTimerObserver struct { onAckTO onAckTO } func (o *testAckTimerObserver) onAckTimeout() { o.onAckTO() } func TestAckTimer(t *testing.T) { t.Run("start and close", func(t *testing.T) { var nCbs uint32 rt := newAckTimer(&testAckTimerObserver{ onAckTO: func() { t.Log("ack timed out") atomic.AddUint32(&nCbs, 1) }, }) for i := 0; i < 2; i++ { // should start ok ok := rt.start() assert.True(t, ok, "start() should succeed") assert.True(t, rt.isRunning(), "should be running") // subsequent start is a noop ok = rt.start() assert.False(t, ok, "start() should NOT succeed once closed") assert.True(t, rt.isRunning(), "should be running") // Sleep more than 2 * 200msec interval to test if it times out only once time.Sleep(ackInterval*2 + 50*time.Millisecond) assert.Equal(t, uint32(1), atomic.LoadUint32(&nCbs), "should be called once (actual: %d)", atomic.LoadUint32(&nCbs)) atomic.StoreUint32(&nCbs, 0) } // should close ok rt.close() assert.False(t, rt.isRunning(), "should not be running") // once closed, it cannot start ok := rt.start() assert.False(t, ok, "start() should NOT succeed once closed") assert.False(t, rt.isRunning(), "should not be running") }) t.Run("start and stop", func(t *testing.T) { var nCbs uint32 rt := newAckTimer(&testAckTimerObserver{ onAckTO: func() { t.Log("ack timed out") atomic.AddUint32(&nCbs, 1) }, }) // should start ok ok := rt.start() assert.True(t, ok, "start() should succeed") assert.True(t, rt.isRunning(), "should be running") // stop immedidately rt.stop() assert.False(t, rt.isRunning(), "should not be running") // Sleep more than 200msec of interval to test if it never times out time.Sleep(ackInterval + 50*time.Millisecond) assert.Equal(t, uint32(0), atomic.LoadUint32(&nCbs), "should not be timed out (actual: %d)", atomic.LoadUint32(&nCbs)) // can start again ok = rt.start() assert.True(t, ok, "start() should succeed again") assert.True(t, rt.isRunning(), "should be running") // should close ok rt.close() assert.False(t, rt.isRunning(), "should not be running") }) } sctp-1.8.6/association.go000066400000000000000000002205501436021606300153470ustar00rootroot00000000000000package sctp import ( "bytes" "context" "errors" "fmt" "io" "math" "net" "sync" "sync/atomic" "time" "github.com/pion/logging" "github.com/pion/randutil" ) // Use global random generator to properly seed by crypto grade random. var globalMathRandomGenerator = randutil.NewMathRandomGenerator() // nolint:gochecknoglobals // Association errors var ( ErrChunk = errors.New("abort chunk, with following errors") ErrShutdownNonEstablished = errors.New("shutdown called in non-established state") ErrAssociationClosedBeforeConn = errors.New("association closed before connecting") ErrSilentlyDiscard = errors.New("silently discard") ErrInitNotStoredToSend = errors.New("the init not stored to send") ErrCookieEchoNotStoredToSend = errors.New("cookieEcho not stored to send") ErrSCTPPacketSourcePortZero = errors.New("sctp packet must not have a source port of 0") ErrSCTPPacketDestinationPortZero = errors.New("sctp packet must not have a destination port of 0") ErrInitChunkBundled = errors.New("init chunk must not be bundled with any other chunk") ErrInitChunkVerifyTagNotZero = errors.New("init chunk expects a verification tag of 0 on the packet when out-of-the-blue") ErrHandleInitState = errors.New("todo: handle Init when in state") ErrInitAckNoCookie = errors.New("no cookie in InitAck") ErrInflightQueueTSNPop = errors.New("unable to be popped from inflight queue TSN") ErrTSNRequestNotExist = errors.New("requested non-existent TSN") ErrResetPacketInStateNotExist = errors.New("sending reset packet in non-established state") ErrParamterType = errors.New("unexpected parameter type") ErrPayloadDataStateNotExist = errors.New("sending payload data in non-established state") ErrChunkTypeUnhandled = errors.New("unhandled chunk type") ErrHandshakeInitAck = errors.New("handshake failed (INIT ACK)") ErrHandshakeCookieEcho = errors.New("handshake failed (COOKIE ECHO)") ) const ( receiveMTU uint32 = 8192 // MTU for inbound packet (from DTLS) initialMTU uint32 = 1228 // initial MTU for outgoing packets (to DTLS) initialRecvBufSize uint32 = 1024 * 1024 commonHeaderSize uint32 = 12 dataChunkHeaderSize uint32 = 16 defaultMaxMessageSize uint32 = 65536 ) // association state enums const ( closed uint32 = iota cookieWait cookieEchoed established shutdownAckSent shutdownPending shutdownReceived shutdownSent ) // retransmission timer IDs const ( timerT1Init int = iota timerT1Cookie timerT2Shutdown timerT3RTX timerReconfig ) // ack mode (for testing) const ( ackModeNormal int = iota ackModeNoDelay ackModeAlwaysDelay ) // ack transmission state const ( ackStateIdle int = iota // ack timer is off ackStateImmediate // will send ack immediately ackStateDelay // ack timer is on (ack is being delayed) ) // other constants const ( acceptChSize = 16 ) func getAssociationStateString(a uint32) string { switch a { case closed: return "Closed" case cookieWait: return "CookieWait" case cookieEchoed: return "CookieEchoed" case established: return "Established" case shutdownPending: return "ShutdownPending" case shutdownSent: return "ShutdownSent" case shutdownReceived: return "ShutdownReceived" case shutdownAckSent: return "ShutdownAckSent" default: return fmt.Sprintf("Invalid association state %d", a) } } // Association represents an SCTP association // 13.2. Parameters Necessary per Association (i.e., the TCB) // Peer : Tag value to be sent in every packet and is received // Verification: in the INIT or INIT ACK chunk. // Tag : // // My : Tag expected in every inbound packet and sent in the // Verification: INIT or INIT ACK chunk. // // Tag : // State : A state variable indicating what state the association // : is in, i.e., COOKIE-WAIT, COOKIE-ECHOED, ESTABLISHED, // : SHUTDOWN-PENDING, SHUTDOWN-SENT, SHUTDOWN-RECEIVED, // : SHUTDOWN-ACK-SENT. // // Note: No "CLOSED" state is illustrated since if a // association is "CLOSED" its TCB SHOULD be removed. type Association struct { bytesReceived uint64 bytesSent uint64 lock sync.RWMutex netConn net.Conn peerVerificationTag uint32 myVerificationTag uint32 state uint32 myNextTSN uint32 // nextTSN peerLastTSN uint32 // lastRcvdTSN minTSN2MeasureRTT uint32 // for RTT measurement willSendForwardTSN bool willRetransmitFast bool willRetransmitReconfig bool willSendShutdown bool willSendShutdownAck bool willSendShutdownComplete bool willSendAbort bool willSendAbortCause errorCause // Reconfig myNextRSN uint32 reconfigs map[uint32]*chunkReconfig reconfigRequests map[uint32]*paramOutgoingResetRequest // Non-RFC internal data sourcePort uint16 destinationPort uint16 myMaxNumInboundStreams uint16 myMaxNumOutboundStreams uint16 myCookie *paramStateCookie payloadQueue *payloadQueue inflightQueue *payloadQueue pendingQueue *pendingQueue controlQueue *controlQueue mtu uint32 maxPayloadSize uint32 // max DATA chunk payload size cumulativeTSNAckPoint uint32 advancedPeerTSNAckPoint uint32 useForwardTSN bool // Congestion control parameters maxReceiveBufferSize uint32 maxMessageSize uint32 cwnd uint32 // my congestion window size rwnd uint32 // calculated peer's receiver windows size ssthresh uint32 // slow start threshold partialBytesAcked uint32 inFastRecovery bool fastRecoverExitPoint uint32 // RTX & Ack timer rtoMgr *rtoManager t1Init *rtxTimer t1Cookie *rtxTimer t2Shutdown *rtxTimer t3RTX *rtxTimer tReconfig *rtxTimer ackTimer *ackTimer // Chunks stored for retransmission storedInit *chunkInit storedCookieEcho *chunkCookieEcho streams map[uint16]*Stream acceptCh chan *Stream readLoopCloseCh chan struct{} awakeWriteLoopCh chan struct{} closeWriteLoopCh chan struct{} handshakeCompletedCh chan error closeWriteLoopOnce sync.Once // local error silentError error ackState int ackMode int // for testing // stats stats *associationStats // per inbound packet context delayedAckTriggered bool immediateAckTriggered bool name string log logging.LeveledLogger } // Config collects the arguments to createAssociation construction into // a single structure type Config struct { NetConn net.Conn MaxReceiveBufferSize uint32 MaxMessageSize uint32 LoggerFactory logging.LoggerFactory } // Server accepts a SCTP stream over a conn func Server(config Config) (*Association, error) { a := createAssociation(config) a.init(false) select { case err := <-a.handshakeCompletedCh: if err != nil { return nil, err } return a, nil case <-a.readLoopCloseCh: return nil, ErrAssociationClosedBeforeConn } } // Client opens a SCTP stream over a conn func Client(config Config) (*Association, error) { a := createAssociation(config) a.init(true) select { case err := <-a.handshakeCompletedCh: if err != nil { return nil, err } return a, nil case <-a.readLoopCloseCh: return nil, ErrAssociationClosedBeforeConn } } func createAssociation(config Config) *Association { var maxReceiveBufferSize uint32 if config.MaxReceiveBufferSize == 0 { maxReceiveBufferSize = initialRecvBufSize } else { maxReceiveBufferSize = config.MaxReceiveBufferSize } var maxMessageSize uint32 if config.MaxMessageSize == 0 { maxMessageSize = defaultMaxMessageSize } else { maxMessageSize = config.MaxMessageSize } tsn := globalMathRandomGenerator.Uint32() a := &Association{ netConn: config.NetConn, maxReceiveBufferSize: maxReceiveBufferSize, maxMessageSize: maxMessageSize, myMaxNumOutboundStreams: math.MaxUint16, myMaxNumInboundStreams: math.MaxUint16, payloadQueue: newPayloadQueue(), inflightQueue: newPayloadQueue(), pendingQueue: newPendingQueue(), controlQueue: newControlQueue(), mtu: initialMTU, maxPayloadSize: initialMTU - (commonHeaderSize + dataChunkHeaderSize), myVerificationTag: globalMathRandomGenerator.Uint32(), myNextTSN: tsn, myNextRSN: tsn, minTSN2MeasureRTT: tsn, state: closed, rtoMgr: newRTOManager(), streams: map[uint16]*Stream{}, reconfigs: map[uint32]*chunkReconfig{}, reconfigRequests: map[uint32]*paramOutgoingResetRequest{}, acceptCh: make(chan *Stream, acceptChSize), readLoopCloseCh: make(chan struct{}), awakeWriteLoopCh: make(chan struct{}, 1), closeWriteLoopCh: make(chan struct{}), handshakeCompletedCh: make(chan error), cumulativeTSNAckPoint: tsn - 1, advancedPeerTSNAckPoint: tsn - 1, silentError: ErrSilentlyDiscard, stats: &associationStats{}, log: config.LoggerFactory.NewLogger("sctp"), } a.name = fmt.Sprintf("%p", a) // RFC 4690 Sec 7.2.1 // o The initial cwnd before DATA transmission or after a sufficiently // long idle period MUST be set to min(4*MTU, max (2*MTU, 4380 // bytes)). a.cwnd = min32(4*a.mtu, max32(2*a.mtu, 4380)) a.log.Tracef("[%s] updated cwnd=%d ssthresh=%d inflight=%d (INI)", a.name, a.cwnd, a.ssthresh, a.inflightQueue.getNumBytes()) a.t1Init = newRTXTimer(timerT1Init, a, maxInitRetrans) a.t1Cookie = newRTXTimer(timerT1Cookie, a, maxInitRetrans) a.t2Shutdown = newRTXTimer(timerT2Shutdown, a, noMaxRetrans) // retransmit forever a.t3RTX = newRTXTimer(timerT3RTX, a, noMaxRetrans) // retransmit forever a.tReconfig = newRTXTimer(timerReconfig, a, noMaxRetrans) // retransmit forever a.ackTimer = newAckTimer(a) return a } func (a *Association) init(isClient bool) { a.lock.Lock() defer a.lock.Unlock() go a.readLoop() go a.writeLoop() if isClient { a.setState(cookieWait) init := &chunkInit{} init.initialTSN = a.myNextTSN init.numOutboundStreams = a.myMaxNumOutboundStreams init.numInboundStreams = a.myMaxNumInboundStreams init.initiateTag = a.myVerificationTag init.advertisedReceiverWindowCredit = a.maxReceiveBufferSize setSupportedExtensions(&init.chunkInitCommon) a.storedInit = init err := a.sendInit() if err != nil { a.log.Errorf("[%s] failed to send init: %s", a.name, err.Error()) } a.t1Init.start(a.rtoMgr.getRTO()) } } // caller must hold a.lock func (a *Association) sendInit() error { a.log.Debugf("[%s] sending INIT", a.name) if a.storedInit == nil { return ErrInitNotStoredToSend } outbound := &packet{} outbound.verificationTag = a.peerVerificationTag a.sourcePort = 5000 // Spec?? a.destinationPort = 5000 // Spec?? outbound.sourcePort = a.sourcePort outbound.destinationPort = a.destinationPort outbound.chunks = []chunk{a.storedInit} a.controlQueue.push(outbound) a.awakeWriteLoop() return nil } // caller must hold a.lock func (a *Association) sendCookieEcho() error { if a.storedCookieEcho == nil { return ErrCookieEchoNotStoredToSend } a.log.Debugf("[%s] sending COOKIE-ECHO", a.name) outbound := &packet{} outbound.verificationTag = a.peerVerificationTag outbound.sourcePort = a.sourcePort outbound.destinationPort = a.destinationPort outbound.chunks = []chunk{a.storedCookieEcho} a.controlQueue.push(outbound) a.awakeWriteLoop() return nil } // Shutdown initiates the shutdown sequence. The method blocks until the // shutdown sequence is completed and the connection is closed, or until the // passed context is done, in which case the context's error is returned. func (a *Association) Shutdown(ctx context.Context) error { a.log.Debugf("[%s] closing association..", a.name) state := a.getState() if state != established { return fmt.Errorf("%w: shutdown %s", ErrShutdownNonEstablished, a.name) } // Attempt a graceful shutdown. a.setState(shutdownPending) a.lock.Lock() if a.inflightQueue.size() == 0 { // No more outstanding, send shutdown. a.willSendShutdown = true a.awakeWriteLoop() a.setState(shutdownSent) } a.lock.Unlock() select { case <-a.closeWriteLoopCh: return nil case <-ctx.Done(): return ctx.Err() } } // Close ends the SCTP Association and cleans up any state func (a *Association) Close() error { a.log.Debugf("[%s] closing association..", a.name) err := a.close() // Wait for readLoop to end <-a.readLoopCloseCh a.log.Debugf("[%s] association closed", a.name) a.log.Debugf("[%s] stats nDATAs (in) : %d", a.name, a.stats.getNumDATAs()) a.log.Debugf("[%s] stats nSACKs (in) : %d", a.name, a.stats.getNumSACKs()) a.log.Debugf("[%s] stats nT3Timeouts : %d", a.name, a.stats.getNumT3Timeouts()) a.log.Debugf("[%s] stats nAckTimeouts: %d", a.name, a.stats.getNumAckTimeouts()) a.log.Debugf("[%s] stats nFastRetrans: %d", a.name, a.stats.getNumFastRetrans()) return err } func (a *Association) close() error { a.log.Debugf("[%s] closing association..", a.name) a.setState(closed) err := a.netConn.Close() a.closeAllTimers() // awake writeLoop to exit a.closeWriteLoopOnce.Do(func() { close(a.closeWriteLoopCh) }) return err } // Abort sends the abort packet with user initiated abort and immediately // closes the connection. func (a *Association) Abort(reason string) { a.log.Debugf("[%s] aborting association: %s", a.name, reason) a.lock.Lock() a.willSendAbort = true a.willSendAbortCause = &errorCauseUserInitiatedAbort{ upperLayerAbortReason: []byte(reason), } a.lock.Unlock() a.awakeWriteLoop() // Wait for readLoop to end <-a.readLoopCloseCh } func (a *Association) closeAllTimers() { // Close all retransmission & ack timers a.t1Init.close() a.t1Cookie.close() a.t2Shutdown.close() a.t3RTX.close() a.tReconfig.close() a.ackTimer.close() } func (a *Association) readLoop() { var closeErr error defer func() { // also stop writeLoop, otherwise writeLoop can be leaked // if connection is lost when there is no writing packet. a.closeWriteLoopOnce.Do(func() { close(a.closeWriteLoopCh) }) a.lock.Lock() for _, s := range a.streams { a.unregisterStream(s, closeErr) } a.lock.Unlock() close(a.acceptCh) close(a.readLoopCloseCh) a.log.Debugf("[%s] association closed", a.name) a.log.Debugf("[%s] stats nDATAs (in) : %d", a.name, a.stats.getNumDATAs()) a.log.Debugf("[%s] stats nSACKs (in) : %d", a.name, a.stats.getNumSACKs()) a.log.Debugf("[%s] stats nT3Timeouts : %d", a.name, a.stats.getNumT3Timeouts()) a.log.Debugf("[%s] stats nAckTimeouts: %d", a.name, a.stats.getNumAckTimeouts()) a.log.Debugf("[%s] stats nFastRetrans: %d", a.name, a.stats.getNumFastRetrans()) }() a.log.Debugf("[%s] readLoop entered", a.name) buffer := make([]byte, receiveMTU) for { n, err := a.netConn.Read(buffer) if err != nil { closeErr = err break } // Make a buffer sized to what we read, then copy the data we // read from the underlying transport. We do this because the // user data is passed to the reassembly queue without // copying. inbound := make([]byte, n) copy(inbound, buffer[:n]) atomic.AddUint64(&a.bytesReceived, uint64(n)) if err = a.handleInbound(inbound); err != nil { closeErr = err break } } a.log.Debugf("[%s] readLoop exited %s", a.name, closeErr) } func (a *Association) writeLoop() { a.log.Debugf("[%s] writeLoop entered", a.name) defer a.log.Debugf("[%s] writeLoop exited", a.name) loop: for { rawPackets, ok := a.gatherOutbound() for _, raw := range rawPackets { _, err := a.netConn.Write(raw) if err != nil { if !errors.Is(err, io.EOF) { a.log.Warnf("[%s] failed to write packets on netConn: %v", a.name, err) } a.log.Debugf("[%s] writeLoop ended", a.name) break loop } atomic.AddUint64(&a.bytesSent, uint64(len(raw))) } if !ok { if err := a.close(); err != nil { a.log.Warnf("[%s] failed to close association: %v", a.name, err) } return } select { case <-a.awakeWriteLoopCh: case <-a.closeWriteLoopCh: break loop } } a.setState(closed) a.closeAllTimers() } func (a *Association) awakeWriteLoop() { select { case a.awakeWriteLoopCh <- struct{}{}: default: } } // unregisterStream un-registers a stream from the association // The caller should hold the association write lock. func (a *Association) unregisterStream(s *Stream, err error) { s.lock.Lock() defer s.lock.Unlock() delete(a.streams, s.streamIdentifier) s.readErr = err s.readNotifier.Broadcast() } // handleInbound parses incoming raw packets func (a *Association) handleInbound(raw []byte) error { p := &packet{} if err := p.unmarshal(raw); err != nil { a.log.Warnf("[%s] unable to parse SCTP packet %s", a.name, err) return nil } if err := checkPacket(p); err != nil { a.log.Warnf("[%s] failed validating packet %s", a.name, err) return nil } a.handleChunkStart() for _, c := range p.chunks { if err := a.handleChunk(p, c); err != nil { return err } } a.handleChunkEnd() return nil } // The caller should hold the lock func (a *Association) gatherDataPacketsToRetransmit(rawPackets [][]byte) [][]byte { for _, p := range a.getDataPacketsToRetransmit() { raw, err := p.marshal() if err != nil { a.log.Warnf("[%s] failed to serialize a DATA packet to be retransmitted", a.name) continue } rawPackets = append(rawPackets, raw) } return rawPackets } // The caller should hold the lock func (a *Association) gatherOutboundDataAndReconfigPackets(rawPackets [][]byte) [][]byte { // Pop unsent data chunks from the pending queue to send as much as // cwnd and rwnd allow. chunks, sisToReset := a.popPendingDataChunksToSend() if len(chunks) > 0 { // Start timer. (noop if already started) a.log.Tracef("[%s] T3-rtx timer start (pt1)", a.name) a.t3RTX.start(a.rtoMgr.getRTO()) for _, p := range a.bundleDataChunksIntoPackets(chunks) { raw, err := p.marshal() if err != nil { a.log.Warnf("[%s] failed to serialize a DATA packet", a.name) continue } rawPackets = append(rawPackets, raw) } } if len(sisToReset) > 0 || a.willRetransmitReconfig { if a.willRetransmitReconfig { a.willRetransmitReconfig = false a.log.Debugf("[%s] retransmit %d RECONFIG chunk(s)", a.name, len(a.reconfigs)) for _, c := range a.reconfigs { p := a.createPacket([]chunk{c}) raw, err := p.marshal() if err != nil { a.log.Warnf("[%s] failed to serialize a RECONFIG packet to be retransmitted", a.name) } else { rawPackets = append(rawPackets, raw) } } } if len(sisToReset) > 0 { rsn := a.generateNextRSN() tsn := a.myNextTSN - 1 c := &chunkReconfig{ paramA: ¶mOutgoingResetRequest{ reconfigRequestSequenceNumber: rsn, senderLastTSN: tsn, streamIdentifiers: sisToReset, }, } a.reconfigs[rsn] = c // store in the map for retransmission a.log.Debugf("[%s] sending RECONFIG: rsn=%d tsn=%d streams=%v", a.name, rsn, a.myNextTSN-1, sisToReset) p := a.createPacket([]chunk{c}) raw, err := p.marshal() if err != nil { a.log.Warnf("[%s] failed to serialize a RECONFIG packet to be transmitted", a.name) } else { rawPackets = append(rawPackets, raw) } } if len(a.reconfigs) > 0 { a.tReconfig.start(a.rtoMgr.getRTO()) } } return rawPackets } // The caller should hold the lock func (a *Association) gatherOutboundFastRetransmissionPackets(rawPackets [][]byte) [][]byte { if a.willRetransmitFast { a.willRetransmitFast = false toFastRetrans := []chunk{} fastRetransSize := commonHeaderSize for i := 0; ; i++ { c, ok := a.inflightQueue.get(a.cumulativeTSNAckPoint + uint32(i) + 1) if !ok { break // end of pending data } if c.acked || c.abandoned() { continue } if c.nSent > 1 || c.missIndicator < 3 { continue } // RFC 4960 Sec 7.2.4 Fast Retransmit on Gap Reports // 3) Determine how many of the earliest (i.e., lowest TSN) DATA chunks // marked for retransmission will fit into a single packet, subject // to constraint of the path MTU of the destination transport // address to which the packet is being sent. Call this value K. // Retransmit those K DATA chunks in a single packet. When a Fast // Retransmit is being performed, the sender SHOULD ignore the value // of cwnd and SHOULD NOT delay retransmission for this single // packet. dataChunkSize := dataChunkHeaderSize + uint32(len(c.userData)) if a.mtu < fastRetransSize+dataChunkSize { break } fastRetransSize += dataChunkSize a.stats.incFastRetrans() c.nSent++ a.checkPartialReliabilityStatus(c) toFastRetrans = append(toFastRetrans, c) a.log.Tracef("[%s] fast-retransmit: tsn=%d sent=%d htna=%d", a.name, c.tsn, c.nSent, a.fastRecoverExitPoint) } if len(toFastRetrans) > 0 { raw, err := a.createPacket(toFastRetrans).marshal() if err != nil { a.log.Warnf("[%s] failed to serialize a DATA packet to be fast-retransmitted", a.name) } else { rawPackets = append(rawPackets, raw) } } } return rawPackets } // The caller should hold the lock func (a *Association) gatherOutboundSackPackets(rawPackets [][]byte) [][]byte { if a.ackState == ackStateImmediate { a.ackState = ackStateIdle sack := a.createSelectiveAckChunk() a.log.Debugf("[%s] sending SACK: %s", a.name, sack.String()) raw, err := a.createPacket([]chunk{sack}).marshal() if err != nil { a.log.Warnf("[%s] failed to serialize a SACK packet", a.name) } else { rawPackets = append(rawPackets, raw) } } return rawPackets } // The caller should hold the lock func (a *Association) gatherOutboundForwardTSNPackets(rawPackets [][]byte) [][]byte { if a.willSendForwardTSN { a.willSendForwardTSN = false if sna32GT(a.advancedPeerTSNAckPoint, a.cumulativeTSNAckPoint) { fwdtsn := a.createForwardTSN() raw, err := a.createPacket([]chunk{fwdtsn}).marshal() if err != nil { a.log.Warnf("[%s] failed to serialize a Forward TSN packet", a.name) } else { rawPackets = append(rawPackets, raw) } } } return rawPackets } func (a *Association) gatherOutboundShutdownPackets(rawPackets [][]byte) ([][]byte, bool) { ok := true switch { case a.willSendShutdown: a.willSendShutdown = false shutdown := &chunkShutdown{ cumulativeTSNAck: a.cumulativeTSNAckPoint, } raw, err := a.createPacket([]chunk{shutdown}).marshal() if err != nil { a.log.Warnf("[%s] failed to serialize a Shutdown packet", a.name) } else { a.t2Shutdown.start(a.rtoMgr.getRTO()) rawPackets = append(rawPackets, raw) } case a.willSendShutdownAck: a.willSendShutdownAck = false shutdownAck := &chunkShutdownAck{} raw, err := a.createPacket([]chunk{shutdownAck}).marshal() if err != nil { a.log.Warnf("[%s] failed to serialize a ShutdownAck packet", a.name) } else { a.t2Shutdown.start(a.rtoMgr.getRTO()) rawPackets = append(rawPackets, raw) } case a.willSendShutdownComplete: a.willSendShutdownComplete = false shutdownComplete := &chunkShutdownComplete{} raw, err := a.createPacket([]chunk{shutdownComplete}).marshal() if err != nil { a.log.Warnf("[%s] failed to serialize a ShutdownComplete packet", a.name) } else { rawPackets = append(rawPackets, raw) ok = false } } return rawPackets, ok } func (a *Association) gatherAbortPacket() ([]byte, error) { cause := a.willSendAbortCause a.willSendAbort = false a.willSendAbortCause = nil abort := &chunkAbort{} if cause != nil { abort.errorCauses = []errorCause{cause} } raw, err := a.createPacket([]chunk{abort}).marshal() return raw, err } // gatherOutbound gathers outgoing packets. The returned bool value set to // false means the association should be closed down after the final send. func (a *Association) gatherOutbound() ([][]byte, bool) { a.lock.Lock() defer a.lock.Unlock() if a.willSendAbort { pkt, err := a.gatherAbortPacket() if err != nil { a.log.Warnf("[%s] failed to serialize an abort packet", a.name) return nil, false } return [][]byte{pkt}, false } rawPackets := [][]byte{} if a.controlQueue.size() > 0 { for _, p := range a.controlQueue.popAll() { raw, err := p.marshal() if err != nil { a.log.Warnf("[%s] failed to serialize a control packet", a.name) continue } rawPackets = append(rawPackets, raw) } } state := a.getState() ok := true switch state { case established: rawPackets = a.gatherDataPacketsToRetransmit(rawPackets) rawPackets = a.gatherOutboundDataAndReconfigPackets(rawPackets) rawPackets = a.gatherOutboundFastRetransmissionPackets(rawPackets) rawPackets = a.gatherOutboundSackPackets(rawPackets) rawPackets = a.gatherOutboundForwardTSNPackets(rawPackets) case shutdownPending, shutdownSent, shutdownReceived: rawPackets = a.gatherDataPacketsToRetransmit(rawPackets) rawPackets = a.gatherOutboundFastRetransmissionPackets(rawPackets) rawPackets = a.gatherOutboundSackPackets(rawPackets) rawPackets, ok = a.gatherOutboundShutdownPackets(rawPackets) case shutdownAckSent: rawPackets, ok = a.gatherOutboundShutdownPackets(rawPackets) } return rawPackets, ok } func checkPacket(p *packet) error { // All packets must adhere to these rules // This is the SCTP sender's port number. It can be used by the // receiver in combination with the source IP address, the SCTP // destination port, and possibly the destination IP address to // identify the association to which this packet belongs. The port // number 0 MUST NOT be used. if p.sourcePort == 0 { return ErrSCTPPacketSourcePortZero } // This is the SCTP port number to which this packet is destined. // The receiving host will use this port number to de-multiplex the // SCTP packet to the correct receiving endpoint/application. The // port number 0 MUST NOT be used. if p.destinationPort == 0 { return ErrSCTPPacketDestinationPortZero } // Check values on the packet that are specific to a particular chunk type for _, c := range p.chunks { switch c.(type) { // nolint:gocritic case *chunkInit: // An INIT or INIT ACK chunk MUST NOT be bundled with any other chunk. // They MUST be the only chunks present in the SCTP packets that carry // them. if len(p.chunks) != 1 { return ErrInitChunkBundled } // A packet containing an INIT chunk MUST have a zero Verification // Tag. if p.verificationTag != 0 { return ErrInitChunkVerifyTagNotZero } } } return nil } func min16(a, b uint16) uint16 { if a < b { return a } return b } func max32(a, b uint32) uint32 { if a > b { return a } return b } func min32(a, b uint32) uint32 { if a < b { return a } return b } // setState atomically sets the state of the Association. // The caller should hold the lock. func (a *Association) setState(newState uint32) { oldState := atomic.SwapUint32(&a.state, newState) if newState != oldState { a.log.Debugf("[%s] state change: '%s' => '%s'", a.name, getAssociationStateString(oldState), getAssociationStateString(newState)) } } // getState atomically returns the state of the Association. func (a *Association) getState() uint32 { return atomic.LoadUint32(&a.state) } // BytesSent returns the number of bytes sent func (a *Association) BytesSent() uint64 { return atomic.LoadUint64(&a.bytesSent) } // BytesReceived returns the number of bytes received func (a *Association) BytesReceived() uint64 { return atomic.LoadUint64(&a.bytesReceived) } func setSupportedExtensions(init *chunkInitCommon) { // nolint:godox // TODO RFC5061 https://tools.ietf.org/html/rfc6525#section-5.2 // An implementation supporting this (Supported Extensions Parameter) // extension MUST list the ASCONF, the ASCONF-ACK, and the AUTH chunks // in its INIT and INIT-ACK parameters. init.params = append(init.params, ¶mSupportedExtensions{ ChunkTypes: []chunkType{ctReconfig, ctForwardTSN}, }) } // The caller should hold the lock. func (a *Association) handleInit(p *packet, i *chunkInit) ([]*packet, error) { state := a.getState() a.log.Debugf("[%s] chunkInit received in state '%s'", a.name, getAssociationStateString(state)) // https://tools.ietf.org/html/rfc4960#section-5.2.1 // Upon receipt of an INIT in the COOKIE-WAIT state, an endpoint MUST // respond with an INIT ACK using the same parameters it sent in its // original INIT chunk (including its Initiate Tag, unchanged). When // responding, the endpoint MUST send the INIT ACK back to the same // address that the original INIT (sent by this endpoint) was sent. if state != closed && state != cookieWait && state != cookieEchoed { // 5.2.2. Unexpected INIT in States Other than CLOSED, COOKIE-ECHOED, // COOKIE-WAIT, and SHUTDOWN-ACK-SENT return nil, fmt.Errorf("%w: %s", ErrHandleInitState, getAssociationStateString(state)) } // Should we be setting any of these permanently until we've ACKed further? a.myMaxNumInboundStreams = min16(i.numInboundStreams, a.myMaxNumInboundStreams) a.myMaxNumOutboundStreams = min16(i.numOutboundStreams, a.myMaxNumOutboundStreams) a.peerVerificationTag = i.initiateTag a.sourcePort = p.destinationPort a.destinationPort = p.sourcePort // 13.2 This is the last TSN received in sequence. This value // is set initially by taking the peer's initial TSN, // received in the INIT or INIT ACK chunk, and // subtracting one from it. a.peerLastTSN = i.initialTSN - 1 for _, param := range i.params { switch v := param.(type) { // nolint:gocritic case *paramSupportedExtensions: for _, t := range v.ChunkTypes { if t == ctForwardTSN { a.log.Debugf("[%s] use ForwardTSN (on init)", a.name) a.useForwardTSN = true } } } } if !a.useForwardTSN { a.log.Warnf("[%s] not using ForwardTSN (on init)", a.name) } outbound := &packet{} outbound.verificationTag = a.peerVerificationTag outbound.sourcePort = a.sourcePort outbound.destinationPort = a.destinationPort initAck := &chunkInitAck{} initAck.initialTSN = a.myNextTSN initAck.numOutboundStreams = a.myMaxNumOutboundStreams initAck.numInboundStreams = a.myMaxNumInboundStreams initAck.initiateTag = a.myVerificationTag initAck.advertisedReceiverWindowCredit = a.maxReceiveBufferSize if a.myCookie == nil { var err error if a.myCookie, err = newRandomStateCookie(); err != nil { return nil, err } } initAck.params = []param{a.myCookie} setSupportedExtensions(&initAck.chunkInitCommon) outbound.chunks = []chunk{initAck} return pack(outbound), nil } // The caller should hold the lock. func (a *Association) handleInitAck(p *packet, i *chunkInitAck) error { state := a.getState() a.log.Debugf("[%s] chunkInitAck received in state '%s'", a.name, getAssociationStateString(state)) if state != cookieWait { // RFC 4960 // 5.2.3. Unexpected INIT ACK // If an INIT ACK is received by an endpoint in any state other than the // COOKIE-WAIT state, the endpoint should discard the INIT ACK chunk. // An unexpected INIT ACK usually indicates the processing of an old or // duplicated INIT chunk. return nil } a.myMaxNumInboundStreams = min16(i.numInboundStreams, a.myMaxNumInboundStreams) a.myMaxNumOutboundStreams = min16(i.numOutboundStreams, a.myMaxNumOutboundStreams) a.peerVerificationTag = i.initiateTag a.peerLastTSN = i.initialTSN - 1 if a.sourcePort != p.destinationPort || a.destinationPort != p.sourcePort { a.log.Warnf("[%s] handleInitAck: port mismatch", a.name) return nil } a.rwnd = i.advertisedReceiverWindowCredit a.log.Debugf("[%s] initial rwnd=%d", a.name, a.rwnd) // RFC 4690 Sec 7.2.1 // o The initial value of ssthresh MAY be arbitrarily high (for // example, implementations MAY use the size of the receiver // advertised window). a.ssthresh = a.rwnd a.log.Tracef("[%s] updated cwnd=%d ssthresh=%d inflight=%d (INI)", a.name, a.cwnd, a.ssthresh, a.inflightQueue.getNumBytes()) a.t1Init.stop() a.storedInit = nil var cookieParam *paramStateCookie for _, param := range i.params { switch v := param.(type) { case *paramStateCookie: cookieParam = v case *paramSupportedExtensions: for _, t := range v.ChunkTypes { if t == ctForwardTSN { a.log.Debugf("[%s] use ForwardTSN (on initAck)", a.name) a.useForwardTSN = true } } } } if !a.useForwardTSN { a.log.Warnf("[%s] not using ForwardTSN (on initAck)", a.name) } if cookieParam == nil { return ErrInitAckNoCookie } a.storedCookieEcho = &chunkCookieEcho{} a.storedCookieEcho.cookie = cookieParam.cookie err := a.sendCookieEcho() if err != nil { a.log.Errorf("[%s] failed to send init: %s", a.name, err.Error()) } a.t1Cookie.start(a.rtoMgr.getRTO()) a.setState(cookieEchoed) return nil } // The caller should hold the lock. func (a *Association) handleHeartbeat(c *chunkHeartbeat) []*packet { a.log.Tracef("[%s] chunkHeartbeat", a.name) hbi, ok := c.params[0].(*paramHeartbeatInfo) if !ok { a.log.Warnf("[%s] failed to handle Heartbeat, no ParamHeartbeatInfo", a.name) } return pack(&packet{ verificationTag: a.peerVerificationTag, sourcePort: a.sourcePort, destinationPort: a.destinationPort, chunks: []chunk{&chunkHeartbeatAck{ params: []param{ ¶mHeartbeatInfo{ heartbeatInformation: hbi.heartbeatInformation, }, }, }}, }) } // The caller should hold the lock. func (a *Association) handleCookieEcho(c *chunkCookieEcho) []*packet { state := a.getState() a.log.Debugf("[%s] COOKIE-ECHO received in state '%s'", a.name, getAssociationStateString(state)) if a.myCookie == nil { a.log.Debugf("[%s] COOKIE-ECHO received before initialization", a.name) return nil } switch state { default: return nil case established: if !bytes.Equal(a.myCookie.cookie, c.cookie) { return nil } case closed, cookieWait, cookieEchoed: if !bytes.Equal(a.myCookie.cookie, c.cookie) { return nil } a.t1Init.stop() a.storedInit = nil a.t1Cookie.stop() a.storedCookieEcho = nil a.setState(established) a.handshakeCompletedCh <- nil } p := &packet{ verificationTag: a.peerVerificationTag, sourcePort: a.sourcePort, destinationPort: a.destinationPort, chunks: []chunk{&chunkCookieAck{}}, } return pack(p) } // The caller should hold the lock. func (a *Association) handleCookieAck() { state := a.getState() a.log.Debugf("[%s] COOKIE-ACK received in state '%s'", a.name, getAssociationStateString(state)) if state != cookieEchoed { // RFC 4960 // 5.2.5. Handle Duplicate COOKIE-ACK. // At any state other than COOKIE-ECHOED, an endpoint should silently // discard a received COOKIE ACK chunk. return } a.t1Cookie.stop() a.storedCookieEcho = nil a.setState(established) a.handshakeCompletedCh <- nil } // The caller should hold the lock. func (a *Association) handleData(d *chunkPayloadData) []*packet { a.log.Tracef("[%s] DATA: tsn=%d immediateSack=%v len=%d", a.name, d.tsn, d.immediateSack, len(d.userData)) a.stats.incDATAs() canPush := a.payloadQueue.canPush(d, a.peerLastTSN) if canPush { s := a.getOrCreateStream(d.streamIdentifier, true, PayloadTypeUnknown) if s == nil { // silentely discard the data. (sender will retry on T3-rtx timeout) // see pion/sctp#30 a.log.Debugf("discard %d", d.streamSequenceNumber) return nil } if a.getMyReceiverWindowCredit() > 0 { // Pass the new chunk to stream level as soon as it arrives a.payloadQueue.push(d, a.peerLastTSN) s.handleData(d) } else { // Receive buffer is full lastTSN, ok := a.payloadQueue.getLastTSNReceived() if ok && sna32LT(d.tsn, lastTSN) { a.log.Debugf("[%s] receive buffer full, but accepted as this is a missing chunk with tsn=%d ssn=%d", a.name, d.tsn, d.streamSequenceNumber) a.payloadQueue.push(d, a.peerLastTSN) s.handleData(d) } else { a.log.Debugf("[%s] receive buffer full. dropping DATA with tsn=%d ssn=%d", a.name, d.tsn, d.streamSequenceNumber) } } } return a.handlePeerLastTSNAndAcknowledgement(d.immediateSack) } // A common routine for handleData and handleForwardTSN routines // The caller should hold the lock. func (a *Association) handlePeerLastTSNAndAcknowledgement(sackImmediately bool) []*packet { var reply []*packet // Try to advance peerLastTSN // From RFC 3758 Sec 3.6: // .. and then MUST further advance its cumulative TSN point locally // if possible // Meaning, if peerLastTSN+1 points to a chunk that is received, // advance peerLastTSN until peerLastTSN+1 points to unreceived chunk. for { if _, popOk := a.payloadQueue.pop(a.peerLastTSN + 1); !popOk { break } a.peerLastTSN++ for _, rstReq := range a.reconfigRequests { resp := a.resetStreamsIfAny(rstReq) if resp != nil { a.log.Debugf("[%s] RESET RESPONSE: %+v", a.name, resp) reply = append(reply, resp) } } } hasPacketLoss := (a.payloadQueue.size() > 0) if hasPacketLoss { a.log.Tracef("[%s] packetloss: %s", a.name, a.payloadQueue.getGapAckBlocksString(a.peerLastTSN)) } if (a.ackState != ackStateImmediate && !sackImmediately && !hasPacketLoss && a.ackMode == ackModeNormal) || a.ackMode == ackModeAlwaysDelay { if a.ackState == ackStateIdle { a.delayedAckTriggered = true } else { a.immediateAckTriggered = true } } else { a.immediateAckTriggered = true } return reply } // The caller should hold the lock. func (a *Association) getMyReceiverWindowCredit() uint32 { var bytesQueued uint32 for _, s := range a.streams { bytesQueued += uint32(s.getNumBytesInReassemblyQueue()) } if bytesQueued >= a.maxReceiveBufferSize { return 0 } return a.maxReceiveBufferSize - bytesQueued } // OpenStream opens a stream func (a *Association) OpenStream(streamIdentifier uint16, defaultPayloadType PayloadProtocolIdentifier) (*Stream, error) { a.lock.Lock() defer a.lock.Unlock() return a.getOrCreateStream(streamIdentifier, false, defaultPayloadType), nil } // AcceptStream accepts a stream func (a *Association) AcceptStream() (*Stream, error) { s, ok := <-a.acceptCh if !ok { return nil, io.EOF // no more incoming streams } return s, nil } // createStream creates a stream. The caller should hold the lock and check no stream exists for this id. func (a *Association) createStream(streamIdentifier uint16, accept bool) *Stream { s := &Stream{ association: a, streamIdentifier: streamIdentifier, reassemblyQueue: newReassemblyQueue(streamIdentifier), log: a.log, name: fmt.Sprintf("%d:%s", streamIdentifier, a.name), } s.readNotifier = sync.NewCond(&s.lock) if accept { select { case a.acceptCh <- s: a.streams[streamIdentifier] = s a.log.Debugf("[%s] accepted a new stream (streamIdentifier: %d)", a.name, streamIdentifier) default: a.log.Debugf("[%s] dropped a new stream (acceptCh size: %d)", a.name, len(a.acceptCh)) return nil } } else { a.streams[streamIdentifier] = s } return s } // getOrCreateStream gets or creates a stream. The caller should hold the lock. func (a *Association) getOrCreateStream(streamIdentifier uint16, accept bool, defaultPayloadType PayloadProtocolIdentifier) *Stream { if s, ok := a.streams[streamIdentifier]; ok { s.SetDefaultPayloadType(defaultPayloadType) return s } s := a.createStream(streamIdentifier, accept) if s != nil { s.SetDefaultPayloadType(defaultPayloadType) } return s } // The caller should hold the lock. func (a *Association) processSelectiveAck(d *chunkSelectiveAck) (map[uint16]int, uint32, error) { // nolint:gocognit bytesAckedPerStream := map[uint16]int{} // New ack point, so pop all ACKed packets from inflightQueue // We add 1 because the "currentAckPoint" has already been popped from the inflight queue // For the first SACK we take care of this by setting the ackpoint to cumAck - 1 for i := a.cumulativeTSNAckPoint + 1; sna32LTE(i, d.cumulativeTSNAck); i++ { c, ok := a.inflightQueue.pop(i) if !ok { return nil, 0, fmt.Errorf("%w: %v", ErrInflightQueueTSNPop, i) } if !c.acked { // RFC 4096 sec 6.3.2. Retransmission Timer Rules // R3) Whenever a SACK is received that acknowledges the DATA chunk // with the earliest outstanding TSN for that address, restart the // T3-rtx timer for that address with its current RTO (if there is // still outstanding data on that address). if i == a.cumulativeTSNAckPoint+1 { // T3 timer needs to be reset. Stop it for now. a.t3RTX.stop() } nBytesAcked := len(c.userData) // Sum the number of bytes acknowledged per stream if amount, ok := bytesAckedPerStream[c.streamIdentifier]; ok { bytesAckedPerStream[c.streamIdentifier] = amount + nBytesAcked } else { bytesAckedPerStream[c.streamIdentifier] = nBytesAcked } // RFC 4960 sec 6.3.1. RTO Calculation // C4) When data is in flight and when allowed by rule C5 below, a new // RTT measurement MUST be made each round trip. Furthermore, new // RTT measurements SHOULD be made no more than once per round trip // for a given destination transport address. // C5) Karn's algorithm: RTT measurements MUST NOT be made using // packets that were retransmitted (and thus for which it is // ambiguous whether the reply was for the first instance of the // chunk or for a later instance) if c.nSent == 1 && sna32GTE(c.tsn, a.minTSN2MeasureRTT) { a.minTSN2MeasureRTT = a.myNextTSN rtt := time.Since(c.since).Seconds() * 1000.0 srtt := a.rtoMgr.setNewRTT(rtt) a.log.Tracef("[%s] SACK: measured-rtt=%f srtt=%f new-rto=%f", a.name, rtt, srtt, a.rtoMgr.getRTO()) } } if a.inFastRecovery && c.tsn == a.fastRecoverExitPoint { a.log.Debugf("[%s] exit fast-recovery", a.name) a.inFastRecovery = false } } htna := d.cumulativeTSNAck // Mark selectively acknowledged chunks as "acked" for _, g := range d.gapAckBlocks { for i := g.start; i <= g.end; i++ { tsn := d.cumulativeTSNAck + uint32(i) c, ok := a.inflightQueue.get(tsn) if !ok { return nil, 0, fmt.Errorf("%w: %v", ErrTSNRequestNotExist, tsn) } if !c.acked { nBytesAcked := a.inflightQueue.markAsAcked(tsn) // Sum the number of bytes acknowledged per stream if amount, ok := bytesAckedPerStream[c.streamIdentifier]; ok { bytesAckedPerStream[c.streamIdentifier] = amount + nBytesAcked } else { bytesAckedPerStream[c.streamIdentifier] = nBytesAcked } a.log.Tracef("[%s] tsn=%d has been sacked", a.name, c.tsn) if c.nSent == 1 { a.minTSN2MeasureRTT = a.myNextTSN rtt := time.Since(c.since).Seconds() * 1000.0 srtt := a.rtoMgr.setNewRTT(rtt) a.log.Tracef("[%s] SACK: measured-rtt=%f srtt=%f new-rto=%f", a.name, rtt, srtt, a.rtoMgr.getRTO()) } if sna32LT(htna, tsn) { htna = tsn } } } } return bytesAckedPerStream, htna, nil } // The caller should hold the lock. func (a *Association) onCumulativeTSNAckPointAdvanced(totalBytesAcked int) { // RFC 4096, sec 6.3.2. Retransmission Timer Rules // R2) Whenever all outstanding data sent to an address have been // acknowledged, turn off the T3-rtx timer of that address. if a.inflightQueue.size() == 0 { a.log.Tracef("[%s] SACK: no more packet in-flight (pending=%d)", a.name, a.pendingQueue.size()) a.t3RTX.stop() } else { a.log.Tracef("[%s] T3-rtx timer start (pt2)", a.name) a.t3RTX.start(a.rtoMgr.getRTO()) } // Update congestion control parameters if a.cwnd <= a.ssthresh { // RFC 4096, sec 7.2.1. Slow-Start // o When cwnd is less than or equal to ssthresh, an SCTP endpoint MUST // use the slow-start algorithm to increase cwnd only if the current // congestion window is being fully utilized, an incoming SACK // advances the Cumulative TSN Ack Point, and the data sender is not // in Fast Recovery. Only when these three conditions are met can // the cwnd be increased; otherwise, the cwnd MUST not be increased. // If these conditions are met, then cwnd MUST be increased by, at // most, the lesser of 1) the total size of the previously // outstanding DATA chunk(s) acknowledged, and 2) the destination's // path MTU. if !a.inFastRecovery && a.pendingQueue.size() > 0 { a.cwnd += min32(uint32(totalBytesAcked), a.cwnd) // TCP way // a.cwnd += min32(uint32(totalBytesAcked), a.mtu) // SCTP way (slow) a.log.Tracef("[%s] updated cwnd=%d ssthresh=%d acked=%d (SS)", a.name, a.cwnd, a.ssthresh, totalBytesAcked) } else { a.log.Tracef("[%s] cwnd did not grow: cwnd=%d ssthresh=%d acked=%d FR=%v pending=%d", a.name, a.cwnd, a.ssthresh, totalBytesAcked, a.inFastRecovery, a.pendingQueue.size()) } } else { // RFC 4096, sec 7.2.2. Congestion Avoidance // o Whenever cwnd is greater than ssthresh, upon each SACK arrival // that advances the Cumulative TSN Ack Point, increase // partial_bytes_acked by the total number of bytes of all new chunks // acknowledged in that SACK including chunks acknowledged by the new // Cumulative TSN Ack and by Gap Ack Blocks. a.partialBytesAcked += uint32(totalBytesAcked) // o When partial_bytes_acked is equal to or greater than cwnd and // before the arrival of the SACK the sender had cwnd or more bytes // of data outstanding (i.e., before arrival of the SACK, flight size // was greater than or equal to cwnd), increase cwnd by MTU, and // reset partial_bytes_acked to (partial_bytes_acked - cwnd). if a.partialBytesAcked >= a.cwnd && a.pendingQueue.size() > 0 { a.partialBytesAcked -= a.cwnd a.cwnd += a.mtu a.log.Tracef("[%s] updated cwnd=%d ssthresh=%d acked=%d (CA)", a.name, a.cwnd, a.ssthresh, totalBytesAcked) } } } // The caller should hold the lock. func (a *Association) processFastRetransmission(cumTSNAckPoint, htna uint32, cumTSNAckPointAdvanced bool) error { // HTNA algorithm - RFC 4960 Sec 7.2.4 // Increment missIndicator of each chunks that the SACK reported missing // when either of the following is met: // a) Not in fast-recovery // miss indications are incremented only for missing TSNs prior to the // highest TSN newly acknowledged in the SACK. // b) In fast-recovery AND the Cumulative TSN Ack Point advanced // the miss indications are incremented for all TSNs reported missing // in the SACK. if !a.inFastRecovery || (a.inFastRecovery && cumTSNAckPointAdvanced) { var maxTSN uint32 if !a.inFastRecovery { // a) increment only for missing TSNs prior to the HTNA maxTSN = htna } else { // b) increment for all TSNs reported missing maxTSN = cumTSNAckPoint + uint32(a.inflightQueue.size()) + 1 } for tsn := cumTSNAckPoint + 1; sna32LT(tsn, maxTSN); tsn++ { c, ok := a.inflightQueue.get(tsn) if !ok { return fmt.Errorf("%w: %v", ErrTSNRequestNotExist, tsn) } if !c.acked && !c.abandoned() && c.missIndicator < 3 { c.missIndicator++ if c.missIndicator == 3 { if !a.inFastRecovery { // 2) If not in Fast Recovery, adjust the ssthresh and cwnd of the // destination address(es) to which the missing DATA chunks were // last sent, according to the formula described in Section 7.2.3. a.inFastRecovery = true a.fastRecoverExitPoint = htna a.ssthresh = max32(a.cwnd/2, 4*a.mtu) a.cwnd = a.ssthresh a.partialBytesAcked = 0 a.willRetransmitFast = true a.log.Tracef("[%s] updated cwnd=%d ssthresh=%d inflight=%d (FR)", a.name, a.cwnd, a.ssthresh, a.inflightQueue.getNumBytes()) } } } } } if a.inFastRecovery && cumTSNAckPointAdvanced { a.willRetransmitFast = true } return nil } // The caller should hold the lock. func (a *Association) handleSack(d *chunkSelectiveAck) error { a.log.Tracef("[%s] SACK: cumTSN=%d a_rwnd=%d", a.name, d.cumulativeTSNAck, d.advertisedReceiverWindowCredit) state := a.getState() if state != established && state != shutdownPending && state != shutdownReceived { return nil } a.stats.incSACKs() if sna32GT(a.cumulativeTSNAckPoint, d.cumulativeTSNAck) { // RFC 4960 sec 6.2.1. Processing a Received SACK // D) // i) If Cumulative TSN Ack is less than the Cumulative TSN Ack // Point, then drop the SACK. Since Cumulative TSN Ack is // monotonically increasing, a SACK whose Cumulative TSN Ack is // less than the Cumulative TSN Ack Point indicates an out-of- // order SACK. a.log.Debugf("[%s] SACK Cumulative ACK %v is older than ACK point %v", a.name, d.cumulativeTSNAck, a.cumulativeTSNAckPoint) return nil } // Process selective ack bytesAckedPerStream, htna, err := a.processSelectiveAck(d) if err != nil { return err } var totalBytesAcked int for _, nBytesAcked := range bytesAckedPerStream { totalBytesAcked += nBytesAcked } cumTSNAckPointAdvanced := false if sna32LT(a.cumulativeTSNAckPoint, d.cumulativeTSNAck) { a.log.Tracef("[%s] SACK: cumTSN advanced: %d -> %d", a.name, a.cumulativeTSNAckPoint, d.cumulativeTSNAck) a.cumulativeTSNAckPoint = d.cumulativeTSNAck cumTSNAckPointAdvanced = true a.onCumulativeTSNAckPointAdvanced(totalBytesAcked) } for si, nBytesAcked := range bytesAckedPerStream { if s, ok := a.streams[si]; ok { a.lock.Unlock() s.onBufferReleased(nBytesAcked) a.lock.Lock() } } // New rwnd value // RFC 4960 sec 6.2.1. Processing a Received SACK // D) // ii) Set rwnd equal to the newly received a_rwnd minus the number // of bytes still outstanding after processing the Cumulative // TSN Ack and the Gap Ack Blocks. // bytes acked were already subtracted by markAsAcked() method bytesOutstanding := uint32(a.inflightQueue.getNumBytes()) if bytesOutstanding >= d.advertisedReceiverWindowCredit { a.rwnd = 0 } else { a.rwnd = d.advertisedReceiverWindowCredit - bytesOutstanding } err = a.processFastRetransmission(d.cumulativeTSNAck, htna, cumTSNAckPointAdvanced) if err != nil { return err } if a.useForwardTSN { // RFC 3758 Sec 3.5 C1 if sna32LT(a.advancedPeerTSNAckPoint, a.cumulativeTSNAckPoint) { a.advancedPeerTSNAckPoint = a.cumulativeTSNAckPoint } // RFC 3758 Sec 3.5 C2 for i := a.advancedPeerTSNAckPoint + 1; ; i++ { c, ok := a.inflightQueue.get(i) if !ok { break } if !c.abandoned() { break } a.advancedPeerTSNAckPoint = i } // RFC 3758 Sec 3.5 C3 if sna32GT(a.advancedPeerTSNAckPoint, a.cumulativeTSNAckPoint) { a.willSendForwardTSN = true } a.awakeWriteLoop() } a.postprocessSack(state, cumTSNAckPointAdvanced) return nil } // The caller must hold the lock. This method was only added because the // linter was complaining about the "cognitive complexity" of handleSack. func (a *Association) postprocessSack(state uint32, shouldAwakeWriteLoop bool) { switch { case a.inflightQueue.size() > 0: // Start timer. (noop if already started) a.log.Tracef("[%s] T3-rtx timer start (pt3)", a.name) a.t3RTX.start(a.rtoMgr.getRTO()) case state == shutdownPending: // No more outstanding, send shutdown. shouldAwakeWriteLoop = true a.willSendShutdown = true a.setState(shutdownSent) case state == shutdownReceived: // No more outstanding, send shutdown ack. shouldAwakeWriteLoop = true a.willSendShutdownAck = true a.setState(shutdownAckSent) } if shouldAwakeWriteLoop { a.awakeWriteLoop() } } // The caller should hold the lock. func (a *Association) handleShutdown(_ *chunkShutdown) { state := a.getState() switch state { case established: if a.inflightQueue.size() > 0 { a.setState(shutdownReceived) } else { // No more outstanding, send shutdown ack. a.willSendShutdownAck = true a.setState(shutdownAckSent) a.awakeWriteLoop() } // a.cumulativeTSNAckPoint = c.cumulativeTSNAck case shutdownSent: a.willSendShutdownAck = true a.setState(shutdownAckSent) a.awakeWriteLoop() } } // The caller should hold the lock. func (a *Association) handleShutdownAck(_ *chunkShutdownAck) { state := a.getState() if state == shutdownSent || state == shutdownAckSent { a.t2Shutdown.stop() a.willSendShutdownComplete = true a.awakeWriteLoop() } } func (a *Association) handleShutdownComplete(_ *chunkShutdownComplete) error { state := a.getState() if state == shutdownAckSent { a.t2Shutdown.stop() return a.close() } return nil } func (a *Association) handleAbort(c *chunkAbort) error { var errStr string for _, e := range c.errorCauses { errStr += fmt.Sprintf("(%s)", e) } _ = a.close() return fmt.Errorf("[%s] %w: %s", a.name, ErrChunk, errStr) } // createForwardTSN generates ForwardTSN chunk. // This method will be be called if useForwardTSN is set to false. // The caller should hold the lock. func (a *Association) createForwardTSN() *chunkForwardTSN { // RFC 3758 Sec 3.5 C4 streamMap := map[uint16]uint16{} // to report only once per SI for i := a.cumulativeTSNAckPoint + 1; sna32LTE(i, a.advancedPeerTSNAckPoint); i++ { c, ok := a.inflightQueue.get(i) if !ok { break } ssn, ok := streamMap[c.streamIdentifier] if !ok { streamMap[c.streamIdentifier] = c.streamSequenceNumber } else if sna16LT(ssn, c.streamSequenceNumber) { // to report only once with greatest SSN streamMap[c.streamIdentifier] = c.streamSequenceNumber } } fwdtsn := &chunkForwardTSN{ newCumulativeTSN: a.advancedPeerTSNAckPoint, streams: []chunkForwardTSNStream{}, } var streamStr string for si, ssn := range streamMap { streamStr += fmt.Sprintf("(si=%d ssn=%d)", si, ssn) fwdtsn.streams = append(fwdtsn.streams, chunkForwardTSNStream{ identifier: si, sequence: ssn, }) } a.log.Tracef("[%s] building fwdtsn: newCumulativeTSN=%d cumTSN=%d - %s", a.name, fwdtsn.newCumulativeTSN, a.cumulativeTSNAckPoint, streamStr) return fwdtsn } // createPacket wraps chunks in a packet. // The caller should hold the read lock. func (a *Association) createPacket(cs []chunk) *packet { return &packet{ verificationTag: a.peerVerificationTag, sourcePort: a.sourcePort, destinationPort: a.destinationPort, chunks: cs, } } // The caller should hold the lock. func (a *Association) handleReconfig(c *chunkReconfig) ([]*packet, error) { a.log.Tracef("[%s] handleReconfig", a.name) pp := make([]*packet, 0) p, err := a.handleReconfigParam(c.paramA) if err != nil { return nil, err } if p != nil { pp = append(pp, p) } if c.paramB != nil { p, err = a.handleReconfigParam(c.paramB) if err != nil { return nil, err } if p != nil { pp = append(pp, p) } } return pp, nil } // The caller should hold the lock. func (a *Association) handleForwardTSN(c *chunkForwardTSN) []*packet { a.log.Tracef("[%s] FwdTSN: %s", a.name, c.String()) if !a.useForwardTSN { a.log.Warn("[%s] received FwdTSN but not enabled") // Return an error chunk cerr := &chunkError{ errorCauses: []errorCause{&errorCauseUnrecognizedChunkType{}}, } outbound := &packet{} outbound.verificationTag = a.peerVerificationTag outbound.sourcePort = a.sourcePort outbound.destinationPort = a.destinationPort outbound.chunks = []chunk{cerr} return []*packet{outbound} } // From RFC 3758 Sec 3.6: // Note, if the "New Cumulative TSN" value carried in the arrived // FORWARD TSN chunk is found to be behind or at the current cumulative // TSN point, the data receiver MUST treat this FORWARD TSN as out-of- // date and MUST NOT update its Cumulative TSN. The receiver SHOULD // send a SACK to its peer (the sender of the FORWARD TSN) since such a // duplicate may indicate the previous SACK was lost in the network. a.log.Tracef("[%s] should send ack? newCumTSN=%d peerLastTSN=%d", a.name, c.newCumulativeTSN, a.peerLastTSN) if sna32LTE(c.newCumulativeTSN, a.peerLastTSN) { a.log.Tracef("[%s] sending ack on Forward TSN", a.name) a.ackState = ackStateImmediate a.ackTimer.stop() a.awakeWriteLoop() return nil } // From RFC 3758 Sec 3.6: // the receiver MUST perform the same TSN handling, including duplicate // detection, gap detection, SACK generation, cumulative TSN // advancement, etc. as defined in RFC 2960 [2]---with the following // exceptions and additions. // When a FORWARD TSN chunk arrives, the data receiver MUST first update // its cumulative TSN point to the value carried in the FORWARD TSN // chunk, // Advance peerLastTSN for sna32LT(a.peerLastTSN, c.newCumulativeTSN) { a.payloadQueue.pop(a.peerLastTSN + 1) // may not exist a.peerLastTSN++ } // Report new peerLastTSN value and abandoned largest SSN value to // corresponding streams so that the abandoned chunks can be removed // from the reassemblyQueue. for _, forwarded := range c.streams { if s, ok := a.streams[forwarded.identifier]; ok { s.handleForwardTSNForOrdered(forwarded.sequence) } } // TSN may be forewared for unordered chunks. ForwardTSN chunk does not // report which stream identifier it skipped for unordered chunks. // Therefore, we need to broadcast this event to all existing streams for // unordered chunks. // See https://github.com/pion/sctp/issues/106 for _, s := range a.streams { s.handleForwardTSNForUnordered(c.newCumulativeTSN) } return a.handlePeerLastTSNAndAcknowledgement(false) } func (a *Association) sendResetRequest(streamIdentifier uint16) error { a.lock.Lock() defer a.lock.Unlock() state := a.getState() if state != established { return fmt.Errorf("%w: state=%s", ErrResetPacketInStateNotExist, getAssociationStateString(state)) } // Create DATA chunk which only contains valid stream identifier with // nil userData and use it as a EOS from the stream. c := &chunkPayloadData{ streamIdentifier: streamIdentifier, beginningFragment: true, endingFragment: true, userData: nil, } a.pendingQueue.push(c) a.awakeWriteLoop() return nil } // The caller should hold the lock. func (a *Association) handleReconfigParam(raw param) (*packet, error) { switch p := raw.(type) { case *paramOutgoingResetRequest: a.log.Tracef("[%s] handleReconfigParam (OutgoingResetRequest)", a.name) a.reconfigRequests[p.reconfigRequestSequenceNumber] = p resp := a.resetStreamsIfAny(p) if resp != nil { return resp, nil } return nil, nil //nolint:nilnil case *paramReconfigResponse: a.log.Tracef("[%s] handleReconfigParam (ReconfigResponse)", a.name) delete(a.reconfigs, p.reconfigResponseSequenceNumber) if len(a.reconfigs) == 0 { a.tReconfig.stop() } return nil, nil //nolint:nilnil default: return nil, fmt.Errorf("%w: %t", ErrParamterType, p) } } // The caller should hold the lock. func (a *Association) resetStreamsIfAny(p *paramOutgoingResetRequest) *packet { result := reconfigResultSuccessPerformed if sna32LTE(p.senderLastTSN, a.peerLastTSN) { a.log.Debugf("[%s] resetStream(): senderLastTSN=%d <= peerLastTSN=%d", a.name, p.senderLastTSN, a.peerLastTSN) for _, id := range p.streamIdentifiers { s, ok := a.streams[id] if !ok { continue } a.lock.Unlock() s.onInboundStreamReset() a.lock.Lock() a.log.Debugf("[%s] deleting stream %d", a.name, id) delete(a.streams, s.streamIdentifier) } delete(a.reconfigRequests, p.reconfigRequestSequenceNumber) } else { a.log.Debugf("[%s] resetStream(): senderLastTSN=%d > peerLastTSN=%d", a.name, p.senderLastTSN, a.peerLastTSN) result = reconfigResultInProgress } return a.createPacket([]chunk{&chunkReconfig{ paramA: ¶mReconfigResponse{ reconfigResponseSequenceNumber: p.reconfigRequestSequenceNumber, result: result, }, }}) } // Move the chunk peeked with a.pendingQueue.peek() to the inflightQueue. // The caller should hold the lock. func (a *Association) movePendingDataChunkToInflightQueue(c *chunkPayloadData) { if err := a.pendingQueue.pop(c); err != nil { a.log.Errorf("[%s] failed to pop from pending queue: %s", a.name, err.Error()) } // Mark all fragements are in-flight now if c.endingFragment { c.setAllInflight() } // Assign TSN c.tsn = a.generateNextTSN() c.since = time.Now() // use to calculate RTT and also for maxPacketLifeTime c.nSent = 1 // being sent for the first time a.checkPartialReliabilityStatus(c) a.log.Tracef("[%s] sending ppi=%d tsn=%d ssn=%d sent=%d len=%d (%v,%v)", a.name, c.payloadType, c.tsn, c.streamSequenceNumber, c.nSent, len(c.userData), c.beginningFragment, c.endingFragment) a.inflightQueue.pushNoCheck(c) } // popPendingDataChunksToSend pops chunks from the pending queues as many as // the cwnd and rwnd allows to send. // The caller should hold the lock. func (a *Association) popPendingDataChunksToSend() ([]*chunkPayloadData, []uint16) { chunks := []*chunkPayloadData{} var sisToReset []uint16 // stream identifieres to reset if a.pendingQueue.size() > 0 { // RFC 4960 sec 6.1. Transmission of DATA Chunks // A) At any given time, the data sender MUST NOT transmit new data to // any destination transport address if its peer's rwnd indicates // that the peer has no buffer space (i.e., rwnd is 0; see Section // 6.2.1). However, regardless of the value of rwnd (including if it // is 0), the data sender can always have one DATA chunk in flight to // the receiver if allowed by cwnd (see rule B, below). for { c := a.pendingQueue.peek() if c == nil { break // no more pending data } dataLen := uint32(len(c.userData)) if dataLen == 0 { sisToReset = append(sisToReset, c.streamIdentifier) err := a.pendingQueue.pop(c) if err != nil { a.log.Errorf("failed to pop from pending queue: %s", err.Error()) } continue } if uint32(a.inflightQueue.getNumBytes())+dataLen > a.cwnd { break // would exceeds cwnd } if dataLen > a.rwnd { break // no more rwnd } a.rwnd -= dataLen a.movePendingDataChunkToInflightQueue(c) chunks = append(chunks, c) } // the data sender can always have one DATA chunk in flight to the receiver if len(chunks) == 0 && a.inflightQueue.size() == 0 { // Send zero window probe c := a.pendingQueue.peek() if c != nil { a.movePendingDataChunkToInflightQueue(c) chunks = append(chunks, c) } } } return chunks, sisToReset } // bundleDataChunksIntoPackets packs DATA chunks into packets. It tries to bundle // DATA chunks into a packet so long as the resulting packet size does not exceed // the path MTU. // The caller should hold the lock. func (a *Association) bundleDataChunksIntoPackets(chunks []*chunkPayloadData) []*packet { packets := []*packet{} chunksToSend := []chunk{} bytesInPacket := int(commonHeaderSize) for _, c := range chunks { // RFC 4960 sec 6.1. Transmission of DATA Chunks // Multiple DATA chunks committed for transmission MAY be bundled in a // single packet. Furthermore, DATA chunks being retransmitted MAY be // bundled with new DATA chunks, as long as the resulting packet size // does not exceed the path MTU. if bytesInPacket+len(c.userData) > int(a.mtu) { packets = append(packets, a.createPacket(chunksToSend)) chunksToSend = []chunk{} bytesInPacket = int(commonHeaderSize) } chunksToSend = append(chunksToSend, c) bytesInPacket += int(dataChunkHeaderSize) + len(c.userData) } if len(chunksToSend) > 0 { packets = append(packets, a.createPacket(chunksToSend)) } return packets } // sendPayloadData sends the data chunks. func (a *Association) sendPayloadData(chunks []*chunkPayloadData) error { a.lock.Lock() defer a.lock.Unlock() state := a.getState() if state != established { return fmt.Errorf("%w: state=%s", ErrPayloadDataStateNotExist, getAssociationStateString(state)) } // Push the chunks into the pending queue first. for _, c := range chunks { a.pendingQueue.push(c) } a.awakeWriteLoop() return nil } // The caller should hold the lock. func (a *Association) checkPartialReliabilityStatus(c *chunkPayloadData) { if !a.useForwardTSN { return } // draft-ietf-rtcweb-data-protocol-09.txt section 6 // 6. Procedures // All Data Channel Establishment Protocol messages MUST be sent using // ordered delivery and reliable transmission. // if c.payloadType == PayloadTypeWebRTCDCEP { return } // PR-SCTP if s, ok := a.streams[c.streamIdentifier]; ok { s.lock.RLock() if s.reliabilityType == ReliabilityTypeRexmit { if c.nSent >= s.reliabilityValue { c.setAbandoned(true) a.log.Tracef("[%s] marked as abandoned: tsn=%d ppi=%d (remix: %d)", a.name, c.tsn, c.payloadType, c.nSent) } } else if s.reliabilityType == ReliabilityTypeTimed { elapsed := int64(time.Since(c.since).Seconds() * 1000) if elapsed >= int64(s.reliabilityValue) { c.setAbandoned(true) a.log.Tracef("[%s] marked as abandoned: tsn=%d ppi=%d (timed: %d)", a.name, c.tsn, c.payloadType, elapsed) } } s.lock.RUnlock() } else { a.log.Errorf("[%s] stream %d not found)", a.name, c.streamIdentifier) } } // getDataPacketsToRetransmit is called when T3-rtx is timed out and retransmit outstanding data chunks // that are not acked or abandoned yet. // The caller should hold the lock. func (a *Association) getDataPacketsToRetransmit() []*packet { awnd := min32(a.cwnd, a.rwnd) chunks := []*chunkPayloadData{} var bytesToSend int var done bool for i := 0; !done; i++ { c, ok := a.inflightQueue.get(a.cumulativeTSNAckPoint + uint32(i) + 1) if !ok { break // end of pending data } if !c.retransmit { continue } if i == 0 && int(a.rwnd) < len(c.userData) { // Send it as a zero window probe done = true } else if bytesToSend+len(c.userData) > int(awnd) { break } // reset the retransmit flag not to retransmit again before the next // t3-rtx timer fires c.retransmit = false bytesToSend += len(c.userData) c.nSent++ a.checkPartialReliabilityStatus(c) a.log.Tracef("[%s] retransmitting tsn=%d ssn=%d sent=%d", a.name, c.tsn, c.streamSequenceNumber, c.nSent) chunks = append(chunks, c) } return a.bundleDataChunksIntoPackets(chunks) } // generateNextTSN returns the myNextTSN and increases it. The caller should hold the lock. // The caller should hold the lock. func (a *Association) generateNextTSN() uint32 { tsn := a.myNextTSN a.myNextTSN++ return tsn } // generateNextRSN returns the myNextRSN and increases it. The caller should hold the lock. // The caller should hold the lock. func (a *Association) generateNextRSN() uint32 { rsn := a.myNextRSN a.myNextRSN++ return rsn } func (a *Association) createSelectiveAckChunk() *chunkSelectiveAck { sack := &chunkSelectiveAck{} sack.cumulativeTSNAck = a.peerLastTSN sack.advertisedReceiverWindowCredit = a.getMyReceiverWindowCredit() sack.duplicateTSN = a.payloadQueue.popDuplicates() sack.gapAckBlocks = a.payloadQueue.getGapAckBlocks(a.peerLastTSN) return sack } func pack(p *packet) []*packet { return []*packet{p} } func (a *Association) handleChunkStart() { a.lock.Lock() defer a.lock.Unlock() a.delayedAckTriggered = false a.immediateAckTriggered = false } func (a *Association) handleChunkEnd() { a.lock.Lock() defer a.lock.Unlock() if a.immediateAckTriggered { a.ackState = ackStateImmediate a.ackTimer.stop() a.awakeWriteLoop() } else if a.delayedAckTriggered { // Will send delayed ack in the next ack timeout a.ackState = ackStateDelay a.ackTimer.start() } } func (a *Association) handleChunk(p *packet, c chunk) error { a.lock.Lock() defer a.lock.Unlock() var packets []*packet var err error if _, err = c.check(); err != nil { a.log.Errorf("[ %s ] failed validating chunk: %s ", a.name, err) return nil } isAbort := false switch c := c.(type) { case *chunkInit: packets, err = a.handleInit(p, c) case *chunkInitAck: err = a.handleInitAck(p, c) case *chunkAbort: isAbort = true err = a.handleAbort(c) case *chunkError: var errStr string for _, e := range c.errorCauses { errStr += fmt.Sprintf("(%s)", e) } a.log.Debugf("[%s] Error chunk, with following errors: %s", a.name, errStr) case *chunkHeartbeat: packets = a.handleHeartbeat(c) case *chunkCookieEcho: packets = a.handleCookieEcho(c) case *chunkCookieAck: a.handleCookieAck() case *chunkPayloadData: packets = a.handleData(c) case *chunkSelectiveAck: err = a.handleSack(c) case *chunkReconfig: packets, err = a.handleReconfig(c) case *chunkForwardTSN: packets = a.handleForwardTSN(c) case *chunkShutdown: a.handleShutdown(c) case *chunkShutdownAck: a.handleShutdownAck(c) case *chunkShutdownComplete: err = a.handleShutdownComplete(c) default: err = ErrChunkTypeUnhandled } // Log and return, the only condition that is fatal is a ABORT chunk if err != nil { if isAbort { return err } a.log.Errorf("Failed to handle chunk: %v", err) return nil } if len(packets) > 0 { a.controlQueue.pushAll(packets) a.awakeWriteLoop() } return nil } func (a *Association) onRetransmissionTimeout(id int, nRtos uint) { a.lock.Lock() defer a.lock.Unlock() if id == timerT1Init { err := a.sendInit() if err != nil { a.log.Debugf("[%s] failed to retransmit init (nRtos=%d): %v", a.name, nRtos, err) } return } if id == timerT1Cookie { err := a.sendCookieEcho() if err != nil { a.log.Debugf("[%s] failed to retransmit cookie-echo (nRtos=%d): %v", a.name, nRtos, err) } return } if id == timerT2Shutdown { a.log.Debugf("[%s] retransmission of shutdown timeout (nRtos=%d): %v", a.name, nRtos) state := a.getState() switch state { case shutdownSent: a.willSendShutdown = true a.awakeWriteLoop() case shutdownAckSent: a.willSendShutdownAck = true a.awakeWriteLoop() } } if id == timerT3RTX { a.stats.incT3Timeouts() // RFC 4960 sec 6.3.3 // E1) For the destination address for which the timer expires, adjust // its ssthresh with rules defined in Section 7.2.3 and set the // cwnd <- MTU. // RFC 4960 sec 7.2.3 // When the T3-rtx timer expires on an address, SCTP should perform slow // start by: // ssthresh = max(cwnd/2, 4*MTU) // cwnd = 1*MTU a.ssthresh = max32(a.cwnd/2, 4*a.mtu) a.cwnd = a.mtu a.log.Tracef("[%s] updated cwnd=%d ssthresh=%d inflight=%d (RTO)", a.name, a.cwnd, a.ssthresh, a.inflightQueue.getNumBytes()) // RFC 3758 sec 3.5 // A5) Any time the T3-rtx timer expires, on any destination, the sender // SHOULD try to advance the "Advanced.Peer.Ack.Point" by following // the procedures outlined in C2 - C5. if a.useForwardTSN { // RFC 3758 Sec 3.5 C2 for i := a.advancedPeerTSNAckPoint + 1; ; i++ { c, ok := a.inflightQueue.get(i) if !ok { break } if !c.abandoned() { break } a.advancedPeerTSNAckPoint = i } // RFC 3758 Sec 3.5 C3 if sna32GT(a.advancedPeerTSNAckPoint, a.cumulativeTSNAckPoint) { a.willSendForwardTSN = true } } a.log.Debugf("[%s] T3-rtx timed out: nRtos=%d cwnd=%d ssthresh=%d", a.name, nRtos, a.cwnd, a.ssthresh) /* a.log.Debugf(" - advancedPeerTSNAckPoint=%d", a.advancedPeerTSNAckPoint) a.log.Debugf(" - cumulativeTSNAckPoint=%d", a.cumulativeTSNAckPoint) a.inflightQueue.updateSortedKeys() for i, tsn := range a.inflightQueue.sorted { if c, ok := a.inflightQueue.get(tsn); ok { a.log.Debugf(" - [%d] tsn=%d acked=%v abandoned=%v (%v,%v) len=%d", i, c.tsn, c.acked, c.abandoned(), c.beginningFragment, c.endingFragment, len(c.userData)) } } */ a.inflightQueue.markAllToRetrasmit() a.awakeWriteLoop() return } if id == timerReconfig { a.willRetransmitReconfig = true a.awakeWriteLoop() } } func (a *Association) onRetransmissionFailure(id int) { a.lock.Lock() defer a.lock.Unlock() if id == timerT1Init { a.log.Errorf("[%s] retransmission failure: T1-init", a.name) a.handshakeCompletedCh <- ErrHandshakeInitAck return } if id == timerT1Cookie { a.log.Errorf("[%s] retransmission failure: T1-cookie", a.name) a.handshakeCompletedCh <- ErrHandshakeCookieEcho return } if id == timerT2Shutdown { a.log.Errorf("[%s] retransmission failure: T2-shutdown", a.name) return } if id == timerT3RTX { // T3-rtx timer will not fail by design // Justifications: // * ICE would fail if the connectivity is lost // * WebRTC spec is not clear how this incident should be reported to ULP a.log.Errorf("[%s] retransmission failure: T3-rtx (DATA)", a.name) return } } func (a *Association) onAckTimeout() { a.lock.Lock() defer a.lock.Unlock() a.log.Tracef("[%s] ack timed out (ackState: %d)", a.name, a.ackState) a.stats.incAckTimeouts() a.ackState = ackStateImmediate a.awakeWriteLoop() } // bufferedAmount returns total amount (in bytes) of currently buffered user data. // This is used only by testing. func (a *Association) bufferedAmount() int { a.lock.RLock() defer a.lock.RUnlock() return a.pendingQueue.getNumBytes() + a.inflightQueue.getNumBytes() } // MaxMessageSize returns the maximum message size you can send. func (a *Association) MaxMessageSize() uint32 { return atomic.LoadUint32(&a.maxMessageSize) } // SetMaxMessageSize sets the maximum message size you can send. func (a *Association) SetMaxMessageSize(maxMsgSize uint32) { atomic.StoreUint32(&a.maxMessageSize, maxMsgSize) } sctp-1.8.6/association_stats.go000066400000000000000000000024231436021606300165620ustar00rootroot00000000000000package sctp import ( "sync/atomic" ) type associationStats struct { nDATAs uint64 nSACKs uint64 nT3Timeouts uint64 nAckTimeouts uint64 nFastRetrans uint64 } func (s *associationStats) incDATAs() { atomic.AddUint64(&s.nDATAs, 1) } func (s *associationStats) getNumDATAs() uint64 { return atomic.LoadUint64(&s.nDATAs) } func (s *associationStats) incSACKs() { atomic.AddUint64(&s.nSACKs, 1) } func (s *associationStats) getNumSACKs() uint64 { return atomic.LoadUint64(&s.nSACKs) } func (s *associationStats) incT3Timeouts() { atomic.AddUint64(&s.nT3Timeouts, 1) } func (s *associationStats) getNumT3Timeouts() uint64 { return atomic.LoadUint64(&s.nT3Timeouts) } func (s *associationStats) incAckTimeouts() { atomic.AddUint64(&s.nAckTimeouts, 1) } func (s *associationStats) getNumAckTimeouts() uint64 { return atomic.LoadUint64(&s.nAckTimeouts) } func (s *associationStats) incFastRetrans() { atomic.AddUint64(&s.nFastRetrans, 1) } func (s *associationStats) getNumFastRetrans() uint64 { return atomic.LoadUint64(&s.nFastRetrans) } func (s *associationStats) reset() { atomic.StoreUint64(&s.nDATAs, 0) atomic.StoreUint64(&s.nSACKs, 0) atomic.StoreUint64(&s.nT3Timeouts, 0) atomic.StoreUint64(&s.nAckTimeouts, 0) atomic.StoreUint64(&s.nFastRetrans, 0) } sctp-1.8.6/association_test.go000066400000000000000000002305701436021606300164110ustar00rootroot00000000000000//go:build !js // +build !js package sctp import ( "context" cryptoRand "crypto/rand" "encoding/binary" "errors" "io" "math" "math/rand" "net" "os" "runtime" "strings" "sync" "testing" "time" "github.com/pion/logging" "github.com/pion/transport/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) var ( errHandshakeFailed = errors.New("handshake failed") errSINotMatch = errors.New("SI should match") errReadData = errors.New("failed to read data") errReceivedDataNot3Bytes = errors.New("received data must by 3 bytes") errPPIUnexpected = errors.New("unexpected ppi") errReceivedDataMismatch = errors.New("received data mismatch") ) func TestAssocStressDuplex(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() stressDuplex(t) } func stressDuplex(t *testing.T) { ca, cb, stop, err := pipe(pipeDump) if err != nil { t.Fatal(err) } defer stop(t) // Need to Increase once SCTP is more reliable in case of slow reader opt := test.Options{ MsgSize: 2048, // 65535, MsgCount: 10, // 1000, } err = test.StressDuplex(ca, cb, opt) if err != nil { t.Fatal(err) } } func pipe(piper piperFunc) (*Stream, *Stream, func(*testing.T), error) { var err error var aa, ab *Association aa, ab, err = association(piper) if err != nil { return nil, nil, nil, err } var sa, sb *Stream sa, err = aa.OpenStream(0, 0) if err != nil { return nil, nil, nil, err } sb, err = ab.OpenStream(0, 0) if err != nil { return nil, nil, nil, err } stop := func(t *testing.T) { err = sa.Close() if err != nil { t.Error(err) } err = sb.Close() if err != nil { t.Error(err) } err = aa.Close() if err != nil { t.Error(err) } err = ab.Close() if err != nil { t.Error(err) } } return sa, sb, stop, nil } func association(piper piperFunc) (*Association, *Association, error) { ca, cb := piper() type result struct { a *Association err error } c := make(chan result) loggerFactory := logging.NewDefaultLoggerFactory() // Setup client go func() { client, err := Client(Config{ NetConn: ca, LoggerFactory: loggerFactory, }) c <- result{client, err} }() // Setup server server, err := Server(Config{ NetConn: cb, LoggerFactory: loggerFactory, }) if err != nil { return nil, nil, err } // Receive client res := <-c if res.err != nil { return nil, nil, res.err } return res.a, server, nil } type piperFunc func() (net.Conn, net.Conn) func check(err error) { if err != nil { panic(err) } } func pipeDump() (net.Conn, net.Conn) { aConn := acceptDumbConn() bConn, err := net.DialUDP("udp4", nil, aConn.LocalAddr().(*net.UDPAddr)) check(err) // Dumb handshake mgs := "Test" _, err = bConn.Write([]byte(mgs)) check(err) b := make([]byte, 4) _, err = aConn.Read(b) check(err) if string(b) != mgs { panic("Dumb handshake failed") } return aConn, bConn } type dumbConn struct { mu sync.RWMutex rAddr net.Addr pConn net.PacketConn } func acceptDumbConn() *dumbConn { pConn, err := net.ListenUDP("udp4", nil) check(err) return &dumbConn{ pConn: pConn, } } // Read func (c *dumbConn) Read(p []byte) (int, error) { i, rAddr, err := c.pConn.ReadFrom(p) if err != nil { return 0, err } c.mu.Lock() c.rAddr = rAddr c.mu.Unlock() return i, err } // Write writes len(p) bytes from p to the DTLS connection func (c *dumbConn) Write(p []byte) (n int, err error) { return c.pConn.WriteTo(p, c.RemoteAddr()) } // Close closes the conn and releases any Read calls func (c *dumbConn) Close() error { return c.pConn.Close() } // LocalAddr is a stub func (c *dumbConn) LocalAddr() net.Addr { if c.pConn != nil { return c.pConn.LocalAddr() } return nil } // RemoteAddr is a stub func (c *dumbConn) RemoteAddr() net.Addr { c.mu.RLock() defer c.mu.RUnlock() return c.rAddr } // SetDeadline is a stub func (c *dumbConn) SetDeadline(t time.Time) error { return nil } // SetReadDeadline is a stub func (c *dumbConn) SetReadDeadline(t time.Time) error { return nil } // SetWriteDeadline is a stub func (c *dumbConn) SetWriteDeadline(t time.Time) error { return nil } func createNewAssociationPair(br *test.Bridge, ackMode int, recvBufSize uint32) (*Association, *Association, error) { var a0, a1 *Association var err0, err1 error loggerFactory := logging.NewDefaultLoggerFactory() handshake0Ch := make(chan bool) handshake1Ch := make(chan bool) go func() { a0, err0 = Client(Config{ NetConn: br.GetConn0(), MaxReceiveBufferSize: recvBufSize, LoggerFactory: loggerFactory, }) handshake0Ch <- true }() go func() { a1, err1 = Client(Config{ NetConn: br.GetConn1(), MaxReceiveBufferSize: recvBufSize, LoggerFactory: loggerFactory, }) handshake1Ch <- true }() a0handshakeDone := false a1handshakeDone := false loop1: for i := 0; i < 100; i++ { time.Sleep(10 * time.Millisecond) br.Tick() select { case a0handshakeDone = <-handshake0Ch: if a1handshakeDone { break loop1 } case a1handshakeDone = <-handshake1Ch: if a0handshakeDone { break loop1 } default: } } if !a0handshakeDone || !a1handshakeDone { return nil, nil, errHandshakeFailed } if err0 != nil { return nil, nil, err0 } if err1 != nil { return nil, nil, err1 } a0.ackMode = ackMode a1.ackMode = ackMode return a0, a1, nil } func closeAssociationPair(br *test.Bridge, a0, a1 *Association) { close0Ch := make(chan bool) close1Ch := make(chan bool) go func() { // nolint:errcheck,gosec a0.Close() close0Ch <- true }() go func() { // nolint:errcheck,gosec a1.Close() close1Ch <- true }() a0closed := false a1closed := false loop1: for i := 0; i < 100; i++ { time.Sleep(10 * time.Millisecond) br.Tick() select { case a0closed = <-close0Ch: if a1closed { break loop1 } case a1closed = <-close1Ch: if a0closed { break loop1 } default: } } } func flushBuffers(br *test.Bridge, a0, a1 *Association) { for { for { n := br.Tick() if n == 0 { break } } if a0.bufferedAmount() == 0 && a1.bufferedAmount() == 0 { break } time.Sleep(10 * time.Millisecond) } } func establishSessionPair(br *test.Bridge, a0, a1 *Association, si uint16) (*Stream, *Stream, error) { helloMsg := "Hello" // mimic datachannel.channelOpen s0, err := a0.OpenStream(si, PayloadTypeWebRTCBinary) if err != nil { return nil, nil, err } _, err = s0.WriteSCTP([]byte(helloMsg), PayloadTypeWebRTCDCEP) if err != nil { return nil, nil, err } flushBuffers(br, a0, a1) s1, err := a1.AcceptStream() if err != nil { return nil, nil, err } if s0.streamIdentifier != s1.streamIdentifier { return nil, nil, errSINotMatch } br.Process() buf := make([]byte, 1024) n, ppi, err := s1.ReadSCTP(buf) if err != nil { return nil, nil, errReadData } if n != len(helloMsg) { return nil, nil, errReceivedDataNot3Bytes } if ppi != PayloadTypeWebRTCDCEP { return nil, nil, errPPIUnexpected } if string(buf[:n]) != helloMsg { return nil, nil, errReceivedDataMismatch } flushBuffers(br, a0, a1) return s0, s1, nil } func TestAssocReliable(t *testing.T) { // sbuf - small enogh not to be fragmented // large enobh not to be bundled sbuf := make([]byte, 1000) for i := 0; i < len(sbuf); i++ { sbuf[i] = byte(i & 0xff) } rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(sbuf), func(i, j int) { sbuf[i], sbuf[j] = sbuf[j], sbuf[i] }) // sbufL - large enogh to be fragmented into two chunks and each chunks are // large enobh not to be bundled sbufL := make([]byte, 2000) for i := 0; i < len(sbufL); i++ { sbufL[i] = byte(i & 0xff) } rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(sbufL), func(i, j int) { sbufL[i], sbufL[j] = sbufL[j], sbufL[i] }) t.Run("Simple", func(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() const si uint16 = 1 const msg = "ABC" br := test.NewBridge() a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) if !assert.Nil(t, err, "failed to create associations") { assert.FailNow(t, "failed due to earlier error") } s0, s1, err := establishSessionPair(br, a0, a1, si) assert.Nil(t, err, "failed to establish session pair") assert.Equal(t, 0, a0.bufferedAmount(), "incorrect bufferedAmount") n, err := s0.WriteSCTP([]byte(msg), PayloadTypeWebRTCBinary) if err != nil { assert.FailNow(t, "failed due to earlier error") } assert.Equal(t, len(msg), n, "unexpected length of received data") assert.Equal(t, len(msg), a0.bufferedAmount(), "incorrect bufferedAmount") flushBuffers(br, a0, a1) buf := make([]byte, 32) n, ppi, err := s1.ReadSCTP(buf) if !assert.Nil(t, err, "ReadSCTP failed") { assert.FailNow(t, "failed due to earlier error") } assert.Equal(t, n, len(msg), "unexpected length of received data") assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") assert.False(t, s0.reassemblyQueue.isReadable(), "should no longer be readable") assert.Equal(t, 0, a0.bufferedAmount(), "incorrect bufferedAmount") closeAssociationPair(br, a0, a1) }) t.Run("ReadDeadline", func(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() const si uint16 = 1 const msg = "ABC" br := test.NewBridge() a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) if !assert.Nil(t, err, "failed to create associations") { assert.FailNow(t, "failed due to earlier error") } s0, s1, err := establishSessionPair(br, a0, a1, si) assert.Nil(t, err, "failed to establish session pair") assert.Equal(t, 0, a0.bufferedAmount(), "incorrect bufferedAmount") assert.NoError(t, s1.SetReadDeadline(time.Now().Add(time.Millisecond)), "failed to set read deadline") buf := make([]byte, 32) // First fails n, ppi, err := s1.ReadSCTP(buf) assert.Equal(t, 0, n) assert.Equal(t, PayloadProtocolIdentifier(0), ppi) assert.True(t, errors.Is(err, os.ErrDeadlineExceeded)) // Second too n, ppi, err = s1.ReadSCTP(buf) assert.Equal(t, 0, n) assert.Equal(t, PayloadProtocolIdentifier(0), ppi) assert.True(t, errors.Is(err, os.ErrDeadlineExceeded)) assert.NoError(t, s1.SetReadDeadline(time.Time{}), "failed to disable read deadline") n, err = s0.WriteSCTP([]byte(msg), PayloadTypeWebRTCBinary) if err != nil { assert.FailNow(t, "failed due to earlier error") } assert.Equal(t, len(msg), n, "unexpected length of received data") assert.Equal(t, len(msg), a0.bufferedAmount(), "incorrect bufferedAmount") flushBuffers(br, a0, a1) n, ppi, err = s1.ReadSCTP(buf) if !assert.Nil(t, err, "ReadSCTP failed") { assert.FailNow(t, "failed due to earlier error") } assert.Equal(t, n, len(msg), "unexpected length of received data") assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") closeAssociationPair(br, a0, a1) }) t.Run("ordered reordered", func(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() const si uint16 = 2 var n int var ppi PayloadProtocolIdentifier br := test.NewBridge() a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) if !assert.Nil(t, err, "failed to create associations") { assert.FailNow(t, "failed due to earlier error") } s0, s1, err := establishSessionPair(br, a0, a1, si) assert.Nil(t, err, "failed to establish session pair") binary.BigEndian.PutUint32(sbuf, 0) n, err = s0.WriteSCTP(sbuf, PayloadTypeWebRTCBinary) assert.Nil(t, err, "WriteSCTP failed") assert.Equal(t, n, len(sbuf), "unexpected length of received data") binary.BigEndian.PutUint32(sbuf, 1) n, err = s0.WriteSCTP(sbuf, PayloadTypeWebRTCBinary) assert.Nil(t, err, "WriteSCTP failed") assert.Equal(t, n, len(sbuf), "unexpected length of received data") time.Sleep(10 * time.Millisecond) err = br.Reorder(0) assert.Nil(t, err, "reorder failed") br.Process() buf := make([]byte, 2000) n, ppi, err = s1.ReadSCTP(buf) if !assert.Nil(t, err, "ReadSCTP failed") { assert.FailNow(t, "failed due to earlier error") } assert.Equal(t, n, len(sbuf), "unexpected length of received data") assert.Equal(t, uint32(0), binary.BigEndian.Uint32(buf[:n]), "unexpected received data") assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") n, ppi, err = s1.ReadSCTP(buf) if !assert.Nil(t, err, "ReadSCTP failed") { assert.FailNow(t, "failed due to earlier error") } assert.Equal(t, n, len(sbuf), "unexpected length of received data") assert.Equal(t, uint32(1), binary.BigEndian.Uint32(buf[:n]), "unexpected received data") assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") br.Process() assert.False(t, s0.reassemblyQueue.isReadable(), "should no longer be readable") closeAssociationPair(br, a0, a1) }) t.Run("ordered fragmented then defragmented", func(t *testing.T) { // nolint:dupl lim := test.TimeOut(time.Second * 10) defer lim.Stop() const si uint16 = 3 var n int var ppi PayloadProtocolIdentifier br := test.NewBridge() a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) if !assert.Nil(t, err, "failed to create associations") { assert.FailNow(t, "failed due to earlier error") } s0, s1, err := establishSessionPair(br, a0, a1, si) assert.Nil(t, err, "failed to establish session pair") s0.SetReliabilityParams(false, ReliabilityTypeReliable, 0) s1.SetReliabilityParams(false, ReliabilityTypeReliable, 0) n, err = s0.WriteSCTP(sbufL, PayloadTypeWebRTCBinary) assert.Nil(t, err, "WriteSCTP failed") assert.Equal(t, n, len(sbufL), "unexpected length of received data") rbuf := make([]byte, 2000) flushBuffers(br, a0, a1) n, ppi, err = s1.ReadSCTP(rbuf) if !assert.Nil(t, err, "ReadSCTP failed") { assert.FailNow(t, "failed due to earlier error") } assert.Equal(t, n, len(sbufL), "unexpected length of received data") assert.Equal(t, sbufL, rbuf[:n], "unexpected received data") assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") br.Process() assert.False(t, s0.reassemblyQueue.isReadable(), "should no longer be readable") closeAssociationPair(br, a0, a1) }) t.Run("unordered fragmented then defragmented", func(t *testing.T) { // nolint:dupl lim := test.TimeOut(time.Second * 10) defer lim.Stop() const si uint16 = 4 var n int var ppi PayloadProtocolIdentifier br := test.NewBridge() a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) if !assert.Nil(t, err, "failed to create associations") { assert.FailNow(t, "failed due to earlier error") } s0, s1, err := establishSessionPair(br, a0, a1, si) assert.Nil(t, err, "failed to establish session pair") s0.SetReliabilityParams(true, ReliabilityTypeReliable, 0) s1.SetReliabilityParams(true, ReliabilityTypeReliable, 0) n, err = s0.WriteSCTP(sbufL, PayloadTypeWebRTCBinary) assert.Nil(t, err, "WriteSCTP failed") assert.Equal(t, n, len(sbufL), "unexpected length of received data") rbuf := make([]byte, 2000) flushBuffers(br, a0, a1) n, ppi, err = s1.ReadSCTP(rbuf) if !assert.Nil(t, err, "ReadSCTP failed") { assert.FailNow(t, "failed due to earlier error") } assert.Equal(t, n, len(sbufL), "unexpected length of received data") assert.Equal(t, sbufL, rbuf[:n], "unexpected received data") assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") br.Process() assert.False(t, s0.reassemblyQueue.isReadable(), "should no longer be readable") closeAssociationPair(br, a0, a1) }) t.Run("unordered reordered", func(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() const si uint16 = 5 var n int var ppi PayloadProtocolIdentifier br := test.NewBridge() a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) if !assert.Nil(t, err, "failed to create associations") { assert.FailNow(t, "failed due to earlier error") } s0, s1, err := establishSessionPair(br, a0, a1, si) assert.Nil(t, err, "failed to establish session pair") s0.SetReliabilityParams(true, ReliabilityTypeReliable, 0) s1.SetReliabilityParams(true, ReliabilityTypeReliable, 0) br.ReorderNextNWrites(0, 2) binary.BigEndian.PutUint32(sbuf, 0) n, err = s0.WriteSCTP(sbuf, PayloadTypeWebRTCBinary) assert.Nil(t, err, "WriteSCTP failed") assert.Equal(t, n, len(sbuf), "unexpected length of received data") binary.BigEndian.PutUint32(sbuf, 1) n, err = s0.WriteSCTP(sbuf, PayloadTypeWebRTCBinary) assert.Nil(t, err, "WriteSCTP failed") assert.Equal(t, n, len(sbuf), "unexpected length of received data") buf := make([]byte, 2000) flushBuffers(br, a0, a1) n, ppi, err = s1.ReadSCTP(buf) if !assert.Nil(t, err, "ReadSCTP failed") { assert.FailNow(t, "failed due to earlier error") } assert.Equal(t, n, len(sbuf), "unexpected length of received data") assert.Equal(t, uint32(1), binary.BigEndian.Uint32(buf[:n]), "unexpected received data") assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") br.Process() n, ppi, err = s1.ReadSCTP(buf) if !assert.Nil(t, err, "ReadSCTP failed") { assert.FailNow(t, "failed due to earlier error") } assert.Equal(t, n, len(sbuf), "unexpected length of received data") assert.Equal(t, uint32(0), binary.BigEndian.Uint32(buf[:n]), "unexpected received data") assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") br.Process() assert.False(t, s0.reassemblyQueue.isReadable(), "should no longer be readable") closeAssociationPair(br, a0, a1) }) t.Run("retransmission", func(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() const si uint16 = 6 const msg1 = "ABC" const msg2 = "DEFG" var n int var ppi PayloadProtocolIdentifier br := test.NewBridge() a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) if !assert.Nil(t, err, "failed to create associations") { assert.FailNow(t, "failed due to earlier error") } // lock RTO value at 100 [msec] a0.rtoMgr.setRTO(100.0, true) s0, s1, err := establishSessionPair(br, a0, a1, si) assert.Nil(t, err, "failed to establish session pair") n, err = s0.WriteSCTP([]byte(msg1), PayloadTypeWebRTCBinary) assert.Nil(t, err, "WriteSCTP failed") assert.Equal(t, n, len(msg1), "unexpected length of received data") n, err = s0.WriteSCTP([]byte(msg2), PayloadTypeWebRTCBinary) assert.Nil(t, err, "WriteSCTP failed") assert.Equal(t, n, len(msg2), "unexpected length of received data") br.Drop(0, 0, 1) // drop the first packet (second one should be sacked) // process packets for 200 msec for i := 0; i < 20; i++ { br.Tick() time.Sleep(10 * time.Millisecond) } buf := make([]byte, 32) n, ppi, err = s1.ReadSCTP(buf) if !assert.Nil(t, err, "ReadSCTP failed") { assert.FailNow(t, "failed due to earlier error") } assert.Equal(t, n, len(msg1), "unexpected length of received data") assert.Equal(t, msg1, string(buf[:n]), "unexpected received data") assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") n, ppi, err = s1.ReadSCTP(buf) if !assert.Nil(t, err, "ReadSCTP failed") { assert.FailNow(t, "failed due to earlier error") } assert.Equal(t, n, len(msg2), "unexpected length of received data") assert.Equal(t, msg2, string(buf[:n]), "unexpected received data") assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") br.Process() assert.False(t, s0.reassemblyQueue.isReadable(), "should no longer be readable") closeAssociationPair(br, a0, a1) }) t.Run("short buffer", func(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() const si uint16 = 1 const msg = "Hello" br := test.NewBridge() a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) if !assert.Nil(t, err, "failed to create associations") { assert.FailNow(t, "failed due to earlier error") } s0, s1, err := establishSessionPair(br, a0, a1, si) assert.Nil(t, err, "failed to establish session pair") assert.Equal(t, 0, a0.bufferedAmount(), "incorrect bufferedAmount") n, err := s0.WriteSCTP([]byte(msg), PayloadTypeWebRTCBinary) if err != nil { assert.FailNow(t, "failed due to earlier error") } assert.Equal(t, len(msg), n, "unexpected length of received data") assert.Equal(t, len(msg), a0.bufferedAmount(), "incorrect bufferedAmount") flushBuffers(br, a0, a1) buf := make([]byte, 3) n, ppi, err := s1.ReadSCTP(buf) assert.Equal(t, err, io.ErrShortBuffer, "expected error to be io.ErrShortBuffer") assert.Equal(t, n, 0, "unexpected length of received data") assert.Equal(t, ppi, PayloadProtocolIdentifier(0), "unexpected ppi") assert.False(t, s0.reassemblyQueue.isReadable(), "should no longer be readable") assert.Equal(t, 0, a0.bufferedAmount(), "incorrect bufferedAmount") closeAssociationPair(br, a0, a1) }) } func TestAssocUnreliable(t *testing.T) { // sbuf1, sbuf2: // large enogh to be fragmented into two chunks and each chunks are // large enobh not to be bundled sbuf1 := make([]byte, 2000) sbuf2 := make([]byte, 2000) for i := 0; i < len(sbuf1); i++ { sbuf1[i] = byte(i & 0xff) } rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(sbuf1), func(i, j int) { sbuf1[i], sbuf1[j] = sbuf1[j], sbuf1[i] }) for i := 0; i < len(sbuf2); i++ { sbuf2[i] = byte(i & 0xff) } rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(sbuf2), func(i, j int) { sbuf2[i], sbuf2[j] = sbuf2[j], sbuf2[i] }) // sbuf - small enogh not to be fragmented // large enobh not to be bundled sbuf := make([]byte, 1000) for i := 0; i < len(sbuf); i++ { sbuf[i] = byte(i & 0xff) } rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(sbuf), func(i, j int) { sbuf[i], sbuf[j] = sbuf[j], sbuf[i] }) t.Run("Rexmit ordered no fragment", func(t *testing.T) { // nolint:dupl lim := test.TimeOut(time.Second * 10) defer lim.Stop() const si uint16 = 1 br := test.NewBridge() a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) if !assert.Nil(t, err, "failed to create associations") { assert.FailNow(t, "failed due to earlier error") } s0, s1, err := establishSessionPair(br, a0, a1, si) assert.Nil(t, err, "failed to establish session pair") // When we set the reliability value to 0 [times], then it will cause // the chunk to be abandoned immediately after the first transmission. s0.SetReliabilityParams(false, ReliabilityTypeRexmit, 0) s1.SetReliabilityParams(false, ReliabilityTypeRexmit, 0) // doesn't matter br.DropNextNWrites(0, 1) // drop the first packet (second one should be sacked) var n int binary.BigEndian.PutUint32(sbuf, uint32(0)) n, err = s0.WriteSCTP(sbuf, PayloadTypeWebRTCBinary) if err != nil { assert.FailNow(t, "failed due to earlier error") } assert.Equal(t, len(sbuf), n, "unexpected length of written data") binary.BigEndian.PutUint32(sbuf, uint32(1)) n, err = s0.WriteSCTP(sbuf, PayloadTypeWebRTCBinary) if err != nil { assert.FailNow(t, "failed due to earlier error") } assert.Equal(t, len(sbuf), n, "unexpected length of written data") flushBuffers(br, a0, a1) buf := make([]byte, 2000) n, ppi, err := s1.ReadSCTP(buf) if !assert.Nil(t, err, "ReadSCTP failed") { assert.FailNow(t, "failed due to earlier error") } // should receive the second one only assert.Equal(t, len(sbuf), n, "unexpected length of written data") assert.Equal(t, uint32(1), binary.BigEndian.Uint32(buf[:n]), "unexpected received data") assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") br.Process() assert.False(t, s0.reassemblyQueue.isReadable(), "should no longer be readable") closeAssociationPair(br, a0, a1) }) t.Run("Rexmit ordered fragments", func(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() const si uint16 = 1 br := test.NewBridge() a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) if !assert.Nil(t, err, "failed to create associations") { assert.FailNow(t, "failed due to earlier error") } s0, s1, err := establishSessionPair(br, a0, a1, si) assert.Nil(t, err, "failed to establish session pair") // lock RTO value at 100 [msec] a0.rtoMgr.setRTO(100.0, true) // When we set the reliability value to 0 [times], then it will cause // the chunk to be abandoned immediately after the first transmission. s0.SetReliabilityParams(false, ReliabilityTypeRexmit, 0) s1.SetReliabilityParams(false, ReliabilityTypeRexmit, 0) // doesn't matter br.DropNextNWrites(0, 1) // drop the first fragment of the first chunk (second chunk should be sacked) var n int n, err = s0.WriteSCTP(sbuf1, PayloadTypeWebRTCBinary) if err != nil { assert.FailNow(t, "failed due to earlier error") } assert.Equal(t, len(sbuf1), n, "unexpected length of written data") n, err = s0.WriteSCTP(sbuf2, PayloadTypeWebRTCBinary) if err != nil { assert.FailNow(t, "failed due to earlier error") } assert.Equal(t, len(sbuf2), n, "unexpected length of written data") flushBuffers(br, a0, a1) rbuf := make([]byte, 2000) n, ppi, err := s1.ReadSCTP(rbuf) if !assert.Nil(t, err, "ReadSCTP failed") { assert.FailNow(t, "failed due to earlier error") } // should receive the second one only assert.Equal(t, sbuf2, rbuf[:n], "unexpected received data") assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") br.Process() assert.False(t, s0.reassemblyQueue.isReadable(), "should no longer be readable") assert.Equal(t, 0, len(s0.reassemblyQueue.ordered), "should be nothing in the ordered queue") closeAssociationPair(br, a0, a1) }) t.Run("Rexmit unordered no fragment", func(t *testing.T) { // nolint:dupl lim := test.TimeOut(time.Second * 10) defer lim.Stop() const si uint16 = 2 br := test.NewBridge() a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) if !assert.Nil(t, err, "failed to create associations") { assert.FailNow(t, "failed due to earlier error") } s0, s1, err := establishSessionPair(br, a0, a1, si) assert.Nil(t, err, "failed to establish session pair") // When we set the reliability value to 0 [times], then it will cause // the chunk to be abandoned immediately after the first transmission. s0.SetReliabilityParams(true, ReliabilityTypeRexmit, 0) s1.SetReliabilityParams(true, ReliabilityTypeRexmit, 0) // doesn't matter br.DropNextNWrites(0, 1) // drop the first packet (second one should be sacked) var n int binary.BigEndian.PutUint32(sbuf, uint32(0)) n, err = s0.WriteSCTP(sbuf, PayloadTypeWebRTCBinary) if err != nil { assert.FailNow(t, "failed due to earlier error") } assert.Equal(t, len(sbuf), n, "unexpected length of written data") binary.BigEndian.PutUint32(sbuf, uint32(1)) n, err = s0.WriteSCTP(sbuf, PayloadTypeWebRTCBinary) if err != nil { assert.FailNow(t, "failed due to earlier error") } assert.Equal(t, len(sbuf), n, "unexpected length of written data") flushBuffers(br, a0, a1) buf := make([]byte, 2000) n, ppi, err := s1.ReadSCTP(buf) if !assert.Nil(t, err, "ReadSCTP failed") { assert.FailNow(t, "failed due to earlier error") } // should receive the second one only assert.Equal(t, len(sbuf), n, "unexpected length of written data") assert.Equal(t, uint32(1), binary.BigEndian.Uint32(buf[:n]), "unexpected received data") assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") br.Process() assert.False(t, s0.reassemblyQueue.isReadable(), "should no longer be readable") closeAssociationPair(br, a0, a1) }) t.Run("Rexmit unordered fragments", func(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() const si uint16 = 1 br := test.NewBridge() a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) if !assert.Nil(t, err, "failed to create associations") { assert.FailNow(t, "failed due to earlier error") } s0, s1, err := establishSessionPair(br, a0, a1, si) assert.Nil(t, err, "failed to establish session pair") // When we set the reliability value to 0 [times], then it will cause // the chunk to be abandoned immediately after the first transmission. s0.SetReliabilityParams(true, ReliabilityTypeRexmit, 0) s1.SetReliabilityParams(true, ReliabilityTypeRexmit, 0) // doesn't matter var n int n, err = s0.WriteSCTP(sbuf1, PayloadTypeWebRTCBinary) if err != nil { assert.FailNow(t, "failed due to earlier error") } assert.Equal(t, len(sbuf1), n, "unexpected length of written data") n, err = s0.WriteSCTP(sbuf2, PayloadTypeWebRTCBinary) if err != nil { assert.FailNow(t, "failed due to earlier error") } assert.Equal(t, len(sbuf2), n, "unexpected length of written data") time.Sleep(10 * time.Millisecond) br.Drop(0, 0, 2) // drop the second fragment of the first chunk (second chunk should be sacked) flushBuffers(br, a0, a1) rbuf := make([]byte, 2000) n, ppi, err := s1.ReadSCTP(rbuf) if !assert.Nil(t, err, "ReadSCTP failed") { assert.FailNow(t, "failed due to earlier error") } // should receive the second one only assert.Equal(t, sbuf2, rbuf[:n], "unexpected received data") assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") br.Process() assert.False(t, s0.reassemblyQueue.isReadable(), "should no longer be readable") assert.Equal(t, 0, len(s0.reassemblyQueue.unordered), "should be nothing in the unordered queue") assert.Equal(t, 0, len(s0.reassemblyQueue.unorderedChunks), "should be nothing in the unorderedChunks list") closeAssociationPair(br, a0, a1) }) t.Run("Timed ordered", func(t *testing.T) { // nolint:dupl lim := test.TimeOut(time.Second * 10) defer lim.Stop() const si uint16 = 3 br := test.NewBridge() a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) if !assert.Nil(t, err, "failed to create associations") { assert.FailNow(t, "failed due to earlier error") } s0, s1, err := establishSessionPair(br, a0, a1, si) assert.Nil(t, err, "failed to establish session pair") // When we set the reliability value to 0 [msec], then it will cause // the chunk to be abandoned immediately after the first transmission. s0.SetReliabilityParams(false, ReliabilityTypeTimed, 0) s1.SetReliabilityParams(false, ReliabilityTypeTimed, 0) // doesn't matter br.DropNextNWrites(0, 1) // drop the first packet (second one should be sacked) var n int binary.BigEndian.PutUint32(sbuf, uint32(0)) n, err = s0.WriteSCTP(sbuf, PayloadTypeWebRTCBinary) if err != nil { assert.FailNow(t, "failed due to earlier error") } assert.Equal(t, len(sbuf), n, "unexpected length of written data") binary.BigEndian.PutUint32(sbuf, uint32(1)) n, err = s0.WriteSCTP(sbuf, PayloadTypeWebRTCBinary) if err != nil { assert.FailNow(t, "failed due to earlier error") } assert.Equal(t, len(sbuf), n, "unexpected length of written data") // br.Drop(0, 0, 1) // drop the first packet (second one should be sacked) flushBuffers(br, a0, a1) buf := make([]byte, 2000) n, ppi, err := s1.ReadSCTP(buf) if !assert.Nil(t, err, "ReadSCTP failed") { assert.FailNow(t, "failed due to earlier error") } // should receive the second one only assert.Equal(t, len(sbuf), n, "unexpected length of written data") assert.Equal(t, uint32(1), binary.BigEndian.Uint32(buf[:n]), "unexpected received data") assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") br.Process() assert.False(t, s0.reassemblyQueue.isReadable(), "should no longer be readable") closeAssociationPair(br, a0, a1) }) t.Run("Timed unordered", func(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() const si uint16 = 3 br := test.NewBridge() a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) if !assert.Nil(t, err, "failed to create associations") { assert.FailNow(t, "failed due to earlier error") } s0, s1, err := establishSessionPair(br, a0, a1, si) assert.Nil(t, err, "failed to establish session pair") // When we set the reliability value to 0 [msec], then it will cause // the chunk to be abandoned immediately after the first transmission. s0.SetReliabilityParams(true, ReliabilityTypeTimed, 0) s1.SetReliabilityParams(true, ReliabilityTypeTimed, 0) // doesn't matter br.DropNextNWrites(0, 1) // drop the first packet (second one should be sacked) var n int binary.BigEndian.PutUint32(sbuf, uint32(0)) n, err = s0.WriteSCTP(sbuf, PayloadTypeWebRTCBinary) if err != nil { assert.FailNow(t, "failed due to earlier error") } assert.Equal(t, len(sbuf), n, "unexpected length of written data") binary.BigEndian.PutUint32(sbuf, uint32(1)) n, err = s0.WriteSCTP(sbuf, PayloadTypeWebRTCBinary) if err != nil { assert.FailNow(t, "failed due to earlier error") } assert.Equal(t, len(sbuf), n, "unexpected length of written data") flushBuffers(br, a0, a1) buf := make([]byte, 2000) n, ppi, err := s1.ReadSCTP(buf) if !assert.Nil(t, err, "ReadSCTP failed") { assert.FailNow(t, "failed due to earlier error") } // should receive the second one only assert.Equal(t, len(sbuf), n, "unexpected length of written data") assert.Equal(t, uint32(1), binary.BigEndian.Uint32(buf[:n]), "unexpected received data") assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") br.Process() assert.False(t, s0.reassemblyQueue.isReadable(), "should no longer be readable") assert.Equal(t, 0, len(s0.reassemblyQueue.unordered), "should be nothing in the unordered queue") assert.Equal(t, 0, len(s0.reassemblyQueue.unorderedChunks), "should be nothing in the unorderedChunks list") closeAssociationPair(br, a0, a1) }) } func TestCreateForwardTSN(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() t.Run("forward one abandoned", func(t *testing.T) { a := createAssociation(Config{ NetConn: &dumbConn{}, LoggerFactory: loggerFactory, }) a.cumulativeTSNAckPoint = 9 a.advancedPeerTSNAckPoint = 10 a.inflightQueue.pushNoCheck(&chunkPayloadData{ beginningFragment: true, endingFragment: true, tsn: 10, streamIdentifier: 1, streamSequenceNumber: 2, userData: []byte("ABC"), nSent: 1, _abandoned: true, }) fwdtsn := a.createForwardTSN() assert.Equal(t, uint32(10), fwdtsn.newCumulativeTSN, "should be able to serialize") assert.Equal(t, 1, len(fwdtsn.streams), "there should be one stream") assert.Equal(t, uint16(1), fwdtsn.streams[0].identifier, "si should be 1") assert.Equal(t, uint16(2), fwdtsn.streams[0].sequence, "ssn should be 2") }) t.Run("forward two abandoned with the same SI", func(t *testing.T) { a := createAssociation(Config{ NetConn: &dumbConn{}, LoggerFactory: loggerFactory, }) a.cumulativeTSNAckPoint = 9 a.advancedPeerTSNAckPoint = 12 a.inflightQueue.pushNoCheck(&chunkPayloadData{ beginningFragment: true, endingFragment: true, tsn: 10, streamIdentifier: 1, streamSequenceNumber: 2, userData: []byte("ABC"), nSent: 1, _abandoned: true, }) a.inflightQueue.pushNoCheck(&chunkPayloadData{ beginningFragment: true, endingFragment: true, tsn: 11, streamIdentifier: 1, streamSequenceNumber: 3, userData: []byte("DEF"), nSent: 1, _abandoned: true, }) a.inflightQueue.pushNoCheck(&chunkPayloadData{ beginningFragment: true, endingFragment: true, tsn: 12, streamIdentifier: 2, streamSequenceNumber: 1, userData: []byte("123"), nSent: 1, _abandoned: true, }) fwdtsn := a.createForwardTSN() assert.Equal(t, uint32(12), fwdtsn.newCumulativeTSN, "should be able to serialize") assert.Equal(t, 2, len(fwdtsn.streams), "there should be two stream") si1OK := false si2OK := false for _, s := range fwdtsn.streams { switch s.identifier { case 1: assert.Equal(t, uint16(3), s.sequence, "ssn should be 3") si1OK = true case 2: assert.Equal(t, uint16(1), s.sequence, "ssn should be 1") si2OK = true default: assert.Fail(t, "unexpected stream indentifier") } } assert.True(t, si1OK, "si=1 should be present") assert.True(t, si2OK, "si=2 should be present") }) } func TestHandleForwardTSN(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() t.Run("forward 3 unreceived chunks", func(t *testing.T) { a := createAssociation(Config{ NetConn: &dumbConn{}, LoggerFactory: loggerFactory, }) a.useForwardTSN = true prevTSN := a.peerLastTSN fwdtsn := &chunkForwardTSN{ newCumulativeTSN: a.peerLastTSN + 3, streams: []chunkForwardTSNStream{{identifier: 0, sequence: 0}}, } p := a.handleForwardTSN(fwdtsn) a.lock.Lock() delayedAckTriggered := a.delayedAckTriggered immediateAckTriggered := a.immediateAckTriggered a.lock.Unlock() assert.Equal(t, a.peerLastTSN, prevTSN+3, "peerLastTSN should advance by 3 ") assert.True(t, delayedAckTriggered, "delayed sack should be triggered") assert.False(t, immediateAckTriggered, "immediate sack should NOT be triggered") assert.Nil(t, p, "should return nil") }) t.Run("forward 1 for 1 missing", func(t *testing.T) { a := createAssociation(Config{ NetConn: &dumbConn{}, LoggerFactory: loggerFactory, }) a.useForwardTSN = true prevTSN := a.peerLastTSN // this chunk is blocked by the missing chunk at tsn=1 a.payloadQueue.push(&chunkPayloadData{ beginningFragment: true, endingFragment: true, tsn: a.peerLastTSN + 2, streamIdentifier: 0, streamSequenceNumber: 1, userData: []byte("ABC"), }, a.peerLastTSN) fwdtsn := &chunkForwardTSN{ newCumulativeTSN: a.peerLastTSN + 1, streams: []chunkForwardTSNStream{ {identifier: 0, sequence: 1}, }, } p := a.handleForwardTSN(fwdtsn) a.lock.Lock() delayedAckTriggered := a.delayedAckTriggered immediateAckTriggered := a.immediateAckTriggered a.lock.Unlock() assert.Equal(t, a.peerLastTSN, prevTSN+2, "peerLastTSN should advance by 3") assert.True(t, delayedAckTriggered, "delayed sack should be triggered") assert.False(t, immediateAckTriggered, "immediate sack should NOT be triggered") assert.Nil(t, p, "should return nil") }) t.Run("forward 1 for 2 missing", func(t *testing.T) { a := createAssociation(Config{ NetConn: &dumbConn{}, LoggerFactory: loggerFactory, }) a.useForwardTSN = true prevTSN := a.peerLastTSN // this chunk is blocked by the missing chunk at tsn=1 a.payloadQueue.push(&chunkPayloadData{ beginningFragment: true, endingFragment: true, tsn: a.peerLastTSN + 3, streamIdentifier: 0, streamSequenceNumber: 1, userData: []byte("ABC"), }, a.peerLastTSN) fwdtsn := &chunkForwardTSN{ newCumulativeTSN: a.peerLastTSN + 1, streams: []chunkForwardTSNStream{ {identifier: 0, sequence: 1}, }, } p := a.handleForwardTSN(fwdtsn) a.lock.Lock() immediateAckTriggered := a.immediateAckTriggered a.lock.Unlock() assert.Equal(t, a.peerLastTSN, prevTSN+1, "peerLastTSN should advance by 1") assert.True(t, immediateAckTriggered, "immediate sack should be triggered") assert.Nil(t, p, "should return nil") }) t.Run("dup forward TSN chunk should generate sack", func(t *testing.T) { a := createAssociation(Config{ NetConn: &dumbConn{}, LoggerFactory: loggerFactory, }) a.useForwardTSN = true prevTSN := a.peerLastTSN fwdtsn := &chunkForwardTSN{ newCumulativeTSN: a.peerLastTSN, // old TSN streams: []chunkForwardTSNStream{ {identifier: 0, sequence: 1}, }, } p := a.handleForwardTSN(fwdtsn) a.lock.Lock() ackState := a.ackState a.lock.Unlock() assert.Equal(t, a.peerLastTSN, prevTSN, "peerLastTSN should not advance") assert.Equal(t, ackStateImmediate, ackState, "sack should be requested") assert.Nil(t, p, "should return nil") }) } func TestAssocT1InitTimer(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() t.Run("Retransmission success", func(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() br := test.NewBridge() a0 := createAssociation(Config{ NetConn: br.GetConn0(), LoggerFactory: loggerFactory, }) a1 := createAssociation(Config{ NetConn: br.GetConn1(), LoggerFactory: loggerFactory, }) var err0, err1 error a0ReadyCh := make(chan bool) a1ReadyCh := make(chan bool) assert.Equal(t, rtoInitial, a0.rtoMgr.getRTO()) assert.Equal(t, rtoInitial, a1.rtoMgr.getRTO()) // modified rto for fast test a0.rtoMgr.setRTO(20, false) go func() { err0 = <-a0.handshakeCompletedCh a0ReadyCh <- true }() go func() { err1 = <-a1.handshakeCompletedCh a1ReadyCh <- true }() // Drop the first write br.DropNextNWrites(0, 1) // Start the handlshake a0.init(true) a1.init(true) a0Ready := false a1Ready := false for !a0Ready || !a1Ready { br.Process() select { case a0Ready = <-a0ReadyCh: case a1Ready = <-a1ReadyCh: default: } } flushBuffers(br, a0, a1) assert.Nil(t, err0, "should be nil") assert.Nil(t, err1, "should be nil") closeAssociationPair(br, a0, a1) }) t.Run("Retransmission failure", func(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() br := test.NewBridge() a0 := createAssociation(Config{ NetConn: br.GetConn0(), LoggerFactory: loggerFactory, }) a1 := createAssociation(Config{ NetConn: br.GetConn1(), LoggerFactory: loggerFactory, }) var err0, err1 error a0ReadyCh := make(chan bool) a1ReadyCh := make(chan bool) assert.Equal(t, rtoInitial, a0.rtoMgr.getRTO()) assert.Equal(t, rtoInitial, a1.rtoMgr.getRTO()) // modified rto for fast test a0.rtoMgr.setRTO(20, false) a1.rtoMgr.setRTO(20, false) // fail after 4 retransmission a0.t1Init.maxRetrans = 4 a1.t1Init.maxRetrans = 4 go func() { err0 = <-a0.handshakeCompletedCh a0ReadyCh <- true }() go func() { err1 = <-a1.handshakeCompletedCh a1ReadyCh <- true }() // Drop all INIT br.DropNextNWrites(0, 99) br.DropNextNWrites(1, 99) // Start the handlshake a0.init(true) a1.init(true) a0Ready := false a1Ready := false for !a0Ready || !a1Ready { br.Process() select { case a0Ready = <-a0ReadyCh: case a1Ready = <-a1ReadyCh: default: } } assert.NotNil(t, err0, "should NOT be nil") assert.NotNil(t, err1, "should NOT be nil") closeAssociationPair(br, a0, a1) }) } func TestAssocT1CookieTimer(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() t.Run("Retransmission success", func(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() br := test.NewBridge() a0 := createAssociation(Config{ NetConn: br.GetConn0(), LoggerFactory: loggerFactory, }) a1 := createAssociation(Config{ NetConn: br.GetConn1(), LoggerFactory: loggerFactory, }) var err0, err1 error a0ReadyCh := make(chan bool) a1ReadyCh := make(chan bool) assert.Equal(t, rtoInitial, a0.rtoMgr.getRTO()) assert.Equal(t, rtoInitial, a1.rtoMgr.getRTO()) // modified rto for fast test a0.rtoMgr.setRTO(20, false) go func() { err0 = <-a0.handshakeCompletedCh a0ReadyCh <- true }() go func() { err1 = <-a1.handshakeCompletedCh a1ReadyCh <- true }() // Start the handlshake a0.init(true) a1.init(true) // Let the INIT go. br.Tick() // Drop COOKIE-ECHO br.DropNextNWrites(0, 1) a0Ready := false a1Ready := false for !a0Ready || !a1Ready { br.Process() select { case a0Ready = <-a0ReadyCh: case a1Ready = <-a1ReadyCh: default: } } assert.Nil(t, err0, "should be nil") assert.Nil(t, err1, "should be nil") closeAssociationPair(br, a0, a1) }) t.Run("Retransmission failure", func(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() br := test.NewBridge() a0 := createAssociation(Config{ NetConn: br.GetConn0(), LoggerFactory: loggerFactory, }) a1 := createAssociation(Config{ NetConn: br.GetConn1(), LoggerFactory: loggerFactory, }) var err0 error a0ReadyCh := make(chan bool) assert.Equal(t, rtoInitial, a0.rtoMgr.getRTO()) assert.Equal(t, rtoInitial, a1.rtoMgr.getRTO()) // modified rto for fast test a0.rtoMgr.setRTO(20, false) // fail after 4 retransmission a0.t1Cookie.maxRetrans = 4 go func() { err0 = <-a0.handshakeCompletedCh a0ReadyCh <- true }() // Drop all COOKIE-ECHO br.Filter(0, func(raw []byte) bool { p := &packet{} err := p.unmarshal(raw) if !assert.Nil(t, err, "failed to parse packet") { return false // drop } for _, c := range p.chunks { switch c.(type) { case *chunkCookieEcho: return false // drop default: return true } } return true }) // Start the handlshake a0.init(true) a1.init(false) a0Ready := false for !a0Ready { br.Process() select { case a0Ready = <-a0ReadyCh: default: } } assert.NotNil(t, err0, "should be an error") time.Sleep(1000 * time.Millisecond) br.Process() closeAssociationPair(br, a0, a1) }) } func TestAssocCreateNewStream(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() t.Run("acceptChSize", func(t *testing.T) { a := createAssociation(Config{ NetConn: &dumbConn{}, LoggerFactory: loggerFactory, }) for i := 0; i < acceptChSize; i++ { s := a.createStream(uint16(i), true) _, ok := a.streams[s.streamIdentifier] assert.True(t, ok, "should be in a.streams map") } newSI := uint16(acceptChSize) s := a.createStream(newSI, true) assert.Nil(t, s, "should be nil") _, ok := a.streams[newSI] assert.False(t, ok, "should NOT be in a.streams map") toBeIgnored := &chunkPayloadData{ beginningFragment: true, endingFragment: true, tsn: a.peerLastTSN + 1, streamIdentifier: newSI, userData: []byte("ABC"), } p := a.handleData(toBeIgnored) assert.Nil(t, p, "should be nil") }) } func TestAssocT3RtxTimer(t *testing.T) { // Send one packet, drop it, then retransmitted successfully. t.Run("Retransmission success", func(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() const si uint16 = 6 const msg1 = "ABC" var n int var ppi PayloadProtocolIdentifier br := test.NewBridge() a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) if !assert.Nil(t, err, "failed to create associations") { assert.FailNow(t, "failed due to earlier error") } // lock RTO value at 20 [msec] a0.rtoMgr.setRTO(20.0, false) a0.rtoMgr.noUpdate = true s0, s1, err := establishSessionPair(br, a0, a1, si) assert.Nil(t, err, "failed to establish session pair") n, err = s0.WriteSCTP([]byte(msg1), PayloadTypeWebRTCBinary) assert.Nil(t, err, "WriteSCTP failed") assert.Equal(t, n, len(msg1), "unexpected length of received data") br.Drop(0, 0, 1) // drop the first packet (second one should be sacked) // process packets for 100 msec for i := 0; i < 10; i++ { br.Tick() time.Sleep(10 * time.Millisecond) } buf := make([]byte, 32) n, ppi, err = s1.ReadSCTP(buf) if !assert.Nil(t, err, "ReadSCTP failed") { assert.FailNow(t, "failed due to earlier error") } assert.Equal(t, n, len(msg1), "unexpected length of received data") assert.Equal(t, msg1, string(buf[:n]), "unexpected received data") assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") br.Process() assert.False(t, s0.reassemblyQueue.isReadable(), "should no longer be readable") a0.lock.RLock() assert.Equal(t, 0, a0.pendingQueue.size(), "should be no packet pending") assert.Equal(t, 0, a0.inflightQueue.size(), "should be no packet inflight") a0.lock.RUnlock() closeAssociationPair(br, a0, a1) }) } func TestAssocCongestionControl(t *testing.T) { // sbuf - large enobh not to be bundled sbuf := make([]byte, 1000) for i := 0; i < len(sbuf); i++ { sbuf[i] = byte(i & 0xcc) } // 1) Send 4 packets. drop the first one. // 2) Last 3 packet will be received, which triggers fast-retransmission // 3) The first one is retransmitted, which makes s1 readable // Above should be done before RTO occurs (fast recovery) t.Run("Fast retransmission", func(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() const si uint16 = 6 var n int var ppi PayloadProtocolIdentifier br := test.NewBridge() a0, a1, err := createNewAssociationPair(br, ackModeNormal, 0) if !assert.Nil(t, err, "failed to create associations") { assert.FailNow(t, "failed due to earlier error") } s0, s1, err := establishSessionPair(br, a0, a1, si) assert.Nil(t, err, "failed to establish session pair") br.DropNextNWrites(0, 1) // drop the next write for i := 0; i < 4; i++ { binary.BigEndian.PutUint32(sbuf, uint32(i)) // uint32 sequence number n, err = s0.WriteSCTP(sbuf, PayloadTypeWebRTCBinary) assert.Nil(t, err, "WriteSCTP failed") assert.Equal(t, n, len(sbuf), "unexpected length of received data") } // process packets for 500 msec, assuming that the fast retrans/recover // should complete within 500 msec. for i := 0; i < 50; i++ { br.Tick() time.Sleep(10 * time.Millisecond) } rbuf := make([]byte, 3000) // Try to read all 4 packets for i := 0; i < 4; i++ { // The receiver (s1) should be readable s1.lock.RLock() readable := s1.reassemblyQueue.isReadable() s1.lock.RUnlock() if !assert.True(t, readable, "should be readable") { return } n, ppi, err = s1.ReadSCTP(rbuf) if !assert.Nil(t, err, "ReadSCTP failed") { return } assert.Equal(t, len(sbuf), n, "unexpected length of received data") assert.Equal(t, i, int(binary.BigEndian.Uint32(rbuf)), "unexpected length of received data") assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") } a0.lock.RLock() inFastRecovery := a0.inFastRecovery a0.lock.RUnlock() assert.False(t, inFastRecovery, "should not be in fast-recovery") t.Logf("nDATAs : %d\n", a1.stats.getNumDATAs()) t.Logf("nSACKs : %d\n", a0.stats.getNumSACKs()) t.Logf("nAckTimeouts: %d\n", a1.stats.getNumAckTimeouts()) t.Logf("nFastRetrans: %d\n", a0.stats.getNumFastRetrans()) assert.Equal(t, uint64(1), a0.stats.getNumFastRetrans(), "should be 1") closeAssociationPair(br, a0, a1) }) t.Run("Congestion Avoidance", func(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() const maxReceiveBufferSize uint32 = 64 * 1024 const si uint16 = 6 const nPacketsToSend = 2000 var n int var nPacketsReceived int var ppi PayloadProtocolIdentifier rbuf := make([]byte, 3000) br := test.NewBridge() a0, a1, err := createNewAssociationPair(br, ackModeNormal, maxReceiveBufferSize) if !assert.Nil(t, err, "failed to create associations") { assert.FailNow(t, "failed due to earlier error") } s0, s1, err := establishSessionPair(br, a0, a1, si) assert.Nil(t, err, "failed to establish session pair") a0.stats.reset() a1.stats.reset() for i := 0; i < nPacketsToSend; i++ { binary.BigEndian.PutUint32(sbuf, uint32(i)) // uint32 sequence number n, err = s0.WriteSCTP(sbuf, PayloadTypeWebRTCBinary) assert.Nil(t, err, "WriteSCTP failed") assert.Equal(t, n, len(sbuf), "unexpected length of received data") } // Repeat calling br.Tick() until the buffered amount becomes 0 for s0.BufferedAmount() > 0 && nPacketsReceived < nPacketsToSend { for { n = br.Tick() if n == 0 { break } } for { s1.lock.RLock() readable := s1.reassemblyQueue.isReadable() s1.lock.RUnlock() if !readable { break } n, ppi, err = s1.ReadSCTP(rbuf) if !assert.Nil(t, err, "ReadSCTP failed") { return } assert.Equal(t, len(sbuf), n, "unexpected length of received data") assert.Equal(t, nPacketsReceived, int(binary.BigEndian.Uint32(rbuf)), "unexpected length of received data") assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") nPacketsReceived++ } } br.Process() a0.lock.RLock() inFastRecovery := a0.inFastRecovery cwnd := a0.cwnd ssthresh := a0.ssthresh a0.lock.RUnlock() assert.False(t, inFastRecovery, "should not be in fast-recovery") assert.True(t, cwnd > ssthresh, "should be in congestion avoidance mode") assert.True(t, ssthresh >= maxReceiveBufferSize, "should not be less than the initial size of 128KB") assert.Equal(t, nPacketsReceived, nPacketsToSend, "unexpected num of packets received") assert.Equal(t, 0, s1.getNumBytesInReassemblyQueue(), "reassembly queue should be empty") t.Logf("nDATAs : %d\n", a1.stats.getNumDATAs()) t.Logf("nSACKs : %d\n", a0.stats.getNumSACKs()) t.Logf("nT3Timeouts : %d\n", a0.stats.getNumT3Timeouts()) assert.Equal(t, uint64(nPacketsToSend), a1.stats.getNumDATAs(), "packet count mismatch") assert.True(t, a0.stats.getNumSACKs() <= nPacketsToSend/2, "too many sacks") assert.Equal(t, uint64(0), a0.stats.getNumT3Timeouts(), "should be no retransmit") closeAssociationPair(br, a0, a1) }) // This is to test even rwnd becomes 0, sender should be able to send a zero window probe // on T3-rtx retramission timeout to complete receiving all the packets. t.Run("Slow reader", func(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() const maxReceiveBufferSize uint32 = 64 * 1024 const si uint16 = 6 nPacketsToSend := int(math.Floor(float64(maxReceiveBufferSize)/1000.0)) * 2 var n int var nPacketsReceived int var ppi PayloadProtocolIdentifier rbuf := make([]byte, 3000) br := test.NewBridge() a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, maxReceiveBufferSize) if !assert.Nil(t, err, "failed to create associations") { assert.FailNow(t, "failed due to earlier error") } s0, s1, err := establishSessionPair(br, a0, a1, si) assert.Nil(t, err, "failed to establish session pair") for i := 0; i < nPacketsToSend; i++ { binary.BigEndian.PutUint32(sbuf, uint32(i)) // uint32 sequence number n, err = s0.WriteSCTP(sbuf, PayloadTypeWebRTCBinary) assert.Nil(t, err, "WriteSCTP failed") assert.Equal(t, n, len(sbuf), "unexpected length of received data") } // 1. First forward packets to receiver until rwnd becomes 0 // 2. Wait until the sender's cwnd becomes 1*MTU (RTO occurred) // 3. Stat reading a1's data var hasRTOed bool for s0.BufferedAmount() > 0 && nPacketsReceived < nPacketsToSend { for { n = br.Tick() if n == 0 { break } } if !hasRTOed { a1.lock.RLock() rwnd := a1.getMyReceiverWindowCredit() a1.lock.RUnlock() a0.lock.RLock() cwnd := a0.cwnd a0.lock.RUnlock() if cwnd > a0.mtu || rwnd > 0 { // Do not read until a1.getMyReceiverWindowCredit() becomes zero continue } hasRTOed = true } for { s1.lock.RLock() readable := s1.reassemblyQueue.isReadable() s1.lock.RUnlock() if !readable { break } n, ppi, err = s1.ReadSCTP(rbuf) if !assert.Nil(t, err, "ReadSCTP failed") { return } assert.Equal(t, len(sbuf), n, "unexpected length of received data") assert.Equal(t, nPacketsReceived, int(binary.BigEndian.Uint32(rbuf)), "unexpected length of received data") assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") nPacketsReceived++ } time.Sleep(4 * time.Millisecond) } br.Process() assert.Equal(t, nPacketsReceived, nPacketsToSend, "unexpected num of packets received") assert.Equal(t, 0, s1.getNumBytesInReassemblyQueue(), "reassembly queue should be empty") t.Logf("nDATAs : %d\n", a1.stats.getNumDATAs()) t.Logf("nSACKs : %d\n", a0.stats.getNumSACKs()) t.Logf("nAckTimeouts: %d\n", a1.stats.getNumAckTimeouts()) closeAssociationPair(br, a0, a1) }) } func TestAssocDelayedAck(t *testing.T) { t.Run("First DATA chunk gets acked with delay", func(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() const si uint16 = 6 var n int var nPacketsReceived int var ppi PayloadProtocolIdentifier sbuf := make([]byte, 1000) // size should be less than initial cwnd (4380) rbuf := make([]byte, 1500) _, err := cryptoRand.Read(sbuf) if !assert.Nil(t, err, "failed to create associations") { return } br := test.NewBridge() a0, a1, err := createNewAssociationPair(br, ackModeAlwaysDelay, 0) if !assert.Nil(t, err, "failed to create associations") { assert.FailNow(t, "failed due to earlier error") } s0, s1, err := establishSessionPair(br, a0, a1, si) assert.Nil(t, err, "failed to establish session pair") a0.stats.reset() a1.stats.reset() // Writes data (will fragmented) n, err = s0.WriteSCTP(sbuf, PayloadTypeWebRTCBinary) assert.Nil(t, err, "WriteSCTP failed") assert.Equal(t, n, len(sbuf), "unexpected length of received data") // Repeat calling br.Tick() until the buffered amount becomes 0 since := time.Now() for s0.BufferedAmount() > 0 { for { n = br.Tick() if n == 0 { break } } for { s1.lock.RLock() readable := s1.reassemblyQueue.isReadable() s1.lock.RUnlock() if !readable { break } n, ppi, err = s1.ReadSCTP(rbuf) if !assert.Nil(t, err, "ReadSCTP failed") { return } assert.Equal(t, len(sbuf), n, "unexpected length of received data") assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") nPacketsReceived++ } } delay := time.Since(since).Seconds() t.Logf("received in %.03f seconds", delay) assert.True(t, delay >= 0.2, "should be >= 200msec") br.Process() assert.Equal(t, 1, nPacketsReceived, "should be one packet received") assert.Equal(t, 0, s1.getNumBytesInReassemblyQueue(), "reassembly queue should be empty") t.Logf("nDATAs : %d\n", a1.stats.getNumDATAs()) t.Logf("nSACKs : %d\n", a0.stats.getNumSACKs()) t.Logf("nAckTimeouts: %d\n", a1.stats.getNumAckTimeouts()) assert.Equal(t, uint64(1), a1.stats.getNumDATAs(), "DATA chunk count mismatch") assert.Equal(t, a0.stats.getNumSACKs(), a1.stats.getNumDATAs(), "sack count should be equal to the number of data chunks") assert.Equal(t, uint64(1), a1.stats.getNumAckTimeouts(), "ackTimeout count mismatch") assert.Equal(t, uint64(0), a0.stats.getNumT3Timeouts(), "should be no retransmit") closeAssociationPair(br, a0, a1) }) } func TestAssocReset(t *testing.T) { t.Run("Close one way", func(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() const si uint16 = 1 const msg = "ABC" br := test.NewBridge() a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) if !assert.Nil(t, err, "failed to create associations") { assert.FailNow(t, "failed due to earlier error") } s0, s1, err := establishSessionPair(br, a0, a1, si) assert.Nil(t, err, "failed to establish session pair") assert.Equal(t, 0, a0.bufferedAmount(), "incorrect bufferedAmount") n, err := s0.WriteSCTP([]byte(msg), PayloadTypeWebRTCBinary) if err != nil { assert.FailNow(t, "failed due to earlier error") } assert.Equal(t, len(msg), n, "unexpected length of received data") assert.Equal(t, len(msg), a0.bufferedAmount(), "incorrect bufferedAmount") err = s0.Close() // send reset if err != nil { t.Error(err) } doneCh := make(chan error) buf := make([]byte, 32) go func() { for { var ppi PayloadProtocolIdentifier n, ppi, err = s1.ReadSCTP(buf) if err != nil { doneCh <- err return } assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") assert.Equal(t, n, len(msg), "unexpected length of received data") } }() loop: for { br.Process() select { case err = <-doneCh: assert.Equal(t, io.EOF, err, "should end with EOF") break loop default: } } closeAssociationPair(br, a0, a1) }) t.Run("Close both ways", func(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() const si uint16 = 1 const msg = "ABC" br := test.NewBridge() a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) if !assert.Nil(t, err, "failed to create associations") { assert.FailNow(t, "failed due to earlier error") } s0, s1, err := establishSessionPair(br, a0, a1, si) assert.Nil(t, err, "failed to establish session pair") assert.Equal(t, 0, a0.bufferedAmount(), "incorrect bufferedAmount") n, err := s0.WriteSCTP([]byte(msg), PayloadTypeWebRTCBinary) if err != nil { assert.FailNow(t, "failed due to earlier error") } assert.Equal(t, len(msg), n, "unexpected length of received data") assert.Equal(t, len(msg), a0.bufferedAmount(), "incorrect bufferedAmount") err = s0.Close() // send reset if err != nil { t.Error(err) } doneCh := make(chan error) buf := make([]byte, 32) go func() { for { var ppi PayloadProtocolIdentifier n, ppi, err = s1.ReadSCTP(buf) if err != nil { doneCh <- err return } assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") assert.Equal(t, n, len(msg), "unexpected length of received data") } }() loop0: for { br.Process() select { case err = <-doneCh: assert.Equal(t, io.EOF, err, "should end with EOF") break loop0 default: } } err = s1.Close() // send reset if err != nil { t.Error(err) } go func() { for { _, _, err = s0.ReadSCTP(buf) assert.Equal(t, io.EOF, err, "should be EOF") doneCh <- err } }() loop1: for { br.Process() select { case <-doneCh: break loop1 default: } } time.Sleep(2 * time.Second) closeAssociationPair(br, a0, a1) }) } func TestAssocAbort(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() const si uint16 = 1 br := test.NewBridge() a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) assert.NoError(t, err) abort := &chunkAbort{ errorCauses: []errorCause{&errorCauseProtocolViolation{ errorCauseHeader: errorCauseHeader{code: protocolViolation}, }}, } packet, err := a0.createPacket([]chunk{abort}).marshal() assert.NoError(t, err) _, _, err = establishSessionPair(br, a0, a1, si) assert.NoError(t, err) // Both associations are established assert.Equal(t, established, a0.getState()) assert.Equal(t, established, a1.getState()) _, err = a0.netConn.Write(packet) assert.NoError(t, err) flushBuffers(br, a0, a1) // There is a little delay before changing the state to closed time.Sleep(10 * time.Millisecond) // The receiving association should be closed because it got an ABORT assert.Equal(t, established, a0.getState()) assert.Equal(t, closed, a1.getState()) closeAssociationPair(br, a0, a1) } type fakeEchoConn struct { echo chan []byte done chan struct{} closed chan struct{} once sync.Once errClose error mu sync.Mutex bytesSent uint64 bytesReceived uint64 } func newFakeEchoConn(errClose error) *fakeEchoConn { return &fakeEchoConn{ echo: make(chan []byte, 1), done: make(chan struct{}), closed: make(chan struct{}), errClose: errClose, } } func (c *fakeEchoConn) Read(b []byte) (int, error) { r, ok := <-c.echo if ok { copy(b, r) c.once.Do(func() { close(c.done) }) c.mu.Lock() c.bytesReceived += uint64(len(r)) c.mu.Unlock() return len(r), nil } return 0, io.EOF } func (c *fakeEchoConn) Write(b []byte) (int, error) { c.mu.Lock() defer c.mu.Unlock() select { case <-c.closed: return 0, io.EOF default: } c.echo <- b c.bytesSent += uint64(len(b)) return len(b), nil } func (c *fakeEchoConn) Close() error { c.mu.Lock() defer c.mu.Unlock() close(c.echo) close(c.closed) return c.errClose } func (c *fakeEchoConn) LocalAddr() net.Addr { return nil } func (c *fakeEchoConn) RemoteAddr() net.Addr { return nil } func (c *fakeEchoConn) SetDeadline(t time.Time) error { return nil } func (c *fakeEchoConn) SetReadDeadline(t time.Time) error { return nil } func (c *fakeEchoConn) SetWriteDeadline(t time.Time) error { return nil } func TestRoutineLeak(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() t.Run("Close failed", func(t *testing.T) { runtime.GC() n0 := runtime.NumGoroutine() conn := newFakeEchoConn(io.EOF) a, err := Client(Config{NetConn: conn, LoggerFactory: loggerFactory}) assert.Equal(t, nil, err, "errored to initialize Client") <-conn.done err = a.Close() assert.Equal(t, io.EOF, err, "Close() should fail with EOF") select { case _, ok := <-a.closeWriteLoopCh: if ok { t.Errorf("closeWriteLoopCh is expected to be closed, but received signal") } default: t.Errorf("closeWriteLoopCh is expected to be closed, but not") } _ = a runtime.GC() assert.Equal(t, n0, runtime.NumGoroutine(), "goroutine is leaked") }) t.Run("Connection closed by remote host", func(t *testing.T) { runtime.GC() n0 := runtime.NumGoroutine() conn := newFakeEchoConn(nil) a, err := Client(Config{NetConn: conn, LoggerFactory: loggerFactory}) assert.Equal(t, nil, err, "errored to initialize Client") <-conn.done err = conn.Close() // close connection assert.Equal(t, nil, err, "fake connection returned unexpected error") <-conn.closed <-time.After(10 * time.Millisecond) // switch context to make read/write loops finished select { case _, ok := <-a.closeWriteLoopCh: if ok { t.Errorf("closeWriteLoopCh is expected to be closed, but received signal") } default: t.Errorf("closeWriteLoopCh is expected to be closed, but not") } runtime.GC() assert.Equal(t, n0, runtime.NumGoroutine(), "goroutine is leaked") }) } func TestStats(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() conn := newFakeEchoConn(nil) a, err := Client(Config{NetConn: conn, LoggerFactory: loggerFactory}) assert.Equal(t, nil, err, "errored to initialize Client") <-conn.done assert.NoError(t, conn.Close()) conn.mu.Lock() defer conn.mu.Unlock() assert.Equal(t, conn.bytesReceived, a.BytesReceived()) assert.Equal(t, conn.bytesSent, a.BytesSent()) } func TestAssocHandleInit(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() handleInitTest := func(t *testing.T, initialState uint32, expectErr bool) { a := createAssociation(Config{ NetConn: &dumbConn{}, LoggerFactory: loggerFactory, }) a.setState(initialState) pkt := &packet{ sourcePort: 5001, destinationPort: 5002, } init := &chunkInit{} init.initialTSN = 1234 init.numOutboundStreams = 1001 init.numInboundStreams = 1002 init.initiateTag = 5678 init.advertisedReceiverWindowCredit = 512 * 1024 setSupportedExtensions(&init.chunkInitCommon) _, err := a.handleInit(pkt, init) if expectErr { assert.Error(t, err, "should fail") return } assert.NoError(t, err, "should succeed") assert.Equal(t, init.initialTSN-1, a.peerLastTSN, "should match") assert.Equal(t, uint16(1001), a.myMaxNumOutboundStreams, "should match") assert.Equal(t, uint16(1002), a.myMaxNumInboundStreams, "should match") assert.Equal(t, uint32(5678), a.peerVerificationTag, "should match") assert.Equal(t, pkt.sourcePort, a.destinationPort, "should match") assert.Equal(t, pkt.destinationPort, a.sourcePort, "should match") assert.True(t, a.useForwardTSN, "should be set to true") } t.Run("normal", func(t *testing.T) { handleInitTest(t, closed, false) }) t.Run("unexpected state established", func(t *testing.T) { handleInitTest(t, established, true) }) t.Run("unexpected state shutdownAckSent", func(t *testing.T) { handleInitTest(t, shutdownAckSent, true) }) t.Run("unexpected state shutdownPending", func(t *testing.T) { handleInitTest(t, shutdownPending, true) }) t.Run("unexpected state shutdownReceived", func(t *testing.T) { handleInitTest(t, shutdownReceived, true) }) t.Run("unexpected state shutdownSent", func(t *testing.T) { handleInitTest(t, shutdownSent, true) }) } func TestAssocMaxMessageSize(t *testing.T) { t.Run("default", func(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() a := createAssociation(Config{ LoggerFactory: loggerFactory, }) assert.NotNil(t, a, "should succeed") assert.Equal(t, uint32(65536), a.MaxMessageSize(), "should match") s := a.createStream(1, false) assert.NotNil(t, s, "should succeed") p := make([]byte, 65537) var err error _, err = s.WriteSCTP(p[:65536], s.defaultPayloadType) assert.False(t, strings.Contains(err.Error(), "larger than maximum"), "should be false") _, err = s.WriteSCTP(p[:65537], s.defaultPayloadType) assert.True(t, strings.Contains(err.Error(), "larger than maximum"), "should be false") }) t.Run("explicit", func(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() a := createAssociation(Config{ MaxMessageSize: 30000, LoggerFactory: loggerFactory, }) assert.NotNil(t, a, "should succeed") assert.Equal(t, uint32(30000), a.MaxMessageSize(), "should match") s := a.createStream(1, false) assert.NotNil(t, s, "should succeed") p := make([]byte, 30001) var err error _, err = s.WriteSCTP(p[:30000], s.defaultPayloadType) assert.False(t, strings.Contains(err.Error(), "larger than maximum"), "should be false") _, err = s.WriteSCTP(p[:30001], s.defaultPayloadType) assert.True(t, strings.Contains(err.Error(), "larger than maximum"), "should be false") }) t.Run("set value", func(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() a := createAssociation(Config{ LoggerFactory: loggerFactory, }) assert.NotNil(t, a, "should succeed") assert.Equal(t, uint32(65536), a.MaxMessageSize(), "should match") a.SetMaxMessageSize(20000) assert.Equal(t, uint32(20000), a.MaxMessageSize(), "should match") }) } func createAssocs(t *testing.T) (a1, a2 *Association) { addr1 := &net.UDPAddr{ IP: net.IP{127, 0, 0, 1}, Port: 1234, } addr2 := &net.UDPAddr{ IP: net.IP{127, 0, 0, 1}, Port: 5678, } udp1, err := net.DialUDP("udp", addr1, addr2) if err != nil { panic(err) } udp2, err := net.DialUDP("udp", addr2, addr1) if err != nil { panic(err) } loggerFactory := logging.NewDefaultLoggerFactory() a1Chan := make(chan *Association) a2Chan := make(chan *Association) go func() { a, err := Client(Config{ NetConn: udp1, LoggerFactory: loggerFactory, }) require.NoError(t, err) a1Chan <- a }() go func() { a, err := Client(Config{ NetConn: udp2, LoggerFactory: loggerFactory, }) require.NoError(t, err) a2Chan <- a }() select { case a1 = <-a1Chan: case <-time.After(time.Second): assert.Fail(t, "timed out waiting for a1") } select { case a2 = <-a2Chan: case <-time.After(time.Second): assert.Fail(t, "timed out waiting for a2") } return a1, a2 } func TestAssociation_Shutdown(t *testing.T) { runtime.GC() n0 := runtime.NumGoroutine() defer func() { runtime.GC() assert.Equal(t, n0, runtime.NumGoroutine(), "goroutine is leaked") }() a1, a2 := createAssocs(t) s11, err := a1.OpenStream(1, PayloadTypeWebRTCString) require.NoError(t, err) s21, err := a2.OpenStream(1, PayloadTypeWebRTCString) require.NoError(t, err) testData := []byte("test") i, err := s11.Write(testData) assert.Equal(t, len(testData), i) assert.NoError(t, err) buf := make([]byte, len(testData)) i, err = s21.Read(buf) assert.Equal(t, len(testData), i) assert.NoError(t, err) assert.Equal(t, testData, buf) ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() err = a1.Shutdown(ctx) require.NoError(t, err) // Wait for close read loop channels to prevent flaky tests. select { case <-a2.readLoopCloseCh: case <-time.After(1 * time.Second): assert.Fail(t, "timed out waiting for a2 read loop to close") } } func TestAssociation_ShutdownDuringWrite(t *testing.T) { runtime.GC() n0 := runtime.NumGoroutine() defer func() { runtime.GC() assert.Equal(t, n0, runtime.NumGoroutine(), "goroutine is leaked") }() a1, a2 := createAssocs(t) s11, err := a1.OpenStream(1, PayloadTypeWebRTCString) require.NoError(t, err) s21, err := a2.OpenStream(1, PayloadTypeWebRTCString) require.NoError(t, err) writingDone := make(chan struct{}) go func() { defer close(writingDone) var i byte for { i++ if i%100 == 0 { time.Sleep(20 * time.Millisecond) } _, writeErr := s21.Write([]byte{i}) if writeErr != nil { return } } }() testData := []byte("test") i, err := s11.Write(testData) assert.Equal(t, len(testData), i) assert.NoError(t, err) buf := make([]byte, len(testData)) i, err = s21.Read(buf) assert.Equal(t, len(testData), i) assert.NoError(t, err) assert.Equal(t, testData, buf) // running this test with -race flag is very slow so timeout needs to be high. timeout := 5 * time.Minute ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() err = a1.Shutdown(ctx) require.NoError(t, err, "timed out waiting for a1 shutdown to complete") select { case <-writingDone: case <-time.After(timeout): assert.Fail(t, "timed out waiting writing goroutine to exit") } // Wait for close read loop channels to prevent flaky tests. select { case <-a2.readLoopCloseCh: case <-time.After(timeout): assert.Fail(t, "timed out waiting for a2 read loop to close") } } func TestAssociation_HandlePacketInCookieWaitState(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() testCases := map[string]struct { inputPacket *packet skipClose bool }{ "InitAck": { inputPacket: &packet{ sourcePort: 1, destinationPort: 1, chunks: []chunk{ &chunkInitAck{ chunkInitCommon: chunkInitCommon{ initiateTag: 1, numInboundStreams: 1, numOutboundStreams: 1, advertisedReceiverWindowCredit: 1500, }, }, }, }, }, "Abort": { inputPacket: &packet{ sourcePort: 1, destinationPort: 1, chunks: []chunk{&chunkAbort{}}, }, // Prevent "use of close network connection" error on close. skipClose: true, }, "CoockeEcho": { inputPacket: &packet{ sourcePort: 1, destinationPort: 1, chunks: []chunk{&chunkCookieEcho{}}, }, }, "HeartBeat": { inputPacket: &packet{ sourcePort: 1, destinationPort: 1, chunks: []chunk{&chunkHeartbeat{}}, }, }, "PayloadData": { inputPacket: &packet{ sourcePort: 1, destinationPort: 1, chunks: []chunk{&chunkPayloadData{}}, }, }, "Sack": { inputPacket: &packet{ sourcePort: 1, destinationPort: 1, chunks: []chunk{&chunkSelectiveAck{ cumulativeTSNAck: 1000, advertisedReceiverWindowCredit: 1500, gapAckBlocks: []gapAckBlock{ {start: 100, end: 200}, }, }}, }, }, "Reconfig": { inputPacket: &packet{ sourcePort: 1, destinationPort: 1, chunks: []chunk{&chunkReconfig{ paramA: ¶mOutgoingResetRequest{}, paramB: ¶mReconfigResponse{}, }}, }, }, "ForwardTSN": { inputPacket: &packet{ sourcePort: 1, destinationPort: 1, chunks: []chunk{&chunkForwardTSN{ newCumulativeTSN: 100, }}, }, }, "Error": { inputPacket: &packet{ sourcePort: 1, destinationPort: 1, chunks: []chunk{&chunkError{}}, }, }, "Shutdown": { inputPacket: &packet{ sourcePort: 1, destinationPort: 1, chunks: []chunk{&chunkShutdown{}}, }, }, "ShutdownAck": { inputPacket: &packet{ sourcePort: 1, destinationPort: 1, chunks: []chunk{&chunkShutdownAck{}}, }, }, "ShutdownComplete": { inputPacket: &packet{ sourcePort: 1, destinationPort: 1, chunks: []chunk{&chunkShutdownComplete{}}, }, }, } for name, testCase := range testCases { testCase := testCase t.Run(name, func(t *testing.T) { aConn, charlieConn := pipeDump() a := createAssociation(Config{ NetConn: aConn, MaxReceiveBufferSize: 0, LoggerFactory: loggerFactory, }) a.init(true) if !testCase.skipClose { defer func() { assert.NoError(t, a.close()) }() } packet, err := testCase.inputPacket.marshal() assert.NoError(t, err) _, err = charlieConn.Write(packet) assert.NoError(t, err) // Should not panic. time.Sleep(100 * time.Millisecond) }) } } func TestAssociation_Abort(t *testing.T) { runtime.GC() n0 := runtime.NumGoroutine() defer func() { runtime.GC() assert.Equal(t, n0, runtime.NumGoroutine(), "goroutine is leaked") }() a1, a2 := createAssocs(t) s11, err := a1.OpenStream(1, PayloadTypeWebRTCString) require.NoError(t, err) s21, err := a2.OpenStream(1, PayloadTypeWebRTCString) require.NoError(t, err) testData := []byte("test") i, err := s11.Write(testData) assert.Equal(t, len(testData), i) assert.NoError(t, err) buf := make([]byte, len(testData)) i, err = s21.Read(buf) assert.Equal(t, len(testData), i) assert.NoError(t, err) assert.Equal(t, testData, buf) a1.Abort("1234") // Wait for close read loop channels to prevent flaky tests. select { case <-a2.readLoopCloseCh: case <-time.After(1 * time.Second): assert.Fail(t, "timed out waiting for a2 read loop to close") } i, err = s21.Read(buf) assert.Equal(t, i, 0, "expected no data read") assert.Error(t, err, "User Initiated Abort: 1234", "expected abort reason") } sctp-1.8.6/chunk.go000066400000000000000000000002121436021606300141320ustar00rootroot00000000000000package sctp type chunk interface { unmarshal(raw []byte) error marshal() ([]byte, error) check() (bool, error) valueLength() int } sctp-1.8.6/chunk_abort.go000066400000000000000000000046211436021606300153310ustar00rootroot00000000000000package sctp // nolint:dupl import ( "errors" "fmt" ) /* Abort represents an SCTP Chunk of type ABORT The ABORT chunk is sent to the peer of an association to close the association. The ABORT chunk may contain Cause Parameters to inform the receiver about the reason of the abort. DATA chunks MUST NOT be bundled with ABORT. Control chunks (except for INIT, INIT ACK, and SHUTDOWN COMPLETE) MAY be bundled with an ABORT, but they MUST be placed before the ABORT in the SCTP packet or they will be ignored by the receiver. 0 1 2 3 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Type = 6 |Reserved |T| Length | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | | | zero or more Error Causes | | | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ */ type chunkAbort struct { chunkHeader errorCauses []errorCause } // Abort chunk errors var ( ErrChunkTypeNotAbort = errors.New("ChunkType is not of type ABORT") ErrBuildAbortChunkFailed = errors.New("failed build Abort Chunk") ) func (a *chunkAbort) unmarshal(raw []byte) error { if err := a.chunkHeader.unmarshal(raw); err != nil { return err } if a.typ != ctAbort { return fmt.Errorf("%w: actually is %s", ErrChunkTypeNotAbort, a.typ.String()) } offset := chunkHeaderSize for { if len(raw)-offset < 4 { break } e, err := buildErrorCause(raw[offset:]) if err != nil { return fmt.Errorf("%w: %v", ErrBuildAbortChunkFailed, err) } offset += int(e.length()) a.errorCauses = append(a.errorCauses, e) } return nil } func (a *chunkAbort) marshal() ([]byte, error) { a.chunkHeader.typ = ctAbort a.flags = 0x00 a.raw = []byte{} for _, ec := range a.errorCauses { raw, err := ec.marshal() if err != nil { return nil, err } a.raw = append(a.raw, raw...) } return a.chunkHeader.marshal() } func (a *chunkAbort) check() (abort bool, err error) { return false, nil } // String makes chunkAbort printable func (a *chunkAbort) String() string { res := a.chunkHeader.String() for _, cause := range a.errorCauses { res += fmt.Sprintf("\n - %s", cause) } return res } sctp-1.8.6/chunk_abort_test.go000066400000000000000000000031121436021606300163620ustar00rootroot00000000000000package sctp import ( "testing" "github.com/stretchr/testify/assert" ) func TestAbortChunk(t *testing.T) { t.Run("One error cause", func(t *testing.T) { abort1 := &chunkAbort{ errorCauses: []errorCause{&errorCauseProtocolViolation{ errorCauseHeader: errorCauseHeader{code: protocolViolation}, }}, } bytes, err := abort1.marshal() assert.NoError(t, err, "should succeed") abort2 := &chunkAbort{} err = abort2.unmarshal(bytes) assert.NoError(t, err, "should succeed") assert.Equal(t, 1, len(abort2.errorCauses), "should have only one cause") assert.Equal(t, abort1.errorCauses[0].errorCauseCode(), abort2.errorCauses[0].errorCauseCode(), "errorCause code should match") }) t.Run("Many error causes", func(t *testing.T) { abort1 := &chunkAbort{ errorCauses: []errorCause{ &errorCauseProtocolViolation{ errorCauseHeader: errorCauseHeader{code: invalidMandatoryParameter}, }, &errorCauseProtocolViolation{ errorCauseHeader: errorCauseHeader{code: unrecognizedChunkType}, }, &errorCauseProtocolViolation{ errorCauseHeader: errorCauseHeader{code: protocolViolation}, }, }, } bytes, err := abort1.marshal() assert.NoError(t, err, "should succeed") abort2 := &chunkAbort{} err = abort2.unmarshal(bytes) assert.NoError(t, err, "should succeed") assert.Equal(t, 3, len(abort2.errorCauses), "should have only one cause") for i, errorCause := range abort1.errorCauses { assert.Equal(t, errorCause.errorCauseCode(), abort2.errorCauses[i].errorCauseCode(), "errorCause code should match") } }) } sctp-1.8.6/chunk_cookie_ack.go000066400000000000000000000022241436021606300163060ustar00rootroot00000000000000package sctp import ( "errors" "fmt" ) /* chunkCookieAck represents an SCTP Chunk of type chunkCookieAck 0 1 2 3 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Type = 11 |Chunk Flags | Length = 4 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ */ type chunkCookieAck struct { chunkHeader } // Cookie ack chunk errors var ( ErrChunkTypeNotCookieAck = errors.New("ChunkType is not of type COOKIEACK") ) func (c *chunkCookieAck) unmarshal(raw []byte) error { if err := c.chunkHeader.unmarshal(raw); err != nil { return err } if c.typ != ctCookieAck { return fmt.Errorf("%w: actually is %s", ErrChunkTypeNotCookieAck, c.typ.String()) } return nil } func (c *chunkCookieAck) marshal() ([]byte, error) { c.chunkHeader.typ = ctCookieAck return c.chunkHeader.marshal() } func (c *chunkCookieAck) check() (abort bool, err error) { return false, nil } // String makes chunkCookieAck printable func (c *chunkCookieAck) String() string { return c.chunkHeader.String() } sctp-1.8.6/chunk_cookie_echo.go000066400000000000000000000024451436021606300164730ustar00rootroot00000000000000package sctp import ( "errors" "fmt" ) /* CookieEcho represents an SCTP Chunk of type CookieEcho 0 1 2 3 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Type = 10 |Chunk Flags | Length | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Cookie | | | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ */ type chunkCookieEcho struct { chunkHeader cookie []byte } // Cookie echo chunk errors var ( ErrChunkTypeNotCookieEcho = errors.New("ChunkType is not of type COOKIEECHO") ) func (c *chunkCookieEcho) unmarshal(raw []byte) error { if err := c.chunkHeader.unmarshal(raw); err != nil { return err } if c.typ != ctCookieEcho { return fmt.Errorf("%w: actually is %s", ErrChunkTypeNotCookieEcho, c.typ.String()) } c.cookie = c.raw return nil } func (c *chunkCookieEcho) marshal() ([]byte, error) { c.chunkHeader.typ = ctCookieEcho c.chunkHeader.raw = c.cookie return c.chunkHeader.marshal() } func (c *chunkCookieEcho) check() (abort bool, err error) { return false, nil } sctp-1.8.6/chunk_error.go000066400000000000000000000050741436021606300153560ustar00rootroot00000000000000package sctp // nolint:dupl import ( "errors" "fmt" ) /* Operation Error (ERROR) (9) An endpoint sends this chunk to its peer endpoint to notify it of certain error conditions. It contains one or more error causes. An Operation Error is not considered fatal in and of itself, but may be used with an ERROR chunk to report a fatal condition. It has the following parameters: 0 1 2 3 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Type = 9 | Chunk Flags | Length | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ \ \ / one or more Error Causes / \ \ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ Chunk Flags: 8 bits Set to 0 on transmit and ignored on receipt. Length: 16 bits (unsigned integer) Set to the size of the chunk in bytes, including the chunk header and all the Error Cause fields present. */ type chunkError struct { chunkHeader errorCauses []errorCause } // Error chunk errors var ( ErrChunkTypeNotCtError = errors.New("ChunkType is not of type ctError") ErrBuildErrorChunkFailed = errors.New("failed build Error Chunk") ) func (a *chunkError) unmarshal(raw []byte) error { if err := a.chunkHeader.unmarshal(raw); err != nil { return err } if a.typ != ctError { return fmt.Errorf("%w, actually is %s", ErrChunkTypeNotCtError, a.typ.String()) } offset := chunkHeaderSize for { if len(raw)-offset < 4 { break } e, err := buildErrorCause(raw[offset:]) if err != nil { return fmt.Errorf("%w: %v", ErrBuildErrorChunkFailed, err) } offset += int(e.length()) a.errorCauses = append(a.errorCauses, e) } return nil } func (a *chunkError) marshal() ([]byte, error) { a.chunkHeader.typ = ctError a.flags = 0x00 a.raw = []byte{} for _, ec := range a.errorCauses { raw, err := ec.marshal() if err != nil { return nil, err } a.raw = append(a.raw, raw...) } return a.chunkHeader.marshal() } func (a *chunkError) check() (abort bool, err error) { return false, nil } // String makes chunkError printable func (a *chunkError) String() string { res := a.chunkHeader.String() for _, cause := range a.errorCauses { res += fmt.Sprintf("\n - %s", cause) } return res } sctp-1.8.6/chunk_error_test.go000066400000000000000000000036361436021606300164170ustar00rootroot00000000000000package sctp import ( "reflect" "testing" "github.com/stretchr/testify/assert" ) func TestChunkErrorUnrecognizedChunkType(t *testing.T) { const chunkFlags byte = 0x00 orgUnrecognizedChunk := []byte{0xc0, 0x0, 0x0, 0x8, 0x0, 0x0, 0x0, 0x3} rawIn := append([]byte{byte(ctError), chunkFlags, 0x00, 0x10, 0x00, 0x06, 0x00, 0x0c}, orgUnrecognizedChunk...) t.Run("unmarshal", func(t *testing.T) { c := &chunkError{} err := c.unmarshal(rawIn) assert.Nil(t, err, "unmarshal should succeed") assert.Equal(t, ctError, c.typ, "chunk type should be ERROR") assert.Equal(t, 1, len(c.errorCauses), "there should be on errorCause") ec := c.errorCauses[0] assert.Equal(t, unrecognizedChunkType, ec.errorCauseCode(), "cause code should be unrecognizedChunkType") ecUnrecognizedChunkType, ok := ec.(*errorCauseUnrecognizedChunkType) assert.True(t, ok) unrecognizedChunk := ecUnrecognizedChunkType.unrecognizedChunk assert.True(t, reflect.DeepEqual(unrecognizedChunk, orgUnrecognizedChunk), "should have valid unrecognizedChunk") }) t.Run("marshal", func(t *testing.T) { ecUnrecognizedChunkType := &errorCauseUnrecognizedChunkType{ unrecognizedChunk: orgUnrecognizedChunk, } ec := &chunkError{ errorCauses: []errorCause{ errorCause(ecUnrecognizedChunkType), }, } raw, err := ec.marshal() assert.Nil(t, err, "marshal should succeed") assert.True(t, reflect.DeepEqual(raw, rawIn), "unexpected serialization result") }) t.Run("marshal with cause value being nil", func(t *testing.T) { expected := []byte{byte(ctError), chunkFlags, 0x00, 0x08, 0x00, 0x06, 0x00, 0x04} ecUnrecognizedChunkType := &errorCauseUnrecognizedChunkType{} ec := &chunkError{ errorCauses: []errorCause{ errorCause(ecUnrecognizedChunkType), }, } raw, err := ec.marshal() assert.Nil(t, err, "marshal should succeed") assert.True(t, reflect.DeepEqual(raw, expected), "unexpected serialization result") }) } sctp-1.8.6/chunk_forward_tsn.go000066400000000000000000000106021436021606300165460ustar00rootroot00000000000000package sctp import ( "encoding/binary" "errors" "fmt" ) // This chunk shall be used by the data sender to inform the data // receiver to adjust its cumulative received TSN point forward because // some missing TSNs are associated with data chunks that SHOULD NOT be // transmitted or retransmitted by the sender. // // 0 1 2 3 // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | Type = 192 | Flags = 0x00 | Length = Variable | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | New Cumulative TSN | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | Stream-1 | Stream Sequence-1 | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // \ / // / \ // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | Stream-N | Stream Sequence-N | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ type chunkForwardTSN struct { chunkHeader // This indicates the new cumulative TSN to the data receiver. Upon // the reception of this value, the data receiver MUST consider // any missing TSNs earlier than or equal to this value as received, // and stop reporting them as gaps in any subsequent SACKs. newCumulativeTSN uint32 streams []chunkForwardTSNStream } const ( newCumulativeTSNLength = 4 forwardTSNStreamLength = 4 ) // Forward TSN chunk errors var ( ErrMarshalStreamFailed = errors.New("failed to marshal stream") ErrChunkTooShort = errors.New("chunk too short") ) func (c *chunkForwardTSN) unmarshal(raw []byte) error { if err := c.chunkHeader.unmarshal(raw); err != nil { return err } if len(c.raw) < newCumulativeTSNLength { return ErrChunkTooShort } c.newCumulativeTSN = binary.BigEndian.Uint32(c.raw[0:]) offset := newCumulativeTSNLength remaining := len(c.raw) - offset for remaining > 0 { s := chunkForwardTSNStream{} if err := s.unmarshal(c.raw[offset:]); err != nil { return fmt.Errorf("%w: %v", ErrMarshalStreamFailed, err) } c.streams = append(c.streams, s) offset += s.length() remaining -= s.length() } return nil } func (c *chunkForwardTSN) marshal() ([]byte, error) { out := make([]byte, newCumulativeTSNLength) binary.BigEndian.PutUint32(out[0:], c.newCumulativeTSN) for _, s := range c.streams { b, err := s.marshal() if err != nil { return nil, fmt.Errorf("%w: %v", ErrMarshalStreamFailed, err) } out = append(out, b...) } c.typ = ctForwardTSN c.raw = out return c.chunkHeader.marshal() } func (c *chunkForwardTSN) check() (abort bool, err error) { return true, nil } // String makes chunkForwardTSN printable func (c *chunkForwardTSN) String() string { res := fmt.Sprintf("New Cumulative TSN: %d\n", c.newCumulativeTSN) for _, s := range c.streams { res += fmt.Sprintf(" - si=%d, ssn=%d\n", s.identifier, s.sequence) } return res } type chunkForwardTSNStream struct { // This field holds a stream number that was skipped by this // FWD-TSN. identifier uint16 // This field holds the sequence number associated with the stream // that was skipped. The stream sequence field holds the largest // stream sequence number in this stream being skipped. The receiver // of the FWD-TSN's can use the Stream-N and Stream Sequence-N fields // to enable delivery of any stranded TSN's that remain on the stream // re-ordering queues. This field MUST NOT report TSN's corresponding // to DATA chunks that are marked as unordered. For ordered DATA // chunks this field MUST be filled in. sequence uint16 } func (s *chunkForwardTSNStream) length() int { return forwardTSNStreamLength } func (s *chunkForwardTSNStream) unmarshal(raw []byte) error { if len(raw) < forwardTSNStreamLength { return ErrChunkTooShort } s.identifier = binary.BigEndian.Uint16(raw[0:]) s.sequence = binary.BigEndian.Uint16(raw[2:]) return nil } func (s *chunkForwardTSNStream) marshal() ([]byte, error) { // nolint:unparam out := make([]byte, forwardTSNStreamLength) binary.BigEndian.PutUint16(out[0:], s.identifier) binary.BigEndian.PutUint16(out[2:], s.sequence) return out, nil } sctp-1.8.6/chunk_forward_tsn_test.go000066400000000000000000000024171436021606300176120ustar00rootroot00000000000000package sctp import ( "testing" "github.com/stretchr/testify/assert" ) func testChunkForwardTSN() []byte { return []byte{0xc0, 0x0, 0x0, 0x8, 0x0, 0x0, 0x0, 0x3} } func TestChunkForwardTSN_Success(t *testing.T) { tt := []struct { binary []byte }{ {testChunkForwardTSN()}, {[]byte{0xc0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x3, 0x0, 0x4, 0x0, 0x5}}, {[]byte{0xc0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x3, 0x0, 0x4, 0x0, 0x5, 0x0, 0x6, 0x0, 0x7}}, } for i, tc := range tt { actual := &chunkForwardTSN{} err := actual.unmarshal(tc.binary) if err != nil { t.Fatalf("failed to unmarshal #%d: %v", i, err) } b, err := actual.marshal() if err != nil { t.Fatalf("failed to marshal: %v", err) } assert.Equal(t, tc.binary, b, "test %d not equal", i) } } func TestChunkForwardTSNUnmarshal_Failure(t *testing.T) { tt := []struct { name string binary []byte }{ {"chunk header to short", []byte{0xc0}}, {"missing New Cumulative TSN", []byte{0xc0, 0x0, 0x0, 0x4}}, {"missing stream sequence", []byte{0xc0, 0x0, 0x0, 0xe, 0x0, 0x0, 0x0, 0x3, 0x0, 0x4, 0x0, 0x5, 0x0, 0x6}}, } for i, tc := range tt { actual := &chunkForwardTSN{} err := actual.unmarshal(tc.binary) if err == nil { t.Errorf("expected unmarshal #%d: '%s' to fail.", i, tc.name) } } } sctp-1.8.6/chunk_heartbeat.go000066400000000000000000000052741436021606300161660ustar00rootroot00000000000000package sctp import ( "errors" "fmt" ) /* chunkHeartbeat represents an SCTP Chunk of type HEARTBEAT An endpoint should send this chunk to its peer endpoint to probe the reachability of a particular destination transport address defined in the present association. The parameter field contains the Heartbeat Information, which is a variable-length opaque data structure understood only by the sender. 0 1 2 3 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Type = 4 | Chunk Flags | Heartbeat Length | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | | | Heartbeat Information TLV (Variable-Length) | | | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ Defined as a variable-length parameter using the format described in Section 3.2.1, i.e.: Variable Parameters Status Type Value ------------------------------------------------------------- heartbeat Info Mandatory 1 */ type chunkHeartbeat struct { chunkHeader params []param } // Heartbeat chunk errors var ( ErrChunkTypeNotHeartbeat = errors.New("ChunkType is not of type HEARTBEAT") ErrHeartbeatNotLongEnoughInfo = errors.New("heartbeat is not long enough to contain Heartbeat Info") ErrParseParamTypeFailed = errors.New("failed to parse param type") ErrHeartbeatParam = errors.New("heartbeat should only have HEARTBEAT param") ErrHeartbeatChunkUnmarshal = errors.New("failed unmarshalling param in Heartbeat Chunk") ) func (h *chunkHeartbeat) unmarshal(raw []byte) error { if err := h.chunkHeader.unmarshal(raw); err != nil { return err } else if h.typ != ctHeartbeat { return fmt.Errorf("%w: actually is %s", ErrChunkTypeNotHeartbeat, h.typ.String()) } if len(raw) <= chunkHeaderSize { return fmt.Errorf("%w: %d", ErrHeartbeatNotLongEnoughInfo, len(raw)) } pType, err := parseParamType(raw[chunkHeaderSize:]) if err != nil { return fmt.Errorf("%w: %v", ErrParseParamTypeFailed, err) } if pType != heartbeatInfo { return fmt.Errorf("%w: instead have %s", ErrHeartbeatParam, pType.String()) } p, err := buildParam(pType, raw[chunkHeaderSize:]) if err != nil { return fmt.Errorf("%w: %v", ErrHeartbeatChunkUnmarshal, err) } h.params = append(h.params, p) return nil } func (h *chunkHeartbeat) Marshal() ([]byte, error) { return nil, ErrUnimplemented } func (h *chunkHeartbeat) check() (abort bool, err error) { return false, nil } sctp-1.8.6/chunk_heartbeat_ack.go000066400000000000000000000057111436021606300170000ustar00rootroot00000000000000package sctp import ( "errors" "fmt" ) /* chunkHeartbeatAck represents an SCTP Chunk of type HEARTBEAT ACK An endpoint should send this chunk to its peer endpoint as a response to a HEARTBEAT chunk (see Section 8.3). A HEARTBEAT ACK is always sent to the source IP address of the IP datagram containing the HEARTBEAT chunk to which this ack is responding. The parameter field contains a variable-length opaque data structure. 0 1 2 3 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Type = 5 | Chunk Flags | Heartbeat Ack Length | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | | | Heartbeat Information TLV (Variable-Length) | | | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ Defined as a variable-length parameter using the format described in Section 3.2.1, i.e.: Variable Parameters Status Type Value ------------------------------------------------------------- Heartbeat Info Mandatory 1 */ type chunkHeartbeatAck struct { chunkHeader params []param } // Heartbeat ack chunk errors var ( ErrUnimplemented = errors.New("unimplemented") ErrHeartbeatAckParams = errors.New("heartbeat Ack must have one param") ErrHeartbeatAckNotHeartbeatInfo = errors.New("heartbeat Ack must have one param, and it should be a HeartbeatInfo") ErrHeartbeatAckMarshalParam = errors.New("unable to marshal parameter for Heartbeat Ack") ) func (h *chunkHeartbeatAck) unmarshal(raw []byte) error { return ErrUnimplemented } func (h *chunkHeartbeatAck) marshal() ([]byte, error) { if len(h.params) != 1 { return nil, ErrHeartbeatAckParams } switch h.params[0].(type) { case *paramHeartbeatInfo: // ParamHeartbeatInfo is valid default: return nil, ErrHeartbeatAckNotHeartbeatInfo } out := make([]byte, 0) for idx, p := range h.params { pp, err := p.marshal() if err != nil { return nil, fmt.Errorf("%w: %v", ErrHeartbeatAckMarshalParam, err) } out = append(out, pp...) // Chunks (including Type, Length, and Value fields) are padded out // by the sender with all zero bytes to be a multiple of 4 bytes // long. This padding MUST NOT be more than 3 bytes in total. The // Chunk Length value does not include terminating padding of the // chunk. *However, it does include padding of any variable-length // parameter except the last parameter in the chunk.* The receiver // MUST ignore the padding. if idx != len(h.params)-1 { out = padByte(out, getPadding(len(pp))) } } h.chunkHeader.typ = ctHeartbeatAck h.chunkHeader.raw = out return h.chunkHeader.marshal() } func (h *chunkHeartbeatAck) check() (abort bool, err error) { return false, nil } sctp-1.8.6/chunk_init.go000066400000000000000000000111101436021606300151540ustar00rootroot00000000000000package sctp // nolint:dupl import ( "errors" "fmt" ) /* Init represents an SCTP Chunk of type INIT See chunkInitCommon for the fixed headers Variable Parameters Status Type Value ------------------------------------------------------------- IPv4 IP (Note 1) Optional 5 IPv6 IP (Note 1) Optional 6 Cookie Preservative Optional 9 Reserved for ECN Capable (Note 2) Optional 32768 (0x8000) Host Name IP (Note 3) Optional 11 Supported IP Types (Note 4) Optional 12 */ type chunkInit struct { chunkHeader chunkInitCommon } // Init chunk errors var ( ErrChunkTypeNotTypeInit = errors.New("ChunkType is not of type INIT") ErrChunkValueNotLongEnough = errors.New("chunk Value isn't long enough for mandatory parameters exp") ErrChunkTypeInitFlagZero = errors.New("ChunkType of type INIT flags must be all 0") ErrChunkTypeInitUnmarshalFailed = errors.New("failed to unmarshal INIT body") ErrChunkTypeInitMarshalFailed = errors.New("failed marshaling INIT common data") ErrChunkTypeInitInitateTagZero = errors.New("ChunkType of type INIT ACK InitiateTag must not be 0") ErrInitInboundStreamRequestZero = errors.New("INIT ACK inbound stream request must be > 0") ErrInitOutboundStreamRequestZero = errors.New("INIT ACK outbound stream request must be > 0") ErrInitAdvertisedReceiver1500 = errors.New("INIT ACK Advertised Receiver Window Credit (a_rwnd) must be >= 1500") ) func (i *chunkInit) unmarshal(raw []byte) error { if err := i.chunkHeader.unmarshal(raw); err != nil { return err } if i.typ != ctInit { return fmt.Errorf("%w: actually is %s", ErrChunkTypeNotTypeInit, i.typ.String()) } else if len(i.raw) < initChunkMinLength { return fmt.Errorf("%w: %d actual: %d", ErrChunkValueNotLongEnough, initChunkMinLength, len(i.raw)) } // The Chunk Flags field in INIT is reserved, and all bits in it should // be set to 0 by the sender and ignored by the receiver. The sequence // of parameters within an INIT can be processed in any order. if i.flags != 0 { return ErrChunkTypeInitFlagZero } if err := i.chunkInitCommon.unmarshal(i.raw); err != nil { return fmt.Errorf("%w: %v", ErrChunkTypeInitUnmarshalFailed, err) } return nil } func (i *chunkInit) marshal() ([]byte, error) { initShared, err := i.chunkInitCommon.marshal() if err != nil { return nil, fmt.Errorf("%w: %v", ErrChunkTypeInitMarshalFailed, err) } i.chunkHeader.typ = ctInit i.chunkHeader.raw = initShared return i.chunkHeader.marshal() } func (i *chunkInit) check() (abort bool, err error) { // The receiver of the INIT (the responding end) records the value of // the Initiate Tag parameter. This value MUST be placed into the // Verification Tag field of every SCTP packet that the receiver of // the INIT transmits within this association. // // The Initiate Tag is allowed to have any value except 0. See // Section 5.3.1 for more on the selection of the tag value. // // If the value of the Initiate Tag in a received INIT chunk is found // to be 0, the receiver MUST treat it as an error and close the // association by transmitting an ABORT. if i.initiateTag == 0 { abort = true return abort, ErrChunkTypeInitInitateTagZero } // Defines the maximum number of streams the sender of this INIT // chunk allows the peer end to create in this association. The // value 0 MUST NOT be used. // // Note: There is no negotiation of the actual number of streams but // instead the two endpoints will use the min(requested, offered). // See Section 5.1.1 for details. // // Note: A receiver of an INIT with the MIS value of 0 SHOULD abort // the association. if i.numInboundStreams == 0 { abort = true return abort, ErrInitInboundStreamRequestZero } // Defines the number of outbound streams the sender of this INIT // chunk wishes to create in this association. The value of 0 MUST // NOT be used. // // Note: A receiver of an INIT with the OS value set to 0 SHOULD // abort the association. if i.numOutboundStreams == 0 { abort = true return abort, ErrInitOutboundStreamRequestZero } // An SCTP receiver MUST be able to receive a minimum of 1500 bytes in // one SCTP packet. This means that an SCTP endpoint MUST NOT indicate // less than 1500 bytes in its initial a_rwnd sent in the INIT or INIT // ACK. if i.advertisedReceiverWindowCredit < 1500 { abort = true return abort, ErrInitAdvertisedReceiver1500 } return false, nil } // String makes chunkInit printable func (i *chunkInit) String() string { return fmt.Sprintf("%s\n%s", i.chunkHeader, i.chunkInitCommon) } sctp-1.8.6/chunk_init_ack.go000066400000000000000000000115051436021606300160020ustar00rootroot00000000000000package sctp // nolint:dupl import ( "errors" "fmt" ) /* chunkInitAck represents an SCTP Chunk of type INIT ACK See chunkInitCommon for the fixed headers Variable Parameters Status Type Value ------------------------------------------------------------- State Cookie Mandatory 7 IPv4 IP (Note 1) Optional 5 IPv6 IP (Note 1) Optional 6 Unrecognized Parameter Optional 8 Reserved for ECN Capable (Note 2) Optional 32768 (0x8000) Host Name IP (Note 3) Optional 11 */ type chunkInitAck struct { chunkHeader chunkInitCommon } // Init ack chunk errors var ( ErrChunkTypeNotInitAck = errors.New("ChunkType is not of type INIT ACK") ErrChunkNotLongEnoughForParams = errors.New("chunk Value isn't long enough for mandatory parameters exp") ErrChunkTypeInitAckFlagZero = errors.New("ChunkType of type INIT ACK flags must be all 0") ErrInitAckUnmarshalFailed = errors.New("failed to unmarshal INIT body") ErrInitCommonDataMarshalFailed = errors.New("failed marshaling INIT common data") ErrChunkTypeInitAckInitateTagZero = errors.New("ChunkType of type INIT ACK InitiateTag must not be 0") ErrInitAckInboundStreamRequestZero = errors.New("INIT ACK inbound stream request must be > 0") ErrInitAckOutboundStreamRequestZero = errors.New("INIT ACK outbound stream request must be > 0") ErrInitAckAdvertisedReceiver1500 = errors.New("INIT ACK Advertised Receiver Window Credit (a_rwnd) must be >= 1500") ) func (i *chunkInitAck) unmarshal(raw []byte) error { if err := i.chunkHeader.unmarshal(raw); err != nil { return err } if i.typ != ctInitAck { return fmt.Errorf("%w: actually is %s", ErrChunkTypeNotInitAck, i.typ.String()) } else if len(i.raw) < initChunkMinLength { return fmt.Errorf("%w: %d actual: %d", ErrChunkNotLongEnoughForParams, initChunkMinLength, len(i.raw)) } // The Chunk Flags field in INIT is reserved, and all bits in it should // be set to 0 by the sender and ignored by the receiver. The sequence // of parameters within an INIT can be processed in any order. if i.flags != 0 { return ErrChunkTypeInitAckFlagZero } if err := i.chunkInitCommon.unmarshal(i.raw); err != nil { return fmt.Errorf("%w: %v", ErrInitAckUnmarshalFailed, err) } return nil } func (i *chunkInitAck) marshal() ([]byte, error) { initShared, err := i.chunkInitCommon.marshal() if err != nil { return nil, fmt.Errorf("%w: %v", ErrInitCommonDataMarshalFailed, err) } i.chunkHeader.typ = ctInitAck i.chunkHeader.raw = initShared return i.chunkHeader.marshal() } func (i *chunkInitAck) check() (abort bool, err error) { // The receiver of the INIT ACK records the value of the Initiate Tag // parameter. This value MUST be placed into the Verification Tag // field of every SCTP packet that the INIT ACK receiver transmits // within this association. // // The Initiate Tag MUST NOT take the value 0. See Section 5.3.1 for // more on the selection of the Initiate Tag value. // // If the value of the Initiate Tag in a received INIT ACK chunk is // found to be 0, the receiver MUST destroy the association // discarding its TCB. The receiver MAY send an ABORT for debugging // purpose. if i.initiateTag == 0 { abort = true return abort, ErrChunkTypeInitAckInitateTagZero } // Defines the maximum number of streams the sender of this INIT ACK // chunk allows the peer end to create in this association. The // value 0 MUST NOT be used. // // Note: There is no negotiation of the actual number of streams but // instead the two endpoints will use the min(requested, offered). // See Section 5.1.1 for details. // // Note: A receiver of an INIT ACK with the MIS value set to 0 SHOULD // destroy the association discarding its TCB. if i.numInboundStreams == 0 { abort = true return abort, ErrInitAckInboundStreamRequestZero } // Defines the number of outbound streams the sender of this INIT ACK // chunk wishes to create in this association. The value of 0 MUST // NOT be used, and the value MUST NOT be greater than the MIS value // sent in the INIT chunk. // // Note: A receiver of an INIT ACK with the OS value set to 0 SHOULD // destroy the association discarding its TCB. if i.numOutboundStreams == 0 { abort = true return abort, ErrInitAckOutboundStreamRequestZero } // An SCTP receiver MUST be able to receive a minimum of 1500 bytes in // one SCTP packet. This means that an SCTP endpoint MUST NOT indicate // less than 1500 bytes in its initial a_rwnd sent in the INIT or INIT // ACK. if i.advertisedReceiverWindowCredit < 1500 { abort = true return abort, ErrInitAckAdvertisedReceiver1500 } return false, nil } // String makes chunkInitAck printable func (i *chunkInitAck) String() string { return fmt.Sprintf("%s\n%s", i.chunkHeader, i.chunkInitCommon) } sctp-1.8.6/chunk_init_common.go000066400000000000000000000132711436021606300165360ustar00rootroot00000000000000package sctp import ( "encoding/binary" "errors" "fmt" ) /* chunkInitCommon represents an SCTP Chunk body of type INIT and INIT ACK 0 1 2 3 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Type = 1 | Chunk Flags | Chunk Length | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Initiate Tag | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Advertised Receiver Window Credit (a_rwnd) | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Number of Outbound Streams | Number of Inbound Streams | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Initial TSN | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | | | Optional/Variable-Length Parameters | | | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ The INIT chunk contains the following parameters. Unless otherwise noted, each parameter MUST only be included once in the INIT chunk. Fixed Parameters Status ---------------------------------------------- Initiate Tag Mandatory Advertised Receiver Window Credit Mandatory Number of Outbound Streams Mandatory Number of Inbound Streams Mandatory Initial TSN Mandatory */ type chunkInitCommon struct { initiateTag uint32 advertisedReceiverWindowCredit uint32 numOutboundStreams uint16 numInboundStreams uint16 initialTSN uint32 params []param } const ( initChunkMinLength = 16 initOptionalVarHeaderLength = 4 ) // Init chunk errors var ( ErrInitChunkParseParamTypeFailed = errors.New("failed to parse param type") ErrInitChunkUnmarshalParam = errors.New("failed unmarshalling param in Init Chunk") ErrInitAckMarshalParam = errors.New("unable to marshal parameter for INIT/INITACK") ) func (i *chunkInitCommon) unmarshal(raw []byte) error { i.initiateTag = binary.BigEndian.Uint32(raw[0:]) i.advertisedReceiverWindowCredit = binary.BigEndian.Uint32(raw[4:]) i.numOutboundStreams = binary.BigEndian.Uint16(raw[8:]) i.numInboundStreams = binary.BigEndian.Uint16(raw[10:]) i.initialTSN = binary.BigEndian.Uint32(raw[12:]) // https://tools.ietf.org/html/rfc4960#section-3.2.1 // // Chunk values of SCTP control chunks consist of a chunk-type-specific // header of required fields, followed by zero or more parameters. The // optional and variable-length parameters contained in a chunk are // defined in a Type-Length-Value format as shown below. // // 0 1 2 3 // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | Parameter Type | Parameter Length | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | | // | Parameter Value | // | | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ offset := initChunkMinLength remaining := len(raw) - offset for remaining > 0 { if remaining > initOptionalVarHeaderLength { pType, err := parseParamType(raw[offset:]) if err != nil { return fmt.Errorf("%w: %v", ErrInitChunkParseParamTypeFailed, err) } p, err := buildParam(pType, raw[offset:]) if err != nil { return fmt.Errorf("%w: %v", ErrInitChunkUnmarshalParam, err) } i.params = append(i.params, p) padding := getPadding(p.length()) offset += p.length() + padding remaining -= p.length() + padding } else { break } } return nil } func (i *chunkInitCommon) marshal() ([]byte, error) { out := make([]byte, initChunkMinLength) binary.BigEndian.PutUint32(out[0:], i.initiateTag) binary.BigEndian.PutUint32(out[4:], i.advertisedReceiverWindowCredit) binary.BigEndian.PutUint16(out[8:], i.numOutboundStreams) binary.BigEndian.PutUint16(out[10:], i.numInboundStreams) binary.BigEndian.PutUint32(out[12:], i.initialTSN) for idx, p := range i.params { pp, err := p.marshal() if err != nil { return nil, fmt.Errorf("%w: %v", ErrInitAckMarshalParam, err) } out = append(out, pp...) // Chunks (including Type, Length, and Value fields) are padded out // by the sender with all zero bytes to be a multiple of 4 bytes // long. This padding MUST NOT be more than 3 bytes in total. The // Chunk Length value does not include terminating padding of the // chunk. *However, it does include padding of any variable-length // parameter except the last parameter in the chunk.* The receiver // MUST ignore the padding. if idx != len(i.params)-1 { out = padByte(out, getPadding(len(pp))) } } return out, nil } // String makes chunkInitCommon printable func (i chunkInitCommon) String() string { format := `initiateTag: %d advertisedReceiverWindowCredit: %d numOutboundStreams: %d numInboundStreams: %d initialTSN: %d` res := fmt.Sprintf(format, i.initiateTag, i.advertisedReceiverWindowCredit, i.numOutboundStreams, i.numInboundStreams, i.initialTSN, ) for i, param := range i.params { res += fmt.Sprintf("Param %d:\n %s", i, param) } return res } sctp-1.8.6/chunk_payload_data.go000066400000000000000000000144031436021606300166430ustar00rootroot00000000000000package sctp import ( "encoding/binary" "errors" "fmt" "time" ) /* chunkPayloadData represents an SCTP Chunk of type DATA 0 1 2 3 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Type = 0 | Reserved|U|B|E| Length | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | TSN | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Stream Identifier S | Stream Sequence Number n | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Payload Protocol Identifier | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | | | User Data (seq n of Stream S) | | | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ An unfragmented user message shall have both the B and E bits set to '1'. Setting both B and E bits to '0' indicates a middle fragment of a multi-fragment user message, as summarized in the following table: B E Description ============================================================ | 1 0 | First piece of a fragmented user message | +----------------------------------------------------------+ | 0 0 | Middle piece of a fragmented user message | +----------------------------------------------------------+ | 0 1 | Last piece of a fragmented user message | +----------------------------------------------------------+ | 1 1 | Unfragmented message | ============================================================ | Table 1: Fragment Description Flags | ============================================================ */ type chunkPayloadData struct { chunkHeader unordered bool beginningFragment bool endingFragment bool immediateSack bool tsn uint32 streamIdentifier uint16 streamSequenceNumber uint16 payloadType PayloadProtocolIdentifier userData []byte // Whether this data chunk was acknowledged (received by peer) acked bool missIndicator uint32 // Partial-reliability parameters used only by sender since time.Time nSent uint32 // number of transmission made for this chunk _abandoned bool _allInflight bool // valid only with the first fragment // Retransmission flag set when T1-RTX timeout occurred and this // chunk is still in the inflight queue retransmit bool head *chunkPayloadData // link to the head of the fragment } const ( payloadDataEndingFragmentBitmask = 1 payloadDataBeginingFragmentBitmask = 2 payloadDataUnorderedBitmask = 4 payloadDataImmediateSACK = 8 payloadDataHeaderSize = 12 ) // PayloadProtocolIdentifier is an enum for DataChannel payload types type PayloadProtocolIdentifier uint32 // PayloadProtocolIdentifier enums // https://www.iana.org/assignments/sctp-parameters/sctp-parameters.xhtml#sctp-parameters-25 const ( PayloadTypeUnknown PayloadProtocolIdentifier = 0 PayloadTypeWebRTCDCEP PayloadProtocolIdentifier = 50 PayloadTypeWebRTCString PayloadProtocolIdentifier = 51 PayloadTypeWebRTCBinary PayloadProtocolIdentifier = 53 PayloadTypeWebRTCStringEmpty PayloadProtocolIdentifier = 56 PayloadTypeWebRTCBinaryEmpty PayloadProtocolIdentifier = 57 ) // Data chunk errors var ( ErrChunkPayloadSmall = errors.New("packet is smaller than the header size") ) func (p PayloadProtocolIdentifier) String() string { switch p { case PayloadTypeWebRTCDCEP: return "WebRTC DCEP" case PayloadTypeWebRTCString: return "WebRTC String" case PayloadTypeWebRTCBinary: return "WebRTC Binary" case PayloadTypeWebRTCStringEmpty: return "WebRTC String (Empty)" case PayloadTypeWebRTCBinaryEmpty: return "WebRTC Binary (Empty)" default: return fmt.Sprintf("Unknown Payload Protocol Identifier: %d", p) } } func (p *chunkPayloadData) unmarshal(raw []byte) error { if err := p.chunkHeader.unmarshal(raw); err != nil { return err } p.immediateSack = p.flags&payloadDataImmediateSACK != 0 p.unordered = p.flags&payloadDataUnorderedBitmask != 0 p.beginningFragment = p.flags&payloadDataBeginingFragmentBitmask != 0 p.endingFragment = p.flags&payloadDataEndingFragmentBitmask != 0 if len(raw) < payloadDataHeaderSize { return ErrChunkPayloadSmall } p.tsn = binary.BigEndian.Uint32(p.raw[0:]) p.streamIdentifier = binary.BigEndian.Uint16(p.raw[4:]) p.streamSequenceNumber = binary.BigEndian.Uint16(p.raw[6:]) p.payloadType = PayloadProtocolIdentifier(binary.BigEndian.Uint32(p.raw[8:])) p.userData = p.raw[payloadDataHeaderSize:] return nil } func (p *chunkPayloadData) marshal() ([]byte, error) { payRaw := make([]byte, payloadDataHeaderSize+len(p.userData)) binary.BigEndian.PutUint32(payRaw[0:], p.tsn) binary.BigEndian.PutUint16(payRaw[4:], p.streamIdentifier) binary.BigEndian.PutUint16(payRaw[6:], p.streamSequenceNumber) binary.BigEndian.PutUint32(payRaw[8:], uint32(p.payloadType)) copy(payRaw[payloadDataHeaderSize:], p.userData) flags := uint8(0) if p.endingFragment { flags = 1 } if p.beginningFragment { flags |= 1 << 1 } if p.unordered { flags |= 1 << 2 } if p.immediateSack { flags |= 1 << 3 } p.chunkHeader.flags = flags p.chunkHeader.typ = ctPayloadData p.chunkHeader.raw = payRaw return p.chunkHeader.marshal() } func (p *chunkPayloadData) check() (abort bool, err error) { return false, nil } // String makes chunkPayloadData printable func (p *chunkPayloadData) String() string { return fmt.Sprintf("%s\n%d", p.chunkHeader, p.tsn) } func (p *chunkPayloadData) abandoned() bool { if p.head != nil { return p.head._abandoned && p.head._allInflight } return p._abandoned && p._allInflight } func (p *chunkPayloadData) setAbandoned(abandoned bool) { if p.head != nil { p.head._abandoned = abandoned return } p._abandoned = abandoned } func (p *chunkPayloadData) setAllInflight() { if p.endingFragment { if p.head != nil { p.head._allInflight = true } else { p._allInflight = true } } } sctp-1.8.6/chunk_reconfig.go000066400000000000000000000057131436021606300160210ustar00rootroot00000000000000package sctp import ( "errors" "fmt" ) // https://tools.ietf.org/html/rfc6525#section-3.1 // chunkReconfig represents an SCTP Chunk used to reconfigure streams. // // 0 1 2 3 // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | Type = 130 | Chunk Flags | Chunk Length | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // \ \ // / Re-configuration Parameter / // \ \ // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // \ \ // / Re-configuration Parameter (optional) / // \ \ // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ type chunkReconfig struct { chunkHeader paramA param paramB param } // Reconfigure chunk errors var ( ErrChunkParseParamTypeFailed = errors.New("failed to parse param type") ErrChunkMarshalParamAReconfigFailed = errors.New("unable to marshal parameter A for reconfig") ErrChunkMarshalParamBReconfigFailed = errors.New("unable to marshal parameter B for reconfig") ) func (c *chunkReconfig) unmarshal(raw []byte) error { if err := c.chunkHeader.unmarshal(raw); err != nil { return err } pType, err := parseParamType(c.raw) if err != nil { return fmt.Errorf("%w: %v", ErrChunkParseParamTypeFailed, err) } a, err := buildParam(pType, c.raw) if err != nil { return err } c.paramA = a padding := getPadding(a.length()) offset := a.length() + padding if len(c.raw) > offset { pType, err := parseParamType(c.raw[offset:]) if err != nil { return fmt.Errorf("%w: %v", ErrChunkParseParamTypeFailed, err) } b, err := buildParam(pType, c.raw[offset:]) if err != nil { return err } c.paramB = b } return nil } func (c *chunkReconfig) marshal() ([]byte, error) { out, err := c.paramA.marshal() if err != nil { return nil, fmt.Errorf("%w: %v", ErrChunkMarshalParamAReconfigFailed, err) } if c.paramB != nil { // Pad param A out = padByte(out, getPadding(len(out))) outB, err := c.paramB.marshal() if err != nil { return nil, fmt.Errorf("%w: %v", ErrChunkMarshalParamBReconfigFailed, err) } out = append(out, outB...) } c.typ = ctReconfig c.raw = out return c.chunkHeader.marshal() } func (c *chunkReconfig) check() (abort bool, err error) { // nolint:godox // TODO: check allowed combinations: // https://tools.ietf.org/html/rfc6525#section-3.1 return true, nil } // String makes chunkReconfig printable func (c *chunkReconfig) String() string { res := fmt.Sprintf("Param A:\n %s", c.paramA) if c.paramB != nil { res += fmt.Sprintf("Param B:\n %s", c.paramB) } return res } sctp-1.8.6/chunk_reconfig_test.go000066400000000000000000000031741436021606300170570ustar00rootroot00000000000000package sctp import ( "testing" "github.com/stretchr/testify/assert" ) func TestChunkReconfig_Success(t *testing.T) { tt := []struct { binary []byte }{ {append([]byte{0x82, 0x0, 0x0, 0x1a}, testChunkReconfigParamA()...)}, // Note: chunk trailing padding is added in packet.marshal {append([]byte{0x82, 0x0, 0x0, 0x14}, testChunkReconfigParamB()...)}, {append([]byte{0x82, 0x0, 0x0, 0x10}, testChunkReconfigResponce()...)}, {append(append([]byte{0x82, 0x0, 0x0, 0x2c}, padByte(testChunkReconfigParamA(), 2)...), testChunkReconfigParamB()...)}, {append(append([]byte{0x82, 0x0, 0x0, 0x2a}, testChunkReconfigParamB()...), testChunkReconfigParamA()...)}, // Note: chunk trailing padding is added in packet.marshal } for i, tc := range tt { actual := &chunkReconfig{} err := actual.unmarshal(tc.binary) if err != nil { t.Fatalf("failed to unmarshal #%d: %v", i, err) } b, err := actual.marshal() if err != nil { t.Fatalf("failed to marshal: %v", err) } assert.Equal(t, tc.binary, b, "test %d not equal", i) } } func TestChunkReconfigUnmarshal_Failure(t *testing.T) { tt := []struct { name string binary []byte }{ {"chunk header to short", []byte{0x82}}, {"missing parse param type (A)", []byte{0x82, 0x0, 0x0, 0x4}}, {"wrong param (A)", []byte{0x82, 0x0, 0x0, 0x8, 0x0, 0xd, 0x0, 0x0}}, {"wrong param (B)", append(append([]byte{0x82, 0x0, 0x0, 0x18}, testChunkReconfigParamB()...), []byte{0x0, 0xd, 0x0, 0x0}...)}, } for i, tc := range tt { actual := &chunkReconfig{} err := actual.unmarshal(tc.binary) if err == nil { t.Errorf("expected unmarshal #%d: '%s' to fail.", i, tc.name) } } } sctp-1.8.6/chunk_selective_ack.go000066400000000000000000000122161436021606300170220ustar00rootroot00000000000000package sctp import ( "encoding/binary" "errors" "fmt" ) /* chunkSelectiveAck represents an SCTP Chunk of type SACK This chunk is sent to the peer endpoint to acknowledge received DATA chunks and to inform the peer endpoint of gaps in the received subsequences of DATA chunks as represented by their TSNs. 0 1 2 3 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Type = 3 |Chunk Flags | Chunk Length | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Cumulative TSN Ack | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Advertised Receiver Window Credit (a_rwnd) | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Number of Gap Ack Blocks = N | Number of Duplicate TSNs = X | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Gap Ack Block #1 Start | Gap Ack Block #1 End | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ / / \ ... \ / / +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Gap Ack Block #N Start | Gap Ack Block #N End | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Duplicate TSN 1 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ / / \ ... \ / / +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Duplicate TSN X | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ */ type gapAckBlock struct { start uint16 end uint16 } // Selective ack chunk errors var ( ErrChunkTypeNotSack = errors.New("ChunkType is not of type SACK") ErrSackSizeNotLargeEnoughInfo = errors.New("SACK Chunk size is not large enough to contain header") ErrSackSizeNotMatchPredicted = errors.New("SACK Chunk size does not match predicted amount from header values") ) // String makes gapAckBlock printable func (g gapAckBlock) String() string { return fmt.Sprintf("%d - %d", g.start, g.end) } type chunkSelectiveAck struct { chunkHeader cumulativeTSNAck uint32 advertisedReceiverWindowCredit uint32 gapAckBlocks []gapAckBlock duplicateTSN []uint32 } const ( selectiveAckHeaderSize = 12 ) func (s *chunkSelectiveAck) unmarshal(raw []byte) error { if err := s.chunkHeader.unmarshal(raw); err != nil { return err } if s.typ != ctSack { return fmt.Errorf("%w: actually is %s", ErrChunkTypeNotSack, s.typ.String()) } if len(s.raw) < selectiveAckHeaderSize { return fmt.Errorf("%w: %v remaining, needs %v bytes", ErrSackSizeNotLargeEnoughInfo, len(s.raw), selectiveAckHeaderSize) } s.cumulativeTSNAck = binary.BigEndian.Uint32(s.raw[0:]) s.advertisedReceiverWindowCredit = binary.BigEndian.Uint32(s.raw[4:]) s.gapAckBlocks = make([]gapAckBlock, binary.BigEndian.Uint16(s.raw[8:])) s.duplicateTSN = make([]uint32, binary.BigEndian.Uint16(s.raw[10:])) if len(s.raw) != selectiveAckHeaderSize+(4*len(s.gapAckBlocks)+(4*len(s.duplicateTSN))) { return ErrSackSizeNotMatchPredicted } offset := selectiveAckHeaderSize for i := range s.gapAckBlocks { s.gapAckBlocks[i].start = binary.BigEndian.Uint16(s.raw[offset:]) s.gapAckBlocks[i].end = binary.BigEndian.Uint16(s.raw[offset+2:]) offset += 4 } for i := range s.duplicateTSN { s.duplicateTSN[i] = binary.BigEndian.Uint32(s.raw[offset:]) offset += 4 } return nil } func (s *chunkSelectiveAck) marshal() ([]byte, error) { sackRaw := make([]byte, selectiveAckHeaderSize+(4*len(s.gapAckBlocks)+(4*len(s.duplicateTSN)))) binary.BigEndian.PutUint32(sackRaw[0:], s.cumulativeTSNAck) binary.BigEndian.PutUint32(sackRaw[4:], s.advertisedReceiverWindowCredit) binary.BigEndian.PutUint16(sackRaw[8:], uint16(len(s.gapAckBlocks))) binary.BigEndian.PutUint16(sackRaw[10:], uint16(len(s.duplicateTSN))) offset := selectiveAckHeaderSize for _, g := range s.gapAckBlocks { binary.BigEndian.PutUint16(sackRaw[offset:], g.start) binary.BigEndian.PutUint16(sackRaw[offset+2:], g.end) offset += 4 } for _, t := range s.duplicateTSN { binary.BigEndian.PutUint32(sackRaw[offset:], t) offset += 4 } s.chunkHeader.typ = ctSack s.chunkHeader.raw = sackRaw return s.chunkHeader.marshal() } func (s *chunkSelectiveAck) check() (abort bool, err error) { return false, nil } // String makes chunkSelectiveAck printable func (s *chunkSelectiveAck) String() string { res := fmt.Sprintf("SACK cumTsnAck=%d arwnd=%d dupTsn=%d", s.cumulativeTSNAck, s.advertisedReceiverWindowCredit, s.duplicateTSN) for _, gap := range s.gapAckBlocks { res = fmt.Sprintf("%s\n gap ack: %s", res, gap) } return res } sctp-1.8.6/chunk_shutdown.go000066400000000000000000000032041436021606300160710ustar00rootroot00000000000000package sctp import ( "encoding/binary" "errors" "fmt" ) /* chunkShutdown represents an SCTP Chunk of type chunkShutdown 0 1 2 3 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Type = 7 | Chunk Flags | Length = 8 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Cumulative TSN Ack | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ */ type chunkShutdown struct { chunkHeader cumulativeTSNAck uint32 } const ( cumulativeTSNAckLength = 4 ) // Shutdown chunk errors var ( ErrInvalidChunkSize = errors.New("invalid chunk size") ErrChunkTypeNotShutdown = errors.New("ChunkType is not of type SHUTDOWN") ) func (c *chunkShutdown) unmarshal(raw []byte) error { if err := c.chunkHeader.unmarshal(raw); err != nil { return err } if c.typ != ctShutdown { return fmt.Errorf("%w: actually is %s", ErrChunkTypeNotShutdown, c.typ.String()) } if len(c.raw) != cumulativeTSNAckLength { return ErrInvalidChunkSize } c.cumulativeTSNAck = binary.BigEndian.Uint32(c.raw[0:]) return nil } func (c *chunkShutdown) marshal() ([]byte, error) { out := make([]byte, cumulativeTSNAckLength) binary.BigEndian.PutUint32(out[0:], c.cumulativeTSNAck) c.typ = ctShutdown c.raw = out return c.chunkHeader.marshal() } func (c *chunkShutdown) check() (abort bool, err error) { return false, nil } // String makes chunkShutdown printable func (c *chunkShutdown) String() string { return c.chunkHeader.String() } sctp-1.8.6/chunk_shutdown_ack.go000066400000000000000000000022431436021606300167110ustar00rootroot00000000000000package sctp import ( "errors" "fmt" ) /* chunkShutdownAck represents an SCTP Chunk of type chunkShutdownAck 0 1 2 3 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Type = 8 | Chunk Flags | Length = 4 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ */ type chunkShutdownAck struct { chunkHeader } // Shutdown ack chunk errors var ( ErrChunkTypeNotShutdownAck = errors.New("ChunkType is not of type SHUTDOWN-ACK") ) func (c *chunkShutdownAck) unmarshal(raw []byte) error { if err := c.chunkHeader.unmarshal(raw); err != nil { return err } if c.typ != ctShutdownAck { return fmt.Errorf("%w: actually is %s", ErrChunkTypeNotShutdownAck, c.typ.String()) } return nil } func (c *chunkShutdownAck) marshal() ([]byte, error) { c.typ = ctShutdownAck return c.chunkHeader.marshal() } func (c *chunkShutdownAck) check() (abort bool, err error) { return false, nil } // String makes chunkShutdownAck printable func (c *chunkShutdownAck) String() string { return c.chunkHeader.String() } sctp-1.8.6/chunk_shutdown_ack_test.go000066400000000000000000000017571436021606300177610ustar00rootroot00000000000000package sctp import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestChunkShutdownAck_Success(t *testing.T) { tt := []struct { binary []byte }{ {[]byte{0x08, 0x00, 0x00, 0x04}}, } for i, tc := range tt { actual := &chunkShutdownAck{} err := actual.unmarshal(tc.binary) require.NoError(t, err, "failed to unmarshal #%d: %v", i, err) b, err := actual.marshal() require.NoError(t, err, "failed to marshal: %v", err) assert.Equal(t, tc.binary, b, "test %d not equal", i) } } func TestChunkShutdownAck_Failure(t *testing.T) { tt := []struct { name string binary []byte }{ {"length too short", []byte{0x08, 0x00, 0x00}}, {"length too long", []byte{0x08, 0x00, 0x00, 0x04, 0x12}}, {"invalid type", []byte{0x0f, 0x00, 0x00, 0x04}}, } for i, tc := range tt { actual := &chunkShutdownAck{} err := actual.unmarshal(tc.binary) if err == nil { t.Errorf("expected unmarshal #%d: '%s' to fail.", i, tc.name) } } } sctp-1.8.6/chunk_shutdown_complete.go000066400000000000000000000023511436021606300177630ustar00rootroot00000000000000package sctp import ( "errors" "fmt" ) /* chunkShutdownComplete represents an SCTP Chunk of type chunkShutdownComplete 0 1 2 3 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Type = 14 |Reserved |T| Length = 4 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ */ type chunkShutdownComplete struct { chunkHeader } // Shutdown complete chunk errors var ( ErrChunkTypeNotShutdownComplete = errors.New("ChunkType is not of type SHUTDOWN-COMPLETE") ) func (c *chunkShutdownComplete) unmarshal(raw []byte) error { if err := c.chunkHeader.unmarshal(raw); err != nil { return err } if c.typ != ctShutdownComplete { return fmt.Errorf("%w: actually is %s", ErrChunkTypeNotShutdownComplete, c.typ.String()) } return nil } func (c *chunkShutdownComplete) marshal() ([]byte, error) { c.typ = ctShutdownComplete return c.chunkHeader.marshal() } func (c *chunkShutdownComplete) check() (abort bool, err error) { return false, nil } // String makes chunkShutdownComplete printable func (c *chunkShutdownComplete) String() string { return c.chunkHeader.String() } sctp-1.8.6/chunk_shutdown_complete_test.go000066400000000000000000000017711436021606300210270ustar00rootroot00000000000000package sctp import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestChunkShutdownComplete_Success(t *testing.T) { tt := []struct { binary []byte }{ {[]byte{0x0e, 0x00, 0x00, 0x04}}, } for i, tc := range tt { actual := &chunkShutdownComplete{} err := actual.unmarshal(tc.binary) require.NoError(t, err, "failed to unmarshal #%d: %v", i, err) b, err := actual.marshal() require.NoError(t, err, "failed to marshal: %v", err) assert.Equal(t, tc.binary, b, "test %d not equal", i) } } func TestChunkShutdownComplete_Failure(t *testing.T) { tt := []struct { name string binary []byte }{ {"length too short", []byte{0x0e, 0x00, 0x00}}, {"length too long", []byte{0x0e, 0x00, 0x00, 0x04, 0x12}}, {"invalid type", []byte{0x0f, 0x00, 0x00, 0x04}}, } for i, tc := range tt { actual := &chunkShutdownComplete{} err := actual.unmarshal(tc.binary) require.Error(t, err, "expected unmarshal #%d: '%s' to fail.", i, tc.name) } } sctp-1.8.6/chunk_shutdown_test.go000066400000000000000000000023151436021606300171320ustar00rootroot00000000000000package sctp import ( "testing" "github.com/stretchr/testify/assert" ) func TestChunkShutdown_Success(t *testing.T) { tt := []struct { binary []byte }{ {[]byte{0x07, 0x00, 0x00, 0x08, 0x12, 0x34, 0x56, 0x78}}, } for i, tc := range tt { actual := &chunkShutdown{} err := actual.unmarshal(tc.binary) if err != nil { t.Fatalf("failed to unmarshal #%d: %v", i, err) } b, err := actual.marshal() if err != nil { t.Fatalf("failed to marshal: %v", err) } assert.Equal(t, tc.binary, b, "test %d not equal", i) } } func TestChunkShutdown_Failure(t *testing.T) { tt := []struct { name string binary []byte }{ {"length too short", []byte{0x07, 0x00, 0x00, 0x07, 0x12, 0x34, 0x56, 0x78}}, {"length too long", []byte{0x07, 0x00, 0x00, 0x09, 0x12, 0x34, 0x56, 0x78}}, {"payload too short", []byte{0x07, 0x00, 0x00, 0x08, 0x12, 0x34, 0x56}}, {"payload too long", []byte{0x07, 0x00, 0x00, 0x08, 0x12, 0x34, 0x56, 0x78, 0x9f}}, {"invalid type", []byte{0x08, 0x00, 0x00, 0x08, 0x12, 0x34, 0x56, 0x78}}, } for i, tc := range tt { actual := &chunkShutdown{} err := actual.unmarshal(tc.binary) if err == nil { t.Errorf("expected unmarshal #%d: '%s' to fail.", i, tc.name) } } } sctp-1.8.6/chunk_test.go000066400000000000000000000242211436021606300151770ustar00rootroot00000000000000package sctp import ( "testing" "github.com/stretchr/testify/assert" ) func TestInitChunk(t *testing.T) { pkt := &packet{} rawPkt := []byte{ 0x13, 0x88, 0x13, 0x88, 0x00, 0x00, 0x00, 0x00, 0x81, 0x46, 0x9d, 0xfc, 0x01, 0x00, 0x00, 0x56, 0x55, 0xb9, 0x64, 0xa5, 0x00, 0x02, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0xe8, 0x6d, 0x10, 0x30, 0xc0, 0x00, 0x00, 0x04, 0x80, 0x08, 0x00, 0x09, 0xc0, 0x0f, 0xc1, 0x80, 0x82, 0x00, 0x00, 0x00, 0x80, 0x02, 0x00, 0x24, 0x9f, 0xeb, 0xbb, 0x5c, 0x50, 0xc9, 0xbf, 0x75, 0x9c, 0xb1, 0x2c, 0x57, 0x4f, 0xa4, 0x5a, 0x51, 0xba, 0x60, 0x17, 0x78, 0x27, 0x94, 0x5c, 0x31, 0xe6, 0x5d, 0x5b, 0x09, 0x47, 0xe2, 0x22, 0x06, 0x80, 0x04, 0x00, 0x06, 0x00, 0x01, 0x00, 0x00, 0x80, 0x03, 0x00, 0x06, 0x80, 0xc1, 0x00, 0x00, } err := pkt.unmarshal(rawPkt) if err != nil { t.Errorf("Unmarshal failed, has chunk") } i, ok := pkt.chunks[0].(*chunkInit) if !ok { t.Errorf("Failed to cast Chunk -> Init") } switch { case err != nil: t.Errorf("Unmarshal init Chunk failed: %v", err) case i.initiateTag != 1438213285: t.Errorf("Unmarshal passed for SCTP packet, but got incorrect initiate tag exp: %d act: %d", 1438213285, i.initiateTag) case i.advertisedReceiverWindowCredit != 131072: t.Errorf("Unmarshal passed for SCTP packet, but got incorrect advertisedReceiverWindowCredit exp: %d act: %d", 131072, i.advertisedReceiverWindowCredit) case i.numOutboundStreams != 1024: t.Errorf("Unmarshal passed for SCTP packet, but got incorrect numOutboundStreams tag exp: %d act: %d", 1024, i.numOutboundStreams) case i.numInboundStreams != 2048: t.Errorf("Unmarshal passed for SCTP packet, but got incorrect numInboundStreams exp: %d act: %d", 2048, i.numInboundStreams) case i.initialTSN != uint32(3899461680): t.Errorf("Unmarshal passed for SCTP packet, but got incorrect initialTSN exp: %d act: %d", uint32(3899461680), i.initialTSN) } } func TestInitAck(t *testing.T) { pkt := &packet{} rawPkt := []byte{0x13, 0x88, 0x13, 0x88, 0xce, 0x15, 0x79, 0xa2, 0x96, 0x19, 0xe8, 0xb2, 0x02, 0x00, 0x00, 0x1c, 0xeb, 0x81, 0x4e, 0x01, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x50, 0xdf, 0x90, 0xd9, 0x00, 0x07, 0x00, 0x08, 0x94, 0x06, 0x2f, 0x93} err := pkt.unmarshal(rawPkt) if err != nil { t.Errorf("Unmarshal failed, has chunk: %v", err) } _, ok := pkt.chunks[0].(*chunkInitAck) if !ok { t.Error("Failed to cast Chunk -> Init") } else if err != nil { t.Errorf("Unmarshal init Chunk failed: %v", err) } } func TestChromeChunk1Init(t *testing.T) { pkt := &packet{} rawPkt := []byte{0x13, 0x88, 0x13, 0x88, 0x00, 0x00, 0x00, 0x00, 0xbc, 0xb3, 0x45, 0xa2, 0x01, 0x00, 0x00, 0x56, 0xce, 0x15, 0x79, 0xa2, 0x00, 0x02, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x94, 0x57, 0x95, 0xc0, 0xc0, 0x00, 0x00, 0x04, 0x80, 0x08, 0x00, 0x09, 0xc0, 0x0f, 0xc1, 0x80, 0x82, 0x00, 0x00, 0x00, 0x80, 0x02, 0x00, 0x24, 0xff, 0x5c, 0x49, 0x19, 0x4a, 0x94, 0xe8, 0x2a, 0xec, 0x58, 0x55, 0x62, 0x29, 0x1f, 0x8e, 0x23, 0xcd, 0x7c, 0xe8, 0x46, 0xba, 0x58, 0x1b, 0x3d, 0xab, 0xd7, 0x7e, 0x50, 0xf2, 0x41, 0xb1, 0x2e, 0x80, 0x04, 0x00, 0x06, 0x00, 0x01, 0x00, 0x00, 0x80, 0x03, 0x00, 0x06, 0x80, 0xc1, 0x00, 0x00} err := pkt.unmarshal(rawPkt) if err != nil { t.Errorf("Unmarshal failed, has chunk: %v", err) } rawPkt2, err := pkt.marshal() if err != nil { t.Errorf("Remarshal failed: %v", err) } assert.Equal(t, rawPkt, rawPkt2) } func TestChromeChunk2InitAck(t *testing.T) { pkt := &packet{} rawPkt := []byte{0x13, 0x88, 0x13, 0x88, 0xce, 0x15, 0x79, 0xa2, 0xb5, 0xdb, 0x2d, 0x93, 0x02, 0x00, 0x01, 0x90, 0x9b, 0xd5, 0xb3, 0x6f, 0x00, 0x02, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0xef, 0xb4, 0x72, 0x87, 0xc0, 0x00, 0x00, 0x04, 0x80, 0x08, 0x00, 0x09, 0xc0, 0x0f, 0xc1, 0x80, 0x82, 0x00, 0x00, 0x00, 0x80, 0x02, 0x00, 0x24, 0x2e, 0xf9, 0x9c, 0x10, 0x63, 0x72, 0xed, 0x0d, 0x33, 0xc2, 0xdc, 0x7f, 0x9f, 0xd7, 0xef, 0x1b, 0xc9, 0xc4, 0xa7, 0x41, 0x9a, 0x07, 0x68, 0x6b, 0x66, 0xfb, 0x6a, 0x4e, 0x32, 0x5d, 0xe4, 0x25, 0x80, 0x04, 0x00, 0x06, 0x00, 0x01, 0x00, 0x00, 0x80, 0x03, 0x00, 0x06, 0x80, 0xc1, 0x00, 0x00, 0x00, 0x07, 0x01, 0x38, 0x4b, 0x41, 0x4d, 0x45, 0x2d, 0x42, 0x53, 0x44, 0x20, 0x31, 0x2e, 0x31, 0x00, 0x00, 0x00, 0x00, 0x9c, 0x1e, 0x49, 0x5b, 0x00, 0x00, 0x00, 0x00, 0xd2, 0x42, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x60, 0xea, 0x00, 0x00, 0xc4, 0x13, 0x3d, 0xe9, 0x86, 0xb1, 0x85, 0x75, 0xa2, 0x79, 0x15, 0xce, 0x9b, 0xd5, 0xb3, 0x6f, 0x20, 0xe0, 0x9f, 0x89, 0xe0, 0x27, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x20, 0xe0, 0x9f, 0x89, 0xe0, 0x27, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x13, 0x88, 0x13, 0x88, 0x00, 0x00, 0x01, 0x00, 0x01, 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x56, 0xce, 0x15, 0x79, 0xa2, 0x00, 0x02, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x94, 0x57, 0x95, 0xc0, 0xc0, 0x00, 0x00, 0x04, 0x80, 0x08, 0x00, 0x09, 0xc0, 0x0f, 0xc1, 0x80, 0x82, 0x00, 0x00, 0x00, 0x80, 0x02, 0x00, 0x24, 0xff, 0x5c, 0x49, 0x19, 0x4a, 0x94, 0xe8, 0x2a, 0xec, 0x58, 0x55, 0x62, 0x29, 0x1f, 0x8e, 0x23, 0xcd, 0x7c, 0xe8, 0x46, 0xba, 0x58, 0x1b, 0x3d, 0xab, 0xd7, 0x7e, 0x50, 0xf2, 0x41, 0xb1, 0x2e, 0x80, 0x04, 0x00, 0x06, 0x00, 0x01, 0x00, 0x00, 0x80, 0x03, 0x00, 0x06, 0x80, 0xc1, 0x00, 0x00, 0x02, 0x00, 0x01, 0x90, 0x9b, 0xd5, 0xb3, 0x6f, 0x00, 0x02, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0xef, 0xb4, 0x72, 0x87, 0xc0, 0x00, 0x00, 0x04, 0x80, 0x08, 0x00, 0x09, 0xc0, 0x0f, 0xc1, 0x80, 0x82, 0x00, 0x00, 0x00, 0x80, 0x02, 0x00, 0x24, 0x2e, 0xf9, 0x9c, 0x10, 0x63, 0x72, 0xed, 0x0d, 0x33, 0xc2, 0xdc, 0x7f, 0x9f, 0xd7, 0xef, 0x1b, 0xc9, 0xc4, 0xa7, 0x41, 0x9a, 0x07, 0x68, 0x6b, 0x66, 0xfb, 0x6a, 0x4e, 0x32, 0x5d, 0xe4, 0x25, 0x80, 0x04, 0x00, 0x06, 0x00, 0x01, 0x00, 0x00, 0x80, 0x03, 0x00, 0x06, 0x80, 0xc1, 0x00, 0x00, 0xca, 0x0c, 0x21, 0x11, 0xce, 0xf4, 0xfc, 0xb3, 0x66, 0x99, 0x4f, 0xdb, 0x4f, 0x95, 0x6b, 0x6f, 0x3b, 0xb1, 0xdb, 0x5a} err := pkt.unmarshal(rawPkt) if err != nil { t.Errorf("Unmarshal failed, has chunk: %v", err) } rawPkt2, err := pkt.marshal() if err != nil { t.Errorf("Remarshal failed: %v", err) } assert.Equal(t, rawPkt, rawPkt2) } func TestInitMarshalUnmarshal(t *testing.T) { p := &packet{} p.destinationPort = 1 p.sourcePort = 1 p.verificationTag = 123 initAck := &chunkInitAck{} initAck.initialTSN = 123 initAck.numOutboundStreams = 1 initAck.numInboundStreams = 1 initAck.initiateTag = 123 initAck.advertisedReceiverWindowCredit = 1024 cookie, ErrRand := newRandomStateCookie() if ErrRand != nil { t.Fatalf("Failed to generate random state cookie: %v", ErrRand) } initAck.params = []param{cookie} p.chunks = []chunk{initAck} rawPkt, err := p.marshal() if err != nil { t.Errorf("Failed to marshal packet: %v", err) } pkt := &packet{} err = pkt.unmarshal(rawPkt) if err != nil { t.Errorf("Unmarshal failed, has chunk: %v", err) } i, ok := pkt.chunks[0].(*chunkInitAck) if !ok { t.Error("Failed to cast Chunk -> InitAck") } switch { case err != nil: t.Errorf("Unmarshal init ack Chunk failed: %v", err) case i.initiateTag != 123: t.Errorf("Unmarshal passed for SCTP packet, but got incorrect initiate tag exp: %d act: %d", 123, i.initiateTag) case i.advertisedReceiverWindowCredit != 1024: t.Errorf("Unmarshal passed for SCTP packet, but got incorrect advertisedReceiverWindowCredit exp: %d act: %d", 1024, i.advertisedReceiverWindowCredit) case i.numOutboundStreams != 1: t.Errorf("Unmarshal passed for SCTP packet, but got incorrect numOutboundStreams tag exp: %d act: %d", 1, i.numOutboundStreams) case i.numInboundStreams != 1: t.Errorf("Unmarshal passed for SCTP packet, but got incorrect numInboundStreams exp: %d act: %d", 1, i.numInboundStreams) case i.initialTSN != 123: t.Errorf("Unmarshal passed for SCTP packet, but got incorrect initialTSN exp: %d act: %d", 123, i.initialTSN) } } func TestPayloadDataMarshalUnmarshal(t *testing.T) { pkt := &packet{} rawPkt := []byte{0x13, 0x88, 0x13, 0x88, 0xfc, 0xd6, 0x3f, 0xc6, 0xbe, 0xfa, 0xdc, 0x52, 0x0a, 0x00, 0x00, 0x24, 0x9b, 0x28, 0x7e, 0x48, 0xa3, 0x7b, 0xc1, 0x83, 0xc4, 0x4b, 0x41, 0x04, 0xa4, 0xf7, 0xed, 0x4c, 0x93, 0x62, 0xc3, 0x49, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0x1f, 0xa8, 0x79, 0xa1, 0xc7, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x32, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x66, 0x6f, 0x6f, 0x00} err := pkt.unmarshal(rawPkt) if err != nil { t.Errorf("Unmarshal failed, has chunk: %v", err) } _, ok := pkt.chunks[1].(*chunkPayloadData) if !ok { t.Error("Failed to cast Chunk -> PayloadData") } } func TestSelectAckChunk(t *testing.T) { pkt := &packet{} rawPkt := []byte{0x13, 0x88, 0x13, 0x88, 0xc2, 0x98, 0x98, 0x0f, 0x42, 0x31, 0xea, 0x78, 0x03, 0x00, 0x00, 0x14, 0x87, 0x73, 0xbd, 0xa4, 0x00, 0x01, 0xfe, 0x74, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x02} err := pkt.unmarshal(rawPkt) if err != nil { t.Errorf("Unmarshal failed, has chunk: %v", err) } _, ok := pkt.chunks[0].(*chunkSelectiveAck) if !ok { t.Error("Failed to cast Chunk -> SelectiveAck") } } func TestReconfigChunk(t *testing.T) { pkt := &packet{} rawPkt := []byte{0x13, 0x88, 0x13, 0x88, 0xb6, 0xa5, 0x12, 0xe5, 0x75, 0x3b, 0x12, 0xd3, 0x82, 0x0, 0x0, 0x16, 0x0, 0xd, 0x0, 0x12, 0x4e, 0x1c, 0xb9, 0xe6, 0x3a, 0x74, 0x8d, 0xff, 0x4e, 0x1c, 0xb9, 0xe6, 0x0, 0x1, 0x0, 0x0} err := pkt.unmarshal(rawPkt) if err != nil { t.Errorf("Unmarshal failed, has chunk: %v", err) } c, ok := pkt.chunks[0].(*chunkReconfig) if !ok { t.Error("Failed to cast Chunk -> Reconfig") } if c.paramA.(*paramOutgoingResetRequest).streamIdentifiers[0] != uint16(1) { //nolint:forcetypeassert t.Errorf("unexpected stream identifier: %d", c.paramA.(*paramOutgoingResetRequest).streamIdentifiers[0]) //nolint:forcetypeassert } } func TestForwardTSNChunk(t *testing.T) { pkt := &packet{} rawPkt := append([]byte{0x13, 0x88, 0x13, 0x88, 0xb6, 0xa5, 0x12, 0xe5, 0x1f, 0x9d, 0xa0, 0xfb}, testChunkForwardTSN()...) err := pkt.unmarshal(rawPkt) if err != nil { t.Errorf("Unmarshal failed, has chunk: %v", err) } c, ok := pkt.chunks[0].(*chunkForwardTSN) if !ok { t.Error("Failed to cast Chunk -> Forward TSN") } if c.newCumulativeTSN != uint32(3) { t.Errorf("unexpected New Cumulative TSN: %d", c.newCumulativeTSN) } } sctp-1.8.6/chunkheader.go000066400000000000000000000063201436021606300153110ustar00rootroot00000000000000package sctp import ( "encoding/binary" "errors" "fmt" ) /* chunkHeader represents a SCTP Chunk header, defined in https://tools.ietf.org/html/rfc4960#section-3.2 The figure below illustrates the field format for the chunks to be transmitted in the SCTP packet. Each chunk is formatted with a Chunk Type field, a chunk-specific Flag field, a Chunk Length field, and a Value field. 0 1 2 3 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Chunk Type | Chunk Flags | Chunk Length | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | | | Chunk Value | | | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ */ type chunkHeader struct { typ chunkType flags byte raw []byte } const ( chunkHeaderSize = 4 ) // SCTP chunk header errors var ( ErrChunkHeaderTooSmall = errors.New("raw is too small for a SCTP chunk") ErrChunkHeaderNotEnoughSpace = errors.New("not enough data left in SCTP packet to satisfy requested length") ErrChunkHeaderPaddingNonZero = errors.New("chunk padding is non-zero at offset") ) func (c *chunkHeader) unmarshal(raw []byte) error { if len(raw) < chunkHeaderSize { return fmt.Errorf("%w: raw only %d bytes, %d is the minimum length", ErrChunkHeaderTooSmall, len(raw), chunkHeaderSize) } c.typ = chunkType(raw[0]) c.flags = raw[1] length := binary.BigEndian.Uint16(raw[2:]) // Length includes Chunk header valueLength := int(length - chunkHeaderSize) lengthAfterValue := len(raw) - (chunkHeaderSize + valueLength) if lengthAfterValue < 0 { return fmt.Errorf("%w: remain %d req %d ", ErrChunkHeaderNotEnoughSpace, valueLength, len(raw)-chunkHeaderSize) } else if lengthAfterValue < 4 { // https://tools.ietf.org/html/rfc4960#section-3.2 // The Chunk Length field does not count any chunk padding. // Chunks (including Type, Length, and Value fields) are padded out // by the sender with all zero bytes to be a multiple of 4 bytes // long. This padding MUST NOT be more than 3 bytes in total. The // Chunk Length value does not include terminating padding of the // chunk. However, it does include padding of any variable-length // parameter except the last parameter in the chunk. The receiver // MUST ignore the padding. for i := lengthAfterValue; i > 0; i-- { paddingOffset := chunkHeaderSize + valueLength + (i - 1) if raw[paddingOffset] != 0 { return fmt.Errorf("%w: %d ", ErrChunkHeaderPaddingNonZero, paddingOffset) } } } c.raw = raw[chunkHeaderSize : chunkHeaderSize+valueLength] return nil } func (c *chunkHeader) marshal() ([]byte, error) { raw := make([]byte, 4+len(c.raw)) raw[0] = uint8(c.typ) raw[1] = c.flags binary.BigEndian.PutUint16(raw[2:], uint16(len(c.raw)+chunkHeaderSize)) copy(raw[4:], c.raw) return raw, nil } func (c *chunkHeader) valueLength() int { return len(c.raw) } // String makes chunkHeader printable func (c chunkHeader) String() string { return c.typ.String() } sctp-1.8.6/chunkheader_test.go000066400000000000000000000000151436021606300163430ustar00rootroot00000000000000package sctp sctp-1.8.6/chunktype.go000066400000000000000000000030471436021606300150450ustar00rootroot00000000000000package sctp import "fmt" // chunkType is an enum for SCTP Chunk Type field // This field identifies the type of information contained in the // Chunk Value field. type chunkType uint8 // List of known chunkType enums const ( ctPayloadData chunkType = 0 ctInit chunkType = 1 ctInitAck chunkType = 2 ctSack chunkType = 3 ctHeartbeat chunkType = 4 ctHeartbeatAck chunkType = 5 ctAbort chunkType = 6 ctShutdown chunkType = 7 ctShutdownAck chunkType = 8 ctError chunkType = 9 ctCookieEcho chunkType = 10 ctCookieAck chunkType = 11 ctCWR chunkType = 13 ctShutdownComplete chunkType = 14 ctReconfig chunkType = 130 ctForwardTSN chunkType = 192 ) func (c chunkType) String() string { switch c { case ctPayloadData: return "DATA" case ctInit: return "INIT" case ctInitAck: return "INIT-ACK" case ctSack: return "SACK" case ctHeartbeat: return "HEARTBEAT" case ctHeartbeatAck: return "HEARTBEAT-ACK" case ctAbort: return "ABORT" case ctShutdown: return "SHUTDOWN" case ctShutdownAck: return "SHUTDOWN-ACK" case ctError: return "ERROR" case ctCookieEcho: return "COOKIE-ECHO" case ctCookieAck: return "COOKIE-ACK" case ctCWR: return "ECNE" // Explicit Congestion Notification Echo case ctShutdownComplete: return "SHUTDOWN-COMPLETE" case ctReconfig: return "RECONFIG" // Re-configuration case ctForwardTSN: return "FORWARD-TSN" default: return fmt.Sprintf("Unknown ChunkType: %d", c) } } sctp-1.8.6/chunktype_test.go000066400000000000000000000014521436021606300161020ustar00rootroot00000000000000package sctp import "testing" func TestChunkType_String(t *testing.T) { tt := []struct { chunkType chunkType expected string }{ {ctPayloadData, "DATA"}, {ctInit, "INIT"}, {ctInitAck, "INIT-ACK"}, {ctSack, "SACK"}, {ctHeartbeat, "HEARTBEAT"}, {ctHeartbeatAck, "HEARTBEAT-ACK"}, {ctAbort, "ABORT"}, {ctShutdown, "SHUTDOWN"}, {ctShutdownAck, "SHUTDOWN-ACK"}, {ctError, "ERROR"}, {ctCookieEcho, "COOKIE-ECHO"}, {ctCookieAck, "COOKIE-ACK"}, {ctCWR, "ECNE"}, {ctShutdownComplete, "SHUTDOWN-COMPLETE"}, {ctReconfig, "RECONFIG"}, {ctForwardTSN, "FORWARD-TSN"}, {chunkType(255), "Unknown ChunkType: 255"}, } for _, tc := range tt { if tc.chunkType.String() != tc.expected { t.Errorf("failed to stringify chunkType %v, expected %s", tc.chunkType, tc.expected) } } } sctp-1.8.6/codecov.yml000066400000000000000000000005521436021606300146470ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # coverage: status: project: default: # Allow decreasing 2% of total coverage to avoid noise. threshold: 2% patch: default: target: 70% only_pulls: true ignore: - "examples/*" - "examples/**/*" sctp-1.8.6/control_queue.go000066400000000000000000000007601436021606300157160ustar00rootroot00000000000000package sctp // control queue type controlQueue struct { queue []*packet } func newControlQueue() *controlQueue { return &controlQueue{queue: []*packet{}} } func (q *controlQueue) push(c *packet) { q.queue = append(q.queue, c) } func (q *controlQueue) pushAll(packets []*packet) { q.queue = append(q.queue, packets...) } func (q *controlQueue) popAll() []*packet { packets := q.queue q.queue = []*packet{} return packets } func (q *controlQueue) size() int { return len(q.queue) } sctp-1.8.6/error_cause.go000066400000000000000000000054511436021606300153450ustar00rootroot00000000000000package sctp import ( "encoding/binary" "errors" "fmt" ) // errorCauseCode is a cause code that appears in either a ERROR or ABORT chunk type errorCauseCode uint16 type errorCause interface { unmarshal([]byte) error marshal() ([]byte, error) length() uint16 String() string errorCauseCode() errorCauseCode } // Error and abort chunk errors var ( ErrBuildErrorCaseHandle = errors.New("BuildErrorCause does not handle") ) // buildErrorCause delegates the building of a error cause from raw bytes to the correct structure func buildErrorCause(raw []byte) (errorCause, error) { var e errorCause c := errorCauseCode(binary.BigEndian.Uint16(raw[0:])) switch c { case invalidMandatoryParameter: e = &errorCauseInvalidMandatoryParameter{} case unrecognizedChunkType: e = &errorCauseUnrecognizedChunkType{} case protocolViolation: e = &errorCauseProtocolViolation{} case userInitiatedAbort: e = &errorCauseUserInitiatedAbort{} default: return nil, fmt.Errorf("%w: %s", ErrBuildErrorCaseHandle, c.String()) } if err := e.unmarshal(raw); err != nil { return nil, err } return e, nil } const ( invalidStreamIdentifier errorCauseCode = 1 missingMandatoryParameter errorCauseCode = 2 staleCookieError errorCauseCode = 3 outOfResource errorCauseCode = 4 unresolvableAddress errorCauseCode = 5 unrecognizedChunkType errorCauseCode = 6 invalidMandatoryParameter errorCauseCode = 7 unrecognizedParameters errorCauseCode = 8 noUserData errorCauseCode = 9 cookieReceivedWhileShuttingDown errorCauseCode = 10 restartOfAnAssociationWithNewAddresses errorCauseCode = 11 userInitiatedAbort errorCauseCode = 12 protocolViolation errorCauseCode = 13 ) func (e errorCauseCode) String() string { switch e { case invalidStreamIdentifier: return "Invalid Stream Identifier" case missingMandatoryParameter: return "Missing Mandatory Parameter" case staleCookieError: return "Stale Cookie Error" case outOfResource: return "Out Of Resource" case unresolvableAddress: return "Unresolvable IP" case unrecognizedChunkType: return "Unrecognized Chunk Type" case invalidMandatoryParameter: return "Invalid Mandatory Parameter" case unrecognizedParameters: return "Unrecognized Parameters" case noUserData: return "No User Data" case cookieReceivedWhileShuttingDown: return "Cookie Received While Shutting Down" case restartOfAnAssociationWithNewAddresses: return "Restart Of An Association With New Addresses" case userInitiatedAbort: return "User Initiated Abort" case protocolViolation: return "Protocol Violation" default: return fmt.Sprintf("Unknown CauseCode: %d", e) } } sctp-1.8.6/error_cause_header.go000066400000000000000000000021161436021606300166500ustar00rootroot00000000000000package sctp import ( "encoding/binary" ) // errorCauseHeader represents the shared header that is shared by all error causes type errorCauseHeader struct { code errorCauseCode len uint16 raw []byte } const ( errorCauseHeaderLength = 4 ) func (e *errorCauseHeader) marshal() ([]byte, error) { e.len = uint16(len(e.raw)) + uint16(errorCauseHeaderLength) raw := make([]byte, e.len) binary.BigEndian.PutUint16(raw[0:], uint16(e.code)) binary.BigEndian.PutUint16(raw[2:], e.len) copy(raw[errorCauseHeaderLength:], e.raw) return raw, nil } func (e *errorCauseHeader) unmarshal(raw []byte) error { e.code = errorCauseCode(binary.BigEndian.Uint16(raw[0:])) e.len = binary.BigEndian.Uint16(raw[2:]) valueLength := e.len - errorCauseHeaderLength e.raw = raw[errorCauseHeaderLength : errorCauseHeaderLength+valueLength] return nil } func (e *errorCauseHeader) length() uint16 { return e.len } func (e *errorCauseHeader) errorCauseCode() errorCauseCode { return e.code } // String makes errorCauseHeader printable func (e errorCauseHeader) String() string { return e.code.String() } sctp-1.8.6/error_cause_invalid_mandatory_parameter.go000066400000000000000000000010521436021606300231620ustar00rootroot00000000000000package sctp // errorCauseInvalidMandatoryParameter represents an SCTP error cause type errorCauseInvalidMandatoryParameter struct { errorCauseHeader } func (e *errorCauseInvalidMandatoryParameter) marshal() ([]byte, error) { return e.errorCauseHeader.marshal() } func (e *errorCauseInvalidMandatoryParameter) unmarshal(raw []byte) error { return e.errorCauseHeader.unmarshal(raw) } // String makes errorCauseInvalidMandatoryParameter printable func (e *errorCauseInvalidMandatoryParameter) String() string { return e.errorCauseHeader.String() } sctp-1.8.6/error_cause_protocol_violation.go000066400000000000000000000033551436021606300213530ustar00rootroot00000000000000package sctp import ( "errors" "fmt" ) /* This error cause MAY be included in ABORT chunks that are sent because an SCTP endpoint detects a protocol violation of the peer that is not covered by the error causes described in Section 3.3.10.1 to Section 3.3.10.12. An implementation MAY provide additional information specifying what kind of protocol violation has been detected. 0 1 2 3 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Cause Code=13 | Cause Length=Variable | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ / Additional Information / \ \ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ */ type errorCauseProtocolViolation struct { errorCauseHeader additionalInformation []byte } // Abort chunk errors var ( ErrProtocolViolationUnmarshal = errors.New("unable to unmarshal Protocol Violation error") ) func (e *errorCauseProtocolViolation) marshal() ([]byte, error) { e.raw = e.additionalInformation return e.errorCauseHeader.marshal() } func (e *errorCauseProtocolViolation) unmarshal(raw []byte) error { err := e.errorCauseHeader.unmarshal(raw) if err != nil { return fmt.Errorf("%w: %v", ErrProtocolViolationUnmarshal, err) } e.additionalInformation = e.raw return nil } // String makes errorCauseProtocolViolation printable func (e *errorCauseProtocolViolation) String() string { return fmt.Sprintf("%s: %s", e.errorCauseHeader, e.additionalInformation) } sctp-1.8.6/error_cause_unrecognized_chunk_type.go000066400000000000000000000013261436021606300223470ustar00rootroot00000000000000package sctp // errorCauseUnrecognizedChunkType represents an SCTP error cause type errorCauseUnrecognizedChunkType struct { errorCauseHeader unrecognizedChunk []byte } func (e *errorCauseUnrecognizedChunkType) marshal() ([]byte, error) { e.code = unrecognizedChunkType e.errorCauseHeader.raw = e.unrecognizedChunk return e.errorCauseHeader.marshal() } func (e *errorCauseUnrecognizedChunkType) unmarshal(raw []byte) error { err := e.errorCauseHeader.unmarshal(raw) if err != nil { return err } e.unrecognizedChunk = e.errorCauseHeader.raw return nil } // String makes errorCauseUnrecognizedChunkType printable func (e *errorCauseUnrecognizedChunkType) String() string { return e.errorCauseHeader.String() } sctp-1.8.6/error_cause_user_initiated_abort.go000066400000000000000000000030531436021606300216200ustar00rootroot00000000000000package sctp import ( "fmt" ) /* This error cause MAY be included in ABORT chunks that are sent because of an upper-layer request. The upper layer can specify an Upper Layer Abort Reason that is transported by SCTP transparently and MAY be delivered to the upper-layer protocol at the peer. 0 1 2 3 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Cause Code=12 | Cause Length=Variable | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ / Upper Layer Abort Reason / \ \ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ */ type errorCauseUserInitiatedAbort struct { errorCauseHeader upperLayerAbortReason []byte } func (e *errorCauseUserInitiatedAbort) marshal() ([]byte, error) { e.code = userInitiatedAbort e.errorCauseHeader.raw = e.upperLayerAbortReason return e.errorCauseHeader.marshal() } func (e *errorCauseUserInitiatedAbort) unmarshal(raw []byte) error { err := e.errorCauseHeader.unmarshal(raw) if err != nil { return err } e.upperLayerAbortReason = e.errorCauseHeader.raw return nil } // String makes errorCauseUserInitiatedAbort printable func (e *errorCauseUserInitiatedAbort) String() string { return fmt.Sprintf("%s: %s", e.errorCauseHeader.String(), e.upperLayerAbortReason) } sctp-1.8.6/examples/000077500000000000000000000000001436021606300143165ustar00rootroot00000000000000sctp-1.8.6/examples/ping-pong/000077500000000000000000000000001436021606300162145ustar00rootroot00000000000000sctp-1.8.6/examples/ping-pong/Makefile000066400000000000000000000001451436021606300176540ustar00rootroot00000000000000all: ping pong ping: ping.go conn.go go build -o $@ pong: pong.go conn.go go build -o $@ -tags $@sctp-1.8.6/examples/ping-pong/README.md000066400000000000000000000010001436021606300174620ustar00rootroot00000000000000# ping-pong ping-pong is a sctp example that shows how you can send/recv messages. In this example, there are 2 types of peers: **ping** and **pong**. **Ping** will always send `ping ` messages to **pong** and receive `pong ` messages from **pong**. **Pong** will always receive `ping ` from **ping** and send `pong ` messages to **ping**. ## Instruction ### Build ping and pong ```sh make ``` ### Run pong ```sh ./pong ``` ### Run ping ```sh ./ping ```sctp-1.8.6/examples/ping-pong/conn.go000066400000000000000000000033061436021606300175020ustar00rootroot00000000000000package main import ( "net" "sync" "time" ) // Reference: https://github.com/pion/sctp/blob/master/association_test.go // Since UDP is connectionless, as a server, it doesn't know how to reply // simply using the `Write` method. So, to make it work, `disconnectedPacketConn` // will infer the last packet that it reads as the reply address for `Write` type disconnectedPacketConn struct { // nolint: unused mu sync.RWMutex rAddr net.Addr pConn net.PacketConn } // Read func (c *disconnectedPacketConn) Read(p []byte) (int, error) { //nolint:unused i, rAddr, err := c.pConn.ReadFrom(p) if err != nil { return 0, err } c.mu.Lock() c.rAddr = rAddr c.mu.Unlock() return i, err } // Write writes len(p) bytes from p to the DTLS connection func (c *disconnectedPacketConn) Write(p []byte) (n int, err error) { //nolint:unused return c.pConn.WriteTo(p, c.RemoteAddr()) } // Close closes the conn and releases any Read calls func (c *disconnectedPacketConn) Close() error { //nolint:unused return c.pConn.Close() } // LocalAddr is a stub func (c *disconnectedPacketConn) LocalAddr() net.Addr { //nolint:unused if c.pConn != nil { return c.pConn.LocalAddr() } return nil } // RemoteAddr is a stub func (c *disconnectedPacketConn) RemoteAddr() net.Addr { //nolint:unused c.mu.RLock() defer c.mu.RUnlock() return c.rAddr } // SetDeadline is a stub func (c *disconnectedPacketConn) SetDeadline(t time.Time) error { //nolint:unused return nil } // SetReadDeadline is a stub func (c *disconnectedPacketConn) SetReadDeadline(t time.Time) error { //nolint:unused return nil } // SetWriteDeadline is a stub func (c *disconnectedPacketConn) SetWriteDeadline(t time.Time) error { //nolint:unused return nil } sctp-1.8.6/examples/ping-pong/ping.go000066400000000000000000000025741436021606300175100ustar00rootroot00000000000000//go:build !pong // +build !pong package main import ( "fmt" "log" "net" "github.com/pion/logging" "github.com/pion/sctp" ) func main() { conn, err := net.Dial("udp", "127.0.0.1:5678") if err != nil { log.Panic(err) } defer func() { if closeErr := conn.Close(); closeErr != nil { panic(err) } }() fmt.Println("dialed udp ponger") config := sctp.Config{ NetConn: conn, LoggerFactory: logging.NewDefaultLoggerFactory(), } a, err := sctp.Client(config) if err != nil { log.Panic(err) } defer func() { if closeErr := a.Close(); closeErr != nil { panic(err) } }() fmt.Println("created a client") stream, err := a.OpenStream(0, sctp.PayloadTypeWebRTCString) if err != nil { log.Panic(err) } defer func() { if closeErr := stream.Close(); closeErr != nil { panic(err) } }() fmt.Println("opened a stream") // set unordered = true and 10ms treshold for dropping packets stream.SetReliabilityParams(true, sctp.ReliabilityTypeTimed, 10) go func() { var pingSeqNum int for { pingMsg := fmt.Sprintf("ping %d", pingSeqNum) _, err = stream.Write([]byte(pingMsg)) if err != nil { log.Panic(err) } fmt.Println("sent:", pingMsg) pingSeqNum++ } }() for { buff := make([]byte, 1024) _, err = stream.Read(buff) if err != nil { log.Panic(err) } pongMsg := string(buff) fmt.Println("received:", pongMsg) } } sctp-1.8.6/examples/ping-pong/pong.go000066400000000000000000000024051436021606300175070ustar00rootroot00000000000000// +build pong package main import ( "fmt" "log" "net" "time" "github.com/pion/logging" "github.com/pion/sctp" ) func main() { addr := net.UDPAddr{ IP: net.IPv4(127, 0, 0, 1), Port: 5678, } conn, err := net.ListenUDP("udp", &addr) if err != nil { log.Panic(err) } defer conn.Close() fmt.Println("created a udp listener") config := sctp.Config{ NetConn: &disconnectedPacketConn{pConn: conn}, LoggerFactory: logging.NewDefaultLoggerFactory(), } a, err := sctp.Server(config) if err != nil { log.Panic(err) } defer a.Close() fmt.Println("created a server") stream, err := a.AcceptStream() if err != nil { log.Panic(err) } defer stream.Close() fmt.Println("accepted a stream") // set unordered = true and 10ms treshold for dropping packets stream.SetReliabilityParams(true, sctp.ReliabilityTypeTimed, 10) var pongSeqNum int for { buff := make([]byte, 1024) _, err = stream.Read(buff) if err != nil { log.Panic(err) } pingMsg := string(buff) fmt.Println("received:", pingMsg) fmt.Sscanf(pingMsg, "ping %d", &pongSeqNum) pongMsg := fmt.Sprintf("pong %d", pongSeqNum) _, err = stream.Write([]byte(pongMsg)) if err != nil { log.Panic(err) } fmt.Println("sent:", pongMsg) time.Sleep(time.Second) } } sctp-1.8.6/go.mod000066400000000000000000000004451436021606300136110ustar00rootroot00000000000000module github.com/pion/sctp require ( github.com/kr/pretty v0.1.0 // indirect github.com/pion/logging v0.2.2 github.com/pion/randutil v0.1.0 github.com/pion/transport v0.14.1 github.com/stretchr/testify v1.8.1 gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect ) go 1.13 sctp-1.8.6/go.sum000066400000000000000000000123501436021606300136340ustar00rootroot00000000000000github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY= github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA= github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8= github.com/pion/transport v0.14.1 h1:XSM6olwW+o8J4SCmOBb/BpwZypkHeyM0PGFCxNQBr40= github.com/pion/transport v0.14.1/go.mod h1:4tGmbk00NeYA3rUa9+n+dzCCoKkcy3YlYb99Jn2fNnI= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.1.0 h1:hZ/3BUoy5aId7sCpA/Tc5lt8DkFgdVS2onTpJsZ/fl0= golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= sctp-1.8.6/packet.go000066400000000000000000000133571436021606300143070ustar00rootroot00000000000000package sctp import ( "encoding/binary" "errors" "fmt" "hash/crc32" ) // Create the crc32 table we'll use for the checksum var castagnoliTable = crc32.MakeTable(crc32.Castagnoli) // nolint:gochecknoglobals // Allocate and zero this data once. // We need to use it for the checksum and don't want to allocate/clear each time. var fourZeroes [4]byte // nolint:gochecknoglobals /* Packet represents an SCTP packet, defined in https://tools.ietf.org/html/rfc4960#section-3 An SCTP packet is composed of a common header and chunks. A chunk contains either control information or user data. SCTP Packet Format 0 1 2 3 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Common Header | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Chunk #1 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | ... | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Chunk #n | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ SCTP Common Header Format 0 1 2 3 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Source Value Number | Destination Value Number | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Verification Tag | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Checksum | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ */ type packet struct { sourcePort uint16 destinationPort uint16 verificationTag uint32 chunks []chunk } const ( packetHeaderSize = 12 ) // SCTP packet errors var ( ErrPacketRawTooSmall = errors.New("raw is smaller than the minimum length for a SCTP packet") ErrParseSCTPChunkNotEnoughData = errors.New("unable to parse SCTP chunk, not enough data for complete header") ErrUnmarshalUnknownChunkType = errors.New("failed to unmarshal, contains unknown chunk type") ErrChecksumMismatch = errors.New("checksum mismatch theirs") ) func (p *packet) unmarshal(raw []byte) error { if len(raw) < packetHeaderSize { return fmt.Errorf("%w: raw only %d bytes, %d is the minimum length", ErrPacketRawTooSmall, len(raw), packetHeaderSize) } p.sourcePort = binary.BigEndian.Uint16(raw[0:]) p.destinationPort = binary.BigEndian.Uint16(raw[2:]) p.verificationTag = binary.BigEndian.Uint32(raw[4:]) offset := packetHeaderSize for { // Exact match, no more chunks if offset == len(raw) { break } else if offset+chunkHeaderSize > len(raw) { return fmt.Errorf("%w: offset %d remaining %d", ErrParseSCTPChunkNotEnoughData, offset, len(raw)) } var c chunk switch chunkType(raw[offset]) { case ctInit: c = &chunkInit{} case ctInitAck: c = &chunkInitAck{} case ctAbort: c = &chunkAbort{} case ctCookieEcho: c = &chunkCookieEcho{} case ctCookieAck: c = &chunkCookieAck{} case ctHeartbeat: c = &chunkHeartbeat{} case ctPayloadData: c = &chunkPayloadData{} case ctSack: c = &chunkSelectiveAck{} case ctReconfig: c = &chunkReconfig{} case ctForwardTSN: c = &chunkForwardTSN{} case ctError: c = &chunkError{} case ctShutdown: c = &chunkShutdown{} case ctShutdownAck: c = &chunkShutdownAck{} case ctShutdownComplete: c = &chunkShutdownComplete{} default: return fmt.Errorf("%w: %s", ErrUnmarshalUnknownChunkType, chunkType(raw[offset]).String()) } if err := c.unmarshal(raw[offset:]); err != nil { return err } p.chunks = append(p.chunks, c) chunkValuePadding := getPadding(c.valueLength()) offset += chunkHeaderSize + c.valueLength() + chunkValuePadding } theirChecksum := binary.LittleEndian.Uint32(raw[8:]) ourChecksum := generatePacketChecksum(raw) if theirChecksum != ourChecksum { return fmt.Errorf("%w: %d ours: %d", ErrChecksumMismatch, theirChecksum, ourChecksum) } return nil } func (p *packet) marshal() ([]byte, error) { raw := make([]byte, packetHeaderSize) // Populate static headers // 8-12 is Checksum which will be populated when packet is complete binary.BigEndian.PutUint16(raw[0:], p.sourcePort) binary.BigEndian.PutUint16(raw[2:], p.destinationPort) binary.BigEndian.PutUint32(raw[4:], p.verificationTag) // Populate chunks for _, c := range p.chunks { chunkRaw, err := c.marshal() if err != nil { return nil, err } raw = append(raw, chunkRaw...) paddingNeeded := getPadding(len(raw)) if paddingNeeded != 0 { raw = append(raw, make([]byte, paddingNeeded)...) } } // Checksum is already in BigEndian // Using LittleEndian.PutUint32 stops it from being flipped binary.LittleEndian.PutUint32(raw[8:], generatePacketChecksum(raw)) return raw, nil } func generatePacketChecksum(raw []byte) (sum uint32) { // Fastest way to do a crc32 without allocating. sum = crc32.Update(sum, castagnoliTable, raw[0:8]) sum = crc32.Update(sum, castagnoliTable, fourZeroes[:]) sum = crc32.Update(sum, castagnoliTable, raw[12:]) return sum } // String makes packet printable func (p *packet) String() string { format := `Packet: sourcePort: %d destinationPort: %d verificationTag: %d ` res := fmt.Sprintf(format, p.sourcePort, p.destinationPort, p.verificationTag, ) for i, chunk := range p.chunks { res += fmt.Sprintf("Chunk %d:\n %s", i, chunk) } return res } sctp-1.8.6/packet_test.go000066400000000000000000000046011436021606300153360ustar00rootroot00000000000000package sctp import ( "bytes" "testing" ) func TestPacketUnmarshal(t *testing.T) { pkt := &packet{} if err := pkt.unmarshal([]byte{}); err == nil { t.Errorf("Unmarshal should fail when a packet is too small to be SCTP") } headerOnly := []byte{0x13, 0x88, 0x13, 0x88, 0x00, 0x00, 0x00, 0x00, 0x06, 0xa9, 0x00, 0xe1} err := pkt.unmarshal(headerOnly) switch { case err != nil: t.Errorf("Unmarshal failed for SCTP packet with no chunks: %v", err) case pkt.sourcePort != 5000: t.Errorf("Unmarshal passed for SCTP packet, but got incorrect source port exp: %d act: %d", 5000, pkt.sourcePort) case pkt.destinationPort != 5000: t.Errorf("Unmarshal passed for SCTP packet, but got incorrect destination port exp: %d act: %d", 5000, pkt.destinationPort) case pkt.verificationTag != 0: t.Errorf("Unmarshal passed for SCTP packet, but got incorrect verification tag exp: %d act: %d", 0, pkt.verificationTag) } rawChunk := []byte{ 0x13, 0x88, 0x13, 0x88, 0x00, 0x00, 0x00, 0x00, 0x81, 0x46, 0x9d, 0xfc, 0x01, 0x00, 0x00, 0x56, 0x55, 0xb9, 0x64, 0xa5, 0x00, 0x02, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0xe8, 0x6d, 0x10, 0x30, 0xc0, 0x00, 0x00, 0x04, 0x80, 0x08, 0x00, 0x09, 0xc0, 0x0f, 0xc1, 0x80, 0x82, 0x00, 0x00, 0x00, 0x80, 0x02, 0x00, 0x24, 0x9f, 0xeb, 0xbb, 0x5c, 0x50, 0xc9, 0xbf, 0x75, 0x9c, 0xb1, 0x2c, 0x57, 0x4f, 0xa4, 0x5a, 0x51, 0xba, 0x60, 0x17, 0x78, 0x27, 0x94, 0x5c, 0x31, 0xe6, 0x5d, 0x5b, 0x09, 0x47, 0xe2, 0x22, 0x06, 0x80, 0x04, 0x00, 0x06, 0x00, 0x01, 0x00, 0x00, 0x80, 0x03, 0x00, 0x06, 0x80, 0xc1, 0x00, 0x00, } if err := pkt.unmarshal(rawChunk); err != nil { t.Errorf("Unmarshal failed, has chunk: %v", err) } } func TestPacketMarshal(t *testing.T) { pkt := &packet{} headerOnly := []byte{0x13, 0x88, 0x13, 0x88, 0x00, 0x00, 0x00, 0x00, 0x06, 0xa9, 0x00, 0xe1} if err := pkt.unmarshal(headerOnly); err != nil { t.Errorf("Unmarshal failed for SCTP packet with no chunks: %v", err) } headerOnlyMarshaled, err := pkt.marshal() if err != nil { t.Errorf("Marshal failed for SCTP packet with no chunks: %v", err) } else if !bytes.Equal(headerOnly, headerOnlyMarshaled) { t.Errorf("Unmarshal/Marshaled header only packet did not match \nheaderOnly: % 02x \nheaderOnlyMarshaled % 02x", headerOnly, headerOnlyMarshaled) } } func BenchmarkPacketGenerateChecksum(b *testing.B) { var data [1024]byte for i := 0; i < b.N; i++ { _ = generatePacketChecksum(data[:]) } } sctp-1.8.6/param.go000066400000000000000000000021621436021606300141300ustar00rootroot00000000000000package sctp import ( "errors" "fmt" ) type param interface { marshal() ([]byte, error) length() int } // ErrParamTypeUnhandled is returned if unknown parameter type is specified. var ErrParamTypeUnhandled = errors.New("unhandled ParamType") func buildParam(t paramType, rawParam []byte) (param, error) { switch t { case forwardTSNSupp: return (¶mForwardTSNSupported{}).unmarshal(rawParam) case supportedExt: return (¶mSupportedExtensions{}).unmarshal(rawParam) case ecnCapable: return (¶mECNCapable{}).unmarshal(rawParam) case random: return (¶mRandom{}).unmarshal(rawParam) case reqHMACAlgo: return (¶mRequestedHMACAlgorithm{}).unmarshal(rawParam) case chunkList: return (¶mChunkList{}).unmarshal(rawParam) case stateCookie: return (¶mStateCookie{}).unmarshal(rawParam) case heartbeatInfo: return (¶mHeartbeatInfo{}).unmarshal(rawParam) case outSSNResetReq: return (¶mOutgoingResetRequest{}).unmarshal(rawParam) case reconfigResp: return (¶mReconfigResponse{}).unmarshal(rawParam) default: return nil, fmt.Errorf("%w: %v", ErrParamTypeUnhandled, t) } } sctp-1.8.6/param_chunk_list.go000066400000000000000000000010211436021606300163440ustar00rootroot00000000000000package sctp type paramChunkList struct { paramHeader chunkTypes []chunkType } func (c *paramChunkList) marshal() ([]byte, error) { c.typ = chunkList c.raw = make([]byte, len(c.chunkTypes)) for i, t := range c.chunkTypes { c.raw[i] = byte(t) } return c.paramHeader.marshal() } func (c *paramChunkList) unmarshal(raw []byte) (param, error) { err := c.paramHeader.unmarshal(raw) if err != nil { return nil, err } for _, t := range c.raw { c.chunkTypes = append(c.chunkTypes, chunkType(t)) } return c, nil } sctp-1.8.6/param_ecn_capable.go000066400000000000000000000005301436021606300164210ustar00rootroot00000000000000package sctp type paramECNCapable struct { paramHeader } func (r *paramECNCapable) marshal() ([]byte, error) { r.typ = ecnCapable r.raw = []byte{} return r.paramHeader.marshal() } func (r *paramECNCapable) unmarshal(raw []byte) (param, error) { err := r.paramHeader.unmarshal(raw) if err != nil { return nil, err } return r, nil } sctp-1.8.6/param_ecn_capable_test.go000066400000000000000000000021371436021606300174650ustar00rootroot00000000000000package sctp // nolint:dupl import ( "testing" "github.com/stretchr/testify/assert" ) func testParamECNCapabale() []byte { return []byte{0x80, 0x0, 0x0, 0x4} } func TestParamECNCapabale_Success(t *testing.T) { tt := []struct { binary []byte parsed *paramECNCapable }{ { testParamECNCapabale(), ¶mECNCapable{ paramHeader: paramHeader{ typ: ecnCapable, len: 4, raw: []byte{}, }, }, }, } for i, tc := range tt { actual := ¶mECNCapable{} _, err := actual.unmarshal(tc.binary) if err != nil { t.Fatalf("failed to unmarshal #%d: %v", i, err) } assert.Equal(t, tc.parsed, actual) b, err := actual.marshal() if err != nil { t.Fatalf("failed to marshal: %v", err) } assert.Equal(t, tc.binary, b) } } func TestParamECNCapabale_Failure(t *testing.T) { tt := []struct { name string binary []byte }{ {"param too short", []byte{0x0, 0xd, 0x0}}, } for i, tc := range tt { actual := ¶mECNCapable{} _, err := actual.unmarshal(tc.binary) if err == nil { t.Errorf("expected unmarshal #%d: '%s' to fail.", i, tc.name) } } } sctp-1.8.6/param_forward_tsn_supported.go000066400000000000000000000015261436021606300206500ustar00rootroot00000000000000package sctp // At the initialization of the association, the sender of the INIT or // INIT ACK chunk MAY include this OPTIONAL parameter to inform its peer // that it is able to support the Forward TSN chunk // // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | Parameter Type = 49152 | Parameter Length = 4 | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ type paramForwardTSNSupported struct { paramHeader } func (f *paramForwardTSNSupported) marshal() ([]byte, error) { f.typ = forwardTSNSupp f.raw = []byte{} return f.paramHeader.marshal() } func (f *paramForwardTSNSupported) unmarshal(raw []byte) (param, error) { err := f.paramHeader.unmarshal(raw) if err != nil { return nil, err } return f, nil } sctp-1.8.6/param_forward_tsn_supported_test.go000066400000000000000000000022471436021606300217100ustar00rootroot00000000000000package sctp // nolint:dupl import ( "testing" "github.com/stretchr/testify/assert" ) func testParamForwardTSNSupported() []byte { return []byte{0xc0, 0x0, 0x0, 0x4} } func TestParamForwardTSNSupported_Success(t *testing.T) { tt := []struct { binary []byte parsed *paramForwardTSNSupported }{ { testParamForwardTSNSupported(), ¶mForwardTSNSupported{ paramHeader: paramHeader{ typ: forwardTSNSupp, len: 4, raw: []byte{}, }, }, }, } for i, tc := range tt { actual := ¶mForwardTSNSupported{} _, err := actual.unmarshal(tc.binary) if err != nil { t.Fatalf("failed to unmarshal #%d: %v", i, err) } assert.Equal(t, tc.parsed, actual) b, err := actual.marshal() if err != nil { t.Fatalf("failed to marshal: %v", err) } assert.Equal(t, tc.binary, b) } } func TestParamForwardTSNSupported_Failure(t *testing.T) { tt := []struct { name string binary []byte }{ {"param too short", []byte{0x0, 0xd, 0x0}}, } for i, tc := range tt { actual := ¶mForwardTSNSupported{} _, err := actual.unmarshal(tc.binary) if err == nil { t.Errorf("expected unmarshal #%d: '%s' to fail.", i, tc.name) } } } sctp-1.8.6/param_heartbeat_info.go000066400000000000000000000006571436021606300171710ustar00rootroot00000000000000package sctp type paramHeartbeatInfo struct { paramHeader heartbeatInformation []byte } func (h *paramHeartbeatInfo) marshal() ([]byte, error) { h.typ = heartbeatInfo h.raw = h.heartbeatInformation return h.paramHeader.marshal() } func (h *paramHeartbeatInfo) unmarshal(raw []byte) (param, error) { err := h.paramHeader.unmarshal(raw) if err != nil { return nil, err } h.heartbeatInformation = h.raw return h, nil } sctp-1.8.6/param_outgoing_reset_request.go000066400000000000000000000075461436021606300210300ustar00rootroot00000000000000package sctp import ( "encoding/binary" "errors" ) const ( paramOutgoingResetRequestStreamIdentifiersOffset = 12 ) // This parameter is used by the sender to request the reset of some or // all outgoing streams. // 0 1 2 3 // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | Parameter Type = 13 | Parameter Length = 16 + 2 * N | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | Re-configuration Request Sequence Number | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | Re-configuration Response Sequence Number | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | Sender's Last Assigned TSN | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | Stream Number 1 (optional) | Stream Number 2 (optional) | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // / ...... / // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | Stream Number N-1 (optional) | Stream Number N (optional) | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ type paramOutgoingResetRequest struct { paramHeader // reconfigRequestSequenceNumber is used to identify the request. It is a monotonically // increasing number that is initialized to the same value as the // initial TSN. It is increased by 1 whenever sending a new Re- // configuration Request Parameter. reconfigRequestSequenceNumber uint32 // When this Outgoing SSN Reset Request Parameter is sent in response // to an Incoming SSN Reset Request Parameter, this parameter is also // an implicit response to the incoming request. This field then // holds the Re-configuration Request Sequence Number of the incoming // request. In other cases, it holds the next expected // Re-configuration Request Sequence Number minus 1. reconfigResponseSequenceNumber uint32 // This value holds the next TSN minus 1 -- in other words, the last // TSN that this sender assigned. senderLastTSN uint32 // This optional field, if included, is used to indicate specific // streams that are to be reset. If no streams are listed, then all // streams are to be reset. streamIdentifiers []uint16 } // Outgoing reset request parameter errors var ( ErrSSNResetRequestParamTooShort = errors.New("outgoing SSN reset request parameter too short") ) func (r *paramOutgoingResetRequest) marshal() ([]byte, error) { r.typ = outSSNResetReq r.raw = make([]byte, paramOutgoingResetRequestStreamIdentifiersOffset+2*len(r.streamIdentifiers)) binary.BigEndian.PutUint32(r.raw, r.reconfigRequestSequenceNumber) binary.BigEndian.PutUint32(r.raw[4:], r.reconfigResponseSequenceNumber) binary.BigEndian.PutUint32(r.raw[8:], r.senderLastTSN) for i, sID := range r.streamIdentifiers { binary.BigEndian.PutUint16(r.raw[paramOutgoingResetRequestStreamIdentifiersOffset+2*i:], sID) } return r.paramHeader.marshal() } func (r *paramOutgoingResetRequest) unmarshal(raw []byte) (param, error) { err := r.paramHeader.unmarshal(raw) if err != nil { return nil, err } if len(r.raw) < paramOutgoingResetRequestStreamIdentifiersOffset { return nil, ErrSSNResetRequestParamTooShort } r.reconfigRequestSequenceNumber = binary.BigEndian.Uint32(r.raw) r.reconfigResponseSequenceNumber = binary.BigEndian.Uint32(r.raw[4:]) r.senderLastTSN = binary.BigEndian.Uint32(r.raw[8:]) lim := (len(r.raw) - paramOutgoingResetRequestStreamIdentifiersOffset) / 2 r.streamIdentifiers = make([]uint16, lim) for i := 0; i < lim; i++ { r.streamIdentifiers[i] = binary.BigEndian.Uint16(r.raw[paramOutgoingResetRequestStreamIdentifiersOffset+2*i:]) } return r, nil } sctp-1.8.6/param_outgoing_reset_request_test.go000066400000000000000000000037361436021606300220640ustar00rootroot00000000000000package sctp import ( "testing" "github.com/stretchr/testify/assert" ) func testChunkReconfigParamA() []byte { return []byte{0x0, 0xd, 0x0, 0x16, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x3, 0x0, 0x4, 0x0, 0x5, 0x0, 0x6} } func testChunkReconfigParamB() []byte { return []byte{0x0, 0xd, 0x0, 0x10, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x3} } func TestParamOutgoingResetRequest_Success(t *testing.T) { tt := []struct { binary []byte parsed *paramOutgoingResetRequest }{ { testChunkReconfigParamA(), ¶mOutgoingResetRequest{ paramHeader: paramHeader{ typ: outSSNResetReq, len: 22, raw: testChunkReconfigParamA()[4:], }, reconfigRequestSequenceNumber: 1, reconfigResponseSequenceNumber: 2, senderLastTSN: 3, streamIdentifiers: []uint16{4, 5, 6}, }, }, { testChunkReconfigParamB(), ¶mOutgoingResetRequest{ paramHeader: paramHeader{ typ: outSSNResetReq, len: 16, raw: testChunkReconfigParamB()[4:], }, reconfigRequestSequenceNumber: 1, reconfigResponseSequenceNumber: 2, senderLastTSN: 3, streamIdentifiers: []uint16{}, }, }, } for i, tc := range tt { actual := ¶mOutgoingResetRequest{} _, err := actual.unmarshal(tc.binary) if err != nil { t.Fatalf("failed to unmarshal #%d: %v", i, err) } assert.Equal(t, tc.parsed, actual) b, err := actual.marshal() if err != nil { t.Fatalf("failed to marshal: %v", err) } assert.Equal(t, tc.binary, b) } } func TestParamOutgoingResetRequest_Failure(t *testing.T) { tt := []struct { name string binary []byte }{ {"packet too short", testChunkReconfigParamA()[:8]}, {"param too short", []byte{0x0, 0xd, 0x0, 0x4}}, } for i, tc := range tt { actual := ¶mOutgoingResetRequest{} _, err := actual.unmarshal(tc.binary) if err == nil { t.Errorf("expected unmarshal #%d: '%s' to fail.", i, tc.name) } } } sctp-1.8.6/param_random.go000066400000000000000000000005651436021606300154750ustar00rootroot00000000000000package sctp type paramRandom struct { paramHeader randomData []byte } func (r *paramRandom) marshal() ([]byte, error) { r.typ = random r.raw = r.randomData return r.paramHeader.marshal() } func (r *paramRandom) unmarshal(raw []byte) (param, error) { err := r.paramHeader.unmarshal(raw) if err != nil { return nil, err } r.randomData = r.raw return r, nil } sctp-1.8.6/param_reconfig_response.go000066400000000000000000000063341436021606300177270ustar00rootroot00000000000000package sctp import ( "encoding/binary" "errors" "fmt" ) // This parameter is used by the receiver of a Re-configuration Request // Parameter to respond to the request. // // 0 1 2 3 // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | Parameter Type = 16 | Parameter Length | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | Re-configuration Response Sequence Number | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | Result | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | Sender's Next TSN (optional) | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | Receiver's Next TSN (optional) | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ type paramReconfigResponse struct { paramHeader // This value is copied from the request parameter and is used by the // receiver of the Re-configuration Response Parameter to tie the // response to the request. reconfigResponseSequenceNumber uint32 // This value describes the result of the processing of the request. result reconfigResult } type reconfigResult uint32 const ( reconfigResultSuccessNOP reconfigResult = 0 reconfigResultSuccessPerformed reconfigResult = 1 reconfigResultDenied reconfigResult = 2 reconfigResultErrorWrongSSN reconfigResult = 3 reconfigResultErrorRequestAlreadyInProgress reconfigResult = 4 reconfigResultErrorBadSequenceNumber reconfigResult = 5 reconfigResultInProgress reconfigResult = 6 ) // Reconfiguration response errors var ( ErrReconfigRespParamTooShort = errors.New("reconfig response parameter too short") ) func (t reconfigResult) String() string { switch t { case reconfigResultSuccessNOP: return "0: Success - Nothing to do" case reconfigResultSuccessPerformed: return "1: Success - Performed" case reconfigResultDenied: return "2: Denied" case reconfigResultErrorWrongSSN: return "3: Error - Wrong SSN" case reconfigResultErrorRequestAlreadyInProgress: return "4: Error - Request already in progress" case reconfigResultErrorBadSequenceNumber: return "5: Error - Bad Sequence Number" case reconfigResultInProgress: return "6: In progress" default: return fmt.Sprintf("Unknown reconfigResult: %d", t) } } func (r *paramReconfigResponse) marshal() ([]byte, error) { r.typ = reconfigResp r.raw = make([]byte, 8) binary.BigEndian.PutUint32(r.raw, r.reconfigResponseSequenceNumber) binary.BigEndian.PutUint32(r.raw[4:], uint32(r.result)) return r.paramHeader.marshal() } func (r *paramReconfigResponse) unmarshal(raw []byte) (param, error) { err := r.paramHeader.unmarshal(raw) if err != nil { return nil, err } if len(r.raw) < 8 { return nil, ErrReconfigRespParamTooShort } r.reconfigResponseSequenceNumber = binary.BigEndian.Uint32(r.raw) r.result = reconfigResult(binary.BigEndian.Uint32(r.raw[4:])) return r, nil } sctp-1.8.6/param_reconfig_response_test.go000066400000000000000000000040051436021606300207570ustar00rootroot00000000000000package sctp import ( "testing" "github.com/stretchr/testify/assert" ) func testChunkReconfigResponce() []byte { return []byte{0x0, 0x10, 0x0, 0xc, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x1} } func TestParamReconfigResponse_Success(t *testing.T) { tt := []struct { binary []byte parsed *paramReconfigResponse }{ { testChunkReconfigResponce(), ¶mReconfigResponse{ paramHeader: paramHeader{ typ: reconfigResp, len: 12, raw: testChunkReconfigResponce()[4:], }, reconfigResponseSequenceNumber: 1, result: reconfigResultSuccessPerformed, }, }, } for i, tc := range tt { actual := ¶mReconfigResponse{} _, err := actual.unmarshal(tc.binary) if err != nil { t.Fatalf("failed to unmarshal #%d: %v", i, err) } assert.Equal(t, tc.parsed, actual) b, err := actual.marshal() if err != nil { t.Fatalf("failed to marshal: %v", err) } assert.Equal(t, tc.binary, b) } } func TestParamReconfigResponse_Failure(t *testing.T) { tt := []struct { name string binary []byte }{ {"packet too short", testChunkReconfigParamA()[:8]}, {"param too short", []byte{0x0, 0x10, 0x0, 0x4}}, } for i, tc := range tt { actual := ¶mReconfigResponse{} _, err := actual.unmarshal(tc.binary) if err == nil { t.Errorf("expected unmarshal #%d: '%s' to fail.", i, tc.name) } } } func TestReconfigResultStringer(t *testing.T) { tt := []struct { result reconfigResult expected string }{ {reconfigResultSuccessNOP, "0: Success - Nothing to do"}, {reconfigResultSuccessPerformed, "1: Success - Performed"}, {reconfigResultDenied, "2: Denied"}, {reconfigResultErrorWrongSSN, "3: Error - Wrong SSN"}, {reconfigResultErrorRequestAlreadyInProgress, "4: Error - Request already in progress"}, {reconfigResultErrorBadSequenceNumber, "5: Error - Bad Sequence Number"}, {reconfigResultInProgress, "6: In progress"}, } for i, tc := range tt { actual := tc.result.String() assert.Equal(t, tc.expected, actual, "Test case %d", i) } } sctp-1.8.6/param_requested_hmac_algorithm.go000066400000000000000000000030011436021606300212400ustar00rootroot00000000000000package sctp import ( "encoding/binary" "errors" "fmt" ) type hmacAlgorithm uint16 const ( hmacResv1 hmacAlgorithm = 0 hmacSHA128 = 1 hmacResv2 hmacAlgorithm = 2 hmacSHA256 hmacAlgorithm = 3 ) // ErrInvalidAlgorithmType is returned if unknown auth algorithm is specified. var ErrInvalidAlgorithmType = errors.New("invalid algorithm type") func (c hmacAlgorithm) String() string { switch c { case hmacResv1: return "HMAC Reserved (0x00)" case hmacSHA128: return "HMAC SHA-128" case hmacResv2: return "HMAC Reserved (0x02)" case hmacSHA256: return "HMAC SHA-256" default: return fmt.Sprintf("Unknown HMAC Algorithm type: %d", c) } } type paramRequestedHMACAlgorithm struct { paramHeader availableAlgorithms []hmacAlgorithm } func (r *paramRequestedHMACAlgorithm) marshal() ([]byte, error) { r.typ = reqHMACAlgo r.raw = make([]byte, len(r.availableAlgorithms)*2) i := 0 for _, a := range r.availableAlgorithms { binary.BigEndian.PutUint16(r.raw[i:], uint16(a)) i += 2 } return r.paramHeader.marshal() } func (r *paramRequestedHMACAlgorithm) unmarshal(raw []byte) (param, error) { err := r.paramHeader.unmarshal(raw) if err != nil { return nil, err } i := 0 for i < len(r.raw) { a := hmacAlgorithm(binary.BigEndian.Uint16(r.raw[i:])) switch a { case hmacSHA128: fallthrough case hmacSHA256: r.availableAlgorithms = append(r.availableAlgorithms, a) default: return nil, fmt.Errorf("%w: %v", ErrInvalidAlgorithmType, a) } i += 2 } return r, nil } sctp-1.8.6/param_state_cookie.go000066400000000000000000000015331436021606300166620ustar00rootroot00000000000000package sctp import ( "crypto/rand" "fmt" ) type paramStateCookie struct { paramHeader cookie []byte } func newRandomStateCookie() (*paramStateCookie, error) { randCookie := make([]byte, 32) _, err := rand.Read(randCookie) // crypto/rand.Read returns n == len(b) if and only if err == nil. if err != nil { return nil, err } s := ¶mStateCookie{ cookie: randCookie, } return s, nil } func (s *paramStateCookie) marshal() ([]byte, error) { s.typ = stateCookie s.raw = s.cookie return s.paramHeader.marshal() } func (s *paramStateCookie) unmarshal(raw []byte) (param, error) { err := s.paramHeader.unmarshal(raw) if err != nil { return nil, err } s.cookie = s.raw return s, nil } // String makes paramStateCookie printable func (s *paramStateCookie) String() string { return fmt.Sprintf("%s: %s", s.paramHeader, s.cookie) } sctp-1.8.6/param_supported_extensions.go000066400000000000000000000010631436021606300205130ustar00rootroot00000000000000package sctp type paramSupportedExtensions struct { paramHeader ChunkTypes []chunkType } func (s *paramSupportedExtensions) marshal() ([]byte, error) { s.typ = supportedExt s.raw = make([]byte, len(s.ChunkTypes)) for i, c := range s.ChunkTypes { s.raw[i] = byte(c) } return s.paramHeader.marshal() } func (s *paramSupportedExtensions) unmarshal(raw []byte) (param, error) { err := s.paramHeader.unmarshal(raw) if err != nil { return nil, err } for _, t := range s.raw { s.ChunkTypes = append(s.ChunkTypes, chunkType(t)) } return s, nil } sctp-1.8.6/param_test.go000066400000000000000000000020221436021606300151620ustar00rootroot00000000000000package sctp import ( "testing" "github.com/stretchr/testify/assert" ) func TestBuildParam_Success(t *testing.T) { tt := []struct { binary []byte }{ {testChunkReconfigParamA()}, } for i, tc := range tt { pType, err := parseParamType(tc.binary) if err != nil { t.Fatalf("failed to parse param type: %v", err) } p, err := buildParam(pType, tc.binary) if err != nil { t.Fatalf("failed to unmarshal #%d: %v", i, err) } b, err := p.marshal() if err != nil { t.Fatalf("failed to marshal: %v", err) } assert.Equal(t, tc.binary, b) } } func TestBuildParam_Failure(t *testing.T) { tt := []struct { name string binary []byte }{ {"invalid ParamType", []byte{0x0, 0x0}}, {"build failure", testChunkReconfigParamA()[:8]}, } for i, tc := range tt { pType, err := parseParamType(tc.binary) if err != nil { t.Fatalf("failed to parse param type: %v", err) } _, err = buildParam(pType, tc.binary) if err == nil { t.Errorf("expected unmarshal #%d: '%s' to fail.", i, tc.name) } } } sctp-1.8.6/paramheader.go000066400000000000000000000037621436021606300153100ustar00rootroot00000000000000package sctp import ( "encoding/binary" "encoding/hex" "errors" "fmt" ) type paramHeader struct { typ paramType len int raw []byte } const ( paramHeaderLength = 4 ) // Parameter header parse errors var ( ErrParamHeaderTooShort = errors.New("param header too short") ErrParamHeaderSelfReportedLengthShorter = errors.New("param self reported length is shorter than header length") ErrParamHeaderSelfReportedLengthLonger = errors.New("param self reported length is longer than header length") ErrParamHeaderParseFailed = errors.New("failed to parse param type") ) func (p *paramHeader) marshal() ([]byte, error) { paramLengthPlusHeader := paramHeaderLength + len(p.raw) rawParam := make([]byte, paramLengthPlusHeader) binary.BigEndian.PutUint16(rawParam[0:], uint16(p.typ)) binary.BigEndian.PutUint16(rawParam[2:], uint16(paramLengthPlusHeader)) copy(rawParam[paramHeaderLength:], p.raw) return rawParam, nil } func (p *paramHeader) unmarshal(raw []byte) error { if len(raw) < paramHeaderLength { return ErrParamHeaderTooShort } paramLengthPlusHeader := binary.BigEndian.Uint16(raw[2:]) if int(paramLengthPlusHeader) < paramHeaderLength { return fmt.Errorf("%w: param self reported length (%d) shorter than header length (%d)", ErrParamHeaderSelfReportedLengthShorter, int(paramLengthPlusHeader), paramHeaderLength) } if len(raw) < int(paramLengthPlusHeader) { return fmt.Errorf("%w: param length (%d) shorter than its self reported length (%d)", ErrParamHeaderSelfReportedLengthLonger, len(raw), int(paramLengthPlusHeader)) } typ, err := parseParamType(raw[0:]) if err != nil { return fmt.Errorf("%w: %v", ErrParamHeaderParseFailed, err) } p.typ = typ p.raw = raw[paramHeaderLength:paramLengthPlusHeader] p.len = int(paramLengthPlusHeader) return nil } func (p *paramHeader) length() int { return p.len } // String makes paramHeader printable func (p paramHeader) String() string { return fmt.Sprintf("%s (%d): %s", p.typ, p.len, hex.Dump(p.raw)) } sctp-1.8.6/paramheader_test.go000066400000000000000000000023601436021606300163400ustar00rootroot00000000000000package sctp import ( "testing" "github.com/stretchr/testify/assert" ) func testParamHeader() []byte { return []byte{0x0, 0x1, 0x0, 0x4} } func TestParamHeader_Success(t *testing.T) { tt := []struct { binary []byte parsed *paramHeader }{ { testParamHeader(), ¶mHeader{ typ: heartbeatInfo, len: 4, raw: []byte{}, }, }, } for i, tc := range tt { actual := ¶mHeader{} err := actual.unmarshal(tc.binary) if err != nil { t.Errorf("failed to unmarshal #%d: %v", i, err) } assert.Equal(t, tc.parsed, actual) b, err := actual.marshal() if err != nil { t.Errorf("failed to marshal: %v", err) } assert.Equal(t, tc.binary, b) } } func TestParamHeaderUnmarshal_Failure(t *testing.T) { tt := []struct { name string binary []byte }{ {"header too short", testParamHeader()[:2]}, // {"wrong param type", []byte{0x0, 0x0, 0x0, 0x4}}, // Not possible to fail parseParamType atm. {"reported length below header length", []byte{0x0, 0xd, 0x0, 0x3}}, {"wrong reported length", testChunkReconfigParamA()[:4]}, } for i, tc := range tt { actual := ¶mHeader{} err := actual.unmarshal(tc.binary) if err == nil { t.Errorf("expected unmarshal #%d: '%s' to fail.", i, tc.name) } } } sctp-1.8.6/paramtype.go000066400000000000000000000076141436021606300150410ustar00rootroot00000000000000package sctp import ( "encoding/binary" "errors" "fmt" ) // paramType represents a SCTP INIT/INITACK parameter type paramType uint16 const ( heartbeatInfo paramType = 1 // Heartbeat Info [RFC4960] ipV4Addr paramType = 5 // IPv4 IP [RFC4960] ipV6Addr paramType = 6 // IPv6 IP [RFC4960] stateCookie paramType = 7 // State Cookie [RFC4960] unrecognizedParam paramType = 8 // Unrecognized Parameters [RFC4960] cookiePreservative paramType = 9 // Cookie Preservative [RFC4960] hostNameAddr paramType = 11 // Host Name IP [RFC4960] supportedAddrTypes paramType = 12 // Supported IP Types [RFC4960] outSSNResetReq paramType = 13 // Outgoing SSN Reset Request Parameter [RFC6525] incSSNResetReq paramType = 14 // Incoming SSN Reset Request Parameter [RFC6525] ssnTSNResetReq paramType = 15 // SSN/TSN Reset Request Parameter [RFC6525] reconfigResp paramType = 16 // Re-configuration Response Parameter [RFC6525] addOutStreamsReq paramType = 17 // Add Outgoing Streams Request Parameter [RFC6525] addIncStreamsReq paramType = 18 // Add Incoming Streams Request Parameter [RFC6525] ecnCapable paramType = 32768 // ECN Capable (0x8000) [RFC2960] random paramType = 32770 // Random (0x8002) [RFC4805] chunkList paramType = 32771 // Chunk List (0x8003) [RFC4895] reqHMACAlgo paramType = 32772 // Requested HMAC Algorithm Parameter (0x8004) [RFC4895] padding paramType = 32773 // Padding (0x8005) supportedExt paramType = 32776 // Supported Extensions (0x8008) [RFC5061] forwardTSNSupp paramType = 49152 // Forward TSN supported (0xC000) [RFC3758] addIPAddr paramType = 49153 // Add IP IP (0xC001) [RFC5061] delIPAddr paramType = 49154 // Delete IP IP (0xC002) [RFC5061] errClauseInd paramType = 49155 // Error Cause Indication (0xC003) [RFC5061] setPriAddr paramType = 49156 // Set Primary IP (0xC004) [RFC5061] successInd paramType = 49157 // Success Indication (0xC005) [RFC5061] adaptLayerInd paramType = 49158 // Adaptation Layer Indication (0xC006) [RFC5061] ) // Parameter packet errors var ( ErrParamPacketTooShort = errors.New("packet to short") ) func parseParamType(raw []byte) (paramType, error) { if len(raw) < 2 { return paramType(0), ErrParamPacketTooShort } return paramType(binary.BigEndian.Uint16(raw)), nil } func (p paramType) String() string { switch p { case heartbeatInfo: return "Heartbeat Info" case ipV4Addr: return "IPv4 IP" case ipV6Addr: return "IPv6 IP" case stateCookie: return "State Cookie" case unrecognizedParam: return "Unrecognized Parameters" case cookiePreservative: return "Cookie Preservative" case hostNameAddr: return "Host Name IP" case supportedAddrTypes: return "Supported IP Types" case outSSNResetReq: return "Outgoing SSN Reset Request Parameter" case incSSNResetReq: return "Incoming SSN Reset Request Parameter" case ssnTSNResetReq: return "SSN/TSN Reset Request Parameter" case reconfigResp: return "Re-configuration Response Parameter" case addOutStreamsReq: return "Add Outgoing Streams Request Parameter" case addIncStreamsReq: return "Add Incoming Streams Request Parameter" case ecnCapable: return "ECN Capable" case random: return "Random" case chunkList: return "Chunk List" case reqHMACAlgo: return "Requested HMAC Algorithm Parameter" case padding: return "Padding" case supportedExt: return "Supported Extensions" case forwardTSNSupp: return "Forward TSN supported" case addIPAddr: return "Add IP IP" case delIPAddr: return "Delete IP IP" case errClauseInd: return "Error Cause Indication" case setPriAddr: return "Set Primary IP" case successInd: return "Success Indication" case adaptLayerInd: return "Adaptation Layer Indication" default: return fmt.Sprintf("Unknown ParamType: %d", p) } } sctp-1.8.6/paramtype_test.go000066400000000000000000000013551436021606300160740ustar00rootroot00000000000000package sctp import ( "testing" "github.com/stretchr/testify/assert" ) func TestParseParamType_Success(t *testing.T) { tt := []struct { binary []byte expected paramType }{ {[]byte{0x0, 0x1}, heartbeatInfo}, {[]byte{0x0, 0xd}, outSSNResetReq}, } for i, tc := range tt { pType, err := parseParamType(tc.binary) if err != nil { t.Fatalf("failed to parse paramType %d: %v", i, err) } assert.Equal(t, tc.expected, pType) } } func TestParseParamType_Failure(t *testing.T) { tt := []struct { name string binary []byte }{ {"empty packet", []byte{}}, } for i, tc := range tt { _, err := parseParamType(tc.binary) if err == nil { t.Errorf("expected parseParamType #%d: '%s' to fail.", i, tc.name) } } } sctp-1.8.6/payload_queue.go000066400000000000000000000072661436021606300156770ustar00rootroot00000000000000package sctp import ( "fmt" "sort" ) type payloadQueue struct { chunkMap map[uint32]*chunkPayloadData sorted []uint32 dupTSN []uint32 nBytes int } func newPayloadQueue() *payloadQueue { return &payloadQueue{chunkMap: map[uint32]*chunkPayloadData{}} } func (q *payloadQueue) updateSortedKeys() { if q.sorted != nil { return } q.sorted = make([]uint32, len(q.chunkMap)) i := 0 for k := range q.chunkMap { q.sorted[i] = k i++ } sort.Slice(q.sorted, func(i, j int) bool { return sna32LT(q.sorted[i], q.sorted[j]) }) } func (q *payloadQueue) canPush(p *chunkPayloadData, cumulativeTSN uint32) bool { _, ok := q.chunkMap[p.tsn] if ok || sna32LTE(p.tsn, cumulativeTSN) { return false } return true } func (q *payloadQueue) pushNoCheck(p *chunkPayloadData) { q.chunkMap[p.tsn] = p q.nBytes += len(p.userData) q.sorted = nil } // push pushes a payload data. If the payload data is already in our queue or // older than our cumulativeTSN marker, it will be recored as duplications, // which can later be retrieved using popDuplicates. func (q *payloadQueue) push(p *chunkPayloadData, cumulativeTSN uint32) bool { _, ok := q.chunkMap[p.tsn] if ok || sna32LTE(p.tsn, cumulativeTSN) { // Found the packet, log in dups q.dupTSN = append(q.dupTSN, p.tsn) return false } q.chunkMap[p.tsn] = p q.nBytes += len(p.userData) q.sorted = nil return true } // pop pops only if the oldest chunk's TSN matches the given TSN. func (q *payloadQueue) pop(tsn uint32) (*chunkPayloadData, bool) { q.updateSortedKeys() if len(q.chunkMap) > 0 && tsn == q.sorted[0] { q.sorted = q.sorted[1:] if c, ok := q.chunkMap[tsn]; ok { delete(q.chunkMap, tsn) q.nBytes -= len(c.userData) return c, true } } return nil, false } // get returns reference to chunkPayloadData with the given TSN value. func (q *payloadQueue) get(tsn uint32) (*chunkPayloadData, bool) { c, ok := q.chunkMap[tsn] return c, ok } // popDuplicates returns an array of TSN values that were found duplicate. func (q *payloadQueue) popDuplicates() []uint32 { dups := q.dupTSN q.dupTSN = []uint32{} return dups } func (q *payloadQueue) getGapAckBlocks(cumulativeTSN uint32) (gapAckBlocks []gapAckBlock) { var b gapAckBlock if len(q.chunkMap) == 0 { return []gapAckBlock{} } q.updateSortedKeys() for i, tsn := range q.sorted { if i == 0 { b.start = uint16(tsn - cumulativeTSN) b.end = b.start continue } diff := uint16(tsn - cumulativeTSN) if b.end+1 == diff { b.end++ } else { gapAckBlocks = append(gapAckBlocks, gapAckBlock{ start: b.start, end: b.end, }) b.start = diff b.end = diff } } gapAckBlocks = append(gapAckBlocks, gapAckBlock{ start: b.start, end: b.end, }) return gapAckBlocks } func (q *payloadQueue) getGapAckBlocksString(cumulativeTSN uint32) string { gapAckBlocks := q.getGapAckBlocks(cumulativeTSN) str := fmt.Sprintf("cumTSN=%d", cumulativeTSN) for _, b := range gapAckBlocks { str += fmt.Sprintf(",%d-%d", b.start, b.end) } return str } func (q *payloadQueue) markAsAcked(tsn uint32) int { var nBytesAcked int if c, ok := q.chunkMap[tsn]; ok { c.acked = true c.retransmit = false nBytesAcked = len(c.userData) q.nBytes -= nBytesAcked c.userData = []byte{} } return nBytesAcked } func (q *payloadQueue) getLastTSNReceived() (uint32, bool) { q.updateSortedKeys() qlen := len(q.sorted) if qlen == 0 { return 0, false } return q.sorted[qlen-1], true } func (q *payloadQueue) markAllToRetrasmit() { for _, c := range q.chunkMap { if c.acked || c.abandoned() { continue } c.retransmit = true } } func (q *payloadQueue) getNumBytes() int { return q.nBytes } func (q *payloadQueue) size() int { return len(q.chunkMap) } sctp-1.8.6/payload_queue_test.go000066400000000000000000000114761436021606300167340ustar00rootroot00000000000000package sctp import ( "testing" "github.com/stretchr/testify/assert" ) func makePayload(tsn uint32, nBytes int) *chunkPayloadData { return &chunkPayloadData{tsn: tsn, userData: make([]byte, nBytes)} } func TestPayloadQueue(t *testing.T) { t.Run("pushNoCheck", func(t *testing.T) { pq := newPayloadQueue() pq.pushNoCheck(makePayload(0, 10)) assert.Equal(t, 10, pq.getNumBytes(), "total bytes mismatch") assert.Equal(t, 1, pq.size(), "item count mismatch") pq.pushNoCheck(makePayload(1, 11)) assert.Equal(t, 21, pq.getNumBytes(), "total bytes mismatch") assert.Equal(t, 2, pq.size(), "item count mismatch") pq.pushNoCheck(makePayload(2, 12)) assert.Equal(t, 33, pq.getNumBytes(), "total bytes mismatch") assert.Equal(t, 3, pq.size(), "item count mismatch") for i := uint32(0); i < 3; i++ { c, ok := pq.pop(i) assert.True(t, ok, "pop should succeed") if ok { assert.Equal(t, i, c.tsn, "TSN should match") assert.NotNil(t, pq.sorted, "should not be nil") } } assert.Equal(t, 0, pq.getNumBytes(), "total bytes mismatch") assert.Equal(t, 0, pq.size(), "item count mismatch") pq.pushNoCheck(makePayload(3, 13)) assert.Nil(t, pq.sorted, "should be nil") assert.Equal(t, 13, pq.getNumBytes(), "total bytes mismatch") pq.pushNoCheck(makePayload(4, 14)) assert.Nil(t, pq.sorted, "should be nil") assert.Equal(t, 27, pq.getNumBytes(), "total bytes mismatch") for i := uint32(3); i < 5; i++ { c, ok := pq.pop(i) assert.True(t, ok, "pop should succeed") if ok { assert.Equal(t, i, c.tsn, "TSN should match") assert.NotNil(t, pq.sorted, "should not be nil") } } assert.Equal(t, 0, pq.getNumBytes(), "total bytes mismatch") assert.Equal(t, 0, pq.size(), "item count mismatch") }) t.Run("getGapAckBlocks", func(t *testing.T) { pq := newPayloadQueue() pq.push(makePayload(1, 0), 0) pq.push(makePayload(2, 0), 0) pq.push(makePayload(3, 0), 0) pq.push(makePayload(4, 0), 0) pq.push(makePayload(5, 0), 0) pq.push(makePayload(6, 0), 0) gab1 := []*gapAckBlock{{start: 1, end: 6}} gab2 := pq.getGapAckBlocks(0) assert.NotNil(t, gab2) assert.Len(t, gab2, 1) assert.Equal(t, gab1[0].start, gab2[0].start) assert.Equal(t, gab1[0].end, gab2[0].end) pq.push(makePayload(8, 0), 0) pq.push(makePayload(9, 0), 0) gab1 = []*gapAckBlock{{start: 1, end: 6}, {start: 8, end: 9}} gab2 = pq.getGapAckBlocks(0) assert.NotNil(t, gab2) assert.Len(t, gab2, 2) assert.Equal(t, gab1[0].start, gab2[0].start) assert.Equal(t, gab1[0].end, gab2[0].end) assert.Equal(t, gab1[1].start, gab2[1].start) assert.Equal(t, gab1[1].end, gab2[1].end) }) t.Run("getLastTSNReceived", func(t *testing.T) { pq := newPayloadQueue() // empty queie should return false _, ok := pq.getLastTSNReceived() assert.False(t, ok, "should be false") ok = pq.push(makePayload(20, 0), 0) assert.True(t, ok, "should be true") tsn, ok := pq.getLastTSNReceived() assert.True(t, ok, "should be false") assert.Equal(t, uint32(20), tsn, "should match") // append should work ok = pq.push(makePayload(21, 0), 0) assert.True(t, ok, "should be true") tsn, ok = pq.getLastTSNReceived() assert.True(t, ok, "should be false") assert.Equal(t, uint32(21), tsn, "should match") // check if sorting applied ok = pq.push(makePayload(19, 0), 0) assert.True(t, ok, "should be true") tsn, ok = pq.getLastTSNReceived() assert.True(t, ok, "should be false") assert.Equal(t, uint32(21), tsn, "should match") }) t.Run("markAllToRetrasmit", func(t *testing.T) { pq := newPayloadQueue() for i := 0; i < 3; i++ { pq.push(makePayload(uint32(i+1), 10), 0) } pq.markAsAcked(2) pq.markAllToRetrasmit() c, ok := pq.get(1) assert.True(t, ok, "should be true") assert.True(t, c.retransmit, "should be marked as retransmit") c, ok = pq.get(2) assert.True(t, ok, "should be true") assert.False(t, c.retransmit, "should NOT be marked as retransmit") c, ok = pq.get(3) assert.True(t, ok, "should be true") assert.True(t, c.retransmit, "should be marked as retransmit") }) t.Run("reset retransmit flag on ack", func(t *testing.T) { pq := newPayloadQueue() for i := 0; i < 4; i++ { pq.push(makePayload(uint32(i+1), 10), 0) } pq.markAllToRetrasmit() pq.markAsAcked(2) // should cancel retransmission for TSN 2 pq.markAsAcked(4) // should cancel retransmission for TSN 4 c, ok := pq.get(1) assert.True(t, ok, "should be true") assert.True(t, c.retransmit, "should be marked as retransmit") c, ok = pq.get(2) assert.True(t, ok, "should be true") assert.False(t, c.retransmit, "should NOT be marked as retransmit") c, ok = pq.get(3) assert.True(t, ok, "should be true") assert.True(t, c.retransmit, "should be marked as retransmit") c, ok = pq.get(4) assert.True(t, ok, "should be true") assert.False(t, c.retransmit, "should NOT be marked as retransmit") }) } sctp-1.8.6/pending_queue.go000066400000000000000000000054561436021606300156710ustar00rootroot00000000000000package sctp import ( "errors" ) // pendingBaseQueue type pendingBaseQueue struct { queue []*chunkPayloadData } func newPendingBaseQueue() *pendingBaseQueue { return &pendingBaseQueue{queue: []*chunkPayloadData{}} } func (q *pendingBaseQueue) push(c *chunkPayloadData) { q.queue = append(q.queue, c) } func (q *pendingBaseQueue) pop() *chunkPayloadData { if len(q.queue) == 0 { return nil } c := q.queue[0] q.queue = q.queue[1:] return c } func (q *pendingBaseQueue) get(i int) *chunkPayloadData { if len(q.queue) == 0 || i < 0 || i >= len(q.queue) { return nil } return q.queue[i] } func (q *pendingBaseQueue) size() int { return len(q.queue) } // pendingQueue type pendingQueue struct { unorderedQueue *pendingBaseQueue orderedQueue *pendingBaseQueue nBytes int selected bool unorderedIsSelected bool } // Pending queue errors var ( ErrUnexpectedChuckPoppedUnordered = errors.New("unexpected chunk popped (unordered)") ErrUnexpectedChuckPoppedOrdered = errors.New("unexpected chunk popped (ordered)") ErrUnexpectedQState = errors.New("unexpected q state (should've been selected)") ) func newPendingQueue() *pendingQueue { return &pendingQueue{ unorderedQueue: newPendingBaseQueue(), orderedQueue: newPendingBaseQueue(), } } func (q *pendingQueue) push(c *chunkPayloadData) { if c.unordered { q.unorderedQueue.push(c) } else { q.orderedQueue.push(c) } q.nBytes += len(c.userData) } func (q *pendingQueue) peek() *chunkPayloadData { if q.selected { if q.unorderedIsSelected { return q.unorderedQueue.get(0) } return q.orderedQueue.get(0) } if c := q.unorderedQueue.get(0); c != nil { return c } return q.orderedQueue.get(0) } func (q *pendingQueue) pop(c *chunkPayloadData) error { if q.selected { var popped *chunkPayloadData if q.unorderedIsSelected { popped = q.unorderedQueue.pop() if popped != c { return ErrUnexpectedChuckPoppedUnordered } } else { popped = q.orderedQueue.pop() if popped != c { return ErrUnexpectedChuckPoppedOrdered } } if popped.endingFragment { q.selected = false } } else { if !c.beginningFragment { return ErrUnexpectedQState } if c.unordered { popped := q.unorderedQueue.pop() if popped != c { return ErrUnexpectedChuckPoppedUnordered } if !popped.endingFragment { q.selected = true q.unorderedIsSelected = true } } else { popped := q.orderedQueue.pop() if popped != c { return ErrUnexpectedChuckPoppedOrdered } if !popped.endingFragment { q.selected = true q.unorderedIsSelected = false } } } q.nBytes -= len(c.userData) return nil } func (q *pendingQueue) getNumBytes() int { return q.nBytes } func (q *pendingQueue) size() int { return q.unorderedQueue.size() + q.orderedQueue.size() } sctp-1.8.6/pending_queue_test.go000066400000000000000000000123321436021606300167170ustar00rootroot00000000000000package sctp import ( "testing" "github.com/stretchr/testify/assert" ) const ( noFragment = iota fragBegin fragMiddle fragEnd ) func makeDataChunk(tsn uint32, unordered bool, frag int) *chunkPayloadData { var b, e bool switch frag { case noFragment: b = true e = true case fragBegin: b = true case fragEnd: e = true } return &chunkPayloadData{ tsn: tsn, unordered: unordered, beginningFragment: b, endingFragment: e, userData: make([]byte, 10), // always 10 bytes } } func TestPendingBaseQueue(t *testing.T) { t.Run("push and pop", func(t *testing.T) { pq := newPendingBaseQueue() pq.push(makeDataChunk(0, false, noFragment)) pq.push(makeDataChunk(1, false, noFragment)) pq.push(makeDataChunk(2, false, noFragment)) for i := uint32(0); i < 3; i++ { c := pq.get(int(i)) assert.NotNil(t, c, "should not be nil") assert.Equal(t, i, c.tsn, "TSN should match") } for i := uint32(0); i < 3; i++ { c := pq.pop() assert.NotNil(t, c, "should not be nil") assert.Equal(t, i, c.tsn, "TSN should match") } pq.push(makeDataChunk(3, false, noFragment)) pq.push(makeDataChunk(4, false, noFragment)) for i := uint32(3); i < 5; i++ { c := pq.pop() assert.NotNil(t, c, "should not be nil") assert.Equal(t, i, c.tsn, "TSN should match") } }) t.Run("out of bounce", func(t *testing.T) { pq := newPendingBaseQueue() assert.Nil(t, pq.pop(), "should be nil") assert.Nil(t, pq.get(0), "should be nil") pq.push(makeDataChunk(0, false, noFragment)) assert.Nil(t, pq.get(-1), "should be nil") assert.Nil(t, pq.get(1), "should be nil") }) } func TestPendingQueue(t *testing.T) { // NOTE: TSN is not used in pendingQueue in the actual usage. // Following tests use TSN field as a chunk ID. t.Run("push and pop", func(t *testing.T) { pq := newPendingQueue() pq.push(makeDataChunk(0, false, noFragment)) assert.Equal(t, 10, pq.getNumBytes(), "total bytes mismatch") pq.push(makeDataChunk(1, false, noFragment)) assert.Equal(t, 20, pq.getNumBytes(), "total bytes mismatch") pq.push(makeDataChunk(2, false, noFragment)) assert.Equal(t, 30, pq.getNumBytes(), "total bytes mismatch") for i := uint32(0); i < 3; i++ { c := pq.peek() err := pq.pop(c) assert.Nil(t, err, "should not error") assert.Equal(t, i, c.tsn, "TSN should match") } assert.Equal(t, 0, pq.getNumBytes(), "total bytes mismatch") pq.push(makeDataChunk(3, false, noFragment)) assert.Equal(t, 10, pq.getNumBytes(), "total bytes mismatch") pq.push(makeDataChunk(4, false, noFragment)) assert.Equal(t, 20, pq.getNumBytes(), "total bytes mismatch") for i := uint32(3); i < 5; i++ { c := pq.peek() err := pq.pop(c) assert.Nil(t, err, "should not error") assert.Equal(t, i, c.tsn, "TSN should match") } assert.Equal(t, 0, pq.getNumBytes(), "total bytes mismatch") }) t.Run("unordered wins", func(t *testing.T) { pq := newPendingQueue() pq.push(makeDataChunk(0, false, noFragment)) assert.Equal(t, 10, pq.getNumBytes(), "total bytes mismatch") pq.push(makeDataChunk(1, true, noFragment)) assert.Equal(t, 20, pq.getNumBytes(), "total bytes mismatch") pq.push(makeDataChunk(2, false, noFragment)) assert.Equal(t, 30, pq.getNumBytes(), "total bytes mismatch") pq.push(makeDataChunk(3, true, noFragment)) assert.Equal(t, 40, pq.getNumBytes(), "total bytes mismatch") c := pq.peek() err := pq.pop(c) assert.NoError(t, err, "should not error") assert.Equal(t, uint32(1), c.tsn, "TSN should match") c = pq.peek() err = pq.pop(c) assert.NoError(t, err, "should not error") assert.Equal(t, uint32(3), c.tsn, "TSN should match") c = pq.peek() err = pq.pop(c) assert.NoError(t, err, "should not error") assert.Equal(t, uint32(0), c.tsn, "TSN should match") c = pq.peek() err = pq.pop(c) assert.NoError(t, err, "should not error") assert.Equal(t, uint32(2), c.tsn, "TSN should match") assert.Equal(t, 0, pq.getNumBytes(), "total bytes mismatch") }) t.Run("fragments", func(t *testing.T) { pq := newPendingQueue() pq.push(makeDataChunk(0, false, fragBegin)) pq.push(makeDataChunk(1, false, fragMiddle)) pq.push(makeDataChunk(2, false, fragEnd)) pq.push(makeDataChunk(3, true, fragBegin)) pq.push(makeDataChunk(4, true, fragMiddle)) pq.push(makeDataChunk(5, true, fragEnd)) expects := []uint32{3, 4, 5, 0, 1, 2} for _, exp := range expects { c := pq.peek() err := pq.pop(c) assert.NoError(t, err, "should not error") assert.Equal(t, exp, c.tsn, "TSN should match") } }) // Once decided ordered or unordered, the decision should persist until // it pops a chunk with endingFragment flags set to true. t.Run("selection persistence", func(t *testing.T) { pq := newPendingQueue() pq.push(makeDataChunk(0, false, fragBegin)) c := pq.peek() err := pq.pop(c) assert.NoError(t, err, "should not error") assert.Equal(t, uint32(0), c.tsn, "TSN should match") pq.push(makeDataChunk(1, true, noFragment)) pq.push(makeDataChunk(2, false, fragMiddle)) pq.push(makeDataChunk(3, false, fragEnd)) expects := []uint32{2, 3, 1} for _, exp := range expects { c = pq.peek() err = pq.pop(c) assert.NoError(t, err, "should not error") assert.Equal(t, exp, c.tsn, "TSN should match") } }) } sctp-1.8.6/reassembly_queue.go000066400000000000000000000167261436021606300164150ustar00rootroot00000000000000package sctp import ( "errors" "io" "sort" "sync/atomic" ) func sortChunksByTSN(a []*chunkPayloadData) { sort.Slice(a, func(i, j int) bool { return sna32LT(a[i].tsn, a[j].tsn) }) } func sortChunksBySSN(a []*chunkSet) { sort.Slice(a, func(i, j int) bool { return sna16LT(a[i].ssn, a[j].ssn) }) } // chunkSet is a set of chunks that share the same SSN type chunkSet struct { ssn uint16 // used only with the ordered chunks ppi PayloadProtocolIdentifier chunks []*chunkPayloadData } func newChunkSet(ssn uint16, ppi PayloadProtocolIdentifier) *chunkSet { return &chunkSet{ ssn: ssn, ppi: ppi, chunks: []*chunkPayloadData{}, } } func (set *chunkSet) push(chunk *chunkPayloadData) bool { // check if dup for _, c := range set.chunks { if c.tsn == chunk.tsn { return false } } // append and sort set.chunks = append(set.chunks, chunk) sortChunksByTSN(set.chunks) // Check if we now have a complete set complete := set.isComplete() return complete } func (set *chunkSet) isComplete() bool { // Condition for complete set // 0. Has at least one chunk. // 1. Begins with beginningFragment set to true // 2. Ends with endingFragment set to true // 3. TSN monotinically increase by 1 from beginning to end // 0. nChunks := len(set.chunks) if nChunks == 0 { return false } // 1. if !set.chunks[0].beginningFragment { return false } // 2. if !set.chunks[nChunks-1].endingFragment { return false } // 3. var lastTSN uint32 for i, c := range set.chunks { if i > 0 { // Fragments must have contiguous TSN // From RFC 4960 Section 3.3.1: // When a user message is fragmented into multiple chunks, the TSNs are // used by the receiver to reassemble the message. This means that the // TSNs for each fragment of a fragmented user message MUST be strictly // sequential. if c.tsn != lastTSN+1 { // mid or end fragment is missing return false } } lastTSN = c.tsn } return true } type reassemblyQueue struct { si uint16 nextSSN uint16 // expected SSN for next ordered chunk ordered []*chunkSet unordered []*chunkSet unorderedChunks []*chunkPayloadData nBytes uint64 } var errTryAgain = errors.New("try again") func newReassemblyQueue(si uint16) *reassemblyQueue { // From RFC 4960 Sec 6.5: // The Stream Sequence Number in all the streams MUST start from 0 when // the association is established. Also, when the Stream Sequence // Number reaches the value 65535 the next Stream Sequence Number MUST // be set to 0. return &reassemblyQueue{ si: si, nextSSN: 0, // From RFC 4960 Sec 6.5: ordered: make([]*chunkSet, 0), unordered: make([]*chunkSet, 0), } } func (r *reassemblyQueue) push(chunk *chunkPayloadData) bool { var cset *chunkSet if chunk.streamIdentifier != r.si { return false } if chunk.unordered { // First, insert into unorderedChunks array r.unorderedChunks = append(r.unorderedChunks, chunk) atomic.AddUint64(&r.nBytes, uint64(len(chunk.userData))) sortChunksByTSN(r.unorderedChunks) // Scan unorderedChunks that are contiguous (in TSN) cset = r.findCompleteUnorderedChunkSet() // If found, append the complete set to the unordered array if cset != nil { r.unordered = append(r.unordered, cset) return true } return false } // This is an ordered chunk if sna16LT(chunk.streamSequenceNumber, r.nextSSN) { return false } // Check if a chunkSet with the SSN already exists for _, set := range r.ordered { if set.ssn == chunk.streamSequenceNumber { cset = set break } } // If not found, create a new chunkSet if cset == nil { cset = newChunkSet(chunk.streamSequenceNumber, chunk.payloadType) r.ordered = append(r.ordered, cset) if !chunk.unordered { sortChunksBySSN(r.ordered) } } atomic.AddUint64(&r.nBytes, uint64(len(chunk.userData))) return cset.push(chunk) } func (r *reassemblyQueue) findCompleteUnorderedChunkSet() *chunkSet { startIdx := -1 nChunks := 0 var lastTSN uint32 var found bool for i, c := range r.unorderedChunks { // seek beigining if c.beginningFragment { startIdx = i nChunks = 1 lastTSN = c.tsn if c.endingFragment { found = true break } continue } if startIdx < 0 { continue } // Check if contiguous in TSN if c.tsn != lastTSN+1 { startIdx = -1 continue } lastTSN = c.tsn nChunks++ if c.endingFragment { found = true break } } if !found { return nil } // Extract the range of chunks var chunks []*chunkPayloadData chunks = append(chunks, r.unorderedChunks[startIdx:startIdx+nChunks]...) r.unorderedChunks = append( r.unorderedChunks[:startIdx], r.unorderedChunks[startIdx+nChunks:]...) chunkSet := newChunkSet(0, chunks[0].payloadType) chunkSet.chunks = chunks return chunkSet } func (r *reassemblyQueue) isReadable() bool { // Check unordered first if len(r.unordered) > 0 { // The chunk sets in r.unordered should all be complete. return true } // Check ordered sets if len(r.ordered) > 0 { cset := r.ordered[0] if cset.isComplete() { if sna16LTE(cset.ssn, r.nextSSN) { return true } } } return false } func (r *reassemblyQueue) read(buf []byte) (int, PayloadProtocolIdentifier, error) { var cset *chunkSet // Check unordered first switch { case len(r.unordered) > 0: cset = r.unordered[0] r.unordered = r.unordered[1:] case len(r.ordered) > 0: // Now, check ordered cset = r.ordered[0] if !cset.isComplete() { return 0, 0, errTryAgain } if sna16GT(cset.ssn, r.nextSSN) { return 0, 0, errTryAgain } r.ordered = r.ordered[1:] if cset.ssn == r.nextSSN { r.nextSSN++ } default: return 0, 0, errTryAgain } // Concat all fragments into the buffer nWritten := 0 ppi := cset.ppi var err error for _, c := range cset.chunks { toCopy := len(c.userData) r.subtractNumBytes(toCopy) if err == nil { n := copy(buf[nWritten:], c.userData) nWritten += n if n < toCopy { err = io.ErrShortBuffer } } } return nWritten, ppi, err } func (r *reassemblyQueue) forwardTSNForOrdered(lastSSN uint16) { // Use lastSSN to locate a chunkSet then remove it if the set has // not been complete keep := []*chunkSet{} for _, set := range r.ordered { if sna16LTE(set.ssn, lastSSN) { if !set.isComplete() { // drop the set for _, c := range set.chunks { r.subtractNumBytes(len(c.userData)) } continue } } keep = append(keep, set) } r.ordered = keep // Finally, forward nextSSN if sna16LTE(r.nextSSN, lastSSN) { r.nextSSN = lastSSN + 1 } } func (r *reassemblyQueue) forwardTSNForUnordered(newCumulativeTSN uint32) { // Remove all fragments in the unordered sets that contains chunks // equal to or older than `newCumulativeTSN`. // We know all sets in the r.unordered are complete ones. // Just remove chunks that are equal to or older than newCumulativeTSN // from the unorderedChunks lastIdx := -1 for i, c := range r.unorderedChunks { if sna32GT(c.tsn, newCumulativeTSN) { break } lastIdx = i } if lastIdx >= 0 { for _, c := range r.unorderedChunks[0 : lastIdx+1] { r.subtractNumBytes(len(c.userData)) } r.unorderedChunks = r.unorderedChunks[lastIdx+1:] } } func (r *reassemblyQueue) subtractNumBytes(nBytes int) { cur := atomic.LoadUint64(&r.nBytes) if int(cur) >= nBytes { atomic.AddUint64(&r.nBytes, -uint64(nBytes)) } else { atomic.StoreUint64(&r.nBytes, 0) } } func (r *reassemblyQueue) getNumBytes() int { return int(atomic.LoadUint64(&r.nBytes)) } sctp-1.8.6/reassembly_queue_test.go000066400000000000000000000343541436021606300174510ustar00rootroot00000000000000package sctp import ( "io" "testing" "github.com/stretchr/testify/assert" ) func TestReassemblyQueue(t *testing.T) { t.Run("ordered fragments", func(t *testing.T) { rq := newReassemblyQueue(0) orgPpi := PayloadTypeWebRTCBinary var chunk *chunkPayloadData var complete bool chunk = &chunkPayloadData{ payloadType: orgPpi, beginningFragment: true, tsn: 1, streamSequenceNumber: 0, userData: []byte("ABC"), } complete = rq.push(chunk) assert.False(t, complete, "chunk set should not be complete yet") assert.Equal(t, 3, rq.getNumBytes(), "num bytes mismatch") chunk = &chunkPayloadData{ payloadType: orgPpi, endingFragment: true, tsn: 2, streamSequenceNumber: 0, userData: []byte("DEFG"), } complete = rq.push(chunk) assert.True(t, complete, "chunk set should be complete") assert.Equal(t, 7, rq.getNumBytes(), "num bytes mismatch") buf := make([]byte, 16) n, ppi, err := rq.read(buf) assert.Nil(t, err, "read() should succeed") assert.Equal(t, 7, n, "should received 7 bytes") assert.Equal(t, 0, rq.getNumBytes(), "num bytes mismatch") assert.Equal(t, ppi, orgPpi, "should have valid ppi") assert.Equal(t, string(buf[:n]), "ABCDEFG", "data should match") }) t.Run("ordered fragments", func(t *testing.T) { rq := newReassemblyQueue(0) orgPpi := PayloadTypeWebRTCBinary var chunk *chunkPayloadData var complete bool chunk = &chunkPayloadData{ payloadType: orgPpi, unordered: true, beginningFragment: true, tsn: 1, streamSequenceNumber: 0, userData: []byte("ABC"), } complete = rq.push(chunk) assert.False(t, complete, "chunk set should not be complete yet") assert.Equal(t, 3, rq.getNumBytes(), "num bytes mismatch") chunk = &chunkPayloadData{ payloadType: orgPpi, unordered: true, tsn: 2, streamSequenceNumber: 0, userData: []byte("DEFG"), } complete = rq.push(chunk) assert.False(t, complete, "chunk set should not be complete yet") assert.Equal(t, 7, rq.getNumBytes(), "num bytes mismatch") chunk = &chunkPayloadData{ payloadType: orgPpi, unordered: true, endingFragment: true, tsn: 3, streamSequenceNumber: 0, userData: []byte("H"), } complete = rq.push(chunk) assert.True(t, complete, "chunk set should be complete") assert.Equal(t, 8, rq.getNumBytes(), "num bytes mismatch") buf := make([]byte, 16) n, ppi, err := rq.read(buf) assert.Nil(t, err, "read() should succeed") assert.Equal(t, 8, n, "should received 8 bytes") assert.Equal(t, 0, rq.getNumBytes(), "num bytes mismatch") assert.Equal(t, ppi, orgPpi, "should have valid ppi") assert.Equal(t, string(buf[:n]), "ABCDEFGH", "data should match") }) t.Run("ordered and unordered in the mix", func(t *testing.T) { rq := newReassemblyQueue(0) orgPpi := PayloadTypeWebRTCBinary var chunk *chunkPayloadData var complete bool chunk = &chunkPayloadData{ payloadType: orgPpi, beginningFragment: true, endingFragment: true, tsn: 1, streamSequenceNumber: 0, userData: []byte("ABC"), } complete = rq.push(chunk) assert.True(t, complete, "chunk set should be complete") assert.Equal(t, 3, rq.getNumBytes(), "num bytes mismatch") chunk = &chunkPayloadData{ payloadType: orgPpi, unordered: true, beginningFragment: true, endingFragment: true, tsn: 2, streamSequenceNumber: 1, userData: []byte("DEF"), } complete = rq.push(chunk) assert.True(t, complete, "chunk set should be complete") assert.Equal(t, 6, rq.getNumBytes(), "num bytes mismatch") // // Now we have two complete chunks ready to read in the reassemblyQueue. // buf := make([]byte, 16) // Should read unordered chunks first n, ppi, err := rq.read(buf) assert.Nil(t, err, "read() should succeed") assert.Equal(t, 3, n, "should received 3 bytes") assert.Equal(t, 3, rq.getNumBytes(), "num bytes mismatch") assert.Equal(t, ppi, orgPpi, "should have valid ppi") assert.Equal(t, string(buf[:n]), "DEF", "data should match") // Next should read ordered chunks n, ppi, err = rq.read(buf) assert.Nil(t, err, "read() should succeed") assert.Equal(t, 3, n, "should received 3 bytes") assert.Equal(t, 0, rq.getNumBytes(), "num bytes mismatch") assert.Equal(t, ppi, orgPpi, "should have valid ppi") assert.Equal(t, string(buf[:n]), "ABC", "data should match") }) t.Run("unordered complete skips incomplete", func(t *testing.T) { rq := newReassemblyQueue(0) orgPpi := PayloadTypeWebRTCBinary var chunk *chunkPayloadData var complete bool chunk = &chunkPayloadData{ payloadType: orgPpi, unordered: true, beginningFragment: true, tsn: 10, streamSequenceNumber: 0, userData: []byte("IN"), } complete = rq.push(chunk) assert.False(t, complete, "chunk set should not be complete yet") assert.Equal(t, 2, rq.getNumBytes(), "num bytes mismatch") chunk = &chunkPayloadData{ payloadType: orgPpi, unordered: true, endingFragment: true, tsn: 12, // <- incongiguous streamSequenceNumber: 1, userData: []byte("COMPLETE"), } complete = rq.push(chunk) assert.False(t, complete, "chunk set should not be complete yet") assert.Equal(t, 10, rq.getNumBytes(), "num bytes mismatch") chunk = &chunkPayloadData{ payloadType: orgPpi, unordered: true, beginningFragment: true, endingFragment: true, tsn: 13, streamSequenceNumber: 1, userData: []byte("GOOD"), } complete = rq.push(chunk) assert.True(t, complete, "chunk set should be complete") assert.Equal(t, 14, rq.getNumBytes(), "num bytes mismatch") // // Now we have two complete chunks ready to read in the reassemblyQueue. // buf := make([]byte, 16) // Should pick the one that has "GOOD" n, ppi, err := rq.read(buf) assert.Nil(t, err, "read() should succeed") assert.Equal(t, 4, n, "should receive 4 bytes") assert.Equal(t, 10, rq.getNumBytes(), "num bytes mismatch") assert.Equal(t, ppi, orgPpi, "should have valid ppi") assert.Equal(t, string(buf[:n]), "GOOD", "data should match") }) t.Run("ignores chunk with wrong SI", func(t *testing.T) { rq := newReassemblyQueue(123) orgPpi := PayloadTypeWebRTCBinary var chunk *chunkPayloadData var complete bool chunk = &chunkPayloadData{ payloadType: orgPpi, streamIdentifier: 124, beginningFragment: true, endingFragment: true, tsn: 10, streamSequenceNumber: 0, userData: []byte("IN"), } complete = rq.push(chunk) assert.False(t, complete, "chunk should be ignored") assert.Equal(t, 0, rq.getNumBytes(), "num bytes mismatch") }) t.Run("ignores chunk with stale SSN", func(t *testing.T) { rq := newReassemblyQueue(0) rq.nextSSN = 7 // forcibly set expected SSN to 7 orgPpi := PayloadTypeWebRTCBinary var chunk *chunkPayloadData var complete bool chunk = &chunkPayloadData{ payloadType: orgPpi, beginningFragment: true, endingFragment: true, tsn: 10, streamSequenceNumber: 6, // <-- stale userData: []byte("IN"), } complete = rq.push(chunk) assert.False(t, complete, "chunk should not be ignored") assert.Equal(t, 0, rq.getNumBytes(), "num bytes mismatch") }) t.Run("should fail to read incomplete chunk", func(t *testing.T) { rq := newReassemblyQueue(0) orgPpi := PayloadTypeWebRTCBinary var chunk *chunkPayloadData var complete bool chunk = &chunkPayloadData{ payloadType: orgPpi, beginningFragment: true, tsn: 123, streamSequenceNumber: 0, userData: []byte("IN"), } complete = rq.push(chunk) assert.False(t, complete, "the set should not be complete") assert.Equal(t, 2, rq.getNumBytes(), "num bytes mismatch") buf := make([]byte, 16) _, _, err := rq.read(buf) assert.NotNil(t, err, "read() should not succeed") assert.Equal(t, 2, rq.getNumBytes(), "num bytes mismatch") }) t.Run("should fail to read if the next SSN is not ready", func(t *testing.T) { rq := newReassemblyQueue(0) orgPpi := PayloadTypeWebRTCBinary var chunk *chunkPayloadData var complete bool chunk = &chunkPayloadData{ payloadType: orgPpi, beginningFragment: true, endingFragment: true, tsn: 123, streamSequenceNumber: 1, userData: []byte("IN"), } complete = rq.push(chunk) assert.True(t, complete, "the set should be complete") assert.Equal(t, 2, rq.getNumBytes(), "num bytes mismatch") buf := make([]byte, 16) _, _, err := rq.read(buf) assert.NotNil(t, err, "read() should not succeed") assert.Equal(t, 2, rq.getNumBytes(), "num bytes mismatch") }) t.Run("detect buffer too short", func(t *testing.T) { rq := newReassemblyQueue(0) orgPpi := PayloadTypeWebRTCBinary var chunk *chunkPayloadData var complete bool chunk = &chunkPayloadData{ payloadType: orgPpi, beginningFragment: true, endingFragment: true, tsn: 123, streamSequenceNumber: 0, userData: []byte("0123456789"), } complete = rq.push(chunk) assert.True(t, complete, "the set should be complete") assert.Equal(t, 10, rq.getNumBytes(), "num bytes mismatch") buf := make([]byte, 8) // <- passing buffer too short _, _, err := rq.read(buf) assert.Equal(t, io.ErrShortBuffer, err, "read() should not succeed") assert.Equal(t, 0, rq.getNumBytes(), "num bytes mismatch") }) t.Run("forwardTSN for ordered fragments", func(t *testing.T) { rq := newReassemblyQueue(0) orgPpi := PayloadTypeWebRTCBinary var chunk *chunkPayloadData var complete bool var ssnComplete uint16 = 5 var ssnDropped uint16 = 6 chunk = &chunkPayloadData{ payloadType: orgPpi, beginningFragment: true, endingFragment: true, tsn: 10, streamSequenceNumber: ssnComplete, userData: []byte("123"), } complete = rq.push(chunk) assert.True(t, complete, "chunk set should be complete") assert.Equal(t, 3, rq.getNumBytes(), "num bytes mismatch") chunk = &chunkPayloadData{ payloadType: orgPpi, beginningFragment: true, tsn: 11, streamSequenceNumber: ssnDropped, userData: []byte("ABC"), } complete = rq.push(chunk) assert.False(t, complete, "chunk set should not be complete yet") assert.Equal(t, 6, rq.getNumBytes(), "num bytes mismatch") chunk = &chunkPayloadData{ payloadType: orgPpi, tsn: 12, streamSequenceNumber: ssnDropped, userData: []byte("DEF"), } complete = rq.push(chunk) assert.False(t, complete, "chunk set should not be complete yet") assert.Equal(t, 9, rq.getNumBytes(), "num bytes mismatch") rq.forwardTSNForOrdered(ssnDropped) assert.Equal(t, 1, len(rq.ordered), "there should be one chunk left") assert.Equal(t, 3, rq.getNumBytes(), "num bytes mismatch") }) t.Run("forwardTSN for unordered fragments", func(t *testing.T) { rq := newReassemblyQueue(0) orgPpi := PayloadTypeWebRTCBinary var chunk *chunkPayloadData var complete bool var ssnDropped uint16 = 6 var ssnKept uint16 = 7 chunk = &chunkPayloadData{ payloadType: orgPpi, unordered: true, beginningFragment: true, tsn: 11, streamSequenceNumber: ssnDropped, userData: []byte("ABC"), } complete = rq.push(chunk) assert.False(t, complete, "chunk set should not be complete yet") assert.Equal(t, 3, rq.getNumBytes(), "num bytes mismatch") chunk = &chunkPayloadData{ payloadType: orgPpi, unordered: true, tsn: 12, streamSequenceNumber: ssnDropped, userData: []byte("DEF"), } complete = rq.push(chunk) assert.False(t, complete, "chunk set should not be complete yet") assert.Equal(t, 6, rq.getNumBytes(), "num bytes mismatch") chunk = &chunkPayloadData{ payloadType: orgPpi, unordered: true, tsn: 14, beginningFragment: true, streamSequenceNumber: ssnKept, userData: []byte("SOS"), } complete = rq.push(chunk) assert.False(t, complete, "chunk set should not be complete yet") assert.Equal(t, 9, rq.getNumBytes(), "num bytes mismatch") // At this point, there are 3 chunks in the rq.unorderedChunks. // This call should remove chunks with tsn equals to 13 or older. rq.forwardTSNForUnordered(13) // As a result, there should be one chunk (tsn=14) assert.Equal(t, 1, len(rq.unorderedChunks), "there should be one chunk kept") assert.Equal(t, 3, rq.getNumBytes(), "num bytes mismatch") }) } func TestChunkSet(t *testing.T) { t.Run("Empty chunkSet", func(t *testing.T) { cset := newChunkSet(0, 0) assert.False(t, cset.isComplete(), "empty chunkSet cannot be complete") }) t.Run("Push dup chunks to chunkSet", func(t *testing.T) { cset := newChunkSet(0, 0) cset.push(&chunkPayloadData{ tsn: 100, beginningFragment: true, }) complete := cset.push(&chunkPayloadData{ tsn: 100, endingFragment: true, }) assert.False(t, complete, "chunk with dup TSN is not complete") nChunks := len(cset.chunks) assert.Equal(t, 1, nChunks, "chunk with dup TSN should be ignored") }) t.Run("Incomplete chunkSet: no beginning", func(t *testing.T) { cset := &chunkSet{ ssn: 0, ppi: 0, chunks: []*chunkPayloadData{{}}, } assert.False(t, cset.isComplete(), "chunkSet not starting with B=1 cannot be complete") }) t.Run("Incomplete chunkSet: no contiguous tsn", func(t *testing.T) { cset := &chunkSet{ ssn: 0, ppi: 0, chunks: []*chunkPayloadData{ { tsn: 100, beginningFragment: true, }, { tsn: 101, }, { tsn: 103, endingFragment: true, }, }, } assert.False(t, cset.isComplete(), "chunkSet not starting with incontiguous tsn cannot be complete") }) } sctp-1.8.6/renovate.json000066400000000000000000000001731436021606300152170ustar00rootroot00000000000000{ "$schema": "https://docs.renovatebot.com/renovate-schema.json", "extends": [ "github>pion/renovate-config" ] } sctp-1.8.6/rtx_timer.go000066400000000000000000000113051436021606300150440ustar00rootroot00000000000000package sctp import ( "math" "sync" "time" ) const ( rtoInitial float64 = 3.0 * 1000 // msec rtoMin float64 = 1.0 * 1000 // msec rtoMax float64 = 60.0 * 1000 // msec rtoAlpha float64 = 0.125 rtoBeta float64 = 0.25 maxInitRetrans uint = 8 pathMaxRetrans uint = 5 noMaxRetrans uint = 0 ) // rtoManager manages Rtx timeout values. // This is an implementation of RFC 4960 sec 6.3.1. type rtoManager struct { srtt float64 rttvar float64 rto float64 noUpdate bool mutex sync.RWMutex } // newRTOManager creates a new rtoManager. func newRTOManager() *rtoManager { return &rtoManager{ rto: rtoInitial, } } // setNewRTT takes a newly measured RTT then adjust the RTO in msec. func (m *rtoManager) setNewRTT(rtt float64) float64 { m.mutex.Lock() defer m.mutex.Unlock() if m.noUpdate { return m.srtt } if m.srtt == 0 { // First measurement m.srtt = rtt m.rttvar = rtt / 2 } else { // Subsequent rtt measurement m.rttvar = (1-rtoBeta)*m.rttvar + rtoBeta*(math.Abs(m.srtt-rtt)) m.srtt = (1-rtoAlpha)*m.srtt + rtoAlpha*rtt } m.rto = math.Min(math.Max(m.srtt+4*m.rttvar, rtoMin), rtoMax) return m.srtt } // getRTO simply returns the current RTO in msec. func (m *rtoManager) getRTO() float64 { m.mutex.RLock() defer m.mutex.RUnlock() return m.rto } // reset resets the RTO variables to the initial values. func (m *rtoManager) reset() { m.mutex.Lock() defer m.mutex.Unlock() if m.noUpdate { return } m.srtt = 0 m.rttvar = 0 m.rto = rtoInitial } // set RTO value for testing func (m *rtoManager) setRTO(rto float64, noUpdate bool) { m.mutex.Lock() defer m.mutex.Unlock() m.rto = rto m.noUpdate = noUpdate } // rtxTimerObserver is the inteface to a timer observer. // NOTE: Observers MUST NOT call start() or stop() method on rtxTimer // from within these callbacks. type rtxTimerObserver interface { onRetransmissionTimeout(timerID int, n uint) onRetransmissionFailure(timerID int) } // rtxTimer provides the retnransmission timer conforms with RFC 4960 Sec 6.3.1 type rtxTimer struct { id int observer rtxTimerObserver maxRetrans uint stopFunc stopTimerLoop closed bool mutex sync.RWMutex } type stopTimerLoop func() // newRTXTimer creates a new retransmission timer. // if maxRetrans is set to 0, it will keep retransmitting until stop() is called. // (it will never make onRetransmissionFailure() callback. func newRTXTimer(id int, observer rtxTimerObserver, maxRetrans uint) *rtxTimer { return &rtxTimer{ id: id, observer: observer, maxRetrans: maxRetrans, } } // start starts the timer. func (t *rtxTimer) start(rto float64) bool { t.mutex.Lock() defer t.mutex.Unlock() // this timer is already closed if t.closed { return false } // this is a noop if the timer is always running if t.stopFunc != nil { return false } // Note: rto value is intentionally not capped by RTO.Min to allow // fast timeout for the tests. Non-test code should pass in the // rto generated by rtoManager getRTO() method which caps the // value at RTO.Min or at RTO.Max. var nRtos uint cancelCh := make(chan struct{}) go func() { canceling := false for !canceling { timeout := calculateNextTimeout(rto, nRtos) timer := time.NewTimer(time.Duration(timeout) * time.Millisecond) select { case <-timer.C: nRtos++ if t.maxRetrans == 0 || nRtos <= t.maxRetrans { t.observer.onRetransmissionTimeout(t.id, nRtos) } else { t.stop() t.observer.onRetransmissionFailure(t.id) } case <-cancelCh: canceling = true timer.Stop() } } }() t.stopFunc = func() { close(cancelCh) } return true } // stop stops the timer. func (t *rtxTimer) stop() { t.mutex.Lock() defer t.mutex.Unlock() if t.stopFunc != nil { t.stopFunc() t.stopFunc = nil } } // closes the timer. this is similar to stop() but subsequent start() call // will fail (the timer is no longer usable) func (t *rtxTimer) close() { t.mutex.Lock() defer t.mutex.Unlock() if t.stopFunc != nil { t.stopFunc() t.stopFunc = nil } t.closed = true } // isRunning tests if the timer is running. // Debug purpose only func (t *rtxTimer) isRunning() bool { t.mutex.RLock() defer t.mutex.RUnlock() return (t.stopFunc != nil) } func calculateNextTimeout(rto float64, nRtos uint) float64 { // RFC 4096 sec 6.3.3. Handle T3-rtx Expiration // E2) For the destination address for which the timer expires, set RTO // <- RTO * 2 ("back off the timer"). The maximum value discussed // in rule C7 above (RTO.max) may be used to provide an upper bound // to this doubling operation. if nRtos < 31 { m := 1 << nRtos return math.Min(rto*float64(m), rtoMax) } return rtoMax } sctp-1.8.6/rtx_timer_test.go000066400000000000000000000240101436021606300161000ustar00rootroot00000000000000package sctp import ( "math" "sync/atomic" "testing" "time" "github.com/stretchr/testify/assert" ) func TestRTOManager(t *testing.T) { t.Run("initial values", func(t *testing.T) { m := newRTOManager() assert.Equal(t, rtoInitial, m.rto, "should be rtoInitial") assert.Equal(t, rtoInitial, m.getRTO(), "should be rtoInitial") assert.Equal(t, float64(0), m.srtt, "should be 0") assert.Equal(t, float64(0), m.rttvar, "should be 0") }) t.Run("RTO calculation (small RTT)", func(t *testing.T) { var rto float64 m := newRTOManager() exp := []int32{ 1800, 1500, 1275, 1106, 1000, // capped at RTO.Min } for i := 0; i < 5; i++ { m.setNewRTT(600) rto = m.getRTO() assert.Equal(t, exp[i], int32(math.Floor(rto)), "should be equal") } }) t.Run("RTO calculation (large RTT)", func(t *testing.T) { var rto float64 m := newRTOManager() exp := []int32{ 60000, // capped at RTO.Max 60000, // capped at RTO.Max 60000, // capped at RTO.Max 55312, 48984, } for i := 0; i < 5; i++ { m.setNewRTT(30000) rto = m.getRTO() assert.Equal(t, exp[i], int32(math.Floor(rto)), "should be equal") } }) t.Run("calculateNextTimeout", func(t *testing.T) { var rto float64 rto = calculateNextTimeout(1.0, 0) assert.Equal(t, float64(1), rto, "should match") rto = calculateNextTimeout(1.0, 1) assert.Equal(t, float64(2), rto, "should match") rto = calculateNextTimeout(1.0, 2) assert.Equal(t, float64(4), rto, "should match") rto = calculateNextTimeout(1.0, 30) assert.Equal(t, float64(60000), rto, "should match") rto = calculateNextTimeout(1.0, 63) assert.Equal(t, float64(60000), rto, "should match") rto = calculateNextTimeout(1.0, 64) assert.Equal(t, float64(60000), rto, "should match") }) t.Run("reset", func(t *testing.T) { m := newRTOManager() for i := 0; i < 10; i++ { m.setNewRTT(200) } m.reset() assert.Equal(t, rtoInitial, m.getRTO(), "should be rtoInitial") assert.Equal(t, float64(0), m.srtt, "should be 0") assert.Equal(t, float64(0), m.rttvar, "should be 0") }) } type ( onRTO func(id int, n uint) onRtxFailure func(id int) ) type testTimerObserver struct { onRTO onRTO onRtxFailure onRtxFailure } func (o *testTimerObserver) onRetransmissionTimeout(id int, n uint) { o.onRTO(id, n) } func (o *testTimerObserver) onRetransmissionFailure(id int) { o.onRtxFailure(id) } func TestRtxTimer(t *testing.T) { t.Run("callback interval", func(t *testing.T) { timerID := 0 var nCbs int32 rt := newRTXTimer(timerID, &testTimerObserver{ onRTO: func(id int, nRtos uint) { atomic.AddInt32(&nCbs, 1) // 30 : 1 (30) // 60 : 2 (90) // 120: 3 (210) // 240: 4 (550) <== expected in 650 msec assert.Equal(t, timerID, id, "unexpted timer ID: %d", id) }, onRtxFailure: func(id int) {}, }, pathMaxRetrans) assert.False(t, rt.isRunning(), "should not be running") // since := time.Now() ok := rt.start(30) assert.True(t, ok, "should be true") assert.True(t, rt.isRunning(), "should be running") time.Sleep(650 * time.Millisecond) rt.stop() assert.False(t, rt.isRunning(), "should not be running") assert.Equal(t, int32(4), atomic.LoadInt32(&nCbs), "should be called 4 times") }) t.Run("last start wins", func(t *testing.T) { timerID := 3 var nCbs int32 rt := newRTXTimer(timerID, &testTimerObserver{ onRTO: func(id int, nRtos uint) { atomic.AddInt32(&nCbs, 1) assert.Equal(t, timerID, id, "unexpted timer ID: %d", id) }, onRtxFailure: func(id int) {}, }, pathMaxRetrans) interval := float64(30.0) ok := rt.start(interval) assert.True(t, ok, "should be accepted") ok = rt.start(interval * 99) // should ignored assert.False(t, ok, "should be ignored") ok = rt.start(interval * 99) // should ignored assert.False(t, ok, "should be ignored") time.Sleep(time.Duration(interval*1.5) * time.Millisecond) rt.stop() assert.False(t, rt.isRunning(), "should not be running") assert.Equal(t, int32(1), atomic.LoadInt32(&nCbs), "must be called once") }) t.Run("stop right afeter start", func(t *testing.T) { timerID := 3 var nCbs int32 rt := newRTXTimer(timerID, &testTimerObserver{ onRTO: func(id int, nRtos uint) { atomic.AddInt32(&nCbs, 1) assert.Equal(t, timerID, id, "unexpted timer ID: %d", id) }, onRtxFailure: func(id int) {}, }, pathMaxRetrans) interval := float64(30.0) ok := rt.start(interval) assert.True(t, ok, "should be accepted") rt.stop() time.Sleep(time.Duration(interval*1.5) * time.Millisecond) rt.stop() assert.False(t, rt.isRunning(), "should not be running") assert.Equal(t, int32(0), atomic.LoadInt32(&nCbs), "no callback should be made") }) t.Run("start, stop then start", func(t *testing.T) { timerID := 1 var nCbs int32 rt := newRTXTimer(timerID, &testTimerObserver{ onRTO: func(id int, nRtos uint) { atomic.AddInt32(&nCbs, 1) assert.Equal(t, timerID, id, "unexpted timer ID: %d", id) }, onRtxFailure: func(id int) {}, }, pathMaxRetrans) interval := float64(30.0) ok := rt.start(interval) assert.True(t, ok, "should be accepted") rt.stop() assert.False(t, rt.isRunning(), "should NOT be running") ok = rt.start(interval) assert.True(t, ok, "should be accepted") assert.True(t, rt.isRunning(), "should be running") time.Sleep(time.Duration(interval*1.5) * time.Millisecond) rt.stop() assert.False(t, rt.isRunning(), "should NOT be running") assert.Equal(t, int32(1), atomic.LoadInt32(&nCbs), "must be called once") }) t.Run("start and stop in a tight loop", func(t *testing.T) { timerID := 2 var nCbs int32 rt := newRTXTimer(timerID, &testTimerObserver{ onRTO: func(id int, nRtos uint) { atomic.AddInt32(&nCbs, 1) t.Log("onRTO() called") assert.Equal(t, timerID, id, "unexpted timer ID: %d", id) }, onRtxFailure: func(id int) {}, }, pathMaxRetrans) for i := 0; i < 1000; i++ { ok := rt.start(30) assert.True(t, ok, "should be accepted") assert.True(t, rt.isRunning(), "should be running") rt.stop() assert.False(t, rt.isRunning(), "should NOT be running") } assert.Equal(t, int32(0), atomic.LoadInt32(&nCbs), "no callback should be made") }) t.Run("timer should stop after rtx failure", func(t *testing.T) { timerID := 4 var nCbs int32 doneCh := make(chan bool) since := time.Now() var elapsed float64 // in seconds rt := newRTXTimer(timerID, &testTimerObserver{ onRTO: func(id int, nRtos uint) { assert.Equal(t, timerID, id, "unexpted timer ID: %d", id) t.Logf("onRTO: n=%d elapsed=%.03f\n", nRtos, time.Since(since).Seconds()) atomic.AddInt32(&nCbs, 1) }, onRtxFailure: func(id int) { assert.Equal(t, timerID, id, "unexpted timer ID: %d", id) elapsed = time.Since(since).Seconds() t.Logf("onRtxFailure: elapsed=%.03f\n", elapsed) doneCh <- true }, }, pathMaxRetrans) // RTO(msec) Total(msec) // 10 10 1st RTO // 20 30 2nd RTO // 40 70 3rd RTO // 80 150 4th RTO // 160 310 5th RTO (== Path.Max.Retrans) // 320 630 Failure interval := float64(10.0) ok := rt.start(interval) assert.True(t, ok, "should be accepted") assert.True(t, rt.isRunning(), "should be running") <-doneCh assert.False(t, rt.isRunning(), "should not be running") assert.Equal(t, int32(5), atomic.LoadInt32(&nCbs), "should be called 5 times") assert.True(t, elapsed > 0.600, "must have taken more than 600 msec") assert.True(t, elapsed < 0.700, "must fail in less than 700 msec") }) t.Run("timer should not stop if maxRetrans is 0", func(t *testing.T) { timerID := 4 maxRtos := uint(6) var nCbs int32 doneCh := make(chan bool) since := time.Now() var elapsed float64 // in seconds rt := newRTXTimer(timerID, &testTimerObserver{ onRTO: func(id int, nRtos uint) { assert.Equal(t, timerID, id, "unexpted timer ID: %d", id) elapsed = time.Since(since).Seconds() t.Logf("onRTO: n=%d elapsed=%.03f\n", nRtos, elapsed) atomic.AddInt32(&nCbs, 1) if nRtos == maxRtos { doneCh <- true } }, onRtxFailure: func(id int) { assert.Fail(t, "timer should not fail") }, }, 0) // RTO(msec) Total(msec) // 10 10 1st RTO // 20 30 2nd RTO // 40 70 3rd RTO // 80 150 4th RTO // 160 310 5th RTO // 320 630 6th RTO => exit test (timer should still be running) interval := float64(10.0) ok := rt.start(interval) assert.True(t, ok, "should be accepted") assert.True(t, rt.isRunning(), "should be running") <-doneCh assert.True(t, rt.isRunning(), "should still be running") assert.Equal(t, int32(6), atomic.LoadInt32(&nCbs), "should be called 6 times") assert.True(t, elapsed > 0.600, "must have taken more than 600 msec") assert.True(t, elapsed < 0.700, "must fail in less than 700 msec") rt.stop() }) t.Run("stop timer that is not running is noop", func(t *testing.T) { timerID := 5 doneCh := make(chan bool) rt := newRTXTimer(timerID, &testTimerObserver{ onRTO: func(id int, nRtos uint) { assert.Equal(t, timerID, id, "unexpted timer ID: %d", id) doneCh <- true }, onRtxFailure: func(id int) {}, }, pathMaxRetrans) for i := 0; i < 10; i++ { rt.stop() } ok := rt.start(20) assert.True(t, ok, "should be accepted") assert.True(t, rt.isRunning(), "must be running") <-doneCh rt.stop() assert.False(t, rt.isRunning(), "must be false") }) t.Run("closed timer won't start", func(t *testing.T) { var rtoCount int timerID := 6 rt := newRTXTimer(timerID, &testTimerObserver{ onRTO: func(id int, nRtos uint) { rtoCount++ }, onRtxFailure: func(id int) {}, }, pathMaxRetrans) ok := rt.start(20) assert.True(t, ok, "should be accepted") assert.True(t, rt.isRunning(), "must be running") rt.close() assert.False(t, rt.isRunning(), "must be false") ok = rt.start(20) assert.False(t, ok, "should not start") assert.False(t, rt.isRunning(), "must not be running") time.Sleep(100 * time.Millisecond) assert.Equal(t, 0, rtoCount, "RTO should not occur") }) } sctp-1.8.6/sctp.go000066400000000000000000000000661436021606300140020ustar00rootroot00000000000000// Package sctp implements the SCTP spec package sctp sctp-1.8.6/stream.go000066400000000000000000000304061436021606300143250ustar00rootroot00000000000000package sctp import ( "errors" "fmt" "io" "math" "os" "sync" "sync/atomic" "time" "github.com/pion/logging" ) const ( // ReliabilityTypeReliable is used for reliable transmission ReliabilityTypeReliable byte = 0 // ReliabilityTypeRexmit is used for partial reliability by retransmission count ReliabilityTypeRexmit byte = 1 // ReliabilityTypeTimed is used for partial reliability by retransmission duration ReliabilityTypeTimed byte = 2 ) // StreamState is an enum for SCTP Stream state field // This field identifies the state of stream. type StreamState int // StreamState enums const ( StreamStateOpen StreamState = iota // Stream object starts with StreamStateOpen StreamStateClosing // Outgoing stream is being reset StreamStateClosed // Stream has been closed ) func (ss StreamState) String() string { switch ss { case StreamStateOpen: return "open" case StreamStateClosing: return "closing" case StreamStateClosed: return "closed" } return "unknown" } // SCTP stream errors var ( ErrOutboundPacketTooLarge = errors.New("outbound packet larger than maximum message size") ErrStreamClosed = errors.New("stream closed") ErrReadDeadlineExceeded = fmt.Errorf("read deadline exceeded: %w", os.ErrDeadlineExceeded) ) // Stream represents an SCTP stream type Stream struct { association *Association lock sync.RWMutex streamIdentifier uint16 defaultPayloadType PayloadProtocolIdentifier reassemblyQueue *reassemblyQueue sequenceNumber uint16 readNotifier *sync.Cond readErr error readTimeoutCancel chan struct{} unordered bool reliabilityType byte reliabilityValue uint32 bufferedAmount uint64 bufferedAmountLow uint64 onBufferedAmountLow func() state StreamState log logging.LeveledLogger name string } // StreamIdentifier returns the Stream identifier associated to the stream. func (s *Stream) StreamIdentifier() uint16 { s.lock.RLock() defer s.lock.RUnlock() return s.streamIdentifier } // SetDefaultPayloadType sets the default payload type used by Write. func (s *Stream) SetDefaultPayloadType(defaultPayloadType PayloadProtocolIdentifier) { atomic.StoreUint32((*uint32)(&s.defaultPayloadType), uint32(defaultPayloadType)) } // SetReliabilityParams sets reliability parameters for this stream. func (s *Stream) SetReliabilityParams(unordered bool, relType byte, relVal uint32) { s.lock.Lock() defer s.lock.Unlock() s.setReliabilityParams(unordered, relType, relVal) } // setReliabilityParams sets reliability parameters for this stream. // The caller should hold the lock. func (s *Stream) setReliabilityParams(unordered bool, relType byte, relVal uint32) { s.log.Debugf("[%s] reliability params: ordered=%v type=%d value=%d", s.name, !unordered, relType, relVal) s.unordered = unordered s.reliabilityType = relType s.reliabilityValue = relVal } // Read reads a packet of len(p) bytes, dropping the Payload Protocol Identifier. // Returns EOF when the stream is reset or an error if the stream is closed // otherwise. func (s *Stream) Read(p []byte) (int, error) { n, _, err := s.ReadSCTP(p) return n, err } // ReadSCTP reads a packet of len(p) bytes and returns the associated Payload // Protocol Identifier. // Returns EOF when the stream is reset or an error if the stream is closed // otherwise. func (s *Stream) ReadSCTP(p []byte) (int, PayloadProtocolIdentifier, error) { s.lock.Lock() defer s.lock.Unlock() defer func() { // close readTimeoutCancel if the current read timeout routine is no longer effective if s.readTimeoutCancel != nil && s.readErr != nil { close(s.readTimeoutCancel) s.readTimeoutCancel = nil } }() for { n, ppi, err := s.reassemblyQueue.read(p) if err == nil { return n, ppi, nil } else if errors.Is(err, io.ErrShortBuffer) { return 0, PayloadProtocolIdentifier(0), err } err = s.readErr if err != nil { return 0, PayloadProtocolIdentifier(0), err } s.readNotifier.Wait() } } // SetReadDeadline sets the read deadline in an identical way to net.Conn func (s *Stream) SetReadDeadline(deadline time.Time) error { s.lock.Lock() defer s.lock.Unlock() if s.readTimeoutCancel != nil { close(s.readTimeoutCancel) s.readTimeoutCancel = nil } if s.readErr != nil { if !errors.Is(s.readErr, ErrReadDeadlineExceeded) { return nil } s.readErr = nil } if !deadline.IsZero() { s.readTimeoutCancel = make(chan struct{}) go func(readTimeoutCancel chan struct{}) { t := time.NewTimer(time.Until(deadline)) select { case <-readTimeoutCancel: t.Stop() return case <-t.C: s.lock.Lock() if s.readErr == nil { s.readErr = ErrReadDeadlineExceeded } s.readTimeoutCancel = nil s.lock.Unlock() s.readNotifier.Signal() } }(s.readTimeoutCancel) } return nil } func (s *Stream) handleData(pd *chunkPayloadData) { s.lock.Lock() defer s.lock.Unlock() var readable bool if s.reassemblyQueue.push(pd) { readable = s.reassemblyQueue.isReadable() s.log.Debugf("[%s] reassemblyQueue readable=%v", s.name, readable) if readable { s.log.Debugf("[%s] readNotifier.signal()", s.name) s.readNotifier.Signal() s.log.Debugf("[%s] readNotifier.signal() done", s.name) } } } func (s *Stream) handleForwardTSNForOrdered(ssn uint16) { var readable bool func() { s.lock.Lock() defer s.lock.Unlock() if s.unordered { return // unordered chunks are handled by handleForwardUnordered method } // Remove all chunks older than or equal to the new TSN from // the reassemblyQueue. s.reassemblyQueue.forwardTSNForOrdered(ssn) readable = s.reassemblyQueue.isReadable() }() // Notify the reader asynchronously if there's a data chunk to read. if readable { s.readNotifier.Signal() } } func (s *Stream) handleForwardTSNForUnordered(newCumulativeTSN uint32) { var readable bool func() { s.lock.Lock() defer s.lock.Unlock() if !s.unordered { return // ordered chunks are handled by handleForwardTSNOrdered method } // Remove all chunks older than or equal to the new TSN from // the reassemblyQueue. s.reassemblyQueue.forwardTSNForUnordered(newCumulativeTSN) readable = s.reassemblyQueue.isReadable() }() // Notify the reader asynchronously if there's a data chunk to read. if readable { s.readNotifier.Signal() } } // Write writes len(p) bytes from p with the default Payload Protocol Identifier func (s *Stream) Write(p []byte) (n int, err error) { ppi := PayloadProtocolIdentifier(atomic.LoadUint32((*uint32)(&s.defaultPayloadType))) return s.WriteSCTP(p, ppi) } // WriteSCTP writes len(p) bytes from p to the DTLS connection func (s *Stream) WriteSCTP(p []byte, ppi PayloadProtocolIdentifier) (int, error) { maxMessageSize := s.association.MaxMessageSize() if len(p) > int(maxMessageSize) { return 0, fmt.Errorf("%w: %v", ErrOutboundPacketTooLarge, math.MaxUint16) } if s.State() != StreamStateOpen { return 0, ErrStreamClosed } chunks := s.packetize(p, ppi) n := len(p) err := s.association.sendPayloadData(chunks) if err != nil { return n, ErrStreamClosed } return n, nil } func (s *Stream) packetize(raw []byte, ppi PayloadProtocolIdentifier) []*chunkPayloadData { s.lock.Lock() defer s.lock.Unlock() i := uint32(0) remaining := uint32(len(raw)) // From draft-ietf-rtcweb-data-protocol-09, section 6: // All Data Channel Establishment Protocol messages MUST be sent using // ordered delivery and reliable transmission. unordered := ppi != PayloadTypeWebRTCDCEP && s.unordered var chunks []*chunkPayloadData var head *chunkPayloadData for remaining != 0 { fragmentSize := min32(s.association.maxPayloadSize, remaining) // Copy the userdata since we'll have to store it until acked // and the caller may re-use the buffer in the mean time userData := make([]byte, fragmentSize) copy(userData, raw[i:i+fragmentSize]) chunk := &chunkPayloadData{ streamIdentifier: s.streamIdentifier, userData: userData, unordered: unordered, beginningFragment: i == 0, endingFragment: remaining-fragmentSize == 0, immediateSack: false, payloadType: ppi, streamSequenceNumber: s.sequenceNumber, head: head, } if head == nil { head = chunk } chunks = append(chunks, chunk) remaining -= fragmentSize i += fragmentSize } // RFC 4960 Sec 6.6 // Note: When transmitting ordered and unordered data, an endpoint does // not increment its Stream Sequence Number when transmitting a DATA // chunk with U flag set to 1. if !unordered { s.sequenceNumber++ } s.bufferedAmount += uint64(len(raw)) s.log.Tracef("[%s] bufferedAmount = %d", s.name, s.bufferedAmount) return chunks } // Close closes the write-direction of the stream. // Future calls to Write are not permitted after calling Close. func (s *Stream) Close() error { if sid, resetOutbound := func() (uint16, bool) { s.lock.Lock() defer s.lock.Unlock() s.log.Debugf("[%s] Close: state=%s", s.name, s.state.String()) if s.state == StreamStateOpen { if s.readErr == nil { s.state = StreamStateClosing } else { s.state = StreamStateClosed } s.log.Debugf("[%s] state change: open => %s", s.name, s.state.String()) return s.streamIdentifier, true } return s.streamIdentifier, false }(); resetOutbound { // Reset the outgoing stream // https://tools.ietf.org/html/rfc6525 return s.association.sendResetRequest(sid) } return nil } // BufferedAmount returns the number of bytes of data currently queued to be sent over this stream. func (s *Stream) BufferedAmount() uint64 { s.lock.RLock() defer s.lock.RUnlock() return s.bufferedAmount } // BufferedAmountLowThreshold returns the number of bytes of buffered outgoing data that is // considered "low." Defaults to 0. func (s *Stream) BufferedAmountLowThreshold() uint64 { s.lock.RLock() defer s.lock.RUnlock() return s.bufferedAmountLow } // SetBufferedAmountLowThreshold is used to update the threshold. // See BufferedAmountLowThreshold(). func (s *Stream) SetBufferedAmountLowThreshold(th uint64) { s.lock.Lock() defer s.lock.Unlock() s.bufferedAmountLow = th } // OnBufferedAmountLow sets the callback handler which would be called when the number of // bytes of outgoing data buffered is lower than the threshold. func (s *Stream) OnBufferedAmountLow(f func()) { s.lock.Lock() defer s.lock.Unlock() s.onBufferedAmountLow = f } // This method is called by association's readLoop (go-)routine to notify this stream // of the specified amount of outgoing data has been delivered to the peer. func (s *Stream) onBufferReleased(nBytesReleased int) { if nBytesReleased <= 0 { return } s.lock.Lock() fromAmount := s.bufferedAmount if s.bufferedAmount < uint64(nBytesReleased) { s.bufferedAmount = 0 s.log.Errorf("[%s] released buffer size %d should be <= %d", s.name, nBytesReleased, s.bufferedAmount) } else { s.bufferedAmount -= uint64(nBytesReleased) } s.log.Tracef("[%s] bufferedAmount = %d", s.name, s.bufferedAmount) if s.onBufferedAmountLow != nil && fromAmount > s.bufferedAmountLow && s.bufferedAmount <= s.bufferedAmountLow { f := s.onBufferedAmountLow s.lock.Unlock() f() return } s.lock.Unlock() } func (s *Stream) getNumBytesInReassemblyQueue() int { // No lock is required as it reads the size with atomic load function. return s.reassemblyQueue.getNumBytes() } func (s *Stream) onInboundStreamReset() { s.lock.Lock() defer s.lock.Unlock() s.log.Debugf("[%s] onInboundStreamReset: state=%s", s.name, s.state.String()) // No more inbound data to read. Unblock the read with io.EOF. // This should cause DCEP layer (datachannel package) to call Close() which // will reset outgoing stream also. // See RFC 8831 section 6.7: // if one side decides to close the data channel, it resets the corresponding // outgoing stream. When the peer sees that an incoming stream was // reset, it also resets its corresponding outgoing stream. Once this // is completed, the data channel is closed. s.readErr = io.EOF s.readNotifier.Broadcast() if s.state == StreamStateClosing { s.log.Debugf("[%s] state change: closing => closed", s.name) s.state = StreamStateClosed } } // State return the stream state. func (s *Stream) State() StreamState { s.lock.RLock() defer s.lock.RUnlock() return s.state } sctp-1.8.6/stream_test.go000066400000000000000000000043041436021606300153620ustar00rootroot00000000000000package sctp import ( "testing" "github.com/pion/logging" "github.com/stretchr/testify/assert" ) func TestSessionBufferedAmount(t *testing.T) { t.Run("bufferedAmount", func(t *testing.T) { s := &Stream{ log: logging.NewDefaultLoggerFactory().NewLogger("sctp-test"), } assert.Equal(t, uint64(0), s.BufferedAmount()) assert.Equal(t, uint64(0), s.BufferedAmountLowThreshold()) s.bufferedAmount = 8192 s.SetBufferedAmountLowThreshold(2048) assert.Equal(t, uint64(8192), s.BufferedAmount(), "unexpected bufferedAmount") assert.Equal(t, uint64(2048), s.BufferedAmountLowThreshold(), "unexpected threshold") }) t.Run("OnBufferedAmountLow", func(t *testing.T) { s := &Stream{ log: logging.NewDefaultLoggerFactory().NewLogger("sctp-test"), } s.bufferedAmount = 4096 s.SetBufferedAmountLowThreshold(2048) nCbs := 0 s.OnBufferedAmountLow(func() { nCbs++ }) // Negative value should be ignored (by design) s.onBufferReleased(-32) // bufferedAmount = 3072 assert.Equal(t, uint64(4096), s.BufferedAmount(), "unexpected bufferedAmount") assert.Equal(t, 0, nCbs, "callback count mismatch") // Above to above, no callback s.onBufferReleased(1024) // bufferedAmount = 3072 assert.Equal(t, uint64(3072), s.BufferedAmount(), "unexpected bufferedAmount") assert.Equal(t, 0, nCbs, "callback count mismatch") // Above to equal, callback should be made s.onBufferReleased(1024) // bufferedAmount = 2048 assert.Equal(t, uint64(2048), s.BufferedAmount(), "unexpected bufferedAmount") assert.Equal(t, 1, nCbs, "callback count mismatch") // Eaual to below, no callback s.onBufferReleased(1024) // bufferedAmount = 1024 assert.Equal(t, uint64(1024), s.BufferedAmount(), "unexpected bufferedAmount") assert.Equal(t, 1, nCbs, "callback count mismatch") // Blow to below, no callback s.onBufferReleased(1024) // bufferedAmount = 0 assert.Equal(t, uint64(0), s.BufferedAmount(), "unexpected bufferedAmount") assert.Equal(t, 1, nCbs, "callback count mismatch") // Capped at 0, no callback s.onBufferReleased(1024) // bufferedAmount = 0 assert.Equal(t, uint64(0), s.BufferedAmount(), "unexpected bufferedAmount") assert.Equal(t, 1, nCbs, "callback count mismatch") }) } sctp-1.8.6/util.go000066400000000000000000000022031436021606300140010ustar00rootroot00000000000000package sctp const ( paddingMultiple = 4 ) func getPadding(l int) int { return (paddingMultiple - (l % paddingMultiple)) % paddingMultiple } func padByte(in []byte, cnt int) []byte { if cnt < 0 { cnt = 0 } padding := make([]byte, cnt) return append(in, padding...) } // Serial Number Arithmetic (RFC 1982) func sna32LT(i1, i2 uint32) bool { return (i1 < i2 && i2-i1 < 1<<31) || (i1 > i2 && i1-i2 > 1<<31) } func sna32LTE(i1, i2 uint32) bool { return i1 == i2 || sna32LT(i1, i2) } func sna32GT(i1, i2 uint32) bool { return (i1 < i2 && (i2-i1) >= 1<<31) || (i1 > i2 && (i1-i2) <= 1<<31) } func sna32GTE(i1, i2 uint32) bool { return i1 == i2 || sna32GT(i1, i2) } func sna32EQ(i1, i2 uint32) bool { return i1 == i2 } func sna16LT(i1, i2 uint16) bool { return (i1 < i2 && (i2-i1) < 1<<15) || (i1 > i2 && (i1-i2) > 1<<15) } func sna16LTE(i1, i2 uint16) bool { return i1 == i2 || sna16LT(i1, i2) } func sna16GT(i1, i2 uint16) bool { return (i1 < i2 && (i2-i1) >= 1<<15) || (i1 > i2 && (i1-i2) <= 1<<15) } func sna16GTE(i1, i2 uint16) bool { return i1 == i2 || sna16GT(i1, i2) } func sna16EQ(i1, i2 uint16) bool { return i1 == i2 } sctp-1.8.6/util_test.go000066400000000000000000000105041436021606300150430ustar00rootroot00000000000000package sctp import ( "testing" "github.com/stretchr/testify/assert" ) func TestPadByte_Success(t *testing.T) { tt := []struct { value []byte padLen int expected []byte }{ {[]byte{0x1, 0x2}, 0, []byte{0x1, 0x2}}, {[]byte{0x1, 0x2}, 1, []byte{0x1, 0x2, 0x0}}, {[]byte{0x1, 0x2}, 2, []byte{0x1, 0x2, 0x0, 0x0}}, {[]byte{0x1, 0x2}, 3, []byte{0x1, 0x2, 0x0, 0x0, 0x0}}, {[]byte{0x1, 0x2}, -1, []byte{0x1, 0x2}}, } for i, tc := range tt { actual := padByte(tc.value, tc.padLen) assert.Equal(t, tc.expected, actual, "test %d not equal", i) } } func TestSerialNumberArithmetic(t *testing.T) { const div int = 16 t.Run("32-bit", func(t *testing.T) { // nolint:dupl const serialBits uint32 = 32 const interval uint32 = uint32((uint64(1) << uint64(serialBits)) / uint64(div)) const maxForwardDistance uint32 = 1<<(serialBits-1) - 1 const maxBackwardDistance uint32 = 1 << (serialBits - 1) for i := uint32(0); i < uint32(div); i++ { s1 := i * interval s2f := s1 + maxForwardDistance s2b := s1 + maxBackwardDistance assert.True(t, sna32LT(s1, s2f), "s1 < s2 should be true: s1=0x%x s2=0x%x", s1, s2f) assert.False(t, sna32LT(s1, s2b), "s1 < s2 should be false: s1=0x%x s2=0x%x", s1, s2b) assert.False(t, sna32GT(s1, s2f), "s1 > s2 should be fales: s1=0x%x s2=0x%x", s1, s2f) assert.True(t, sna32GT(s1, s2b), "s1 > s2 should be true: s1=0x%x s2=0x%x", s1, s2b) assert.True(t, sna32LTE(s1, s2f), "s1 <= s2 should be true: s1=0x%x s2=0x%x", s1, s2f) assert.False(t, sna32LTE(s1, s2b), "s1 <= s2 should be false: s1=0x%x s2=0x%x", s1, s2b) assert.False(t, sna32GTE(s1, s2f), "s1 >= s2 should be fales: s1=0x%x s2=0x%x", s1, s2f) assert.True(t, sna32GTE(s1, s2b), "s1 >= s2 should be true: s1=0x%x s2=0x%x", s1, s2b) assert.True(t, sna32EQ(s1, s1), "s1 == s1 should be true: s1=0x%x s2=0x%x", s1, s1) assert.True(t, sna32EQ(s2b, s2b), "s2 == s2 should be true: s2=0x%x s2=0x%x", s2b, s2b) assert.False(t, sna32EQ(s1, s1+1), "s1 == s1+1 should be false: s1=0x%x s1+1=0x%x", s1, s1+1) assert.False(t, sna32EQ(s1, s1-1), "s1 == s1-1 hould be false: s1=0x%x s1-1=0x%x", s1, s1-1) assert.True(t, sna32LTE(s1, s1), "s1 == s1 should be true: s1=0x%x s2=0x%x", s1, s1) assert.True(t, sna32LTE(s2b, s2b), "s2 == s2 should be true: s2=0x%x s2=0x%x", s2b, s2b) assert.True(t, sna32GTE(s1, s1), "s1 == s1 should be true: s1=0x%x s2=0x%x", s1, s1) assert.True(t, sna32GTE(s2b, s2b), "s2 == s2 should be true: s2=0x%x s2=0x%x", s2b, s2b) } }) t.Run("16-bit", func(t *testing.T) { // nolint:dupl const serialBits uint16 = 16 const interval uint16 = uint16((uint64(1) << uint64(serialBits)) / uint64(div)) const maxForwardDistance uint16 = 1<<(serialBits-1) - 1 const maxBackwardDistance uint16 = 1 << (serialBits - 1) for i := uint16(0); i < uint16(div); i++ { s1 := i * interval s2f := s1 + maxForwardDistance s2b := s1 + maxBackwardDistance assert.True(t, sna16LT(s1, s2f), "s1 < s2 should be true: s1=0x%x s2=0x%x", s1, s2f) assert.False(t, sna16LT(s1, s2b), "s1 < s2 should be false: s1=0x%x s2=0x%x", s1, s2b) assert.False(t, sna16GT(s1, s2f), "s1 > s2 should be fales: s1=0x%x s2=0x%x", s1, s2f) assert.True(t, sna16GT(s1, s2b), "s1 > s2 should be true: s1=0x%x s2=0x%x", s1, s2b) assert.True(t, sna16LTE(s1, s2f), "s1 <= s2 should be true: s1=0x%x s2=0x%x", s1, s2f) assert.False(t, sna16LTE(s1, s2b), "s1 <= s2 should be false: s1=0x%x s2=0x%x", s1, s2b) assert.False(t, sna16GTE(s1, s2f), "s1 >= s2 should be fales: s1=0x%x s2=0x%x", s1, s2f) assert.True(t, sna16GTE(s1, s2b), "s1 >= s2 should be true: s1=0x%x s2=0x%x", s1, s2b) assert.True(t, sna16EQ(s1, s1), "s1 == s1 should be true: s1=0x%x s2=0x%x", s1, s1) assert.True(t, sna16EQ(s2b, s2b), "s2 == s2 should be true: s2=0x%x s2=0x%x", s2b, s2b) assert.False(t, sna16EQ(s1, s1+1), "s1 == s1+1 should be false: s1=0x%x s1+1=0x%x", s1, s1+1) assert.False(t, sna16EQ(s1, s1-1), "s1 == s1-1 hould be false: s1=0x%x s1-1=0x%x", s1, s1-1) assert.True(t, sna16LTE(s1, s1), "s1 == s1 should be true: s1=0x%x s2=0x%x", s1, s1) assert.True(t, sna16LTE(s2b, s2b), "s2 == s2 should be true: s2=0x%x s2=0x%x", s2b, s2b) assert.True(t, sna16GTE(s1, s1), "s1 == s1 should be true: s1=0x%x s2=0x%x", s1, s1) assert.True(t, sna16GTE(s2b, s2b), "s2 == s2 should be true: s2=0x%x s2=0x%x", s2b, s2b) } }) } sctp-1.8.6/vnet_test.go000066400000000000000000000407671436021606300150600ustar00rootroot00000000000000package sctp import ( "bytes" "fmt" "math/rand" "net" "reflect" "testing" "time" "github.com/pion/logging" "github.com/pion/transport/test" "github.com/pion/transport/vnet" "github.com/stretchr/testify/assert" ) type vNetEnvConfig struct { minDelay time.Duration loggerFactory logging.LoggerFactory log logging.LeveledLogger } type vNetEnv struct { wan *vnet.Router net0 *vnet.Net net1 *vnet.Net numToDropData int numToDropReconfig int numToDropCookieEcho int numToDropCookieAck int } var errSCTPPacketParse = fmt.Errorf("unable to parse SCTP packet") func (venv *vNetEnv) dropNextDataChunk(numToDrop int) { venv.numToDropData = numToDrop } func (venv *vNetEnv) dropNextReconfigChunk(numToDrop int) { // nolint:unused venv.numToDropReconfig = numToDrop } func (venv *vNetEnv) dropNextCookieEchoChunk(numToDrop int) { venv.numToDropCookieEcho = numToDrop } func (venv *vNetEnv) dropNextCookieAckChunk(numToDrop int) { venv.numToDropCookieAck = numToDrop } func buildVNetEnv(cfg *vNetEnvConfig) (*vNetEnv, error) { log := cfg.log var venv *vNetEnv serverIP := "1.1.1.1" clientIP := "2.2.2.2" wan, err := vnet.NewRouter(&vnet.RouterConfig{ CIDR: "0.0.0.0/0", MinDelay: cfg.minDelay, MaxJitter: 0 * time.Millisecond, LoggerFactory: cfg.loggerFactory, }) if err != nil { return nil, err } tsnAutoLockOnFilter := func() func(vnet.Chunk) bool { var lockedOnTSN bool var tsn uint32 return func(c vnet.Chunk) bool { var toDrop bool p := &packet{} if err2 := p.unmarshal(c.UserData()); err2 != nil { panic(errSCTPPacketParse) } loop: for i := 0; i < len(p.chunks); i++ { switch chunk := p.chunks[i].(type) { case *chunkPayloadData: if venv.numToDropData > 0 { if !lockedOnTSN { tsn = chunk.tsn lockedOnTSN = true log.Infof("Chunk filter: lock on TSN %d", tsn) } if chunk.tsn == tsn { toDrop = true venv.numToDropData-- log.Infof("Chunk filter: drop TSN %d", tsn) break loop } } case *chunkReconfig: if venv.numToDropReconfig > 0 { toDrop = true venv.numToDropReconfig-- log.Infof("Chunk filter: drop RECONFIG %s", chunk.String()) break loop } case *chunkCookieEcho: if venv.numToDropCookieEcho > 0 { toDrop = true venv.numToDropCookieEcho-- log.Infof("Chunk filter: drop %s", chunk.String()) break loop } case *chunkCookieAck: if venv.numToDropCookieAck > 0 { toDrop = true venv.numToDropCookieAck-- log.Infof("Chunk filter: drop %s", chunk.String()) break loop } } } return !toDrop } } wan.AddChunkFilter(tsnAutoLockOnFilter()) net0 := vnet.NewNet(&vnet.NetConfig{ StaticIPs: []string{serverIP}, }) err = wan.AddNet(net0) if err != nil { return nil, err } net1 := vnet.NewNet(&vnet.NetConfig{ StaticIPs: []string{clientIP}, }) err = wan.AddNet(net1) if err != nil { return nil, err } err = wan.Start() if err != nil { return nil, err } venv = &vNetEnv{ wan: wan, net0: net0, net1: net1, } return venv, nil } func testRwndFull(t *testing.T, unordered bool) { loggerFactory := logging.NewDefaultLoggerFactory() log := loggerFactory.NewLogger("test") venv, err := buildVNetEnv(&vNetEnvConfig{ minDelay: 200 * time.Millisecond, loggerFactory: loggerFactory, log: log, }) if !assert.NoError(t, err, "should succeed") { return } if !assert.NotNil(t, venv, "should not be nil") { return } defer venv.wan.Stop() // nolint:errcheck serverHandshakeDone := make(chan struct{}) clientHandshakeDone := make(chan struct{}) serverStreamReady := make(chan struct{}) clientStreamReady := make(chan struct{}) clientStartWrite := make(chan struct{}) serverRecvBufFull := make(chan struct{}) serverStartRead := make(chan struct{}) serverReadAll := make(chan struct{}) clientShutDown := make(chan struct{}) serverShutDown := make(chan struct{}) shutDownClient := make(chan struct{}) shutDownServer := make(chan struct{}) maxReceiveBufferSize := uint32(64 * 1024) msgSize := int(float32(maxReceiveBufferSize)/2) + int(initialMTU) msg := make([]byte, msgSize) rand.Read(msg) // nolint:errcheck,gosec go func() { defer close(serverShutDown) // connected UDP conn for server conn, err := venv.net0.DialUDP("udp4", &net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: 5000}, &net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: 5000}, ) if !assert.NoError(t, err, "should succeed") { return } defer conn.Close() // nolint:errcheck // server association assoc, err := Server(Config{ NetConn: conn, MaxReceiveBufferSize: maxReceiveBufferSize, LoggerFactory: loggerFactory, }) if !assert.NoError(t, err, "should succeed") { return } defer assoc.Close() // nolint:errcheck log.Info("server handshake complete") close(serverHandshakeDone) stream, err := assoc.AcceptStream() if !assert.NoError(t, err, "should succeed") { return } defer stream.Close() // nolint:errcheck // Expunge the first HELLO packet buf := make([]byte, 64*1024) n, err := stream.Read(buf) if !assert.NoError(t, err, "should succeed") { return } assert.Equal(t, "HELLO", string(buf[:n]), "should match") stream.SetReliabilityParams(unordered, ReliabilityTypeReliable, 0) log.Info("server stream ready") close(serverStreamReady) for { assoc.lock.RLock() rbufSize := assoc.getMyReceiverWindowCredit() log.Infof("rbufSize = %d", rbufSize) assoc.lock.RUnlock() if rbufSize == 0 { break } time.Sleep(50 * time.Millisecond) } close(serverRecvBufFull) <-serverStartRead for i := 0; i < 2; i++ { n, err = stream.Read(buf) if !assert.NoError(t, err, "should succeed") { return } if !assert.NoError(t, err, "should succeed") { return } log.Infof("server read %d bytes", n) assert.True(t, reflect.DeepEqual(msg, buf[:n]), "msg %d should match", i) } close(serverReadAll) <-shutDownServer log.Info("server closing") }() go func() { defer close(clientShutDown) // connected UDP conn for client conn, err := venv.net1.DialUDP("udp4", &net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: 5000}, &net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: 5000}, ) if !assert.NoError(t, err, "should succeed") { return } // client association assoc, err := Client(Config{ NetConn: conn, MaxReceiveBufferSize: maxReceiveBufferSize, LoggerFactory: loggerFactory, }) if !assert.NoError(t, err, "should succeed") { return } defer assoc.Close() // nolint:errcheck log.Info("client handshake complete") close(clientHandshakeDone) stream, err := assoc.OpenStream(777, PayloadTypeWebRTCBinary) if !assert.NoError(t, err, "should succeed") { return } defer stream.Close() // nolint:errcheck // Send a message to let server side stream to open _, err = stream.Write([]byte("HELLO")) if !assert.NoError(t, err, "should succeed") { return } stream.SetReliabilityParams(unordered, ReliabilityTypeReliable, 0) log.Info("client stream ready") close(clientStreamReady) <-clientStartWrite // Set the cwnd and rwnd to the size large enough to send the large messages // right away assoc.lock.Lock() assoc.cwnd = 2 * maxReceiveBufferSize assoc.rwnd = 2 * maxReceiveBufferSize assoc.lock.Unlock() // Send two large messages so that the second one will // cause receiver side buffer full for i := 0; i < 2; i++ { _, err = stream.Write(msg) if !assert.NoError(t, err, "should succeed") { return } } <-shutDownClient log.Info("client closing") }() // // Scenario // // wait until both handshake complete <-clientHandshakeDone <-serverHandshakeDone log.Info("handshake complete") // wait until both establish a stream <-clientStreamReady <-serverStreamReady log.Info("stream ready") // drop next 1 DATA chunk sent to the server venv.dropNextDataChunk(1) // let client begin writing log.Info("client start writing") close(clientStartWrite) // wait until the server's receive buffer becomes full <-serverRecvBufFull // let server start reading close(serverStartRead) // wait until the server receives all data log.Info("let server start reading") <-serverReadAll log.Info("server received all data") close(shutDownClient) <-clientShutDown close(shutDownServer) <-serverShutDown log.Info("all done") } func TestRwndFull(t *testing.T) { t.Run("Ordered", func(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 10) defer lim.Stop() testRwndFull(t, false) }) t.Run("Unordered", func(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 10) defer lim.Stop() testRwndFull(t, true) }) } func TestStreamClose(t *testing.T) { loopBackTest := func(t *testing.T, dropReconfigChunk bool) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() loggerFactory := logging.NewDefaultLoggerFactory() log := loggerFactory.NewLogger("test") venv, err := buildVNetEnv(&vNetEnvConfig{ loggerFactory: loggerFactory, log: log, }) if !assert.NoError(t, err, "should succeed") { return } if !assert.NotNil(t, venv, "should not be nil") { return } defer venv.wan.Stop() // nolint:errcheck clientShutDown := make(chan struct{}) serverShutDown := make(chan struct{}) const numMessages = 10 const messageSize = 1024 var messages [][]byte var numServerReceived int var numClientReceived int for i := 0; i < numMessages; i++ { bytes := make([]byte, messageSize) messages = append(messages, bytes) } go func() { defer close(serverShutDown) // connected UDP conn for server conn, innerErr := venv.net0.DialUDP("udp4", &net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: 5000}, &net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: 5000}, ) if !assert.NoError(t, innerErr, "should succeed") { return } defer conn.Close() // nolint:errcheck // server association assoc, innerErr := Server(Config{ NetConn: conn, LoggerFactory: loggerFactory, }) if !assert.NoError(t, innerErr, "should succeed") { return } defer assoc.Close() // nolint:errcheck log.Info("server handshake complete") stream, innerErr := assoc.AcceptStream() if !assert.NoError(t, innerErr, "should succeed") { return } assert.Equal(t, StreamStateOpen, stream.State()) buf := make([]byte, 1500) for { n, errRead := stream.Read(buf) if errRead != nil { log.Infof("server: Read returned %v", errRead) _ = stream.Close() // nolint:errcheck assert.Equal(t, StreamStateClosed, stream.State()) break } log.Infof("server: received %d bytes (%d)", n, numServerReceived) assert.Equal(t, 0, bytes.Compare(buf[:n], messages[numServerReceived]), "should receive HELLO") _, err2 := stream.Write(buf[:n]) assert.NoError(t, err2, "should succeed") numServerReceived++ } // don't close association until the client's stream routine is complete <-clientShutDown }() // connected UDP conn for client conn, err := venv.net1.DialUDP("udp4", &net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: 5000}, &net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: 5000}, ) if !assert.NoError(t, err, "should succeed") { return } defer conn.Close() // nolint:errcheck // client association assoc, err := Client(Config{ NetConn: conn, LoggerFactory: loggerFactory, }) if !assert.NoError(t, err, "should succeed") { return } defer assoc.Close() // nolint:errcheck log.Info("client handshake complete") stream, err := assoc.OpenStream(777, PayloadTypeWebRTCBinary) if !assert.NoError(t, err, "should succeed") { return } assert.Equal(t, StreamStateOpen, stream.State()) stream.SetReliabilityParams(false, ReliabilityTypeReliable, 0) // begin client read-loop buf := make([]byte, 1500) go func() { defer close(clientShutDown) for { n, err2 := stream.Read(buf) if err2 != nil { log.Infof("client: Read returned %v", err2) assert.Equal(t, StreamStateClosed, stream.State()) break } log.Infof("client: received %d bytes (%d)", n, numClientReceived) assert.Equal(t, 0, bytes.Compare(buf[:n], messages[numClientReceived]), "should receive HELLO") numClientReceived++ } }() // Send messages to the server for i := 0; i < numMessages; i++ { _, err = stream.Write(messages[i]) assert.NoError(t, err, "should succeed") } if dropReconfigChunk { venv.dropNextReconfigChunk(1) } // Immediately close the stream err = stream.Close() assert.NoError(t, err, "should succeed") assert.Equal(t, StreamStateClosing, stream.State()) log.Info("client wait for exit reading..") <-clientShutDown assert.Equal(t, numMessages, numServerReceived, "all messages should be received") assert.Equal(t, numMessages, numClientReceived, "all messages should be received") _, err = stream.Write([]byte{1}) assert.Equal(t, err, ErrStreamClosed, "after closed should not allow write") // Check if RECONFIG was actually dropped assert.Equal(t, 0, venv.numToDropReconfig, "should be zero") // Sleep enough time for reconfig response to come back time.Sleep(100 * time.Millisecond) // Verify there's no more pending reconfig assoc.lock.RLock() pendingReconfigs := len(assoc.reconfigs) assoc.lock.RUnlock() assert.Equal(t, 0, pendingReconfigs, "should be zero") } t.Run("without dropping Reconfig", func(t *testing.T) { loopBackTest(t, false) }) t.Run("with dropping Reconfig", func(t *testing.T) { loopBackTest(t, true) }) } // this test case reproduces the issue mentioned in // https://github.com/pion/webrtc/issues/1270#issuecomment-653953743 // and confirmes the fix. // To reproduce the case mentioned above: // * Use simultaneous-open (SCTP) // * Drop both of the first COOKIE-ECHO and COOKIE-ACK func TestCookieEchoRetransmission(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() loggerFactory := logging.NewDefaultLoggerFactory() log := loggerFactory.NewLogger("test") venv, err := buildVNetEnv(&vNetEnvConfig{ minDelay: 200 * time.Millisecond, loggerFactory: loggerFactory, log: log, }) if !assert.NoError(t, err, "should succeed") { return } if !assert.NotNil(t, venv, "should not be nil") { return } defer venv.wan.Stop() // nolint:errcheck // To cause the cookie echo retransmission, both COOKIE-ECHO // and COOKIE-ACK chunks need to be dropped at the same time. venv.dropNextCookieEchoChunk(1) venv.dropNextCookieAckChunk(1) serverHandshakeDone := make(chan struct{}) clientHandshakeDone := make(chan struct{}) waitAllHandshakeDone := make(chan struct{}) clientShutDown := make(chan struct{}) serverShutDown := make(chan struct{}) maxReceiveBufferSize := uint32(64 * 1024) // Go routine for Server go func() { defer close(serverShutDown) // connected UDP conn for server conn, err := venv.net0.DialUDP("udp4", &net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: 5000}, &net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: 5000}, ) if !assert.NoError(t, err, "should succeed") { return } defer conn.Close() // nolint:errcheck // server association // using Client for simultaneous open assoc, err := Client(Config{ NetConn: conn, MaxReceiveBufferSize: maxReceiveBufferSize, LoggerFactory: loggerFactory, }) if !assert.NoError(t, err, "should succeed") { return } defer assoc.Close() // nolint:errcheck log.Info("server handshake complete") close(serverHandshakeDone) <-waitAllHandshakeDone }() // Go routine for Client go func() { defer close(clientShutDown) // connected UDP conn for client conn, err := venv.net1.DialUDP("udp4", &net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: 5000}, &net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: 5000}, ) if !assert.NoError(t, err, "should succeed") { return } // client association assoc, err := Client(Config{ NetConn: conn, MaxReceiveBufferSize: maxReceiveBufferSize, LoggerFactory: loggerFactory, }) if !assert.NoError(t, err, "should succeed") { return } defer assoc.Close() // nolint:errcheck log.Info("client handshake complete") close(clientHandshakeDone) <-waitAllHandshakeDone }() // // Scenario // // wait until both handshake complete <-clientHandshakeDone <-serverHandshakeDone close(waitAllHandshakeDone) log.Info("handshake complete") <-clientShutDown <-serverShutDown log.Info("all done") }