pax_global_header00006660000000000000000000000064143710606240014515gustar00rootroot0000000000000052 comment=1bdeef256f2943f28f23d37b621be7d1a0a49572 srtp-2.0.12/000077500000000000000000000000001437106062400125675ustar00rootroot00000000000000srtp-2.0.12/.github/000077500000000000000000000000001437106062400141275ustar00rootroot00000000000000srtp-2.0.12/.github/.gitignore000066400000000000000000000000121437106062400161100ustar00rootroot00000000000000.goassets srtp-2.0.12/.github/fetch-scripts.sh000077500000000000000000000014351437106062400172470ustar00rootroot00000000000000#!/bin/sh # # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # set -eu SCRIPT_PATH="$(realpath "$(dirname "$0")")" GOASSETS_PATH="${SCRIPT_PATH}/.goassets" GOASSETS_REF=${GOASSETS_REF:-master} if [ -d "${GOASSETS_PATH}" ]; then if ! git -C "${GOASSETS_PATH}" diff --exit-code; then echo "${GOASSETS_PATH} has uncommitted changes" >&2 exit 1 fi git -C "${GOASSETS_PATH}" fetch origin git -C "${GOASSETS_PATH}" checkout ${GOASSETS_REF} git -C "${GOASSETS_PATH}" reset --hard origin/${GOASSETS_REF} else git clone -b ${GOASSETS_REF} https://github.com/pion/.goassets.git "${GOASSETS_PATH}" fi srtp-2.0.12/.github/install-hooks.sh000077500000000000000000000010771437106062400172620ustar00rootroot00000000000000#!/bin/sh # # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # SCRIPT_PATH="$(realpath "$(dirname "$0")")" . ${SCRIPT_PATH}/fetch-scripts.sh cp "${GOASSETS_PATH}/hooks/commit-msg.sh" "${SCRIPT_PATH}/../.git/hooks/commit-msg" cp "${GOASSETS_PATH}/hooks/pre-commit.sh" "${SCRIPT_PATH}/../.git/hooks/pre-commit" cp "${GOASSETS_PATH}/hooks/pre-push.sh" "${SCRIPT_PATH}/../.git/hooks/pre-push" srtp-2.0.12/.github/workflows/000077500000000000000000000000001437106062400161645ustar00rootroot00000000000000srtp-2.0.12/.github/workflows/codeql-analysis.yml000066400000000000000000000011551437106062400220010ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # If this repository should have package specific CI config, # remove the repository name from .goassets/.github/workflows/assets-sync.yml. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # name: CodeQL on: workflow_dispatch: schedule: - cron: '23 5 * * 0' pull_request: branches: - master paths: - '**.go' jobs: analyze: uses: pion/.goassets/.github/workflows/codeql-analysis.reusable.yml@master srtp-2.0.12/.github/workflows/generate-authors.yml000066400000000000000000000011041437106062400221600ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # If this repository should have package specific CI config, # remove the repository name from .goassets/.github/workflows/assets-sync.yml. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # name: Generate Authors on: pull_request: jobs: generate: uses: pion/.goassets/.github/workflows/generate-authors.reusable.yml@master secrets: token: ${{ secrets.PIONBOT_PRIVATE_KEY }} srtp-2.0.12/.github/workflows/lint.yaml000066400000000000000000000007521437106062400200220ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # If this repository should have package specific CI config, # remove the repository name from .goassets/.github/workflows/assets-sync.yml. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # name: Lint on: pull_request: jobs: lint: uses: pion/.goassets/.github/workflows/lint.reusable.yml@master srtp-2.0.12/.github/workflows/release.yml000066400000000000000000000011051437106062400203240ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # If this repository should have package specific CI config, # remove the repository name from .goassets/.github/workflows/assets-sync.yml. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # name: Release on: push: tags: - 'v*' jobs: release: uses: pion/.goassets/.github/workflows/release.reusable.yml@master with: go-version: '1.19' # auto-update/latest-go-version srtp-2.0.12/.github/workflows/renovate-go-sum-fix.yaml000066400000000000000000000011241437106062400226620ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # If this repository should have package specific CI config, # remove the repository name from .goassets/.github/workflows/assets-sync.yml. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # name: Fix go.sum on: push: branches: - renovate/* jobs: fix: uses: pion/.goassets/.github/workflows/renovate-go-sum-fix.reusable.yml@master secrets: token: ${{ secrets.PIONBOT_PRIVATE_KEY }} srtp-2.0.12/.github/workflows/test.yaml000066400000000000000000000021121437106062400200230ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # If this repository should have package specific CI config, # remove the repository name from .goassets/.github/workflows/assets-sync.yml. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # name: Test on: push: branches: - master pull_request: jobs: test: uses: pion/.goassets/.github/workflows/test.reusable.yml@master strategy: matrix: go: ['1.19', '1.18'] # auto-update/supported-go-version-list fail-fast: false with: go-version: ${{ matrix.go }} test-i386: uses: pion/.goassets/.github/workflows/test-i386.reusable.yml@master strategy: matrix: go: ['1.19', '1.18'] # auto-update/supported-go-version-list fail-fast: false with: go-version: ${{ matrix.go }} test-wasm: uses: pion/.goassets/.github/workflows/test-wasm.reusable.yml@master with: go-version: '1.19' # auto-update/latest-go-version srtp-2.0.12/.github/workflows/tidy-check.yaml000066400000000000000000000011371437106062400210760ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # If this repository should have package specific CI config, # remove the repository name from .goassets/.github/workflows/assets-sync.yml. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # name: Go mod tidy on: pull_request: push: branches: - master jobs: tidy: uses: pion/.goassets/.github/workflows/tidy-check.reusable.yml@master with: go-version: '1.19' # auto-update/latest-go-version srtp-2.0.12/.gitignore000066400000000000000000000004661437106062400145650ustar00rootroot00000000000000### JetBrains IDE ### ##################### .idea/ ### Emacs Temporary Files ### ############################# *~ ### Folders ### ############### bin/ vendor/ node_modules/ ### Files ### ############# *.ivf *.ogg tags cover.out *.sw[poe] *.wasm examples/sfu-ws/cert.pem examples/sfu-ws/key.pem wasm_exec.js srtp-2.0.12/.golangci.yml000066400000000000000000000172751437106062400151670ustar00rootroot00000000000000linters-settings: govet: check-shadowing: true misspell: locale: US exhaustive: default-signifies-exhaustive: true gomodguard: blocked: modules: - github.com/pkg/errors: recommendations: - errors linters: enable: - asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers - bidichk # Checks for dangerous unicode character sequences - bodyclose # checks whether HTTP response body is closed successfully - contextcheck # check the function whether use a non-inherited context - decorder # check declaration order and count of types, constants, variables and functions - depguard # Go linter that checks if package imports are in a list of acceptable packages - dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) - dupl # Tool for code clone detection - durationcheck # check for two durations multiplied together - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases - errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted. - errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`. - errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13. - exhaustive # check exhaustiveness of enum switch statements - exportloopref # checks for pointers to enclosing loop variables - forcetypeassert # finds forced type assertions - gci # Gci control golang package import order and make it always deterministic. - gochecknoglobals # Checks that no globals are present in Go code - gochecknoinits # Checks that no init functions are present in Go code - gocognit # Computes and checks the cognitive complexity of functions - goconst # Finds repeated strings that could be replaced by a constant - gocritic # The most opinionated Go source code linter - godox # Tool for detection of FIXME, TODO and other comment keywords - goerr113 # Golang linter to check the errors handling expressions - gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification - gofumpt # Gofumpt checks whether code was gofumpt-ed. - goheader # Checks is file header matches to pattern - goimports # Goimports does everything that gofmt does. Additionally it checks unused imports - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. - goprintffuncname # Checks that printf-like functions are named with `f` at the end - gosec # Inspects source code for security problems - gosimple # Linter for Go source code that specializes in simplifying a code - govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string - grouper # An analyzer to analyze expression groups. - importas # Enforces consistent import aliases - ineffassign # Detects when assignments to existing variables are not used - misspell # Finds commonly misspelled English words in comments - nakedret # Finds naked returns in functions greater than a specified function length - nilerr # Finds the code that returns nil even if it checks that the error is not nil. - nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value. - noctx # noctx finds sending http request without context.Context - predeclared # find code that shadows one of Go's predeclared identifiers - revive # golint replacement, finds style mistakes - staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks - stylecheck # Stylecheck is a replacement for golint - tagliatelle # Checks the struct tags. - tenv # tenv is analyzer that detects using os.Setenv instead of t.Setenv since Go1.17 - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes - typecheck # Like the front-end of a Go compiler, parses and type-checks Go code - unconvert # Remove unnecessary type conversions - unparam # Reports unused function parameters - unused # Checks Go code for unused constants, variables, functions and types - wastedassign # wastedassign finds wasted assignment statements - whitespace # Tool for detection of leading and trailing whitespace disable: - containedctx # containedctx is a linter that detects struct contained context.Context field - cyclop # checks function and package cyclomatic complexity - exhaustivestruct # Checks if all struct's fields are initialized - forbidigo # Forbids identifiers - funlen # Tool for detection of long functions - gocyclo # Computes and checks the cyclomatic complexity of functions - godot # Check if comments end in a period - gomnd # An analyzer to detect magic numbers. - ifshort # Checks that your code uses short syntax for if-statements whenever possible - ireturn # Accept Interfaces, Return Concrete Types - lll # Reports long lines - maintidx # maintidx measures the maintainability index of each function. - makezero # Finds slice declarations with non-zero initial length - maligned # Tool to detect Go structs that would take less memory if their fields were sorted - nestif # Reports deeply nested if statements - nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity - nolintlint # Reports ill-formed or insufficient nolint directives - paralleltest # paralleltest detects missing usage of t.Parallel() method in your Go test - prealloc # Finds slice declarations that could potentially be preallocated - promlinter # Check Prometheus metrics naming via promlint - rowserrcheck # checks whether Err of rows is checked successfully - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. - testpackage # linter that makes you use a separate _test package - thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers - varnamelen # checks that the length of a variable's name matches its scope - wrapcheck # Checks that errors returned from external packages are wrapped - wsl # Whitespace Linter - Forces you to use empty lines! issues: exclude-use-default: false exclude-rules: # Allow complex tests, better to be self contained - path: _test\.go linters: - gocognit # Allow complex main function in examples - path: examples text: "of func `main` is high" linters: - gocognit run: skip-dirs-use-default: false srtp-2.0.12/.goreleaser.yml000066400000000000000000000000251437106062400155150ustar00rootroot00000000000000builds: - skip: true srtp-2.0.12/AUTHORS.txt000066400000000000000000000022321437106062400144540ustar00rootroot00000000000000# Thank you to everyone that made Pion possible. If you are interested in contributing # we would love to have you https://github.com/pion/webrtc/wiki/Contributing # # This file is auto generated, using git to list all individuals contributors. # see https://github.com/pion/.goassets/blob/master/scripts/generate-authors.sh for the scripting adamroach Adrian Cable Agniva De Sarker Atsushi Watanabe backkem chenkaiC4 Chris Hiszpanski cszdlt Hugo Arregui Jerko Steiner Juliusz Chroboczek Luke Curley Luke Curley Max Hawkins mission-liao Novel Corpse OrlandoCo Sean DuBois Sean DuBois Tobias Fridén Woodrow Douglass Yutaka Takeda # List of contributors not appearing in Git history srtp-2.0.12/DESIGN.md000066400000000000000000000014451437106062400140660ustar00rootroot00000000000000

Design

### Portable Pion SRTP is written in Go and extremely portable. Anywhere Golang runs, Pion SRTP should work as well! Instead of dealing with complicated cross-compiling of multiple libraries, you now can run anywhere with one `go build` ### Simple API The API is based on an io.ReadWriteCloser. ### Readable If code comes from an RFC we try to make sure everything is commented with a link to the spec. This makes learning and debugging easier, this library was written to also serve as a guide for others. ### Tested Every commit is tested via travis-ci Go provides fantastic facilities for testing, and more will be added as time goes on. ### Shared libraries Every pion product is built using shared libraries, allowing others to review and reuse our libraries. srtp-2.0.12/LICENSE000066400000000000000000000020411437106062400135710ustar00rootroot00000000000000MIT License Copyright (c) 2018 Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. srtp-2.0.12/README.md000066400000000000000000000043151437106062400140510ustar00rootroot00000000000000


Pion SRTP

A Go implementation of SRTP

Pion SRTP Sourcegraph Widget Slack Widget
Build Status GoDoc Coverage Status Go Report Card License: MIT


See [DESIGN.md](DESIGN.md) for an overview of features and future goals. ### Roadmap The library is used as a part of our WebRTC implementation. Please refer to that [roadmap](https://github.com/pion/webrtc/issues/9) to track our major milestones. ### Community Pion has an active community on the [Golang Slack](https://invite.slack.golangbridge.org/). Sign up and join the **#pion** channel for discussions and support. You can also use [Pion mailing list](https://groups.google.com/forum/#!forum/pion). We are always looking to support **your projects**. Please reach out if you have something to build! If you need commercial support or don't want to use public methods you can contact us at [team@pion.ly](mailto:team@pion.ly) ### Contributing Check out the **[contributing wiki](https://github.com/pion/webrtc/wiki/Contributing)** to join the group of amazing people making this project possible: ### License MIT License - see [LICENSE](LICENSE) for full text srtp-2.0.12/codecov.yml000066400000000000000000000005521437106062400147360ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # coverage: status: project: default: # Allow decreasing 2% of total coverage to avoid noise. threshold: 2% patch: default: target: 70% only_pulls: true ignore: - "examples/*" - "examples/**/*" srtp-2.0.12/context.go000066400000000000000000000133121437106062400146020ustar00rootroot00000000000000package srtp import ( "fmt" "github.com/pion/transport/v2/replaydetector" ) const ( labelSRTPEncryption = 0x00 labelSRTPAuthenticationTag = 0x01 labelSRTPSalt = 0x02 labelSRTCPEncryption = 0x03 labelSRTCPAuthenticationTag = 0x04 labelSRTCPSalt = 0x05 maxSequenceNumber = 65535 maxROC = (1 << 32) - 1 seqNumMedian = 1 << 15 seqNumMax = 1 << 16 srtcpIndexSize = 4 ) // Encrypt/Decrypt state for a single SRTP SSRC type srtpSSRCState struct { ssrc uint32 index uint64 rolloverHasProcessed bool replayDetector replaydetector.ReplayDetector } // Encrypt/Decrypt state for a single SRTCP SSRC type srtcpSSRCState struct { srtcpIndex uint32 ssrc uint32 replayDetector replaydetector.ReplayDetector } // Context represents a SRTP cryptographic context. // Context can only be used for one-way operations. // it must either used ONLY for encryption or ONLY for decryption. // Note that Context does not provide any concurrency protection: // access to a Context from multiple goroutines requires external // synchronization. type Context struct { cipher srtpCipher srtpSSRCStates map[uint32]*srtpSSRCState srtcpSSRCStates map[uint32]*srtcpSSRCState newSRTCPReplayDetector func() replaydetector.ReplayDetector newSRTPReplayDetector func() replaydetector.ReplayDetector } // CreateContext creates a new SRTP Context. // // CreateContext receives variable number of ContextOption-s. // Passing multiple options which set the same parameter let the last one valid. // Following example create SRTP Context with replay protection with window size of 256. // // decCtx, err := srtp.CreateContext(key, salt, profile, srtp.SRTPReplayProtection(256)) func CreateContext(masterKey, masterSalt []byte, profile ProtectionProfile, opts ...ContextOption) (c *Context, err error) { keyLen, err := profile.keyLen() if err != nil { return nil, err } saltLen, err := profile.saltLen() if err != nil { return nil, err } if masterKeyLen := len(masterKey); masterKeyLen != keyLen { return c, fmt.Errorf("%w expected(%d) actual(%d)", errShortSrtpMasterKey, masterKey, keyLen) } else if masterSaltLen := len(masterSalt); masterSaltLen != saltLen { return c, fmt.Errorf("%w expected(%d) actual(%d)", errShortSrtpMasterSalt, saltLen, masterSaltLen) } c = &Context{ srtpSSRCStates: map[uint32]*srtpSSRCState{}, srtcpSSRCStates: map[uint32]*srtcpSSRCState{}, } switch profile { case ProtectionProfileAeadAes128Gcm: c.cipher, err = newSrtpCipherAeadAesGcm(profile, masterKey, masterSalt) case ProtectionProfileAes128CmHmacSha1_32, ProtectionProfileAes128CmHmacSha1_80: c.cipher, err = newSrtpCipherAesCmHmacSha1(profile, masterKey, masterSalt) default: return nil, fmt.Errorf("%w: %#v", errNoSuchSRTPProfile, profile) } if err != nil { return nil, err } for _, o := range append( []ContextOption{ // Default options SRTPNoReplayProtection(), SRTCPNoReplayProtection(), }, opts..., // User specified options ) { if errOpt := o(c); errOpt != nil { return nil, errOpt } } return c, nil } // https://tools.ietf.org/html/rfc3550#appendix-A.1 func (s *srtpSSRCState) nextRolloverCount(sequenceNumber uint16) (roc uint32, diff int32, overflow bool) { seq := int32(sequenceNumber) localRoc := uint32(s.index >> 16) localSeq := int32(s.index & (seqNumMax - 1)) guessRoc := localRoc var difference int32 if s.rolloverHasProcessed { // When localROC is equal to 0, and entering seq-localSeq > seqNumMedian // judgment, it will cause guessRoc calculation error if s.index > seqNumMedian { if localSeq < seqNumMedian { if seq-localSeq > seqNumMedian { guessRoc = localRoc - 1 difference = seq - localSeq - seqNumMax } else { guessRoc = localRoc difference = seq - localSeq } } else { if localSeq-seqNumMedian > seq { guessRoc = localRoc + 1 difference = seq - localSeq + seqNumMax } else { guessRoc = localRoc difference = seq - localSeq } } } else { // localRoc is equal to 0 difference = seq - localSeq } } return guessRoc, difference, (guessRoc == 0 && localRoc == maxROC) } func (s *srtpSSRCState) updateRolloverCount(sequenceNumber uint16, difference int32) { if !s.rolloverHasProcessed { s.index |= uint64(sequenceNumber) s.rolloverHasProcessed = true return } if difference > 0 { s.index += uint64(difference) } } func (c *Context) getSRTPSSRCState(ssrc uint32) *srtpSSRCState { s, ok := c.srtpSSRCStates[ssrc] if ok { return s } s = &srtpSSRCState{ ssrc: ssrc, replayDetector: c.newSRTPReplayDetector(), } c.srtpSSRCStates[ssrc] = s return s } func (c *Context) getSRTCPSSRCState(ssrc uint32) *srtcpSSRCState { s, ok := c.srtcpSSRCStates[ssrc] if ok { return s } s = &srtcpSSRCState{ ssrc: ssrc, replayDetector: c.newSRTCPReplayDetector(), } c.srtcpSSRCStates[ssrc] = s return s } // ROC returns SRTP rollover counter value of specified SSRC. func (c *Context) ROC(ssrc uint32) (uint32, bool) { s, ok := c.srtpSSRCStates[ssrc] if !ok { return 0, false } return uint32(s.index >> 16), true } // SetROC sets SRTP rollover counter value of specified SSRC. func (c *Context) SetROC(ssrc uint32, roc uint32) { s := c.getSRTPSSRCState(ssrc) s.index = uint64(roc) << 16 s.rolloverHasProcessed = false } // Index returns SRTCP index value of specified SSRC. func (c *Context) Index(ssrc uint32) (uint32, bool) { s, ok := c.srtcpSSRCStates[ssrc] if !ok { return 0, false } return s.srtcpIndex, true } // SetIndex sets SRTCP index value of specified SSRC. func (c *Context) SetIndex(ssrc uint32, index uint32) { s := c.getSRTCPSSRCState(ssrc) s.srtcpIndex = index % (maxSRTCPIndex + 1) } srtp-2.0.12/context_test.go000066400000000000000000000015541437106062400156460ustar00rootroot00000000000000package srtp import ( "testing" ) func TestContextROC(t *testing.T) { c, err := CreateContext(make([]byte, 16), make([]byte, 14), profileCTR) if err != nil { t.Fatal(err) } if _, ok := c.ROC(123); ok { t.Error("ROC must return false for unused SSRC") } c.SetROC(123, 100) roc, ok := c.ROC(123) if !ok { t.Fatal("ROC must return true for used SSRC") } if roc != 100 { t.Errorf("ROC is set to 100, but returned %d", roc) } } func TestContextIndex(t *testing.T) { c, err := CreateContext(make([]byte, 16), make([]byte, 14), profileCTR) if err != nil { t.Fatal(err) } if _, ok := c.Index(123); ok { t.Error("Index must return false for unused SSRC") } c.SetIndex(123, 100) index, ok := c.Index(123) if !ok { t.Fatal("Index must return true for used SSRC") } if index != 100 { t.Errorf("Index is set to 100, but returned %d", index) } } srtp-2.0.12/crypto.go000066400000000000000000000014411437106062400144360ustar00rootroot00000000000000package srtp import ( "crypto/cipher" "github.com/pion/transport/v2/utils/xor" ) // incrementCTR increments a big-endian integer of arbitrary size. func incrementCTR(ctr []byte) { for i := len(ctr) - 1; i >= 0; i-- { ctr[i]++ if ctr[i] != 0 { break } } } // xorBytesCTR performs CTR encryption and decryption. // It is equivalent to cipher.NewCTR followed by XORKeyStream. func xorBytesCTR(block cipher.Block, iv []byte, dst, src []byte) error { if len(iv) != block.BlockSize() { return errBadIVLength } ctr := make([]byte, len(iv)) copy(ctr, iv) bs := block.BlockSize() stream := make([]byte, bs) i := 0 for i < len(src) { block.Encrypt(stream, ctr) incrementCTR(ctr) n := xor.XorBytes(dst[i:], src[i:], stream) if n == 0 { break } i += n } return nil } srtp-2.0.12/crypto_test.go000066400000000000000000000030101437106062400154670ustar00rootroot00000000000000package srtp import ( "crypto/aes" "crypto/cipher" "math/rand" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func xorBytesCTRReference(block cipher.Block, iv []byte, dst, src []byte) { stream := cipher.NewCTR(block, iv) stream.XORKeyStream(dst, src) } func TestXorBytesCTR(t *testing.T) { for keysize := 16; keysize < 64; keysize *= 2 { key := make([]byte, keysize) _, err := rand.Read(key) //nolint: gosec require.NoError(t, err) block, err := aes.NewCipher(key) require.NoError(t, err) iv := make([]byte, block.BlockSize()) for i := 0; i < 1500; i++ { src := make([]byte, i) dst := make([]byte, i) reference := make([]byte, i) _, err = rand.Read(iv) //nolint: gosec require.NoError(t, err) _, err = rand.Read(src) //nolint: gosec require.NoError(t, err) assert.NoError(t, xorBytesCTR(block, iv, dst, src)) xorBytesCTRReference(block, iv, reference, src) require.Equal(t, dst, reference) // test overlap assert.NoError(t, xorBytesCTR(block, iv, dst, dst)) xorBytesCTRReference(block, iv, reference, reference) require.Equal(t, dst, reference) } } } func TestXorBytesCTRInvalidIvLength(t *testing.T) { key := make([]byte, 16) block, err := aes.NewCipher(key) require.NoError(t, err) src := make([]byte, 1024) dst := make([]byte, 1024) test := func(iv []byte) { assert.Error(t, errBadIVLength, xorBytesCTR(block, iv, dst, src)) } test(make([]byte, block.BlockSize()-1)) test(make([]byte, block.BlockSize()+1)) } srtp-2.0.12/errors.go000066400000000000000000000033521437106062400144350ustar00rootroot00000000000000package srtp import ( "errors" "fmt" ) var ( errDuplicated = errors.New("duplicated packet") errShortSrtpMasterKey = errors.New("SRTP master key is not long enough") errShortSrtpMasterSalt = errors.New("SRTP master salt is not long enough") errNoSuchSRTPProfile = errors.New("no such SRTP Profile") errNonZeroKDRNotSupported = errors.New("indexOverKdr > 0 is not supported yet") errExporterWrongLabel = errors.New("exporter called with wrong label") errNoConfig = errors.New("no config provided") errNoConn = errors.New("no conn provided") errFailedToVerifyAuthTag = errors.New("failed to verify auth tag") errTooShortRTCP = errors.New("packet is too short to be rtcp packet") errPayloadDiffers = errors.New("payload differs") errStartedChannelUsedIncorrectly = errors.New("started channel used incorrectly, should only be closed") errBadIVLength = errors.New("bad iv length in xorBytesCTR") errExceededMaxPackets = errors.New("exceeded the maximum number of packets") errStreamNotInited = errors.New("stream has not been inited, unable to close") errStreamAlreadyClosed = errors.New("stream is already closed") errStreamAlreadyInited = errors.New("stream is already inited") errFailedTypeAssertion = errors.New("failed to cast child") ) type duplicatedError struct { Proto string // srtp or srtcp SSRC uint32 Index uint32 // sequence number or index } func (e *duplicatedError) Error() string { return fmt.Sprintf("%s ssrc=%d index=%d: %v", e.Proto, e.SSRC, e.Index, errDuplicated) } func (e *duplicatedError) Unwrap() error { return errDuplicated } srtp-2.0.12/go.mod000066400000000000000000000003311437106062400136720ustar00rootroot00000000000000module github.com/pion/srtp/v2 go 1.14 require ( github.com/pion/logging v0.2.2 github.com/pion/rtcp v1.2.10 github.com/pion/rtp v1.7.13 github.com/pion/transport/v2 v2.0.0 github.com/stretchr/testify v1.8.1 ) srtp-2.0.12/go.sum000066400000000000000000000121761437106062400137310ustar00rootroot00000000000000github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/pion/logging v0.2.2 h1:M9+AIj/+pxNsDfAT64+MAVgJO0rsyLnoJKCqf//DoeY= github.com/pion/logging v0.2.2/go.mod h1:k0/tDVsRCX2Mb2ZEmTqNa7CWsQPc+YYCB7Q+5pahoms= github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA= github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8= github.com/pion/rtcp v1.2.10 h1:nkr3uj+8Sp97zyItdN60tE/S6vk4al5CPRR6Gejsdjc= github.com/pion/rtcp v1.2.10/go.mod h1:ztfEwXZNLGyF1oQDttz/ZKIBaeeg/oWbRYqzBM9TL1I= github.com/pion/rtp v1.7.13 h1:qcHwlmtiI50t1XivvoawdCGTP4Uiypzfrsap+bijcoA= github.com/pion/rtp v1.7.13/go.mod h1:bDb5n+BFZxXx0Ea7E5qe+klMuqiBrP+w8XSjiWtCUko= github.com/pion/transport/v2 v2.0.0 h1:bsMYyqHCbkvHwj+eNCFBuxtlKndKfyGI2vaQmM3fIE4= github.com/pion/transport/v2 v2.0.0/go.mod h1:HS2MEBJTwD+1ZI2eSXSvHJx/HnzQqRy2/LXxt6eVMHc= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.1.0 h1:hZ/3BUoy5aId7sCpA/Tc5lt8DkFgdVS2onTpJsZ/fl0= golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0 h1:ljd4t30dBnAvMZaQCevtY0xLLD0A+bRZXbgLMLU1F/A= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= srtp-2.0.12/key_derivation.go000066400000000000000000000041651437106062400161400ustar00rootroot00000000000000package srtp import ( "crypto/aes" "encoding/binary" ) func aesCmKeyDerivation(label byte, masterKey, masterSalt []byte, indexOverKdr int, outLen int) ([]byte, error) { if indexOverKdr != 0 { // 24-bit "index DIV kdr" must be xored to prf input. return nil, errNonZeroKDRNotSupported } // https://tools.ietf.org/html/rfc3711#appendix-B.3 // The input block for AES-CM is generated by exclusive-oring the master salt with the // concatenation of the encryption key label 0x00 with (index DIV kdr), // - index is 'rollover count' and DIV is 'divided by' nMasterKey := len(masterKey) nMasterSalt := len(masterSalt) prfIn := make([]byte, nMasterKey) copy(prfIn[:nMasterSalt], masterSalt) prfIn[7] ^= label // The resulting value is then AES encrypted using the master key to get the cipher key. block, err := aes.NewCipher(masterKey) if err != nil { return nil, err } out := make([]byte, ((outLen+nMasterKey)/nMasterKey)*nMasterKey) var i uint16 for n := 0; n < outLen; n += nMasterKey { binary.BigEndian.PutUint16(prfIn[nMasterKey-2:], i) block.Encrypt(out[n:n+nMasterKey], prfIn) i++ } return out[:outLen], nil } // Generate IV https://tools.ietf.org/html/rfc3711#section-4.1.1 // where the 128-bit integer value IV SHALL be defined by the SSRC, the // SRTP packet index i, and the SRTP session salting key k_s, as below. // - ROC = a 32-bit unsigned rollover counter (ROC), which records how many // - times the 16-bit RTP sequence number has been reset to zero after // - passing through 65,535 // i = 2^16 * ROC + SEQ // IV = (salt*2 ^ 16) | (ssrc*2 ^ 64) | (i*2 ^ 16) func generateCounter(sequenceNumber uint16, rolloverCounter uint32, ssrc uint32, sessionSalt []byte) (counter [16]byte) { copy(counter[:], sessionSalt) counter[4] ^= byte(ssrc >> 24) counter[5] ^= byte(ssrc >> 16) counter[6] ^= byte(ssrc >> 8) counter[7] ^= byte(ssrc) counter[8] ^= byte(rolloverCounter >> 24) counter[9] ^= byte(rolloverCounter >> 16) counter[10] ^= byte(rolloverCounter >> 8) counter[11] ^= byte(rolloverCounter) counter[12] ^= byte(sequenceNumber >> 8) counter[13] ^= byte(sequenceNumber) return counter } srtp-2.0.12/key_derivation_test.go000066400000000000000000000052531437106062400171760ustar00rootroot00000000000000package srtp import ( "bytes" "testing" "github.com/stretchr/testify/assert" ) func TestValidSessionKeys(t *testing.T) { masterKey := []byte{0xE1, 0xF9, 0x7A, 0x0D, 0x3E, 0x01, 0x8B, 0xE0, 0xD6, 0x4F, 0xA3, 0x2C, 0x06, 0xDE, 0x41, 0x39} masterSalt := []byte{0x0E, 0xC6, 0x75, 0xAD, 0x49, 0x8A, 0xFE, 0xEB, 0xB6, 0x96, 0x0B, 0x3A, 0xAB, 0xE6} expectedSessionKey := []byte{0xC6, 0x1E, 0x7A, 0x93, 0x74, 0x4F, 0x39, 0xEE, 0x10, 0x73, 0x4A, 0xFE, 0x3F, 0xF7, 0xA0, 0x87} expectedSessionSalt := []byte{0x30, 0xCB, 0xBC, 0x08, 0x86, 0x3D, 0x8C, 0x85, 0xD4, 0x9D, 0xB3, 0x4A, 0x9A, 0xE1} expectedSessionAuthTag := []byte{0xCE, 0xBE, 0x32, 0x1F, 0x6F, 0xF7, 0x71, 0x6B, 0x6F, 0xD4, 0xAB, 0x49, 0xAF, 0x25, 0x6A, 0x15, 0x6D, 0x38, 0xBA, 0xA4} sessionKey, err := aesCmKeyDerivation(labelSRTPEncryption, masterKey, masterSalt, 0, len(masterKey)) if err != nil { t.Errorf("generateSessionKey failed: %v", err) } else if !bytes.Equal(sessionKey, expectedSessionKey) { t.Errorf("Session Key % 02x does not match expected % 02x", sessionKey, expectedSessionKey) } sessionSalt, err := aesCmKeyDerivation(labelSRTPSalt, masterKey, masterSalt, 0, len(masterSalt)) if err != nil { t.Errorf("generateSessionSalt failed: %v", err) } else if !bytes.Equal(sessionSalt, expectedSessionSalt) { t.Errorf("Session Salt % 02x does not match expected % 02x", sessionSalt, expectedSessionSalt) } authKeyLen, err := ProtectionProfileAes128CmHmacSha1_80.authKeyLen() assert.NoError(t, err) sessionAuthTag, err := aesCmKeyDerivation(labelSRTPAuthenticationTag, masterKey, masterSalt, 0, authKeyLen) if err != nil { t.Errorf("generateSessionAuthTag failed: %v", err) } else if !bytes.Equal(sessionAuthTag, expectedSessionAuthTag) { t.Errorf("Session Auth Tag % 02x does not match expected % 02x", sessionAuthTag, expectedSessionAuthTag) } } // This test asserts that calling aesCmKeyDerivation with a non-zero indexOverKdr fails // Currently this isn't supported, but the API makes sure we can add this in the future func TestIndexOverKDR(t *testing.T) { _, err := aesCmKeyDerivation(labelSRTPAuthenticationTag, []byte{}, []byte{}, 1, 0) assert.Error(t, err) } func BenchmarkGenerateCounter(b *testing.B) { masterKey := []byte{0x0d, 0xcd, 0x21, 0x3e, 0x4c, 0xbc, 0xf2, 0x8f, 0x01, 0x7f, 0x69, 0x94, 0x40, 0x1e, 0x28, 0x89} masterSalt := []byte{0x62, 0x77, 0x60, 0x38, 0xc0, 0x6d, 0xc9, 0x41, 0x9f, 0x6d, 0xd9, 0x43, 0x3e, 0x7c} s := &srtpSSRCState{ssrc: 4160032510} srtpSessionSalt, err := aesCmKeyDerivation(labelSRTPSalt, masterKey, masterSalt, 0, len(masterSalt)) assert.NoError(b, err) b.ResetTimer() for i := 0; i < b.N; i++ { generateCounter(32846, uint32(s.index>>16), s.ssrc, srtpSessionSalt) } } srtp-2.0.12/keying.go000066400000000000000000000032041437106062400144030ustar00rootroot00000000000000package srtp const labelExtractorDtlsSrtp = "EXTRACTOR-dtls_srtp" // KeyingMaterialExporter allows package SRTP to extract keying material type KeyingMaterialExporter interface { ExportKeyingMaterial(label string, context []byte, length int) ([]byte, error) } // ExtractSessionKeysFromDTLS allows setting the Config SessionKeys by // extracting them from DTLS. This behavior is defined in RFC5764: // https://tools.ietf.org/html/rfc5764 func (c *Config) ExtractSessionKeysFromDTLS(exporter KeyingMaterialExporter, isClient bool) error { keyLen, err := c.Profile.keyLen() if err != nil { return err } saltLen, err := c.Profile.saltLen() if err != nil { return err } keyingMaterial, err := exporter.ExportKeyingMaterial(labelExtractorDtlsSrtp, nil, (keyLen*2)+(saltLen*2)) if err != nil { return err } offset := 0 clientWriteKey := append([]byte{}, keyingMaterial[offset:offset+keyLen]...) offset += keyLen serverWriteKey := append([]byte{}, keyingMaterial[offset:offset+keyLen]...) offset += keyLen clientWriteKey = append(clientWriteKey, keyingMaterial[offset:offset+saltLen]...) offset += saltLen serverWriteKey = append(serverWriteKey, keyingMaterial[offset:offset+saltLen]...) if isClient { c.Keys.LocalMasterKey = clientWriteKey[0:keyLen] c.Keys.LocalMasterSalt = clientWriteKey[keyLen:] c.Keys.RemoteMasterKey = serverWriteKey[0:keyLen] c.Keys.RemoteMasterSalt = serverWriteKey[keyLen:] return nil } c.Keys.LocalMasterKey = serverWriteKey[0:keyLen] c.Keys.LocalMasterSalt = serverWriteKey[keyLen:] c.Keys.RemoteMasterKey = clientWriteKey[0:keyLen] c.Keys.RemoteMasterSalt = clientWriteKey[keyLen:] return nil } srtp-2.0.12/keying_test.go000066400000000000000000000036731437106062400154540ustar00rootroot00000000000000package srtp import ( "bytes" "crypto/rand" "fmt" "testing" ) type mockKeyingMaterialExporter struct { exported []byte } func (m *mockKeyingMaterialExporter) ExportKeyingMaterial(label string, context []byte, length int) ([]byte, error) { if label != labelExtractorDtlsSrtp { return nil, fmt.Errorf("%w: expected(%s) actual(%s)", errExporterWrongLabel, label, labelExtractorDtlsSrtp) } m.exported = make([]byte, length) if _, err := rand.Read(m.exported); err != nil { return nil, err } return m.exported, nil } func TestExtractSessionKeysFromDTLS(t *testing.T) { tt := []struct { config *Config }{ {&Config{Profile: ProtectionProfileAes128CmHmacSha1_80}}, } m := &mockKeyingMaterialExporter{} for i, tc := range tt { // Test client err := tc.config.ExtractSessionKeysFromDTLS(m, true) if err != nil { t.Errorf("failed to extract keys for %d-client: %v", i, err) } keys := tc.config.Keys clientMaterial := append([]byte{}, keys.LocalMasterKey...) clientMaterial = append(clientMaterial, keys.RemoteMasterKey...) clientMaterial = append(clientMaterial, keys.LocalMasterSalt...) clientMaterial = append(clientMaterial, keys.RemoteMasterSalt...) if !bytes.Equal(clientMaterial, m.exported) { t.Errorf("material reconstruction failed for %d-client:\n%#v\nexpected\n%#v", i, clientMaterial, m.exported) } // Test server err = tc.config.ExtractSessionKeysFromDTLS(m, false) if err != nil { t.Errorf("failed to extract keys for %d-server: %v", i, err) } keys = tc.config.Keys serverMaterial := append([]byte{}, keys.RemoteMasterKey...) serverMaterial = append(serverMaterial, keys.LocalMasterKey...) serverMaterial = append(serverMaterial, keys.RemoteMasterSalt...) serverMaterial = append(serverMaterial, keys.LocalMasterSalt...) if !bytes.Equal(serverMaterial, m.exported) { t.Errorf("material reconstruction failed for %d-server:\n%#v\nexpected\n%#v", i, serverMaterial, m.exported) } } } srtp-2.0.12/option.go000066400000000000000000000027311437106062400144310ustar00rootroot00000000000000package srtp import ( "github.com/pion/transport/v2/replaydetector" ) // ContextOption represents option of Context using the functional options pattern. type ContextOption func(*Context) error // SRTPReplayProtection sets SRTP replay protection window size. func SRTPReplayProtection(windowSize uint) ContextOption { // nolint:revive return func(c *Context) error { c.newSRTPReplayDetector = func() replaydetector.ReplayDetector { return replaydetector.New(windowSize, maxROC<<16|maxSequenceNumber) } return nil } } // SRTCPReplayProtection sets SRTCP replay protection window size. func SRTCPReplayProtection(windowSize uint) ContextOption { return func(c *Context) error { c.newSRTCPReplayDetector = func() replaydetector.ReplayDetector { return replaydetector.New(windowSize, maxSRTCPIndex) } return nil } } // SRTPNoReplayProtection disables SRTP replay protection. func SRTPNoReplayProtection() ContextOption { // nolint:revive return func(c *Context) error { c.newSRTPReplayDetector = func() replaydetector.ReplayDetector { return &nopReplayDetector{} } return nil } } // SRTCPNoReplayProtection disables SRTCP replay protection. func SRTCPNoReplayProtection() ContextOption { return func(c *Context) error { c.newSRTCPReplayDetector = func() replaydetector.ReplayDetector { return &nopReplayDetector{} } return nil } } type nopReplayDetector struct{} func (s *nopReplayDetector) Check(uint64) (func(), bool) { return func() {}, true } srtp-2.0.12/protection_profile.go000066400000000000000000000042751437106062400170340ustar00rootroot00000000000000package srtp import "fmt" // ProtectionProfile specifies Cipher and AuthTag details, similar to TLS cipher suite type ProtectionProfile uint16 // Supported protection profiles // See https://www.iana.org/assignments/srtp-protection/srtp-protection.xhtml const ( ProtectionProfileAes128CmHmacSha1_80 ProtectionProfile = 0x0001 ProtectionProfileAes128CmHmacSha1_32 ProtectionProfile = 0x0002 ProtectionProfileAeadAes128Gcm ProtectionProfile = 0x0007 ) func (p ProtectionProfile) keyLen() (int, error) { switch p { case ProtectionProfileAes128CmHmacSha1_32, ProtectionProfileAes128CmHmacSha1_80, ProtectionProfileAeadAes128Gcm: return 16, nil default: return 0, fmt.Errorf("%w: %#v", errNoSuchSRTPProfile, p) } } func (p ProtectionProfile) saltLen() (int, error) { switch p { case ProtectionProfileAes128CmHmacSha1_32, ProtectionProfileAes128CmHmacSha1_80: return 14, nil case ProtectionProfileAeadAes128Gcm: return 12, nil default: return 0, fmt.Errorf("%w: %#v", errNoSuchSRTPProfile, p) } } func (p ProtectionProfile) rtpAuthTagLen() (int, error) { switch p { case ProtectionProfileAes128CmHmacSha1_80: return 10, nil case ProtectionProfileAes128CmHmacSha1_32: return 4, nil case ProtectionProfileAeadAes128Gcm: return 0, nil default: return 0, fmt.Errorf("%w: %#v", errNoSuchSRTPProfile, p) } } func (p ProtectionProfile) rtcpAuthTagLen() (int, error) { switch p { case ProtectionProfileAes128CmHmacSha1_32, ProtectionProfileAes128CmHmacSha1_80: return 10, nil case ProtectionProfileAeadAes128Gcm: return 0, nil default: return 0, fmt.Errorf("%w: %#v", errNoSuchSRTPProfile, p) } } func (p ProtectionProfile) aeadAuthTagLen() (int, error) { switch p { case ProtectionProfileAes128CmHmacSha1_32, ProtectionProfileAes128CmHmacSha1_80: return 0, nil case ProtectionProfileAeadAes128Gcm: return 16, nil default: return 0, fmt.Errorf("%w: %#v", errNoSuchSRTPProfile, p) } } func (p ProtectionProfile) authKeyLen() (int, error) { switch p { case ProtectionProfileAes128CmHmacSha1_32, ProtectionProfileAes128CmHmacSha1_80: return 20, nil case ProtectionProfileAeadAes128Gcm: return 0, nil default: return 0, fmt.Errorf("%w: %#v", errNoSuchSRTPProfile, p) } } srtp-2.0.12/protection_profile_test.go000066400000000000000000000004701437106062400200640ustar00rootroot00000000000000package srtp import ( "testing" "github.com/stretchr/testify/assert" ) func TestInvalidProtectionProfile(t *testing.T) { var invalidProtectionProfile ProtectionProfile _, err := invalidProtectionProfile.keyLen() assert.Error(t, err) _, err = invalidProtectionProfile.saltLen() assert.Error(t, err) } srtp-2.0.12/renovate.json000066400000000000000000000001731437106062400153060ustar00rootroot00000000000000{ "$schema": "https://docs.renovatebot.com/renovate-schema.json", "extends": [ "github>pion/renovate-config" ] } srtp-2.0.12/session.go000066400000000000000000000062301437106062400146020ustar00rootroot00000000000000package srtp import ( "errors" "io" "net" "sync" "github.com/pion/logging" "github.com/pion/transport/v2/packetio" ) type streamSession interface { Close() error write([]byte) (int, error) decrypt([]byte) error } type session struct { localContextMutex sync.Mutex localContext, remoteContext *Context localOptions, remoteOptions []ContextOption newStream chan readStream started chan interface{} closed chan interface{} readStreamsClosed bool readStreams map[uint32]readStream readStreamsLock sync.Mutex log logging.LeveledLogger bufferFactory func(packetType packetio.BufferPacketType, ssrc uint32) io.ReadWriteCloser nextConn net.Conn } // Config is used to configure a session. // You can provide either a KeyingMaterialExporter to export keys // or directly pass the keys themselves. // After a Config is passed to a session it must not be modified. type Config struct { Keys SessionKeys Profile ProtectionProfile BufferFactory func(packetType packetio.BufferPacketType, ssrc uint32) io.ReadWriteCloser LoggerFactory logging.LoggerFactory // List of local/remote context options. // ReplayProtection is enabled on remote context by default. // Default replay protection window size is 64. LocalOptions, RemoteOptions []ContextOption } // SessionKeys bundles the keys required to setup an SRTP session type SessionKeys struct { LocalMasterKey []byte LocalMasterSalt []byte RemoteMasterKey []byte RemoteMasterSalt []byte } func (s *session) getOrCreateReadStream(ssrc uint32, child streamSession, proto func() readStream) (readStream, bool) { s.readStreamsLock.Lock() defer s.readStreamsLock.Unlock() if s.readStreamsClosed { return nil, false } r, ok := s.readStreams[ssrc] if ok { return r, false } // Create the readStream. r = proto() if err := r.init(child, ssrc); err != nil { return nil, false } s.readStreams[ssrc] = r return r, true } func (s *session) removeReadStream(ssrc uint32) { s.readStreamsLock.Lock() defer s.readStreamsLock.Unlock() if s.readStreamsClosed { return } delete(s.readStreams, ssrc) } func (s *session) close() error { if s.nextConn == nil { return nil } else if err := s.nextConn.Close(); err != nil { return err } <-s.closed return nil } func (s *session) start(localMasterKey, localMasterSalt, remoteMasterKey, remoteMasterSalt []byte, profile ProtectionProfile, child streamSession) error { var err error s.localContext, err = CreateContext(localMasterKey, localMasterSalt, profile, s.localOptions...) if err != nil { return err } s.remoteContext, err = CreateContext(remoteMasterKey, remoteMasterSalt, profile, s.remoteOptions...) if err != nil { return err } go func() { defer func() { close(s.newStream) s.readStreamsLock.Lock() s.readStreamsClosed = true s.readStreamsLock.Unlock() close(s.closed) }() b := make([]byte, 8192) for { var i int i, err = s.nextConn.Read(b) if err != nil { if !errors.Is(err, io.EOF) { s.log.Error(err.Error()) } return } if err = child.decrypt(b[:i]); err != nil { s.log.Info(err.Error()) } } }() close(s.started) return nil } srtp-2.0.12/session_srtcp.go000066400000000000000000000105671437106062400160250ustar00rootroot00000000000000package srtp import ( "net" "time" "github.com/pion/logging" "github.com/pion/rtcp" ) const defaultSessionSRTCPReplayProtectionWindow = 64 // SessionSRTCP implements io.ReadWriteCloser and provides a bi-directional SRTCP session // SRTCP itself does not have a design like this, but it is common in most applications // for local/remote to each have their own keying material. This provides those patterns // instead of making everyone re-implement type SessionSRTCP struct { session writeStream *WriteStreamSRTCP } // NewSessionSRTCP creates a SRTCP session using conn as the underlying transport. func NewSessionSRTCP(conn net.Conn, config *Config) (*SessionSRTCP, error) { //nolint:dupl if config == nil { return nil, errNoConfig } else if conn == nil { return nil, errNoConn } loggerFactory := config.LoggerFactory if loggerFactory == nil { loggerFactory = logging.NewDefaultLoggerFactory() } localOpts := append( []ContextOption{}, config.LocalOptions..., ) remoteOpts := append( []ContextOption{ // Default options SRTCPReplayProtection(defaultSessionSRTCPReplayProtectionWindow), }, config.RemoteOptions..., ) s := &SessionSRTCP{ session: session{ nextConn: conn, localOptions: localOpts, remoteOptions: remoteOpts, readStreams: map[uint32]readStream{}, newStream: make(chan readStream), started: make(chan interface{}), closed: make(chan interface{}), bufferFactory: config.BufferFactory, log: loggerFactory.NewLogger("srtp"), }, } s.writeStream = &WriteStreamSRTCP{s} err := s.session.start( config.Keys.LocalMasterKey, config.Keys.LocalMasterSalt, config.Keys.RemoteMasterKey, config.Keys.RemoteMasterSalt, config.Profile, s, ) if err != nil { return nil, err } return s, nil } // OpenWriteStream returns the global write stream for the Session func (s *SessionSRTCP) OpenWriteStream() (*WriteStreamSRTCP, error) { return s.writeStream, nil } // OpenReadStream opens a read stream for the given SSRC, it can be used // if you want a certain SSRC, but don't want to wait for AcceptStream func (s *SessionSRTCP) OpenReadStream(ssrc uint32) (*ReadStreamSRTCP, error) { r, _ := s.session.getOrCreateReadStream(ssrc, s, newReadStreamSRTCP) if readStream, ok := r.(*ReadStreamSRTCP); ok { return readStream, nil } return nil, errFailedTypeAssertion } // AcceptStream returns a stream to handle RTCP for a single SSRC func (s *SessionSRTCP) AcceptStream() (*ReadStreamSRTCP, uint32, error) { stream, ok := <-s.newStream if !ok { return nil, 0, errStreamAlreadyClosed } readStream, ok := stream.(*ReadStreamSRTCP) if !ok { return nil, 0, errFailedTypeAssertion } return readStream, stream.GetSSRC(), nil } // Close ends the session func (s *SessionSRTCP) Close() error { return s.session.close() } // Private func (s *SessionSRTCP) write(buf []byte) (int, error) { if _, ok := <-s.session.started; ok { return 0, errStartedChannelUsedIncorrectly } ibuf := bufferpool.Get() defer bufferpool.Put(ibuf) s.session.localContextMutex.Lock() encrypted, err := s.localContext.EncryptRTCP(ibuf.([]byte), buf, nil) s.session.localContextMutex.Unlock() if err != nil { return 0, err } return s.session.nextConn.Write(encrypted) } func (s *SessionSRTCP) setWriteDeadline(t time.Time) error { return s.session.nextConn.SetWriteDeadline(t) } // create a list of Destination SSRCs // that's a superset of all Destinations in the slice. func destinationSSRC(pkts []rtcp.Packet) []uint32 { ssrcSet := make(map[uint32]struct{}) for _, p := range pkts { for _, ssrc := range p.DestinationSSRC() { ssrcSet[ssrc] = struct{}{} } } out := make([]uint32, 0, len(ssrcSet)) for ssrc := range ssrcSet { out = append(out, ssrc) } return out } func (s *SessionSRTCP) decrypt(buf []byte) error { decrypted, err := s.remoteContext.DecryptRTCP(buf, buf, nil) if err != nil { return err } pkt, err := rtcp.Unmarshal(decrypted) if err != nil { return err } for _, ssrc := range destinationSSRC(pkt) { r, isNew := s.session.getOrCreateReadStream(ssrc, s, newReadStreamSRTCP) if r == nil { return nil // Session has been closed } else if isNew { s.session.newStream <- r // Notify AcceptStream } readStream, ok := r.(*ReadStreamSRTCP) if !ok { return errFailedTypeAssertion } _, err = readStream.write(decrypted) if err != nil { return err } } return nil } srtp-2.0.12/session_srtcp_test.go000066400000000000000000000210131437106062400170500ustar00rootroot00000000000000package srtp import ( "bytes" "errors" "io" "net" "reflect" "strings" "sync" "testing" "time" "github.com/pion/rtcp" "github.com/pion/transport/v2/test" ) const rtcpHeaderSize = 4 func TestSessionSRTCPBadInit(t *testing.T) { if _, err := NewSessionSRTCP(nil, nil); err == nil { t.Fatal("NewSessionSRTCP should error if no config was provided") } else if _, err := NewSessionSRTCP(nil, &Config{}); err == nil { t.Fatal("NewSessionSRTCP should error if no net was provided") } } func buildSessionSRTCP(t *testing.T) (*SessionSRTCP, net.Conn, *Config) { aPipe, bPipe := net.Pipe() config := &Config{ Profile: ProtectionProfileAes128CmHmacSha1_80, Keys: SessionKeys{ []byte{0xE1, 0xF9, 0x7A, 0x0D, 0x3E, 0x01, 0x8B, 0xE0, 0xD6, 0x4F, 0xA3, 0x2C, 0x06, 0xDE, 0x41, 0x39}, []byte{0x0E, 0xC6, 0x75, 0xAD, 0x49, 0x8A, 0xFE, 0xEB, 0xB6, 0x96, 0x0B, 0x3A, 0xAB, 0xE6}, []byte{0xE1, 0xF9, 0x7A, 0x0D, 0x3E, 0x01, 0x8B, 0xE0, 0xD6, 0x4F, 0xA3, 0x2C, 0x06, 0xDE, 0x41, 0x39}, []byte{0x0E, 0xC6, 0x75, 0xAD, 0x49, 0x8A, 0xFE, 0xEB, 0xB6, 0x96, 0x0B, 0x3A, 0xAB, 0xE6}, }, } aSession, err := NewSessionSRTCP(aPipe, config) if err != nil { t.Fatal(err) } else if aSession == nil { t.Fatal("NewSessionSRTCP did not error, but returned nil session") } return aSession, bPipe, config } func buildSessionSRTCPPair(t *testing.T) (*SessionSRTCP, *SessionSRTCP) { //nolint:dupl aSession, bPipe, config := buildSessionSRTCP(t) bSession, err := NewSessionSRTCP(bPipe, config) if err != nil { t.Fatal(err) } else if bSession == nil { t.Fatal("NewSessionSRTCP did not error, but returned nil session") } return aSession, bSession } func TestSessionSRTCP(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() report := test.CheckRoutines(t) defer report() testPayload, err := rtcp.Marshal([]rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: 5000}}) if err != nil { t.Fatal(err) } readBuffer := make([]byte, len(testPayload)) aSession, bSession := buildSessionSRTCPPair(t) aWriteStream, err := aSession.OpenWriteStream() if err != nil { t.Fatal(err) } if _, err = aWriteStream.Write(testPayload); err != nil { t.Fatal(err) } bReadStream, _, err := bSession.AcceptStream() if err != nil { t.Fatal(err) } if _, err = bReadStream.Read(readBuffer); err != nil { t.Fatal(err) } if !bytes.Equal(testPayload, readBuffer) { t.Fatalf("Sent buffer does not match the one received exp(%v) actual(%v)", testPayload, readBuffer) } if err = aSession.Close(); err != nil { t.Fatal(err) } if err = bSession.Close(); err != nil { t.Fatal(err) } } func TestSessionSRTCPWithIODeadline(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() report := test.CheckRoutines(t) defer report() testPayload, err := rtcp.Marshal([]rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: 5000}}) if err != nil { t.Fatal(err) } readBuffer := make([]byte, len(testPayload)) aSession, bPipe, config := buildSessionSRTCP(t) aWriteStream, err := aSession.OpenWriteStream() if err != nil { t.Fatal(err) } // When the other peer is not ready, the Write would be blocked if no deadline. if err = aWriteStream.SetWriteDeadline(time.Now().Add(1 * time.Second)); err != nil { t.Fatal(err) } if _, err = aWriteStream.Write(testPayload); !errIsTimeout(err) { t.Fatalf("Unexcepted write-error(%v)", err) } if err = aWriteStream.SetWriteDeadline(time.Time{}); err != nil { t.Fatal(err) } // Setup another peer. bSession, err := NewSessionSRTCP(bPipe, config) if err != nil { t.Fatal(err) } else if bSession == nil { t.Fatal("NewSessionSRTCP did not error, but returned nil session") } // The second attempt to write. if _, err = aWriteStream.Write(testPayload); err != nil { // The other peer is ready, this write attempt should work. t.Fatal(err) } bReadStream, _, err := bSession.AcceptStream() if err != nil { t.Fatal(err) } if _, err = bReadStream.Read(readBuffer); err != nil { t.Fatal(err) } if !bytes.Equal(testPayload, readBuffer) { t.Fatalf("Sent buffer does not match the one received exp(%v) actual(%v)", testPayload, readBuffer) } // The second Read attempt would be blocked if the deadline is not set. if err = bReadStream.SetReadDeadline(time.Now().Add(1 * time.Second)); err != nil { t.Fatal(err) } if _, err = bReadStream.Read(readBuffer); !errIsTimeout(err) { t.Fatalf("Unexpected read-error(%v)", err) } if err = aSession.Close(); err != nil { t.Fatal(err) } if err = bSession.Close(); err != nil { t.Fatal(err) } } func TestSessionSRTCPOpenReadStream(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() report := test.CheckRoutines(t) defer report() testPayload, err := rtcp.Marshal([]rtcp.Packet{&rtcp.PictureLossIndication{MediaSSRC: 5000}}) if err != nil { t.Fatal(err) } readBuffer := make([]byte, len(testPayload)) aSession, bSession := buildSessionSRTCPPair(t) bReadStream, err := bSession.OpenReadStream(5000) if err != nil { t.Fatal(err) } aWriteStream, err := aSession.OpenWriteStream() if err != nil { t.Fatal(err) } if _, err = aWriteStream.Write(testPayload); err != nil { t.Fatal(err) } if _, err = bReadStream.Read(readBuffer); err != nil { t.Fatal(err) } if !bytes.Equal(testPayload, readBuffer) { t.Fatalf("Sent buffer does not match the one received exp(%v) actual(%v)", testPayload, readBuffer) } if err = aSession.Close(); err != nil { t.Fatal(err) } if err = bSession.Close(); err != nil { t.Fatal(err) } } func TestSessionSRTCPReplayProtection(t *testing.T) { lim := test.TimeOut(time.Second * 5) defer lim.Stop() report := test.CheckRoutines(t) defer report() const ( testSSRC = 5000 ) aSession, bSession := buildSessionSRTCPPair(t) bReadStream, err := bSession.OpenReadStream(testSSRC) if err != nil { t.Fatal(err) } // Generate test packets var packets [][]byte var expectedSSRC []uint32 for i := uint32(0); i < 0x100; i++ { testPacket := &rtcp.PictureLossIndication{ MediaSSRC: testSSRC, SenderSSRC: i, } expectedSSRC = append(expectedSSRC, i) encrypted, eerr := encryptSRTCP(aSession.session.localContext, testPacket) if eerr != nil { t.Fatal(eerr) } packets = append(packets, encrypted) } // Receive SRTCP packets with replay protection var receivedSSRC []uint32 var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() for { if ssrc, perr := getSenderSSRC(t, bReadStream); perr == nil { receivedSSRC = append(receivedSSRC, ssrc) } else if errors.Is(perr, io.EOF) { return } } }() // Write with replay attack for _, p := range packets { if _, err = aSession.session.nextConn.Write(p); err != nil { t.Fatal(err) } // Immediately replay if _, err = aSession.session.nextConn.Write(p); err != nil { t.Fatal(err) } } for _, p := range packets { // Delayed replay if _, err = aSession.session.nextConn.Write(p); err != nil { t.Fatal(err) } } if err = aSession.Close(); err != nil { t.Fatal(err) } if err = bSession.Close(); err != nil { t.Fatal(err) } if err = bReadStream.Close(); err != nil { t.Fatal(err) } wg.Wait() if !reflect.DeepEqual(expectedSSRC, receivedSSRC) { t.Errorf("Expected and received packet differs,\nexpected:\n%v\nreceived:\n%v", expectedSSRC, receivedSSRC, ) } } func getSenderSSRC(t *testing.T, stream *ReadStreamSRTCP) (ssrc uint32, err error) { authTagSize, err := ProtectionProfileAes128CmHmacSha1_80.rtcpAuthTagLen() if err != nil { return 0, err } const pliPacketSize = 8 readBuffer := make([]byte, pliPacketSize+authTagSize+srtcpIndexSize) n, _, err := stream.ReadRTCP(readBuffer) if errors.Is(err, io.EOF) { return 0, err } if err != nil { t.Error(err) return 0, err } pli := &rtcp.PictureLossIndication{} if uerr := pli.Unmarshal(readBuffer[:n]); uerr != nil { t.Error(uerr) return 0, uerr } return pli.SenderSSRC, nil } func encryptSRTCP(context *Context, pkt rtcp.Packet) ([]byte, error) { decryptedRaw, err := pkt.Marshal() if err != nil { return nil, err } encryptInput := make([]byte, len(decryptedRaw), rtcpHeaderSize+len(decryptedRaw)+10) copy(encryptInput, decryptedRaw) encrypted, eerr := context.EncryptRTCP(encryptInput, encryptInput, nil) if eerr != nil { return nil, eerr } return encrypted, nil } func errIsTimeout(err error) bool { if err == nil { return false } s := err.Error() switch { case strings.Contains(s, "i/o timeout"): // error message when timeout before go1.15. return true case strings.Contains(s, "deadline exceeded"): // error message when timeout after go1.15. return true } return false } srtp-2.0.12/session_srtp.go000066400000000000000000000116501437106062400156540ustar00rootroot00000000000000package srtp import ( "net" "sync" "time" "github.com/pion/logging" "github.com/pion/rtp" ) const defaultSessionSRTPReplayProtectionWindow = 64 // SessionSRTP implements io.ReadWriteCloser and provides a bi-directional SRTP session // SRTP itself does not have a design like this, but it is common in most applications // for local/remote to each have their own keying material. This provides those patterns // instead of making everyone re-implement type SessionSRTP struct { session writeStream *WriteStreamSRTP } // NewSessionSRTP creates a SRTP session using conn as the underlying transport. func NewSessionSRTP(conn net.Conn, config *Config) (*SessionSRTP, error) { //nolint:dupl if config == nil { return nil, errNoConfig } else if conn == nil { return nil, errNoConn } loggerFactory := config.LoggerFactory if loggerFactory == nil { loggerFactory = logging.NewDefaultLoggerFactory() } localOpts := append( []ContextOption{}, config.LocalOptions..., ) remoteOpts := append( []ContextOption{ // Default options SRTPReplayProtection(defaultSessionSRTPReplayProtectionWindow), }, config.RemoteOptions..., ) s := &SessionSRTP{ session: session{ nextConn: conn, localOptions: localOpts, remoteOptions: remoteOpts, readStreams: map[uint32]readStream{}, newStream: make(chan readStream), started: make(chan interface{}), closed: make(chan interface{}), bufferFactory: config.BufferFactory, log: loggerFactory.NewLogger("srtp"), }, } s.writeStream = &WriteStreamSRTP{s} err := s.session.start( config.Keys.LocalMasterKey, config.Keys.LocalMasterSalt, config.Keys.RemoteMasterKey, config.Keys.RemoteMasterSalt, config.Profile, s, ) if err != nil { return nil, err } return s, nil } // OpenWriteStream returns the global write stream for the Session func (s *SessionSRTP) OpenWriteStream() (*WriteStreamSRTP, error) { return s.writeStream, nil } // OpenReadStream opens a read stream for the given SSRC, it can be used // if you want a certain SSRC, but don't want to wait for AcceptStream func (s *SessionSRTP) OpenReadStream(ssrc uint32) (*ReadStreamSRTP, error) { r, _ := s.session.getOrCreateReadStream(ssrc, s, newReadStreamSRTP) if readStream, ok := r.(*ReadStreamSRTP); ok { return readStream, nil } return nil, errFailedTypeAssertion } // AcceptStream returns a stream to handle RTCP for a single SSRC func (s *SessionSRTP) AcceptStream() (*ReadStreamSRTP, uint32, error) { stream, ok := <-s.newStream if !ok { return nil, 0, errStreamAlreadyClosed } readStream, ok := stream.(*ReadStreamSRTP) if !ok { return nil, 0, errFailedTypeAssertion } return readStream, stream.GetSSRC(), nil } // Close ends the session func (s *SessionSRTP) Close() error { return s.session.close() } func (s *SessionSRTP) write(b []byte) (int, error) { packet := &rtp.Packet{} if err := packet.Unmarshal(b); err != nil { return 0, err } return s.writeRTP(&packet.Header, packet.Payload) } // bufferpool is a global pool of buffers used for encrypted packets in // writeRTP below. Since it's global, buffers can be shared between // different sessions, which amortizes the cost of allocating the pool. // // 1472 is the maximum Ethernet UDP payload. We give ourselves 20 bytes // of slack for any authentication tags, which is more than enough for // either CTR or GCM. If the buffer is too small, no harm, it will just // get expanded by growBuffer. var bufferpool = sync.Pool{ // nolint:gochecknoglobals New: func() interface{} { return make([]byte, 1492) }, } func (s *SessionSRTP) writeRTP(header *rtp.Header, payload []byte) (int, error) { if _, ok := <-s.session.started; ok { return 0, errStartedChannelUsedIncorrectly } // encryptRTP will either return our buffer, or, if it is too // small, allocate a new buffer itself. In either case, it is // safe to put the buffer back into the pool, but only after // nextConn.Write has returned. ibuf := bufferpool.Get() defer bufferpool.Put(ibuf) s.session.localContextMutex.Lock() encrypted, err := s.localContext.encryptRTP(ibuf.([]byte), header, payload) s.session.localContextMutex.Unlock() if err != nil { return 0, err } return s.session.nextConn.Write(encrypted) } func (s *SessionSRTP) setWriteDeadline(t time.Time) error { return s.session.nextConn.SetWriteDeadline(t) } func (s *SessionSRTP) decrypt(buf []byte) error { h := &rtp.Header{} headerLen, err := h.Unmarshal(buf) if err != nil { return err } r, isNew := s.session.getOrCreateReadStream(h.SSRC, s, newReadStreamSRTP) if r == nil { return nil // Session has been closed } else if isNew { s.session.newStream <- r // Notify AcceptStream } readStream, ok := r.(*ReadStreamSRTP) if !ok { return errFailedTypeAssertion } decrypted, err := s.remoteContext.decryptRTP(buf, buf, h, headerLen) if err != nil { return err } _, err = readStream.write(decrypted) if err != nil { return err } return nil } srtp-2.0.12/session_srtp_test.go000066400000000000000000000236651437106062400167240ustar00rootroot00000000000000package srtp import ( "bytes" "errors" "io" "net" "reflect" "sync" "testing" "time" "github.com/pion/rtp" "github.com/pion/transport/v2/test" ) func TestSessionSRTPBadInit(t *testing.T) { if _, err := NewSessionSRTP(nil, nil); err == nil { t.Fatal("NewSessionSRTP should error if no config was provided") } else if _, err := NewSessionSRTP(nil, &Config{}); err == nil { t.Fatal("NewSessionSRTP should error if no net was provided") } } func buildSessionSRTP(t *testing.T) (*SessionSRTP, net.Conn, *Config) { aPipe, bPipe := net.Pipe() config := &Config{ Profile: ProtectionProfileAes128CmHmacSha1_80, Keys: SessionKeys{ []byte{0xE1, 0xF9, 0x7A, 0x0D, 0x3E, 0x01, 0x8B, 0xE0, 0xD6, 0x4F, 0xA3, 0x2C, 0x06, 0xDE, 0x41, 0x39}, []byte{0x0E, 0xC6, 0x75, 0xAD, 0x49, 0x8A, 0xFE, 0xEB, 0xB6, 0x96, 0x0B, 0x3A, 0xAB, 0xE6}, []byte{0xE1, 0xF9, 0x7A, 0x0D, 0x3E, 0x01, 0x8B, 0xE0, 0xD6, 0x4F, 0xA3, 0x2C, 0x06, 0xDE, 0x41, 0x39}, []byte{0x0E, 0xC6, 0x75, 0xAD, 0x49, 0x8A, 0xFE, 0xEB, 0xB6, 0x96, 0x0B, 0x3A, 0xAB, 0xE6}, }, } aSession, err := NewSessionSRTP(aPipe, config) if err != nil { t.Fatal(err) } else if aSession == nil { t.Fatal("NewSessionSRTP did not error, but returned nil session") } return aSession, bPipe, config } func buildSessionSRTPPair(t *testing.T) (*SessionSRTP, *SessionSRTP) { //nolint:dupl aSession, bPipe, config := buildSessionSRTP(t) bSession, err := NewSessionSRTP(bPipe, config) if err != nil { t.Fatal(err) } else if bSession == nil { t.Fatal("NewSessionSRTP did not error, but returned nil session") } return aSession, bSession } func TestSessionSRTP(t *testing.T) { lim := test.TimeOut(time.Second * 5) defer lim.Stop() report := test.CheckRoutines(t) defer report() const ( testSSRC = 5000 rtpHeaderSize = 12 ) testPayload := []byte{0x00, 0x01, 0x03, 0x04} readBuffer := make([]byte, rtpHeaderSize+len(testPayload)) aSession, bSession := buildSessionSRTPPair(t) aWriteStream, err := aSession.OpenWriteStream() if err != nil { t.Fatal(err) } if _, err = aWriteStream.WriteRTP(&rtp.Header{SSRC: testSSRC}, append([]byte{}, testPayload...)); err != nil { t.Fatal(err) } bReadStream, ssrc, err := bSession.AcceptStream() if err != nil { t.Fatal(err) } else if ssrc != testSSRC { t.Fatalf("SSRC mismatch during accept exp(%v) actual%v)", testSSRC, ssrc) } if _, err = bReadStream.Read(readBuffer); err != nil { t.Fatal(err) } if !bytes.Equal(testPayload, readBuffer[rtpHeaderSize:]) { t.Fatalf("Sent buffer does not match the one received exp(%v) actual(%v)", testPayload, readBuffer[rtpHeaderSize:]) } if err = aSession.Close(); err != nil { t.Fatal(err) } if err = bSession.Close(); err != nil { t.Fatal(err) } } func TestSessionSRTPWithIODeadline(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() report := test.CheckRoutines(t) defer report() const ( testSSRC = 5000 rtpHeaderSize = 12 ) testPayload := []byte{0x00, 0x01, 0x03, 0x04} readBuffer := make([]byte, rtpHeaderSize+len(testPayload)) aSession, bPipe, config := buildSessionSRTP(t) aWriteStream, err := aSession.OpenWriteStream() if err != nil { t.Fatal(err) } // When the other peer is not ready, the Write would be blocked if no deadline. if err = aWriteStream.SetWriteDeadline(time.Now().Add(1 * time.Second)); err != nil { t.Fatal(err) } if _, err = aWriteStream.WriteRTP(&rtp.Header{SSRC: testSSRC}, append([]byte{}, testPayload...)); !errIsTimeout(err) { t.Fatal(err) } if err = aWriteStream.SetWriteDeadline(time.Time{}); err != nil { t.Fatal(err) } // Setup another peer. bSession, err := NewSessionSRTP(bPipe, config) if err != nil { t.Fatal(err) } else if bSession == nil { t.Fatal("NewSessionSRTCP did not error, but returned nil session") } // The second attempt to write, even without deadline. if _, err = aWriteStream.WriteRTP(&rtp.Header{SSRC: testSSRC}, append([]byte{}, testPayload...)); err != nil { t.Fatal(err) } bReadStream, ssrc, err := bSession.AcceptStream() if err != nil { t.Fatal(err) } else if ssrc != testSSRC { t.Fatalf("SSRC mismatch during accept exp(%v) actual%v)", testSSRC, ssrc) } if _, err = bReadStream.Read(readBuffer); err != nil { t.Fatal(err) } if !bytes.Equal(testPayload, readBuffer[rtpHeaderSize:]) { t.Fatalf("Sent buffer does not match the one received exp(%v) actual(%v)", testPayload, readBuffer[rtpHeaderSize:]) } // The second Read attempt would be blocked if the deadline is not set. if err = bReadStream.SetReadDeadline(time.Now().Add(1 * time.Second)); err != nil { t.Fatal(err) } if _, err = bReadStream.Read(readBuffer); !errIsTimeout(err) { t.Fatalf("Unexpected read-error(%v)", err) } if err = aSession.Close(); err != nil { t.Fatal(err) } if err = bSession.Close(); err != nil { t.Fatal(err) } } func TestSessionSRTPOpenReadStream(t *testing.T) { lim := test.TimeOut(time.Second * 5) defer lim.Stop() report := test.CheckRoutines(t) defer report() const ( testSSRC = 5000 rtpHeaderSize = 12 ) testPayload := []byte{0x00, 0x01, 0x03, 0x04} readBuffer := make([]byte, rtpHeaderSize+len(testPayload)) aSession, bSession := buildSessionSRTPPair(t) bReadStream, err := bSession.OpenReadStream(5000) if err != nil { t.Fatal(err) } aWriteStream, err := aSession.OpenWriteStream() if err != nil { t.Fatal(err) } if _, err = aWriteStream.WriteRTP(&rtp.Header{SSRC: testSSRC}, append([]byte{}, testPayload...)); err != nil { t.Fatal(err) } if _, err = bReadStream.Read(readBuffer); err != nil { t.Fatal(err) } if !bytes.Equal(testPayload, readBuffer[rtpHeaderSize:]) { t.Fatalf("Sent buffer does not match the one received exp(%v) actual(%v)", testPayload, readBuffer[rtpHeaderSize:]) } if err = aSession.Close(); err != nil { t.Fatal(err) } if err = bSession.Close(); err != nil { t.Fatal(err) } } func TestSessionSRTPMultiSSRC(t *testing.T) { lim := test.TimeOut(time.Second * 5) defer lim.Stop() report := test.CheckRoutines(t) defer report() const rtpHeaderSize = 12 ssrcs := []uint32{5000, 5001, 5002} testPayload := []byte{0x00, 0x01, 0x03, 0x04} aSession, bSession := buildSessionSRTPPair(t) bReadStreams := make(map[uint32]*ReadStreamSRTP) for _, ssrc := range ssrcs { bReadStream, err := bSession.OpenReadStream(ssrc) if err != nil { t.Fatal(err) } bReadStreams[ssrc] = bReadStream } aWriteStream, err := aSession.OpenWriteStream() if err != nil { t.Fatal(err) } for _, ssrc := range ssrcs { if _, err = aWriteStream.WriteRTP(&rtp.Header{SSRC: ssrc}, append([]byte{}, testPayload...)); err != nil { t.Fatal(err) } readBuffer := make([]byte, rtpHeaderSize+len(testPayload)) if _, err = bReadStreams[ssrc].Read(readBuffer); err != nil { t.Fatal(err) } if !bytes.Equal(testPayload, readBuffer[rtpHeaderSize:]) { t.Fatalf("Sent buffer does not match the one received exp(%v) actual(%v)", testPayload, readBuffer[rtpHeaderSize:]) } } if err = aSession.Close(); err != nil { t.Fatal(err) } if err = bSession.Close(); err != nil { t.Fatal(err) } } func TestSessionSRTPReplayProtection(t *testing.T) { lim := test.TimeOut(time.Second * 5) defer lim.Stop() report := test.CheckRoutines(t) defer report() const ( testSSRC = 5000 rtpHeaderSize = 12 ) testPayload := []byte{0x00, 0x01, 0x03, 0x04} aSession, bSession := buildSessionSRTPPair(t) bReadStream, err := bSession.OpenReadStream(testSSRC) if err != nil { t.Fatal(err) } // Generate test packets var packets [][]byte var expectedSequenceNumber []uint16 for i := uint16(0xFF00); i != 0x100; i++ { expectedSequenceNumber = append(expectedSequenceNumber, i) encrypted, eerr := encryptSRTP(aSession.session.localContext, &rtp.Packet{ Header: rtp.Header{ SSRC: testSSRC, SequenceNumber: i, }, Payload: testPayload, }) if eerr != nil { t.Fatal(eerr) } packets = append(packets, encrypted) } // Receive SRTP packets with replay protection var receivedSequenceNumber []uint16 var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() for { if seq, perr := assertPayloadSRTP(t, bReadStream, rtpHeaderSize, testPayload); perr == nil { receivedSequenceNumber = append(receivedSequenceNumber, seq) } else if errors.Is(perr, io.EOF) { return } } }() // Write with replay attack for _, p := range packets { if _, err = aSession.session.nextConn.Write(p); err != nil { t.Fatal(err) } // Immediately replay if _, err = aSession.session.nextConn.Write(p); err != nil { t.Fatal(err) } } for _, p := range packets { // Delayed replay if _, err = aSession.session.nextConn.Write(p); err != nil { t.Fatal(err) } } if err = aSession.Close(); err != nil { t.Fatal(err) } if err = bSession.Close(); err != nil { t.Fatal(err) } if err = bReadStream.Close(); err != nil { t.Fatal(err) } wg.Wait() if !reflect.DeepEqual(expectedSequenceNumber, receivedSequenceNumber) { t.Errorf("Expected and received sequence number differs,\nexpected:\n%v\nreceived:\n%v", expectedSequenceNumber, receivedSequenceNumber, ) } } func assertPayloadSRTP(t *testing.T, stream *ReadStreamSRTP, headerSize int, expectedPayload []byte) (seq uint16, err error) { readBuffer := make([]byte, headerSize+len(expectedPayload)) n, hdr, err := stream.ReadRTP(readBuffer) if errors.Is(err, io.EOF) { return 0, err } if err != nil { t.Error(err) return 0, err } if !bytes.Equal(expectedPayload, readBuffer[headerSize:n]) { t.Errorf("Sent buffer does not match the one received exp(%v) actual(%v)", expectedPayload, readBuffer[headerSize:n]) return 0, errPayloadDiffers } return hdr.SequenceNumber, nil } func encryptSRTP(context *Context, pkt *rtp.Packet) ([]byte, error) { decryptedRaw, err := pkt.Marshal() if err != nil { return nil, err } encryptInput := make([]byte, len(decryptedRaw), len(decryptedRaw)+10) copy(encryptInput, decryptedRaw) encrypted, eerr := context.EncryptRTP(encryptInput, encryptInput, nil) if eerr != nil { return nil, eerr } return encrypted, nil } srtp-2.0.12/srtcp.go000066400000000000000000000045351437106062400142600ustar00rootroot00000000000000package srtp import ( "encoding/binary" "fmt" "github.com/pion/rtcp" ) const maxSRTCPIndex = 0x7FFFFFFF func (c *Context) decryptRTCP(dst, encrypted []byte) ([]byte, error) { out := allocateIfMismatch(dst, encrypted) authTagLen, err := c.cipher.rtcpAuthTagLen() if err != nil { return nil, err } aeadAuthTagLen, err := c.cipher.aeadAuthTagLen() if err != nil { return nil, err } tailOffset := len(encrypted) - (authTagLen + srtcpIndexSize) if tailOffset < aeadAuthTagLen { return nil, fmt.Errorf("%w: %d", errTooShortRTCP, len(encrypted)) } else if isEncrypted := encrypted[tailOffset] >> 7; isEncrypted == 0 { return out, nil } index := c.cipher.getRTCPIndex(encrypted) ssrc := binary.BigEndian.Uint32(encrypted[4:]) s := c.getSRTCPSSRCState(ssrc) markAsValid, ok := s.replayDetector.Check(uint64(index)) if !ok { return nil, &duplicatedError{Proto: "srtcp", SSRC: ssrc, Index: index} } out, err = c.cipher.decryptRTCP(out, encrypted, index, ssrc) if err != nil { return nil, err } markAsValid() return out, nil } // DecryptRTCP decrypts a buffer that contains a RTCP packet func (c *Context) DecryptRTCP(dst, encrypted []byte, header *rtcp.Header) ([]byte, error) { if header == nil { header = &rtcp.Header{} } if err := header.Unmarshal(encrypted); err != nil { return nil, err } return c.decryptRTCP(dst, encrypted) } func (c *Context) encryptRTCP(dst, decrypted []byte) ([]byte, error) { ssrc := binary.BigEndian.Uint32(decrypted[4:]) s := c.getSRTCPSSRCState(ssrc) if s.srtcpIndex >= maxSRTCPIndex { // ... when 2^48 SRTP packets or 2^31 SRTCP packets have been secured with the same key // (whichever occurs before), the key management MUST be called to provide new master key(s) // (previously stored and used keys MUST NOT be used again), or the session MUST be terminated. // https://www.rfc-editor.org/rfc/rfc3711#section-9.2 return nil, errExceededMaxPackets } // We roll over early because MSB is used for marking as encrypted s.srtcpIndex++ return c.cipher.encryptRTCP(dst, decrypted, s.srtcpIndex, ssrc) } // EncryptRTCP Encrypts a RTCP packet func (c *Context) EncryptRTCP(dst, decrypted []byte, header *rtcp.Header) ([]byte, error) { if header == nil { header = &rtcp.Header{} } if err := header.Unmarshal(decrypted); err != nil { return nil, err } return c.encryptRTCP(dst, decrypted) } srtp-2.0.12/srtcp_test.go000066400000000000000000000500511437106062400153110ustar00rootroot00000000000000package srtp import ( "bytes" "encoding/binary" "errors" "testing" "github.com/pion/rtcp" "github.com/stretchr/testify/assert" ) type rtcpTestPacket struct { ssrc uint32 index uint32 pktType rtcp.PacketType encrypted []byte decrypted []byte } type rtcpTestCase struct { algo ProtectionProfile masterKey []byte masterSalt []byte packets []rtcpTestPacket } func rtcpTestCases() map[string]rtcpTestCase { return map[string]rtcpTestCase{ "AEAD_AES_128_GCM": { algo: ProtectionProfileAeadAes128Gcm, masterKey: []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f}, masterSalt: []byte{0xa0, 0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7, 0xa8, 0xa9, 0xaa, 0xab}, packets: []rtcpTestPacket{ { ssrc: 0xcafebabe, index: 0, pktType: rtcp.TypeSenderReport, encrypted: []byte{ 0x81, 0xc8, 0x00, 0x0b, 0xca, 0xfe, 0xba, 0xbe, 0xc9, 0x8b, 0x8b, 0x5d, 0xf0, 0x39, 0x2a, 0x55, 0x85, 0x2b, 0x6c, 0x21, 0xac, 0x8e, 0x70, 0x25, 0xc5, 0x2c, 0x6f, 0xbe, 0xa2, 0xb3, 0xb4, 0x46, 0xea, 0x31, 0x12, 0x3b, 0xa8, 0x8c, 0xe6, 0x1e, 0x80, 0x00, 0x00, 0x01, }, decrypted: []byte{ 0x81, 0xc8, 0x00, 0x0b, 0xca, 0xfe, 0xba, 0xbe, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, }, }, }, }, "AES_128_CM_HMAC_SHA1_80": { algo: ProtectionProfileAes128CmHmacSha1_80, masterKey: []byte{0xfd, 0xa6, 0x25, 0x95, 0xd7, 0xf6, 0x92, 0x6f, 0x7d, 0x9c, 0x02, 0x4c, 0xc9, 0x20, 0x9f, 0x34}, masterSalt: []byte{0xa9, 0x65, 0x19, 0x85, 0x54, 0x0b, 0x47, 0xbe, 0x2f, 0x27, 0xa8, 0xb8, 0x81, 0x23}, packets: []rtcpTestPacket{ { ssrc: 0x66ef91ff, index: 0, pktType: rtcp.TypeSenderReport, encrypted: []byte{ 0x80, 0xc8, 0x00, 0x06, 0x66, 0xef, 0x91, 0xff, 0xcd, 0x34, 0xc5, 0x78, 0xb2, 0x8b, 0xe1, 0x6b, 0xc5, 0x09, 0xd5, 0x77, 0xe4, 0xce, 0x5f, 0x20, 0x80, 0x21, 0xbd, 0x66, 0x74, 0x65, 0xe9, 0x5f, 0x49, 0xe5, 0xf5, 0xc0, 0x68, 0x4e, 0xe5, 0x6a, 0x78, 0x07, 0x75, 0x46, 0xed, 0x90, 0xf6, 0xdc, 0x9d, 0xef, 0x3b, 0xdf, 0xf2, 0x79, 0xa9, 0xd8, 0x80, 0x00, 0x00, 0x01, 0x60, 0xc0, 0xae, 0xb5, 0x6f, 0x40, 0x88, 0x0e, 0x28, 0xba, }, decrypted: []byte{ 0x80, 0xc8, 0x00, 0x06, 0x66, 0xef, 0x91, 0xff, 0xdf, 0x48, 0x80, 0xdd, 0x61, 0xa6, 0x2e, 0xd3, 0xd8, 0xbc, 0xde, 0xbe, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x16, 0x04, 0x81, 0xca, 0x00, 0x06, 0x66, 0xef, 0x91, 0xff, 0x01, 0x10, 0x52, 0x6e, 0x54, 0x35, 0x43, 0x6d, 0x4a, 0x68, 0x7a, 0x79, 0x65, 0x74, 0x41, 0x78, 0x77, 0x2b, 0x00, 0x00, }, }, { ssrc: 0x11111111, index: 0, pktType: rtcp.TypeSenderReport, encrypted: []byte{ 0x80, 0xc8, 0x00, 0x06, 0x11, 0x11, 0x11, 0x11, 0x17, 0x8c, 0x15, 0xf1, 0x4b, 0x11, 0xda, 0xf5, 0x74, 0x53, 0x86, 0x2b, 0xc9, 0x07, 0x29, 0x40, 0xbf, 0x22, 0xf6, 0x46, 0x11, 0xa4, 0xc1, 0x3a, 0xff, 0x5a, 0xbd, 0xd0, 0xf8, 0x8b, 0x38, 0xe4, 0x95, 0x38, 0x5d, 0xcf, 0x1b, 0xf5, 0x27, 0x77, 0xfb, 0xdb, 0x3f, 0x10, 0x68, 0x99, 0xd8, 0xad, 0x80, 0x00, 0x00, 0x01, 0x34, 0x3c, 0x2e, 0x83, 0x17, 0x13, 0x93, 0x69, 0xcf, 0xc0, }, decrypted: []byte{ 0x80, 0xc8, 0x00, 0x06, 0x11, 0x11, 0x11, 0x11, 0xdf, 0x48, 0x80, 0xdd, 0x61, 0xa6, 0x2e, 0xd3, 0xd8, 0xbc, 0xde, 0xbe, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x16, 0x04, 0x81, 0xca, 0x00, 0x06, 0x66, 0xef, 0x91, 0xff, 0x01, 0x10, 0x52, 0x6e, 0x54, 0x35, 0x43, 0x6d, 0x4a, 0x68, 0x7a, 0x79, 0x65, 0x74, 0x41, 0x78, 0x77, 0x2b, 0x00, 0x00, }, }, }, }, } } func TestRTCPLifecycle(t *testing.T) { options := map[string][]ContextOption{ "Default": {}, "WithReplayProtection": {SRTCPReplayProtection(10)}, } for name, option := range options { option := option t.Run(name, func(t *testing.T) { for caseName, testCase := range rtcpTestCases() { testCase := testCase t.Run(caseName, func(t *testing.T) { assert := assert.New(t) encryptContext, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo, option...) if err != nil { t.Errorf("CreateContext failed: %v", err) } decryptContext, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo, option...) if err != nil { t.Errorf("CreateContext failed: %v", err) } for _, pkt := range testCase.packets { decryptResult, err := decryptContext.DecryptRTCP(nil, pkt.encrypted, nil) if err != nil { t.Error(err) } assert.Equal(pkt.decrypted, decryptResult, "RTCP failed to decrypt") encryptContext.SetIndex(pkt.ssrc, pkt.index) encryptResult, err := encryptContext.EncryptRTCP(nil, pkt.decrypted, nil) if err != nil { t.Error(err) } assert.Equal(pkt.encrypted, encryptResult, "RTCP failed to encrypt") } }) } }) } } func TestRTCPLifecycleInPlace(t *testing.T) { for caseName, testCase := range rtcpTestCases() { testCase := testCase t.Run(caseName, func(t *testing.T) { assert := assert.New(t) authTagLen, err := testCase.algo.rtcpAuthTagLen() assert.NoError(err) aeadAuthTagLen, err := testCase.algo.aeadAuthTagLen() assert.NoError(err) encryptHeader := &rtcp.Header{} encryptContext, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo) if err != nil { t.Errorf("CreateContext failed: %v", err) } decryptHeader := &rtcp.Header{} decryptContext, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo) if err != nil { t.Errorf("CreateContext failed: %v", err) } for _, pkt := range testCase.packets { // Copy packet, asserts that everything was done in place decryptInput := append([]byte{}, pkt.encrypted...) actualDecrypted, err := decryptContext.DecryptRTCP(decryptInput, decryptInput, decryptHeader) switch { case err != nil: t.Error(err) case decryptHeader.Type != pkt.pktType: t.Fatalf("DecryptRTCP failed to populate input rtcp.Header, expected: %d, got %d", pkt.pktType, decryptHeader.Type) case !bytes.Equal(decryptInput[:len(decryptInput)-(authTagLen+aeadAuthTagLen+srtcpIndexSize)], actualDecrypted): t.Fatalf("DecryptRTP failed to decrypt in place\nexpected: %v\n got: %v", decryptInput[:len(decryptInput)-(authTagLen+srtcpIndexSize)], actualDecrypted) } assert.Equal(decryptInput[:len(decryptInput)-(authTagLen+aeadAuthTagLen+srtcpIndexSize)], actualDecrypted, "DecryptRTP failed to decrypt in place") assert.Equal(pkt.decrypted, actualDecrypted, "RTCP failed to decrypt") // Destination buffer should have capacity to store the resutl. // Otherwise, the buffer may be realloc-ed and the actual result will be written to the other address. encryptInput := make([]byte, 0, len(pkt.encrypted)) // Copy packet, asserts that everything was done in place encryptInput = append(encryptInput, pkt.decrypted...) encryptContext.SetIndex(pkt.ssrc, pkt.index) actualEncrypted, err := encryptContext.EncryptRTCP(encryptInput, encryptInput, encryptHeader) switch { case err != nil: t.Error(err) case encryptHeader.Type != pkt.pktType: t.Fatalf("EncryptRTCP failed to populate input rtcp.Header, expected: %d, got %d", pkt.pktType, encryptHeader.Type) } assert.Equal(actualEncrypted[:len(actualEncrypted)-(authTagLen+aeadAuthTagLen+srtcpIndexSize)], encryptInput, "EncryptRTCP failed to encrypt in place") assert.Equal(pkt.encrypted, actualEncrypted, "RTCP failed to encrypt") } }) } } // Assert that passing a dst buffer that is too short doesn't result in a failure func TestRTCPLifecyclePartialAllocation(t *testing.T) { for caseName, testCase := range rtcpTestCases() { testCase := testCase t.Run(caseName, func(t *testing.T) { assert := assert.New(t) encryptHeader := &rtcp.Header{} encryptContext, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo) if err != nil { t.Errorf("CreateContext failed: %v", err) } decryptHeader := &rtcp.Header{} decryptContext, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo) if err != nil { t.Errorf("CreateContext failed: %v", err) } for _, pkt := range testCase.packets { // Copy packet, asserts that partial buffers can be used decryptDst := make([]byte, len(pkt.decrypted)*2) actualDecrypted, err := decryptContext.DecryptRTCP(decryptDst, pkt.encrypted, decryptHeader) if err != nil { t.Error(err) } else if decryptHeader.Type != pkt.pktType { t.Fatalf("DecryptRTCP failed to populate input rtcp.Header, expected: %d, got %d", pkt.pktType, decryptHeader.Type) } assert.Equal(pkt.decrypted, actualDecrypted, "RTCP failed to decrypt") // Copy packet, asserts that partial buffers can be used encryptDst := make([]byte, len(pkt.encrypted)/2) encryptContext.SetIndex(pkt.ssrc, pkt.index) actualEncrypted, err := encryptContext.EncryptRTCP(encryptDst, pkt.decrypted, encryptHeader) if err != nil { t.Error(err) } else if encryptHeader.Type != pkt.pktType { t.Fatalf("EncryptRTCP failed to populate input rtcp.Header, expected: %d, got %d", pkt.pktType, encryptHeader.Type) } assert.Equal(pkt.encrypted, actualEncrypted, "RTCP failed to encrypt") } }) } } func TestRTCPInvalidAuthTag(t *testing.T) { for caseName, testCase := range rtcpTestCases() { testCase := testCase t.Run(caseName, func(t *testing.T) { assert := assert.New(t) authTagLen, err := testCase.algo.rtcpAuthTagLen() assert.NoError(err) aeadAuthTagLen, err := testCase.algo.aeadAuthTagLen() assert.NoError(err) decryptContext, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo) if err != nil { t.Errorf("CreateContext failed: %v", err) } for _, pkt := range testCase.packets { rtcpPacket := append([]byte{}, pkt.encrypted...) decryptResult, err := decryptContext.DecryptRTCP(nil, rtcpPacket, nil) if err != nil { t.Error(err) } assert.Equal(pkt.decrypted, decryptResult, "RTCP failed to decrypt") // Zero out auth tag if authTagLen > 0 { copy(rtcpPacket[len(rtcpPacket)-authTagLen:], make([]byte, authTagLen)) } if aeadAuthTagLen > 0 { authTagPos := len(rtcpPacket) - authTagLen - srtcpIndexSize - aeadAuthTagLen copy(rtcpPacket[authTagPos:authTagPos+aeadAuthTagLen], make([]byte, aeadAuthTagLen)) } if _, err = decryptContext.DecryptRTCP(nil, rtcpPacket, nil); err == nil { t.Errorf("Was able to decrypt RTCP packet with invalid Auth Tag") } } }) } } func TestRTCPReplayDetectorSeparation(t *testing.T) { for caseName, testCase := range rtcpTestCases() { testCase := testCase t.Run(caseName, func(t *testing.T) { assert := assert.New(t) decryptContext, err := CreateContext( testCase.masterKey, testCase.masterSalt, testCase.algo, SRTCPReplayProtection(10), ) if err != nil { t.Errorf("CreateContext failed: %v", err) } for _, pkt := range testCase.packets { rtcpPacket := append([]byte{}, pkt.encrypted...) decryptResult, errDec := decryptContext.DecryptRTCP(nil, rtcpPacket, nil) if errDec != nil { t.Error(errDec) } assert.Equal(pkt.decrypted, decryptResult, "RTCP failed to decrypt") } for i, pkt := range testCase.packets { rtcpPacket := append([]byte{}, pkt.encrypted...) if _, err = decryptContext.DecryptRTCP(nil, rtcpPacket, nil); !errors.Is(err, errDuplicated) { t.Error("Was able to decrypt duplicated RTCP packet", i) } } }) } } func getRTCPIndex(encrypted []byte, authTagLen int) uint32 { tailOffset := len(encrypted) - (authTagLen + srtcpIndexSize) srtcpIndexBuffer := encrypted[tailOffset : tailOffset+srtcpIndexSize] return binary.BigEndian.Uint32(srtcpIndexBuffer) &^ (1 << 31) } func TestEncryptRTCPSeparation(t *testing.T) { for caseName, testCase := range rtcpTestCases() { testCase := testCase t.Run(caseName, func(t *testing.T) { assert := assert.New(t) encryptContext, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo) assert.NoError(err) authTagLen, err := testCase.algo.rtcpAuthTagLen() assert.NoError(err) decryptContext, err := CreateContext( testCase.masterKey, testCase.masterSalt, testCase.algo, SRTCPReplayProtection(10), ) assert.NoError(err) encryptHeader := &rtcp.Header{} inputs := [][]byte{} expectedIndexes := []uint32{} pktCnt := map[uint32]uint32{} for _, pkt := range testCase.packets { inputs = append(inputs, pkt.decrypted) pktCnt[pkt.ssrc]++ expectedIndexes = append(expectedIndexes, pktCnt[pkt.ssrc]) } for _, pkt := range testCase.packets { inputs = append(inputs, pkt.decrypted) pktCnt[pkt.ssrc]++ expectedIndexes = append(expectedIndexes, pktCnt[pkt.ssrc]) } encryptedRCTPs := make([][]byte, len(inputs)) for i, input := range inputs { encrypted, err := encryptContext.EncryptRTCP(nil, input, encryptHeader) assert.NoError(err) encryptedRCTPs[i] = encrypted } for i, expectedIndex := range expectedIndexes { assert.Equal(expectedIndex, getRTCPIndex(encryptedRCTPs[i], authTagLen), "RTCP index does not match") } for i, output := range encryptedRCTPs { decrypted, err := decryptContext.DecryptRTCP(nil, output, encryptHeader) assert.NoError(err) assert.Equal(decrypted, inputs[i]) } }) } } func TestRTCPDecryptShortenedPacket(t *testing.T) { for caseName, testCase := range rtcpTestCases() { testCase := testCase t.Run(caseName, func(t *testing.T) { pkt := testCase.packets[0] for i := 1; i < len(pkt.encrypted)-1; i++ { packet := pkt.encrypted[:i] decryptContext, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo) if err != nil { t.Errorf("CreateContext failed: %v", err) } assert.NotPanics(t, func() { _, _ = decryptContext.DecryptRTCP(nil, packet, nil) }, "Panic on length %d/%d", i, len(pkt.encrypted)) } }) } } func TestRTCPMaxPackets(t *testing.T) { const ssrc = 0x11111111 testCases := map[string]rtcpTestCase{ "AEAD_AES_128_GCM": { algo: ProtectionProfileAeadAes128Gcm, masterKey: []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f}, masterSalt: []byte{0xa0, 0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7, 0xa8, 0xa9, 0xaa, 0xab}, packets: []rtcpTestPacket{ { pktType: rtcp.TypeSenderReport, encrypted: []byte{ 0x80, 0xc8, 0x00, 0x06, 0x11, 0x11, 0x11, 0x11, 0x02, 0xb6, 0xc1, 0x47, 0x92, 0xbe, 0xf0, 0xae, 0xd9, 0x40, 0xa5, 0x1c, 0xbe, 0xec, 0xaf, 0xfc, 0x7d, 0x86, 0x3b, 0xbb, 0x93, 0x0c, 0xb0, 0xd4, 0xea, 0x4a, 0x3c, 0x5b, 0xd1, 0xd5, 0x47, 0xb1, 0x1a, 0x61, 0xae, 0xa6, 0x1a, 0x0c, 0xb9, 0x14, 0xa5, 0x16, 0x08, 0xe4, 0xfb, 0x0d, 0x15, 0xba, 0x7f, 0x70, 0x2b, 0xb8, 0x99, 0x97, 0x91, 0xfd, 0x53, 0x03, 0xcd, 0x57, 0xbb, 0x8f, 0x93, 0xbe, 0xff, 0xff, 0xff, 0xff, }, decrypted: []byte{ 0x80, 0xc8, 0x00, 0x06, 0x11, 0x11, 0x11, 0x11, 0x04, 0x99, 0x47, 0x53, 0xc4, 0x1e, 0xb9, 0xde, 0x52, 0xa3, 0x1d, 0x77, 0x2f, 0xff, 0xcc, 0x75, 0xbb, 0x6a, 0x29, 0xb8, 0x01, 0xb7, 0x2e, 0x4b, 0x4e, 0xcb, 0xa4, 0x81, 0x2d, 0x46, 0x04, 0x5e, 0x86, 0x90, 0x17, 0x4f, 0x4d, 0x78, 0x2f, 0x58, 0xb8, 0x67, 0x91, 0x89, 0xe3, 0x61, 0x01, 0x7d, }, }, { pktType: rtcp.TypeSenderReport, encrypted: []byte{ 0x80, 0xc8, 0x00, 0x06, 0x11, 0x11, 0x11, 0x11, 0x77, 0x47, 0x0c, 0x21, 0xc2, 0xcd, 0x33, 0xa7, 0x5a, 0x81, 0xb5, 0xb5, 0x8f, 0xe2, 0x34, 0x28, 0x11, 0xa8, 0xa3, 0x34, 0xf8, 0x9d, 0xfc, 0xd8, 0xcb, 0x87, 0xe2, 0x51, 0x8e, 0xae, 0xdb, 0xfd, 0x9d, 0xf1, 0xfa, 0x18, 0xe2, 0xdc, 0x0a, 0xd4, 0xe3, 0x06, 0x18, 0xff, 0xf7, 0x27, 0x92, 0x1f, 0x28, 0xcd, 0x3c, 0xf8, 0xa4, 0x0a, 0x2b, 0xbb, 0x5b, 0x1f, 0x4d, 0x1f, 0xef, 0x0e, 0xc4, 0x91, 0x80, 0x00, 0x00, 0x01, }, decrypted: []byte{ 0x80, 0xc8, 0x00, 0x06, 0x11, 0x11, 0x11, 0x11, 0xda, 0xb5, 0xe0, 0x56, 0x9a, 0x4a, 0x74, 0xed, 0x8a, 0x54, 0x0c, 0xcf, 0xd5, 0x09, 0xb1, 0x40, 0x01, 0x42, 0xc3, 0x9a, 0x76, 0x00, 0xa9, 0xd4, 0xf7, 0x29, 0x9e, 0x51, 0xfb, 0x3c, 0xc1, 0x74, 0x72, 0xf9, 0x52, 0xb1, 0x92, 0x31, 0xca, 0x22, 0xab, 0x3e, 0xc5, 0x5f, 0x83, 0x34, 0xf0, 0x28, }, }, }, }, "AES_128_CM_HMAC_SHA1_80": { algo: ProtectionProfileAes128CmHmacSha1_80, masterKey: []byte{0xfd, 0xa6, 0x25, 0x95, 0xd7, 0xf6, 0x92, 0x6f, 0x7d, 0x9c, 0x02, 0x4c, 0xc9, 0x20, 0x9f, 0x34}, masterSalt: []byte{0xa9, 0x65, 0x19, 0x85, 0x54, 0x0b, 0x47, 0xbe, 0x2f, 0x27, 0xa8, 0xb8, 0x81, 0x23}, packets: []rtcpTestPacket{ { encrypted: []byte{ 0x80, 0xc8, 0x00, 0x06, 0x11, 0x11, 0x11, 0x11, 0x17, 0x8c, 0x15, 0xf1, 0x4b, 0x11, 0xda, 0xf5, 0x74, 0x53, 0x86, 0x2b, 0xc9, 0x07, 0x29, 0x40, 0xbf, 0x22, 0xf6, 0x46, 0x11, 0xa4, 0xc1, 0x3a, 0xff, 0x5a, 0xbd, 0xd0, 0xf8, 0x8b, 0x38, 0xe4, 0x95, 0x38, 0x5d, 0xcf, 0x1b, 0xf5, 0x27, 0x77, 0xfb, 0xdb, 0x3f, 0x10, 0x68, 0x99, 0xd8, 0xad, 0xff, 0xff, 0xff, 0xff, 0x5a, 0x99, 0xce, 0xed, 0x9f, 0x2e, 0x4d, 0x9d, 0xfa, 0x97, }, decrypted: []byte{ 0x80, 0xc8, 0x00, 0x06, 0x11, 0x11, 0x11, 0x11, 0x04, 0x99, 0x47, 0x53, 0xc4, 0x1e, 0xb9, 0xde, 0x52, 0xa3, 0x1d, 0x77, 0x2f, 0xff, 0xcc, 0x75, 0xbb, 0x6a, 0x29, 0xb8, 0x01, 0xb7, 0x2e, 0x4b, 0x4e, 0xcb, 0xa4, 0x81, 0x2d, 0x46, 0x04, 0x5e, 0x86, 0x90, 0x17, 0x4f, 0x4d, 0x78, 0x2f, 0x58, 0xb8, 0x67, 0x91, 0x89, 0xe3, 0x61, 0x01, 0x7d, }, }, { encrypted: []byte{ 0x80, 0xc8, 0x00, 0x06, 0x11, 0x11, 0x11, 0x11, 0x12, 0x71, 0x75, 0x7a, 0xb0, 0xfd, 0x80, 0xcb, 0x26, 0xbb, 0x54, 0x5a, 0x1c, 0x0e, 0x98, 0x09, 0xbe, 0x60, 0x23, 0xd8, 0xe6, 0x6e, 0x68, 0xe8, 0x6e, 0x9c, 0xb2, 0x7e, 0x02, 0xa7, 0xab, 0xfe, 0xb3, 0xf4, 0x4c, 0x13, 0xc3, 0xac, 0x97, 0x2c, 0x35, 0x91, 0xbb, 0x37, 0x9c, 0x86, 0x28, 0x85, 0x80, 0x00, 0x00, 0x01, 0x89, 0x76, 0x07, 0xca, 0xd9, 0xc4, 0xcb, 0xca, 0x66, 0xab, }, decrypted: []byte{ 0x80, 0xc8, 0x00, 0x06, 0x11, 0x11, 0x11, 0x11, 0xda, 0xb5, 0xe0, 0x56, 0x9a, 0x4a, 0x74, 0xed, 0x8a, 0x54, 0x0c, 0xcf, 0xd5, 0x09, 0xb1, 0x40, 0x01, 0x42, 0xc3, 0x9a, 0x76, 0x00, 0xa9, 0xd4, 0xf7, 0x29, 0x9e, 0x51, 0xfb, 0x3c, 0xc1, 0x74, 0x72, 0xf9, 0x52, 0xb1, 0x92, 0x31, 0xca, 0x22, 0xab, 0x3e, 0xc5, 0x5f, 0x83, 0x34, 0xf0, 0x28, }, }, }, }, } for caseName, testCase := range testCases { testCase := testCase t.Run(caseName, func(t *testing.T) { assert := assert.New(t) encryptContext, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo) if err != nil { t.Errorf("CreateContext failed: %v", err) } decryptContext, err := CreateContext(testCase.masterKey, testCase.masterSalt, testCase.algo, SRTCPReplayProtection(10)) if err != nil { t.Errorf("CreateContext failed: %v", err) } // Upper boundary of index encryptContext.SetIndex(ssrc, 0x7ffffffe) decryptResult, err := decryptContext.DecryptRTCP(nil, testCase.packets[0].encrypted, nil) if err != nil { t.Error(err) } assert.Equal(testCase.packets[0].decrypted, decryptResult, "RTCP failed to decrypt") encryptResult, err := encryptContext.EncryptRTCP(nil, testCase.packets[0].decrypted, nil) if err != nil { t.Error(err) } assert.Equal(testCase.packets[0].encrypted, encryptResult, "RTCP failed to encrypt") // Next packet will exceeds the maximum packet count _, err = decryptContext.DecryptRTCP(nil, testCase.packets[1].encrypted, nil) if !errors.Is(err, errDuplicated) { t.Errorf("Expected error: '%v', got: '%v'", errDuplicated, err) } _, err = encryptContext.EncryptRTCP(nil, testCase.packets[1].decrypted, nil) if !errors.Is(err, errExceededMaxPackets) { t.Errorf("Expected error: '%v', got: '%v'", errExceededMaxPackets, err) } }) } } srtp-2.0.12/srtp.go000066400000000000000000000053601437106062400141120ustar00rootroot00000000000000// Package srtp implements Secure Real-time Transport Protocol package srtp import ( "github.com/pion/rtp" ) func (c *Context) decryptRTP(dst, ciphertext []byte, header *rtp.Header, headerLen int) ([]byte, error) { s := c.getSRTPSSRCState(header.SSRC) roc, diff, _ := s.nextRolloverCount(header.SequenceNumber) markAsValid, ok := s.replayDetector.Check( (uint64(roc) << 16) | uint64(header.SequenceNumber), ) if !ok { return nil, &duplicatedError{ Proto: "srtp", SSRC: header.SSRC, Index: uint32(header.SequenceNumber), } } authTagLen, err := c.cipher.rtpAuthTagLen() if err != nil { return nil, err } dst = growBufferSize(dst, len(ciphertext)-authTagLen) dst, err = c.cipher.decryptRTP(dst, ciphertext, header, headerLen, roc) if err != nil { return nil, err } markAsValid() s.updateRolloverCount(header.SequenceNumber, diff) return dst, nil } // DecryptRTP decrypts a RTP packet with an encrypted payload func (c *Context) DecryptRTP(dst, encrypted []byte, header *rtp.Header) ([]byte, error) { if header == nil { header = &rtp.Header{} } headerLen, err := header.Unmarshal(encrypted) if err != nil { return nil, err } return c.decryptRTP(dst, encrypted, header, headerLen) } // EncryptRTP marshals and encrypts an RTP packet, writing to the dst buffer provided. // If the dst buffer does not have the capacity to hold `len(plaintext) + 10` bytes, a new one will be allocated and returned. // If a rtp.Header is provided, it will be Unmarshaled using the plaintext. func (c *Context) EncryptRTP(dst []byte, plaintext []byte, header *rtp.Header) ([]byte, error) { if header == nil { header = &rtp.Header{} } headerLen, err := header.Unmarshal(plaintext) if err != nil { return nil, err } return c.encryptRTP(dst, header, plaintext[headerLen:]) } // encryptRTP marshals and encrypts an RTP packet, writing to the dst buffer provided. // If the dst buffer does not have the capacity, a new one will be allocated and returned. // Similar to above but faster because it can avoid unmarshaling the header and marshaling the payload. func (c *Context) encryptRTP(dst []byte, header *rtp.Header, payload []byte) (ciphertext []byte, err error) { s := c.getSRTPSSRCState(header.SSRC) roc, diff, ovf := s.nextRolloverCount(header.SequenceNumber) if ovf { // ... when 2^48 SRTP packets or 2^31 SRTCP packets have been secured with the same key // (whichever occurs before), the key management MUST be called to provide new master key(s) // (previously stored and used keys MUST NOT be used again), or the session MUST be terminated. // https://www.rfc-editor.org/rfc/rfc3711#section-9.2 return nil, errExceededMaxPackets } s.updateRolloverCount(header.SequenceNumber, diff) return c.cipher.encryptRTP(dst, header, payload, roc) } srtp-2.0.12/srtp_cipher.go000066400000000000000000000033051437106062400154410ustar00rootroot00000000000000package srtp import "github.com/pion/rtp" // cipher represents a implementation of one // of the SRTP Specific ciphers type srtpCipher interface { // authTagLen returns auth key length of the cipher. // See the note below. rtpAuthTagLen() (int, error) rtcpAuthTagLen() (int, error) // aeadAuthTagLen returns AEAD auth key length of the cipher. // See the note below. aeadAuthTagLen() (int, error) getRTCPIndex([]byte) uint32 encryptRTP([]byte, *rtp.Header, []byte, uint32) ([]byte, error) encryptRTCP([]byte, []byte, uint32, uint32) ([]byte, error) decryptRTP([]byte, []byte, *rtp.Header, int, uint32) ([]byte, error) decryptRTCP([]byte, []byte, uint32, uint32) ([]byte, error) } /* NOTE: Auth tag and AEAD auth tag are placed at the different position in SRTCP In non-AEAD cipher, the authentication tag is placed *after* the ESRTCP word (Encrypted-flag and SRTCP index). > AES_128_CM_HMAC_SHA1_80 > | RTCP Header | Encrypted payload |E| SRTCP Index | Auth tag | > ^ |----------| > | ^ > | authTagLen=10 > aeadAuthTagLen=0 In AEAD cipher, the AEAD authentication tag is embedded in the ciphertext. It is *before* the ESRTCP word (Encrypted-flag and SRTCP index). > AEAD_AES_128_GCM > | RTCP Header | Encrypted payload | AEAD auth tag |E| SRTCP Index | > |---------------| ^ > ^ authTagLen=0 > aeadAuthTagLen=16 See https://tools.ietf.org/html/rfc7714 for the full specifications. */ srtp-2.0.12/srtp_cipher_aead_aes_gcm.go000066400000000000000000000133501437106062400200720ustar00rootroot00000000000000package srtp import ( "crypto/aes" "crypto/cipher" "encoding/binary" "github.com/pion/rtp" ) const ( rtcpEncryptionFlag = 0x80 ) type srtpCipherAeadAesGcm struct { ProtectionProfile srtpCipher, srtcpCipher cipher.AEAD srtpSessionSalt, srtcpSessionSalt []byte } func newSrtpCipherAeadAesGcm(profile ProtectionProfile, masterKey, masterSalt []byte) (*srtpCipherAeadAesGcm, error) { s := &srtpCipherAeadAesGcm{ProtectionProfile: profile} srtpSessionKey, err := aesCmKeyDerivation(labelSRTPEncryption, masterKey, masterSalt, 0, len(masterKey)) if err != nil { return nil, err } srtpBlock, err := aes.NewCipher(srtpSessionKey) if err != nil { return nil, err } s.srtpCipher, err = cipher.NewGCM(srtpBlock) if err != nil { return nil, err } srtcpSessionKey, err := aesCmKeyDerivation(labelSRTCPEncryption, masterKey, masterSalt, 0, len(masterKey)) if err != nil { return nil, err } srtcpBlock, err := aes.NewCipher(srtcpSessionKey) if err != nil { return nil, err } s.srtcpCipher, err = cipher.NewGCM(srtcpBlock) if err != nil { return nil, err } if s.srtpSessionSalt, err = aesCmKeyDerivation(labelSRTPSalt, masterKey, masterSalt, 0, len(masterSalt)); err != nil { return nil, err } else if s.srtcpSessionSalt, err = aesCmKeyDerivation(labelSRTCPSalt, masterKey, masterSalt, 0, len(masterSalt)); err != nil { return nil, err } return s, nil } func (s *srtpCipherAeadAesGcm) encryptRTP(dst []byte, header *rtp.Header, payload []byte, roc uint32) (ciphertext []byte, err error) { // Grow the given buffer to fit the output. authTagLen, err := s.aeadAuthTagLen() if err != nil { return nil, err } dst = growBufferSize(dst, header.MarshalSize()+len(payload)+authTagLen) n, err := header.MarshalTo(dst) if err != nil { return nil, err } iv := s.rtpInitializationVector(header, roc) s.srtpCipher.Seal(dst[n:n], iv[:], payload, dst[:n]) return dst, nil } func (s *srtpCipherAeadAesGcm) decryptRTP(dst, ciphertext []byte, header *rtp.Header, headerLen int, roc uint32) ([]byte, error) { // Grow the given buffer to fit the output. authTagLen, err := s.aeadAuthTagLen() if err != nil { return nil, err } nDst := len(ciphertext) - authTagLen if nDst < 0 { // Size of ciphertext is shorter than AEAD auth tag len. return nil, errFailedToVerifyAuthTag } dst = growBufferSize(dst, nDst) iv := s.rtpInitializationVector(header, roc) if _, err := s.srtpCipher.Open( dst[headerLen:headerLen], iv[:], ciphertext[headerLen:], ciphertext[:headerLen], ); err != nil { return nil, err } copy(dst[:headerLen], ciphertext[:headerLen]) return dst, nil } func (s *srtpCipherAeadAesGcm) encryptRTCP(dst, decrypted []byte, srtcpIndex uint32, ssrc uint32) ([]byte, error) { authTagLen, err := s.aeadAuthTagLen() if err != nil { return nil, err } aadPos := len(decrypted) + authTagLen // Grow the given buffer to fit the output. dst = growBufferSize(dst, aadPos+srtcpIndexSize) iv := s.rtcpInitializationVector(srtcpIndex, ssrc) aad := s.rtcpAdditionalAuthenticatedData(decrypted, srtcpIndex) s.srtcpCipher.Seal(dst[8:8], iv[:], decrypted[8:], aad[:]) copy(dst[:8], decrypted[:8]) copy(dst[aadPos:aadPos+4], aad[8:12]) return dst, nil } func (s *srtpCipherAeadAesGcm) decryptRTCP(dst, encrypted []byte, srtcpIndex, ssrc uint32) ([]byte, error) { aadPos := len(encrypted) - srtcpIndexSize // Grow the given buffer to fit the output. authTagLen, err := s.aeadAuthTagLen() if err != nil { return nil, err } nDst := aadPos - authTagLen if nDst < 0 { // Size of ciphertext is shorter than AEAD auth tag len. return nil, errFailedToVerifyAuthTag } dst = growBufferSize(dst, nDst) iv := s.rtcpInitializationVector(srtcpIndex, ssrc) aad := s.rtcpAdditionalAuthenticatedData(encrypted, srtcpIndex) if _, err := s.srtcpCipher.Open(dst[8:8], iv[:], encrypted[8:aadPos], aad[:]); err != nil { return nil, err } copy(dst[:8], encrypted[:8]) return dst, nil } // The 12-octet IV used by AES-GCM SRTP is formed by first concatenating // 2 octets of zeroes, the 4-octet SSRC, the 4-octet rollover counter // (ROC), and the 2-octet sequence number (SEQ). The resulting 12-octet // value is then XORed to the 12-octet salt to form the 12-octet IV. // // https://tools.ietf.org/html/rfc7714#section-8.1 func (s *srtpCipherAeadAesGcm) rtpInitializationVector(header *rtp.Header, roc uint32) [12]byte { var iv [12]byte binary.BigEndian.PutUint32(iv[2:], header.SSRC) binary.BigEndian.PutUint32(iv[6:], roc) binary.BigEndian.PutUint16(iv[10:], header.SequenceNumber) for i := range iv { iv[i] ^= s.srtpSessionSalt[i] } return iv } // The 12-octet IV used by AES-GCM SRTCP is formed by first // concatenating 2 octets of zeroes, the 4-octet SSRC identifier, // 2 octets of zeroes, a single "0" bit, and the 31-bit SRTCP index. // The resulting 12-octet value is then XORed to the 12-octet salt to // form the 12-octet IV. // // https://tools.ietf.org/html/rfc7714#section-9.1 func (s *srtpCipherAeadAesGcm) rtcpInitializationVector(srtcpIndex uint32, ssrc uint32) [12]byte { var iv [12]byte binary.BigEndian.PutUint32(iv[2:], ssrc) binary.BigEndian.PutUint32(iv[8:], srtcpIndex) for i := range iv { iv[i] ^= s.srtcpSessionSalt[i] } return iv } // In an SRTCP packet, a 1-bit Encryption flag is prepended to the // 31-bit SRTCP index to form a 32-bit value we shall call the // "ESRTCP word" // // https://tools.ietf.org/html/rfc7714#section-17 func (s *srtpCipherAeadAesGcm) rtcpAdditionalAuthenticatedData(rtcpPacket []byte, srtcpIndex uint32) [12]byte { var aad [12]byte copy(aad[:], rtcpPacket[:8]) binary.BigEndian.PutUint32(aad[8:], srtcpIndex) aad[8] |= rtcpEncryptionFlag return aad } func (s *srtpCipherAeadAesGcm) getRTCPIndex(in []byte) uint32 { return binary.BigEndian.Uint32(in[len(in)-4:]) &^ (rtcpEncryptionFlag << 24) } srtp-2.0.12/srtp_cipher_aead_aes_gcm_test.go000066400000000000000000000053731437106062400211370ustar00rootroot00000000000000package srtp import ( "testing" "github.com/stretchr/testify/assert" ) func TestSrtpCipherAedAesGcm(t *testing.T) { decryptedRTPPacket := []byte{ 0x80, 0x0f, 0x12, 0x34, 0xde, 0xca, 0xfb, 0xad, 0xca, 0xfe, 0xba, 0xbe, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, } encryptedRTPPacket := []byte{ 0x80, 0x0f, 0x12, 0x34, 0xde, 0xca, 0xfb, 0xad, 0xca, 0xfe, 0xba, 0xbe, 0xc5, 0x00, 0x2e, 0xde, 0x04, 0xcf, 0xdd, 0x2e, 0xb9, 0x11, 0x59, 0xe0, 0x88, 0x0a, 0xa0, 0x6e, 0xd2, 0x97, 0x68, 0x26, 0xf7, 0x96, 0xb2, 0x01, 0xdf, 0x31, 0x31, 0xa1, 0x27, 0xe8, 0xa3, 0x92, } decryptedRtcpPacket := []byte{ 0x81, 0xc8, 0x00, 0x0b, 0xca, 0xfe, 0xba, 0xbe, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, 0xab, } encryptedRtcpPacket := []byte{ 0x81, 0xc8, 0x00, 0x0b, 0xca, 0xfe, 0xba, 0xbe, 0xc9, 0x8b, 0x8b, 0x5d, 0xf0, 0x39, 0x2a, 0x55, 0x85, 0x2b, 0x6c, 0x21, 0xac, 0x8e, 0x70, 0x25, 0xc5, 0x2c, 0x6f, 0xbe, 0xa2, 0xb3, 0xb4, 0x46, 0xea, 0x31, 0x12, 0x3b, 0xa8, 0x8c, 0xe6, 0x1e, 0x80, 0x00, 0x00, 0x01, } masterKey := []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f} masterSalt := []byte{0xa0, 0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7, 0xa8, 0xa9, 0xaa, 0xab} t.Run("Encrypt RTP", func(t *testing.T) { ctx, err := CreateContext(masterKey, masterSalt, ProtectionProfileAeadAes128Gcm) assert.NoError(t, err) t.Run("New Allocation", func(t *testing.T) { actualEncrypted, err := ctx.EncryptRTP(nil, decryptedRTPPacket, nil) assert.NoError(t, err) assert.Equal(t, encryptedRTPPacket, actualEncrypted) }) }) t.Run("Decrypt RTP", func(t *testing.T) { ctx, err := CreateContext(masterKey, masterSalt, ProtectionProfileAeadAes128Gcm) assert.NoError(t, err) t.Run("New Allocation", func(t *testing.T) { actualDecrypted, err := ctx.DecryptRTP(nil, encryptedRTPPacket, nil) assert.NoError(t, err) assert.Equal(t, decryptedRTPPacket, actualDecrypted) }) }) t.Run("Encrypt RTCP", func(t *testing.T) { ctx, err := CreateContext(masterKey, masterSalt, ProtectionProfileAeadAes128Gcm) assert.NoError(t, err) t.Run("New Allocation", func(t *testing.T) { actualEncrypted, err := ctx.EncryptRTCP(nil, decryptedRtcpPacket, nil) assert.NoError(t, err) assert.Equal(t, encryptedRtcpPacket, actualEncrypted) }) }) t.Run("Decrypt RTCP", func(t *testing.T) { ctx, err := CreateContext(masterKey, masterSalt, ProtectionProfileAeadAes128Gcm) assert.NoError(t, err) t.Run("New Allocation", func(t *testing.T) { actualDecrypted, err := ctx.DecryptRTCP(nil, encryptedRtcpPacket, nil) assert.NoError(t, err) assert.Equal(t, decryptedRtcpPacket, actualDecrypted) }) }) } srtp-2.0.12/srtp_cipher_aes_cm_hmac_sha1.go000066400000000000000000000173061437106062400206620ustar00rootroot00000000000000package srtp import ( //nolint:gci "crypto/aes" "crypto/cipher" "crypto/hmac" "crypto/sha1" //nolint:gosec "crypto/subtle" "encoding/binary" "hash" "github.com/pion/rtp" ) type srtpCipherAesCmHmacSha1 struct { ProtectionProfile srtpSessionSalt []byte srtpSessionAuth hash.Hash srtpBlock cipher.Block srtcpSessionSalt []byte srtcpSessionAuth hash.Hash srtcpBlock cipher.Block } func newSrtpCipherAesCmHmacSha1(profile ProtectionProfile, masterKey, masterSalt []byte) (*srtpCipherAesCmHmacSha1, error) { s := &srtpCipherAesCmHmacSha1{ProtectionProfile: profile} srtpSessionKey, err := aesCmKeyDerivation(labelSRTPEncryption, masterKey, masterSalt, 0, len(masterKey)) if err != nil { return nil, err } else if s.srtpBlock, err = aes.NewCipher(srtpSessionKey); err != nil { return nil, err } srtcpSessionKey, err := aesCmKeyDerivation(labelSRTCPEncryption, masterKey, masterSalt, 0, len(masterKey)) if err != nil { return nil, err } else if s.srtcpBlock, err = aes.NewCipher(srtcpSessionKey); err != nil { return nil, err } if s.srtpSessionSalt, err = aesCmKeyDerivation(labelSRTPSalt, masterKey, masterSalt, 0, len(masterSalt)); err != nil { return nil, err } else if s.srtcpSessionSalt, err = aesCmKeyDerivation(labelSRTCPSalt, masterKey, masterSalt, 0, len(masterSalt)); err != nil { return nil, err } authKeyLen, err := profile.authKeyLen() if err != nil { return nil, err } srtpSessionAuthTag, err := aesCmKeyDerivation(labelSRTPAuthenticationTag, masterKey, masterSalt, 0, authKeyLen) if err != nil { return nil, err } srtcpSessionAuthTag, err := aesCmKeyDerivation(labelSRTCPAuthenticationTag, masterKey, masterSalt, 0, authKeyLen) if err != nil { return nil, err } s.srtcpSessionAuth = hmac.New(sha1.New, srtcpSessionAuthTag) s.srtpSessionAuth = hmac.New(sha1.New, srtpSessionAuthTag) return s, nil } func (s *srtpCipherAesCmHmacSha1) encryptRTP(dst []byte, header *rtp.Header, payload []byte, roc uint32) (ciphertext []byte, err error) { // Grow the given buffer to fit the output. authTagLen, err := s.rtpAuthTagLen() if err != nil { return nil, err } dst = growBufferSize(dst, header.MarshalSize()+len(payload)+authTagLen) // Copy the header unencrypted. n, err := header.MarshalTo(dst) if err != nil { return nil, err } // Encrypt the payload counter := generateCounter(header.SequenceNumber, roc, header.SSRC, s.srtpSessionSalt) if err = xorBytesCTR(s.srtpBlock, counter[:], dst[n:], payload); err != nil { return nil, err } n += len(payload) // Generate the auth tag. authTag, err := s.generateSrtpAuthTag(dst[:n], roc) if err != nil { return nil, err } // Write the auth tag to the dest. copy(dst[n:], authTag) return dst, nil } func (s *srtpCipherAesCmHmacSha1) decryptRTP(dst, ciphertext []byte, header *rtp.Header, headerLen int, roc uint32) ([]byte, error) { // Split the auth tag and the cipher text into two parts. authTagLen, err := s.rtpAuthTagLen() if err != nil { return nil, err } actualTag := ciphertext[len(ciphertext)-authTagLen:] ciphertext = ciphertext[:len(ciphertext)-authTagLen] // Generate the auth tag we expect to see from the ciphertext. expectedTag, err := s.generateSrtpAuthTag(ciphertext, roc) if err != nil { return nil, err } // See if the auth tag actually matches. // We use a constant time comparison to prevent timing attacks. if subtle.ConstantTimeCompare(actualTag, expectedTag) != 1 { return nil, errFailedToVerifyAuthTag } // Write the plaintext header to the destination buffer. copy(dst, ciphertext[:headerLen]) // Decrypt the ciphertext for the payload. counter := generateCounter(header.SequenceNumber, roc, header.SSRC, s.srtpSessionSalt) err = xorBytesCTR( s.srtpBlock, counter[:], dst[headerLen:], ciphertext[headerLen:], ) return dst, err } func (s *srtpCipherAesCmHmacSha1) encryptRTCP(dst, decrypted []byte, srtcpIndex uint32, ssrc uint32) ([]byte, error) { dst = allocateIfMismatch(dst, decrypted) // Encrypt everything after header counter := generateCounter(uint16(srtcpIndex&0xffff), srtcpIndex>>16, ssrc, s.srtcpSessionSalt) if err := xorBytesCTR(s.srtcpBlock, counter[:], dst[8:], dst[8:]); err != nil { return nil, err } // Add SRTCP Index and set Encryption bit dst = append(dst, make([]byte, 4)...) binary.BigEndian.PutUint32(dst[len(dst)-4:], srtcpIndex) dst[len(dst)-4] |= 0x80 authTag, err := s.generateSrtcpAuthTag(dst) if err != nil { return nil, err } return append(dst, authTag...), nil } func (s *srtpCipherAesCmHmacSha1) decryptRTCP(out, encrypted []byte, index, ssrc uint32) ([]byte, error) { authTagLen, err := s.rtcpAuthTagLen() if err != nil { return nil, err } tailOffset := len(encrypted) - (authTagLen + srtcpIndexSize) out = out[0:tailOffset] expectedTag, err := s.generateSrtcpAuthTag(encrypted[:len(encrypted)-authTagLen]) if err != nil { return nil, err } actualTag := encrypted[len(encrypted)-authTagLen:] if subtle.ConstantTimeCompare(actualTag, expectedTag) != 1 { return nil, errFailedToVerifyAuthTag } counter := generateCounter(uint16(index&0xffff), index>>16, ssrc, s.srtcpSessionSalt) err = xorBytesCTR(s.srtcpBlock, counter[:], out[8:], out[8:]) return out, err } func (s *srtpCipherAesCmHmacSha1) generateSrtpAuthTag(buf []byte, roc uint32) ([]byte, error) { // https://tools.ietf.org/html/rfc3711#section-4.2 // In the case of SRTP, M SHALL consist of the Authenticated // Portion of the packet (as specified in Figure 1) concatenated with // the ROC, M = Authenticated Portion || ROC; // // The pre-defined authentication transform for SRTP is HMAC-SHA1 // [RFC2104]. With HMAC-SHA1, the SRTP_PREFIX_LENGTH (Figure 3) SHALL // be 0. For SRTP (respectively SRTCP), the HMAC SHALL be applied to // the session authentication key and M as specified above, i.e., // HMAC(k_a, M). The HMAC output SHALL then be truncated to the n_tag // left-most bits. // - Authenticated portion of the packet is everything BEFORE MKI // - k_a is the session message authentication key // - n_tag is the bit-length of the output authentication tag s.srtpSessionAuth.Reset() if _, err := s.srtpSessionAuth.Write(buf); err != nil { return nil, err } // For SRTP only, we need to hash the rollover counter as well. rocRaw := [4]byte{} binary.BigEndian.PutUint32(rocRaw[:], roc) _, err := s.srtpSessionAuth.Write(rocRaw[:]) if err != nil { return nil, err } // Truncate the hash to the size indicated by the profile authTagLen, err := s.rtpAuthTagLen() if err != nil { return nil, err } return s.srtpSessionAuth.Sum(nil)[0:authTagLen], nil } func (s *srtpCipherAesCmHmacSha1) generateSrtcpAuthTag(buf []byte) ([]byte, error) { // https://tools.ietf.org/html/rfc3711#section-4.2 // // The pre-defined authentication transform for SRTP is HMAC-SHA1 // [RFC2104]. With HMAC-SHA1, the SRTP_PREFIX_LENGTH (Figure 3) SHALL // be 0. For SRTP (respectively SRTCP), the HMAC SHALL be applied to // the session authentication key and M as specified above, i.e., // HMAC(k_a, M). The HMAC output SHALL then be truncated to the n_tag // left-most bits. // - Authenticated portion of the packet is everything BEFORE MKI // - k_a is the session message authentication key // - n_tag is the bit-length of the output authentication tag s.srtcpSessionAuth.Reset() if _, err := s.srtcpSessionAuth.Write(buf); err != nil { return nil, err } authTagLen, err := s.rtcpAuthTagLen() if err != nil { return nil, err } return s.srtcpSessionAuth.Sum(nil)[0:authTagLen], nil } func (s *srtpCipherAesCmHmacSha1) getRTCPIndex(in []byte) uint32 { authTagLen, _ := s.rtcpAuthTagLen() tailOffset := len(in) - (authTagLen + srtcpIndexSize) srtcpIndexBuffer := in[tailOffset : tailOffset+srtcpIndexSize] return binary.BigEndian.Uint32(srtcpIndexBuffer) &^ (1 << 31) } srtp-2.0.12/srtp_test.go000066400000000000000000000575071437106062400151630ustar00rootroot00000000000000package srtp import ( "bytes" "errors" "testing" "github.com/pion/rtp" "github.com/stretchr/testify/assert" ) const ( profileCTR = ProtectionProfileAes128CmHmacSha1_80 profileGCM = ProtectionProfileAeadAes128Gcm defaultSsrc = 0 ) type rtpTestCase struct { sequenceNumber uint16 encryptedCTR []byte encryptedGCM []byte } func (tc rtpTestCase) encrypted(profile ProtectionProfile) []byte { switch profile { case profileCTR: return tc.encryptedCTR case profileGCM: return tc.encryptedGCM default: panic("unknown profile") } } func testKeyLen(t *testing.T, profile ProtectionProfile) { keyLen, err := profile.keyLen() assert.NoError(t, err) saltLen, err := profile.saltLen() assert.NoError(t, err) if _, err := CreateContext([]byte{}, make([]byte, saltLen), profile); err == nil { t.Errorf("CreateContext accepted a 0 length key") } if _, err := CreateContext(make([]byte, keyLen), []byte{}, profile); err == nil { t.Errorf("CreateContext accepted a 0 length salt") } if _, err := CreateContext(make([]byte, keyLen), make([]byte, saltLen), profile); err != nil { t.Errorf("CreateContext failed with a valid length key and salt: %v", err) } } func TestKeyLen(t *testing.T) { t.Run("CTR", func(t *testing.T) { testKeyLen(t, profileCTR) }) t.Run("GCM", func(t *testing.T) { testKeyLen(t, profileGCM) }) } func TestValidPacketCounter(t *testing.T) { masterKey := []byte{0x0d, 0xcd, 0x21, 0x3e, 0x4c, 0xbc, 0xf2, 0x8f, 0x01, 0x7f, 0x69, 0x94, 0x40, 0x1e, 0x28, 0x89} masterSalt := []byte{0x62, 0x77, 0x60, 0x38, 0xc0, 0x6d, 0xc9, 0x41, 0x9f, 0x6d, 0xd9, 0x43, 0x3e, 0x7c} srtpSessionSalt, err := aesCmKeyDerivation(labelSRTPSalt, masterKey, masterSalt, 0, len(masterSalt)) assert.NoError(t, err) s := &srtpSSRCState{ssrc: 4160032510} expectedCounter := []byte{0xcf, 0x90, 0x1e, 0xa5, 0xda, 0xd3, 0x2c, 0x15, 0x00, 0xa2, 0x24, 0xae, 0xae, 0xaf, 0x00, 0x00} counter := generateCounter(32846, uint32(s.index>>16), s.ssrc, srtpSessionSalt) if !bytes.Equal(counter[:], expectedCounter) { t.Errorf("Session Key % 02x does not match expected % 02x", counter, expectedCounter) } } func TestRolloverCount(t *testing.T) { s := &srtpSSRCState{ssrc: defaultSsrc} // Set initial seqnum roc, diff, ovf := s.nextRolloverCount(65530) if roc != 0 { t.Errorf("Initial rolloverCounter must be 0") } if ovf { t.Error("Should not overflow") } s.updateRolloverCount(65530, diff) // Invalid packets never update ROC s.nextRolloverCount(0) s.nextRolloverCount(0x4000) s.nextRolloverCount(0x8000) s.nextRolloverCount(0xFFFF) s.nextRolloverCount(0) // We rolled over to 0 roc, diff, ovf = s.nextRolloverCount(0) if roc != 1 { t.Errorf("rolloverCounter was not updated after it crossed 0") } if ovf { t.Error("Should not overflow") } s.updateRolloverCount(0, diff) roc, diff, ovf = s.nextRolloverCount(65530) if roc != 0 { t.Errorf("rolloverCounter was not updated when it rolled back, failed to handle out of order") } if ovf { t.Error("Should not overflow") } s.updateRolloverCount(65530, diff) roc, diff, ovf = s.nextRolloverCount(5) if roc != 1 { t.Errorf("rolloverCounter was not updated when it rolled over initial, to handle out of order") } if ovf { t.Error("Should not overflow") } s.updateRolloverCount(5, diff) _, diff, _ = s.nextRolloverCount(6) s.updateRolloverCount(6, diff) _, diff, _ = s.nextRolloverCount(7) s.updateRolloverCount(7, diff) roc, diff, _ = s.nextRolloverCount(8) if roc != 1 { t.Errorf("rolloverCounter was improperly updated for non-significant packets") } s.updateRolloverCount(8, diff) // valid packets never update ROC roc, diff, ovf = s.nextRolloverCount(0x4000) if roc != 1 { t.Errorf("rolloverCounter was improperly updated for non-significant packets") } if ovf { t.Error("Should not overflow") } s.updateRolloverCount(0x4000, diff) roc, diff, ovf = s.nextRolloverCount(0x8000) if roc != 1 { t.Errorf("rolloverCounter was improperly updated for non-significant packets") } if ovf { t.Error("Should not overflow") } s.updateRolloverCount(0x8000, diff) roc, diff, ovf = s.nextRolloverCount(0xFFFF) if roc != 1 { t.Errorf("rolloverCounter was improperly updated for non-significant packets") } if ovf { t.Error("Should not overflow") } s.updateRolloverCount(0xFFFF, diff) roc, _, ovf = s.nextRolloverCount(0) if roc != 2 { t.Errorf("rolloverCounter must be incremented after wrapping, got %d", roc) } if ovf { t.Error("Should not overflow") } } func TestRolloverCountOverflow(t *testing.T) { s := &srtpSSRCState{ ssrc: defaultSsrc, index: maxROC << 16, } s.updateRolloverCount(0xFFFF, 0) _, _, ovf := s.nextRolloverCount(0) if !ovf { t.Error("Should overflow") } } func buildTestContext(profile ProtectionProfile, opts ...ContextOption) (*Context, error) { keyLen, err := profile.keyLen() if err != nil { return nil, err } saltLen, err := profile.saltLen() if err != nil { return nil, err } masterKey := []byte{0x0d, 0xcd, 0x21, 0x3e, 0x4c, 0xbc, 0xf2, 0x8f, 0x01, 0x7f, 0x69, 0x94, 0x40, 0x1e, 0x28, 0x89} masterKey = masterKey[:keyLen] masterSalt := []byte{0x62, 0x77, 0x60, 0x38, 0xc0, 0x6d, 0xc9, 0x41, 0x9f, 0x6d, 0xd9, 0x43, 0x3e, 0x7c} masterSalt = masterSalt[:saltLen] return CreateContext(masterKey, masterSalt, profile, opts...) } func TestRTPInvalidAuth(t *testing.T) { masterKey := []byte{0x0d, 0xcd, 0x21, 0x3e, 0x4c, 0xbc, 0xf2, 0x8f, 0x01, 0x7f, 0x69, 0x94, 0x40, 0x1e, 0x28, 0x89} invalidSalt := []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00} encryptContext, err := buildTestContext(profileCTR) if err != nil { t.Fatal(err) } invalidContext, err := CreateContext(masterKey, invalidSalt, profileCTR) if err != nil { t.Errorf("CreateContext failed: %v", err) } for _, testCase := range rtpTestCases() { pkt := &rtp.Packet{Payload: rtpTestCaseDecrypted(), Header: rtp.Header{SequenceNumber: testCase.sequenceNumber}} pktRaw, err := pkt.Marshal() if err != nil { t.Fatal(err) } out, err := encryptContext.EncryptRTP(nil, pktRaw, nil) if err != nil { t.Fatal(err) } if _, err := invalidContext.DecryptRTP(nil, out, nil); err == nil { t.Errorf("Managed to decrypt with incorrect salt for packet with SeqNum: %d", testCase.sequenceNumber) } } } func rtpTestCaseDecrypted() []byte { return []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05} } func rtpTestCases() []rtpTestCase { return []rtpTestCase{ { sequenceNumber: 5000, encryptedCTR: []byte{0x6d, 0xd3, 0x7e, 0xd5, 0x99, 0xb7, 0x2d, 0x28, 0xb1, 0xf3, 0xa1, 0xf0, 0xc, 0xfb, 0xfd, 0x8}, encryptedGCM: []byte{0x05, 0x39, 0x62, 0xbb, 0x50, 0x2a, 0x08, 0x19, 0xc7, 0xcc, 0xc9, 0x24, 0xb8, 0xd9, 0x7a, 0xe5, 0xad, 0x99, 0x06, 0xc7, 0x3b, 0}, }, { sequenceNumber: 5001, encryptedCTR: []byte{0xda, 0x47, 0xb, 0x2a, 0x74, 0x53, 0x65, 0xbd, 0x2f, 0xeb, 0xdc, 0x4b, 0x6d, 0x23, 0xf3, 0xde}, encryptedGCM: []byte{0xb0, 0xbc, 0xfc, 0xb0, 0x15, 0x2c, 0xa0, 0x15, 0xb5, 0xa8, 0xcd, 0x0d, 0x65, 0xfa, 0x98, 0xb3, 0x09, 0xb1, 0xf8, 0x4b, 0x1c, 0xfa}, }, { sequenceNumber: 5002, encryptedCTR: []byte{0x6e, 0xa7, 0x69, 0x8d, 0x24, 0x6d, 0xdc, 0xbf, 0xec, 0x2, 0x1c, 0xd1, 0x60, 0x76, 0xc1, 0xe}, encryptedGCM: []byte{0x5e, 0x20, 0x6a, 0xbf, 0x58, 0x7e, 0x24, 0xc0, 0x15, 0x94, 0x7a, 0xe2, 0x49, 0x25, 0xd4, 0xd4, 0x08, 0xe2, 0xf1, 0x47, 0x7a, 0x33}, }, { sequenceNumber: 5003, encryptedCTR: []byte{0x24, 0x7e, 0x96, 0xc8, 0x7d, 0x33, 0xa2, 0x92, 0x8d, 0x13, 0x8d, 0xe0, 0x76, 0x9f, 0x8, 0xdc}, encryptedGCM: []byte{0xb0, 0x63, 0x14, 0xe7, 0xd2, 0x29, 0xca, 0x92, 0x8c, 0x97, 0x25, 0xd2, 0x50, 0x69, 0x6e, 0x1b, 0x04, 0xb9, 0x37, 0xa5, 0xa1, 0xc5}, }, { sequenceNumber: 5004, encryptedCTR: []byte{0x75, 0x43, 0x28, 0xe4, 0x3a, 0x77, 0x59, 0x9b, 0x2e, 0xdf, 0x7b, 0x12, 0x68, 0xb, 0x57, 0x49}, encryptedGCM: []byte{0xb2, 0x4f, 0x19, 0x53, 0x79, 0x8a, 0x9b, 0x9e, 0xe5, 0x22, 0x93, 0x14, 0x50, 0x8a, 0x8c, 0xd5, 0xfc, 0x61, 0xbf, 0x95, 0xd1, 0xfb}, }, { sequenceNumber: 65535, // upper boundary encryptedCTR: []byte{0xaf, 0xf7, 0xc2, 0x70, 0x37, 0x20, 0x83, 0x9c, 0x2c, 0x63, 0x85, 0x15, 0xe, 0x44, 0xca, 0x36}, encryptedGCM: []byte{0x40, 0x44, 0x6c, 0xd1, 0x33, 0x5f, 0xca, 0x9b, 0x2e, 0xa3, 0xe5, 0x03, 0xd7, 0x82, 0x36, 0xd8, 0xb7, 0xe8, 0x97, 0x3c, 0xe6, 0xb6}, }, } } func testRTPLifecyleNewAlloc(t *testing.T, profile ProtectionProfile) { assert := assert.New(t) authTagLen, err := profile.rtpAuthTagLen() assert.NoError(err) for _, testCase := range rtpTestCases() { encryptContext, err := buildTestContext(profile) if err != nil { t.Fatal(err) } decryptContext, err := buildTestContext(profile) if err != nil { t.Fatal(err) } decryptedPkt := &rtp.Packet{Payload: rtpTestCaseDecrypted(), Header: rtp.Header{SequenceNumber: testCase.sequenceNumber}} decryptedRaw, err := decryptedPkt.Marshal() if err != nil { t.Fatal(err) } encryptedPkt := &rtp.Packet{Payload: testCase.encrypted(profile), Header: rtp.Header{SequenceNumber: testCase.sequenceNumber}} encryptedRaw, err := encryptedPkt.Marshal() if err != nil { t.Fatal(err) } actualEncrypted, err := encryptContext.EncryptRTP(nil, decryptedRaw, nil) if err != nil { t.Fatal(err) } assert.Equalf(actualEncrypted, encryptedRaw, "RTP packet with SeqNum invalid encryption: %d", testCase.sequenceNumber) actualDecrypted, err := decryptContext.DecryptRTP(nil, encryptedRaw, nil) if err != nil { t.Fatal(err) } else if bytes.Equal(encryptedRaw[:len(encryptedRaw)-authTagLen], actualDecrypted) { t.Fatal("DecryptRTP improperly encrypted in place") } assert.Equalf(actualDecrypted, decryptedRaw, "RTP packet with SeqNum invalid decryption: %d", testCase.sequenceNumber) } } func TestRTPLifecycleNewAlloc(t *testing.T) { t.Run("CTR", func(t *testing.T) { testRTPLifecyleNewAlloc(t, profileCTR) }) t.Run("GCM", func(t *testing.T) { testRTPLifecyleNewAlloc(t, profileGCM) }) } func testRTPLifecyleInPlace(t *testing.T, profile ProtectionProfile) { assert := assert.New(t) for _, testCase := range rtpTestCases() { encryptContext, err := buildTestContext(profile) if err != nil { t.Fatal(err) } decryptContext, err := buildTestContext(profile) if err != nil { t.Fatal(err) } decryptHeader := &rtp.Header{} decryptedPkt := &rtp.Packet{Payload: rtpTestCaseDecrypted(), Header: rtp.Header{SequenceNumber: testCase.sequenceNumber}} decryptedRaw, err := decryptedPkt.Marshal() if err != nil { t.Fatal(err) } encryptHeader := &rtp.Header{} encryptedPkt := &rtp.Packet{Payload: testCase.encrypted(profile), Header: rtp.Header{SequenceNumber: testCase.sequenceNumber}} encryptedRaw, err := encryptedPkt.Marshal() if err != nil { t.Fatal(err) } // Copy packet, asserts that everything was done in place slack := 10 if profile == profileGCM { slack = 16 } encryptInput := make([]byte, len(decryptedRaw), len(decryptedRaw)+slack) copy(encryptInput, decryptedRaw) actualEncrypted, err := encryptContext.EncryptRTP(encryptInput, encryptInput, encryptHeader) switch { case err != nil: t.Fatal(err) case &encryptInput[0] != &actualEncrypted[0]: t.Errorf("EncryptRTP failed to encrypt in place") case encryptHeader.SequenceNumber != testCase.sequenceNumber: t.Errorf("EncryptRTP failed to populate input rtp.Header") } assert.Equalf(actualEncrypted, encryptedRaw, "RTP packet with SeqNum invalid encryption: %d", testCase.sequenceNumber) // Copy packet, asserts that everything was done in place decryptInput := make([]byte, len(encryptedRaw)) copy(decryptInput, encryptedRaw) actualDecrypted, err := decryptContext.DecryptRTP(decryptInput, decryptInput, decryptHeader) switch { case err != nil: t.Fatal(err) case &decryptInput[0] != &actualDecrypted[0]: t.Errorf("DecryptRTP failed to decrypt in place") case decryptHeader.SequenceNumber != testCase.sequenceNumber: t.Errorf("DecryptRTP failed to populate input rtp.Header") } assert.Equalf(actualDecrypted, decryptedRaw, "RTP packet with SeqNum invalid decryption: %d", testCase.sequenceNumber) } } func TestRTPLifecycleInPlace(t *testing.T) { t.Run("CTR", func(t *testing.T) { testRTPLifecyleInPlace(t, profileCTR) }) t.Run("GCM", func(t *testing.T) { testRTPLifecyleInPlace(t, profileGCM) }) } func testRTPReplayProtection(t *testing.T, profile ProtectionProfile) { assert := assert.New(t) for _, testCase := range rtpTestCases() { encryptContext, err := buildTestContext(profile) if err != nil { t.Fatal(err) } decryptContext, err := buildTestContext( profile, SRTPReplayProtection(64), ) if err != nil { t.Fatal(err) } decryptHeader := &rtp.Header{} decryptedPkt := &rtp.Packet{Payload: rtpTestCaseDecrypted(), Header: rtp.Header{SequenceNumber: testCase.sequenceNumber}} decryptedRaw, err := decryptedPkt.Marshal() if err != nil { t.Fatal(err) } encryptHeader := &rtp.Header{} encryptedPkt := &rtp.Packet{Payload: testCase.encrypted(profile), Header: rtp.Header{SequenceNumber: testCase.sequenceNumber}} encryptedRaw, err := encryptedPkt.Marshal() if err != nil { t.Fatal(err) } // Copy packet, asserts that everything was done in place slack := 10 if profile == profileGCM { slack = 16 } encryptInput := make([]byte, len(decryptedRaw), len(decryptedRaw)+slack) copy(encryptInput, decryptedRaw) actualEncrypted, err := encryptContext.EncryptRTP(encryptInput, encryptInput, encryptHeader) switch { case err != nil: t.Fatal(err) case &encryptInput[0] != &actualEncrypted[0]: t.Errorf("EncryptRTP failed to encrypt in place") case encryptHeader.SequenceNumber != testCase.sequenceNumber: t.Fatal("EncryptRTP failed to populate input rtp.Header") } assert.Equalf(actualEncrypted, encryptedRaw, "RTP packet with SeqNum invalid encryption: %d", testCase.sequenceNumber) // Copy packet, asserts that everything was done in place decryptInput := make([]byte, len(encryptedRaw)) copy(decryptInput, encryptedRaw) actualDecrypted, err := decryptContext.DecryptRTP(decryptInput, decryptInput, decryptHeader) switch { case err != nil: t.Fatal(err) case &decryptInput[0] != &actualDecrypted[0]: t.Errorf("DecryptRTP failed to decrypt in place") case decryptHeader.SequenceNumber != testCase.sequenceNumber: t.Errorf("DecryptRTP failed to populate input rtp.Header") } assert.Equalf(actualDecrypted, decryptedRaw, "RTP packet with SeqNum invalid decryption: %d", testCase.sequenceNumber) _, errReplay := decryptContext.DecryptRTP(decryptInput, decryptInput, decryptHeader) if !errors.Is(errReplay, errDuplicated) { t.Errorf("Replayed packet must be errored with %v, got %v", errDuplicated, errReplay) } } } func TestRTPReplayProtection(t *testing.T) { t.Run("CTR", func(t *testing.T) { testRTPReplayProtection(t, profileCTR) }) t.Run("GCM", func(t *testing.T) { testRTPReplayProtection(t, profileGCM) }) } func benchmarkEncryptRTP(b *testing.B, profile ProtectionProfile, size int) { encryptContext, err := buildTestContext(profile) if err != nil { b.Fatal(err) } pkt := &rtp.Packet{Payload: make([]byte, size)} pktRaw, err := pkt.Marshal() if err != nil { b.Fatal(err) } b.SetBytes(int64(len(pktRaw))) b.ResetTimer() for i := 0; i < b.N; i++ { _, err = encryptContext.EncryptRTP(nil, pktRaw, nil) if err != nil { b.Fatal(err) } } } func BenchmarkEncryptRTP(b *testing.B) { b.Run("CTR-100", func(b *testing.B) { benchmarkEncryptRTP(b, profileCTR, 100) }) b.Run("CTR-1000", func(b *testing.B) { benchmarkEncryptRTP(b, profileCTR, 1000) }) b.Run("GCM-100", func(b *testing.B) { benchmarkEncryptRTP(b, profileGCM, 100) }) b.Run("GCM-1000", func(b *testing.B) { benchmarkEncryptRTP(b, profileGCM, 1000) }) } func benchmarkEncryptRTPInPlace(b *testing.B, profile ProtectionProfile, size int) { encryptContext, err := buildTestContext(profile) if err != nil { b.Fatal(err) } pkt := &rtp.Packet{Payload: make([]byte, size)} pktRaw, err := pkt.Marshal() if err != nil { b.Fatal(err) } buf := make([]byte, 0, len(pktRaw)+10) b.SetBytes(int64(len(pktRaw))) b.ResetTimer() for i := 0; i < b.N; i++ { buf, err = encryptContext.EncryptRTP(buf[:0], pktRaw, nil) if err != nil { b.Fatal(err) } } } func BenchmarkEncryptRTPInPlace(b *testing.B) { b.Run("CTR-100", func(b *testing.B) { benchmarkEncryptRTPInPlace(b, profileCTR, 100) }) b.Run("CTR-1000", func(b *testing.B) { benchmarkEncryptRTPInPlace(b, profileCTR, 1000) }) b.Run("GCM-100", func(b *testing.B) { benchmarkEncryptRTPInPlace(b, profileGCM, 100) }) b.Run("GCM-1000", func(b *testing.B) { benchmarkEncryptRTPInPlace(b, profileGCM, 1000) }) } func benchmarkDecryptRTP(b *testing.B, profile ProtectionProfile) { sequenceNumber := uint16(5000) encrypted := rtpTestCases()[0].encrypted(profile) encryptedPkt := &rtp.Packet{ Payload: encrypted, Header: rtp.Header{ SequenceNumber: sequenceNumber, }, } encryptedRaw, err := encryptedPkt.Marshal() if err != nil { b.Fatal(err) } context, err := buildTestContext(profile) if err != nil { b.Fatal(err) } b.SetBytes(int64(len(encryptedRaw))) b.ResetTimer() for i := 0; i < b.N; i++ { _, err := context.DecryptRTP(nil, encryptedRaw, nil) if err != nil { b.Fatal(err) } } } func BenchmarkDecryptRTP(b *testing.B) { b.Run("CTR", func(b *testing.B) { benchmarkDecryptRTP(b, profileCTR) }) b.Run("GCM", func(b *testing.B) { benchmarkDecryptRTP(b, profileGCM) }) } func TestRolloverCount2(t *testing.T) { s := &srtpSSRCState{ssrc: defaultSsrc} roc, diff, ovf := s.nextRolloverCount(30123) if roc != 0 { t.Errorf("Initial rolloverCounter must be 0") } if ovf { t.Error("Should not overflow") } s.updateRolloverCount(30123, diff) roc, diff, ovf = s.nextRolloverCount(62892) // 30123 + (1 << 15) + 1 if roc != 0 { t.Errorf("Initial rolloverCounter must be 0") } if ovf { t.Error("Should not overflow") } s.updateRolloverCount(62892, diff) roc, diff, ovf = s.nextRolloverCount(204) if roc != 1 { t.Errorf("rolloverCounter was not updated after it crossed 0") } if ovf { t.Error("Should not overflow") } s.updateRolloverCount(62892, diff) roc, diff, ovf = s.nextRolloverCount(64535) if roc != 0 { t.Errorf("rolloverCounter was not updated when it rolled back, failed to handle out of order") } if ovf { t.Error("Should not overflow") } s.updateRolloverCount(64535, diff) roc, diff, ovf = s.nextRolloverCount(205) if roc != 1 { t.Errorf("rolloverCounter was improperly updated for non-significant packets") } if ovf { t.Error("Should not overflow") } s.updateRolloverCount(205, diff) roc, diff, ovf = s.nextRolloverCount(1) if roc != 1 { t.Errorf("rolloverCounter was improperly updated for non-significant packets") } if ovf { t.Error("Should not overflow") } s.updateRolloverCount(1, diff) roc, diff, ovf = s.nextRolloverCount(64532) if roc != 0 { t.Errorf("rolloverCounter was improperly updated for non-significant packets") } if ovf { t.Error("Should not overflow") } s.updateRolloverCount(64532, diff) roc, diff, ovf = s.nextRolloverCount(65534) if roc != 0 { t.Errorf("index was improperly updated for non-significant packets") } if ovf { t.Error("Should not overflow") } s.updateRolloverCount(65534, diff) roc, diff, ovf = s.nextRolloverCount(64532) if roc != 0 { t.Errorf("index was improperly updated for non-significant packets") } if ovf { t.Error("Should not overflow") } s.updateRolloverCount(65532, diff) roc, diff, ovf = s.nextRolloverCount(205) if roc != 1 { t.Errorf("index was not updated after it crossed 0") } if ovf { t.Error("Should not overflow") } s.updateRolloverCount(65532, diff) } func TestProtectionProfileAes128CmHmacSha1_32(t *testing.T) { masterKey := []byte{0x0d, 0xcd, 0x21, 0x3e, 0x4c, 0xbc, 0xf2, 0x8f, 0x01, 0x7f, 0x69, 0x94, 0x40, 0x1e, 0x28, 0x89} masterSalt := []byte{0x62, 0x77, 0x60, 0x38, 0xc0, 0x6d, 0xc9, 0x41, 0x9f, 0x6d, 0xd9, 0x43, 0x3e, 0x7c} encryptContext, err := CreateContext(masterKey, masterSalt, ProtectionProfileAes128CmHmacSha1_32) if err != nil { t.Fatal(err) } decryptContext, err := CreateContext(masterKey, masterSalt, ProtectionProfileAes128CmHmacSha1_32) if err != nil { t.Fatal(err) } pkt := &rtp.Packet{Payload: rtpTestCaseDecrypted(), Header: rtp.Header{SequenceNumber: 5000}} pktRaw, err := pkt.Marshal() if err != nil { t.Fatal(err) } out, err := encryptContext.EncryptRTP(nil, pktRaw, nil) if err != nil { t.Fatal(err) } decrypted, err := decryptContext.DecryptRTP(nil, out, nil) if err != nil { t.Fatal(err) } if !bytes.Equal(decrypted, pktRaw) { t.Errorf("Decrypted % 02x does not match original % 02x", decrypted, pktRaw) } } func TestRTPDecryptShotenedPacket(t *testing.T) { profiles := map[string]ProtectionProfile{ "CTR": profileCTR, "GCM": profileGCM, } for name, profile := range profiles { profile := profile t.Run(name, func(t *testing.T) { for _, testCase := range rtpTestCases() { decryptContext, err := buildTestContext(profile) if err != nil { t.Fatal(err) } encryptedPkt := &rtp.Packet{Payload: testCase.encrypted(profile), Header: rtp.Header{SequenceNumber: testCase.sequenceNumber}} encryptedRaw, err := encryptedPkt.Marshal() if err != nil { t.Fatal(err) } for i := 1; i < len(encryptedRaw)-1; i++ { packet := encryptedRaw[:i] assert.NotPanics(t, func() { _, _ = decryptContext.DecryptRTP(nil, packet, nil) }, "Panic on length %d/%d", i, len(encryptedRaw)) } } }) } } func TestRTPMaxPackets(t *testing.T) { profiles := map[string]ProtectionProfile{ "CTR": profileCTR, "GCM": profileGCM, } for name, profile := range profiles { profile := profile t.Run(name, func(t *testing.T) { context, err := buildTestContext(profile) if err != nil { t.Fatal(err) } context.SetROC(1, (1<<32)-1) pkt0 := &rtp.Packet{ Header: rtp.Header{ SSRC: 1, SequenceNumber: 0xffff, }, Payload: []byte{0, 1}, } raw0, err0 := pkt0.Marshal() if err0 != nil { t.Fatal(err0) } if _, errEnc := context.EncryptRTP(nil, raw0, nil); errEnc != nil { t.Fatal(errEnc) } pkt1 := &rtp.Packet{ Header: rtp.Header{ SSRC: 1, SequenceNumber: 0x0, }, Payload: []byte{0, 1}, } raw1, err1 := pkt1.Marshal() if err1 != nil { t.Fatal(err1) } if _, errEnc := context.EncryptRTP(nil, raw1, nil); !errors.Is(errEnc, errExceededMaxPackets) { t.Fatalf("Expected error '%v', got '%v'", errExceededMaxPackets, errEnc) } }) } } func TestRTPBurstLossWithSetROC(t *testing.T) { profiles := map[string]ProtectionProfile{ "CTR": profileCTR, "GCM": profileGCM, } for name, profile := range profiles { profile := profile t.Run(name, func(t *testing.T) { assert := assert.New(t) encryptContext, err := buildTestContext(profile) if err != nil { t.Fatal(err) } type packetWithROC struct { pkt rtp.Packet enc []byte raw []byte roc uint32 } var pkts []*packetWithROC encryptContext.SetROC(1, 3) for i := 0x8C00; i < 0x20400; i += 0x100 { p := &packetWithROC{ pkt: rtp.Packet{ Payload: []byte{ byte(i >> 16), byte(i >> 8), byte(i), }, Header: rtp.Header{ Marker: true, SSRC: 1, SequenceNumber: uint16(i), }, }, } b, errMarshal := p.pkt.Marshal() if errMarshal != nil { t.Fatal(errMarshal) } p.raw = b enc, errEnc := encryptContext.EncryptRTP(nil, b, nil) if errEnc != nil { t.Fatal(errEnc) } p.roc, _ = encryptContext.ROC(1) if 0x9000 < i && i < 0x20100 { continue } p.enc = enc pkts = append(pkts, p) } decryptContext, err := buildTestContext(profile) if err != nil { t.Fatal(err) } for _, p := range pkts { decryptContext.SetROC(1, p.roc) pkt, err := decryptContext.DecryptRTP(nil, p.enc, nil) if err != nil { t.Errorf("roc=%d, seq=%d: %v", p.roc, p.pkt.SequenceNumber, err) continue } assert.Equal(p.raw, pkt) } }) } } srtp-2.0.12/stream.go000066400000000000000000000002141437106062400144060ustar00rootroot00000000000000package srtp type readStream interface { init(child streamSession, ssrc uint32) error Read(buf []byte) (int, error) GetSSRC() uint32 } srtp-2.0.12/stream_srtcp.go000066400000000000000000000065511437106062400156330ustar00rootroot00000000000000package srtp import ( "errors" "io" "sync" "time" "github.com/pion/rtcp" "github.com/pion/transport/v2/packetio" ) // Limit the buffer size to 100KB const srtcpBufferSize = 100 * 1000 // ReadStreamSRTCP handles decryption for a single RTCP SSRC type ReadStreamSRTCP struct { mu sync.Mutex isInited bool isClosed chan bool session *SessionSRTCP ssrc uint32 buffer io.ReadWriteCloser } func (r *ReadStreamSRTCP) write(buf []byte) (n int, err error) { n, err = r.buffer.Write(buf) if errors.Is(err, packetio.ErrFull) { // Silently drop data when the buffer is full. return len(buf), nil } return n, err } // Used by getOrCreateReadStream func newReadStreamSRTCP() readStream { return &ReadStreamSRTCP{} } // ReadRTCP reads and decrypts full RTCP packet and its header from the nextConn func (r *ReadStreamSRTCP) ReadRTCP(buf []byte) (int, *rtcp.Header, error) { n, err := r.Read(buf) if err != nil { return 0, nil, err } header := &rtcp.Header{} err = header.Unmarshal(buf[:n]) if err != nil { return 0, nil, err } return n, header, nil } // Read reads and decrypts full RTCP packet from the nextConn func (r *ReadStreamSRTCP) Read(buf []byte) (int, error) { return r.buffer.Read(buf) } // SetReadDeadline sets the deadline for the Read operation. // Setting to zero means no deadline. func (r *ReadStreamSRTCP) SetReadDeadline(t time.Time) error { if b, ok := r.buffer.(interface { SetReadDeadline(time.Time) error }); ok { return b.SetReadDeadline(t) } return nil } // Close removes the ReadStream from the session and cleans up any associated state func (r *ReadStreamSRTCP) Close() error { r.mu.Lock() defer r.mu.Unlock() if !r.isInited { return errStreamNotInited } select { case <-r.isClosed: return errStreamAlreadyClosed default: err := r.buffer.Close() if err != nil { return err } r.session.removeReadStream(r.ssrc) return nil } } func (r *ReadStreamSRTCP) init(child streamSession, ssrc uint32) error { sessionSRTCP, ok := child.(*SessionSRTCP) r.mu.Lock() defer r.mu.Unlock() if !ok { return errFailedTypeAssertion } else if r.isInited { return errStreamAlreadyInited } r.session = sessionSRTCP r.ssrc = ssrc r.isInited = true r.isClosed = make(chan bool) if r.session.bufferFactory != nil { r.buffer = r.session.bufferFactory(packetio.RTCPBufferPacket, ssrc) } else { // Create a buffer and limit it to 100KB buff := packetio.NewBuffer() buff.SetLimitSize(srtcpBufferSize) r.buffer = buff } return nil } // GetSSRC returns the SSRC we are demuxing for func (r *ReadStreamSRTCP) GetSSRC() uint32 { return r.ssrc } // WriteStreamSRTCP is stream for a single Session that is used to encrypt RTCP type WriteStreamSRTCP struct { session *SessionSRTCP } // WriteRTCP encrypts a RTCP header and its payload to the nextConn func (w *WriteStreamSRTCP) WriteRTCP(header *rtcp.Header, payload []byte) (int, error) { headerRaw, err := header.Marshal() if err != nil { return 0, err } return w.session.write(append(headerRaw, payload...)) } // Write encrypts and writes a full RTCP packets to the nextConn func (w *WriteStreamSRTCP) Write(b []byte) (int, error) { return w.session.write(b) } // SetWriteDeadline sets the deadline for the Write operation. // Setting to zero means no deadline. func (w *WriteStreamSRTCP) SetWriteDeadline(t time.Time) error { return w.session.setWriteDeadline(t) } srtp-2.0.12/stream_srtp.go000066400000000000000000000063521437106062400154670ustar00rootroot00000000000000package srtp import ( "errors" "io" "sync" "time" "github.com/pion/rtp" "github.com/pion/transport/v2/packetio" ) // Limit the buffer size to 1MB const srtpBufferSize = 1000 * 1000 // ReadStreamSRTP handles decryption for a single RTP SSRC type ReadStreamSRTP struct { mu sync.Mutex isInited bool isClosed chan bool session *SessionSRTP ssrc uint32 buffer io.ReadWriteCloser } // Used by getOrCreateReadStream func newReadStreamSRTP() readStream { return &ReadStreamSRTP{} } func (r *ReadStreamSRTP) init(child streamSession, ssrc uint32) error { sessionSRTP, ok := child.(*SessionSRTP) r.mu.Lock() defer r.mu.Unlock() if !ok { return errFailedTypeAssertion } else if r.isInited { return errStreamAlreadyInited } r.session = sessionSRTP r.ssrc = ssrc r.isInited = true r.isClosed = make(chan bool) // Create a buffer with a 1MB limit if r.session.bufferFactory != nil { r.buffer = r.session.bufferFactory(packetio.RTPBufferPacket, ssrc) } else { buff := packetio.NewBuffer() buff.SetLimitSize(srtpBufferSize) r.buffer = buff } return nil } func (r *ReadStreamSRTP) write(buf []byte) (n int, err error) { n, err = r.buffer.Write(buf) if errors.Is(err, packetio.ErrFull) { // Silently drop data when the buffer is full. return len(buf), nil } return n, err } // Read reads and decrypts full RTP packet from the nextConn func (r *ReadStreamSRTP) Read(buf []byte) (int, error) { return r.buffer.Read(buf) } // ReadRTP reads and decrypts full RTP packet and its header from the nextConn func (r *ReadStreamSRTP) ReadRTP(buf []byte) (int, *rtp.Header, error) { n, err := r.Read(buf) if err != nil { return 0, nil, err } header := &rtp.Header{} _, err = header.Unmarshal(buf[:n]) if err != nil { return 0, nil, err } return n, header, nil } // SetReadDeadline sets the deadline for the Read operation. // Setting to zero means no deadline. func (r *ReadStreamSRTP) SetReadDeadline(t time.Time) error { if b, ok := r.buffer.(interface { SetReadDeadline(time.Time) error }); ok { return b.SetReadDeadline(t) } return nil } // Close removes the ReadStream from the session and cleans up any associated state func (r *ReadStreamSRTP) Close() error { r.mu.Lock() defer r.mu.Unlock() if !r.isInited { return errStreamNotInited } select { case <-r.isClosed: return errStreamAlreadyClosed default: err := r.buffer.Close() if err != nil { return err } r.session.removeReadStream(r.ssrc) return nil } } // GetSSRC returns the SSRC we are demuxing for func (r *ReadStreamSRTP) GetSSRC() uint32 { return r.ssrc } // WriteStreamSRTP is stream for a single Session that is used to encrypt RTP type WriteStreamSRTP struct { session *SessionSRTP } // WriteRTP encrypts a RTP packet and writes to the connection func (w *WriteStreamSRTP) WriteRTP(header *rtp.Header, payload []byte) (int, error) { return w.session.writeRTP(header, payload) } // Write encrypts and writes a full RTP packets to the nextConn func (w *WriteStreamSRTP) Write(b []byte) (int, error) { return w.session.write(b) } // SetWriteDeadline sets the deadline for the Write operation. // Setting to zero means no deadline. func (w *WriteStreamSRTP) SetWriteDeadline(t time.Time) error { return w.session.setWriteDeadline(t) } srtp-2.0.12/stream_srtp_test.go000066400000000000000000000112201437106062400165140ustar00rootroot00000000000000package srtp import ( "io" "net" "sync" "testing" "time" "github.com/pion/rtp" "github.com/pion/transport/v2/packetio" "github.com/stretchr/testify/assert" ) type noopConn struct{ closed chan struct{} } func newNoopConn() *noopConn { return &noopConn{closed: make(chan struct{})} } func (c *noopConn) Read(b []byte) (n int, err error) { <-c.closed; return 0, io.EOF } func (c *noopConn) Write(b []byte) (n int, err error) { return len(b), nil } func (c *noopConn) Close() error { close(c.closed); return nil } func (c *noopConn) LocalAddr() net.Addr { return nil } func (c *noopConn) RemoteAddr() net.Addr { return nil } func (c *noopConn) SetDeadline(t time.Time) error { return nil } func (c *noopConn) SetReadDeadline(t time.Time) error { return nil } func (c *noopConn) SetWriteDeadline(t time.Time) error { return nil } func TestBufferFactory(t *testing.T) { wg := sync.WaitGroup{} wg.Add(2) conn := newNoopConn() bf := func(_ packetio.BufferPacketType, _ uint32) io.ReadWriteCloser { wg.Done() return packetio.NewBuffer() } rtpSession, err := NewSessionSRTP(conn, &Config{ Keys: SessionKeys{ LocalMasterKey: make([]byte, 16), LocalMasterSalt: make([]byte, 14), RemoteMasterKey: make([]byte, 16), RemoteMasterSalt: make([]byte, 14), }, BufferFactory: bf, Profile: ProtectionProfileAes128CmHmacSha1_80, }) assert.NoError(t, err) rtcpSession, err := NewSessionSRTCP(conn, &Config{ Keys: SessionKeys{ LocalMasterKey: make([]byte, 16), LocalMasterSalt: make([]byte, 14), RemoteMasterKey: make([]byte, 16), RemoteMasterSalt: make([]byte, 14), }, BufferFactory: bf, Profile: ProtectionProfileAes128CmHmacSha1_80, }) assert.NoError(t, err) _, _ = rtpSession.OpenReadStream(123) _, _ = rtcpSession.OpenReadStream(123) wg.Wait() } func benchmarkWrite(b *testing.B, profile ProtectionProfile, size int) { conn := newNoopConn() keyLen, err := profile.keyLen() if err != nil { b.Fatal(err) } saltLen, err := profile.saltLen() if err != nil { b.Fatal(err) } config := &Config{ Keys: SessionKeys{ LocalMasterKey: make([]byte, keyLen), LocalMasterSalt: make([]byte, saltLen), RemoteMasterKey: make([]byte, keyLen), RemoteMasterSalt: make([]byte, saltLen), }, Profile: profile, } session, err := NewSessionSRTP(conn, config) if err != nil { b.Fatal(err) } ws, err := session.OpenWriteStream() if err != nil { b.Fatal(err) } packet := &rtp.Packet{ Header: rtp.Header{ Version: 2, SSRC: 322, }, Payload: make([]byte, size), } packetRaw, err := packet.Marshal() if err != nil { b.Fatal(err) } b.SetBytes(int64(len(packetRaw))) b.ResetTimer() for i := 0; i < b.N; i++ { packet.Header.SequenceNumber++ _, err = ws.Write(packetRaw) if err != nil { b.Fatal(err) } } err = session.Close() if err != nil { b.Fatal(err) } } func BenchmarkWrite(b *testing.B) { b.Run("CTR-100", func(b *testing.B) { benchmarkWrite(b, profileCTR, 100) }) b.Run("CTR-1000", func(b *testing.B) { benchmarkWrite(b, profileCTR, 1000) }) b.Run("GCM-100", func(b *testing.B) { benchmarkWrite(b, profileGCM, 100) }) b.Run("GCM-1000", func(b *testing.B) { benchmarkWrite(b, profileGCM, 1000) }) } func benchmarkWriteRTP(b *testing.B, profile ProtectionProfile, size int) { conn := &noopConn{ closed: make(chan struct{}), } keyLen, err := profile.keyLen() if err != nil { b.Fatal(err) } saltLen, err := profile.saltLen() if err != nil { b.Fatal(err) } config := &Config{ Keys: SessionKeys{ LocalMasterKey: make([]byte, keyLen), LocalMasterSalt: make([]byte, saltLen), RemoteMasterKey: make([]byte, keyLen), RemoteMasterSalt: make([]byte, saltLen), }, Profile: profile, } session, err := NewSessionSRTP(conn, config) if err != nil { b.Fatal(err) } ws, err := session.OpenWriteStream() if err != nil { b.Fatal(err) } header := &rtp.Header{ Version: 2, SSRC: 322, } payload := make([]byte, size) b.SetBytes(int64(header.MarshalSize() + len(payload))) b.ResetTimer() for i := 0; i < b.N; i++ { header.SequenceNumber++ _, err = ws.WriteRTP(header, payload) if err != nil { b.Fatal(err) } } err = session.Close() if err != nil { b.Fatal(err) } } func BenchmarkWriteRTP(b *testing.B) { b.Run("CTR-100", func(b *testing.B) { benchmarkWriteRTP(b, profileCTR, 100) }) b.Run("CTR-1000", func(b *testing.B) { benchmarkWriteRTP(b, profileCTR, 1000) }) b.Run("GCM-100", func(b *testing.B) { benchmarkWriteRTP(b, profileGCM, 100) }) b.Run("GCM-1000", func(b *testing.B) { benchmarkWriteRTP(b, profileGCM, 1000) }) } srtp-2.0.12/util.go000066400000000000000000000013501437106062400140720ustar00rootroot00000000000000package srtp import "bytes" // Grow the buffer size to the given number of bytes. func growBufferSize(buf []byte, size int) []byte { if size <= cap(buf) { return buf[:size] } buf2 := make([]byte, size) copy(buf2, buf) return buf2 } // Check if buffers match, if not allocate a new buffer and return it func allocateIfMismatch(dst, src []byte) []byte { if dst == nil { dst = make([]byte, len(src)) copy(dst, src) } else if !bytes.Equal(dst, src) { // bytes.Equal returns on ref equality, no optimization needed extraNeeded := len(src) - len(dst) if extraNeeded > 0 { dst = append(dst, make([]byte, extraNeeded)...) } else if extraNeeded < 0 { dst = dst[:len(dst)+extraNeeded] } copy(dst, src) } return dst }