pax_global_header00006660000000000000000000000064140045175560014521gustar00rootroot0000000000000052 comment=0aca5f2f285d4ced57f6343b8f729d92f5e660b9 go-proxyproto-0.4.2/000077500000000000000000000000001400451755600143745ustar00rootroot00000000000000go-proxyproto-0.4.2/.github/000077500000000000000000000000001400451755600157345ustar00rootroot00000000000000go-proxyproto-0.4.2/.github/workflows/000077500000000000000000000000001400451755600177715ustar00rootroot00000000000000go-proxyproto-0.4.2/.github/workflows/test.yml000066400000000000000000000014641400451755600215000ustar00rootroot00000000000000name: test on: pull_request: push: jobs: build: name: Build runs-on: ubuntu-latest strategy: matrix: go: [ '1.15', '1.14' ] steps: - uses: actions/checkout@v2 - name: Set up Go uses: actions/setup-go@v2 with: go-version: ${{ matrix.go }} - name: Check out code into the Go module directory uses: actions/checkout@v1 - name: Get dependencies run: | go get golang.org/x/tools/cmd/cover go get github.com/mattn/goveralls - name: Format run: go fmt - name: Vet run: go vet - name: Test run: go test -v -covermode=count -coverprofile=coverage.out - name: actions-goveralls uses: shogo82148/actions-goveralls@v1.2.2 with: github-token: ${{ secrets.GITHUB_TOKEN }} go-proxyproto-0.4.2/.gitignore000066400000000000000000000001571400451755600163670ustar00rootroot00000000000000# Compiled Object files, Static and Dynamic libs (Shared Objects) *.o *.a *.so # Folders .idea bin pkg *.out go-proxyproto-0.4.2/LICENSE000066400000000000000000000261151400451755600154060ustar00rootroot00000000000000 Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "{}" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright 2016 Paulo Pires Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. go-proxyproto-0.4.2/README.md000066400000000000000000000071721400451755600156620ustar00rootroot00000000000000# go-proxyproto [![Actions Status](https://github.com/pires/go-proxyproto/workflows/test/badge.svg)](https://github.com/pires/go-proxyproto/actions) [![Coverage Status](https://coveralls.io/repos/github/pires/go-proxyproto/badge.svg?branch=master)](https://coveralls.io/github/pires/go-proxyproto?branch=master) [![Go Report Card](https://goreportcard.com/badge/github.com/pires/go-proxyproto)](https://goreportcard.com/report/github.com/pires/go-proxyproto) [![](https://godoc.org/github.com/pires/go-proxyproto?status.svg)](https://pkg.go.dev/github.com/pires/go-proxyproto?tab=doc) A Go library implementation of the [PROXY protocol, versions 1 and 2](https://www.haproxy.org/download/2.3/doc/proxy-protocol.txt), which provides, as per specification: > (...) a convenient way to safely transport connection > information such as a client's address across multiple layers of NAT or TCP > proxies. It is designed to require little changes to existing components and > to limit the performance impact caused by the processing of the transported > information. This library is to be used in one of or both proxy clients and proxy servers that need to support said protocol. Both protocol versions, 1 (text-based) and 2 (binary-based) are supported. ## Installation ```shell $ go get -u github.com/pires/go-proxyproto ``` ## Usage ### Client ```go package main import ( "io" "log" "net" proxyproto "github.com/pires/go-proxyproto" ) func chkErr(err error) { if err != nil { log.Fatalf("Error: %s", err.Error()) } } func main() { // Dial some proxy listener e.g. https://github.com/mailgun/proxyproto target, err := net.ResolveTCPAddr("tcp", "127.0.0.1:2319") chkErr(err) conn, err := net.DialTCP("tcp", nil, target) chkErr(err) defer conn.Close() // Create a proxyprotocol header or use HeaderProxyFromAddrs() if you // have two conn's header := &proxyproto.Header{ Version: 1, Command: proxyproto.PROXY, TransportProtocol: proxyproto.TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, } // After the connection was created write the proxy headers first _, err = header.WriteTo(conn) chkErr(err) // Then your data... e.g.: _, err = io.WriteString(conn, "HELO") chkErr(err) } ``` ### Server ```go package main import ( "log" "net" proxyproto "github.com/pires/go-proxyproto" ) func main() { // Create a listener addr := "localhost:9876" list, err := net.Listen("tcp", addr) if err != nil { log.Fatalf("couldn't listen to %q: %q\n", addr, err.Error()) } // Wrap listener in a proxyproto listener proxyListener := &proxyproto.Listener{Listener: list} defer proxyListener.Close() // Wait for a connection and accept it conn, err := proxyListener.Accept() defer conn.Close() // Print connection details if conn.LocalAddr() == nil { log.Fatal("couldn't retrieve local address") } log.Printf("local address: %q", conn.LocalAddr().String()) if conn.RemoteAddr() == nil { log.Fatal("couldn't retrieve remote address") } log.Printf("remote address: %q", conn.RemoteAddr().String()) } ``` ## Special notes ### AWS AWS Network Load Balancer (NLB) does not push the PPV2 header until the client starts sending the data. This is a problem if your server speaks first. e.g. SMTP, FTP, SSH etc. By default, NLB target group attribute `proxy_protocol_v2.client_to_server.header_placement` has the value `on_first_ack_with_payload`. You need to contact AWS support to change it to `on_first_ack`, instead. Just to be clear, you need this fix only if your server is designed to speak first. go-proxyproto-0.4.2/addr_proto.go000066400000000000000000000037121400451755600170630ustar00rootroot00000000000000package proxyproto // AddressFamilyAndProtocol represents address family and transport protocol. type AddressFamilyAndProtocol byte const ( UNSPEC AddressFamilyAndProtocol = '\x00' TCPv4 AddressFamilyAndProtocol = '\x11' UDPv4 AddressFamilyAndProtocol = '\x12' TCPv6 AddressFamilyAndProtocol = '\x21' UDPv6 AddressFamilyAndProtocol = '\x22' UnixStream AddressFamilyAndProtocol = '\x31' UnixDatagram AddressFamilyAndProtocol = '\x32' ) // IsIPv4 returns true if the address family is IPv4 (AF_INET4), false otherwise. func (ap AddressFamilyAndProtocol) IsIPv4() bool { return 0x10 == ap&0xF0 } // IsIPv6 returns true if the address family is IPv6 (AF_INET6), false otherwise. func (ap AddressFamilyAndProtocol) IsIPv6() bool { return 0x20 == ap&0xF0 } // IsUnix returns true if the address family is UNIX (AF_UNIX), false otherwise. func (ap AddressFamilyAndProtocol) IsUnix() bool { return 0x30 == ap&0xF0 } // IsStream returns true if the transport protocol is TCP or STREAM (SOCK_STREAM), false otherwise. func (ap AddressFamilyAndProtocol) IsStream() bool { return 0x01 == ap&0x0F } // IsDatagram returns true if the transport protocol is UDP or DGRAM (SOCK_DGRAM), false otherwise. func (ap AddressFamilyAndProtocol) IsDatagram() bool { return 0x02 == ap&0x0F } // IsUnspec returns true if the transport protocol or address family is unspecified, false otherwise. func (ap AddressFamilyAndProtocol) IsUnspec() bool { return (0x00 == ap&0xF0) || (0x00 == ap&0x0F) } func (ap AddressFamilyAndProtocol) toByte() byte { if ap.IsIPv4() && ap.IsStream() { return byte(TCPv4) } else if ap.IsIPv4() && ap.IsDatagram() { return byte(UDPv4) } else if ap.IsIPv6() && ap.IsStream() { return byte(TCPv6) } else if ap.IsIPv6() && ap.IsDatagram() { return byte(UDPv6) } else if ap.IsUnix() && ap.IsStream() { return byte(UnixStream) } else if ap.IsUnix() && ap.IsDatagram() { return byte(UnixDatagram) } return byte(UNSPEC) } go-proxyproto-0.4.2/addr_proto_test.go000066400000000000000000000032311400451755600201160ustar00rootroot00000000000000package proxyproto import ( "testing" ) func TestTCPoverIPv4(t *testing.T) { b := byte(TCPv4) if !AddressFamilyAndProtocol(b).IsIPv4() { t.Fail() } if !AddressFamilyAndProtocol(b).IsStream() { t.Fail() } if AddressFamilyAndProtocol(b).toByte() != b { t.Fail() } } func TestTCPoverIPv6(t *testing.T) { b := byte(TCPv6) if !AddressFamilyAndProtocol(b).IsIPv6() { t.Fail() } if !AddressFamilyAndProtocol(b).IsStream() { t.Fail() } if AddressFamilyAndProtocol(b).toByte() != b { t.Fail() } } func TestUDPoverIPv4(t *testing.T) { b := byte(UDPv4) if !AddressFamilyAndProtocol(b).IsIPv4() { t.Fail() } if !AddressFamilyAndProtocol(b).IsDatagram() { t.Fail() } if AddressFamilyAndProtocol(b).toByte() != b { t.Fail() } } func TestUDPoverIPv6(t *testing.T) { b := byte(UDPv6) if !AddressFamilyAndProtocol(b).IsIPv6() { t.Fail() } if !AddressFamilyAndProtocol(b).IsDatagram() { t.Fail() } if AddressFamilyAndProtocol(b).toByte() != b { t.Fail() } } func TestUnixStream(t *testing.T) { b := byte(UnixStream) if !AddressFamilyAndProtocol(b).IsUnix() { t.Fail() } if !AddressFamilyAndProtocol(b).IsStream() { t.Fail() } if AddressFamilyAndProtocol(b).toByte() != b { t.Fail() } } func TestUnixDatagram(t *testing.T) { b := byte(UnixDatagram) if !AddressFamilyAndProtocol(b).IsUnix() { t.Fail() } if !AddressFamilyAndProtocol(b).IsDatagram() { t.Fail() } if AddressFamilyAndProtocol(b).toByte() != b { t.Fail() } } func TestInvalidAddressFamilyAndProtocol(t *testing.T) { b := byte(UNSPEC) if !AddressFamilyAndProtocol(b).IsUnspec() { t.Fail() } if AddressFamilyAndProtocol(b).toByte() != b { t.Fail() } } go-proxyproto-0.4.2/go.mod000066400000000000000000000000571400451755600155040ustar00rootroot00000000000000module github.com/pires/go-proxyproto go 1.13 go-proxyproto-0.4.2/header.go000066400000000000000000000214151400451755600161560ustar00rootroot00000000000000// Package proxyproto implements Proxy Protocol (v1 and v2) parser and writer, as per specification: // https://www.haproxy.org/download/2.3/doc/proxy-protocol.txt package proxyproto import ( "bufio" "bytes" "errors" "io" "net" "time" ) var ( // Protocol SIGV1 = []byte{'\x50', '\x52', '\x4F', '\x58', '\x59'} SIGV2 = []byte{'\x0D', '\x0A', '\x0D', '\x0A', '\x00', '\x0D', '\x0A', '\x51', '\x55', '\x49', '\x54', '\x0A'} ErrLineMustEndWithCrlf = errors.New("proxyproto: header is invalid, must end with \\r\\n") ErrCantReadProtocolVersionAndCommand = errors.New("proxyproto: can't read proxy protocol version and command") ErrCantReadAddressFamilyAndProtocol = errors.New("proxyproto: can't read address family or protocol") ErrCantReadLength = errors.New("proxyproto: can't read length") ErrCantResolveSourceUnixAddress = errors.New("proxyproto: can't resolve source Unix address") ErrCantResolveDestinationUnixAddress = errors.New("proxyproto: can't resolve destination Unix address") ErrNoProxyProtocol = errors.New("proxyproto: proxy protocol signature not present") ErrUnknownProxyProtocolVersion = errors.New("proxyproto: unknown proxy protocol version") ErrUnsupportedProtocolVersionAndCommand = errors.New("proxyproto: unsupported proxy protocol version and command") ErrUnsupportedAddressFamilyAndProtocol = errors.New("proxyproto: unsupported address family and protocol") ErrInvalidLength = errors.New("proxyproto: invalid length") ErrInvalidAddress = errors.New("proxyproto: invalid address") ErrInvalidPortNumber = errors.New("proxyproto: invalid port number") ErrSuperfluousProxyHeader = errors.New("proxyproto: upstream connection sent PROXY header but isn't allowed to send one") ) // Header is the placeholder for proxy protocol header. type Header struct { Version byte Command ProtocolVersionAndCommand TransportProtocol AddressFamilyAndProtocol SourceAddr net.Addr DestinationAddr net.Addr rawTLVs []byte } // HeaderProxyFromAddrs creates a new PROXY header from a source and a // destination address. If version is zero, the latest protocol version is // used. // // The header is filled on a best-effort basis: if hints cannot be inferred // from the provided addresses, the header will be left unspecified. func HeaderProxyFromAddrs(version byte, sourceAddr, destAddr net.Addr) *Header { if version < 1 || version > 2 { version = 2 } h := &Header{ Version: version, Command: LOCAL, TransportProtocol: UNSPEC, } switch sourceAddr := sourceAddr.(type) { case *net.TCPAddr: if _, ok := destAddr.(*net.TCPAddr); !ok { break } if len(sourceAddr.IP.To4()) == net.IPv4len { h.TransportProtocol = TCPv4 } else if len(sourceAddr.IP) == net.IPv6len { h.TransportProtocol = TCPv6 } case *net.UDPAddr: if _, ok := destAddr.(*net.UDPAddr); !ok { break } if len(sourceAddr.IP.To4()) == net.IPv4len { h.TransportProtocol = UDPv4 } else if len(sourceAddr.IP) == net.IPv6len { h.TransportProtocol = UDPv6 } case *net.UnixAddr: if _, ok := destAddr.(*net.UnixAddr); !ok { break } switch sourceAddr.Net { case "unix": h.TransportProtocol = UnixStream case "unixgram": h.TransportProtocol = UnixDatagram } } if h.TransportProtocol != UNSPEC { h.Command = PROXY h.SourceAddr = sourceAddr h.DestinationAddr = destAddr } return h } func (header *Header) TCPAddrs() (sourceAddr, destAddr *net.TCPAddr, ok bool) { if !header.TransportProtocol.IsStream() { return nil, nil, false } sourceAddr, sourceOK := header.SourceAddr.(*net.TCPAddr) destAddr, destOK := header.DestinationAddr.(*net.TCPAddr) return sourceAddr, destAddr, sourceOK && destOK } func (header *Header) UDPAddrs() (sourceAddr, destAddr *net.UDPAddr, ok bool) { if !header.TransportProtocol.IsDatagram() { return nil, nil, false } sourceAddr, sourceOK := header.SourceAddr.(*net.UDPAddr) destAddr, destOK := header.DestinationAddr.(*net.UDPAddr) return sourceAddr, destAddr, sourceOK && destOK } func (header *Header) UnixAddrs() (sourceAddr, destAddr *net.UnixAddr, ok bool) { if !header.TransportProtocol.IsUnix() { return nil, nil, false } sourceAddr, sourceOK := header.SourceAddr.(*net.UnixAddr) destAddr, destOK := header.DestinationAddr.(*net.UnixAddr) return sourceAddr, destAddr, sourceOK && destOK } func (header *Header) IPs() (sourceIP, destIP net.IP, ok bool) { if sourceAddr, destAddr, ok := header.TCPAddrs(); ok { return sourceAddr.IP, destAddr.IP, true } else if sourceAddr, destAddr, ok := header.UDPAddrs(); ok { return sourceAddr.IP, destAddr.IP, true } else { return nil, nil, false } } func (header *Header) Ports() (sourcePort, destPort int, ok bool) { if sourceAddr, destAddr, ok := header.TCPAddrs(); ok { return sourceAddr.Port, destAddr.Port, true } else if sourceAddr, destAddr, ok := header.UDPAddrs(); ok { return sourceAddr.Port, destAddr.Port, true } else { return 0, 0, false } } // EqualTo returns true if headers are equivalent, false otherwise. // Deprecated: use EqualsTo instead. This method will eventually be removed. func (header *Header) EqualTo(otherHeader *Header) bool { return header.EqualsTo(otherHeader) } // EqualsTo returns true if headers are equivalent, false otherwise. func (header *Header) EqualsTo(otherHeader *Header) bool { if otherHeader == nil { return false } // TLVs only exist for version 2 if header.Version == 2 && !bytes.Equal(header.rawTLVs, otherHeader.rawTLVs) { return false } if header.Version != otherHeader.Version || header.Command != otherHeader.Command || header.TransportProtocol != otherHeader.TransportProtocol { return false } // Return early for header with LOCAL command, which contains no address information if header.Command == LOCAL { return true } return header.SourceAddr.String() == otherHeader.SourceAddr.String() && header.DestinationAddr.String() == otherHeader.DestinationAddr.String() } // WriteTo renders a proxy protocol header in a format and writes it to an io.Writer. func (header *Header) WriteTo(w io.Writer) (int64, error) { buf, err := header.Format() if err != nil { return 0, err } return bytes.NewBuffer(buf).WriteTo(w) } // Format renders a proxy protocol header in a format to write over the wire. func (header *Header) Format() ([]byte, error) { switch header.Version { case 1: return header.formatVersion1() case 2: return header.formatVersion2() default: return nil, ErrUnknownProxyProtocolVersion } } // TLVs returns the TLVs stored into this header, if they exist. TLVs are optional for v2 of the protocol. func (header *Header) TLVs() ([]TLV, error) { return SplitTLVs(header.rawTLVs) } // SetTLVs sets the TLVs stored in this header. This method replaces any // previous TLV. func (header *Header) SetTLVs(tlvs []TLV) error { raw, err := JoinTLVs(tlvs) if err != nil { return err } header.rawTLVs = raw return nil } // Read identifies the proxy protocol version and reads the remaining of // the header, accordingly. // // If proxy protocol header signature is not present, the reader buffer remains untouched // and is safe for reading outside of this code. // // If proxy protocol header signature is present but an error is raised while processing // the remaining header, assume the reader buffer to be in a corrupt state. // Also, this operation will block until enough bytes are available for peeking. func Read(reader *bufio.Reader) (*Header, error) { // In order to improve speed for small non-PROXYed packets, take a peek at the first byte alone. b1, err := reader.Peek(1) if err != nil { if err == io.EOF { return nil, ErrNoProxyProtocol } return nil, err } if bytes.Equal(b1[:1], SIGV1[:1]) || bytes.Equal(b1[:1], SIGV2[:1]) { signature, err := reader.Peek(5) if err != nil { if err == io.EOF { return nil, ErrNoProxyProtocol } return nil, err } if bytes.Equal(signature[:5], SIGV1) { return parseVersion1(reader) } signature, err = reader.Peek(12) if err != nil { if err == io.EOF { return nil, ErrNoProxyProtocol } return nil, err } if bytes.Equal(signature[:12], SIGV2) { return parseVersion2(reader) } } return nil, ErrNoProxyProtocol } // ReadTimeout acts as Read but takes a timeout. If that timeout is reached, it's assumed // there's no proxy protocol header. func ReadTimeout(reader *bufio.Reader, timeout time.Duration) (*Header, error) { type header struct { h *Header e error } read := make(chan *header, 1) go func() { h := &header{} h.h, h.e = Read(reader) read <- h }() timer := time.NewTimer(timeout) select { case result := <-read: timer.Stop() return result.h, result.e case <-timer.C: return nil, ErrNoProxyProtocol } } go-proxyproto-0.4.2/header_test.go000066400000000000000000000422471400451755600172230ustar00rootroot00000000000000package proxyproto import ( "bufio" "bytes" "errors" "net" "reflect" "testing" "time" ) // Stuff to be used in both versions tests. const ( NO_PROTOCOL = "There is no spoon" IP4_ADDR = "127.0.0.1" IP6_ADDR = "::1" PORT = 65533 INVALID_PORT = 99999 ) var ( v4ip = net.ParseIP(IP4_ADDR).To4() v6ip = net.ParseIP(IP6_ADDR).To16() v4addr net.Addr = &net.TCPAddr{IP: v4ip, Port: PORT} v6addr net.Addr = &net.TCPAddr{IP: v6ip, Port: PORT} v4UDPAddr net.Addr = &net.UDPAddr{IP: v4ip, Port: PORT} v6UDPAddr net.Addr = &net.UDPAddr{IP: v6ip, Port: PORT} unixStreamAddr net.Addr = &net.UnixAddr{Net: "unix", Name: "socket"} unixDatagramAddr net.Addr = &net.UnixAddr{Net: "unixgram", Name: "socket"} errReadIntentionallyBroken = errors.New("read is intentionally broken") ) type timeoutReader []byte func (t *timeoutReader) Read([]byte) (int, error) { time.Sleep(500 * time.Millisecond) return 0, nil } type errorReader []byte func (e *errorReader) Read([]byte) (int, error) { return 0, errReadIntentionallyBroken } func TestReadTimeoutV1Invalid(t *testing.T) { var b timeoutReader reader := bufio.NewReader(&b) _, err := ReadTimeout(reader, 50*time.Millisecond) if err == nil { t.Fatalf("expected error %s", ErrNoProxyProtocol) } else if err != ErrNoProxyProtocol { t.Fatalf("expected %s, actual %s", ErrNoProxyProtocol, err) } } func TestReadTimeoutPropagatesReadError(t *testing.T) { var e errorReader reader := bufio.NewReader(&e) _, err := ReadTimeout(reader, 50*time.Millisecond) if err == nil { t.Fatalf("expected error %s", errReadIntentionallyBroken) } else if err != errReadIntentionallyBroken { t.Fatalf("expected error %s, actual %s", errReadIntentionallyBroken, err) } } func TestEqualsTo(t *testing.T) { var headersEqual = []struct { this, that *Header expected bool }{ { &Header{ Version: 1, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, }, nil, false, }, { &Header{ Version: 1, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, }, &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, }, false, }, { &Header{ Version: 1, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, }, &Header{ Version: 1, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, }, true, }, } for _, tt := range headersEqual { if actual := tt.this.EqualsTo(tt.that); actual != tt.expected { t.Fatalf("expected %t, actual %t", tt.expected, actual) } } } // This is here just because of coveralls func TestEqualTo(t *testing.T) { TestEqualsTo(t) } func TestGetters(t *testing.T) { var tests = []struct { name string header *Header tcpSourceAddr, tcpDestAddr *net.TCPAddr udpSourceAddr, udpDestAddr *net.UDPAddr unixSourceAddr, unixDestAddr *net.UnixAddr ipSource, ipDest net.IP portSource, portDest int }{ { name: "TCPv4", header: &Header{ Version: 1, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, }, tcpSourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, tcpDestAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, ipSource: net.ParseIP("10.1.1.1"), ipDest: net.ParseIP("20.2.2.2"), portSource: 1000, portDest: 2000, }, { name: "UDPv4", header: &Header{ Version: 2, Command: PROXY, TransportProtocol: UDPv6, SourceAddr: &net.UDPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.UDPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, }, udpSourceAddr: &net.UDPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, udpDestAddr: &net.UDPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, ipSource: net.ParseIP("10.1.1.1"), ipDest: net.ParseIP("20.2.2.2"), portSource: 1000, portDest: 2000, }, { name: "UnixStream", header: &Header{ Version: 2, Command: PROXY, TransportProtocol: UnixStream, SourceAddr: &net.UnixAddr{ Net: "unix", Name: "src", }, DestinationAddr: &net.UnixAddr{ Net: "unix", Name: "dst", }, }, unixSourceAddr: &net.UnixAddr{ Net: "unix", Name: "src", }, unixDestAddr: &net.UnixAddr{ Net: "unix", Name: "dst", }, }, { name: "UnixDatagram", header: &Header{ Version: 2, Command: PROXY, TransportProtocol: UnixDatagram, SourceAddr: &net.UnixAddr{ Net: "unix", Name: "src", }, DestinationAddr: &net.UnixAddr{ Net: "unix", Name: "dst", }, }, unixSourceAddr: &net.UnixAddr{ Net: "unix", Name: "src", }, unixDestAddr: &net.UnixAddr{ Net: "unix", Name: "dst", }, }, { name: "Unspec", header: &Header{ Version: 1, Command: PROXY, TransportProtocol: UNSPEC, }, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { tcpSourceAddr, tcpDestAddr, _ := test.header.TCPAddrs() if test.tcpSourceAddr != nil && !reflect.DeepEqual(tcpSourceAddr, test.tcpSourceAddr) { t.Errorf("TCPAddrs() source = %v, want %v", tcpSourceAddr, test.tcpSourceAddr) } if test.tcpDestAddr != nil && !reflect.DeepEqual(tcpDestAddr, test.tcpDestAddr) { t.Errorf("TCPAddrs() dest = %v, want %v", tcpDestAddr, test.tcpDestAddr) } udpSourceAddr, udpDestAddr, _ := test.header.UDPAddrs() if test.udpSourceAddr != nil && !reflect.DeepEqual(udpSourceAddr, test.udpSourceAddr) { t.Errorf("TCPAddrs() source = %v, want %v", udpSourceAddr, test.udpSourceAddr) } if test.udpDestAddr != nil && !reflect.DeepEqual(udpDestAddr, test.udpDestAddr) { t.Errorf("TCPAddrs() dest = %v, want %v", udpDestAddr, test.udpDestAddr) } unixSourceAddr, unixDestAddr, _ := test.header.UnixAddrs() if test.unixSourceAddr != nil && !reflect.DeepEqual(unixSourceAddr, test.unixSourceAddr) { t.Errorf("UnixAddrs() source = %v, want %v", unixSourceAddr, test.unixSourceAddr) } if test.unixDestAddr != nil && !reflect.DeepEqual(unixDestAddr, test.unixDestAddr) { t.Errorf("UnixAddrs() dest = %v, want %v", unixDestAddr, test.unixDestAddr) } ipSource, ipDest, _ := test.header.IPs() if test.ipSource != nil && !ipSource.Equal(test.ipSource) { t.Errorf("IPs() source = %v, want %v", ipSource, test.ipSource) } if test.ipDest != nil && !ipDest.Equal(test.ipDest) { t.Errorf("IPs() dest = %v, want %v", ipDest, test.ipDest) } portSource, portDest, _ := test.header.Ports() if test.portSource != 0 && portSource != test.portSource { t.Errorf("Ports() source = %v, want %v", portSource, test.portSource) } if test.portDest != 0 && portDest != test.portDest { t.Errorf("Ports() dest = %v, want %v", portDest, test.portDest) } }) } } func TestSetTLVs(t *testing.T) { tests := []struct { header *Header name string tlvs []TLV expectErr bool }{ { name: "add authority TLV", header: &Header{ Version: 1, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, }, tlvs: []TLV{{ Type: PP2_TYPE_AUTHORITY, Value: []byte("example.org"), }}, }, { name: "add too long TLV", header: &Header{ Version: 1, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, }, tlvs: []TLV{{ Type: PP2_TYPE_AUTHORITY, Value: append(bytes.Repeat([]byte("a"), 0xFFFF), []byte(".example.org")...), }}, expectErr: true, }, } for _, tt := range tests { err := tt.header.SetTLVs(tt.tlvs) if err != nil && !tt.expectErr { t.Fatalf("shouldn't have thrown error %q", err.Error()) } } } func TestWriteTo(t *testing.T) { var buf bytes.Buffer validHeader := &Header{ Version: 1, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, } if _, err := validHeader.WriteTo(&buf); err != nil { t.Fatalf("shouldn't have thrown error %q", err.Error()) } invalidHeader := &Header{ SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, } if _, err := invalidHeader.WriteTo(&buf); err == nil { t.Fatalf("should have thrown error %q", err.Error()) } } func TestFormat(t *testing.T) { validHeader := &Header{ Version: 1, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, } if _, err := validHeader.Format(); err != nil { t.Fatalf("shouldn't have thrown error %q", err.Error()) } } func TestFormatInvalid(t *testing.T) { tests := []struct { name string header *Header err error }{ { name: "invalidVersion", header: &Header{ Version: 3, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: v4addr, DestinationAddr: v4addr, }, err: ErrUnknownProxyProtocolVersion, }, { name: "v2MismatchTCPv4_UDPv4", header: &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: v4UDPAddr, DestinationAddr: v4addr, }, err: ErrInvalidAddress, }, { name: "v2MismatchTCPv4_TCPv6", header: &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: v4addr, DestinationAddr: v6addr, }, err: ErrInvalidAddress, }, { name: "v2MismatchUnixStream_TCPv4", header: &Header{ Version: 2, Command: PROXY, TransportProtocol: UnixStream, SourceAddr: v4addr, DestinationAddr: unixStreamAddr, }, err: ErrInvalidAddress, }, { name: "v1MismatchTCPv4_TCPv6", header: &Header{ Version: 1, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: v6addr, DestinationAddr: v4addr, }, err: ErrInvalidAddress, }, { name: "v1MismatchTCPv4_UDPv4", header: &Header{ Version: 1, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: v4UDPAddr, DestinationAddr: v4addr, }, err: ErrInvalidAddress, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { if _, err := test.header.Format(); err == nil { t.Errorf("Header.Format() succeeded, want an error") } else if err != test.err { t.Errorf("Header.Format() = %q, want %q", err, test.err) } }) } } func TestHeaderProxyFromAddrs(t *testing.T) { unspec := &Header{ Version: 2, Command: LOCAL, TransportProtocol: UNSPEC, } tests := []struct { name string version byte sourceAddr, destAddr net.Addr expected *Header }{ { name: "TCPv4", sourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, destAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, expected: &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, }, }, { name: "TCPv6", sourceAddr: &net.TCPAddr{ IP: net.ParseIP("fde7::372"), Port: 1000, }, destAddr: &net.TCPAddr{ IP: net.ParseIP("fde7::1"), Port: 2000, }, expected: &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv6, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("fde7::372"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("fde7::1"), Port: 2000, }, }, }, { name: "UDPv4", sourceAddr: &net.UDPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, destAddr: &net.UDPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, expected: &Header{ Version: 2, Command: PROXY, TransportProtocol: UDPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, }, }, { name: "UDPv6", sourceAddr: &net.UDPAddr{ IP: net.ParseIP("fde7::372"), Port: 1000, }, destAddr: &net.UDPAddr{ IP: net.ParseIP("fde7::1"), Port: 2000, }, expected: &Header{ Version: 2, Command: PROXY, TransportProtocol: UDPv6, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("fde7::372"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("fde7::1"), Port: 2000, }, }, }, { name: "UnixStream", sourceAddr: &net.UnixAddr{ Net: "unix", Name: "src", }, destAddr: &net.UnixAddr{ Net: "unix", Name: "dst", }, expected: &Header{ Version: 2, Command: PROXY, TransportProtocol: UnixStream, SourceAddr: &net.UnixAddr{ Net: "unix", Name: "src", }, DestinationAddr: &net.UnixAddr{ Net: "unix", Name: "dst", }, }, }, { name: "UnixDatagram", sourceAddr: &net.UnixAddr{ Net: "unixgram", Name: "src", }, destAddr: &net.UnixAddr{ Net: "unixgram", Name: "dst", }, expected: &Header{ Version: 2, Command: PROXY, TransportProtocol: UnixDatagram, SourceAddr: &net.UnixAddr{ Net: "unixgram", Name: "src", }, DestinationAddr: &net.UnixAddr{ Net: "unixgram", Name: "dst", }, }, }, { name: "Version1", version: 1, sourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, destAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, expected: &Header{ Version: 1, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, }, }, { name: "TCPInvalidIP", sourceAddr: &net.TCPAddr{ IP: nil, Port: 1000, }, destAddr: &net.TCPAddr{ IP: nil, Port: 2000, }, expected: unspec, }, { name: "UDPInvalidIP", sourceAddr: &net.UDPAddr{ IP: nil, Port: 1000, }, destAddr: &net.UDPAddr{ IP: nil, Port: 2000, }, expected: unspec, }, { name: "TCPAddrTypeMismatch", sourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, destAddr: &net.UDPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, expected: unspec, }, { name: "UDPAddrTypeMismatch", sourceAddr: &net.UDPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, destAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, expected: unspec, }, { name: "UnixAddrTypeMismatch", sourceAddr: &net.UnixAddr{ Net: "unix", }, destAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, expected: unspec, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { h := HeaderProxyFromAddrs(tt.version, tt.sourceAddr, tt.destAddr) if !h.EqualsTo(tt.expected) { t.Errorf("expected %+v, actual %+v for source %+v and destination %+v", tt.expected, h, tt.sourceAddr, tt.destAddr) } }) } } go-proxyproto-0.4.2/policy.go000066400000000000000000000106301400451755600162220ustar00rootroot00000000000000package proxyproto import ( "fmt" "net" "strings" ) // PolicyFunc can be used to decide whether to trust the PROXY info from // upstream. If set, the connecting address is passed in as an argument. // // See below for the different policies. // // In case an error is returned the connection is denied. type PolicyFunc func(upstream net.Addr) (Policy, error) // Policy defines how a connection with a PROXY header address is treated. type Policy int const ( // USE address from PROXY header USE Policy = iota // IGNORE address from PROXY header, but accept connection IGNORE // REJECT connection when PROXY header is sent // Note: even though the first read on the connection returns an error if // a PROXY header is present, subsequent reads do not. It is the task of // the code using the connection to handle that case properly. REJECT // REQUIRE connection to send PROXY header, reject if not present // Note: even though the first read on the connection returns an error if // a PROXY header is not present, subsequent reads do not. It is the task // of the code using the connection to handle that case properly. REQUIRE ) // WithPolicy adds given policy to a connection when passed as option to NewConn() func WithPolicy(p Policy) func(*Conn) { return func(c *Conn) { c.ProxyHeaderPolicy = p } } // LaxWhiteListPolicy returns a PolicyFunc which decides whether the // upstream ip is allowed to send a proxy header based on a list of allowed // IP addresses and IP ranges. In case upstream IP is not in list the proxy // header will be ignored. If one of the provided IP addresses or IP ranges // is invalid it will return an error instead of a PolicyFunc. func LaxWhiteListPolicy(allowed []string) (PolicyFunc, error) { allowFrom, err := parse(allowed) if err != nil { return nil, err } return whitelistPolicy(allowFrom, IGNORE), nil } // MustLaxWhiteListPolicy returns a LaxWhiteListPolicy but will panic if one // of the provided IP addresses or IP ranges is invalid. func MustLaxWhiteListPolicy(allowed []string) PolicyFunc { pfunc, err := LaxWhiteListPolicy(allowed) if err != nil { panic(err) } return pfunc } // StrictWhiteListPolicy returns a PolicyFunc which decides whether the // upstream ip is allowed to send a proxy header based on a list of allowed // IP addresses and IP ranges. In case upstream IP is not in list reading on // the connection will be refused on the first read. Please note: subsequent // reads do not error. It is the task of the code using the connection to // handle that case properly. If one of the provided IP addresses or IP // ranges is invalid it will return an error instead of a PolicyFunc. func StrictWhiteListPolicy(allowed []string) (PolicyFunc, error) { allowFrom, err := parse(allowed) if err != nil { return nil, err } return whitelistPolicy(allowFrom, REJECT), nil } // MustStrictWhiteListPolicy returns a StrictWhiteListPolicy but will panic // if one of the provided IP addresses or IP ranges is invalid. func MustStrictWhiteListPolicy(allowed []string) PolicyFunc { pfunc, err := StrictWhiteListPolicy(allowed) if err != nil { panic(err) } return pfunc } func whitelistPolicy(allowed []func(net.IP) bool, def Policy) PolicyFunc { return func(upstream net.Addr) (Policy, error) { upstreamIP, err := ipFromAddr(upstream) if err != nil { // something is wrong with the source IP, better reject the connection return REJECT, err } for _, allowFrom := range allowed { if allowFrom(upstreamIP) { return USE, nil } } return def, nil } } func parse(allowed []string) ([]func(net.IP) bool, error) { a := make([]func(net.IP) bool, len(allowed)) for i, allowFrom := range allowed { if strings.LastIndex(allowFrom, "/") > 0 { _, ipRange, err := net.ParseCIDR(allowFrom) if err != nil { return nil, fmt.Errorf("proxyproto: given string %q is not a valid IP range: %v", allowFrom, err) } a[i] = ipRange.Contains } else { allowed := net.ParseIP(allowFrom) if allowed == nil { return nil, fmt.Errorf("proxyproto: given string %q is not a valid IP address", allowFrom) } a[i] = allowed.Equal } } return a, nil } func ipFromAddr(upstream net.Addr) (net.IP, error) { upstreamString, _, err := net.SplitHostPort(upstream.String()) if err != nil { return nil, err } upstreamIP := net.ParseIP(upstreamString) if nil == upstreamIP { return nil, fmt.Errorf("proxyproto: invalid IP address") } return upstreamIP, nil } go-proxyproto-0.4.2/policy_test.go000066400000000000000000000111431400451755600172610ustar00rootroot00000000000000package proxyproto import ( "net" "testing" ) type failingAddr struct{} func (f failingAddr) Network() string { return "failing" } func (f failingAddr) String() string { return "failing" } func TestWhitelistPolicyReturnsErrorOnInvalidAddress(t *testing.T) { var cases = []struct { name string policy PolicyFunc }{ {"strict whitelist policy", MustStrictWhiteListPolicy([]string{"10.0.0.2", "10.0.0.3", "10.0.0.4", "10.0.0.0/30"})}, {"lax whitelist policy", MustLaxWhiteListPolicy([]string{"10.0.0.2", "10.0.0.3", "10.0.0.4", "10.0.0.0/30"})}, } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { _, err := tc.policy(failingAddr{}) if err == nil { t.Fatal("Expected error, got none") } }) } } func TestStrictWhitelistPolicyReturnsRejectWhenUpstreamIpAddrNotInWhitelist(t *testing.T) { p := MustStrictWhiteListPolicy([]string{"10.0.0.2", "10.0.0.3", "10.0.0.4", "10.0.0.0/30"}) upstream, err := net.ResolveTCPAddr("tcp", "10.0.0.5:45738") if err != nil { t.Fatalf("err: %v", err) } policy, err := p(upstream) if err != nil { t.Fatalf("err: %v", err) } if policy != REJECT { t.Fatalf("Expected policy REJECT, got %v", policy) } } func TestLaxWhitelistPolicyReturnsIgnoreWhenUpstreamIpAddrNotInWhitelist(t *testing.T) { p := MustLaxWhiteListPolicy([]string{"10.0.0.2", "10.0.0.3", "10.0.0.4", "10.0.0.0/30"}) upstream, err := net.ResolveTCPAddr("tcp", "10.0.0.5:45738") if err != nil { t.Fatalf("err: %v", err) } policy, err := p(upstream) if err != nil { t.Fatalf("err: %v", err) } if policy != IGNORE { t.Fatalf("Expected policy IGNORE, got %v", policy) } } func TestWhitelistPolicyReturnsUseWhenUpstreamIpAddrInWhitelist(t *testing.T) { var cases = []struct { name string policy PolicyFunc }{ {"strict whitelist policy", MustStrictWhiteListPolicy([]string{"10.0.0.2", "10.0.0.3", "10.0.0.4"})}, {"lax whitelist policy", MustLaxWhiteListPolicy([]string{"10.0.0.2", "10.0.0.3", "10.0.0.4"})}, } upstream, err := net.ResolveTCPAddr("tcp", "10.0.0.3:45738") if err != nil { t.Fatalf("err: %v", err) } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { policy, err := tc.policy(upstream) if err != nil { t.Fatalf("err: %v", err) } if policy != USE { t.Fatalf("Expected policy USE, got %v", policy) } }) } } func TestWhitelistPolicyReturnsUseWhenUpstreamIpAddrInWhitelistRange(t *testing.T) { var cases = []struct { name string policy PolicyFunc }{ {"strict whitelist policy", MustStrictWhiteListPolicy([]string{"10.0.0.0/29"})}, {"lax whitelist policy", MustLaxWhiteListPolicy([]string{"10.0.0.0/29"})}, } upstream, err := net.ResolveTCPAddr("tcp", "10.0.0.3:45738") if err != nil { t.Fatalf("err: %v", err) } for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { policy, err := tc.policy(upstream) if err != nil { t.Fatalf("err: %v", err) } if policy != USE { t.Fatalf("Expected policy USE, got %v", policy) } }) } } func Test_CreateWhitelistPolicyWithInvalidCidrReturnsError(t *testing.T) { _, err := StrictWhiteListPolicy([]string{"20/80"}) if err == nil { t.Error("Expected error, got none") } } func Test_CreateWhitelistPolicyWithInvalidIpAddressReturnsError(t *testing.T) { _, err := StrictWhiteListPolicy([]string{"855.222.233.11"}) if err == nil { t.Error("Expected error, got none") } } func Test_CreateLaxPolicyWithInvalidCidrReturnsError(t *testing.T) { _, err := LaxWhiteListPolicy([]string{"20/80"}) if err == nil { t.Error("Expected error, got none") } } func Test_CreateLaxPolicyWithInvalidIpAddresseturnsError(t *testing.T) { _, err := LaxWhiteListPolicy([]string{"855.222.233.11"}) if err == nil { t.Error("Expected error, got none") } } func Test_MustLaxWhiteListPolicyPanicsWithInvalidIpAddress(t *testing.T) { defer func() { if r := recover(); r == nil { t.Error("Expected a panic, but got none") } }() MustLaxWhiteListPolicy([]string{"855.222.233.11"}) } func Test_MustLaxWhiteListPolicyPanicsWithInvalidIpRange(t *testing.T) { defer func() { if r := recover(); r == nil { t.Error("Expected a panic, but got none") } }() MustLaxWhiteListPolicy([]string{"20/80"}) } func Test_MustStrictWhiteListPolicyPanicsWithInvalidIpAddress(t *testing.T) { defer func() { if r := recover(); r == nil { t.Error("Expected a panic, but got none") } }() MustStrictWhiteListPolicy([]string{"855.222.233.11"}) } func Test_MustStrictWhiteListPolicyPanicsWithInvalidIpRange(t *testing.T) { defer func() { if r := recover(); r == nil { t.Error("Expected a panic, but got none") } }() MustStrictWhiteListPolicy([]string{"20/80"}) } go-proxyproto-0.4.2/protocol.go000066400000000000000000000144031400451755600165660ustar00rootroot00000000000000package proxyproto import ( "bufio" "net" "sync" "time" ) // Listener is used to wrap an underlying listener, // whose connections may be using the HAProxy Proxy Protocol. // If the connection is using the protocol, the RemoteAddr() will return // the correct client address. type Listener struct { Listener net.Listener Policy PolicyFunc ValidateHeader Validator } // Conn is used to wrap and underlying connection which // may be speaking the Proxy Protocol. If it is, the RemoteAddr() will // return the address of the client instead of the proxy address. type Conn struct { bufReader *bufio.Reader conn net.Conn header *Header once sync.Once ProxyHeaderPolicy Policy Validate Validator readErr error } // Validator receives a header and decides whether it is a valid one // In case the header is not deemed valid it should return an error. type Validator func(*Header) error // ValidateHeader adds given validator for proxy headers to a connection when passed as option to NewConn() func ValidateHeader(v Validator) func(*Conn) { return func(c *Conn) { if v != nil { c.Validate = v } } } // Accept waits for and returns the next connection to the listener. func (p *Listener) Accept() (net.Conn, error) { // Get the underlying connection conn, err := p.Listener.Accept() if err != nil { return nil, err } proxyHeaderPolicy := USE if p.Policy != nil { proxyHeaderPolicy, err = p.Policy(conn.RemoteAddr()) if err != nil { // can't decide the policy, we can't accept the connection conn.Close() return nil, err } } newConn := NewConn( conn, WithPolicy(proxyHeaderPolicy), ValidateHeader(p.ValidateHeader), ) return newConn, nil } // Close closes the underlying listener. func (p *Listener) Close() error { return p.Listener.Close() } // Addr returns the underlying listener's network address. func (p *Listener) Addr() net.Addr { return p.Listener.Addr() } // NewConn is used to wrap a net.Conn that may be speaking // the proxy protocol into a proxyproto.Conn func NewConn(conn net.Conn, opts ...func(*Conn)) *Conn { pConn := &Conn{ bufReader: bufio.NewReader(conn), conn: conn, } for _, opt := range opts { opt(pConn) } return pConn } // Read is check for the proxy protocol header when doing // the initial scan. If there is an error parsing the header, // it is returned and the socket is closed. func (p *Conn) Read(b []byte) (int, error) { p.once.Do(func() { p.readErr = p.readHeader() }) if p.readErr != nil { return 0, p.readErr } return p.bufReader.Read(b) } // Write wraps original conn.Write func (p *Conn) Write(b []byte) (int, error) { return p.conn.Write(b) } // Close wraps original conn.Close func (p *Conn) Close() error { return p.conn.Close() } // ProxyHeader returns the proxy protocol header, if any. If an error occurs // while reading the proxy header, nil is returned. func (p *Conn) ProxyHeader() *Header { p.once.Do(func() { p.readErr = p.readHeader() }) return p.header } // LocalAddr returns the address of the server if the proxy // protocol is being used, otherwise just returns the address of // the socket server. In case an error happens on reading the // proxy header the original LocalAddr is returned, not the one // from the proxy header even if the proxy header itself is // syntactically correct. func (p *Conn) LocalAddr() net.Addr { p.once.Do(func() { p.readErr = p.readHeader() }) if p.header == nil || p.header.Command.IsLocal() || p.readErr != nil { return p.conn.LocalAddr() } return p.header.DestinationAddr } // RemoteAddr returns the address of the client if the proxy // protocol is being used, otherwise just returns the address of // the socket peer. In case an error happens on reading the // proxy header the original RemoteAddr is returned, not the one // from the proxy header even if the proxy header itself is // syntactically correct. func (p *Conn) RemoteAddr() net.Addr { p.once.Do(func() { p.readErr = p.readHeader() }) if p.header == nil || p.header.Command.IsLocal() || p.readErr != nil { return p.conn.RemoteAddr() } return p.header.SourceAddr } // Raw returns the underlying connection which can be casted to // a concrete type, allowing access to specialized functions. // // Use this ONLY if you know exactly what you are doing. func (p *Conn) Raw() net.Conn { return p.conn } // TCPConn returns the underlying TCP connection, // allowing access to specialized functions. // // Use this ONLY if you know exactly what you are doing. func (p *Conn) TCPConn() (conn *net.TCPConn, ok bool) { conn, ok = p.conn.(*net.TCPConn) return } // UnixConn returns the underlying Unix socket connection, // allowing access to specialized functions. // // Use this ONLY if you know exactly what you are doing. func (p *Conn) UnixConn() (conn *net.UnixConn, ok bool) { conn, ok = p.conn.(*net.UnixConn) return } // UDPConn returns the underlying UDP connection, // allowing access to specialized functions. // // Use this ONLY if you know exactly what you are doing. func (p *Conn) UDPConn() (conn *net.UDPConn, ok bool) { conn, ok = p.conn.(*net.UDPConn) return } // SetDeadline wraps original conn.SetDeadline func (p *Conn) SetDeadline(t time.Time) error { return p.conn.SetDeadline(t) } // SetReadDeadline wraps original conn.SetReadDeadline func (p *Conn) SetReadDeadline(t time.Time) error { return p.conn.SetReadDeadline(t) } // SetWriteDeadline wraps original conn.SetWriteDeadline func (p *Conn) SetWriteDeadline(t time.Time) error { return p.conn.SetWriteDeadline(t) } func (p *Conn) readHeader() error { header, err := Read(p.bufReader) // For the purpose of this wrapper shamefully stolen from armon/go-proxyproto // let's act as if there was no error when PROXY protocol is not present. if err == ErrNoProxyProtocol { // but not if it is required that the connection has one if p.ProxyHeaderPolicy == REQUIRE { return err } return nil } // proxy protocol header was found if err == nil && header != nil { switch p.ProxyHeaderPolicy { case REJECT: // this connection is not allowed to send one return ErrSuperfluousProxyHeader case USE, REQUIRE: if p.Validate != nil { err = p.Validate(header) if err != nil { return err } } p.header = header } } return err } go-proxyproto-0.4.2/protocol_test.go000066400000000000000000000432321400451755600176270ustar00rootroot00000000000000// This file was shamefully stolen from github.com/armon/go-proxyproto. // It has been heavily edited to conform to this lib. // // Thanks @armon package proxyproto import ( "bytes" "crypto/tls" "crypto/x509" "fmt" "net" "testing" ) func TestPassthrough(t *testing.T) { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("err: %v", err) } pl := &Listener{Listener: l} go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { t.Fatalf("err: %v", err) } defer conn.Close() conn.Write([]byte("ping")) recv := make([]byte, 4) _, err = conn.Read(recv) if err != nil { t.Fatalf("err: %v", err) } if !bytes.Equal(recv, []byte("pong")) { t.Fatalf("bad: %v", recv) } }() conn, err := pl.Accept() if err != nil { t.Fatalf("err: %v", err) } defer conn.Close() recv := make([]byte, 4) _, err = conn.Read(recv) if err != nil { t.Fatalf("err: %v", err) } if !bytes.Equal(recv, []byte("ping")) { t.Fatalf("bad: %v", recv) } if _, err := conn.Write([]byte("pong")); err != nil { t.Fatalf("err: %v", err) } } func TestParse_ipv4(t *testing.T) { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("err: %v", err) } pl := &Listener{Listener: l} header := &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, } go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { t.Fatalf("err: %v", err) } defer conn.Close() // Write out the header! header.WriteTo(conn) conn.Write([]byte("ping")) recv := make([]byte, 4) _, err = conn.Read(recv) if err != nil { t.Fatalf("err: %v", err) } if !bytes.Equal(recv, []byte("pong")) { t.Fatalf("bad: %v", recv) } }() conn, err := pl.Accept() if err != nil { t.Fatalf("err: %v", err) } defer conn.Close() recv := make([]byte, 4) _, err = conn.Read(recv) if err != nil { t.Fatalf("err: %v", err) } if !bytes.Equal(recv, []byte("ping")) { t.Fatalf("bad: %v", recv) } if _, err := conn.Write([]byte("pong")); err != nil { t.Fatalf("err: %v", err) } // Check the remote addr addr := conn.RemoteAddr().(*net.TCPAddr) if addr.IP.String() != "10.1.1.1" { t.Fatalf("bad: %v", addr) } if addr.Port != 1000 { t.Fatalf("bad: %v", addr) } h := conn.(*Conn).ProxyHeader() if !h.EqualsTo(header) { t.Errorf("bad: %v", h) } } func TestParse_ipv6(t *testing.T) { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("err: %v", err) } pl := &Listener{Listener: l} header := &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv6, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("ffff::ffff"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("ffff::ffff"), Port: 2000, }, } go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { t.Fatalf("err: %v", err) } defer conn.Close() // Write out the header! header.WriteTo(conn) conn.Write([]byte("ping")) recv := make([]byte, 4) _, err = conn.Read(recv) if err != nil { t.Fatalf("err: %v", err) } if !bytes.Equal(recv, []byte("pong")) { t.Fatalf("bad: %v", recv) } }() conn, err := pl.Accept() if err != nil { t.Fatalf("err: %v", err) } defer conn.Close() recv := make([]byte, 4) _, err = conn.Read(recv) if err != nil { t.Fatalf("err: %v", err) } if !bytes.Equal(recv, []byte("ping")) { t.Fatalf("bad: %v", recv) } if _, err := conn.Write([]byte("pong")); err != nil { t.Fatalf("err: %v", err) } // Check the remote addr addr := conn.RemoteAddr().(*net.TCPAddr) if addr.IP.String() != "ffff::ffff" { t.Fatalf("bad: %v", addr) } if addr.Port != 1000 { t.Fatalf("bad: %v", addr) } h := conn.(*Conn).ProxyHeader() if !h.EqualsTo(header) { t.Errorf("bad: %v", h) } } func TestAcceptReturnsErrorWhenPolicyFuncErrors(t *testing.T) { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("err: %v", err) } expectedErr := fmt.Errorf("failure") policyFunc := func(upstream net.Addr) (Policy, error) { return USE, expectedErr } pl := &Listener{Listener: l, Policy: policyFunc} go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { t.Fatalf("err: %v", err) } defer conn.Close() }() conn, err := pl.Accept() if err != expectedErr { t.Fatalf("Expected error %v, got %v", expectedErr, err) } if conn != nil { t.Fatalf("Expected no connection, got %v", conn) } } func TestReadingIsRefusedWhenProxyHeaderRequiredButMissing(t *testing.T) { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("err: %v", err) } policyFunc := func(upstream net.Addr) (Policy, error) { return REQUIRE, nil } pl := &Listener{Listener: l, Policy: policyFunc} go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { t.Fatalf("err: %v", err) } defer conn.Close() conn.Write([]byte("ping")) }() conn, err := pl.Accept() if err != nil { t.Fatalf("err: %v", err) } defer conn.Close() recv := make([]byte, 4) _, err = conn.Read(recv) if err != ErrNoProxyProtocol { t.Fatalf("Expected error %v, received %v", ErrNoProxyProtocol, err) } } func TestReadingIsRefusedWhenProxyHeaderPresentButNotAllowed(t *testing.T) { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("err: %v", err) } policyFunc := func(upstream net.Addr) (Policy, error) { return REJECT, nil } pl := &Listener{Listener: l, Policy: policyFunc} go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { t.Fatalf("err: %v", err) } defer conn.Close() header := &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, } header.WriteTo(conn) }() conn, err := pl.Accept() if err != nil { t.Fatalf("err: %v", err) } defer conn.Close() recv := make([]byte, 4) _, err = conn.Read(recv) if err != ErrSuperfluousProxyHeader { t.Fatalf("Expected error %v, received %v", ErrSuperfluousProxyHeader, err) } } func TestIgnorePolicyIgnoresIpFromProxyHeader(t *testing.T) { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("err: %v", err) } policyFunc := func(upstream net.Addr) (Policy, error) { return IGNORE, nil } pl := &Listener{Listener: l, Policy: policyFunc} go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { t.Fatalf("err: %v", err) } defer conn.Close() // Write out the header! header := &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, } header.WriteTo(conn) conn.Write([]byte("ping")) recv := make([]byte, 4) _, err = conn.Read(recv) if err != nil { t.Fatalf("err: %v", err) } if !bytes.Equal(recv, []byte("pong")) { t.Fatalf("bad: %v", recv) } }() conn, err := pl.Accept() if err != nil { t.Fatalf("err: %v", err) } defer conn.Close() recv := make([]byte, 4) _, err = conn.Read(recv) if err != nil { t.Fatalf("err: %v", err) } if !bytes.Equal(recv, []byte("ping")) { t.Fatalf("bad: %v", recv) } if _, err := conn.Write([]byte("pong")); err != nil { t.Fatalf("err: %v", err) } // Check the remote addr addr := conn.RemoteAddr().(*net.TCPAddr) if addr.IP.String() != "127.0.0.1" { t.Fatalf("bad: %v", addr) } } func Test_AllOptionsAreRecognized(t *testing.T) { recognizedOpt1 := false opt1 := func(c *Conn) { recognizedOpt1 = true } recognizedOpt2 := false opt2 := func(c *Conn) { recognizedOpt2 = true } server, client := net.Pipe() defer func() { client.Close() }() c := NewConn(server, opt1, opt2) if !recognizedOpt1 { t.Error("Expected option 1 recognized") } if !recognizedOpt2 { t.Error("Expected option 2 recognized") } c.Close() } func TestReadingIsRefusedOnErrorWhenRemoteAddrRequestedFirst(t *testing.T) { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("err: %v", err) } policyFunc := func(upstream net.Addr) (Policy, error) { return REQUIRE, nil } pl := &Listener{Listener: l, Policy: policyFunc} go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { t.Fatalf("err: %v", err) } defer conn.Close() conn.Write([]byte("ping")) }() conn, err := pl.Accept() if err != nil { t.Fatalf("err: %v", err) } defer conn.Close() _ = conn.RemoteAddr() recv := make([]byte, 4) _, err = conn.Read(recv) if err != ErrNoProxyProtocol { t.Fatalf("Expected error %v, received %v", ErrNoProxyProtocol, err) } } func TestReadingIsRefusedOnErrorWhenLocalAddrRequestedFirst(t *testing.T) { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("err: %v", err) } policyFunc := func(upstream net.Addr) (Policy, error) { return REQUIRE, nil } pl := &Listener{Listener: l, Policy: policyFunc} go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { t.Fatalf("err: %v", err) } defer conn.Close() conn.Write([]byte("ping")) }() conn, err := pl.Accept() if err != nil { t.Fatalf("err: %v", err) } defer conn.Close() _ = conn.LocalAddr() recv := make([]byte, 4) _, err = conn.Read(recv) if err != ErrNoProxyProtocol { t.Fatalf("Expected error %v, received %v", ErrNoProxyProtocol, err) } } func Test_ConnectionCasts(t *testing.T) { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("err: %v", err) } policyFunc := func(upstream net.Addr) (Policy, error) { return REQUIRE, nil } pl := &Listener{Listener: l, Policy: policyFunc} go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { t.Fatalf("err: %v", err) } defer conn.Close() conn.Write([]byte("ping")) }() conn, err := pl.Accept() if err != nil { t.Fatalf("err: %v", err) } defer conn.Close() proxyprotoConn := conn.(*Conn) _, ok := proxyprotoConn.TCPConn() if !ok { t.Fatal("err: should be a tcp connection") } _, ok = proxyprotoConn.UDPConn() if ok { t.Fatal("err: should be a tcp connection not udp") } _, ok = proxyprotoConn.UnixConn() if ok { t.Fatal("err: should be a tcp connection not unix") } _, ok = proxyprotoConn.Raw().(*net.TCPConn) if !ok { t.Fatal("err: should be a tcp connection") } } func Test_ConnectionErrorsWhenHeaderValidationFails(t *testing.T) { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("err: %v", err) } validationError := fmt.Errorf("failed to validate") pl := &Listener{Listener: l, ValidateHeader: func(*Header) error { return validationError }} go func() { conn, err := net.Dial("tcp", pl.Addr().String()) if err != nil { t.Fatalf("err: %v", err) } defer conn.Close() // Write out the header! header := &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, } header.WriteTo(conn) }() conn, err := pl.Accept() if err != nil { t.Fatalf("err: %v", err) } defer conn.Close() recv := make([]byte, 4) _, err = conn.Read(recv) if err != validationError { t.Fatalf("expected validation error, got %v", err) } } type TestTLSServer struct { Listener net.Listener // TLS is the optional TLS configuration, populated with a new config // after TLS is started. If set on an unstarted server before StartTLS // is called, existing fields are copied into the new config. TLS *tls.Config TLSClientConfig *tls.Config // certificate is a parsed version of the TLS config certificate, if present. certificate *x509.Certificate } func (s *TestTLSServer) Addr() string { return s.Listener.Addr().String() } func (s *TestTLSServer) Close() { s.Listener.Close() } // based on net/http/httptest/Server.StartTLS func NewTestTLSServer(l net.Listener) *TestTLSServer { s := &TestTLSServer{} cert, err := tls.X509KeyPair(LocalhostCert, LocalhostKey) if err != nil { panic(fmt.Sprintf("httptest: NewTLSServer: %v", err)) } s.TLS = new(tls.Config) if len(s.TLS.Certificates) == 0 { s.TLS.Certificates = []tls.Certificate{cert} } s.certificate, err = x509.ParseCertificate(s.TLS.Certificates[0].Certificate[0]) if err != nil { panic(fmt.Sprintf("NewTestTLSServer: %v", err)) } certpool := x509.NewCertPool() certpool.AddCert(s.certificate) s.TLSClientConfig = &tls.Config{ RootCAs: certpool, } s.Listener = tls.NewListener(l, s.TLS) return s } func Test_TLSServer(t *testing.T) { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("err: %v", err) } s := NewTestTLSServer(l) s.Listener = &Listener{ Listener: s.Listener, Policy: func(upstream net.Addr) (Policy, error) { return REQUIRE, nil }, } defer s.Close() go func() { conn, err := tls.Dial("tcp", s.Addr(), s.TLSClientConfig) if err != nil { t.Fatalf("err: %v", err) } defer conn.Close() // Write out the header! header := &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, } header.WriteTo(conn) conn.Write([]byte("test")) }() conn, err := s.Listener.Accept() if err != nil { t.Fatalf("err: %v", err) } defer conn.Close() recv := make([]byte, 1024) n, err := conn.Read(recv) if err != nil { t.Fatalf("expected no error, got %v", err) } if string(recv[:n]) != "test" { t.Fatalf("expected \"test\", got \"%s\" %v", recv[:n], recv[:n]) } } func Test_MisconfiguredTLSServerRespondsWithUnderlyingError(t *testing.T) { l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("err: %v", err) } s := NewTestTLSServer(l) s.Listener = &Listener{ Listener: s.Listener, Policy: func(upstream net.Addr) (Policy, error) { return REQUIRE, nil }, } defer s.Close() go func() { // this is not a valid TLS connection, we are // connecting to the TLS endpoint via plain TCP. // // it's an example of a configuration error: // client: HTTP -> PROXY // server: PROXY -> TLS -> HTTP // // we want to bubble up the underlying error, // in this case a tls handshake error, instead // of responding with a non-descript // > "Proxy protocol signature not present". conn, err := net.Dial("tcp", s.Addr()) if err != nil { t.Fatalf("err: %v", err) } defer conn.Close() // Write out the header! header := &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: &net.TCPAddr{ IP: net.ParseIP("10.1.1.1"), Port: 1000, }, DestinationAddr: &net.TCPAddr{ IP: net.ParseIP("20.2.2.2"), Port: 2000, }, } header.WriteTo(conn) conn.Write([]byte("GET /foo/bar HTTP/1.1")) }() conn, err := s.Listener.Accept() if err != nil { t.Fatalf("err: %v", err) } defer conn.Close() recv := make([]byte, 1024) _, err = conn.Read(recv) if err.Error() != "tls: first record does not look like a TLS handshake" { t.Fatalf("expected tls handshake error, got %s", err) } } // copied from src/net/http/internal/testcert.go // Copyright 2015 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // LocalhostCert is a PEM-encoded TLS cert with SAN IPs // "127.0.0.1" and "[::1]", expiring at Jan 29 16:00:00 2084 GMT. // generated from src/crypto/tls: // go run generate_cert.go --rsa-bits 1024 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h var LocalhostCert = []byte(`-----BEGIN CERTIFICATE----- MIICEzCCAXygAwIBAgIQMIMChMLGrR+QvmQvpwAU6zANBgkqhkiG9w0BAQsFADAS MRAwDgYDVQQKEwdBY21lIENvMCAXDTcwMDEwMTAwMDAwMFoYDzIwODQwMTI5MTYw MDAwWjASMRAwDgYDVQQKEwdBY21lIENvMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCB iQKBgQDuLnQAI3mDgey3VBzWnB2L39JUU4txjeVE6myuDqkM/uGlfjb9SjY1bIw4 iA5sBBZzHi3z0h1YV8QPuxEbi4nW91IJm2gsvvZhIrCHS3l6afab4pZBl2+XsDul rKBxKKtD1rGxlG4LjncdabFn9gvLZad2bSysqz/qTAUStTvqJQIDAQABo2gwZjAO BgNVHQ8BAf8EBAMCAqQwEwYDVR0lBAwwCgYIKwYBBQUHAwEwDwYDVR0TAQH/BAUw AwEB/zAuBgNVHREEJzAlggtleGFtcGxlLmNvbYcEfwAAAYcQAAAAAAAAAAAAAAAA AAAAATANBgkqhkiG9w0BAQsFAAOBgQCEcetwO59EWk7WiJsG4x8SY+UIAA+flUI9 tyC4lNhbcF2Idq9greZwbYCqTTTr2XiRNSMLCOjKyI7ukPoPjo16ocHj+P3vZGfs h1fIw3cSS2OolhloGw/XM6RWPWtPAlGykKLciQrBru5NAPvCMsb/I1DAceTiotQM fblo6RBxUQ== -----END CERTIFICATE-----`) // LocalhostKey is the private key for localhostCert. var LocalhostKey = []byte(`-----BEGIN RSA PRIVATE KEY----- MIICXgIBAAKBgQDuLnQAI3mDgey3VBzWnB2L39JUU4txjeVE6myuDqkM/uGlfjb9 SjY1bIw4iA5sBBZzHi3z0h1YV8QPuxEbi4nW91IJm2gsvvZhIrCHS3l6afab4pZB l2+XsDulrKBxKKtD1rGxlG4LjncdabFn9gvLZad2bSysqz/qTAUStTvqJQIDAQAB AoGAGRzwwir7XvBOAy5tM/uV6e+Zf6anZzus1s1Y1ClbjbE6HXbnWWF/wbZGOpet 3Zm4vD6MXc7jpTLryzTQIvVdfQbRc6+MUVeLKwZatTXtdZrhu+Jk7hx0nTPy8Jcb uJqFk541aEw+mMogY/xEcfbWd6IOkp+4xqjlFLBEDytgbIECQQDvH/E6nk+hgN4H qzzVtxxr397vWrjrIgPbJpQvBsafG7b0dA4AFjwVbFLmQcj2PprIMmPcQrooz8vp jy4SHEg1AkEA/v13/5M47K9vCxmb8QeD/asydfsgS5TeuNi8DoUBEmiSJwma7FXY fFUtxuvL7XvjwjN5B30pNEbc6Iuyt7y4MQJBAIt21su4b3sjXNueLKH85Q+phy2U fQtuUE9txblTu14q3N7gHRZB4ZMhFYyDy8CKrN2cPg/Fvyt0Xlp/DoCzjA0CQQDU y2ptGsuSmgUtWj3NM9xuwYPm+Z/F84K6+ARYiZ6PYj013sovGKUFfYAqVXVlxtIX qyUBnu3X9ps8ZfjLZO7BAkEAlT4R5Yl6cGhaJQYZHOde3JEMhNRcVFMO8dJDaFeo f9Oeos0UUothgiDktdQHxdNEwLjQf7lJJBzV+5OtwswCWA== -----END RSA PRIVATE KEY-----`) go-proxyproto-0.4.2/tlv.go000066400000000000000000000070621400451755600155350ustar00rootroot00000000000000// Type-Length-Value splitting and parsing for proxy protocol V2 // See spec https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt sections 2.2 to 2.7 and package proxyproto import ( "encoding/binary" "errors" "fmt" "math" ) const ( // Section 2.2 PP2_TYPE_ALPN PP2Type = 0x01 PP2_TYPE_AUTHORITY PP2Type = 0x02 PP2_TYPE_CRC32C PP2Type = 0x03 PP2_TYPE_NOOP PP2Type = 0x04 PP2_TYPE_SSL PP2Type = 0x20 PP2_SUBTYPE_SSL_VERSION PP2Type = 0x21 PP2_SUBTYPE_SSL_CN PP2Type = 0x22 PP2_SUBTYPE_SSL_CIPHER PP2Type = 0x23 PP2_SUBTYPE_SSL_SIG_ALG PP2Type = 0x24 PP2_SUBTYPE_SSL_KEY_ALG PP2Type = 0x25 PP2_TYPE_NETNS PP2Type = 0x30 // Section 2.2.7, reserved types PP2_TYPE_MIN_CUSTOM PP2Type = 0xE0 PP2_TYPE_MAX_CUSTOM PP2Type = 0xEF PP2_TYPE_MIN_EXPERIMENT PP2Type = 0xF0 PP2_TYPE_MAX_EXPERIMENT PP2Type = 0xF7 PP2_TYPE_MIN_FUTURE PP2Type = 0xF8 PP2_TYPE_MAX_FUTURE PP2Type = 0xFF ) var ( ErrTruncatedTLV = errors.New("proxyproto: truncated TLV") ErrMalformedTLV = errors.New("proxyproto: malformed TLV Value") ErrIncompatibleTLV = errors.New("proxyproto: incompatible TLV type") ) // PP2Type is the proxy protocol v2 type type PP2Type byte // TLV is a uninterpreted Type-Length-Value for V2 protocol, see section 2.2 type TLV struct { Type PP2Type Value []byte } // SplitTLVs splits the Type-Length-Value vector, returns the vector or an error. func SplitTLVs(raw []byte) ([]TLV, error) { var tlvs []TLV for i := 0; i < len(raw); { tlv := TLV{ Type: PP2Type(raw[i]), } if len(raw)-i <= 2 { return nil, ErrTruncatedTLV } tlvLen := int(binary.BigEndian.Uint16(raw[i+1 : i+3])) // Max length = 65K i += 3 if i+tlvLen > len(raw) { return nil, ErrTruncatedTLV } // Ignore no-op padding if tlv.Type != PP2_TYPE_NOOP { tlv.Value = make([]byte, tlvLen) copy(tlv.Value, raw[i:i+tlvLen]) } i += tlvLen tlvs = append(tlvs, tlv) } return tlvs, nil } // JoinTLVs joins multiple Type-Length-Value records. func JoinTLVs(tlvs []TLV) ([]byte, error) { var raw []byte for _, tlv := range tlvs { if len(tlv.Value) > math.MaxUint16 { return nil, fmt.Errorf("proxyproto: cannot format TLV %v with length %d", tlv.Type, len(tlv.Value)) } var length [2]byte binary.BigEndian.PutUint16(length[:], uint16(len(tlv.Value))) raw = append(raw, byte(tlv.Type)) raw = append(raw, length[:]...) raw = append(raw, tlv.Value...) } return raw, nil } // Registered is true if the type is registered in the spec, see section 2.2 func (p PP2Type) Registered() bool { switch p { case PP2_TYPE_ALPN, PP2_TYPE_AUTHORITY, PP2_TYPE_CRC32C, PP2_TYPE_NOOP, PP2_TYPE_SSL, PP2_SUBTYPE_SSL_VERSION, PP2_SUBTYPE_SSL_CN, PP2_SUBTYPE_SSL_CIPHER, PP2_SUBTYPE_SSL_SIG_ALG, PP2_SUBTYPE_SSL_KEY_ALG, PP2_TYPE_NETNS: return true } return false } // App is true if the type is reserved for application specific data, see section 2.2.7 func (p PP2Type) App() bool { return p >= PP2_TYPE_MIN_CUSTOM && p <= PP2_TYPE_MAX_CUSTOM } // Experiment is true if the type is reserved for temporary experimental use by application developers, see section 2.2.7 func (p PP2Type) Experiment() bool { return p >= PP2_TYPE_MIN_EXPERIMENT && p <= PP2_TYPE_MAX_EXPERIMENT } // Future is true is the type is reserved for future use, see section 2.2.7 func (p PP2Type) Future() bool { return p >= PP2_TYPE_MIN_FUTURE } // Spec is true if the type is covered by the spec, see section 2.2 and 2.2.7 func (p PP2Type) Spec() bool { return p.Registered() || p.App() || p.Experiment() || p.Future() } go-proxyproto-0.4.2/tlv_test.go000066400000000000000000000113601400451755600165700ustar00rootroot00000000000000package proxyproto import ( "bufio" "bytes" "testing" ) var ( fixtureOneByteTLV = []byte{byte(PP2_TYPE_MIN_CUSTOM) + 1} fixtureTwoByteTLV = []byte{byte(PP2_TYPE_MIN_CUSTOM) + 2, 0x00} fixtureEmptyLenTLV = []byte{byte(PP2_TYPE_MIN_CUSTOM) + 3, 0x00, 0x01} fixturePartialLenTLV = []byte{byte(PP2_TYPE_MIN_CUSTOM) + 3, 0x00, 0x02, 0x00} ) func checkTLVs(t *testing.T, name string, raw []byte, expected []PP2Type) []TLV { header, err := parseVersion2(bufio.NewReader(bytes.NewReader(raw))) if err != nil { t.Fatalf("%s: Unexpected error reading header %#v", name, err) } tlvs, err := header.TLVs() if err != nil { t.Fatalf("%s: Unexpected error splitting TLVS %#v", name, err) } if len(tlvs) != len(expected) { t.Fatalf("%s: Expected %d TLVs, actual %d", name, len(expected), len(tlvs)) } for i, et := range expected { if at := tlvs[i].Type; at != et { t.Fatalf("%s: Expected type %X, actual %X", name, et, at) } } return tlvs } var invalidTLVTests = []struct { name string reader *bufio.Reader expectedError error }{ { name: "One byte TLV", reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv4)), fixtureWithTLV(lengthV4Bytes, fixtureIPv4Address, fixtureOneByteTLV)...)), expectedError: ErrTruncatedTLV, }, { name: "Two byte TLV", reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv4)), fixtureWithTLV(lengthV4Bytes, fixtureIPv4Address, fixtureTwoByteTLV)...)), expectedError: ErrTruncatedTLV, }, { name: "Empty Len TLV", reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv4)), fixtureWithTLV(lengthV4Bytes, fixtureIPv4Address, fixtureEmptyLenTLV)...)), expectedError: ErrTruncatedTLV, }, { name: "Partial Len TLV", reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv4)), fixtureWithTLV(lengthV4Bytes, fixtureIPv4Address, fixturePartialLenTLV)...)), expectedError: ErrTruncatedTLV, }, } func TestValid0Length(t *testing.T) { r := bufio.NewReader(bytes.NewReader(append(append(SIGV2, byte(PROXY), byte(TCPv4)), fixtureWithTLV(lengthV4Bytes, fixtureIPv4Address, []byte{byte(PP2_TYPE_MIN_CUSTOM), 0x00, 0x00})...))) h, err := Read(r) if err != nil { t.Fatalf("unexpected error: %v", err) } tlvs, err := h.TLVs() if err != nil { t.Fatalf("unexpected error: %v", err) } if len(tlvs) != 1 { t.Fatalf("expected 1 tlv, got %d", len(tlvs)) } if len(tlvs[0].Value) != 0 { t.Fatalf("expected 0 byte tlv value, got %x", tlvs[0].Value) } } func TestInvalidV2TLV(t *testing.T) { for _, tc := range invalidTLVTests { t.Run(tc.name, func(t *testing.T) { if hdr, err := Read(tc.reader); err != nil { t.Fatalf("TestInvalidV2TLV %s: unexpected error reading proxy protocol %#v", tc.name, err) } else if _, err := hdr.TLVs(); err != tc.expectedError { t.Fatalf("TestInvalidV2TLV %s: expected %#v, actual %#v", tc.name, tc.expectedError, err) } }) } } func TestV2TLVPP2Registered(t *testing.T) { pp2RegTypes := []PP2Type{ PP2_TYPE_ALPN, PP2_TYPE_AUTHORITY, PP2_TYPE_CRC32C, PP2_TYPE_NOOP, PP2_TYPE_SSL, PP2_SUBTYPE_SSL_VERSION, PP2_SUBTYPE_SSL_CN, PP2_SUBTYPE_SSL_CIPHER, PP2_SUBTYPE_SSL_SIG_ALG, PP2_SUBTYPE_SSL_KEY_ALG, PP2_TYPE_NETNS, } pp2RegMap := make(map[PP2Type]bool) for _, p := range pp2RegTypes { pp2RegMap[p] = true if !p.Registered() { t.Fatalf("TestV2TLVPP2Registered: type %x should be registered", p) } if !p.Spec() { t.Fatalf("TestV2TLVPP2Registered: type %x should be in spec", p) } if p.App() { t.Fatalf("TestV2TLVPP2Registered: type %x unexpectedly app", p) } if p.Experiment() { t.Fatalf("TestV2TLVPP2Registered: type %x unexpectedly experiment", p) } if p.Future() { t.Fatalf("TestV2TLVPP2Registered: type %x unexpectedly future", p) } } lastType := PP2Type(0xFF) for i := PP2Type(0x00); i < lastType; i++ { if !pp2RegMap[i] { if i.Registered() { t.Fatalf("TestV2TLVPP2Registered: type %x unexpectedly registered", i) } } } if lastType.Registered() { t.Fatalf("TestV2TLVPP2Registered: type %x unexpectedly registered", lastType) } } func TestJoinTLVs(t *testing.T) { tests := []struct { name string raw []byte tlvs []TLV }{ { name: "authority TLV", raw: append([]byte{byte(PP2_TYPE_AUTHORITY), 0x00, 0x0B}, []byte("example.org")...), tlvs: []TLV{{ Type: PP2_TYPE_AUTHORITY, Value: []byte("example.org"), }}, }, { name: "empty TLV", raw: []byte{byte(PP2_TYPE_NOOP), 0x00, 0x00}, tlvs: []TLV{{ Type: PP2_TYPE_NOOP, Value: nil, }}, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { if raw, err := JoinTLVs(tc.tlvs); err != nil { t.Fatalf("unexpected error: %v", err) } else if !bytes.Equal(raw, tc.raw) { t.Errorf("expected %#v, got %#v", tc.raw, raw) } }) } } go-proxyproto-0.4.2/tlvparse/000077500000000000000000000000001400451755600162345ustar00rootroot00000000000000go-proxyproto-0.4.2/tlvparse/aws.go000066400000000000000000000021301400451755600173510ustar00rootroot00000000000000// Amazon's application extension to TLVs for NLB VPC endpoint services // https://docs.aws.amazon.com/elasticloadbalancing/latest/network/load-balancer-target-groups.html#proxy-protocol package tlvparse import ( "regexp" "github.com/pires/go-proxyproto" ) const ( // Amazon's extension PP2_TYPE_AWS = 0xEA PP2_SUBTYPE_AWS_VPCE_ID = 0x01 ) var vpceRe = regexp.MustCompile("^[A-Za-z0-9-]*$") func IsAWSVPCEndpointID(tlv proxyproto.TLV) bool { return tlv.Type == PP2_TYPE_AWS && len(tlv.Value) > 0 && tlv.Value[0] == PP2_SUBTYPE_AWS_VPCE_ID } func AWSVPCEndpointID(tlv proxyproto.TLV) (string, error) { if !IsAWSVPCEndpointID(tlv) { return "", proxyproto.ErrIncompatibleTLV } vpce := string(tlv.Value[1:]) if !vpceRe.MatchString(vpce) { return "", proxyproto.ErrMalformedTLV } return vpce, nil } // FindAWSVPCEndpointID returns the first AWS VPC ID in the TLV if it exists and is well-formed. func FindAWSVPCEndpointID(tlvs []proxyproto.TLV) string { for _, tlv := range tlvs { if vpc, err := AWSVPCEndpointID(tlv); err == nil && vpc != "" { return vpc } } return "" } go-proxyproto-0.4.2/tlvparse/aws_test.go000066400000000000000000000170301400451755600204150ustar00rootroot00000000000000package tlvparse import ( "encoding/binary" "testing" "github.com/pires/go-proxyproto" ) var awsTestCases = []struct { name string raw []byte types []proxyproto.PP2Type valid func(*testing.T, string, []proxyproto.TLV) }{ { name: "VPCE example", // https://github.com/aws/elastic-load-balancing-tools/blob/c8eee30ab991ab4c57dc37d1c58f09f67bd534aa/proprot/tst/com/amazonaws/proprot/Compatibility_AwsNetworkLoadBalancerTest.java#L41..L67 raw: []byte{ 0x0d, 0x0a, 0x0d, 0x0a, /* Start of Sig */ 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, 0x54, 0x0a, /* End of Sig */ 0x21, 0x11, 0x00, 0x54, /* ver_cmd, fam and len */ 0xac, 0x1f, 0x07, 0x71, /* Caller src ip */ 0xac, 0x1f, 0x0a, 0x1f, /* Endpoint dst ip */ 0xc8, 0xf2, 0x00, 0x50, /* Proxy src port & dst port */ 0x03, 0x00, 0x04, 0xe8, /* CRC TLV start */ 0xd6, 0x89, 0x2d, 0xea, /* CRC TLV cont, VPCE id TLV start */ 0x00, 0x17, 0x01, 0x76, 0x70, 0x63, 0x65, 0x2d, 0x30, 0x38, 0x64, 0x32, 0x62, 0x66, 0x31, 0x35, 0x66, 0x61, 0x63, 0x35, 0x30, 0x30, 0x31, 0x63, 0x39, 0x04, 0x00, 0x24, /* VPCE id TLV end, NOOP TLV start*/ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, /* NOOP TLV end */ }, types: []proxyproto.PP2Type{proxyproto.PP2_TYPE_CRC32C, PP2_TYPE_AWS, proxyproto.PP2_TYPE_NOOP}, valid: func(t *testing.T, name string, tlvs []proxyproto.TLV) { if !IsAWSVPCEndpointID(tlvs[1]) { t.Fatalf("TestParseV2TLV %s: Expected tlvs[1] to be an AWSVPCEndpointID type", name) } vpce := "vpce-08d2bf15fac5001c9" if vpca, err := AWSVPCEndpointID(tlvs[1]); err != nil { t.Fatalf("TestParseV2TLV %s: Unexpected error when parsing AWSVPCEndpointID", name) } else if vpca != vpce { t.Fatalf("TestParseV2TLV %s: Unexpected VPC ID from tlvs[1] expected %#v, actual %#v", name, vpce, vpca) } if vpca := FindAWSVPCEndpointID(tlvs); vpca == "" { t.Fatalf("TestParseV2TLV %s: Expected to find AWSVPCEndpointID %#v in TLVs", name, vpce) } else if vpca != vpce { t.Fatalf("TestParseV2TLV %s: Unexpected AWSVPCEndpointID from header expected %#v, actual %#v", name, vpce, vpca) } }, }, { name: "VPCE capture", raw: []byte{ 0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, 0x54, 0x0a, 0x21, 0x11, 0x00, 0x54, 0xc0, 0xa8, 0x2c, 0x0a, 0xc0, 0xa8, 0x2c, 0x07, 0xcc, 0x3e, 0x24, 0x1b, 0x03, 0x00, 0x04, 0xb9, 0x28, 0x6f, 0xa6, 0xea, 0x00, 0x17, 0x01, 0x76, 0x70, 0x63, 0x65, 0x2d, 0x30, 0x30, 0x65, 0x61, 0x66, 0x63, 0x34, 0x35, 0x38, 0x65, 0x63, 0x39, 0x37, 0x62, 0x38, 0x33, 0x33, 0x04, 0x00, 0x24, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, }, types: []proxyproto.PP2Type{proxyproto.PP2_TYPE_CRC32C, PP2_TYPE_AWS, proxyproto.PP2_TYPE_NOOP}, valid: func(t *testing.T, name string, tlvs []proxyproto.TLV) { if !IsAWSVPCEndpointID(tlvs[1]) { t.Fatalf("TestParseV2TLV %s: Expected tlvs[1] to be an AWS VPC endpoint ID type", name) } vpce := "vpce-00eafc458ec97b833" if vpca, err := AWSVPCEndpointID(tlvs[1]); err != nil { t.Fatalf("TestParseV2TLV %s: Unexpected error when parsing AWS VPC ID", name) } else if vpca != vpce { t.Fatalf("TestParseV2TLV %s: Unexpected VPC ID from tlvs[1] expected %#v, actual %#v", name, vpce, vpca) } if vpca := FindAWSVPCEndpointID(tlvs); vpca == "" { t.Fatalf("TestParseV2TLV %s: Expected to find VPC ID %#v in TLVs", name, vpce) } else if vpca != vpce { t.Fatalf("TestParseV2TLV %s: Unexpected VPC ID from header expected %#v, actual %#v", name, vpce, vpca) } }, }, } func TestV2TLVAWSVPCEBadChars(t *testing.T) { badVPCE := "vcpe-!?***&&&&&&&" rawTLVs := vpceTLV(badVPCE) tlvs, err := proxyproto.SplitTLVs(rawTLVs) if len(tlvs) != 1 { t.Fatalf("TestV2TLVAWSVPCEBadChars: unexpected TLV length expected: %#v, actual: %#v", 1, tlvs) } if err != nil { t.Fatalf("TestV2TLVAWSVPCEBadChars: unexpected TLV parsing error %#v", err) } _, err = AWSVPCEndpointID(tlvs[0]) if err != proxyproto.ErrMalformedTLV { t.Fatalf("TestV2TLVAWSVPCEBadChars: unexpected error actual: %#v", err) } if FindAWSVPCEndpointID(tlvs) != "" { t.Fatal("TestV2TLVAWSVPCEBadChars: AWSVPCEndpointID unexpectedly found") } rawTLVs = vpceTLV("") tlvs, err = proxyproto.SplitTLVs(rawTLVs) if len(tlvs) != 1 { t.Fatalf("TestV2TLVAWSVPCEBadChars: unexpected TLV length expected: %#v, actual: %#v", 1, tlvs) } if err != nil { t.Fatalf("TestV2TLVAWSVPCEBadChars: unexpected TLV parsing error %#v", err) } parsedVPCE, err := AWSVPCEndpointID(tlvs[0]) if err != nil { t.Fatalf("TestV2TLVAWSVPCEBadChars: unexpected error actual: %#v", err) } if parsedVPCE != "" { t.Fatalf("TestV2TLVAWSVPCEBadChars: found non-empty vpce, actual: %#v", parsedVPCE) } parsedVPCE = FindAWSVPCEndpointID(tlvs) if parsedVPCE != "" { t.Fatal("TestV2TLVAWSVPCEBadChars: AWSVPECID unexpectedly found") } } func TestParseAWSVPCEndpointIDTLVs(t *testing.T) { for _, tc := range awsTestCases { t.Run(tc.name, func(t *testing.T) { tlvs := checkTLVs(t, tc.name, tc.raw, tc.types) tc.valid(t, tc.name, tlvs) }) } } func TestV2TLVAWSUnknownSubtype(t *testing.T) { vpce := "vpce-abc1234" rawTLVs := vpceTLV(vpce) tlvs, err := proxyproto.SplitTLVs(rawTLVs) if len(tlvs) != 1 { t.Fatalf("TestV2TLVAWSUnknownSubtype: unexpected TLV length expected: %#v, actual: %#v", 1, tlvs) } if err != nil { t.Fatalf("TestV2TLVAWSUnknownSubtype: unexpected TLV parsing error %#v", err) } avpce, err := AWSVPCEndpointID(tlvs[0]) if err != nil { t.Fatalf("TestV2TLVAWSUnknownSubtype: unexpected AWSVPCEndpointID error actual: %#v", err) } if avpce != vpce { t.Fatalf("TestV2TLVAWSUnknownSubtype: unexpected vpce value expected: %#v, actual: %#v", vpce, avpce) } avpce = FindAWSVPCEndpointID(tlvs) if avpce == "" { t.Fatal("TestV2TLVAWSUnknownSubtype: AWSVPCEndpointID unexpectedly missing") } if avpce != vpce { t.Fatalf("TestV2TLVAWSUnknownSubtype: unexpected AWSVPCEndpointID value expected: %#v, actual: %#v", vpce, avpce) } subtypeIndex := 3 // Sanity check if rawTLVs[subtypeIndex] != PP2_SUBTYPE_AWS_VPCE_ID { t.Fatalf("TestV2TLVAWSUnknownSubtype: unexpected subtype expected %x, actual %x", PP2_SUBTYPE_AWS_VPCE_ID, rawTLVs[subtypeIndex]) } rawTLVs[subtypeIndex] = PP2_SUBTYPE_AWS_VPCE_ID + 1 tlvs, err = proxyproto.SplitTLVs(rawTLVs) if len(tlvs) != 1 { t.Fatalf("TestV2TLVAWSUnknownSubtype: unexpected TLV length expected: %#v, actual: %#v", 1, tlvs) } if err != nil { t.Fatalf("TestV2TLVAWSUnknownSubtype: unexpected TLV parsing error %#v", err) } if IsAWSVPCEndpointID(tlvs[0]) { t.Fatalf("TestV2TLVAWSUnknownSubtype: AWSVPCEType() unexpectedly true after changing subtype") } _, err = AWSVPCEndpointID(tlvs[0]) if err != proxyproto.ErrIncompatibleTLV { t.Fatalf("TestV2TLVAWSUnknownSubtype: unexpected AWSVPCEndpointID error expected %#v, actual: %#v", proxyproto.ErrIncompatibleTLV, err) } if FindAWSVPCEndpointID(tlvs) != "" { t.Fatal("TestV2TLVAWSUnknownSubtype: AWSVPCEndpointID unexpectedly exists despite invalid subtype") } } func vpceTLV(vpce string) []byte { tlv := []byte{ PP2_TYPE_AWS, 0x00, 0x00, PP2_SUBTYPE_AWS_VPCE_ID, } binary.BigEndian.PutUint16(tlv[1:3], uint16(len(vpce)+1)) // +1 for subtype return append(tlv, []byte(vpce)...) } go-proxyproto-0.4.2/tlvparse/azure.go000066400000000000000000000032531400451755600177140ustar00rootroot00000000000000// Azure's application extension to TLVs for Private Link Services // https://docs.microsoft.com/en-us/azure/private-link/private-link-service-overview#getting-connection-information-using-tcp-proxy-v2 package tlvparse import ( "encoding/binary" "github.com/pires/go-proxyproto" ) const ( // Azure's extension PP2_TYPE_AZURE = 0xEE PP2_SUBTYPE_AZURE_PRIVATEENDPOINT_LINKID = 0x01 ) // IsAzurePrivateEndpointLinkID returns true if given TLV matches Azure Private Endpoint LinkID format func isAzurePrivateEndpointLinkID(tlv proxyproto.TLV) bool { return tlv.Type == PP2_TYPE_AZURE && len(tlv.Value) == 5 && tlv.Value[0] == PP2_SUBTYPE_AZURE_PRIVATEENDPOINT_LINKID } // AzurePrivateEndpointLinkID returns linkID if given TLV matches Azure Private Endpoint LinkID format // // Format description: // Field Length (Octets) Description // Type 1 PP2_TYPE_AZURE (0xEE) // Length 2 Length of value // Value 1 PP2_SUBTYPE_AZURE_PRIVATEENDPOINT_LINKID (0x01) // 4 UINT32 (4 bytes) representing the LINKID of the private endpoint. Encoded in little endian format. func azurePrivateEndpointLinkID(tlv proxyproto.TLV) (uint32, error) { if !isAzurePrivateEndpointLinkID(tlv) { return 0, proxyproto.ErrIncompatibleTLV } linkID := binary.LittleEndian.Uint32(tlv.Value[1:]) return linkID, nil } // FindAzurePrivateEndpointLinkID returns the first Azure Private Endpoint LinkID if it exists in the TLV collection // and a boolean indicating if it was found. func FindAzurePrivateEndpointLinkID(tlvs []proxyproto.TLV) (uint32, bool) { for _, tlv := range tlvs { if linkID, err := azurePrivateEndpointLinkID(tlv); err == nil { return linkID, true } } return 0, false } go-proxyproto-0.4.2/tlvparse/azure_test.go000066400000000000000000000044351400451755600207560ustar00rootroot00000000000000package tlvparse import ( "testing" "github.com/pires/go-proxyproto" ) func TestFindAzurePrivateEndpointLinkID(t *testing.T) { tests := []struct { name string tlvs []proxyproto.TLV wantLinkID uint32 wantFound bool }{ { name: "nil TLVs", tlvs: nil, wantLinkID: 0, wantFound: false, }, { name: "empty TLVs", tlvs: []proxyproto.TLV{}, wantLinkID: 0, wantFound: false, }, { name: "AWS VPC endpoint ID", tlvs: []proxyproto.TLV{ { Type: 0xEA, Value: []byte{0x01, 0x76, 0x70, 0x63, 0x65, 0x2d, 0x61, 0x62, 0x63, 0x31, 0x32, 0x33}, }, }, wantLinkID: 0, wantFound: false, }, { name: "Azure but wrong subtype", tlvs: []proxyproto.TLV{ { Type: 0xEE, Value: []byte{0x02, 0x01, 0x01, 0x01, 0x01}, }, }, wantLinkID: 0, wantFound: false, }, { name: "Azure but wrong length", tlvs: []proxyproto.TLV{ { Type: 0xEE, Value: []byte{0x02, 0x01, 0x01}, }, }, wantLinkID: 0, wantFound: false, }, { name: "Azure link ID", tlvs: []proxyproto.TLV{ { Type: 0xEE, Value: []byte{0x1, 0xc1, 0x45, 0x0, 0x21}, }, }, wantLinkID: 0x210045c1, wantFound: true, }, { name: "Multiple TLVs", tlvs: []proxyproto.TLV{ { // AWS Type: 0xEA, Value: []byte{0x01, 0x76, 0x70, 0x63, 0x65, 0x2d, 0x61, 0x62, 0x63, 0x31, 0x32, 0x33}, }, { // Azure but wrong subtype Type: 0xEE, Value: []byte{0x02, 0x01, 0x01, 0x01, 0x01}, }, { // Azure but wrong length Type: 0xEE, Value: []byte{0x02, 0x01, 0x01}, }, { // Correct Type: 0xEE, Value: []byte{0x1, 0xc1, 0x45, 0x0, 0x21}, }, { // Also correct, but second in line Type: 0xEE, Value: []byte{0x1, 0xc1, 0x45, 0x0, 0x22}, }, }, wantLinkID: 0x210045c1, wantFound: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { gotLinkID, gotFound := FindAzurePrivateEndpointLinkID(tt.tlvs) if gotFound != tt.wantFound { t.Errorf("FindAzurePrivateEndpointLinkID() got1 = %v, want %v", gotFound, tt.wantFound) } if gotLinkID != tt.wantLinkID { t.Errorf("FindAzurePrivateEndpointLinkID() got = %v, want %v", gotLinkID, tt.wantLinkID) } }) } } go-proxyproto-0.4.2/tlvparse/ssl.go000066400000000000000000000116321400451755600173670ustar00rootroot00000000000000package tlvparse import ( "encoding/binary" "unicode" "unicode/utf8" "github.com/pires/go-proxyproto" ) const ( // pp2_tlv_ssl.client bit fields PP2_BITFIELD_CLIENT_SSL uint8 = 0x01 PP2_BITFIELD_CLIENT_CERT_CONN = 0x02 PP2_BITFIELD_CLIENT_CERT_SESS = 0x04 tlvSSLMinLen = 5 // len(pp2_tlv_ssl.client) + len(pp2_tlv_ssl.verify) ) // 2.2.5. The PP2_TYPE_SSL type and subtypes /* struct pp2_tlv_ssl { uint8_t client; uint32_t verify; struct pp2_tlv sub_tlv[0]; }; */ type PP2SSL struct { Client uint8 // The field is made of a bit field from the following values, // indicating which element is present: PP2_BITFIELD_CLIENT_SSL, // PP2_BITFIELD_CLIENT_CERT_CONN, PP2_BITFIELD_CLIENT_CERT_SESS Verify uint32 // Verify will be zero if the client presented a certificate // and it was successfully verified, and non-zero otherwise. TLV []proxyproto.TLV } // Verified is true if the client presented a certificate and it was successfully verified func (s PP2SSL) Verified() bool { return s.Verify == 0 } // ClientSSL indicates that the client connected over SSL/TLS. When true, SSLVersion will return the version. func (s PP2SSL) ClientSSL() bool { return s.Client&PP2_BITFIELD_CLIENT_SSL == PP2_BITFIELD_CLIENT_SSL } // ClientCertConn indicates that the client provided a certificate over the current connection. func (s PP2SSL) ClientCertConn() bool { return s.Client&PP2_BITFIELD_CLIENT_CERT_CONN == PP2_BITFIELD_CLIENT_CERT_CONN } // ClientCertSess indicates that the client provided a certificate at least once over the TLS session this // connection belongs to. func (s PP2SSL) ClientCertSess() bool { return s.Client&PP2_BITFIELD_CLIENT_CERT_SESS == PP2_BITFIELD_CLIENT_CERT_SESS } // SSLVersion returns the US-ASCII string representation of the TLS version and whether that extension exists. func (s PP2SSL) SSLVersion() (string, bool) { for _, tlv := range s.TLV { if tlv.Type == proxyproto.PP2_SUBTYPE_SSL_VERSION { return string(tlv.Value), true } } return "", false } // Marshal formats the PP2SSL structure as a TLV. func (s PP2SSL) Marshal() (proxyproto.TLV, error) { v := make([]byte, 5) v[0] = s.Client binary.BigEndian.PutUint32(v[1:5], s.Verify) tlvs, err := proxyproto.JoinTLVs(s.TLV) if err != nil { return proxyproto.TLV{}, err } v = append(v, tlvs...) return proxyproto.TLV{ Type: proxyproto.PP2_TYPE_SSL, Value: v, }, nil } // ClientCN returns the string representation (in UTF8) of the Common Name field (OID: 2.5.4.3) of the client // certificate's Distinguished Name and whether that extension exists. func (s PP2SSL) ClientCN() (string, bool) { for _, tlv := range s.TLV { if tlv.Type == proxyproto.PP2_SUBTYPE_SSL_CN { return string(tlv.Value), true } } return "", false } // SSLType is true if the TLV is type SSL func IsSSL(t proxyproto.TLV) bool { return t.Type == proxyproto.PP2_TYPE_SSL && len(t.Value) >= tlvSSLMinLen } // SSL returns the pp2_tlv_ssl from section 2.2.5 or errors with ErrIncompatibleTLV or ErrMalformedTLV func SSL(t proxyproto.TLV) (PP2SSL, error) { ssl := PP2SSL{} if !IsSSL(t) { return ssl, proxyproto.ErrIncompatibleTLV } if len(t.Value) < tlvSSLMinLen { return ssl, proxyproto.ErrMalformedTLV } ssl.Client = t.Value[0] ssl.Verify = binary.BigEndian.Uint32(t.Value[1:5]) var err error ssl.TLV, err = proxyproto.SplitTLVs(t.Value[5:]) if err != nil { return PP2SSL{}, err } versionFound := !ssl.ClientSSL() for _, tlv := range ssl.TLV { switch tlv.Type { case proxyproto.PP2_SUBTYPE_SSL_VERSION: /* The PP2_CLIENT_SSL flag indicates that the client connected over SSL/TLS. When this field is present, the US-ASCII string representation of the TLS version is appended at the end of the field in the TLV format using the type PP2_SUBTYPE_SSL_VERSION. */ if len(tlv.Value) == 0 || !isASCII(tlv.Value) { return PP2SSL{}, proxyproto.ErrMalformedTLV } versionFound = true case proxyproto.PP2_SUBTYPE_SSL_CN: /* In all cases, the string representation (in UTF8) of the Common Name field (OID: 2.5.4.3) of the client certificate's Distinguished Name, is appended using the TLV format and the type PP2_SUBTYPE_SSL_CN. E.g. "example.com". */ if len(tlv.Value) == 0 || !utf8.Valid(tlv.Value) { return PP2SSL{}, proxyproto.ErrMalformedTLV } } } if !versionFound { return PP2SSL{}, proxyproto.ErrMalformedTLV } return ssl, nil } // SSL returns the first PP2SSL if it exists and is well formed as well as bool indicating if it was found. func FindSSL(tlvs []proxyproto.TLV) (PP2SSL, bool) { for _, t := range tlvs { if ssl, err := SSL(t); err == nil { return ssl, true } } return PP2SSL{}, false } // isASCII checks whether a byte slice has all characters that fit in the ascii character set, including the null byte. func isASCII(b []byte) bool { for _, c := range b { if c > unicode.MaxASCII { return false } } return true } go-proxyproto-0.4.2/tlvparse/ssl_test.go000066400000000000000000000061321400451755600204250ustar00rootroot00000000000000package tlvparse import ( "reflect" "testing" "github.com/pires/go-proxyproto" ) var testCases = []struct { name string raw []byte types []proxyproto.PP2Type valid func(*testing.T, string, []proxyproto.TLV) }{ { name: "SSL haproxy cn", raw: []byte{ 0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, 0x54, 0x0a, 0x21, 0x11, 0x00, 0x40, 0x7f, 0x00, 0x00, 0x01, 0x7f, 0x00, 0x00, 0x01, 0xcc, 0x8a, 0x23, 0x2e, 0x20, 0x00, 0x31, 0x07, 0x00, 0x00, 0x00, 0x00, 0x21, 0x00, 0x07, 0x54, 0x4c, 0x53, 0x76, 0x31, 0x2e, 0x33, 0x22, 0x00, 0x1f, 0x45, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x20, 0x43, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x20, 0x4e, 0x61, 0x6d, 0x65, 0x20, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x20, 0x43, 0x65, 0x72, 0x74, }, types: []proxyproto.PP2Type{proxyproto.PP2_TYPE_SSL}, valid: func(t *testing.T, name string, tlvs []proxyproto.TLV) { if !IsSSL(tlvs[0]) { t.Fatalf("TestParseV2TLV %s: Expected tlvs[0] to be the SSL type", name) } ssl, err := SSL(tlvs[0]) if err != nil { t.Fatalf("TestParseV2TLV %s: Unexpected error when parsing SSL %#v", name, err) } if !ssl.ClientSSL() { t.Fatalf("TestParseV2TLV %s: Expected ClientSSL() to be true", name) } if !ssl.ClientCertConn() { t.Fatalf("TestParseV2TLV %s: Expected ClientCertConn() to be true", name) } if !ssl.ClientCertSess() { t.Fatalf("TestParseV2TLV %s: Expected ClientCertSess() to be true", name) } ecn := "Example Common Name Client Cert" if acn, ok := ssl.ClientCN(); !ok { t.Fatalf("TestParseV2TLV %s: Expected ClientCN to exist", name) } else if acn != ecn { t.Fatalf("TestParseV2TLV %s: Unexpected ClientCN expected %#v, actual %#v", name, ecn, acn) } esslVer := "TLSv1.3" if asslVer, ok := ssl.SSLVersion(); !ok { t.Fatalf("TestParseV2TLV %s: Expected SSLVersion to exist", name) } else if asslVer != esslVer { t.Fatalf("TestParseV2TLV %s: Unexpected SSLVersion expected %#v, actual %#v", name, esslVer, asslVer) } if !ssl.Verified() { t.Fatalf("TestParseV2TLV %s: Expected Verified to be true", name) } }, }, } func TestParseV2TLV(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { tlvs := checkTLVs(t, tc.name, tc.raw, tc.types) tc.valid(t, tc.name, tlvs) }) } } func TestPP2SSLMarshal(t *testing.T) { ver := "TLSv1.3" cn := "example.org" pp2 := PP2SSL{ Client: PP2_BITFIELD_CLIENT_SSL, Verify: 0, TLV: []proxyproto.TLV{ { Type: proxyproto.PP2_SUBTYPE_SSL_VERSION, Value: []byte(ver), }, { Type: proxyproto.PP2_SUBTYPE_SSL_CN, Value: []byte(cn), }, }, } raw := []byte{0x1, 0x0, 0x0, 0x0, 0x0, 0x21, 0x0, 0x7, 0x54, 0x4c, 0x53, 0x76, 0x31, 0x2e, 0x33, 0x22, 0x0, 0xb, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x6f, 0x72, 0x67} want := proxyproto.TLV{ Type: proxyproto.PP2_TYPE_SSL, Value: raw, } tlv, err := pp2.Marshal() if err != nil { t.Fatalf("PP2SSL.Marshal() = %v", err) } if !reflect.DeepEqual(tlv, want) { t.Errorf("PP2SSL.Marshal() = %#v, want %#v", tlv, want) } } go-proxyproto-0.4.2/tlvparse/test.go000066400000000000000000000013371400451755600175460ustar00rootroot00000000000000package tlvparse import ( "bufio" "bytes" "testing" "github.com/pires/go-proxyproto" ) func checkTLVs(t *testing.T, name string, raw []byte, expected []proxyproto.PP2Type) []proxyproto.TLV { header, err := proxyproto.Read(bufio.NewReader(bytes.NewReader(raw))) if err != nil { t.Fatalf("%s: Unexpected error reading header %#v", name, err) } tlvs, err := header.TLVs() if err != nil { t.Fatalf("%s: Unexpected error splitting TLVS %#v", name, err) } if len(tlvs) != len(expected) { t.Fatalf("%s: Expected %d TLVs, actual %d", name, len(expected), len(tlvs)) } for i, et := range expected { if at := tlvs[i].Type; at != et { t.Fatalf("%s: Expected type %X, actual %X", name, et, at) } } return tlvs } go-proxyproto-0.4.2/v1.go000066400000000000000000000102271400451755600152530ustar00rootroot00000000000000package proxyproto import ( "bufio" "bytes" "net" "strconv" "strings" ) const ( crlf = "\r\n" separator = " " ) func initVersion1() *Header { header := new(Header) header.Version = 1 // Command doesn't exist in v1 header.Command = PROXY return header } func parseVersion1(reader *bufio.Reader) (*Header, error) { // Read until LF shows up, otherwise fail. // At this point, can't be sure CR precedes LF which will be validated next. line, err := reader.ReadString('\n') if err != nil { return nil, ErrLineMustEndWithCrlf } if !strings.HasSuffix(line, crlf) { return nil, ErrLineMustEndWithCrlf } // Check full signature. tokens := strings.Split(line[:len(line)-2], separator) // Expect at least 2 tokens: "PROXY" and the transport protocol. if len(tokens) < 2 { return nil, ErrCantReadAddressFamilyAndProtocol } // Read address family and protocol var transportProtocol AddressFamilyAndProtocol switch tokens[1] { case "TCP4": transportProtocol = TCPv4 case "TCP6": transportProtocol = TCPv6 case "UNKNOWN": transportProtocol = UNSPEC // doesn't exist in v1 but fits UNKNOWN default: return nil, ErrCantReadAddressFamilyAndProtocol } // Expect 6 tokens only when UNKNOWN is not present. if transportProtocol != UNSPEC && len(tokens) < 6 { return nil, ErrCantReadAddressFamilyAndProtocol } // When a signature is found, allocate a v1 header with Command set to PROXY. // Command doesn't exist in v1 but set it for other parts of this library // to rely on it for determining connection details. header := initVersion1() // Transport protocol has been processed already. header.TransportProtocol = transportProtocol // When UNKNOWN, set the command to LOCAL and return early if header.TransportProtocol == UNSPEC { header.Command = LOCAL return header, nil } // Otherwise, continue to read addresses and ports sourceIP, err := parseV1IPAddress(header.TransportProtocol, tokens[2]) if err != nil { return nil, err } destIP, err := parseV1IPAddress(header.TransportProtocol, tokens[3]) if err != nil { return nil, err } sourcePort, err := parseV1PortNumber(tokens[4]) if err != nil { return nil, err } destPort, err := parseV1PortNumber(tokens[5]) if err != nil { return nil, err } header.SourceAddr = &net.TCPAddr{ IP: sourceIP, Port: sourcePort, } header.DestinationAddr = &net.TCPAddr{ IP: destIP, Port: destPort, } return header, nil } func (header *Header) formatVersion1() ([]byte, error) { // As of version 1, only "TCP4" ( \x54 \x43 \x50 \x34 ) for TCP over IPv4, // and "TCP6" ( \x54 \x43 \x50 \x36 ) for TCP over IPv6 are allowed. var proto string switch header.TransportProtocol { case TCPv4: proto = "TCP4" case TCPv6: proto = "TCP6" default: // Unknown connection (short form) return []byte("PROXY UNKNOWN" + crlf), nil } sourceAddr, sourceOK := header.SourceAddr.(*net.TCPAddr) destAddr, destOK := header.DestinationAddr.(*net.TCPAddr) if !sourceOK || !destOK { return nil, ErrInvalidAddress } sourceIP, destIP := sourceAddr.IP, destAddr.IP switch header.TransportProtocol { case TCPv4: sourceIP = sourceIP.To4() destIP = destIP.To4() case TCPv6: sourceIP = sourceIP.To16() destIP = destIP.To16() } if sourceIP == nil || destIP == nil { return nil, ErrInvalidAddress } buf := bytes.NewBuffer(make([]byte, 0, 108)) buf.Write(SIGV1) buf.WriteString(separator) buf.WriteString(proto) buf.WriteString(separator) buf.WriteString(sourceIP.String()) buf.WriteString(separator) buf.WriteString(destIP.String()) buf.WriteString(separator) buf.WriteString(strconv.Itoa(sourceAddr.Port)) buf.WriteString(separator) buf.WriteString(strconv.Itoa(destAddr.Port)) buf.WriteString(crlf) return buf.Bytes(), nil } func parseV1PortNumber(portStr string) (int, error) { port, err := strconv.Atoi(portStr) if err != nil || port < 0 || port > 65535 { return 0, ErrInvalidPortNumber } return port, nil } func parseV1IPAddress(protocol AddressFamilyAndProtocol, addrStr string) (addr net.IP, err error) { addr = net.ParseIP(addrStr) tryV4 := addr.To4() if (protocol == TCPv4 && tryV4 == nil) || (protocol == TCPv6 && tryV4 != nil) { err = ErrInvalidAddress } return } go-proxyproto-0.4.2/v1_test.go000066400000000000000000000114061400451755600163120ustar00rootroot00000000000000package proxyproto import ( "bufio" "bytes" "strconv" "strings" "testing" ) var ( IPv4AddressesAndPorts = strings.Join([]string{IP4_ADDR, IP4_ADDR, strconv.Itoa(PORT), strconv.Itoa(PORT)}, separator) IPv4AddressesAndInvalidPorts = strings.Join([]string{IP4_ADDR, IP4_ADDR, strconv.Itoa(INVALID_PORT), strconv.Itoa(INVALID_PORT)}, separator) IPv6AddressesAndPorts = strings.Join([]string{IP6_ADDR, IP6_ADDR, strconv.Itoa(PORT), strconv.Itoa(PORT)}, separator) fixtureTCP4V1 = "PROXY TCP4 " + IPv4AddressesAndPorts + crlf + "GET /" fixtureTCP6V1 = "PROXY TCP6 " + IPv6AddressesAndPorts + crlf + "GET /" fixtureUnknown = "PROXY UNKNOWN" + crlf fixtureUnknownWithAddresses = "PROXY UNKNOWN " + IPv4AddressesAndInvalidPorts + crlf ) var invalidParseV1Tests = []struct { desc string reader *bufio.Reader expectedError error }{ { desc: "no signature", reader: newBufioReader([]byte(NO_PROTOCOL)), expectedError: ErrNoProxyProtocol, }, { desc: "prox", reader: newBufioReader([]byte("PROX")), expectedError: ErrNoProxyProtocol, }, { desc: "proxy lf", reader: newBufioReader([]byte("PROXY \n")), expectedError: ErrLineMustEndWithCrlf, }, { desc: "proxy crlf", reader: newBufioReader([]byte("PROXY " + crlf)), expectedError: ErrCantReadAddressFamilyAndProtocol, }, { desc: "proxy no space crlf", reader: newBufioReader([]byte("PROXY" + crlf)), expectedError: ErrCantReadAddressFamilyAndProtocol, }, { desc: "proxy something crlf", reader: newBufioReader([]byte("PROXY SOMETHING" + crlf)), expectedError: ErrCantReadAddressFamilyAndProtocol, }, { desc: "incomplete signature TCP4", reader: newBufioReader([]byte("PROXY TCP4 " + IPv4AddressesAndPorts)), expectedError: ErrLineMustEndWithCrlf, }, { desc: "TCP6 with IPv4 addresses", reader: newBufioReader([]byte("PROXY TCP6 " + IPv4AddressesAndPorts + crlf)), expectedError: ErrInvalidAddress, }, { desc: "TCP4 with IPv6 addresses", reader: newBufioReader([]byte("PROXY TCP4 " + IPv6AddressesAndPorts + crlf)), expectedError: ErrInvalidAddress, }, { desc: "TCP4 with invalid port", reader: newBufioReader([]byte("PROXY TCP4 " + IPv4AddressesAndInvalidPorts + crlf)), expectedError: ErrInvalidPortNumber, }, } func TestReadV1Invalid(t *testing.T) { for _, tt := range invalidParseV1Tests { t.Run(tt.desc, func(t *testing.T) { if _, err := Read(tt.reader); err != tt.expectedError { t.Fatalf("expected %s, actual %s", tt.expectedError, err.Error()) } }) } } var validParseAndWriteV1Tests = []struct { desc string reader *bufio.Reader expectedHeader *Header }{ { desc: "TCP4", reader: bufio.NewReader(strings.NewReader(fixtureTCP4V1)), expectedHeader: &Header{ Version: 1, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: v4addr, DestinationAddr: v4addr, }, }, { desc: "TCP6", reader: bufio.NewReader(strings.NewReader(fixtureTCP6V1)), expectedHeader: &Header{ Version: 1, Command: PROXY, TransportProtocol: TCPv6, SourceAddr: v6addr, DestinationAddr: v6addr, }, }, { desc: "unknown", reader: bufio.NewReader(strings.NewReader(fixtureUnknown)), expectedHeader: &Header{ Version: 1, Command: LOCAL, TransportProtocol: UNSPEC, SourceAddr: nil, DestinationAddr: nil, }, }, { desc: "unknown with addresses and ports", reader: bufio.NewReader(strings.NewReader(fixtureUnknownWithAddresses)), expectedHeader: &Header{ Version: 1, Command: LOCAL, TransportProtocol: UNSPEC, SourceAddr: nil, DestinationAddr: nil, }, }, } func TestParseV1Valid(t *testing.T) { for _, tt := range validParseAndWriteV1Tests { t.Run(tt.desc, func(t *testing.T) { header, err := Read(tt.reader) if err != nil { t.Fatal("unexpected error", err.Error()) } if !header.EqualsTo(tt.expectedHeader) { t.Fatalf("expected %#v, actual %#v", tt.expectedHeader, header) } }) } } func TestWriteV1Valid(t *testing.T) { for _, tt := range validParseAndWriteV1Tests { t.Run(tt.desc, func(t *testing.T) { var b bytes.Buffer w := bufio.NewWriter(&b) if _, err := tt.expectedHeader.WriteTo(w); err != nil { t.Fatal("unexpected error ", err) } w.Flush() // Read written bytes to validate written header r := bufio.NewReader(&b) newHeader, err := Read(r) if err != nil { t.Fatal("unexpected error ", err) } if !newHeader.EqualsTo(tt.expectedHeader) { t.Fatalf("expected %#v, actual %#v", tt.expectedHeader, newHeader) } }) } } go-proxyproto-0.4.2/v2.go000066400000000000000000000165031400451755600152570ustar00rootroot00000000000000package proxyproto import ( "bufio" "bytes" "encoding/binary" "errors" "io" "net" ) var ( lengthUnspec = uint16(0) lengthV4 = uint16(12) lengthV6 = uint16(36) lengthUnix = uint16(216) lengthUnspecBytes = func() []byte { a := make([]byte, 2) binary.BigEndian.PutUint16(a, lengthUnspec) return a }() lengthV4Bytes = func() []byte { a := make([]byte, 2) binary.BigEndian.PutUint16(a, lengthV4) return a }() lengthV6Bytes = func() []byte { a := make([]byte, 2) binary.BigEndian.PutUint16(a, lengthV6) return a }() lengthUnixBytes = func() []byte { a := make([]byte, 2) binary.BigEndian.PutUint16(a, lengthUnix) return a }() errUint16Overflow = errors.New("proxyproto: uint16 overflow") ) type _ports struct { SrcPort uint16 DstPort uint16 } type _addr4 struct { Src [4]byte Dst [4]byte SrcPort uint16 DstPort uint16 } type _addr6 struct { Src [16]byte Dst [16]byte _ports } type _addrUnix struct { Src [108]byte Dst [108]byte } func parseVersion2(reader *bufio.Reader) (header *Header, err error) { // Skip first 12 bytes (signature) for i := 0; i < 12; i++ { if _, err = reader.ReadByte(); err != nil { return nil, ErrCantReadProtocolVersionAndCommand } } header = new(Header) header.Version = 2 // Read the 13th byte, protocol version and command b13, err := reader.ReadByte() if err != nil { return nil, ErrCantReadProtocolVersionAndCommand } header.Command = ProtocolVersionAndCommand(b13) if _, ok := supportedCommand[header.Command]; !ok { return nil, ErrUnsupportedProtocolVersionAndCommand } // Read the 14th byte, address family and protocol b14, err := reader.ReadByte() if err != nil { return nil, ErrCantReadAddressFamilyAndProtocol } header.TransportProtocol = AddressFamilyAndProtocol(b14) // UNSPEC is only supported when LOCAL is set. if header.TransportProtocol == UNSPEC && header.Command != LOCAL { return nil, ErrUnsupportedAddressFamilyAndProtocol } // Make sure there are bytes available as specified in length var length uint16 if err := binary.Read(io.LimitReader(reader, 2), binary.BigEndian, &length); err != nil { return nil, ErrCantReadLength } if !header.validateLength(length) { return nil, ErrInvalidLength } // Return early if the length is zero, which means that // there's no address information and TLVs present for UNSPEC. if length == 0 { return header, nil } if _, err := reader.Peek(int(length)); err != nil { return nil, ErrInvalidLength } // Length-limited reader for payload section payloadReader := io.LimitReader(reader, int64(length)).(*io.LimitedReader) // Read addresses and ports for protocols other than UNSPEC. // Ignore address information for UNSPEC, and skip straight to read TLVs, // since the length is greater than zero. if header.TransportProtocol != UNSPEC { if header.TransportProtocol.IsIPv4() { var addr _addr4 if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil { return nil, ErrInvalidAddress } header.SourceAddr = newIPAddr(header.TransportProtocol, addr.Src[:], addr.SrcPort) header.DestinationAddr = newIPAddr(header.TransportProtocol, addr.Dst[:], addr.DstPort) } else if header.TransportProtocol.IsIPv6() { var addr _addr6 if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil { return nil, ErrInvalidAddress } header.SourceAddr = newIPAddr(header.TransportProtocol, addr.Src[:], addr.SrcPort) header.DestinationAddr = newIPAddr(header.TransportProtocol, addr.Dst[:], addr.DstPort) } else if header.TransportProtocol.IsUnix() { var addr _addrUnix if err := binary.Read(payloadReader, binary.BigEndian, &addr); err != nil { return nil, ErrInvalidAddress } network := "unix" if header.TransportProtocol.IsDatagram() { network = "unixgram" } header.SourceAddr = &net.UnixAddr{ Net: network, Name: parseUnixName(addr.Src[:]), } header.DestinationAddr = &net.UnixAddr{ Net: network, Name: parseUnixName(addr.Dst[:]), } } } // Copy bytes for optional Type-Length-Value vector header.rawTLVs = make([]byte, payloadReader.N) // Allocate minimum size slice if _, err = io.ReadFull(payloadReader, header.rawTLVs); err != nil && err != io.EOF { return nil, err } return header, nil } func (header *Header) formatVersion2() ([]byte, error) { var buf bytes.Buffer buf.Write(SIGV2) buf.WriteByte(header.Command.toByte()) buf.WriteByte(header.TransportProtocol.toByte()) if header.TransportProtocol.IsUnspec() { // For UNSPEC, write no addresses and ports but only TLVs if they are present hdrLen, err := addTLVLen(lengthUnspecBytes, len(header.rawTLVs)) if err != nil { return nil, err } buf.Write(hdrLen) } else { var addrSrc, addrDst []byte if header.TransportProtocol.IsIPv4() { hdrLen, err := addTLVLen(lengthV4Bytes, len(header.rawTLVs)) if err != nil { return nil, err } buf.Write(hdrLen) sourceIP, destIP, _ := header.IPs() addrSrc = sourceIP.To4() addrDst = destIP.To4() } else if header.TransportProtocol.IsIPv6() { hdrLen, err := addTLVLen(lengthV6Bytes, len(header.rawTLVs)) if err != nil { return nil, err } buf.Write(hdrLen) sourceIP, destIP, _ := header.IPs() addrSrc = sourceIP.To16() addrDst = destIP.To16() } else if header.TransportProtocol.IsUnix() { buf.Write(lengthUnixBytes) sourceAddr, destAddr, ok := header.UnixAddrs() if !ok { return nil, ErrInvalidAddress } addrSrc = formatUnixName(sourceAddr.Name) addrDst = formatUnixName(destAddr.Name) } if addrSrc == nil || addrDst == nil { return nil, ErrInvalidAddress } buf.Write(addrSrc) buf.Write(addrDst) if sourcePort, destPort, ok := header.Ports(); ok { portBytes := make([]byte, 2) binary.BigEndian.PutUint16(portBytes, uint16(sourcePort)) buf.Write(portBytes) binary.BigEndian.PutUint16(portBytes, uint16(destPort)) buf.Write(portBytes) } } if len(header.rawTLVs) > 0 { buf.Write(header.rawTLVs) } return buf.Bytes(), nil } func (header *Header) validateLength(length uint16) bool { if header.TransportProtocol.IsIPv4() { return length >= lengthV4 } else if header.TransportProtocol.IsIPv6() { return length >= lengthV6 } else if header.TransportProtocol.IsUnix() { return length >= lengthUnix } else if header.TransportProtocol.IsUnspec() { return length >= lengthUnspec } return false } // addTLVLen adds the length of the TLV to the header length or errors on uint16 overflow. func addTLVLen(cur []byte, tlvLen int) ([]byte, error) { if tlvLen == 0 { return cur, nil } curLen := binary.BigEndian.Uint16(cur) newLen := int(curLen) + tlvLen if newLen >= 1<<16 { return nil, errUint16Overflow } a := make([]byte, 2) binary.BigEndian.PutUint16(a, uint16(newLen)) return a, nil } func newIPAddr(transport AddressFamilyAndProtocol, ip net.IP, port uint16) net.Addr { if transport.IsStream() { return &net.TCPAddr{IP: ip, Port: int(port)} } else if transport.IsDatagram() { return &net.UDPAddr{IP: ip, Port: int(port)} } else { return nil } } func parseUnixName(b []byte) string { i := bytes.IndexByte(b, 0) if i < 0 { return string(b) } return string(b[:i]) } func formatUnixName(name string) []byte { n := int(lengthUnix) / 2 if len(name) >= n { return []byte(name[:n]) } pad := make([]byte, n-len(name)) return append([]byte(name), pad...) } go-proxyproto-0.4.2/v2_test.go000066400000000000000000000345071400451755600163220ustar00rootroot00000000000000package proxyproto import ( "bufio" "bytes" "encoding/binary" "math/rand" "reflect" "testing" ) var ( invalidRune = byte('\x99') // Lengths to use in tests lengthPadded = uint16(84) lengthEmptyBytes = func() []byte { a := make([]byte, 2) binary.BigEndian.PutUint16(a, 0) return a }() lengthPaddedBytes = func() []byte { a := make([]byte, 2) binary.BigEndian.PutUint16(a, lengthPadded) return a }() // If life gives you lemons, make mojitos portBytes = func() []byte { a := make([]byte, 2) binary.BigEndian.PutUint16(a, PORT) return a }() unixBytes = pad([]byte("socket"), 108) // Tests don't care if source and destination addresses and ports are the same addressesIPv4 = append(v4ip.To4(), v4ip.To4()...) addressesIPv6 = append(v6ip.To16(), v6ip.To16()...) ports = append(portBytes, portBytes...) // Fixtures to use in tests fixtureIPv4Address = append(addressesIPv4, ports...) fixtureIPv4V2 = append(lengthV4Bytes, fixtureIPv4Address...) fixtureIPv4V2Padded = append(append(lengthPaddedBytes, fixtureIPv4Address...), make([]byte, lengthPadded-lengthV4)...) fixtureIPv6Address = append(addressesIPv6, ports...) fixtureIPv6V2 = append(lengthV6Bytes, fixtureIPv6Address...) fixtureIPv6V2Padded = append(append(lengthPaddedBytes, fixtureIPv6Address...), make([]byte, lengthPadded-lengthV6)...) fixtureUnixAddress = append(unixBytes, unixBytes...) fixtureUnixV2 = append(lengthUnixBytes, fixtureUnixAddress...) fixtureTLV = func() []byte { tlv := make([]byte, 2+rand.Intn(1<<12)) // Not enough to overflow, at least size two rand.Read(tlv) return tlv }() fixtureIPv4V2TLV = fixtureWithTLV(lengthV4Bytes, fixtureIPv4Address, fixtureTLV) fixtureIPv6V2TLV = fixtureWithTLV(lengthV6Bytes, fixtureIPv6Address, fixtureTLV) fixtureUnspecTLV = fixtureWithTLV(lengthUnspecBytes, []byte{}, fixtureTLV) // Arbitrary bytes following proxy bytes arbitraryTailBytes = []byte{'\x99', '\x97', '\x98'} ) func pad(b []byte, n int) []byte { padding := make([]byte, n-len(b)) return append(b, padding...) } var invalidParseV2Tests = []struct { desc string reader *bufio.Reader expectedError error }{ { desc: "no signature", reader: newBufioReader([]byte(NO_PROTOCOL)), expectedError: ErrNoProxyProtocol, }, { desc: "truncated v2 signature", reader: newBufioReader(SIGV2[2:]), expectedError: ErrNoProxyProtocol, }, { desc: "v2 signature and nothing else", reader: newBufioReader(SIGV2), expectedError: ErrCantReadProtocolVersionAndCommand, }, { desc: "v2 signature with invalid command", reader: newBufioReader(append(SIGV2, invalidRune)), expectedError: ErrUnsupportedProtocolVersionAndCommand, }, { desc: "v2 signature with command but nothing else", reader: newBufioReader(append(SIGV2, byte(PROXY))), expectedError: ErrCantReadAddressFamilyAndProtocol, }, { desc: "command proxy but inet family unspec", reader: newBufioReader(append(SIGV2, byte(PROXY), byte(UNSPEC))), expectedError: ErrUnsupportedAddressFamilyAndProtocol, }, { desc: "v2 signature with command and invalid inet family", // translated to UNSPEC reader: newBufioReader(append(SIGV2, byte(PROXY), invalidRune)), expectedError: ErrCantReadLength, }, { desc: "TCPv4 but no length", reader: newBufioReader(append(SIGV2, byte(PROXY), byte(TCPv4))), expectedError: ErrCantReadLength, }, { desc: "TCPv4 but invalid length", reader: newBufioReader(append(SIGV2, byte(PROXY), byte(TCPv4), invalidRune)), expectedError: ErrCantReadLength, }, { desc: "unspec but no length", reader: newBufioReader(append(SIGV2, byte(LOCAL), byte(UNSPEC))), expectedError: ErrCantReadLength, }, { desc: "TCPv4 with mismatching length", reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv4)), lengthV4Bytes...)), expectedError: ErrInvalidLength, }, { desc: "TCPv6 with mismatching length", reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv6)), lengthV6Bytes...)), expectedError: ErrInvalidLength, }, { desc: "TCPv4 length zero but with address and ports", reader: newBufioReader(append(append(append(SIGV2, byte(PROXY), byte(TCPv4)), lengthEmptyBytes...), fixtureIPv6Address...)), expectedError: ErrInvalidLength, }, { desc: "TCPv6 with IPv6 length but IPv4 address and ports", reader: newBufioReader(append(append(append(SIGV2, byte(PROXY), byte(TCPv6)), lengthV6Bytes...), fixtureIPv4Address...)), expectedError: ErrInvalidLength, }, { desc: "unspec length greater than zero but no TLVs", reader: newBufioReader(append(append(SIGV2, byte(LOCAL), byte(UNSPEC)), fixtureUnspecTLV[:2]...)), expectedError: ErrInvalidLength, }, } func TestParseV2Invalid(t *testing.T) { for _, tt := range invalidParseV2Tests { t.Run(tt.desc, func(t *testing.T) { if _, err := Read(tt.reader); err != tt.expectedError { t.Fatalf("expected %s, actual %s", tt.expectedError, err.Error()) } }) } } var validParseAndWriteV2Tests = []struct { desc string reader *bufio.Reader expectedHeader *Header }{ { desc: "local", reader: newBufioReader(append(append(SIGV2, byte(LOCAL), byte(TCPv4)), fixtureIPv4V2...)), expectedHeader: &Header{ Version: 2, Command: LOCAL, TransportProtocol: TCPv4, SourceAddr: v4addr, DestinationAddr: v4addr, }, }, { desc: "local unspec", reader: newBufioReader(append(append(SIGV2, byte(LOCAL), byte(UNSPEC)), lengthUnspecBytes...)), expectedHeader: &Header{ Version: 2, Command: LOCAL, TransportProtocol: UNSPEC, SourceAddr: nil, DestinationAddr: nil, }, }, { desc: "proxy TCPv4", reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv4)), fixtureIPv4V2...)), expectedHeader: &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: v4addr, DestinationAddr: v4addr, }, }, { desc: "proxy TCPv6", reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv6)), fixtureIPv6V2...)), expectedHeader: &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv6, SourceAddr: v6addr, DestinationAddr: v6addr, }, }, { desc: "proxy TCPv4 with TLV", reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv4)), fixtureIPv4V2TLV...)), expectedHeader: &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: v4addr, DestinationAddr: v4addr, rawTLVs: fixtureTLV, }, }, { desc: "proxy TCPv6 with TLV", reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv6)), fixtureIPv6V2TLV...)), expectedHeader: &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv6, SourceAddr: v6addr, DestinationAddr: v6addr, rawTLVs: fixtureTLV, }, }, { desc: "local unspec with TLV", reader: newBufioReader(append(append(SIGV2, byte(LOCAL), byte(UNSPEC)), fixtureUnspecTLV...)), expectedHeader: &Header{ Version: 2, Command: LOCAL, TransportProtocol: UNSPEC, SourceAddr: nil, DestinationAddr: nil, rawTLVs: fixtureTLV, }, }, { desc: "proxy UDPv4", reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(UDPv4)), fixtureIPv4V2...)), expectedHeader: &Header{ Version: 2, Command: PROXY, TransportProtocol: UDPv4, SourceAddr: v4UDPAddr, DestinationAddr: v4UDPAddr, }, }, { desc: "proxy UDPv6", reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(UDPv6)), fixtureIPv6V2...)), expectedHeader: &Header{ Version: 2, Command: PROXY, TransportProtocol: UDPv6, SourceAddr: v6UDPAddr, DestinationAddr: v6UDPAddr, }, }, { desc: "proxy unix stream", reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(UnixStream)), fixtureUnixV2...)), expectedHeader: &Header{ Version: 2, Command: PROXY, TransportProtocol: UnixStream, SourceAddr: unixStreamAddr, DestinationAddr: unixStreamAddr, }, }, { desc: "proxy unix datagram", reader: newBufioReader(append(append(SIGV2, byte(PROXY), byte(UnixDatagram)), fixtureUnixV2...)), expectedHeader: &Header{ Version: 2, Command: PROXY, TransportProtocol: UnixDatagram, SourceAddr: unixDatagramAddr, DestinationAddr: unixDatagramAddr, }, }, } func TestParseV2Valid(t *testing.T) { for _, tt := range validParseAndWriteV2Tests { t.Run(tt.desc, func(t *testing.T) { header, err := Read(tt.reader) if err != nil { t.Fatal("unexpected error", err.Error()) } if !header.EqualsTo(tt.expectedHeader) { t.Fatalf("expected %#v, actual %#v", tt.expectedHeader, header) } }) } } func TestWriteV2Valid(t *testing.T) { for _, tt := range validParseAndWriteV2Tests { t.Run(tt.desc, func(t *testing.T) { var b bytes.Buffer w := bufio.NewWriter(&b) if _, err := tt.expectedHeader.WriteTo(w); err != nil { t.Fatal("unexpected error ", err) } w.Flush() // Read written bytes to validate written header r := bufio.NewReader(&b) newHeader, err := Read(r) if err != nil { t.Fatal("unexpected error ", err) } if !newHeader.EqualsTo(tt.expectedHeader) { t.Fatalf("expected %#v, actual %#v", tt.expectedHeader, newHeader) } }) } } var validParseV2PaddedTests = []struct { desc string value []byte expectedHeader *Header }{ { desc: "proxy TCPv4", value: append(append(SIGV2, byte(PROXY), byte(TCPv4)), fixtureIPv4V2Padded...), expectedHeader: &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: v4addr, DestinationAddr: v4addr, rawTLVs: make([]byte, lengthPadded-lengthV4), }, }, { desc: "proxy TCPv6", value: append(append(SIGV2, byte(PROXY), byte(TCPv6)), fixtureIPv6V2Padded...), expectedHeader: &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv6, SourceAddr: v6addr, DestinationAddr: v6addr, rawTLVs: make([]byte, lengthPadded-lengthV6), }, }, { desc: "proxy UDPv4", value: append(append(SIGV2, byte(PROXY), byte(UDPv4)), fixtureIPv4V2Padded...), expectedHeader: &Header{ Version: 2, Command: PROXY, TransportProtocol: UDPv4, SourceAddr: v4addr, DestinationAddr: v4addr, rawTLVs: make([]byte, lengthPadded-lengthV4), }, }, { desc: "proxy UDPv6", value: append(append(SIGV2, byte(PROXY), byte(UDPv6)), fixtureIPv6V2Padded...), expectedHeader: &Header{ Version: 2, Command: PROXY, TransportProtocol: UDPv6, SourceAddr: v6addr, DestinationAddr: v6addr, rawTLVs: make([]byte, lengthPadded-lengthV6), }, }, } func TestParseV2Padded(t *testing.T) { for _, tt := range validParseV2PaddedTests { t.Run(tt.desc, func(t *testing.T) { reader := newBufioReader(append(tt.value, arbitraryTailBytes...)) newHeader, err := Read(reader) if err != nil { t.Fatal("unexpected error ", err) } if !newHeader.EqualsTo(tt.expectedHeader) { t.Fatalf("expected %#v, actual %#v", tt.expectedHeader, newHeader) } // Check that remaining padding bytes have been flushed nextBytes, err := reader.Peek(len(arbitraryTailBytes)) if err != nil { t.Fatal("unexpected error ", err) } if !reflect.DeepEqual(nextBytes, arbitraryTailBytes) { t.Fatalf("expected %#v, actual %#v", arbitraryTailBytes, nextBytes) } }) } } func TestV2EqualsToTLV(t *testing.T) { eHdr := &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: v4addr, DestinationAddr: v4addr, } hdr, err := Read(newBufioReader(append(append(SIGV2, byte(PROXY), byte(TCPv4)), fixtureIPv4V2TLV...))) if err != nil { t.Fatal("unexpected error ", err) } if eHdr.EqualsTo(hdr) { t.Fatalf("unexpectedly equal created: %#v, parsed: %#v", eHdr, hdr) } eHdr.rawTLVs = fixtureTLV[:] if !eHdr.EqualsTo(hdr) { t.Fatalf("unexpectedly unequal after tlv copy created: %#v, parsed: %#v", eHdr, hdr) } eHdr.rawTLVs[0] = eHdr.rawTLVs[0] + 1 if eHdr.EqualsTo(hdr) { t.Fatalf("unexpectedly equal after changing tlv created: %#v, parsed: %#v", eHdr, hdr) } } var tlvFormatTests = []struct { desc string header *Header }{ { desc: "proxy TCPv4", header: &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv4, SourceAddr: v4addr, DestinationAddr: v4addr, rawTLVs: make([]byte, 1<<16), }, }, { desc: "proxy TCPv6", header: &Header{ Version: 2, Command: PROXY, TransportProtocol: TCPv6, SourceAddr: v6addr, DestinationAddr: v6addr, rawTLVs: make([]byte, 1<<16), }, }, { desc: "proxy UDPv4", header: &Header{ Version: 2, Command: PROXY, TransportProtocol: UDPv4, SourceAddr: v4addr, DestinationAddr: v4addr, rawTLVs: make([]byte, 1<<16), }, }, { desc: "proxy UDPv6", header: &Header{ Version: 2, Command: PROXY, TransportProtocol: UDPv6, SourceAddr: v6addr, DestinationAddr: v6addr, rawTLVs: make([]byte, 1<<16), }, }, { desc: "local unspec", header: &Header{ Version: 2, Command: LOCAL, TransportProtocol: UNSPEC, SourceAddr: nil, DestinationAddr: nil, rawTLVs: make([]byte, 1<<16), }, }, } func TestV2TLVFormatTooLargeTLV(t *testing.T) { for _, tt := range tlvFormatTests { t.Run(tt.desc, func(t *testing.T) { if _, err := tt.header.Format(); err != errUint16Overflow { t.Fatalf("missing or expected error when formatting too-large TLV %#v", err) } }) } } func newBufioReader(b []byte) *bufio.Reader { return bufio.NewReader(bytes.NewReader(b)) } func fixtureWithTLV(cur []byte, addr []byte, tlv []byte) []byte { tlen, err := addTLVLen(cur, len(tlv)) if err != nil { panic(err) } return append(append(tlen, addr...), tlv...) } go-proxyproto-0.4.2/version_cmd.go000066400000000000000000000030521400451755600172330ustar00rootroot00000000000000package proxyproto // ProtocolVersionAndCommand represents the command in proxy protocol v2. // Command doesn't exist in v1 but it should be set since other parts of // this library may rely on it for determining connection details. type ProtocolVersionAndCommand byte const ( // LOCAL represents the LOCAL command in v2 or UNKNOWN transport in v1, // in which case no address information is expected. LOCAL ProtocolVersionAndCommand = '\x20' // PROXY represents the PROXY command in v2 or transport is not UNKNOWN in v1, // in which case valid local/remote address and port information is expected. PROXY ProtocolVersionAndCommand = '\x21' ) var supportedCommand = map[ProtocolVersionAndCommand]bool{ LOCAL: true, PROXY: true, } // IsLocal returns true if the command in v2 is LOCAL or the transport in v1 is UNKNOWN, // i.e. when no address information is expected, false otherwise. func (pvc ProtocolVersionAndCommand) IsLocal() bool { return LOCAL == pvc } // IsProxy returns true if the command in v2 is PROXY or the transport in v1 is not UNKNOWN, // i.e. when valid local/remote address and port information is expected, false otherwise. func (pvc ProtocolVersionAndCommand) IsProxy() bool { return PROXY == pvc } // IsUnspec returns true if the command is unspecified, false otherwise. func (pvc ProtocolVersionAndCommand) IsUnspec() bool { return !(pvc.IsLocal() || pvc.IsProxy()) } func (pvc ProtocolVersionAndCommand) toByte() byte { if pvc.IsLocal() { return byte(LOCAL) } else if pvc.IsProxy() { return byte(PROXY) } return byte(LOCAL) } go-proxyproto-0.4.2/version_cmd_test.go000066400000000000000000000013511400451755600202720ustar00rootroot00000000000000package proxyproto import ( "testing" ) func TestLocal(t *testing.T) { b := byte(LOCAL) if ProtocolVersionAndCommand(b).IsUnspec() { t.Fail() } if !ProtocolVersionAndCommand(b).IsLocal() { t.Fail() } if ProtocolVersionAndCommand(b).IsProxy() { t.Fail() } if ProtocolVersionAndCommand(b).toByte() != b { t.Fail() } } func TestProxy(t *testing.T) { b := byte(PROXY) if ProtocolVersionAndCommand(b).IsUnspec() { t.Fail() } if ProtocolVersionAndCommand(b).IsLocal() { t.Fail() } if !ProtocolVersionAndCommand(b).IsProxy() { t.Fail() } if ProtocolVersionAndCommand(b).toByte() != b { t.Fail() } } func TestInvalidProtocolVersion(t *testing.T) { if !ProtocolVersionAndCommand(0x00).IsUnspec() { t.Fail() } }