pax_global_header00006660000000000000000000000064144043200150014503gustar00rootroot0000000000000052 comment=864358a1e7a141febccde06762a1643bf872ff80 go-proxyproto-0.7.0/000077500000000000000000000000001440432001500143575ustar00rootroot00000000000000go-proxyproto-0.7.0/.github/000077500000000000000000000000001440432001500157175ustar00rootroot00000000000000go-proxyproto-0.7.0/.github/FUNDING.yml000066400000000000000000000000161440432001500175310ustar00rootroot00000000000000github: pires go-proxyproto-0.7.0/.github/workflows/000077500000000000000000000000001440432001500177545ustar00rootroot00000000000000go-proxyproto-0.7.0/.github/workflows/golangci-lint.yml000066400000000000000000000030521440432001500232260ustar00rootroot00000000000000name: golangci-lint on: push: tags: - v* branches: - main pull_request: permissions: contents: read # Optional: allow read access to pull request. Use with `only-new-issues` option. # pull-requests: read jobs: golangci: name: lint runs-on: ubuntu-latest strategy: matrix: go: ['1.19', '1.20'] steps: - uses: actions/setup-go@v3 with: go-version: ${{ matrix.go }} - uses: actions/checkout@v3 - name: Format run: go fmt - name: Vet run: go vet - name: lint uses: golangci/golangci-lint-action@v3 #with: # Optional: version of golangci-lint to use in form of v1.2 or v1.2.3 or `latest` to use the latest version #version: v1.29 # Optional: working directory, useful for monorepos # working-directory: somedir # Optional: golangci-lint command line arguments. # args: --issues-exit-code=0 # Optional: show only new issues if it's a pull request. The default value is `false`. # only-new-issues: true # Optional: if set to true then the all caching functionality will be complete disabled, # takes precedence over all other caching options. # skip-cache: true # Optional: if set to true then the action don't cache or restore ~/go/pkg. # skip-pkg-cache: true # Optional: if set to true then the action don't cache or restore ~/.cache/go-build. # skip-build-cache: true go-proxyproto-0.7.0/.github/workflows/release.yml000066400000000000000000000004101440432001500221120ustar00rootroot00000000000000name: release on: push: tags: - "v*.*.*" jobs: release: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - name: Release uses: softprops/action-gh-release@v1 with: generate_release_notes: true go-proxyproto-0.7.0/.github/workflows/test.yml000066400000000000000000000016601440432001500214610ustar00rootroot00000000000000name: test on: pull_request: push: jobs: test: runs-on: ubuntu-latest strategy: fail-fast: false matrix: go: ['1.19', '1.20'] steps: - uses: actions/setup-go@v3 with: go-version: ${{ matrix.go }} - uses: actions/checkout@v3 - name: Get dependencies run: | go get golang.org/x/tools/cmd/cover go get github.com/mattn/goveralls - name: Test run: go test -race -v -covermode=atomic -coverprofile=coverage.out - name: Send coverage uses: shogo82148/actions-goveralls@v1 with: github-token: ${{ secrets.GITHUB_TOKEN }} path-to-profile: coverage.out flag-name: Go-${{ matrix.go }} parallel: true # notifies that all test jobs are finished. finish: needs: test runs-on: ubuntu-latest steps: - uses: shogo82148/actions-goveralls@v1 with: parallel-finished: true go-proxyproto-0.7.0/.gitignore000066400000000000000000000001571440432001500163520ustar00rootroot00000000000000# Compiled Object files, Static and Dynamic libs (Shared Objects) *.o *.a *.so # Folders .idea bin pkg *.out go-proxyproto-0.7.0/LICENSE000066400000000000000000000261151440432001500153710ustar00rootroot00000000000000 Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "{}" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright 2016 Paulo Pires Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. go-proxyproto-0.7.0/README.md000066400000000000000000000100351440432001500156350ustar00rootroot00000000000000# go-proxyproto [![Actions Status](https://github.com/pires/go-proxyproto/workflows/test/badge.svg)](https://github.com/pires/go-proxyproto/actions) [![Coverage Status](https://coveralls.io/repos/github/pires/go-proxyproto/badge.svg?branch=master)](https://coveralls.io/github/pires/go-proxyproto?branch=master) [![Go Report Card](https://goreportcard.com/badge/github.com/pires/go-proxyproto)](https://goreportcard.com/report/github.com/pires/go-proxyproto) [![](https://godoc.org/github.com/pires/go-proxyproto?status.svg)](https://pkg.go.dev/github.com/pires/go-proxyproto?tab=doc) A Go library implementation of the [PROXY protocol, versions 1 and 2](https://www.haproxy.org/download/2.3/doc/proxy-protocol.txt), which provides, as per specification: > (...) a convenient way to safely transport connection > information such as a client's address across multiple layers of NAT or TCP > proxies. It is designed to require little changes to existing components and > to limit the performance impact caused by the processing of the transported > information. This library is to be used in one of or both proxy clients and proxy servers that need to support said protocol. Both protocol versions, 1 (text-based) and 2 (binary-based) are supported. ## Installation ```shell $ go get -u github.com/pires/go-proxyproto ``` ## Usage ### Client ```go package main import ( "io" "log" "net" proxyproto "github.com/pires/go-proxyproto" ) func chkErr(err error) { if err != nil { log.Fatalf("Error: %s", err.Error()) } } func main() { // Dial some proxy listener e.g. https://github.com/mailgun/proxyproto target, err := net.ResolveTCPAddr("tcp", "127.0.0.1:2319") chkErr(err) conn, err := net.DialTCP("tcp", nil, target) chkErr(err) defer conn.Close() // Create a proxyprotocol header or use HeaderProxyFromAddrs() if you // have two conn's header := &proxyproto.Header{ Version: 1, Command: proxyproto.PROXY, TransportProtocol: proxyproto.TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, } // After the connection was created write the proxy headers first _, err = header.WriteTo(conn) chkErr(err) // Then your data... e.g.: _, err = io.WriteString(conn, "HELO") chkErr(err) } ``` ### Server ```go package main import ( "log" "net" proxyproto "github.com/pires/go-proxyproto" ) func main() { // Create a listener addr := "localhost:9876" list, err := net.Listen("tcp", addr) if err != nil { log.Fatalf("couldn't listen to %q: %q\n", addr, err.Error()) } // Wrap listener in a proxyproto listener proxyListener := &proxyproto.Listener{Listener: list} defer proxyListener.Close() // Wait for a connection and accept it conn, err := proxyListener.Accept() defer conn.Close() // Print connection details if conn.LocalAddr() == nil { log.Fatal("couldn't retrieve local address") } log.Printf("local address: %q", conn.LocalAddr().String()) if conn.RemoteAddr() == nil { log.Fatal("couldn't retrieve remote address") } log.Printf("remote address: %q", conn.RemoteAddr().String()) } ``` ### HTTP Server ```go package main import ( "net" "net/http" "time" "github.com/pires/go-proxyproto" ) func main() { server := http.Server{ Addr: ":8080", } ln, err := net.Listen("tcp", server.Addr) if err != nil { panic(err) } proxyListener := &proxyproto.Listener{ Listener: ln, ReadHeaderTimeout: 10 * time.Second, } defer proxyListener.Close() server.Serve(proxyListener) } ``` ## Special notes ### AWS AWS Network Load Balancer (NLB) does not push the PPV2 header until the client starts sending the data. This is a problem if your server speaks first. e.g. SMTP, FTP, SSH etc. By default, NLB target group attribute `proxy_protocol_v2.client_to_server.header_placement` has the value `on_first_ack_with_payload`. You need to contact AWS support to change it to `on_first_ack`, instead. Just to be clear, you need this fix only if your server is designed to speak first. go-proxyproto-0.7.0/addr_proto.go000066400000000000000000000037121440432001500170460ustar00rootroot00000000000000package proxyproto // AddressFamilyAndProtocol represents address family and transport protocol. type AddressFamilyAndProtocol byte const ( UNSPEC AddressFamilyAndProtocol = '\x00' TCPv4 AddressFamilyAndProtocol = '\x11' UDPv4 AddressFamilyAndProtocol = '\x12' TCPv6 AddressFamilyAndProtocol = '\x21' UDPv6 AddressFamilyAndProtocol = '\x22' UnixStream AddressFamilyAndProtocol = '\x31' UnixDatagram AddressFamilyAndProtocol = '\x32' ) // IsIPv4 returns true if the address family is IPv4 (AF_INET4), false otherwise. func (ap AddressFamilyAndProtocol) IsIPv4() bool { return ap&0xF0 == 0x10 } // IsIPv6 returns true if the address family is IPv6 (AF_INET6), false otherwise. func (ap AddressFamilyAndProtocol) IsIPv6() bool { return ap&0xF0 == 0x20 } // IsUnix returns true if the address family is UNIX (AF_UNIX), false otherwise. func (ap AddressFamilyAndProtocol) IsUnix() bool { return ap&0xF0 == 0x30 } // IsStream returns true if the transport protocol is TCP or STREAM (SOCK_STREAM), false otherwise. func (ap AddressFamilyAndProtocol) IsStream() bool { return ap&0x0F == 0x01 } // IsDatagram returns true if the transport protocol is UDP or DGRAM (SOCK_DGRAM), false otherwise. func (ap AddressFamilyAndProtocol) IsDatagram() bool { return ap&0x0F == 0x02 } // IsUnspec returns true if the transport protocol or address family is unspecified, false otherwise. func (ap AddressFamilyAndProtocol) IsUnspec() bool { return (ap&0xF0 == 0x00) || (ap&0x0F == 0x00) } func (ap AddressFamilyAndProtocol) toByte() byte { if ap.IsIPv4() && ap.IsStream() { return byte(TCPv4) } else if ap.IsIPv4() && ap.IsDatagram() { return byte(UDPv4) } else if ap.IsIPv6() && ap.IsStream() { return byte(TCPv6) } else if ap.IsIPv6() && ap.IsDatagram() { return byte(UDPv6) } else if ap.IsUnix() && ap.IsStream() { return byte(UnixStream) } else if ap.IsUnix() && ap.IsDatagram() { return byte(UnixDatagram) } return byte(UNSPEC) } go-proxyproto-0.7.0/addr_proto_test.go000066400000000000000000000032311440432001500201010ustar00rootroot00000000000000package proxyproto import ( "testing" ) func TestTCPoverIPv4(t *testing.T) { b := byte(TCPv4) if !AddressFamilyAndProtocol(b).IsIPv4() { t.Fail() } if !AddressFamilyAndProtocol(b).IsStream() { t.Fail() } if AddressFamilyAndProtocol(b).toByte() != b { t.Fail() } } func TestTCPoverIPv6(t *testing.T) { b := byte(TCPv6) if !AddressFamilyAndProtocol(b).IsIPv6() { t.Fail() } if !AddressFamilyAndProtocol(b).IsStream() { t.Fail() } if AddressFamilyAndProtocol(b).toByte() != b { t.Fail() } } func TestUDPoverIPv4(t *testing.T) { b := byte(UDPv4) if !AddressFamilyAndProtocol(b).IsIPv4() { t.Fail() } if !AddressFamilyAndProtocol(b).IsDatagram() { t.Fail() } if AddressFamilyAndProtocol(b).toByte() != b { t.Fail() } } func TestUDPoverIPv6(t *testing.T) { b := byte(UDPv6) if !AddressFamilyAndProtocol(b).IsIPv6() { t.Fail() } if !AddressFamilyAndProtocol(b).IsDatagram() { t.Fail() } if AddressFamilyAndProtocol(b).toByte() != b { t.Fail() } } func TestUnixStream(t *testing.T) { b := byte(UnixStream) if !AddressFamilyAndProtocol(b).IsUnix() { t.Fail() } if !AddressFamilyAndProtocol(b).IsStream() { t.Fail() } if AddressFamilyAndProtocol(b).toByte() != b { t.Fail() } } func TestUnixDatagram(t *testing.T) { b := byte(UnixDatagram) if !AddressFamilyAndProtocol(b).IsUnix() { t.Fail() } if !AddressFamilyAndProtocol(b).IsDatagram() { t.Fail() } if AddressFamilyAndProtocol(b).toByte() != b { t.Fail() } } func TestInvalidAddressFamilyAndProtocol(t *testing.T) { b := byte(UNSPEC) if !AddressFamilyAndProtocol(b).IsUnspec() { t.Fail() } if AddressFamilyAndProtocol(b).toByte() != b { t.Fail() } } go-proxyproto-0.7.0/examples/000077500000000000000000000000001440432001500161755ustar00rootroot00000000000000go-proxyproto-0.7.0/examples/client/000077500000000000000000000000001440432001500174535ustar00rootroot00000000000000go-proxyproto-0.7.0/examples/client/client.go000066400000000000000000000017611440432001500212650ustar00rootroot00000000000000package main import ( "io" "log" "net" proxyproto "github.com/pires/go-proxyproto" ) func chkErr(err error) { if err != nil { log.Fatalf("Error: %s", err.Error()) } } func main() { // Dial some proxy listener e.g. https://github.com/mailgun/proxyproto target, err := net.ResolveTCPAddr("tcp", "127.0.0.1:9876") chkErr(err) conn, err := net.DialTCP("tcp", nil, target) chkErr(err) defer conn.Close() // Create a proxyprotocol header or use HeaderProxyFromAddrs() if you // have two conn's header := &proxyproto.Header{ Version: 1, Command: proxyproto.PROXY, TransportProtocol: proxyproto.TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, } // After the connection was created write the proxy headers first _, err = header.WriteTo(conn) chkErr(err) // Then your data... e.g.: _, err = io.WriteString(conn, "HELO") chkErr(err) } go-proxyproto-0.7.0/examples/httpserver/000077500000000000000000000000001440432001500204035ustar00rootroot00000000000000go-proxyproto-0.7.0/examples/httpserver/httpserver.go000066400000000000000000000013461440432001500231440ustar00rootroot00000000000000package main import ( "log" "net" "net/http" "time" "github.com/pires/go-proxyproto" ) // TODO: add httpclient example func main() { server := http.Server{ Addr: ":8080", ConnState: func(c net.Conn, s http.ConnState) { if s == http.StateNew { log.Printf("[ConnState] %s -> %s", c.LocalAddr().String(), c.RemoteAddr().String()) } }, Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { log.Printf("[Handler] remote ip %q", r.RemoteAddr) }), } ln, err := net.Listen("tcp", server.Addr) if err != nil { panic(err) } proxyListener := &proxyproto.Listener{ Listener: ln, ReadHeaderTimeout: 10 * time.Second, } defer proxyListener.Close() server.Serve(proxyListener) } go-proxyproto-0.7.0/examples/server/000077500000000000000000000000001440432001500175035ustar00rootroot00000000000000go-proxyproto-0.7.0/examples/server/server.go000066400000000000000000000014541440432001500213440ustar00rootroot00000000000000package main import ( "log" "net" proxyproto "github.com/pires/go-proxyproto" ) func main() { // Create a listener addr := "localhost:9876" list, err := net.Listen("tcp", addr) if err != nil { log.Fatalf("couldn't listen to %q: %q\n", addr, err.Error()) } // Wrap listener in a proxyproto listener proxyListener := &proxyproto.Listener{Listener: list} defer proxyListener.Close() // Wait for a connection and accept it conn, err := proxyListener.Accept() defer conn.Close() // Print connection details if conn.LocalAddr() == nil { log.Fatal("couldn't retrieve local address") } log.Printf("local address: %q", conn.LocalAddr().String()) if conn.RemoteAddr() == nil { log.Fatal("couldn't retrieve remote address") } log.Printf("remote address: %q", conn.RemoteAddr().String()) } go-proxyproto-0.7.0/go.mod000066400000000000000000000000571440432001500154670ustar00rootroot00000000000000module github.com/pires/go-proxyproto go 1.18 go-proxyproto-0.7.0/header.go000066400000000000000000000217501440432001500161430ustar00rootroot00000000000000// Package proxyproto implements Proxy Protocol (v1 and v2) parser and writer, as per specification: // https://www.haproxy.org/download/2.3/doc/proxy-protocol.txt package proxyproto import ( "bufio" "bytes" "errors" "io" "net" "time" ) var ( // Protocol SIGV1 = []byte{'\x50', '\x52', '\x4F', '\x58', '\x59'} SIGV2 = []byte{'\x0D', '\x0A', '\x0D', '\x0A', '\x00', '\x0D', '\x0A', '\x51', '\x55', '\x49', '\x54', '\x0A'} ErrCantReadVersion1Header = errors.New("proxyproto: can't read version 1 header") ErrVersion1HeaderTooLong = errors.New("proxyproto: version 1 header must be 107 bytes or less") ErrLineMustEndWithCrlf = errors.New("proxyproto: version 1 header is invalid, must end with \\r\\n") ErrCantReadProtocolVersionAndCommand = errors.New("proxyproto: can't read proxy protocol version and command") ErrCantReadAddressFamilyAndProtocol = errors.New("proxyproto: can't read address family or protocol") ErrCantReadLength = errors.New("proxyproto: can't read length") ErrCantResolveSourceUnixAddress = errors.New("proxyproto: can't resolve source Unix address") ErrCantResolveDestinationUnixAddress = errors.New("proxyproto: can't resolve destination Unix address") ErrNoProxyProtocol = errors.New("proxyproto: proxy protocol signature not present") ErrUnknownProxyProtocolVersion = errors.New("proxyproto: unknown proxy protocol version") ErrUnsupportedProtocolVersionAndCommand = errors.New("proxyproto: unsupported proxy protocol version and command") ErrUnsupportedAddressFamilyAndProtocol = errors.New("proxyproto: unsupported address family and protocol") ErrInvalidLength = errors.New("proxyproto: invalid length") ErrInvalidAddress = errors.New("proxyproto: invalid address") ErrInvalidPortNumber = errors.New("proxyproto: invalid port number") ErrSuperfluousProxyHeader = errors.New("proxyproto: upstream connection sent PROXY header but isn't allowed to send one") ) // Header is the placeholder for proxy protocol header. type Header struct { Version byte Command ProtocolVersionAndCommand TransportProtocol AddressFamilyAndProtocol SourceAddr net.Addr DestinationAddr net.Addr rawTLVs []byte } // HeaderProxyFromAddrs creates a new PROXY header from a source and a // destination address. If version is zero, the latest protocol version is // used. // // The header is filled on a best-effort basis: if hints cannot be inferred // from the provided addresses, the header will be left unspecified. func HeaderProxyFromAddrs(version byte, sourceAddr, destAddr net.Addr) *Header { if version < 1 || version > 2 { version = 2 } h := &Header{ Version: version, Command: LOCAL, TransportProtocol: UNSPEC, } switch sourceAddr := sourceAddr.(type) { case *net.TCPAddr: if _, ok := destAddr.(*net.TCPAddr); !ok { break } if len(sourceAddr.IP.To4()) == net.IPv4len { h.TransportProtocol = TCPv4 } else if len(sourceAddr.IP) == net.IPv6len { h.TransportProtocol = TCPv6 } case *net.UDPAddr: if _, ok := destAddr.(*net.UDPAddr); !ok { break } if len(sourceAddr.IP.To4()) == net.IPv4len { h.TransportProtocol = UDPv4 } else if len(sourceAddr.IP) == net.IPv6len { h.TransportProtocol = UDPv6 } case *net.UnixAddr: if _, ok := destAddr.(*net.UnixAddr); !ok { break } switch sourceAddr.Net { case "unix": h.TransportProtocol = UnixStream case "unixgram": h.TransportProtocol = UnixDatagram } } if h.TransportProtocol != UNSPEC { h.Command = PROXY h.SourceAddr = sourceAddr h.DestinationAddr = destAddr } return h } func (header *Header) TCPAddrs() (sourceAddr, destAddr *net.TCPAddr, ok bool) { if !header.TransportProtocol.IsStream() { return nil, nil, false } sourceAddr, sourceOK := header.SourceAddr.(*net.TCPAddr) destAddr, destOK := header.DestinationAddr.(*net.TCPAddr) return sourceAddr, destAddr, sourceOK && destOK } func (header *Header) UDPAddrs() (sourceAddr, destAddr *net.UDPAddr, ok bool) { if !header.TransportProtocol.IsDatagram() { return nil, nil, false } sourceAddr, sourceOK := header.SourceAddr.(*net.UDPAddr) destAddr, destOK := header.DestinationAddr.(*net.UDPAddr) return sourceAddr, destAddr, sourceOK && destOK } func (header *Header) UnixAddrs() (sourceAddr, destAddr *net.UnixAddr, ok bool) { if !header.TransportProtocol.IsUnix() { return nil, nil, false } sourceAddr, sourceOK := header.SourceAddr.(*net.UnixAddr) destAddr, destOK := header.DestinationAddr.(*net.UnixAddr) return sourceAddr, destAddr, sourceOK && destOK } func (header *Header) IPs() (sourceIP, destIP net.IP, ok bool) { if sourceAddr, destAddr, ok := header.TCPAddrs(); ok { return sourceAddr.IP, destAddr.IP, true } else if sourceAddr, destAddr, ok := header.UDPAddrs(); ok { return sourceAddr.IP, destAddr.IP, true } else { return nil, nil, false } } func (header *Header) Ports() (sourcePort, destPort int, ok bool) { if sourceAddr, destAddr, ok := header.TCPAddrs(); ok { return sourceAddr.Port, destAddr.Port, true } else if sourceAddr, destAddr, ok := header.UDPAddrs(); ok { return sourceAddr.Port, destAddr.Port, true } else { return 0, 0, false } } // EqualTo returns true if headers are equivalent, false otherwise. // Deprecated: use EqualsTo instead. This method will eventually be removed. func (header *Header) EqualTo(otherHeader *Header) bool { return header.EqualsTo(otherHeader) } // EqualsTo returns true if headers are equivalent, false otherwise. func (header *Header) EqualsTo(otherHeader *Header) bool { if otherHeader == nil { return false } // TLVs only exist for version 2 if header.Version == 2 && !bytes.Equal(header.rawTLVs, otherHeader.rawTLVs) { return false } if header.Version != otherHeader.Version || header.Command != otherHeader.Command || header.TransportProtocol != otherHeader.TransportProtocol { return false } // Return early for header with LOCAL command, which contains no address information if header.Command == LOCAL { return true } return header.SourceAddr.String() == otherHeader.SourceAddr.String() && header.DestinationAddr.String() == otherHeader.DestinationAddr.String() } // WriteTo renders a proxy protocol header in a format and writes it to an io.Writer. func (header *Header) WriteTo(w io.Writer) (int64, error) { buf, err := header.Format() if err != nil { return 0, err } return bytes.NewBuffer(buf).WriteTo(w) } // Format renders a proxy protocol header in a format to write over the wire. func (header *Header) Format() ([]byte, error) { switch header.Version { case 1: return header.formatVersion1() case 2: return header.formatVersion2() default: return nil, ErrUnknownProxyProtocolVersion } } // TLVs returns the TLVs stored into this header, if they exist. TLVs are optional for v2 of the protocol. func (header *Header) TLVs() ([]TLV, error) { return SplitTLVs(header.rawTLVs) } // SetTLVs sets the TLVs stored in this header. This method replaces any // previous TLV. func (header *Header) SetTLVs(tlvs []TLV) error { raw, err := JoinTLVs(tlvs) if err != nil { return err } header.rawTLVs = raw return nil } // Read identifies the proxy protocol version and reads the remaining of // the header, accordingly. // // If proxy protocol header signature is not present, the reader buffer remains untouched // and is safe for reading outside of this code. // // If proxy protocol header signature is present but an error is raised while processing // the remaining header, assume the reader buffer to be in a corrupt state. // Also, this operation will block until enough bytes are available for peeking. func Read(reader *bufio.Reader) (*Header, error) { // In order to improve speed for small non-PROXYed packets, take a peek at the first byte alone. b1, err := reader.Peek(1) if err != nil { if err == io.EOF { return nil, ErrNoProxyProtocol } return nil, err } if bytes.Equal(b1[:1], SIGV1[:1]) || bytes.Equal(b1[:1], SIGV2[:1]) { signature, err := reader.Peek(5) if err != nil { if err == io.EOF { return nil, ErrNoProxyProtocol } return nil, err } if bytes.Equal(signature[:5], SIGV1) { return parseVersion1(reader) } signature, err = reader.Peek(12) if err != nil { if err == io.EOF { return nil, ErrNoProxyProtocol } return nil, err } if bytes.Equal(signature[:12], SIGV2) { return parseVersion2(reader) } } return nil, ErrNoProxyProtocol } // ReadTimeout acts as Read but takes a timeout. If that timeout is reached, it's assumed // there's no proxy protocol header. func ReadTimeout(reader *bufio.Reader, timeout time.Duration) (*Header, error) { type header struct { h *Header e error } read := make(chan *header, 1) go func() { h := &header{} h.h, h.e = Read(reader) read <- h }() timer := time.NewTimer(timeout) select { case result := <-read: timer.Stop() return result.h, result.e case <-timer.C: return nil, ErrNoProxyProtocol } } go-proxyproto-0.7.0/header_test.go000066400000000000000000000424131440432001500172010ustar00rootroot00000000000000package proxyproto import ( "bufio" "bytes" "errors" "net" "reflect" "testing" "time" ) // Stuff to be used in both versions tests. const ( NO_PROTOCOL = "There is no spoon" IP4_ADDR = "127.0.0.1" IP4IN6_ADDR = "::ffff:127.0.0.1" IP6_ADDR = "::1" IP6_LONG_ADDR = "1234:5678:9abc:def0:cafe:babe:dead:2bad" PORT = 65533 INVALID_PORT = 99999 ) var ( v4ip = net.ParseIP(IP4_ADDR).To4() v6ip = net.ParseIP(IP6_ADDR).To16() v4addr net.Addr = &net.TCPAddr{IP: v4ip, Port: PORT} v6addr net.Addr = &net.TCPAddr{IP: v6ip, Port: PORT} v4UDPAddr net.Addr = &net.UDPAddr{IP: v4ip, Port: PORT} v6UDPAddr net.Addr = &net.UDPAddr{IP: v6ip, Port: PORT} unixStreamAddr net.Addr = &net.UnixAddr{Net: "unix", Name: "socket"} unixDatagramAddr net.Addr = &net.UnixAddr{Net: "unixgram", Name: "socket"} errReadIntentionallyBroken = errors.New("read is intentionally broken") ) type timeoutReader []byte func (t *timeoutReader) Read([]byte) (int, error) { time.Sleep(500 * time.Millisecond) return 0, nil } type errorReader []byte func (e *errorReader) Read([]byte) (int, error) { return 0, errReadIntentionallyBroken } func TestReadTimeoutV1Invalid(t *testing.T) { var b timeoutReader reader := bufio.NewReader(&b) _, err := ReadTimeout(reader, 50*time.Millisecond) if err == nil { t.Fatalf("expected error %s", ErrNoProxyProtocol) } else if err != ErrNoProxyProtocol { t.Fatalf("expected %s, actual %s", ErrNoProxyProtocol, err) } } func TestReadTimeoutPropagatesReadError(t *testing.T) { var e errorReader reader := bufio.NewReader(&e) _, err := ReadTimeout(reader, 50*time.Millisecond) if err == nil { t.Fatalf("expected error %s", errReadIntentionallyBroken) } else if err != errReadIntentionallyBroken { t.Fatalf("expected error %s, actual %s", errReadIntentionallyBroken, err) } } func TestEqualsTo(t *testing.T) { var headersEqual = []struct { this, that *Header expected bool }{ { &Header{ Version: 1, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, }, nil, false, }, { &Header{ Version: 1, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, }, &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, }, false, }, { &Header{ Version: 1, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, }, &Header{ Version: 1, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, }, true, }, } for _, tt := range headersEqual { if actual := tt.this.EqualsTo(tt.that); actual != tt.expected { t.Fatalf("expected %t, actual %t", tt.expected, actual) } } } // This is here just because of coveralls func TestEqualTo(t *testing.T) { TestEqualsTo(t) } func TestGetters(t *testing.T) { var tests = []struct { name string header *Header tcpSourceAddr, tcpDestAddr *net.TCPAddr udpSourceAddr, udpDestAddr *net.UDPAddr unixSourceAddr, unixDestAddr *net.UnixAddr ipSource, ipDest net.IP portSource, portDest int }{ { name: "TCPv4", header: &Header{ Version: 1, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, }, tcpSourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, tcpDestAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, ipSource: net.ParseIP("10.1.1.1"), ipDest: net.ParseIP("20.2.2.2"), portSource: 1000, portDest: 2000, }, { name: "UDPv4", header: &Header{ Version: 2, Command: PROXY, TransportProtocol: UDPv6, SourceAddr: &net.UDPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.UDPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, }, udpSourceAddr: &net.UDPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, udpDestAddr: &net.UDPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, ipSource: net.ParseIP("10.1.1.1"), ipDest: net.ParseIP("20.2.2.2"), portSource: 1000, portDest: 2000, }, { name: "UnixStream", header: &Header{ Version: 2, Command: PROXY, TransportProtocol: UnixStream, SourceAddr: &net.UnixAddr{ Net: "unix", Name: "src", }, DestinationAddr: &net.UnixAddr{ Net: "unix", Name: "dst", }, }, unixSourceAddr: &net.UnixAddr{ Net: "unix", Name: "src", }, unixDestAddr: &net.UnixAddr{ Net: "unix", Name: "dst", }, }, { name: "UnixDatagram", header: &Header{ Version: 2, Command: PROXY, TransportProtocol: UnixDatagram, SourceAddr: &net.UnixAddr{ Net: "unix", Name: "src", }, DestinationAddr: &net.UnixAddr{ Net: "unix", Name: "dst", }, }, unixSourceAddr: &net.UnixAddr{ Net: "unix", Name: "src", }, unixDestAddr: &net.UnixAddr{ Net: "unix", Name: "dst", }, }, { name: "Unspec", header: &Header{ Version: 1, Command: PROXY, TransportProtocol: UNSPEC, }, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { tcpSourceAddr, tcpDestAddr, _ := test.header.TCPAddrs() if test.tcpSourceAddr != nil && !reflect.DeepEqual(tcpSourceAddr, test.tcpSourceAddr) { t.Errorf("TCPAddrs() source = %v, want %v", tcpSourceAddr, test.tcpSourceAddr) } if test.tcpDestAddr != nil && !reflect.DeepEqual(tcpDestAddr, test.tcpDestAddr) { t.Errorf("TCPAddrs() dest = %v, want %v", tcpDestAddr, test.tcpDestAddr) } udpSourceAddr, udpDestAddr, _ := test.header.UDPAddrs() if test.udpSourceAddr != nil && !reflect.DeepEqual(udpSourceAddr, test.udpSourceAddr) { t.Errorf("TCPAddrs() source = %v, want %v", udpSourceAddr, test.udpSourceAddr) } if test.udpDestAddr != nil && !reflect.DeepEqual(udpDestAddr, test.udpDestAddr) { t.Errorf("TCPAddrs() dest = %v, want %v", udpDestAddr, test.udpDestAddr) } unixSourceAddr, unixDestAddr, _ := test.header.UnixAddrs() if test.unixSourceAddr != nil && !reflect.DeepEqual(unixSourceAddr, test.unixSourceAddr) { t.Errorf("UnixAddrs() source = %v, want %v", unixSourceAddr, test.unixSourceAddr) } if test.unixDestAddr != nil && !reflect.DeepEqual(unixDestAddr, test.unixDestAddr) { t.Errorf("UnixAddrs() dest = %v, want %v", unixDestAddr, test.unixDestAddr) } ipSource, ipDest, _ := test.header.IPs() if test.ipSource != nil && !ipSource.Equal(test.ipSource) { t.Errorf("IPs() source = %v, want %v", ipSource, test.ipSource) } if test.ipDest != nil && !ipDest.Equal(test.ipDest) { t.Errorf("IPs() dest = %v, want %v", ipDest, test.ipDest) } portSource, portDest, _ := test.header.Ports() if test.portSource != 0 && portSource != test.portSource { t.Errorf("Ports() source = %v, want %v", portSource, test.portSource) } if test.portDest != 0 && portDest != test.portDest { t.Errorf("Ports() dest = %v, want %v", portDest, test.portDest) } }) } } func TestSetTLVs(t *testing.T) { tests := []struct { header *Header name string tlvs []TLV expectErr bool }{ { name: "add authority TLV", header: &Header{ Version: 1, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, }, tlvs: []TLV{{ Type: PP2_TYPE_AUTHORITY, Value: []byte("example.org"), }}, }, { name: "add too long TLV", header: &Header{ Version: 1, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, }, tlvs: []TLV{{ Type: PP2_TYPE_AUTHORITY, Value: append(bytes.Repeat([]byte("a"), 0xFFFF), []byte(".example.org")...), }}, expectErr: true, }, } for _, tt := range tests { err := tt.header.SetTLVs(tt.tlvs) if err != nil && !tt.expectErr { t.Fatalf("shouldn't have thrown error %q", err.Error()) } } } func TestWriteTo(t *testing.T) { var buf bytes.Buffer validHeader := &Header{ Version: 1, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, } if _, err := validHeader.WriteTo(&buf); err != nil { t.Fatalf("shouldn't have thrown error %q", err.Error()) } invalidHeader := &Header{ SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, } if _, err := invalidHeader.WriteTo(&buf); err == nil { t.Fatalf("should have thrown error %q", err.Error()) } } func TestFormat(t *testing.T) { validHeader := &Header{ Version: 1, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, } if _, err := validHeader.Format(); err != nil { t.Fatalf("shouldn't have thrown error %q", err.Error()) } } func TestFormatInvalid(t *testing.T) { tests := []struct { name string header *Header err error }{ { name: "invalidVersion", header: &Header{ Version: 3, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: v4addr, DestinationAddr: v4addr, }, err: ErrUnknownProxyProtocolVersion, }, { name: "v2MismatchTCPv4_UDPv4", header: &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: v4UDPAddr, DestinationAddr: v4addr, }, err: ErrInvalidAddress, }, { name: "v2MismatchTCPv4_TCPv6", header: &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: v4addr, DestinationAddr: v6addr, }, err: ErrInvalidAddress, }, { name: "v2MismatchUnixStream_TCPv4", header: &Header{ Version: 2, Command: PROXY, TransportProtocol: UnixStream, SourceAddr: v4addr, DestinationAddr: unixStreamAddr, }, err: ErrInvalidAddress, }, { name: "v1MismatchTCPv4_TCPv6", header: &Header{ Version: 1, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: v6addr, DestinationAddr: v4addr, }, err: ErrInvalidAddress, }, { name: "v1MismatchTCPv4_UDPv4", header: &Header{ Version: 1, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: v4UDPAddr, DestinationAddr: v4addr, }, err: ErrInvalidAddress, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { if _, err := test.header.Format(); err == nil { t.Errorf("Header.Format() succeeded, want an error") } else if err != test.err { t.Errorf("Header.Format() = %q, want %q", err, test.err) } }) } } func TestHeaderProxyFromAddrs(t *testing.T) { unspec := &Header{ Version: 2, Command: LOCAL, TransportProtocol: UNSPEC, } tests := []struct { name string version byte sourceAddr, destAddr net.Addr expected *Header }{ { name: "TCPv4", sourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, destAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, expected: &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, }, }, { name: "TCPv6", sourceAddr: &net.TCPAddr{ IP: net.ParseIP("fde7::372"), Port: 1000, }, destAddr: &net.TCPAddr{ IP: net.ParseIP("fde7::1"), Port: 2000, }, expected: &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv6, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("fde7::372"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("fde7::1"), Port: 2000, }, }, }, { name: "UDPv4", sourceAddr: &net.UDPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, destAddr: &net.UDPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, expected: &Header{ Version: 2, Command: PROXY, TransportProtocol: UDPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, }, }, { name: "UDPv6", sourceAddr: &net.UDPAddr{ IP: net.ParseIP("fde7::372"), Port: 1000, }, destAddr: &net.UDPAddr{ IP: net.ParseIP("fde7::1"), Port: 2000, }, expected: &Header{ Version: 2, Command: PROXY, TransportProtocol: UDPv6, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("fde7::372"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("fde7::1"), Port: 2000, }, }, }, { name: "UnixStream", sourceAddr: &net.UnixAddr{ Net: "unix", Name: "src", }, destAddr: &net.UnixAddr{ Net: "unix", Name: "dst", }, expected: &Header{ Version: 2, Command: PROXY, TransportProtocol: UnixStream, SourceAddr: &net.UnixAddr{ Net: "unix", Name: "src", }, DestinationAddr: &net.UnixAddr{ Net: "unix", Name: "dst", }, }, }, { name: "UnixDatagram", sourceAddr: &net.UnixAddr{ Net: "unixgram", Name: "src", }, destAddr: &net.UnixAddr{ Net: "unixgram", Name: "dst", }, expected: &Header{ Version: 2, Command: PROXY, TransportProtocol: UnixDatagram, SourceAddr: &net.UnixAddr{ Net: "unixgram", Name: "src", }, DestinationAddr: &net.UnixAddr{ Net: "unixgram", Name: "dst", }, }, }, { name: "Version1", version: 1, sourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, destAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, expected: &Header{ Version: 1, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, }, }, { name: "TCPInvalidIP", sourceAddr: &net.TCPAddr{ IP: nil, Port: 1000, }, destAddr: &net.TCPAddr{ IP: nil, Port: 2000, }, expected: unspec, }, { name: "UDPInvalidIP", sourceAddr: &net.UDPAddr{ IP: nil, Port: 1000, }, destAddr: &net.UDPAddr{ IP: nil, Port: 2000, }, expected: unspec, }, { name: "TCPAddrTypeMismatch", sourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, destAddr: &net.UDPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, expected: unspec, }, { name: "UDPAddrTypeMismatch", sourceAddr: &net.UDPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, destAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, expected: unspec, }, { name: "UnixAddrTypeMismatch", sourceAddr: &net.UnixAddr{ Net: "unix", }, destAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, expected: unspec, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { h := HeaderProxyFromAddrs(tt.version, tt.sourceAddr, tt.destAddr) if !h.EqualsTo(tt.expected) { t.Errorf("expected %+v, actual %+v for source %+v and destination %+v", tt.expected, h, tt.sourceAddr, tt.destAddr) } }) } } go-proxyproto-0.7.0/policy.go000066400000000000000000000122011440432001500162010ustar00rootroot00000000000000package proxyproto import ( "fmt" "net" "strings" ) // PolicyFunc can be used to decide whether to trust the PROXY info from // upstream. If set, the connecting address is passed in as an argument. // // See below for the different policies. // // In case an error is returned the connection is denied. type PolicyFunc func(upstream net.Addr) (Policy, error) // Policy defines how a connection with a PROXY header address is treated. type Policy int const ( // USE address from PROXY header USE Policy = iota // IGNORE address from PROXY header, but accept connection IGNORE // REJECT connection when PROXY header is sent // Note: even though the first read on the connection returns an error if // a PROXY header is present, subsequent reads do not. It is the task of // the code using the connection to handle that case properly. REJECT // REQUIRE connection to send PROXY header, reject if not present // Note: even though the first read on the connection returns an error if // a PROXY header is not present, subsequent reads do not. It is the task // of the code using the connection to handle that case properly. REQUIRE // SKIP accepts a connection without requiring the PROXY header // Note: an example usage can be found in the SkipProxyHeaderForCIDR // function. SKIP ) // SkipProxyHeaderForCIDR returns a PolicyFunc which can be used to accept a // connection from a skipHeaderCIDR without requiring a PROXY header, e.g. // Kubernetes pods local traffic. The def is a policy to use when an upstream // address doesn't match the skipHeaderCIDR. func SkipProxyHeaderForCIDR(skipHeaderCIDR *net.IPNet, def Policy) PolicyFunc { return func(upstream net.Addr) (Policy, error) { ip, err := ipFromAddr(upstream) if err != nil { return def, err } if skipHeaderCIDR != nil && skipHeaderCIDR.Contains(ip) { return SKIP, nil } return def, nil } } // WithPolicy adds given policy to a connection when passed as option to NewConn() func WithPolicy(p Policy) func(*Conn) { return func(c *Conn) { c.ProxyHeaderPolicy = p } } // LaxWhiteListPolicy returns a PolicyFunc which decides whether the // upstream ip is allowed to send a proxy header based on a list of allowed // IP addresses and IP ranges. In case upstream IP is not in list the proxy // header will be ignored. If one of the provided IP addresses or IP ranges // is invalid it will return an error instead of a PolicyFunc. func LaxWhiteListPolicy(allowed []string) (PolicyFunc, error) { allowFrom, err := parse(allowed) if err != nil { return nil, err } return whitelistPolicy(allowFrom, IGNORE), nil } // MustLaxWhiteListPolicy returns a LaxWhiteListPolicy but will panic if one // of the provided IP addresses or IP ranges is invalid. func MustLaxWhiteListPolicy(allowed []string) PolicyFunc { pfunc, err := LaxWhiteListPolicy(allowed) if err != nil { panic(err) } return pfunc } // StrictWhiteListPolicy returns a PolicyFunc which decides whether the // upstream ip is allowed to send a proxy header based on a list of allowed // IP addresses and IP ranges. In case upstream IP is not in list reading on // the connection will be refused on the first read. Please note: subsequent // reads do not error. It is the task of the code using the connection to // handle that case properly. If one of the provided IP addresses or IP // ranges is invalid it will return an error instead of a PolicyFunc. func StrictWhiteListPolicy(allowed []string) (PolicyFunc, error) { allowFrom, err := parse(allowed) if err != nil { return nil, err } return whitelistPolicy(allowFrom, REJECT), nil } // MustStrictWhiteListPolicy returns a StrictWhiteListPolicy but will panic // if one of the provided IP addresses or IP ranges is invalid. func MustStrictWhiteListPolicy(allowed []string) PolicyFunc { pfunc, err := StrictWhiteListPolicy(allowed) if err != nil { panic(err) } return pfunc } func whitelistPolicy(allowed []func(net.IP) bool, def Policy) PolicyFunc { return func(upstream net.Addr) (Policy, error) { upstreamIP, err := ipFromAddr(upstream) if err != nil { // something is wrong with the source IP, better reject the connection return REJECT, err } for _, allowFrom := range allowed { if allowFrom(upstreamIP) { return USE, nil } } return def, nil } } func parse(allowed []string) ([]func(net.IP) bool, error) { a := make([]func(net.IP) bool, len(allowed)) for i, allowFrom := range allowed { if strings.LastIndex(allowFrom, "/") > 0 { _, ipRange, err := net.ParseCIDR(allowFrom) if err != nil { return nil, fmt.Errorf("proxyproto: given string %q is not a valid IP range: %v", allowFrom, err) } a[i] = ipRange.Contains } else { allowed := net.ParseIP(allowFrom) if allowed == nil { return nil, fmt.Errorf("proxyproto: given string %q is not a valid IP address", allowFrom) } a[i] = allowed.Equal } } return a, nil } func ipFromAddr(upstream net.Addr) (net.IP, error) { upstreamString, _, err := net.SplitHostPort(upstream.String()) if err != nil { return nil, err } upstreamIP := net.ParseIP(upstreamString) if nil == upstreamIP { return nil, fmt.Errorf("proxyproto: invalid IP address") } return upstreamIP, nil } go-proxyproto-0.7.0/policy_test.go000066400000000000000000000122661440432001500172530ustar00rootroot00000000000000package proxyproto import ( "net" "testing" ) type failingAddr struct{} func (f failingAddr) Network() string { return "failing" } func (f failingAddr) String() string { return "failing" } func TestWhitelistPolicyReturnsErrorOnInvalidAddress(t *testing.T) { var cases = []struct { name string policy PolicyFunc }{ {"strict whitelist policy", MustStrictWhiteListPolicy([]string{"10.0.0.2", "10.0.0.3", "10.0.0.4", "10.0.0.0/30"})}, {"lax whitelist policy", MustLaxWhiteListPolicy([]string{"10.0.0.2", "10.0.0.3", "10.0.0.4", "10.0.0.0/30"})}, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { _, err := tc.policy(failingAddr{}) if err == nil { t.Fatal("Expected error, got none") } }) } } func TestStrictWhitelistPolicyReturnsRejectWhenUpstreamIpAddrNotInWhitelist(t *testing.T) { p := MustStrictWhiteListPolicy([]string{"10.0.0.2", "10.0.0.3", "10.0.0.4", "10.0.0.0/30"}) upstream, err := net.ResolveTCPAddr("tcp", "10.0.0.5:45738") if err != nil { t.Fatalf("err: %v", err) } policy, err := p(upstream) if err != nil { t.Fatalf("err: %v", err) } if policy != REJECT { t.Fatalf("Expected policy REJECT, got %v", policy) } } func TestLaxWhitelistPolicyReturnsIgnoreWhenUpstreamIpAddrNotInWhitelist(t *testing.T) { p := MustLaxWhiteListPolicy([]string{"10.0.0.2", "10.0.0.3", "10.0.0.4", "10.0.0.0/30"}) upstream, err := net.ResolveTCPAddr("tcp", "10.0.0.5:45738") if err != nil { t.Fatalf("err: %v", err) } policy, err := p(upstream) if err != nil { t.Fatalf("err: %v", err) } if policy != IGNORE { t.Fatalf("Expected policy IGNORE, got %v", policy) } } func TestWhitelistPolicyReturnsUseWhenUpstreamIpAddrInWhitelist(t *testing.T) { var cases = []struct { name string policy PolicyFunc }{ {"strict whitelist policy", MustStrictWhiteListPolicy([]string{"10.0.0.2", "10.0.0.3", "10.0.0.4"})}, {"lax whitelist policy", MustLaxWhiteListPolicy([]string{"10.0.0.2", "10.0.0.3", "10.0.0.4"})}, } upstream, err := net.ResolveTCPAddr("tcp", "10.0.0.3:45738") if err != nil { t.Fatalf("err: %v", err) } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { policy, err := tc.policy(upstream) if err != nil { t.Fatalf("err: %v", err) } if policy != USE { t.Fatalf("Expected policy USE, got %v", policy) } }) } } func TestWhitelistPolicyReturnsUseWhenUpstreamIpAddrInWhitelistRange(t *testing.T) { var cases = []struct { name string policy PolicyFunc }{ {"strict whitelist policy", MustStrictWhiteListPolicy([]string{"10.0.0.0/29"})}, {"lax whitelist policy", MustLaxWhiteListPolicy([]string{"10.0.0.0/29"})}, } upstream, err := net.ResolveTCPAddr("tcp", "10.0.0.3:45738") if err != nil { t.Fatalf("err: %v", err) } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { policy, err := tc.policy(upstream) if err != nil { t.Fatalf("err: %v", err) } if policy != USE { t.Fatalf("Expected policy USE, got %v", policy) } }) } } func Test_CreateWhitelistPolicyWithInvalidCidrReturnsError(t *testing.T) { _, err := StrictWhiteListPolicy([]string{"20/80"}) if err == nil { t.Error("Expected error, got none") } } func Test_CreateWhitelistPolicyWithInvalidIpAddressReturnsError(t *testing.T) { _, err := StrictWhiteListPolicy([]string{"855.222.233.11"}) if err == nil { t.Error("Expected error, got none") } } func Test_CreateLaxPolicyWithInvalidCidrReturnsError(t *testing.T) { _, err := LaxWhiteListPolicy([]string{"20/80"}) if err == nil { t.Error("Expected error, got none") } } func Test_CreateLaxPolicyWithInvalidIpAddresseturnsError(t *testing.T) { _, err := LaxWhiteListPolicy([]string{"855.222.233.11"}) if err == nil { t.Error("Expected error, got none") } } func Test_MustLaxWhiteListPolicyPanicsWithInvalidIpAddress(t *testing.T) { defer func() { if r := recover(); r == nil { t.Error("Expected a panic, but got none") } }() MustLaxWhiteListPolicy([]string{"855.222.233.11"}) } func Test_MustLaxWhiteListPolicyPanicsWithInvalidIpRange(t *testing.T) { defer func() { if r := recover(); r == nil { t.Error("Expected a panic, but got none") } }() MustLaxWhiteListPolicy([]string{"20/80"}) } func Test_MustStrictWhiteListPolicyPanicsWithInvalidIpAddress(t *testing.T) { defer func() { if r := recover(); r == nil { t.Error("Expected a panic, but got none") } }() MustStrictWhiteListPolicy([]string{"855.222.233.11"}) } func Test_MustStrictWhiteListPolicyPanicsWithInvalidIpRange(t *testing.T) { defer func() { if r := recover(); r == nil { t.Error("Expected a panic, but got none") } }() MustStrictWhiteListPolicy([]string{"20/80"}) } func TestSkipProxyHeaderForCIDR(t *testing.T) { _, cidr, _ := net.ParseCIDR("192.0.2.1/24") f := SkipProxyHeaderForCIDR(cidr, REJECT) upstream, _ := net.ResolveTCPAddr("tcp", "192.0.2.255:12345") policy, err := f(upstream) if err != nil { t.Fatalf("err: %v", err) } if policy != SKIP { t.Errorf("Expected a SKIP policy for the %s address", upstream) } upstream, _ = net.ResolveTCPAddr("tcp", "8.8.8.8:12345") policy, err = f(upstream) if err != nil { t.Fatalf("err: %v", err) } if policy != REJECT { t.Errorf("Expected a REJECT policy for the %s address", upstream) } } go-proxyproto-0.7.0/protocol.go000066400000000000000000000221641440432001500165540ustar00rootroot00000000000000package proxyproto import ( "bufio" "io" "net" "sync" "sync/atomic" "time" ) // DefaultReadHeaderTimeout is how long header processing waits for header to // be read from the wire, if Listener.ReaderHeaderTimeout is not set. // It's kept as a global variable so to make it easier to find and override, // e.g. go build -ldflags -X "github.com/pires/go-proxyproto.DefaultReadHeaderTimeout=1s" var DefaultReadHeaderTimeout = 10 * time.Second // Listener is used to wrap an underlying listener, // whose connections may be using the HAProxy Proxy Protocol. // If the connection is using the protocol, the RemoteAddr() will return // the correct client address. ReadHeaderTimeout will be applied to all // connections in order to prevent blocking operations. If no ReadHeaderTimeout // is set, a default of 200ms will be used. This can be disabled by setting the // timeout to < 0. type Listener struct { Listener net.Listener Policy PolicyFunc ValidateHeader Validator ReadHeaderTimeout time.Duration } // Conn is used to wrap and underlying connection which // may be speaking the Proxy Protocol. If it is, the RemoteAddr() will // return the address of the client instead of the proxy address. Each connection // will have its own readHeaderTimeout and readDeadline set by the Accept() call. type Conn struct { readDeadline atomic.Value // time.Time once sync.Once readErr error conn net.Conn Validate Validator bufReader *bufio.Reader header *Header ProxyHeaderPolicy Policy readHeaderTimeout time.Duration } // Validator receives a header and decides whether it is a valid one // In case the header is not deemed valid it should return an error. type Validator func(*Header) error // ValidateHeader adds given validator for proxy headers to a connection when passed as option to NewConn() func ValidateHeader(v Validator) func(*Conn) { return func(c *Conn) { if v != nil { c.Validate = v } } } // Accept waits for and returns the next connection to the listener. func (p *Listener) Accept() (net.Conn, error) { // Get the underlying connection conn, err := p.Listener.Accept() if err != nil { return nil, err } proxyHeaderPolicy := USE if p.Policy != nil { proxyHeaderPolicy, err = p.Policy(conn.RemoteAddr()) if err != nil { // can't decide the policy, we can't accept the connection conn.Close() return nil, err } // Handle a connection as a regular one if proxyHeaderPolicy == SKIP { return conn, nil } } newConn := NewConn( conn, WithPolicy(proxyHeaderPolicy), ValidateHeader(p.ValidateHeader), ) // If the ReadHeaderTimeout for the listener is unset, use the default timeout. if p.ReadHeaderTimeout == 0 { p.ReadHeaderTimeout = DefaultReadHeaderTimeout } // Set the readHeaderTimeout of the new conn to the value of the listener newConn.readHeaderTimeout = p.ReadHeaderTimeout return newConn, nil } // Close closes the underlying listener. func (p *Listener) Close() error { return p.Listener.Close() } // Addr returns the underlying listener's network address. func (p *Listener) Addr() net.Addr { return p.Listener.Addr() } // NewConn is used to wrap a net.Conn that may be speaking // the proxy protocol into a proxyproto.Conn func NewConn(conn net.Conn, opts ...func(*Conn)) *Conn { pConn := &Conn{ bufReader: bufio.NewReader(conn), conn: conn, } for _, opt := range opts { opt(pConn) } return pConn } // Read is check for the proxy protocol header when doing // the initial scan. If there is an error parsing the header, // it is returned and the socket is closed. func (p *Conn) Read(b []byte) (int, error) { p.once.Do(func() { p.readErr = p.readHeader() }) if p.readErr != nil { return 0, p.readErr } return p.bufReader.Read(b) } // Write wraps original conn.Write func (p *Conn) Write(b []byte) (int, error) { return p.conn.Write(b) } // Close wraps original conn.Close func (p *Conn) Close() error { return p.conn.Close() } // ProxyHeader returns the proxy protocol header, if any. If an error occurs // while reading the proxy header, nil is returned. func (p *Conn) ProxyHeader() *Header { p.once.Do(func() { p.readErr = p.readHeader() }) return p.header } // LocalAddr returns the address of the server if the proxy // protocol is being used, otherwise just returns the address of // the socket server. In case an error happens on reading the // proxy header the original LocalAddr is returned, not the one // from the proxy header even if the proxy header itself is // syntactically correct. func (p *Conn) LocalAddr() net.Addr { p.once.Do(func() { p.readErr = p.readHeader() }) if p.header == nil || p.header.Command.IsLocal() || p.readErr != nil { return p.conn.LocalAddr() } return p.header.DestinationAddr } // RemoteAddr returns the address of the client if the proxy // protocol is being used, otherwise just returns the address of // the socket peer. In case an error happens on reading the // proxy header the original RemoteAddr is returned, not the one // from the proxy header even if the proxy header itself is // syntactically correct. func (p *Conn) RemoteAddr() net.Addr { p.once.Do(func() { p.readErr = p.readHeader() }) if p.header == nil || p.header.Command.IsLocal() || p.readErr != nil { return p.conn.RemoteAddr() } return p.header.SourceAddr } // Raw returns the underlying connection which can be casted to // a concrete type, allowing access to specialized functions. // // Use this ONLY if you know exactly what you are doing. func (p *Conn) Raw() net.Conn { return p.conn } // TCPConn returns the underlying TCP connection, // allowing access to specialized functions. // // Use this ONLY if you know exactly what you are doing. func (p *Conn) TCPConn() (conn *net.TCPConn, ok bool) { conn, ok = p.conn.(*net.TCPConn) return } // UnixConn returns the underlying Unix socket connection, // allowing access to specialized functions. // // Use this ONLY if you know exactly what you are doing. func (p *Conn) UnixConn() (conn *net.UnixConn, ok bool) { conn, ok = p.conn.(*net.UnixConn) return } // UDPConn returns the underlying UDP connection, // allowing access to specialized functions. // // Use this ONLY if you know exactly what you are doing. func (p *Conn) UDPConn() (conn *net.UDPConn, ok bool) { conn, ok = p.conn.(*net.UDPConn) return } // SetDeadline wraps original conn.SetDeadline func (p *Conn) SetDeadline(t time.Time) error { p.readDeadline.Store(t) return p.conn.SetDeadline(t) } // SetReadDeadline wraps original conn.SetReadDeadline func (p *Conn) SetReadDeadline(t time.Time) error { // Set a local var that tells us the desired deadline. This is // needed in order to reset the read deadline to the one that is // desired by the user, rather than an empty deadline. p.readDeadline.Store(t) return p.conn.SetReadDeadline(t) } // SetWriteDeadline wraps original conn.SetWriteDeadline func (p *Conn) SetWriteDeadline(t time.Time) error { return p.conn.SetWriteDeadline(t) } func (p *Conn) readHeader() error { // If the connection's readHeaderTimeout is more than 0, // push our deadline back to now plus the timeout. This should only // run on the connection, as we don't want to override the previous // read deadline the user may have used. if p.readHeaderTimeout > 0 { if err := p.conn.SetReadDeadline(time.Now().Add(p.readHeaderTimeout)); err != nil { return err } } header, err := Read(p.bufReader) // If the connection's readHeaderTimeout is more than 0, undo the change to the // deadline that we made above. Because we retain the readDeadline as part of our // SetReadDeadline override, we know the user's desired deadline so we use that. // Therefore, we check whether the error is a net.Timeout and if it is, we decide // the proxy proto does not exist and set the error accordingly. if p.readHeaderTimeout > 0 { t := p.readDeadline.Load() if t == nil { t = time.Time{} } if err := p.conn.SetReadDeadline(t.(time.Time)); err != nil { return err } if netErr, ok := err.(net.Error); ok && netErr.Timeout() { err = ErrNoProxyProtocol } } // For the purpose of this wrapper shamefully stolen from armon/go-proxyproto // let's act as if there was no error when PROXY protocol is not present. if err == ErrNoProxyProtocol { // but not if it is required that the connection has one if p.ProxyHeaderPolicy == REQUIRE { return err } return nil } // proxy protocol header was found if err == nil && header != nil { switch p.ProxyHeaderPolicy { case REJECT: // this connection is not allowed to send one return ErrSuperfluousProxyHeader case USE, REQUIRE: if p.Validate != nil { err = p.Validate(header) if err != nil { return err } } p.header = header } } return err } // ReadFrom implements the io.ReaderFrom ReadFrom method func (p *Conn) ReadFrom(r io.Reader) (int64, error) { if rf, ok := p.conn.(io.ReaderFrom); ok { return rf.ReadFrom(r) } return io.Copy(p.conn, r) } // WriteTo implements io.WriterTo func (p *Conn) WriteTo(w io.Writer) (int64, error) { p.once.Do(func() { p.readErr = p.readHeader() }) if p.readErr != nil { return 0, p.readErr } return p.bufReader.WriteTo(w) } go-proxyproto-0.7.0/protocol_test.go000066400000000000000000001047651440432001500176230ustar00rootroot00000000000000// This file was shamefully stolen from github.com/armon/go-proxyproto. // It has been heavily edited to conform to this lib. // // Thanks @armon package proxyproto import ( "bytes" "crypto/tls" "crypto/x509" "errors" "fmt" "io" "io/ioutil" "net" "testing" "time" ) func TestPassthrough(t *testing.T) { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("err: %v", err) } pl := &Listener{Listener: l} cliResult := make(chan error) go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { cliResult <- err return } defer conn.Close() if _, err := conn.Write([]byte("ping")); err != nil { cliResult <- err return } recv := make([]byte, 4) if _, err = conn.Read(recv); err != nil { cliResult <- err return } if !bytes.Equal(recv, []byte("pong")) { cliResult <- fmt.Errorf("bad: %v", recv) return } close(cliResult) }() conn, err := pl.Accept() if err != nil { t.Fatalf("err: %v", err) } defer conn.Close() recv := make([]byte, 4) _, err = conn.Read(recv) if err != nil { t.Fatalf("err: %v", err) } if !bytes.Equal(recv, []byte("ping")) { t.Fatalf("bad: %v", recv) } if _, err := conn.Write([]byte("pong")); err != nil { t.Fatalf("err: %v", err) } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } } // TestRequiredWithReadHeaderTimeout will iterate through 3 different timeouts to see // whether using a REQUIRE policy for a listener would cause an error if the timeout // is triggerred without a proxy protocol header being defined. func TestRequiredWithReadHeaderTimeout(t *testing.T) { for _, duration := range []int{100, 200, 400} { t.Run(fmt.Sprint(duration), func(t *testing.T) { start := time.Now() l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("err: %v", err) } pl := &Listener{ Listener: l, ReadHeaderTimeout: time.Millisecond * time.Duration(duration), Policy: func(upstream net.Addr) (Policy, error) { return REQUIRE, nil }, } cliResult := make(chan error) go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { cliResult <- err return } defer conn.Close() close(cliResult) }() conn, err := pl.Accept() if err != nil { t.Fatalf("err: %v", err) } defer conn.Close() // Read blocks forever if there is no ReadHeaderTimeout and the policy is not REQUIRE recv := make([]byte, 4) _, err = conn.Read(recv) if err != nil && !errors.Is(err, ErrNoProxyProtocol) && time.Since(start)-pl.ReadHeaderTimeout > 10*time.Millisecond { t.Fatal("proxy proto should not be found and time should be close to read timeout") } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } }) } } // TestUseWithReadHeaderTimeout will iterate through 3 different timeouts to see // whether using a USE policy for a listener would not cause an error if the timeout // is triggerred without a proxy protocol header being defined. func TestUseWithReadHeaderTimeout(t *testing.T) { for _, duration := range []int{100, 200, 400} { t.Run(fmt.Sprint(duration), func(t *testing.T) { start := time.Now() l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("err: %v", err) } pl := &Listener{ Listener: l, ReadHeaderTimeout: time.Millisecond * time.Duration(duration), Policy: func(upstream net.Addr) (Policy, error) { return USE, nil }, } cliResult := make(chan error) go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { cliResult <- err return } defer conn.Close() close(cliResult) }() conn, err := pl.Accept() if err != nil { t.Fatalf("err: %v", err) } defer conn.Close() // 2 times the ReadHeaderTimeout because the first timeout // should occur (the one set on the listener) and allow for the second to follow up if err := conn.SetDeadline(time.Now().Add(pl.ReadHeaderTimeout * 2)); err != nil { t.Fatalf("err: %v", err) } // Read blocks forever if there is no ReadHeaderTimeout recv := make([]byte, 4) _, err = conn.Read(recv) if err != nil && !errors.Is(err, ErrNoProxyProtocol) && (time.Since(start)-(pl.ReadHeaderTimeout*2)) > 10*time.Millisecond { t.Fatal("proxy proto should not be found and time should be close to read timeout") } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } }) } } func TestReadHeaderTimeoutIsReset(t *testing.T) { const timeout = time.Millisecond * 250 l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("err: %v", err) } pl := &Listener{ Listener: l, ReadHeaderTimeout: timeout, } header := &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, } cliResult := make(chan error) go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { cliResult <- err return } defer conn.Close() // Write out the header! if _, err := header.WriteTo(conn); err != nil { cliResult <- err return } // Sleep here longer than the configured timeout. time.Sleep(timeout * 2) if _, err := conn.Write([]byte("ping")); err != nil { cliResult <- err return } recv := make([]byte, 4) if _, err := conn.Read(recv); err != nil { cliResult <- err return } if !bytes.Equal(recv, []byte("pong")) { cliResult <- fmt.Errorf("bad: %v", recv) return } close(cliResult) }() conn, err := pl.Accept() if err != nil { t.Fatalf("err: %v", err) } defer conn.Close() // Set our deadlines higher than our ReadHeaderTimeout if err := conn.SetReadDeadline(time.Now().Add(timeout * 3)); err != nil { t.Fatalf("err: %v", err) } if err := conn.SetWriteDeadline(time.Now().Add(timeout * 3)); err != nil { t.Fatalf("err: %v", err) } recv := make([]byte, 4) if _, err = conn.Read(recv); err != nil { t.Fatalf("err: %v", err) } if !bytes.Equal(recv, []byte("ping")) { t.Fatalf("bad: %v", recv) } if _, err := conn.Write([]byte("pong")); err != nil { t.Fatalf("err: %v", err) } // Check the remote addr addr := conn.RemoteAddr().(*net.TCPAddr) if addr.IP.String() != "10.1.1.1" { t.Fatalf("bad: %v", addr) } if addr.Port != 1000 { t.Fatalf("bad: %v", addr) } h := conn.(*Conn).ProxyHeader() if !h.EqualsTo(header) { t.Errorf("bad: %v", h) } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } } // TestReadHeaderTimeoutIsEmpty ensures the default is set if it is empty. // Because the default is 200ms and we wait longer than that to send a message, // we expect the actual address and port to be returned, // rather than the ProxyHeader we defined. func TestReadHeaderTimeoutIsEmpty(t *testing.T) { DefaultReadHeaderTimeout = 200 * time.Millisecond l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("err: %v", err) } pl := &Listener{ Listener: l, } header := &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, } cliResult := make(chan error) go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { cliResult <- err return } defer conn.Close() // Sleep here longer than the configured timeout. time.Sleep(250 * time.Millisecond) // Write out the header! if _, err := header.WriteTo(conn); err != nil { cliResult <- err return } if _, err := conn.Write([]byte("ping")); err != nil { cliResult <- err return } close(cliResult) }() conn, err := pl.Accept() if err != nil { t.Fatalf("err: %v", err) } defer conn.Close() recv := make([]byte, 4) if _, err = conn.Read(recv); err != nil { t.Fatalf("err: %v", err) } // Check the remote addr addr := conn.RemoteAddr().(*net.TCPAddr) if addr.IP.String() == "10.1.1.1" { t.Fatalf("bad: %v", addr) } if addr.Port == 1000 { t.Fatalf("bad: %v", addr) } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } } // TestReadHeaderTimeoutIsNegative does the same as above except // with a negative timeout. Therefore, we expect the right ProxyHeader // to be returned. func TestReadHeaderTimeoutIsNegative(t *testing.T) { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("err: %v", err) } pl := &Listener{ Listener: l, ReadHeaderTimeout: -1, } header := &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, } cliResult := make(chan error) go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { cliResult <- err return } defer conn.Close() // Sleep here longer than the configured timeout. time.Sleep(250 * time.Millisecond) // Write out the header! if _, err := header.WriteTo(conn); err != nil { cliResult <- err return } if _, err := conn.Write([]byte("ping")); err != nil { cliResult <- err return } close(cliResult) }() conn, err := pl.Accept() if err != nil { t.Fatalf("err: %v", err) } defer conn.Close() recv := make([]byte, 4) if _, err = conn.Read(recv); err != nil { t.Fatalf("err: %v", err) } // Check the remote addr addr := conn.RemoteAddr().(*net.TCPAddr) if addr.IP.String() != "10.1.1.1" { t.Fatalf("bad: %v", addr) } if addr.Port != 1000 { t.Fatalf("bad: %v", addr) } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } } func TestParse_ipv4(t *testing.T) { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("err: %v", err) } pl := &Listener{Listener: l} header := &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, } cliResult := make(chan error) go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { cliResult <- err return } defer conn.Close() // Write out the header! if _, err := header.WriteTo(conn); err != nil { cliResult <- err return } if _, err := conn.Write([]byte("ping")); err != nil { cliResult <- err return } recv := make([]byte, 4) if _, err = conn.Read(recv); err != nil { cliResult <- err return } if !bytes.Equal(recv, []byte("pong")) { cliResult <- fmt.Errorf("bad: %v", recv) return } close(cliResult) }() conn, err := pl.Accept() if err != nil { t.Fatalf("err: %v", err) } defer conn.Close() recv := make([]byte, 4) if _, err = conn.Read(recv); err != nil { t.Fatalf("err: %v", err) } if !bytes.Equal(recv, []byte("ping")) { t.Fatalf("bad: %v", recv) } if _, err := conn.Write([]byte("pong")); err != nil { t.Fatalf("err: %v", err) } // Check the remote addr addr := conn.RemoteAddr().(*net.TCPAddr) if addr.IP.String() != "10.1.1.1" { t.Fatalf("bad: %v", addr) } if addr.Port != 1000 { t.Fatalf("bad: %v", addr) } h := conn.(*Conn).ProxyHeader() if !h.EqualsTo(header) { t.Errorf("bad: %v", h) } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } } func TestParse_ipv6(t *testing.T) { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("err: %v", err) } pl := &Listener{Listener: l} header := &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv6, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("ffff::ffff"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("ffff::ffff"), Port: 2000, }, } cliResult := make(chan error) go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { cliResult <- err return } defer conn.Close() // Write out the header! if _, err := header.WriteTo(conn); err != nil { cliResult <- err return } if _, err := conn.Write([]byte("ping")); err != nil { cliResult <- err return } recv := make([]byte, 4) if _, err = conn.Read(recv); err != nil { cliResult <- err return } if !bytes.Equal(recv, []byte("pong")) { cliResult <- fmt.Errorf("bad: %v", recv) return } close(cliResult) }() conn, err := pl.Accept() if err != nil { t.Fatalf("err: %v", err) } defer conn.Close() recv := make([]byte, 4) if _, err = conn.Read(recv); err != nil { t.Fatalf("err: %v", err) } if !bytes.Equal(recv, []byte("ping")) { t.Fatalf("bad: %v", recv) } if _, err := conn.Write([]byte("pong")); err != nil { t.Fatalf("err: %v", err) } // Check the remote addr addr := conn.RemoteAddr().(*net.TCPAddr) if addr.IP.String() != "ffff::ffff" { t.Fatalf("bad: %v", addr) } if addr.Port != 1000 { t.Fatalf("bad: %v", addr) } h := conn.(*Conn).ProxyHeader() if !h.EqualsTo(header) { t.Errorf("bad: %v", h) } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } } func TestAcceptReturnsErrorWhenPolicyFuncErrors(t *testing.T) { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("err: %v", err) } expectedErr := fmt.Errorf("failure") policyFunc := func(upstream net.Addr) (Policy, error) { return USE, expectedErr } pl := &Listener{Listener: l, Policy: policyFunc} cliResult := make(chan error) go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { cliResult <- err return } defer conn.Close() close(cliResult) }() conn, err := pl.Accept() if err != expectedErr { t.Fatalf("Expected error %v, got %v", expectedErr, err) } if conn != nil { t.Fatalf("Expected no connection, got %v", conn) } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } } func TestReadingIsRefusedWhenProxyHeaderRequiredButMissing(t *testing.T) { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("err: %v", err) } policyFunc := func(upstream net.Addr) (Policy, error) { return REQUIRE, nil } pl := &Listener{Listener: l, Policy: policyFunc} cliResult := make(chan error) go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { cliResult <- err return } defer conn.Close() if _, err := conn.Write([]byte("ping")); err != nil { cliResult <- err return } close(cliResult) }() conn, err := pl.Accept() if err != nil { t.Fatalf("err: %v", err) } defer conn.Close() recv := make([]byte, 4) if _, err = conn.Read(recv); err != ErrNoProxyProtocol { t.Fatalf("Expected error %v, received %v", ErrNoProxyProtocol, err) } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } } func TestReadingIsRefusedWhenProxyHeaderPresentButNotAllowed(t *testing.T) { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("err: %v", err) } policyFunc := func(upstream net.Addr) (Policy, error) { return REJECT, nil } pl := &Listener{Listener: l, Policy: policyFunc} cliResult := make(chan error) go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { cliResult <- err return } defer conn.Close() header := &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, } if _, err := header.WriteTo(conn); err != nil { cliResult <- err return } close(cliResult) }() conn, err := pl.Accept() if err != nil { t.Fatalf("err: %v", err) } defer conn.Close() recv := make([]byte, 4) if _, err = conn.Read(recv); err != ErrSuperfluousProxyHeader { t.Fatalf("Expected error %v, received %v", ErrSuperfluousProxyHeader, err) } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } } func TestIgnorePolicyIgnoresIpFromProxyHeader(t *testing.T) { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("err: %v", err) } policyFunc := func(upstream net.Addr) (Policy, error) { return IGNORE, nil } pl := &Listener{Listener: l, Policy: policyFunc} cliResult := make(chan error) go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { cliResult <- err return } defer conn.Close() // Write out the header! header := &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, } if _, err := header.WriteTo(conn); err != nil { cliResult <- err return } if _, err := conn.Write([]byte("ping")); err != nil { cliResult <- err return } recv := make([]byte, 4) if _, err = conn.Read(recv); err != nil { cliResult <- err return } if !bytes.Equal(recv, []byte("pong")) { cliResult <- fmt.Errorf("bad: %v", recv) return } close(cliResult) }() conn, err := pl.Accept() if err != nil { t.Fatalf("err: %v", err) } defer conn.Close() recv := make([]byte, 4) if _, err = conn.Read(recv); err != nil { t.Fatalf("err: %v", err) } if !bytes.Equal(recv, []byte("ping")) { t.Fatalf("bad: %v", recv) } if _, err := conn.Write([]byte("pong")); err != nil { t.Fatalf("err: %v", err) } // Check the remote addr addr := conn.RemoteAddr().(*net.TCPAddr) if addr.IP.String() != "127.0.0.1" { t.Fatalf("bad: %v", addr) } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } } func Test_AllOptionsAreRecognized(t *testing.T) { recognizedOpt1 := false opt1 := func(c *Conn) { recognizedOpt1 = true } recognizedOpt2 := false opt2 := func(c *Conn) { recognizedOpt2 = true } server, client := net.Pipe() defer func() { client.Close() }() c := NewConn(server, opt1, opt2) if !recognizedOpt1 { t.Error("Expected option 1 recognized") } if !recognizedOpt2 { t.Error("Expected option 2 recognized") } c.Close() } func TestReadingIsRefusedOnErrorWhenRemoteAddrRequestedFirst(t *testing.T) { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("err: %v", err) } policyFunc := func(upstream net.Addr) (Policy, error) { return REQUIRE, nil } pl := &Listener{Listener: l, Policy: policyFunc} cliResult := make(chan error) go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { cliResult <- err return } defer conn.Close() if _, err := conn.Write([]byte("ping")); err != nil { cliResult <- err return } close(cliResult) }() conn, err := pl.Accept() if err != nil { t.Fatalf("err: %v", err) } defer conn.Close() _ = conn.RemoteAddr() recv := make([]byte, 4) if _, err = conn.Read(recv); err != ErrNoProxyProtocol { t.Fatalf("Expected error %v, received %v", ErrNoProxyProtocol, err) } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } } func TestReadingIsRefusedOnErrorWhenLocalAddrRequestedFirst(t *testing.T) { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("err: %v", err) } policyFunc := func(upstream net.Addr) (Policy, error) { return REQUIRE, nil } pl := &Listener{Listener: l, Policy: policyFunc} cliResult := make(chan error) go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { cliResult <- err return } defer conn.Close() if _, err := conn.Write([]byte("ping")); err != nil { cliResult <- err return } close(cliResult) }() conn, err := pl.Accept() if err != nil { t.Fatalf("err: %v", err) } defer conn.Close() _ = conn.LocalAddr() recv := make([]byte, 4) if _, err = conn.Read(recv); err != ErrNoProxyProtocol { t.Fatalf("Expected error %v, received %v", ErrNoProxyProtocol, err) } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } } func TestSkipProxyProtocolPolicy(t *testing.T) { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("err: %v", err) } policyFunc := func(upstream net.Addr) (Policy, error) { return SKIP, nil } pl := &Listener{ Listener: l, Policy: policyFunc, } cliResult := make(chan error) ping := []byte("ping") go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { cliResult <- err return } defer conn.Close() if _, err := conn.Write(ping); err != nil { cliResult <- err return } close(cliResult) }() conn, err := pl.Accept() if err != nil { t.Fatalf("err: %v", err) } defer conn.Close() _, ok := conn.(*net.TCPConn) if !ok { t.Fatal("err: should be a tcp connection") } _ = conn.LocalAddr() recv := make([]byte, 4) if _, err = conn.Read(recv); err != nil { t.Fatalf("Unexpected read error: %v", err) } if !bytes.Equal(ping, recv) { t.Fatalf("Unexpected %s data while expected %s", recv, ping) } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } } func Test_ConnectionCasts(t *testing.T) { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("err: %v", err) } policyFunc := func(upstream net.Addr) (Policy, error) { return REQUIRE, nil } pl := &Listener{Listener: l, Policy: policyFunc} cliResult := make(chan error) go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { cliResult <- err return } defer conn.Close() if _, err := conn.Write([]byte("ping")); err != nil { cliResult <- err return } close(cliResult) }() conn, err := pl.Accept() if err != nil { t.Fatalf("err: %v", err) } defer conn.Close() proxyprotoConn := conn.(*Conn) _, ok := proxyprotoConn.TCPConn() if !ok { t.Fatal("err: should be a tcp connection") } _, ok = proxyprotoConn.UDPConn() if ok { t.Fatal("err: should be a tcp connection not udp") } _, ok = proxyprotoConn.UnixConn() if ok { t.Fatal("err: should be a tcp connection not unix") } _, ok = proxyprotoConn.Raw().(*net.TCPConn) if !ok { t.Fatal("err: should be a tcp connection") } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } } func Test_ConnectionErrorsWhenHeaderValidationFails(t *testing.T) { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("err: %v", err) } validationError := fmt.Errorf("failed to validate") pl := &Listener{Listener: l, ValidateHeader: func(*Header) error { return validationError }} cliResult := make(chan error) go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { cliResult <- err return } defer conn.Close() // Write out the header! header := &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, } if _, err := header.WriteTo(conn); err != nil { cliResult <- err return } close(cliResult) }() conn, err := pl.Accept() if err != nil { t.Fatalf("err: %v", err) } defer conn.Close() recv := make([]byte, 4) if _, err = conn.Read(recv); err != validationError { t.Fatalf("expected validation error, got %v", err) } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } } type TestTLSServer struct { Listener net.Listener // TLS is the optional TLS configuration, populated with a new config // after TLS is started. If set on an unstarted server before StartTLS // is called, existing fields are copied into the new config. TLS *tls.Config TLSClientConfig *tls.Config // certificate is a parsed version of the TLS config certificate, if present. certificate *x509.Certificate } func (s *TestTLSServer) Addr() string { return s.Listener.Addr().String() } func (s *TestTLSServer) Close() { s.Listener.Close() } // based on net/http/httptest/Server.StartTLS func NewTestTLSServer(l net.Listener) *TestTLSServer { s := &TestTLSServer{} cert, err := tls.X509KeyPair(LocalhostCert, LocalhostKey) if err != nil { panic(fmt.Sprintf("httptest: NewTLSServer: %v", err)) } s.TLS = new(tls.Config) if len(s.TLS.Certificates) == 0 { s.TLS.Certificates = []tls.Certificate{cert} } s.certificate, err = x509.ParseCertificate(s.TLS.Certificates[0].Certificate[0]) if err != nil { panic(fmt.Sprintf("NewTestTLSServer: %v", err)) } certpool := x509.NewCertPool() certpool.AddCert(s.certificate) s.TLSClientConfig = &tls.Config{ RootCAs: certpool, } s.Listener = tls.NewListener(l, s.TLS) return s } func Test_TLSServer(t *testing.T) { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("err: %v", err) } s := NewTestTLSServer(l) s.Listener = &Listener{ Listener: s.Listener, Policy: func(upstream net.Addr) (Policy, error) { return REQUIRE, nil }, } defer s.Close() cliResult := make(chan error) go func() { conn, err := tls.Dial("tcp", s.Addr(), s.TLSClientConfig) if err != nil { cliResult <- err return } defer conn.Close() // Write out the header! header := &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, } if _, err := header.WriteTo(conn); err != nil { cliResult <- err return } if _, err := conn.Write([]byte("test")); err != nil { cliResult <- err return } close(cliResult) }() conn, err := s.Listener.Accept() if err != nil { t.Fatalf("err: %v", err) } defer conn.Close() recv := make([]byte, 1024) n, err := conn.Read(recv) if err != nil { t.Fatalf("expected no error, got %v", err) } if string(recv[:n]) != "test" { t.Fatalf("expected \"test\", got \"%s\" %v", recv[:n], recv[:n]) } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } } func Test_MisconfiguredTLSServerRespondsWithUnderlyingError(t *testing.T) { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("err: %v", err) } s := NewTestTLSServer(l) s.Listener = &Listener{ Listener: s.Listener, Policy: func(upstream net.Addr) (Policy, error) { return REQUIRE, nil }, } defer s.Close() cliResult := make(chan error) go func() { // this is not a valid TLS connection, we are // connecting to the TLS endpoint via plain TCP. // // it's an example of a configuration error: // client: HTTP -> PROXY // server: PROXY -> TLS -> HTTP // // we want to bubble up the underlying error, // in this case a tls handshake error, instead // of responding with a non-descript // > "Proxy protocol signature not present". conn, err := net.Dial("tcp", s.Addr()) if err != nil { cliResult <- err return } defer conn.Close() // Write out the header! header := &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, } if _, err := header.WriteTo(conn); err != nil { cliResult <- err return } if _, err := conn.Write([]byte("GET /foo/bar HTTP/1.1")); err != nil { cliResult <- err return } close(cliResult) }() conn, err := s.Listener.Accept() if err != nil { t.Fatalf("err: %v", err) } defer conn.Close() recv := make([]byte, 1024) if _, err = conn.Read(recv); err.Error() != "tls: first record does not look like a TLS handshake" { t.Fatalf("expected tls handshake error, got %s", err) } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } } type testConn struct { readFromCalledWith io.Reader reads int net.Conn // nil; crash on any unexpected use } func (c *testConn) ReadFrom(r io.Reader) (int64, error) { c.readFromCalledWith = r b, err := ioutil.ReadAll(r) return int64(len(b)), err } func (c *testConn) Write(p []byte) (int, error) { return len(p), nil } func (c *testConn) Read(p []byte) (int, error) { if c.reads == 0 { return 0, io.EOF } c.reads-- return 1, nil } func TestCopyToWrappedConnection(t *testing.T) { innerConn := &testConn{} wrappedConn := NewConn(innerConn) dummySrc := &testConn{reads: 1} if _, err := io.Copy(wrappedConn, dummySrc); err != nil { t.Fatalf("err: %v", err) } if innerConn.readFromCalledWith != dummySrc { t.Error("Expected io.Copy to delegate to ReadFrom function of inner destination connection") } } func TestCopyFromWrappedConnection(t *testing.T) { wrappedConn := NewConn(&testConn{reads: 1}) dummyDst := &testConn{} if _, err := io.Copy(dummyDst, wrappedConn); err != nil { t.Fatalf("err: %v", err) } if dummyDst.readFromCalledWith != wrappedConn.conn { t.Errorf("Expected io.Copy to pass inner source connection to ReadFrom method of destination") } } func TestCopyFromWrappedConnectionToWrappedConnection(t *testing.T) { innerConn1 := &testConn{reads: 1} wrappedConn1 := NewConn(innerConn1) innerConn2 := &testConn{} wrappedConn2 := NewConn(innerConn2) if _, err := io.Copy(wrappedConn1, wrappedConn2); err != nil { t.Fatalf("err: %v", err) } if innerConn1.readFromCalledWith != innerConn2 { t.Errorf("Expected io.Copy to pass inner source connection to ReadFrom of inner destination connection") } } func benchmarkTCPProxy(size int, b *testing.B) { //create and start the echo backend backend, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { b.Fatalf("err: %v", err) } defer backend.Close() go func() { for { conn, err := backend.Accept() if err != nil { break } _, err = io.Copy(conn, conn) // Can't defer since we keep accepting on each for iteration. _ = conn.Close() if err != nil { panic(fmt.Sprintf("Failed to read entire payload: %v", err)) } } }() //start the proxyprotocol enabled tcp proxy l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { b.Fatalf("err: %v", err) } defer l.Close() pl := &Listener{Listener: l} go func() { for { conn, err := pl.Accept() if err != nil { break } bConn, err := net.Dial("tcp", backend.Addr().String()) if err != nil { panic(fmt.Sprintf("failed to dial backend: %v", err)) } go func() { _, err = io.Copy(bConn, conn) _ = bConn.(*net.TCPConn).CloseWrite() if err != nil { panic(fmt.Sprintf("Failed to proxy incoming data to backend: %v", err)) } }() _, err = io.Copy(conn, bConn) if err != nil { panic(fmt.Sprintf("Failed to proxy data from backend: %v", err)) } _ = conn.Close() _ = bConn.Close() } }() data := make([]byte, size) header := &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, } //now for the actual benchmark b.ResetTimer() for n := 0; n < b.N; n++ { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { b.Fatalf("err: %v", err) } // Write out the header! if _, err := header.WriteTo(conn); err != nil { b.Fatalf("err: %v", err) } //send data go func() { _, err = conn.Write(data) _ = conn.(*net.TCPConn).CloseWrite() if err != nil { panic(fmt.Sprintf("Failed to write data: %v", err)) } }() //receive data n, err := io.Copy(ioutil.Discard, conn) if n != int64(len(data)) { b.Fatalf("Expected to receive %d bytes, got %d", len(data), n) } if err != nil { b.Fatalf("Failed to read data: %v", err) } conn.Close() } } func BenchmarkTCPProxy16KB(b *testing.B) { benchmarkTCPProxy(16*1024, b) } func BenchmarkTCPProxy32KB(b *testing.B) { benchmarkTCPProxy(32*1024, b) } func BenchmarkTCPProxy64KB(b *testing.B) { benchmarkTCPProxy(64*1024, b) } func BenchmarkTCPProxy128KB(b *testing.B) { benchmarkTCPProxy(128*1024, b) } func BenchmarkTCPProxy256KB(b *testing.B) { benchmarkTCPProxy(256*1024, b) } func BenchmarkTCPProxy512KB(b *testing.B) { benchmarkTCPProxy(512*1024, b) } func BenchmarkTCPProxy1024KB(b *testing.B) { benchmarkTCPProxy(1024*1024, b) } func BenchmarkTCPProxy2048KB(b *testing.B) { benchmarkTCPProxy(2048*1024, b) } // copied from src/net/http/internal/testcert.go // Copyright 2015 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // LocalhostCert is a PEM-encoded TLS cert with SAN IPs // "127.0.0.1" and "[::1]", expiring at Jan 29 16:00:00 2084 GMT. // generated from src/crypto/tls: // go run generate_cert.go --rsa-bits 1024 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h var LocalhostCert = []byte(`-----BEGIN CERTIFICATE----- MIICEzCCAXygAwIBAgIQMIMChMLGrR+QvmQvpwAU6zANBgkqhkiG9w0BAQsFADAS MRAwDgYDVQQKEwdBY21lIENvMCAXDTcwMDEwMTAwMDAwMFoYDzIwODQwMTI5MTYw MDAwWjASMRAwDgYDVQQKEwdBY21lIENvMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCB iQKBgQDuLnQAI3mDgey3VBzWnB2L39JUU4txjeVE6myuDqkM/uGlfjb9SjY1bIw4 iA5sBBZzHi3z0h1YV8QPuxEbi4nW91IJm2gsvvZhIrCHS3l6afab4pZBl2+XsDul rKBxKKtD1rGxlG4LjncdabFn9gvLZad2bSysqz/qTAUStTvqJQIDAQABo2gwZjAO BgNVHQ8BAf8EBAMCAqQwEwYDVR0lBAwwCgYIKwYBBQUHAwEwDwYDVR0TAQH/BAUw AwEB/zAuBgNVHREEJzAlggtleGFtcGxlLmNvbYcEfwAAAYcQAAAAAAAAAAAAAAAA AAAAATANBgkqhkiG9w0BAQsFAAOBgQCEcetwO59EWk7WiJsG4x8SY+UIAA+flUI9 tyC4lNhbcF2Idq9greZwbYCqTTTr2XiRNSMLCOjKyI7ukPoPjo16ocHj+P3vZGfs h1fIw3cSS2OolhloGw/XM6RWPWtPAlGykKLciQrBru5NAPvCMsb/I1DAceTiotQM fblo6RBxUQ== -----END CERTIFICATE-----`) // LocalhostKey is the private key for localhostCert. var LocalhostKey = []byte(`-----BEGIN RSA PRIVATE KEY----- MIICXgIBAAKBgQDuLnQAI3mDgey3VBzWnB2L39JUU4txjeVE6myuDqkM/uGlfjb9 SjY1bIw4iA5sBBZzHi3z0h1YV8QPuxEbi4nW91IJm2gsvvZhIrCHS3l6afab4pZB l2+XsDulrKBxKKtD1rGxlG4LjncdabFn9gvLZad2bSysqz/qTAUStTvqJQIDAQAB AoGAGRzwwir7XvBOAy5tM/uV6e+Zf6anZzus1s1Y1ClbjbE6HXbnWWF/wbZGOpet 3Zm4vD6MXc7jpTLryzTQIvVdfQbRc6+MUVeLKwZatTXtdZrhu+Jk7hx0nTPy8Jcb uJqFk541aEw+mMogY/xEcfbWd6IOkp+4xqjlFLBEDytgbIECQQDvH/E6nk+hgN4H qzzVtxxr397vWrjrIgPbJpQvBsafG7b0dA4AFjwVbFLmQcj2PprIMmPcQrooz8vp jy4SHEg1AkEA/v13/5M47K9vCxmb8QeD/asydfsgS5TeuNi8DoUBEmiSJwma7FXY fFUtxuvL7XvjwjN5B30pNEbc6Iuyt7y4MQJBAIt21su4b3sjXNueLKH85Q+phy2U fQtuUE9txblTu14q3N7gHRZB4ZMhFYyDy8CKrN2cPg/Fvyt0Xlp/DoCzjA0CQQDU y2ptGsuSmgUtWj3NM9xuwYPm+Z/F84K6+ARYiZ6PYj013sovGKUFfYAqVXVlxtIX qyUBnu3X9ps8ZfjLZO7BAkEAlT4R5Yl6cGhaJQYZHOde3JEMhNRcVFMO8dJDaFeo f9Oeos0UUothgiDktdQHxdNEwLjQf7lJJBzV+5OtwswCWA== -----END RSA PRIVATE KEY-----`) go-proxyproto-0.7.0/tlv.go000066400000000000000000000071601440432001500155170ustar00rootroot00000000000000// Type-Length-Value splitting and parsing for proxy protocol V2 // See spec https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt sections 2.2 to 2.7 and package proxyproto import ( "encoding/binary" "errors" "fmt" "math" ) const ( // Section 2.2 PP2_TYPE_ALPN PP2Type = 0x01 PP2_TYPE_AUTHORITY PP2Type = 0x02 PP2_TYPE_CRC32C PP2Type = 0x03 PP2_TYPE_NOOP PP2Type = 0x04 PP2_TYPE_UNIQUE_ID PP2Type = 0x05 PP2_TYPE_SSL PP2Type = 0x20 PP2_SUBTYPE_SSL_VERSION PP2Type = 0x21 PP2_SUBTYPE_SSL_CN PP2Type = 0x22 PP2_SUBTYPE_SSL_CIPHER PP2Type = 0x23 PP2_SUBTYPE_SSL_SIG_ALG PP2Type = 0x24 PP2_SUBTYPE_SSL_KEY_ALG PP2Type = 0x25 PP2_TYPE_NETNS PP2Type = 0x30 // Section 2.2.7, reserved types PP2_TYPE_MIN_CUSTOM PP2Type = 0xE0 PP2_TYPE_MAX_CUSTOM PP2Type = 0xEF PP2_TYPE_MIN_EXPERIMENT PP2Type = 0xF0 PP2_TYPE_MAX_EXPERIMENT PP2Type = 0xF7 PP2_TYPE_MIN_FUTURE PP2Type = 0xF8 PP2_TYPE_MAX_FUTURE PP2Type = 0xFF ) var ( ErrTruncatedTLV = errors.New("proxyproto: truncated TLV") ErrMalformedTLV = errors.New("proxyproto: malformed TLV Value") ErrIncompatibleTLV = errors.New("proxyproto: incompatible TLV type") ) // PP2Type is the proxy protocol v2 type type PP2Type byte // TLV is a uninterpreted Type-Length-Value for V2 protocol, see section 2.2 type TLV struct { Type PP2Type Value []byte } // SplitTLVs splits the Type-Length-Value vector, returns the vector or an error. func SplitTLVs(raw []byte) ([]TLV, error) { var tlvs []TLV for i := 0; i < len(raw); { tlv := TLV{ Type: PP2Type(raw[i]), } if len(raw)-i <= 2 { return nil, ErrTruncatedTLV } tlvLen := int(binary.BigEndian.Uint16(raw[i+1 : i+3])) // Max length = 65K i += 3 if i+tlvLen > len(raw) { return nil, ErrTruncatedTLV } // Ignore no-op padding if tlv.Type != PP2_TYPE_NOOP { tlv.Value = make([]byte, tlvLen) copy(tlv.Value, raw[i:i+tlvLen]) } i += tlvLen tlvs = append(tlvs, tlv) } return tlvs, nil } // JoinTLVs joins multiple Type-Length-Value records. func JoinTLVs(tlvs []TLV) ([]byte, error) { var raw []byte for _, tlv := range tlvs { if len(tlv.Value) > math.MaxUint16 { return nil, fmt.Errorf("proxyproto: cannot format TLV %v with length %d", tlv.Type, len(tlv.Value)) } var length [2]byte binary.BigEndian.PutUint16(length[:], uint16(len(tlv.Value))) raw = append(raw, byte(tlv.Type)) raw = append(raw, length[:]...) raw = append(raw, tlv.Value...) } return raw, nil } // Registered is true if the type is registered in the spec, see section 2.2 func (p PP2Type) Registered() bool { switch p { case PP2_TYPE_ALPN, PP2_TYPE_AUTHORITY, PP2_TYPE_CRC32C, PP2_TYPE_NOOP, PP2_TYPE_UNIQUE_ID, PP2_TYPE_SSL, PP2_SUBTYPE_SSL_VERSION, PP2_SUBTYPE_SSL_CN, PP2_SUBTYPE_SSL_CIPHER, PP2_SUBTYPE_SSL_SIG_ALG, PP2_SUBTYPE_SSL_KEY_ALG, PP2_TYPE_NETNS: return true } return false } // App is true if the type is reserved for application specific data, see section 2.2.7 func (p PP2Type) App() bool { return p >= PP2_TYPE_MIN_CUSTOM && p <= PP2_TYPE_MAX_CUSTOM } // Experiment is true if the type is reserved for temporary experimental use by application developers, see section 2.2.7 func (p PP2Type) Experiment() bool { return p >= PP2_TYPE_MIN_EXPERIMENT && p <= PP2_TYPE_MAX_EXPERIMENT } // Future is true is the type is reserved for future use, see section 2.2.7 func (p PP2Type) Future() bool { return p >= PP2_TYPE_MIN_FUTURE } // Spec is true if the type is covered by the spec, see section 2.2 and 2.2.7 func (p PP2Type) Spec() bool { return p.Registered() || p.App() || p.Experiment() || p.Future() } go-proxyproto-0.7.0/tlv_test.go000066400000000000000000000102321440432001500165500ustar00rootroot00000000000000package proxyproto import ( "bufio" "bytes" "testing" ) var ( fixtureOneByteTLV = []byte{byte(PP2_TYPE_MIN_CUSTOM) + 1} fixtureTwoByteTLV = []byte{byte(PP2_TYPE_MIN_CUSTOM) + 2, 0x00} fixtureEmptyLenTLV = []byte{byte(PP2_TYPE_MIN_CUSTOM) + 3, 0x00, 0x01} fixturePartialLenTLV = []byte{byte(PP2_TYPE_MIN_CUSTOM) + 3, 0x00, 0x02, 0x00} ) var invalidTLVTests = []struct { name string reader *bufio.Reader expectedError error }{ { name: "One byte TLV", reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv4)), fixtureWithTLV(lengthV4Bytes, fixtureIPv4Address, fixtureOneByteTLV)...)), expectedError: ErrTruncatedTLV, }, { name: "Two byte TLV", reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv4)), fixtureWithTLV(lengthV4Bytes, fixtureIPv4Address, fixtureTwoByteTLV)...)), expectedError: ErrTruncatedTLV, }, { name: "Empty Len TLV", reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv4)), fixtureWithTLV(lengthV4Bytes, fixtureIPv4Address, fixtureEmptyLenTLV)...)), expectedError: ErrTruncatedTLV, }, { name: "Partial Len TLV", reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv4)), fixtureWithTLV(lengthV4Bytes, fixtureIPv4Address, fixturePartialLenTLV)...)), expectedError: ErrTruncatedTLV, }, } func TestValid0Length(t *testing.T) { r := bufio.NewReader(bytes.NewReader(append(append(SIGV2, byte(PROXY), byte(TCPv4)), fixtureWithTLV(lengthV4Bytes, fixtureIPv4Address, []byte{byte(PP2_TYPE_MIN_CUSTOM), 0x00, 0x00})...))) h, err := Read(r) if err != nil { t.Fatalf("unexpected error: %v", err) } tlvs, err := h.TLVs() if err != nil { t.Fatalf("unexpected error: %v", err) } if len(tlvs) != 1 { t.Fatalf("expected 1 tlv, got %d", len(tlvs)) } if len(tlvs[0].Value) != 0 { t.Fatalf("expected 0 byte tlv value, got %x", tlvs[0].Value) } } func TestInvalidV2TLV(t *testing.T) { for _, tc := range invalidTLVTests { t.Run(tc.name, func(t *testing.T) { if hdr, err := Read(tc.reader); err != nil { t.Fatalf("TestInvalidV2TLV %s: unexpected error reading proxy protocol %#v", tc.name, err) } else if _, err := hdr.TLVs(); err != tc.expectedError { t.Fatalf("TestInvalidV2TLV %s: expected %#v, actual %#v", tc.name, tc.expectedError, err) } }) } } func TestV2TLVPP2Registered(t *testing.T) { pp2RegTypes := []PP2Type{ PP2_TYPE_ALPN, PP2_TYPE_AUTHORITY, PP2_TYPE_CRC32C, PP2_TYPE_NOOP, PP2_TYPE_UNIQUE_ID, PP2_TYPE_SSL, PP2_SUBTYPE_SSL_VERSION, PP2_SUBTYPE_SSL_CN, PP2_SUBTYPE_SSL_CIPHER, PP2_SUBTYPE_SSL_SIG_ALG, PP2_SUBTYPE_SSL_KEY_ALG, PP2_TYPE_NETNS, } pp2RegMap := make(map[PP2Type]bool) for _, p := range pp2RegTypes { pp2RegMap[p] = true if !p.Registered() { t.Fatalf("TestV2TLVPP2Registered: type %x should be registered", p) } if !p.Spec() { t.Fatalf("TestV2TLVPP2Registered: type %x should be in spec", p) } if p.App() { t.Fatalf("TestV2TLVPP2Registered: type %x unexpectedly app", p) } if p.Experiment() { t.Fatalf("TestV2TLVPP2Registered: type %x unexpectedly experiment", p) } if p.Future() { t.Fatalf("TestV2TLVPP2Registered: type %x unexpectedly future", p) } } lastType := PP2Type(0xFF) for i := PP2Type(0x00); i < lastType; i++ { if !pp2RegMap[i] { if i.Registered() { t.Fatalf("TestV2TLVPP2Registered: type %x unexpectedly registered", i) } } } if lastType.Registered() { t.Fatalf("TestV2TLVPP2Registered: type %x unexpectedly registered", lastType) } } func TestJoinTLVs(t *testing.T) { tests := []struct { name string raw []byte tlvs []TLV }{ { name: "authority TLV", raw: append([]byte{byte(PP2_TYPE_AUTHORITY), 0x00, 0x0B}, []byte("example.org")...), tlvs: []TLV{{ Type: PP2_TYPE_AUTHORITY, Value: []byte("example.org"), }}, }, { name: "empty TLV", raw: []byte{byte(PP2_TYPE_NOOP), 0x00, 0x00}, tlvs: []TLV{{ Type: PP2_TYPE_NOOP, Value: nil, }}, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { if raw, err := JoinTLVs(tc.tlvs); err != nil { t.Fatalf("unexpected error: %v", err) } else if !bytes.Equal(raw, tc.raw) { t.Errorf("expected %#v, got %#v", tc.raw, raw) } }) } } go-proxyproto-0.7.0/tlvparse/000077500000000000000000000000001440432001500162175ustar00rootroot00000000000000go-proxyproto-0.7.0/tlvparse/aws.go000066400000000000000000000021301440432001500173340ustar00rootroot00000000000000// Amazon's application extension to TLVs for NLB VPC endpoint services // https://docs.aws.amazon.com/elasticloadbalancing/latest/network/load-balancer-target-groups.html#proxy-protocol package tlvparse import ( "regexp" "github.com/pires/go-proxyproto" ) const ( // Amazon's extension PP2_TYPE_AWS = 0xEA PP2_SUBTYPE_AWS_VPCE_ID = 0x01 ) var vpceRe = regexp.MustCompile("^[A-Za-z0-9-]*$") func IsAWSVPCEndpointID(tlv proxyproto.TLV) bool { return tlv.Type == PP2_TYPE_AWS && len(tlv.Value) > 0 && tlv.Value[0] == PP2_SUBTYPE_AWS_VPCE_ID } func AWSVPCEndpointID(tlv proxyproto.TLV) (string, error) { if !IsAWSVPCEndpointID(tlv) { return "", proxyproto.ErrIncompatibleTLV } vpce := string(tlv.Value[1:]) if !vpceRe.MatchString(vpce) { return "", proxyproto.ErrMalformedTLV } return vpce, nil } // FindAWSVPCEndpointID returns the first AWS VPC ID in the TLV if it exists and is well-formed. func FindAWSVPCEndpointID(tlvs []proxyproto.TLV) string { for _, tlv := range tlvs { if vpc, err := AWSVPCEndpointID(tlv); err == nil && vpc != "" { return vpc } } return "" } go-proxyproto-0.7.0/tlvparse/aws_test.go000066400000000000000000000170301440432001500204000ustar00rootroot00000000000000package tlvparse import ( "encoding/binary" "testing" "github.com/pires/go-proxyproto" ) var awsTestCases = []struct { name string raw []byte types []proxyproto.PP2Type valid func(*testing.T, string, []proxyproto.TLV) }{ { name: "VPCE example", // https://github.com/aws/elastic-load-balancing-tools/blob/c8eee30ab991ab4c57dc37d1c58f09f67bd534aa/proprot/tst/com/amazonaws/proprot/Compatibility_AwsNetworkLoadBalancerTest.java#L41..L67 raw: []byte{ 0x0d, 0x0a, 0x0d, 0x0a, /* Start of Sig */ 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, 0x54, 0x0a, /* End of Sig */ 0x21, 0x11, 0x00, 0x54, /* ver_cmd, fam and len */ 0xac, 0x1f, 0x07, 0x71, /* Caller src ip */ 0xac, 0x1f, 0x0a, 0x1f, /* Endpoint dst ip */ 0xc8, 0xf2, 0x00, 0x50, /* Proxy src port & dst port */ 0x03, 0x00, 0x04, 0xe8, /* CRC TLV start */ 0xd6, 0x89, 0x2d, 0xea, /* CRC TLV cont, VPCE id TLV start */ 0x00, 0x17, 0x01, 0x76, 0x70, 0x63, 0x65, 0x2d, 0x30, 0x38, 0x64, 0x32, 0x62, 0x66, 0x31, 0x35, 0x66, 0x61, 0x63, 0x35, 0x30, 0x30, 0x31, 0x63, 0x39, 0x04, 0x00, 0x24, /* VPCE id TLV end, NOOP TLV start*/ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, /* NOOP TLV end */ }, types: []proxyproto.PP2Type{proxyproto.PP2_TYPE_CRC32C, PP2_TYPE_AWS, proxyproto.PP2_TYPE_NOOP}, valid: func(t *testing.T, name string, tlvs []proxyproto.TLV) { if !IsAWSVPCEndpointID(tlvs[1]) { t.Fatalf("TestParseV2TLV %s: Expected tlvs[1] to be an AWSVPCEndpointID type", name) } vpce := "vpce-08d2bf15fac5001c9" if vpca, err := AWSVPCEndpointID(tlvs[1]); err != nil { t.Fatalf("TestParseV2TLV %s: Unexpected error when parsing AWSVPCEndpointID", name) } else if vpca != vpce { t.Fatalf("TestParseV2TLV %s: Unexpected VPC ID from tlvs[1] expected %#v, actual %#v", name, vpce, vpca) } if vpca := FindAWSVPCEndpointID(tlvs); vpca == "" { t.Fatalf("TestParseV2TLV %s: Expected to find AWSVPCEndpointID %#v in TLVs", name, vpce) } else if vpca != vpce { t.Fatalf("TestParseV2TLV %s: Unexpected AWSVPCEndpointID from header expected %#v, actual %#v", name, vpce, vpca) } }, }, { name: "VPCE capture", raw: []byte{ 0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, 0x54, 0x0a, 0x21, 0x11, 0x00, 0x54, 0xc0, 0xa8, 0x2c, 0x0a, 0xc0, 0xa8, 0x2c, 0x07, 0xcc, 0x3e, 0x24, 0x1b, 0x03, 0x00, 0x04, 0xb9, 0x28, 0x6f, 0xa6, 0xea, 0x00, 0x17, 0x01, 0x76, 0x70, 0x63, 0x65, 0x2d, 0x30, 0x30, 0x65, 0x61, 0x66, 0x63, 0x34, 0x35, 0x38, 0x65, 0x63, 0x39, 0x37, 0x62, 0x38, 0x33, 0x33, 0x04, 0x00, 0x24, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, }, types: []proxyproto.PP2Type{proxyproto.PP2_TYPE_CRC32C, PP2_TYPE_AWS, proxyproto.PP2_TYPE_NOOP}, valid: func(t *testing.T, name string, tlvs []proxyproto.TLV) { if !IsAWSVPCEndpointID(tlvs[1]) { t.Fatalf("TestParseV2TLV %s: Expected tlvs[1] to be an AWS VPC endpoint ID type", name) } vpce := "vpce-00eafc458ec97b833" if vpca, err := AWSVPCEndpointID(tlvs[1]); err != nil { t.Fatalf("TestParseV2TLV %s: Unexpected error when parsing AWS VPC ID", name) } else if vpca != vpce { t.Fatalf("TestParseV2TLV %s: Unexpected VPC ID from tlvs[1] expected %#v, actual %#v", name, vpce, vpca) } if vpca := FindAWSVPCEndpointID(tlvs); vpca == "" { t.Fatalf("TestParseV2TLV %s: Expected to find VPC ID %#v in TLVs", name, vpce) } else if vpca != vpce { t.Fatalf("TestParseV2TLV %s: Unexpected VPC ID from header expected %#v, actual %#v", name, vpce, vpca) } }, }, } func TestV2TLVAWSVPCEBadChars(t *testing.T) { badVPCE := "vcpe-!?***&&&&&&&" rawTLVs := vpceTLV(badVPCE) tlvs, err := proxyproto.SplitTLVs(rawTLVs) if len(tlvs) != 1 { t.Fatalf("TestV2TLVAWSVPCEBadChars: unexpected TLV length expected: %#v, actual: %#v", 1, tlvs) } if err != nil { t.Fatalf("TestV2TLVAWSVPCEBadChars: unexpected TLV parsing error %#v", err) } _, err = AWSVPCEndpointID(tlvs[0]) if err != proxyproto.ErrMalformedTLV { t.Fatalf("TestV2TLVAWSVPCEBadChars: unexpected error actual: %#v", err) } if FindAWSVPCEndpointID(tlvs) != "" { t.Fatal("TestV2TLVAWSVPCEBadChars: AWSVPCEndpointID unexpectedly found") } rawTLVs = vpceTLV("") tlvs, err = proxyproto.SplitTLVs(rawTLVs) if len(tlvs) != 1 { t.Fatalf("TestV2TLVAWSVPCEBadChars: unexpected TLV length expected: %#v, actual: %#v", 1, tlvs) } if err != nil { t.Fatalf("TestV2TLVAWSVPCEBadChars: unexpected TLV parsing error %#v", err) } parsedVPCE, err := AWSVPCEndpointID(tlvs[0]) if err != nil { t.Fatalf("TestV2TLVAWSVPCEBadChars: unexpected error actual: %#v", err) } if parsedVPCE != "" { t.Fatalf("TestV2TLVAWSVPCEBadChars: found non-empty vpce, actual: %#v", parsedVPCE) } parsedVPCE = FindAWSVPCEndpointID(tlvs) if parsedVPCE != "" { t.Fatal("TestV2TLVAWSVPCEBadChars: AWSVPECID unexpectedly found") } } func TestParseAWSVPCEndpointIDTLVs(t *testing.T) { for _, tc := range awsTestCases { t.Run(tc.name, func(t *testing.T) { tlvs := checkTLVs(t, tc.name, tc.raw, tc.types) tc.valid(t, tc.name, tlvs) }) } } func TestV2TLVAWSUnknownSubtype(t *testing.T) { vpce := "vpce-abc1234" rawTLVs := vpceTLV(vpce) tlvs, err := proxyproto.SplitTLVs(rawTLVs) if len(tlvs) != 1 { t.Fatalf("TestV2TLVAWSUnknownSubtype: unexpected TLV length expected: %#v, actual: %#v", 1, tlvs) } if err != nil { t.Fatalf("TestV2TLVAWSUnknownSubtype: unexpected TLV parsing error %#v", err) } avpce, err := AWSVPCEndpointID(tlvs[0]) if err != nil { t.Fatalf("TestV2TLVAWSUnknownSubtype: unexpected AWSVPCEndpointID error actual: %#v", err) } if avpce != vpce { t.Fatalf("TestV2TLVAWSUnknownSubtype: unexpected vpce value expected: %#v, actual: %#v", vpce, avpce) } avpce = FindAWSVPCEndpointID(tlvs) if avpce == "" { t.Fatal("TestV2TLVAWSUnknownSubtype: AWSVPCEndpointID unexpectedly missing") } if avpce != vpce { t.Fatalf("TestV2TLVAWSUnknownSubtype: unexpected AWSVPCEndpointID value expected: %#v, actual: %#v", vpce, avpce) } subtypeIndex := 3 // Sanity check if rawTLVs[subtypeIndex] != PP2_SUBTYPE_AWS_VPCE_ID { t.Fatalf("TestV2TLVAWSUnknownSubtype: unexpected subtype expected %x, actual %x", PP2_SUBTYPE_AWS_VPCE_ID, rawTLVs[subtypeIndex]) } rawTLVs[subtypeIndex] = PP2_SUBTYPE_AWS_VPCE_ID + 1 tlvs, err = proxyproto.SplitTLVs(rawTLVs) if len(tlvs) != 1 { t.Fatalf("TestV2TLVAWSUnknownSubtype: unexpected TLV length expected: %#v, actual: %#v", 1, tlvs) } if err != nil { t.Fatalf("TestV2TLVAWSUnknownSubtype: unexpected TLV parsing error %#v", err) } if IsAWSVPCEndpointID(tlvs[0]) { t.Fatalf("TestV2TLVAWSUnknownSubtype: AWSVPCEType() unexpectedly true after changing subtype") } _, err = AWSVPCEndpointID(tlvs[0]) if err != proxyproto.ErrIncompatibleTLV { t.Fatalf("TestV2TLVAWSUnknownSubtype: unexpected AWSVPCEndpointID error expected %#v, actual: %#v", proxyproto.ErrIncompatibleTLV, err) } if FindAWSVPCEndpointID(tlvs) != "" { t.Fatal("TestV2TLVAWSUnknownSubtype: AWSVPCEndpointID unexpectedly exists despite invalid subtype") } } func vpceTLV(vpce string) []byte { tlv := []byte{ PP2_TYPE_AWS, 0x00, 0x00, PP2_SUBTYPE_AWS_VPCE_ID, } binary.BigEndian.PutUint16(tlv[1:3], uint16(len(vpce)+1)) // +1 for subtype return append(tlv, []byte(vpce)...) } go-proxyproto-0.7.0/tlvparse/azure.go000066400000000000000000000032531440432001500176770ustar00rootroot00000000000000// Azure's application extension to TLVs for Private Link Services // https://docs.microsoft.com/en-us/azure/private-link/private-link-service-overview#getting-connection-information-using-tcp-proxy-v2 package tlvparse import ( "encoding/binary" "github.com/pires/go-proxyproto" ) const ( // Azure's extension PP2_TYPE_AZURE = 0xEE PP2_SUBTYPE_AZURE_PRIVATEENDPOINT_LINKID = 0x01 ) // IsAzurePrivateEndpointLinkID returns true if given TLV matches Azure Private Endpoint LinkID format func isAzurePrivateEndpointLinkID(tlv proxyproto.TLV) bool { return tlv.Type == PP2_TYPE_AZURE && len(tlv.Value) == 5 && tlv.Value[0] == PP2_SUBTYPE_AZURE_PRIVATEENDPOINT_LINKID } // AzurePrivateEndpointLinkID returns linkID if given TLV matches Azure Private Endpoint LinkID format // // Format description: // Field Length (Octets) Description // Type 1 PP2_TYPE_AZURE (0xEE) // Length 2 Length of value // Value 1 PP2_SUBTYPE_AZURE_PRIVATEENDPOINT_LINKID (0x01) // 4 UINT32 (4 bytes) representing the LINKID of the private endpoint. Encoded in little endian format. func azurePrivateEndpointLinkID(tlv proxyproto.TLV) (uint32, error) { if !isAzurePrivateEndpointLinkID(tlv) { return 0, proxyproto.ErrIncompatibleTLV } linkID := binary.LittleEndian.Uint32(tlv.Value[1:]) return linkID, nil } // FindAzurePrivateEndpointLinkID returns the first Azure Private Endpoint LinkID if it exists in the TLV collection // and a boolean indicating if it was found. func FindAzurePrivateEndpointLinkID(tlvs []proxyproto.TLV) (uint32, bool) { for _, tlv := range tlvs { if linkID, err := azurePrivateEndpointLinkID(tlv); err == nil { return linkID, true } } return 0, false } go-proxyproto-0.7.0/tlvparse/azure_test.go000066400000000000000000000044351440432001500207410ustar00rootroot00000000000000package tlvparse import ( "testing" "github.com/pires/go-proxyproto" ) func TestFindAzurePrivateEndpointLinkID(t *testing.T) { tests := []struct { name string tlvs []proxyproto.TLV wantLinkID uint32 wantFound bool }{ { name: "nil TLVs", tlvs: nil, wantLinkID: 0, wantFound: false, }, { name: "empty TLVs", tlvs: []proxyproto.TLV{}, wantLinkID: 0, wantFound: false, }, { name: "AWS VPC endpoint ID", tlvs: []proxyproto.TLV{ { Type: 0xEA, Value: []byte{0x01, 0x76, 0x70, 0x63, 0x65, 0x2d, 0x61, 0x62, 0x63, 0x31, 0x32, 0x33}, }, }, wantLinkID: 0, wantFound: false, }, { name: "Azure but wrong subtype", tlvs: []proxyproto.TLV{ { Type: 0xEE, Value: []byte{0x02, 0x01, 0x01, 0x01, 0x01}, }, }, wantLinkID: 0, wantFound: false, }, { name: "Azure but wrong length", tlvs: []proxyproto.TLV{ { Type: 0xEE, Value: []byte{0x02, 0x01, 0x01}, }, }, wantLinkID: 0, wantFound: false, }, { name: "Azure link ID", tlvs: []proxyproto.TLV{ { Type: 0xEE, Value: []byte{0x1, 0xc1, 0x45, 0x0, 0x21}, }, }, wantLinkID: 0x210045c1, wantFound: true, }, { name: "Multiple TLVs", tlvs: []proxyproto.TLV{ { // AWS Type: 0xEA, Value: []byte{0x01, 0x76, 0x70, 0x63, 0x65, 0x2d, 0x61, 0x62, 0x63, 0x31, 0x32, 0x33}, }, { // Azure but wrong subtype Type: 0xEE, Value: []byte{0x02, 0x01, 0x01, 0x01, 0x01}, }, { // Azure but wrong length Type: 0xEE, Value: []byte{0x02, 0x01, 0x01}, }, { // Correct Type: 0xEE, Value: []byte{0x1, 0xc1, 0x45, 0x0, 0x21}, }, { // Also correct, but second in line Type: 0xEE, Value: []byte{0x1, 0xc1, 0x45, 0x0, 0x22}, }, }, wantLinkID: 0x210045c1, wantFound: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { gotLinkID, gotFound := FindAzurePrivateEndpointLinkID(tt.tlvs) if gotFound != tt.wantFound { t.Errorf("FindAzurePrivateEndpointLinkID() got1 = %v, want %v", gotFound, tt.wantFound) } if gotLinkID != tt.wantLinkID { t.Errorf("FindAzurePrivateEndpointLinkID() got = %v, want %v", gotLinkID, tt.wantLinkID) } }) } } go-proxyproto-0.7.0/tlvparse/gcp.go000066400000000000000000000026261440432001500173250ustar00rootroot00000000000000package tlvparse import ( "encoding/binary" "github.com/pires/go-proxyproto" ) const ( // PP2_TYPE_GCP indicates a Google Cloud Platform header PP2_TYPE_GCP proxyproto.PP2Type = 0xE0 ) // ExtractPSCConnectionID returns the first PSC Connection ID in the TLV if it exists and is well-formed and // a bool indicating one was found. func ExtractPSCConnectionID(tlvs []proxyproto.TLV) (uint64, bool) { for _, tlv := range tlvs { if linkID, err := pscConnectionID(tlv); err == nil { return linkID, true } } return 0, false } // pscConnectionID returns the ID of a GCP PSC extension TLV or errors with ErrIncompatibleTLV or // ErrMalformedTLV if it's the wrong TLV type or is malformed. // // Field Length (bytes) Description // Type 1 PP2_TYPE_GCP (0xE0) // Length 2 Length of value (always 0x0008) // Value 8 The 8-byte PSC Connection ID (decode to uint64; big endian) // // For example proxyproto.TLV{Type:0xea, Length:8, Value:[]byte{0xff, 0xff, 0xff, 0xff, 0xc0, 0xa8, 0x64, 0x02}} // will be decoded as 18446744072646845442. // // See https://cloud.google.com/vpc/docs/configure-private-service-connect-producer func pscConnectionID(t proxyproto.TLV) (uint64, error) { if !isPSCConnectionID(t) { return 0, proxyproto.ErrIncompatibleTLV } linkID := binary.BigEndian.Uint64(t.Value) return linkID, nil } func isPSCConnectionID(t proxyproto.TLV) bool { return t.Type == PP2_TYPE_GCP && len(t.Value) == 8 } go-proxyproto-0.7.0/tlvparse/gcp_test.go000066400000000000000000000035561440432001500203670ustar00rootroot00000000000000package tlvparse import ( "testing" "github.com/pires/go-proxyproto" ) func TestExtractPSCConnectionID(t *testing.T) { tests := []struct { name string tlvs []proxyproto.TLV wantPSCConnectionID uint64 wantFound bool }{ { name: "nil TLVs", tlvs: nil, wantFound: false, }, { name: "empty TLVs", tlvs: []proxyproto.TLV{}, wantFound: false, }, { name: "AWS VPC endpoint ID", tlvs: []proxyproto.TLV{ { Type: 0xEA, Value: []byte{0x01, 0x76, 0x70, 0x63, 0x65, 0x2d, 0x61, 0x62, 0x63, 0x31, 0x32, 0x33}, }, }, wantFound: false, }, { name: "GCP link ID", tlvs: []proxyproto.TLV{ { Type: PP2_TYPE_GCP, Value: []byte{'\xff', '\xff', '\xff', '\xff', '\xc0', '\xa8', '\x64', '\x02'}, }, }, wantPSCConnectionID: 18446744072646845442, wantFound: true, }, { name: "Multiple TLVs", tlvs: []proxyproto.TLV{ { // AWS Type: 0xEA, Value: []byte{0x01, 0x76, 0x70, 0x63, 0x65, 0x2d, 0x61, 0x62, 0x63, 0x31, 0x32, 0x33}, }, { // Azure Type: 0xEE, Value: []byte{0x02, 0x01, 0x01, 0x01, 0x01}, }, { // GCP but wrong length Type: 0xE0, Value: []byte{0xff, 0xff, 0xff}, }, { // Correct Type: 0xE0, Value: []byte{'\xff', '\xff', '\xff', '\xff', '\xc0', '\xa8', '\x64', '\x02'}, }, }, wantPSCConnectionID: 18446744072646845442, wantFound: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { linkID, hasLinkID := ExtractPSCConnectionID(tt.tlvs) if hasLinkID != tt.wantFound { t.Errorf("ExtractPSCConnectionID() got1 = %v, want %v", hasLinkID, tt.wantFound) } if linkID != tt.wantPSCConnectionID { t.Errorf("ExtractPSCConnectionID() got = %v, want %v", linkID, tt.wantPSCConnectionID) } }) } } go-proxyproto-0.7.0/tlvparse/ssl.go000066400000000000000000000127741440432001500173620ustar00rootroot00000000000000package tlvparse import ( "encoding/binary" "unicode" "unicode/utf8" "github.com/pires/go-proxyproto" ) const ( // pp2_tlv_ssl.client bit fields PP2_BITFIELD_CLIENT_SSL uint8 = 0x01 PP2_BITFIELD_CLIENT_CERT_CONN uint8 = 0x02 PP2_BITFIELD_CLIENT_CERT_SESS uint8 = 0x04 tlvSSLMinLen = 5 // len(pp2_tlv_ssl.client) + len(pp2_tlv_ssl.verify) ) // 2.2.5. The PP2_TYPE_SSL type and subtypes /* struct pp2_tlv_ssl { uint8_t client; uint32_t verify; struct pp2_tlv sub_tlv[0]; }; */ type PP2SSL struct { Client uint8 // The field is made of a bit field from the following values, // indicating which element is present: PP2_BITFIELD_CLIENT_SSL, // PP2_BITFIELD_CLIENT_CERT_CONN, PP2_BITFIELD_CLIENT_CERT_SESS Verify uint32 // Verify will be zero if the client presented a certificate // and it was successfully verified, and non-zero otherwise. TLV []proxyproto.TLV } // Verified is true if the client presented a certificate and it was successfully verified func (s PP2SSL) Verified() bool { return s.Verify == 0 } // ClientSSL indicates that the client connected over SSL/TLS. When true, SSLVersion will return the version. func (s PP2SSL) ClientSSL() bool { return s.Client&PP2_BITFIELD_CLIENT_SSL == PP2_BITFIELD_CLIENT_SSL } // ClientCertConn indicates that the client provided a certificate over the current connection. func (s PP2SSL) ClientCertConn() bool { return s.Client&PP2_BITFIELD_CLIENT_CERT_CONN == PP2_BITFIELD_CLIENT_CERT_CONN } // ClientCertSess indicates that the client provided a certificate at least once over the TLS session this // connection belongs to. func (s PP2SSL) ClientCertSess() bool { return s.Client&PP2_BITFIELD_CLIENT_CERT_SESS == PP2_BITFIELD_CLIENT_CERT_SESS } // SSLVersion returns the US-ASCII string representation of the TLS version and whether that extension exists. func (s PP2SSL) SSLVersion() (string, bool) { for _, tlv := range s.TLV { if tlv.Type == proxyproto.PP2_SUBTYPE_SSL_VERSION { return string(tlv.Value), true } } return "", false } // SSLCipher returns the US-ASCII string representation of the used TLS cipher and whether that extension exists. func (s PP2SSL) SSLCipher() (string, bool) { for _, tlv := range s.TLV { if tlv.Type == proxyproto.PP2_SUBTYPE_SSL_CIPHER { return string(tlv.Value), true } } return "", false } // Marshal formats the PP2SSL structure as a TLV. func (s PP2SSL) Marshal() (proxyproto.TLV, error) { v := make([]byte, 5) v[0] = s.Client binary.BigEndian.PutUint32(v[1:5], s.Verify) tlvs, err := proxyproto.JoinTLVs(s.TLV) if err != nil { return proxyproto.TLV{}, err } v = append(v, tlvs...) return proxyproto.TLV{ Type: proxyproto.PP2_TYPE_SSL, Value: v, }, nil } // ClientCN returns the string representation (in UTF8) of the Common Name field (OID: 2.5.4.3) of the client // certificate's Distinguished Name and whether that extension exists. func (s PP2SSL) ClientCN() (string, bool) { for _, tlv := range s.TLV { if tlv.Type == proxyproto.PP2_SUBTYPE_SSL_CN { return string(tlv.Value), true } } return "", false } // SSLType is true if the TLV is type SSL func IsSSL(t proxyproto.TLV) bool { return t.Type == proxyproto.PP2_TYPE_SSL && len(t.Value) >= tlvSSLMinLen } // SSL returns the pp2_tlv_ssl from section 2.2.5 or errors with ErrIncompatibleTLV or ErrMalformedTLV func SSL(t proxyproto.TLV) (PP2SSL, error) { ssl := PP2SSL{} if !IsSSL(t) { return ssl, proxyproto.ErrIncompatibleTLV } if len(t.Value) < tlvSSLMinLen { return ssl, proxyproto.ErrMalformedTLV } ssl.Client = t.Value[0] ssl.Verify = binary.BigEndian.Uint32(t.Value[1:5]) var err error ssl.TLV, err = proxyproto.SplitTLVs(t.Value[5:]) if err != nil { return PP2SSL{}, err } versionFound := !ssl.ClientSSL() for _, tlv := range ssl.TLV { switch tlv.Type { case proxyproto.PP2_SUBTYPE_SSL_VERSION: /* The PP2_CLIENT_SSL flag indicates that the client connected over SSL/TLS. When this field is present, the US-ASCII string representation of the TLS version is appended at the end of the field in the TLV format using the type PP2_SUBTYPE_SSL_VERSION. */ if len(tlv.Value) == 0 || !isASCII(tlv.Value) { return PP2SSL{}, proxyproto.ErrMalformedTLV } versionFound = true case proxyproto.PP2_SUBTYPE_SSL_CN: /* In all cases, the string representation (in UTF8) of the Common Name field (OID: 2.5.4.3) of the client certificate's Distinguished Name, is appended using the TLV format and the type PP2_SUBTYPE_SSL_CN. E.g. "example.com". */ if len(tlv.Value) == 0 || !utf8.Valid(tlv.Value) { return PP2SSL{}, proxyproto.ErrMalformedTLV } case proxyproto.PP2_SUBTYPE_SSL_CIPHER: /* The second level TLV PP2_SUBTYPE_SSL_CIPHER provides the US-ASCII string name of the used cipher, for example "ECDHE-RSA-AES128-GCM-SHA256". */ if len(tlv.Value) == 0 || !isASCII(tlv.Value) { return PP2SSL{}, proxyproto.ErrMalformedTLV } } } if !versionFound { return PP2SSL{}, proxyproto.ErrMalformedTLV } return ssl, nil } // SSL returns the first PP2SSL if it exists and is well formed as well as bool indicating if it was found. func FindSSL(tlvs []proxyproto.TLV) (PP2SSL, bool) { for _, t := range tlvs { if ssl, err := SSL(t); err == nil { return ssl, true } } return PP2SSL{}, false } // isASCII checks whether a byte slice has all characters that fit in the ascii character set, including the null byte. func isASCII(b []byte) bool { for _, c := range b { if c > unicode.MaxASCII { return false } } return true } go-proxyproto-0.7.0/tlvparse/ssl_test.go000066400000000000000000000124061440432001500204110ustar00rootroot00000000000000package tlvparse import ( "reflect" "testing" "github.com/pires/go-proxyproto" ) var testCases = []struct { name string raw []byte types []proxyproto.PP2Type valid func(*testing.T, string, []proxyproto.TLV) }{ { name: "SSL haproxy cn", raw: []byte{ 0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, 0x54, 0x0a, 0x21, 0x11, 0x00, 0x40, 0x7f, 0x00, 0x00, 0x01, 0x7f, 0x00, 0x00, 0x01, 0xcc, 0x8a, 0x23, 0x2e, 0x20, 0x00, 0x31, 0x07, 0x00, 0x00, 0x00, 0x00, 0x21, 0x00, 0x07, 0x54, 0x4c, 0x53, 0x76, 0x31, 0x2e, 0x33, 0x22, 0x00, 0x1f, 0x45, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x20, 0x43, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x20, 0x4e, 0x61, 0x6d, 0x65, 0x20, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x20, 0x43, 0x65, 0x72, 0x74, }, types: []proxyproto.PP2Type{proxyproto.PP2_TYPE_SSL}, valid: func(t *testing.T, name string, tlvs []proxyproto.TLV) { if !IsSSL(tlvs[0]) { t.Fatalf("TestParseV2TLV %s: Expected tlvs[0] to be the SSL type", name) } ssl, err := SSL(tlvs[0]) if err != nil { t.Fatalf("TestParseV2TLV %s: Unexpected error when parsing SSL %#v", name, err) } if !ssl.ClientSSL() { t.Fatalf("TestParseV2TLV %s: Expected ClientSSL() to be true", name) } if !ssl.ClientCertConn() { t.Fatalf("TestParseV2TLV %s: Expected ClientCertConn() to be true", name) } if !ssl.ClientCertSess() { t.Fatalf("TestParseV2TLV %s: Expected ClientCertSess() to be true", name) } ecn := "Example Common Name Client Cert" if acn, ok := ssl.ClientCN(); !ok { t.Fatalf("TestParseV2TLV %s: Expected ClientCN to exist", name) } else if acn != ecn { t.Fatalf("TestParseV2TLV %s: Unexpected ClientCN expected %#v, actual %#v", name, ecn, acn) } esslVer := "TLSv1.3" if asslVer, ok := ssl.SSLVersion(); !ok { t.Fatalf("TestParseV2TLV %s: Expected SSLVersion to exist", name) } else if asslVer != esslVer { t.Fatalf("TestParseV2TLV %s: Unexpected SSLVersion expected %#v, actual %#v", name, esslVer, asslVer) } if _, ok := ssl.SSLCipher(); ok { t.Fatalf("TestParseV2TLV %s: Unexpected SSLCipher", name) } if !ssl.Verified() { t.Fatalf("TestParseV2TLV %s: Expected Verified to be true", name) } }, }, { name: "SSL haproxy cipher", raw: []byte{ 0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, 0x54, 0x0a, 0x21, 0x21, 0x00, 0x4f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x0a, 0x01, 0x5b, 0x0e, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x0a, 0x01, 0x01, 0x9f, 0xf4, 0x7c, 0x01, 0xbb, 0x20, 0x00, 0x28, 0x01, 0x00, 0x00, 0x00, 0x00, 0x21, 0x00, 0x07, 0x54, 0x4c, 0x53, 0x76, 0x31, 0x2e, 0x33, 0x23, 0x00, 0x16, 0x54, 0x4c, 0x53, 0x5f, 0x41, 0x45, 0x53, 0x5f, 0x32, 0x35, 0x36, 0x5f, 0x47, 0x43, 0x4d, 0x5f, 0x53, 0x48, 0x41, 0x33, 0x38, 0x34, }, types: []proxyproto.PP2Type{proxyproto.PP2_TYPE_SSL}, valid: func(t *testing.T, name string, tlvs []proxyproto.TLV) { if !IsSSL(tlvs[0]) { t.Fatalf("TestParseV2TLV %s: Expected tlvs[0] to be the SSL type", name) } ssl, err := SSL(tlvs[0]) if err != nil { t.Fatalf("TestParseV2TLV %s: Unexpected error when parsing SSL %#v", name, err) } if !ssl.ClientSSL() { t.Fatalf("TestParseV2TLV %s: Expected ClientSSL() to be true", name) } if ssl.ClientCertConn() { t.Fatalf("TestParseV2TLV %s: Expected ClientCertConn() to be false", name) } if ssl.ClientCertSess() { t.Fatalf("TestParseV2TLV %s: Expected ClientCertSess() to be false", name) } if _, ok := ssl.ClientCN(); ok { t.Fatalf("TestParseV2TLV %s: Expected ClientCN to not exist", name) } esslVer := "TLSv1.3" if asslVer, ok := ssl.SSLVersion(); !ok { t.Fatalf("TestParseV2TLV %s: Expected SSLVersion to exist", name) } else if asslVer != esslVer { t.Fatalf("TestParseV2TLV %s: Unexpected SSLVersion expected %#v, actual %#v", name, esslVer, asslVer) } esslCipher := "TLS_AES_256_GCM_SHA384" if asslCipher, ok := ssl.SSLCipher(); !ok { t.Fatalf("TestParseV2TLV %s: Expected SSLCipher to exist", name) } else if asslCipher != esslCipher { t.Fatalf("TestParseV2TLV %s: Unexpected SSLCipher expected %#v, actual %#v", name, esslCipher, asslCipher) } }, }, } func TestParseV2TLV(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { tlvs := checkTLVs(t, tc.name, tc.raw, tc.types) tc.valid(t, tc.name, tlvs) }) } } func TestPP2SSLMarshal(t *testing.T) { ver := "TLSv1.3" cn := "example.org" pp2 := PP2SSL{ Client: PP2_BITFIELD_CLIENT_SSL, Verify: 0, TLV: []proxyproto.TLV{ { Type: proxyproto.PP2_SUBTYPE_SSL_VERSION, Value: []byte(ver), }, { Type: proxyproto.PP2_SUBTYPE_SSL_CN, Value: []byte(cn), }, }, } raw := []byte{0x1, 0x0, 0x0, 0x0, 0x0, 0x21, 0x0, 0x7, 0x54, 0x4c, 0x53, 0x76, 0x31, 0x2e, 0x33, 0x22, 0x0, 0xb, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x6f, 0x72, 0x67} want := proxyproto.TLV{ Type: proxyproto.PP2_TYPE_SSL, Value: raw, } tlv, err := pp2.Marshal() if err != nil { t.Fatalf("PP2SSL.Marshal() = %v", err) } if !reflect.DeepEqual(tlv, want) { t.Errorf("PP2SSL.Marshal() = %#v, want %#v", tlv, want) } } go-proxyproto-0.7.0/tlvparse/test.go000066400000000000000000000013371440432001500175310ustar00rootroot00000000000000package tlvparse import ( "bufio" "bytes" "testing" "github.com/pires/go-proxyproto" ) func checkTLVs(t *testing.T, name string, raw []byte, expected []proxyproto.PP2Type) []proxyproto.TLV { header, err := proxyproto.Read(bufio.NewReader(bytes.NewReader(raw))) if err != nil { t.Fatalf("%s: Unexpected error reading header %#v", name, err) } tlvs, err := header.TLVs() if err != nil { t.Fatalf("%s: Unexpected error splitting TLVS %#v", name, err) } if len(tlvs) != len(expected) { t.Fatalf("%s: Expected %d TLVs, actual %d", name, len(expected), len(tlvs)) } for i, et := range expected { if at := tlvs[i].Type; at != et { t.Fatalf("%s: Expected type %X, actual %X", name, et, at) } } return tlvs } go-proxyproto-0.7.0/v1.go000066400000000000000000000157111440432001500152410ustar00rootroot00000000000000package proxyproto import ( "bufio" "bytes" "fmt" "net" "net/netip" "strconv" "strings" ) const ( crlf = "\r\n" separator = " " ) func initVersion1() *Header { header := new(Header) header.Version = 1 // Command doesn't exist in v1 header.Command = PROXY return header } func parseVersion1(reader *bufio.Reader) (*Header, error) { //The header cannot be more than 107 bytes long. Per spec: // // (...) // - worst case (optional fields set to 0xff) : // "PROXY UNKNOWN ffff:f...f:ffff ffff:f...f:ffff 65535 65535\r\n" // => 5 + 1 + 7 + 1 + 39 + 1 + 39 + 1 + 5 + 1 + 5 + 2 = 107 chars // // So a 108-byte buffer is always enough to store all the line and a // trailing zero for string processing. // // It must also be CRLF terminated, as above. The header does not otherwise // contain a CR or LF byte. // // ISSUE #69 // We can't use Peek here as it will block trying to fill the buffer, which // will never happen if the header is TCP4 or TCP6 (max. 56 and 104 bytes // respectively) and the server is expected to speak first. // // Similarly, we can't use ReadString or ReadBytes as these will keep reading // until the delimiter is found; an abusive client could easily disrupt a // server by sending a large amount of data that do not contain a LF byte. // Another means of attack would be to start connections and simply not send // data after the initial PROXY signature bytes, accumulating a large // number of blocked goroutines on the server. ReadSlice will also block for // a delimiter when the internal buffer does not fill up. // // A plain Read is also problematic since we risk reading past the end of the // header without being able to easily put the excess bytes back into the reader's // buffer (with the current implementation's design). // // So we use a ReadByte loop, which solves the overflow problem and avoids // reading beyond the end of the header. However, we need one more trick to harden // against partial header attacks (slow loris) - per spec: // // (..) The sender must always ensure that the header is sent at once, so that // the transport layer maintains atomicity along the path to the receiver. The // receiver may be tolerant to partial headers or may simply drop the connection // when receiving a partial header. Recommendation is to be tolerant, but // implementation constraints may not always easily permit this. // // We are subject to such implementation constraints. So we return an error if // the header cannot be fully extracted with a single read of the underlying // reader. buf := make([]byte, 0, 107) for { b, err := reader.ReadByte() if err != nil { return nil, fmt.Errorf(ErrCantReadVersion1Header.Error()+": %v", err) } buf = append(buf, b) if b == '\n' { // End of header found break } if len(buf) == 107 { // No delimiter in first 107 bytes return nil, ErrVersion1HeaderTooLong } if reader.Buffered() == 0 { // Header was not buffered in a single read. Since we can't // differentiate between genuine slow writers and DoS agents, // we abort. On healthy networks, this should never happen. return nil, ErrCantReadVersion1Header } } // Check for CR before LF. if len(buf) < 2 || buf[len(buf)-2] != '\r' { return nil, ErrLineMustEndWithCrlf } // Check full signature. tokens := strings.Split(string(buf[:len(buf)-2]), separator) // Expect at least 2 tokens: "PROXY" and the transport protocol. if len(tokens) < 2 { return nil, ErrCantReadAddressFamilyAndProtocol } // Read address family and protocol var transportProtocol AddressFamilyAndProtocol switch tokens[1] { case "TCP4": transportProtocol = TCPv4 case "TCP6": transportProtocol = TCPv6 case "UNKNOWN": transportProtocol = UNSPEC // doesn't exist in v1 but fits UNKNOWN default: return nil, ErrCantReadAddressFamilyAndProtocol } // Expect 6 tokens only when UNKNOWN is not present. if transportProtocol != UNSPEC && len(tokens) < 6 { return nil, ErrCantReadAddressFamilyAndProtocol } // When a signature is found, allocate a v1 header with Command set to PROXY. // Command doesn't exist in v1 but set it for other parts of this library // to rely on it for determining connection details. header := initVersion1() // Transport protocol has been processed already. header.TransportProtocol = transportProtocol // When UNKNOWN, set the command to LOCAL and return early if header.TransportProtocol == UNSPEC { header.Command = LOCAL return header, nil } // Otherwise, continue to read addresses and ports sourceIP, err := parseV1IPAddress(header.TransportProtocol, tokens[2]) if err != nil { return nil, err } destIP, err := parseV1IPAddress(header.TransportProtocol, tokens[3]) if err != nil { return nil, err } sourcePort, err := parseV1PortNumber(tokens[4]) if err != nil { return nil, err } destPort, err := parseV1PortNumber(tokens[5]) if err != nil { return nil, err } header.SourceAddr = &net.TCPAddr{ IP: sourceIP, Port: sourcePort, } header.DestinationAddr = &net.TCPAddr{ IP: destIP, Port: destPort, } return header, nil } func (header *Header) formatVersion1() ([]byte, error) { // As of version 1, only "TCP4" ( \x54 \x43 \x50 \x34 ) for TCP over IPv4, // and "TCP6" ( \x54 \x43 \x50 \x36 ) for TCP over IPv6 are allowed. var proto string switch header.TransportProtocol { case TCPv4: proto = "TCP4" case TCPv6: proto = "TCP6" default: // Unknown connection (short form) return []byte("PROXY UNKNOWN" + crlf), nil } sourceAddr, sourceOK := header.SourceAddr.(*net.TCPAddr) destAddr, destOK := header.DestinationAddr.(*net.TCPAddr) if !sourceOK || !destOK { return nil, ErrInvalidAddress } sourceIP, destIP := sourceAddr.IP, destAddr.IP switch header.TransportProtocol { case TCPv4: sourceIP = sourceIP.To4() destIP = destIP.To4() case TCPv6: sourceIP = sourceIP.To16() destIP = destIP.To16() } if sourceIP == nil || destIP == nil { return nil, ErrInvalidAddress } buf := bytes.NewBuffer(make([]byte, 0, 108)) buf.Write(SIGV1) buf.WriteString(separator) buf.WriteString(proto) buf.WriteString(separator) buf.WriteString(sourceIP.String()) buf.WriteString(separator) buf.WriteString(destIP.String()) buf.WriteString(separator) buf.WriteString(strconv.Itoa(sourceAddr.Port)) buf.WriteString(separator) buf.WriteString(strconv.Itoa(destAddr.Port)) buf.WriteString(crlf) return buf.Bytes(), nil } func parseV1PortNumber(portStr string) (int, error) { port, err := strconv.Atoi(portStr) if err != nil || port < 0 || port > 65535 { return 0, ErrInvalidPortNumber } return port, nil } func parseV1IPAddress(protocol AddressFamilyAndProtocol, addrStr string) (net.IP, error) { addr, err := netip.ParseAddr(addrStr) if err != nil { return nil, ErrInvalidAddress } switch protocol { case TCPv4: if addr.Is4() { return net.IP(addr.AsSlice()), nil } case TCPv6: if addr.Is6() || addr.Is4In6() { return net.IP(addr.AsSlice()), nil } } return nil, ErrInvalidAddress } go-proxyproto-0.7.0/v1_test.go000066400000000000000000000231521440432001500162760ustar00rootroot00000000000000package proxyproto import ( "bufio" "bytes" "errors" "fmt" "io" "net" "strconv" "strings" "testing" "time" ) var ( IPv4AddressesAndPorts = strings.Join([]string{IP4_ADDR, IP4_ADDR, strconv.Itoa(PORT), strconv.Itoa(PORT)}, separator) IPv4In6AddressesAndPorts = strings.Join([]string{IP4IN6_ADDR, IP4IN6_ADDR, strconv.Itoa(PORT), strconv.Itoa(PORT)}, separator) IPv4AddressesAndInvalidPorts = strings.Join([]string{IP4_ADDR, IP4_ADDR, strconv.Itoa(INVALID_PORT), strconv.Itoa(INVALID_PORT)}, separator) IPv6AddressesAndPorts = strings.Join([]string{IP6_ADDR, IP6_ADDR, strconv.Itoa(PORT), strconv.Itoa(PORT)}, separator) IPv6LongAddressesAndPorts = strings.Join([]string{IP6_LONG_ADDR, IP6_LONG_ADDR, strconv.Itoa(PORT), strconv.Itoa(PORT)}, separator) fixtureTCP4V1 = "PROXY TCP4 " + IPv4AddressesAndPorts + crlf + "GET /" fixtureTCP6V1 = "PROXY TCP6 " + IPv6AddressesAndPorts + crlf + "GET /" fixtureTCP4IN6V1 = "PROXY TCP6 " + IPv4In6AddressesAndPorts + crlf + "GET /" fixtureTCP6V1Overflow = "PROXY TCP6 " + IPv6LongAddressesAndPorts fixtureUnknown = "PROXY UNKNOWN" + crlf fixtureUnknownWithAddresses = "PROXY UNKNOWN " + IPv4AddressesAndInvalidPorts + crlf ) var invalidParseV1Tests = []struct { desc string reader *bufio.Reader expectedError error }{ { desc: "no signature", reader: newBufioReader([]byte(NO_PROTOCOL)), expectedError: ErrNoProxyProtocol, }, { desc: "prox", reader: newBufioReader([]byte("PROX")), expectedError: ErrNoProxyProtocol, }, { desc: "proxy lf", reader: newBufioReader([]byte("PROXY \n")), expectedError: ErrLineMustEndWithCrlf, }, { desc: "proxy crlf", reader: newBufioReader([]byte("PROXY " + crlf)), expectedError: ErrCantReadAddressFamilyAndProtocol, }, { desc: "proxy no space crlf", reader: newBufioReader([]byte("PROXY" + crlf)), expectedError: ErrCantReadAddressFamilyAndProtocol, }, { desc: "proxy something crlf", reader: newBufioReader([]byte("PROXY SOMETHING" + crlf)), expectedError: ErrCantReadAddressFamilyAndProtocol, }, { desc: "incomplete signature TCP4", reader: newBufioReader([]byte("PROXY TCP4 " + IPv4AddressesAndPorts)), expectedError: ErrCantReadVersion1Header, }, { desc: "invalid IP address", reader: newBufioReader([]byte("PROXY TCP4 invalid invalid 65533 65533" + crlf)), expectedError: ErrInvalidAddress, }, { desc: "TCP6 with IPv4 addresses", reader: newBufioReader([]byte("PROXY TCP6 " + IPv4AddressesAndPorts + crlf)), expectedError: ErrInvalidAddress, }, { desc: "TCP4 with IPv6 addresses", reader: newBufioReader([]byte("PROXY TCP4 " + IPv6AddressesAndPorts + crlf)), expectedError: ErrInvalidAddress, }, { desc: "TCP4 with IPv4 mapped addresses", reader: newBufioReader([]byte("PROXY TCP4 " + IPv4In6AddressesAndPorts + crlf)), expectedError: ErrInvalidAddress, }, { desc: "TCP4 with invalid port", reader: newBufioReader([]byte("PROXY TCP4 " + IPv4AddressesAndInvalidPorts + crlf)), expectedError: ErrInvalidPortNumber, }, { desc: "header too long", reader: newBufioReader([]byte("PROXY UNKNOWN " + IPv6LongAddressesAndPorts + " " + crlf)), expectedError: ErrVersion1HeaderTooLong, }, } func TestReadV1Invalid(t *testing.T) { for _, tt := range invalidParseV1Tests { t.Run(tt.desc, func(t *testing.T) { if _, err := Read(tt.reader); err != tt.expectedError { t.Fatalf("expected %s, actual %v", tt.expectedError, err) } }) } } var validParseAndWriteV1Tests = []struct { desc string reader *bufio.Reader expectedHeader *Header skipWrite bool }{ { desc: "TCP4", reader: bufio.NewReader(strings.NewReader(fixtureTCP4V1)), expectedHeader: &Header{ Version: 1, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: v4addr, DestinationAddr: v4addr, }, }, { desc: "TCP6", reader: bufio.NewReader(strings.NewReader(fixtureTCP6V1)), expectedHeader: &Header{ Version: 1, Command: PROXY, TransportProtocol: TCPv6, SourceAddr: v6addr, DestinationAddr: v6addr, }, }, { desc: "TCP4IN6", reader: bufio.NewReader(strings.NewReader(fixtureTCP4IN6V1)), expectedHeader: &Header{ Version: 1, Command: PROXY, TransportProtocol: TCPv6, SourceAddr: v4addr, DestinationAddr: v4addr, }, // we skip write test because net.ParseIP converts ::ffff:127.0.0.1 to v4 // instead of preserving the v4 in v6 form, so, after serializing the header, // we end up with v6 protocol and a v4 IP which is invalid skipWrite: true, }, { desc: "unknown", reader: bufio.NewReader(strings.NewReader(fixtureUnknown)), expectedHeader: &Header{ Version: 1, Command: LOCAL, TransportProtocol: UNSPEC, SourceAddr: nil, DestinationAddr: nil, }, }, { desc: "unknown with addresses and ports", reader: bufio.NewReader(strings.NewReader(fixtureUnknownWithAddresses)), expectedHeader: &Header{ Version: 1, Command: LOCAL, TransportProtocol: UNSPEC, SourceAddr: nil, DestinationAddr: nil, }, }, } func TestParseV1Valid(t *testing.T) { for _, tt := range validParseAndWriteV1Tests { t.Run(tt.desc, func(t *testing.T) { header, err := Read(tt.reader) if err != nil { t.Fatal("unexpected error", err.Error()) } if !header.EqualsTo(tt.expectedHeader) { t.Fatalf("expected %#v, actual %#v", tt.expectedHeader, header) } }) } } func TestWriteV1Valid(t *testing.T) { for _, tt := range validParseAndWriteV1Tests { if tt.skipWrite { continue } t.Run(tt.desc, func(t *testing.T) { var b bytes.Buffer w := bufio.NewWriter(&b) if _, err := tt.expectedHeader.WriteTo(w); err != nil { t.Fatal("unexpected error ", err) } w.Flush() // Read written bytes to validate written header r := bufio.NewReader(&b) newHeader, err := Read(r) if err != nil { t.Fatal("unexpected error ", err) } if !newHeader.EqualsTo(tt.expectedHeader) { t.Fatalf("expected %#v, actual %#v", tt.expectedHeader, newHeader) } }) } } // Tests for parseVersion1 overflow - issue #69. type dataSource struct { NBytes int NRead int } func (ds *dataSource) Read(b []byte) (int, error) { if ds.NRead >= ds.NBytes { return 0, io.EOF } avail := ds.NBytes - ds.NRead if len(b) < avail { avail = len(b) } for i := 0; i < avail; i++ { b[i] = 0x20 } ds.NRead += avail return avail, nil } func TestParseVersion1Overflow(t *testing.T) { ds := &dataSource{} reader := bufio.NewReader(ds) bufSize := reader.Size() ds.NBytes = bufSize * 16 _, _ = parseVersion1(reader) if ds.NRead > bufSize { t.Fatalf("read: expected max %d bytes, actual %d\n", bufSize, ds.NRead) } } func listen(t *testing.T) *Listener { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("listen: %v", err) } return &Listener{Listener: l} } func client(t *testing.T, addr, header string, length int, terminate bool, wait time.Duration, done chan struct{}, result chan error, ) { c, err := net.Dial("tcp", addr) if err != nil { result <- fmt.Errorf("dial: %w", err) return } defer c.Close() if terminate && length < 2 { length = 2 } buf := make([]byte, len(header)+length) copy(buf, []byte(header)) for i := 0; i < length-2; i++ { buf[i+len(header)] = 0x20 } if terminate { copy(buf[len(header)+length-2:], []byte(crlf)) } n, err := c.Write(buf) if err != nil { result <- fmt.Errorf("write: %w", err) return } if n != len(buf) { result <- errors.New("write; short write") return } close(result) time.Sleep(wait) close(done) } func TestVersion1Overflow(t *testing.T) { done := make(chan struct{}) cliResult := make(chan error) l := listen(t) go client(t, l.Addr().String(), fixtureTCP6V1Overflow, 10240, true, 10*time.Second, done, cliResult) c, err := l.Accept() if err != nil { t.Fatalf("accept: %v", err) } b := []byte{} _, err = c.Read(b) if err == nil { t.Fatalf("net.Conn: no error reported for oversized header") } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } } func TestVersion1SlowLoris(t *testing.T) { done := make(chan struct{}) cliResult := make(chan error) timeout := make(chan error) l := listen(t) go client(t, l.Addr().String(), fixtureTCP6V1Overflow, 0, false, 10*time.Second, done, cliResult) c, err := l.Accept() if err != nil { t.Fatalf("accept: %v", err) } go func() { b := []byte{} _, err = c.Read(b) timeout <- err }() select { case <-done: t.Fatalf("net.Conn: reader still blocked after 10 seconds") case err := <-timeout: if err == nil { t.Fatalf("net.Conn: no error reported for incomplete header") } } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } } func TestVersion1SlowLorisOverflow(t *testing.T) { done := make(chan struct{}) cliResult := make(chan error) timeout := make(chan error) l := listen(t) go client(t, l.Addr().String(), fixtureTCP6V1Overflow, 10240, false, 10*time.Second, done, cliResult) c, err := l.Accept() if err != nil { t.Fatalf("accept: %v", err) } go func() { b := []byte{} _, err = c.Read(b) timeout <- err }() select { case <-done: t.Fatalf("net.Conn: reader still blocked after 10 seconds") case err := <-timeout: if err == nil { t.Fatalf("net.Conn: no error reported for incomplete and overflowed header") } } err = <-cliResult if err != nil { t.Fatalf("client error: %v", err) } } go-proxyproto-0.7.0/v2.go000066400000000000000000000165031440432001500152420ustar00rootroot00000000000000package proxyproto import ( "bufio" "bytes" "encoding/binary" "errors" "io" "net" ) var ( lengthUnspec = uint16(0) lengthV4 = uint16(12) lengthV6 = uint16(36) lengthUnix = uint16(216) lengthUnspecBytes = func() []byte { a := make([]byte, 2) binary.BigEndian.PutUint16(a, lengthUnspec) return a }() lengthV4Bytes = func() []byte { a := make([]byte, 2) binary.BigEndian.PutUint16(a, lengthV4) return a }() lengthV6Bytes = func() []byte { a := make([]byte, 2) binary.BigEndian.PutUint16(a, lengthV6) return a }() lengthUnixBytes = func() []byte { a := make([]byte, 2) binary.BigEndian.PutUint16(a, lengthUnix) return a }() errUint16Overflow = errors.New("proxyproto: uint16 overflow") ) type _ports struct { SrcPort uint16 DstPort uint16 } type _addr4 struct { Src [4]byte Dst [4]byte SrcPort uint16 DstPort uint16 } type _addr6 struct { Src [16]byte Dst [16]byte _ports } type _addrUnix struct { Src [108]byte Dst [108]byte } func parseVersion2(reader *bufio.Reader) (header *Header, err error) { // Skip first 12 bytes (signature) for i := 0; i < 12; i++ { if _, err = reader.ReadByte(); err != nil { return nil, ErrCantReadProtocolVersionAndCommand } } header = new(Header) header.Version = 2 // Read the 13th byte, protocol version and command b13, err := reader.ReadByte() if err != nil { return nil, ErrCantReadProtocolVersionAndCommand } header.Command = ProtocolVersionAndCommand(b13) if _, ok := supportedCommand[header.Command]; !ok { return nil, ErrUnsupportedProtocolVersionAndCommand } // Read the 14th byte, address family and protocol b14, err := reader.ReadByte() if err != nil { return nil, ErrCantReadAddressFamilyAndProtocol } header.TransportProtocol = AddressFamilyAndProtocol(b14) // UNSPEC is only supported when LOCAL is set. if header.TransportProtocol == UNSPEC && header.Command != LOCAL { return nil, ErrUnsupportedAddressFamilyAndProtocol } // Make sure there are bytes available as specified in length var length uint16 if err := binary.Read(io.LimitReader(reader, 2), binary.BigEndian, &length); err != nil { return nil, ErrCantReadLength } if !header.validateLength(length) { return nil, ErrInvalidLength } // Return early if the length is zero, which means that // there's no address information and TLVs present for UNSPEC. if length == 0 { return header, nil } if _, err := reader.Peek(int(length)); err != nil { return nil, ErrInvalidLength } // Length-limited reader for payload section payloadReader := io.LimitReader(reader, int64(length)).(*io.LimitedReader) // Read addresses and ports for protocols other than UNSPEC. // Ignore address information for UNSPEC, and skip straight to read TLVs, // since the length is greater than zero. if header.TransportProtocol != UNSPEC { if header.TransportProtocol.IsIPv4() { var addr _addr4 if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil { return nil, ErrInvalidAddress } header.SourceAddr = newIPAddr(header.TransportProtocol, addr.Src[:], addr.SrcPort) header.DestinationAddr = newIPAddr(header.TransportProtocol, addr.Dst[:], addr.DstPort) } else if header.TransportProtocol.IsIPv6() { var addr _addr6 if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil { return nil, ErrInvalidAddress } header.SourceAddr = newIPAddr(header.TransportProtocol, addr.Src[:], addr.SrcPort) header.DestinationAddr = newIPAddr(header.TransportProtocol, addr.Dst[:], addr.DstPort) } else if header.TransportProtocol.IsUnix() { var addr _addrUnix if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil { return nil, ErrInvalidAddress } network := "unix" if header.TransportProtocol.IsDatagram() { network = "unixgram" } header.SourceAddr = &net.UnixAddr{ Net: network, Name: parseUnixName(addr.Src[:]), } header.DestinationAddr = &net.UnixAddr{ Net: network, Name: parseUnixName(addr.Dst[:]), } } } // Copy bytes for optional Type-Length-Value vector header.rawTLVs = make([]byte, payloadReader.N) // Allocate minimum size slice if _, err = io.ReadFull(payloadReader, header.rawTLVs); err != nil && err != io.EOF { return nil, err } return header, nil } func (header *Header) formatVersion2() ([]byte, error) { var buf bytes.Buffer buf.Write(SIGV2) buf.WriteByte(header.Command.toByte()) buf.WriteByte(header.TransportProtocol.toByte()) if header.TransportProtocol.IsUnspec() { // For UNSPEC, write no addresses and ports but only TLVs if they are present hdrLen, err := addTLVLen(lengthUnspecBytes, len(header.rawTLVs)) if err != nil { return nil, err } buf.Write(hdrLen) } else { var addrSrc, addrDst []byte if header.TransportProtocol.IsIPv4() { hdrLen, err := addTLVLen(lengthV4Bytes, len(header.rawTLVs)) if err != nil { return nil, err } buf.Write(hdrLen) sourceIP, destIP, _ := header.IPs() addrSrc = sourceIP.To4() addrDst = destIP.To4() } else if header.TransportProtocol.IsIPv6() { hdrLen, err := addTLVLen(lengthV6Bytes, len(header.rawTLVs)) if err != nil { return nil, err } buf.Write(hdrLen) sourceIP, destIP, _ := header.IPs() addrSrc = sourceIP.To16() addrDst = destIP.To16() } else if header.TransportProtocol.IsUnix() { buf.Write(lengthUnixBytes) sourceAddr, destAddr, ok := header.UnixAddrs() if !ok { return nil, ErrInvalidAddress } addrSrc = formatUnixName(sourceAddr.Name) addrDst = formatUnixName(destAddr.Name) } if addrSrc == nil || addrDst == nil { return nil, ErrInvalidAddress } buf.Write(addrSrc) buf.Write(addrDst) if sourcePort, destPort, ok := header.Ports(); ok { portBytes := make([]byte, 2) binary.BigEndian.PutUint16(portBytes, uint16(sourcePort)) buf.Write(portBytes) binary.BigEndian.PutUint16(portBytes, uint16(destPort)) buf.Write(portBytes) } } if len(header.rawTLVs) > 0 { buf.Write(header.rawTLVs) } return buf.Bytes(), nil } func (header *Header) validateLength(length uint16) bool { if header.TransportProtocol.IsIPv4() { return length >= lengthV4 } else if header.TransportProtocol.IsIPv6() { return length >= lengthV6 } else if header.TransportProtocol.IsUnix() { return length >= lengthUnix } else if header.TransportProtocol.IsUnspec() { return length >= lengthUnspec } return false } // addTLVLen adds the length of the TLV to the header length or errors on uint16 overflow. func addTLVLen(cur []byte, tlvLen int) ([]byte, error) { if tlvLen == 0 { return cur, nil } curLen := binary.BigEndian.Uint16(cur) newLen := int(curLen) + tlvLen if newLen >= 1<<16 { return nil, errUint16Overflow } a := make([]byte, 2) binary.BigEndian.PutUint16(a, uint16(newLen)) return a, nil } func newIPAddr(transport AddressFamilyAndProtocol, ip net.IP, port uint16) net.Addr { if transport.IsStream() { return &net.TCPAddr{IP: ip, Port: int(port)} } else if transport.IsDatagram() { return &net.UDPAddr{IP: ip, Port: int(port)} } else { return nil } } func parseUnixName(b []byte) string { i := bytes.IndexByte(b, 0) if i < 0 { return string(b) } return string(b[:i]) } func formatUnixName(name string) []byte { n := int(lengthUnix) / 2 if len(name) >= n { return []byte(name[:n]) } pad := make([]byte, n-len(name)) return append([]byte(name), pad...) } go-proxyproto-0.7.0/v2_test.go000066400000000000000000000345071440432001500163050ustar00rootroot00000000000000package proxyproto import ( "bufio" "bytes" "encoding/binary" "math/rand" "reflect" "testing" ) var ( invalidRune = byte('\x99') // Lengths to use in tests lengthPadded = uint16(84) lengthEmptyBytes = func() []byte { a := make([]byte, 2) binary.BigEndian.PutUint16(a, 0) return a }() lengthPaddedBytes = func() []byte { a := make([]byte, 2) binary.BigEndian.PutUint16(a, lengthPadded) return a }() // If life gives you lemons, make mojitos portBytes = func() []byte { a := make([]byte, 2) binary.BigEndian.PutUint16(a, PORT) return a }() unixBytes = pad([]byte("socket"), 108) // Tests don't care if source and destination addresses and ports are the same addressesIPv4 = append(v4ip.To4(), v4ip.To4()...) addressesIPv6 = append(v6ip.To16(), v6ip.To16()...) ports = append(portBytes, portBytes...) // Fixtures to use in tests fixtureIPv4Address = append(addressesIPv4, ports...) fixtureIPv4V2 = append(lengthV4Bytes, fixtureIPv4Address...) fixtureIPv4V2Padded = append(append(lengthPaddedBytes, fixtureIPv4Address...), make([]byte, lengthPadded-lengthV4)...) fixtureIPv6Address = append(addressesIPv6, ports...) fixtureIPv6V2 = append(lengthV6Bytes, fixtureIPv6Address...) fixtureIPv6V2Padded = append(append(lengthPaddedBytes, fixtureIPv6Address...), make([]byte, lengthPadded-lengthV6)...) fixtureUnixAddress = append(unixBytes, unixBytes...) fixtureUnixV2 = append(lengthUnixBytes, fixtureUnixAddress...) fixtureTLV = func() []byte { tlv := make([]byte, 2+rand.Intn(1<<12)) // Not enough to overflow, at least size two rand.Read(tlv) return tlv }() fixtureIPv4V2TLV = fixtureWithTLV(lengthV4Bytes, fixtureIPv4Address, fixtureTLV) fixtureIPv6V2TLV = fixtureWithTLV(lengthV6Bytes, fixtureIPv6Address, fixtureTLV) fixtureUnspecTLV = fixtureWithTLV(lengthUnspecBytes, []byte{}, fixtureTLV) // Arbitrary bytes following proxy bytes arbitraryTailBytes = []byte{'\x99', '\x97', '\x98'} ) func pad(b []byte, n int) []byte { padding := make([]byte, n-len(b)) return append(b, padding...) } var invalidParseV2Tests = []struct { desc string reader *bufio.Reader expectedError error }{ { desc: "no signature", reader: newBufioReader([]byte(NO_PROTOCOL)), expectedError: ErrNoProxyProtocol, }, { desc: "truncated v2 signature", reader: newBufioReader(SIGV2[2:]), expectedError: ErrNoProxyProtocol, }, { desc: "v2 signature and nothing else", reader: newBufioReader(SIGV2), expectedError: ErrCantReadProtocolVersionAndCommand, }, { desc: "v2 signature with invalid command", reader: newBufioReader(append(SIGV2, invalidRune)), expectedError: ErrUnsupportedProtocolVersionAndCommand, }, { desc: "v2 signature with command but nothing else", reader: newBufioReader(append(SIGV2, byte(PROXY))), expectedError: ErrCantReadAddressFamilyAndProtocol, }, { desc: "command proxy but inet family unspec", reader: newBufioReader(append(SIGV2, byte(PROXY), byte(UNSPEC))), expectedError: ErrUnsupportedAddressFamilyAndProtocol, }, { desc: "v2 signature with command and invalid inet family", // translated to UNSPEC reader: newBufioReader(append(SIGV2, byte(PROXY), invalidRune)), expectedError: ErrCantReadLength, }, { desc: "TCPv4 but no length", reader: newBufioReader(append(SIGV2, byte(PROXY), byte(TCPv4))), expectedError: ErrCantReadLength, }, { desc: "TCPv4 but invalid length", reader: newBufioReader(append(SIGV2, byte(PROXY), byte(TCPv4), invalidRune)), expectedError: ErrCantReadLength, }, { desc: "unspec but no length", reader: newBufioReader(append(SIGV2, byte(LOCAL), byte(UNSPEC))), expectedError: ErrCantReadLength, }, { desc: "TCPv4 with mismatching length", reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv4)), lengthV4Bytes...)), expectedError: ErrInvalidLength, }, { desc: "TCPv6 with mismatching length", reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv6)), lengthV6Bytes...)), expectedError: ErrInvalidLength, }, { desc: "TCPv4 length zero but with address and ports", reader: newBufioReader(append(append(append(SIGV2, byte(PROXY), byte(TCPv4)), lengthEmptyBytes...), fixtureIPv6Address...)), expectedError: ErrInvalidLength, }, { desc: "TCPv6 with IPv6 length but IPv4 address and ports", reader: newBufioReader(append(append(append(SIGV2, byte(PROXY), byte(TCPv6)), lengthV6Bytes...), fixtureIPv4Address...)), expectedError: ErrInvalidLength, }, { desc: "unspec length greater than zero but no TLVs", reader: newBufioReader(append(append(SIGV2, byte(LOCAL), byte(UNSPEC)), fixtureUnspecTLV[:2]...)), expectedError: ErrInvalidLength, }, } func TestParseV2Invalid(t *testing.T) { for _, tt := range invalidParseV2Tests { t.Run(tt.desc, func(t *testing.T) { if _, err := Read(tt.reader); err != tt.expectedError { t.Fatalf("expected %s, actual %s", tt.expectedError, err.Error()) } }) } } var validParseAndWriteV2Tests = []struct { desc string reader *bufio.Reader expectedHeader *Header }{ { desc: "local", reader: newBufioReader(append(append(SIGV2, byte(LOCAL), byte(TCPv4)), fixtureIPv4V2...)), expectedHeader: &Header{ Version: 2, Command: LOCAL, TransportProtocol: TCPv4, SourceAddr: v4addr, DestinationAddr: v4addr, }, }, { desc: "local unspec", reader: newBufioReader(append(append(SIGV2, byte(LOCAL), byte(UNSPEC)), lengthUnspecBytes...)), expectedHeader: &Header{ Version: 2, Command: LOCAL, TransportProtocol: UNSPEC, SourceAddr: nil, DestinationAddr: nil, }, }, { desc: "proxy TCPv4", reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv4)), fixtureIPv4V2...)), expectedHeader: &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: v4addr, DestinationAddr: v4addr, }, }, { desc: "proxy TCPv6", reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv6)), fixtureIPv6V2...)), expectedHeader: &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv6, SourceAddr: v6addr, DestinationAddr: v6addr, }, }, { desc: "proxy TCPv4 with TLV", reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv4)), fixtureIPv4V2TLV...)), expectedHeader: &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: v4addr, DestinationAddr: v4addr, rawTLVs: fixtureTLV, }, }, { desc: "proxy TCPv6 with TLV", reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv6)), fixtureIPv6V2TLV...)), expectedHeader: &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv6, SourceAddr: v6addr, DestinationAddr: v6addr, rawTLVs: fixtureTLV, }, }, { desc: "local unspec with TLV", reader: newBufioReader(append(append(SIGV2, byte(LOCAL), byte(UNSPEC)), fixtureUnspecTLV...)), expectedHeader: &Header{ Version: 2, Command: LOCAL, TransportProtocol: UNSPEC, SourceAddr: nil, DestinationAddr: nil, rawTLVs: fixtureTLV, }, }, { desc: "proxy UDPv4", reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(UDPv4)), fixtureIPv4V2...)), expectedHeader: &Header{ Version: 2, Command: PROXY, TransportProtocol: UDPv4, SourceAddr: v4UDPAddr, DestinationAddr: v4UDPAddr, }, }, { desc: "proxy UDPv6", reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(UDPv6)), fixtureIPv6V2...)), expectedHeader: &Header{ Version: 2, Command: PROXY, TransportProtocol: UDPv6, SourceAddr: v6UDPAddr, DestinationAddr: v6UDPAddr, }, }, { desc: "proxy unix stream", reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(UnixStream)), fixtureUnixV2...)), expectedHeader: &Header{ Version: 2, Command: PROXY, TransportProtocol: UnixStream, SourceAddr: unixStreamAddr, DestinationAddr: unixStreamAddr, }, }, { desc: "proxy unix datagram", reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(UnixDatagram)), fixtureUnixV2...)), expectedHeader: &Header{ Version: 2, Command: PROXY, TransportProtocol: UnixDatagram, SourceAddr: unixDatagramAddr, DestinationAddr: unixDatagramAddr, }, }, } func TestParseV2Valid(t *testing.T) { for _, tt := range validParseAndWriteV2Tests { t.Run(tt.desc, func(t *testing.T) { header, err := Read(tt.reader) if err != nil { t.Fatal("unexpected error", err.Error()) } if !header.EqualsTo(tt.expectedHeader) { t.Fatalf("expected %#v, actual %#v", tt.expectedHeader, header) } }) } } func TestWriteV2Valid(t *testing.T) { for _, tt := range validParseAndWriteV2Tests { t.Run(tt.desc, func(t *testing.T) { var b bytes.Buffer w := bufio.NewWriter(&b) if _, err := tt.expectedHeader.WriteTo(w); err != nil { t.Fatal("unexpected error ", err) } w.Flush() // Read written bytes to validate written header r := bufio.NewReader(&b) newHeader, err := Read(r) if err != nil { t.Fatal("unexpected error ", err) } if !newHeader.EqualsTo(tt.expectedHeader) { t.Fatalf("expected %#v, actual %#v", tt.expectedHeader, newHeader) } }) } } var validParseV2PaddedTests = []struct { desc string value []byte expectedHeader *Header }{ { desc: "proxy TCPv4", value: append(append(SIGV2, byte(PROXY), byte(TCPv4)), fixtureIPv4V2Padded...), expectedHeader: &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: v4addr, DestinationAddr: v4addr, rawTLVs: make([]byte, lengthPadded-lengthV4), }, }, { desc: "proxy TCPv6", value: append(append(SIGV2, byte(PROXY), byte(TCPv6)), fixtureIPv6V2Padded...), expectedHeader: &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv6, SourceAddr: v6addr, DestinationAddr: v6addr, rawTLVs: make([]byte, lengthPadded-lengthV6), }, }, { desc: "proxy UDPv4", value: append(append(SIGV2, byte(PROXY), byte(UDPv4)), fixtureIPv4V2Padded...), expectedHeader: &Header{ Version: 2, Command: PROXY, TransportProtocol: UDPv4, SourceAddr: v4addr, DestinationAddr: v4addr, rawTLVs: make([]byte, lengthPadded-lengthV4), }, }, { desc: "proxy UDPv6", value: append(append(SIGV2, byte(PROXY), byte(UDPv6)), fixtureIPv6V2Padded...), expectedHeader: &Header{ Version: 2, Command: PROXY, TransportProtocol: UDPv6, SourceAddr: v6addr, DestinationAddr: v6addr, rawTLVs: make([]byte, lengthPadded-lengthV6), }, }, } func TestParseV2Padded(t *testing.T) { for _, tt := range validParseV2PaddedTests { t.Run(tt.desc, func(t *testing.T) { reader := newBufioReader(append(tt.value, arbitraryTailBytes...)) newHeader, err := Read(reader) if err != nil { t.Fatal("unexpected error ", err) } if !newHeader.EqualsTo(tt.expectedHeader) { t.Fatalf("expected %#v, actual %#v", tt.expectedHeader, newHeader) } // Check that remaining padding bytes have been flushed nextBytes, err := reader.Peek(len(arbitraryTailBytes)) if err != nil { t.Fatal("unexpected error ", err) } if !reflect.DeepEqual(nextBytes, arbitraryTailBytes) { t.Fatalf("expected %#v, actual %#v", arbitraryTailBytes, nextBytes) } }) } } func TestV2EqualsToTLV(t *testing.T) { eHdr := &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: v4addr, DestinationAddr: v4addr, } hdr, err := Read(newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv4)), fixtureIPv4V2TLV...))) if err != nil { t.Fatal("unexpected error ", err) } if eHdr.EqualsTo(hdr) { t.Fatalf("unexpectedly equal created: %#v, parsed: %#v", eHdr, hdr) } eHdr.rawTLVs = fixtureTLV[:] if !eHdr.EqualsTo(hdr) { t.Fatalf("unexpectedly unequal after tlv copy created: %#v, parsed: %#v", eHdr, hdr) } eHdr.rawTLVs[0] = eHdr.rawTLVs[0] + 1 if eHdr.EqualsTo(hdr) { t.Fatalf("unexpectedly equal after changing tlv created: %#v, parsed: %#v", eHdr, hdr) } } var tlvFormatTests = []struct { desc string header *Header }{ { desc: "proxy TCPv4", header: &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: v4addr, DestinationAddr: v4addr, rawTLVs: make([]byte, 1<<16), }, }, { desc: "proxy TCPv6", header: &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv6, SourceAddr: v6addr, DestinationAddr: v6addr, rawTLVs: make([]byte, 1<<16), }, }, { desc: "proxy UDPv4", header: &Header{ Version: 2, Command: PROXY, TransportProtocol: UDPv4, SourceAddr: v4addr, DestinationAddr: v4addr, rawTLVs: make([]byte, 1<<16), }, }, { desc: "proxy UDPv6", header: &Header{ Version: 2, Command: PROXY, TransportProtocol: UDPv6, SourceAddr: v6addr, DestinationAddr: v6addr, rawTLVs: make([]byte, 1<<16), }, }, { desc: "local unspec", header: &Header{ Version: 2, Command: LOCAL, TransportProtocol: UNSPEC, SourceAddr: nil, DestinationAddr: nil, rawTLVs: make([]byte, 1<<16), }, }, } func TestV2TLVFormatTooLargeTLV(t *testing.T) { for _, tt := range tlvFormatTests { t.Run(tt.desc, func(t *testing.T) { if _, err := tt.header.Format(); err != errUint16Overflow { t.Fatalf("missing or expected error when formatting too-large TLV %#v", err) } }) } } func newBufioReader(b []byte) *bufio.Reader { return bufio.NewReader(bytes.NewReader(b)) } func fixtureWithTLV(cur []byte, addr []byte, tlv []byte) []byte { tlen, err := addTLVLen(cur, len(tlv)) if err != nil { panic(err) } return append(append(tlen, addr...), tlv...) } go-proxyproto-0.7.0/version_cmd.go000066400000000000000000000030521440432001500172160ustar00rootroot00000000000000package proxyproto // ProtocolVersionAndCommand represents the command in proxy protocol v2. // Command doesn't exist in v1 but it should be set since other parts of // this library may rely on it for determining connection details. type ProtocolVersionAndCommand byte const ( // LOCAL represents the LOCAL command in v2 or UNKNOWN transport in v1, // in which case no address information is expected. LOCAL ProtocolVersionAndCommand = '\x20' // PROXY represents the PROXY command in v2 or transport is not UNKNOWN in v1, // in which case valid local/remote address and port information is expected. PROXY ProtocolVersionAndCommand = '\x21' ) var supportedCommand = map[ProtocolVersionAndCommand]bool{ LOCAL: true, PROXY: true, } // IsLocal returns true if the command in v2 is LOCAL or the transport in v1 is UNKNOWN, // i.e. when no address information is expected, false otherwise. func (pvc ProtocolVersionAndCommand) IsLocal() bool { return LOCAL == pvc } // IsProxy returns true if the command in v2 is PROXY or the transport in v1 is not UNKNOWN, // i.e. when valid local/remote address and port information is expected, false otherwise. func (pvc ProtocolVersionAndCommand) IsProxy() bool { return PROXY == pvc } // IsUnspec returns true if the command is unspecified, false otherwise. func (pvc ProtocolVersionAndCommand) IsUnspec() bool { return !(pvc.IsLocal() || pvc.IsProxy()) } func (pvc ProtocolVersionAndCommand) toByte() byte { if pvc.IsLocal() { return byte(LOCAL) } else if pvc.IsProxy() { return byte(PROXY) } return byte(LOCAL) } go-proxyproto-0.7.0/version_cmd_test.go000066400000000000000000000013511440432001500202550ustar00rootroot00000000000000package proxyproto import ( "testing" ) func TestLocal(t *testing.T) { b := byte(LOCAL) if ProtocolVersionAndCommand(b).IsUnspec() { t.Fail() } if !ProtocolVersionAndCommand(b).IsLocal() { t.Fail() } if ProtocolVersionAndCommand(b).IsProxy() { t.Fail() } if ProtocolVersionAndCommand(b).toByte() != b { t.Fail() } } func TestProxy(t *testing.T) { b := byte(PROXY) if ProtocolVersionAndCommand(b).IsUnspec() { t.Fail() } if ProtocolVersionAndCommand(b).IsLocal() { t.Fail() } if !ProtocolVersionAndCommand(b).IsProxy() { t.Fail() } if ProtocolVersionAndCommand(b).toByte() != b { t.Fail() } } func TestInvalidProtocolVersion(t *testing.T) { if !ProtocolVersionAndCommand(0x00).IsUnspec() { t.Fail() } }