pax_global_header00006660000000000000000000000064144232252620014514gustar00rootroot0000000000000052 comment=657f7da1d9bd78d246ae610e0b5efc63a6a5f9e4 netlink-1.7.2/000077500000000000000000000000001442322526200131675ustar00rootroot00000000000000netlink-1.7.2/.github/000077500000000000000000000000001442322526200145275ustar00rootroot00000000000000netlink-1.7.2/.github/workflows/000077500000000000000000000000001442322526200165645ustar00rootroot00000000000000netlink-1.7.2/.github/workflows/linux-integration-test.yml000066400000000000000000000020501442322526200237410ustar00rootroot00000000000000name: Linux Integration Test on: push: branches: - "*" pull_request: branches: - "*" jobs: build: strategy: matrix: go-version: ["1.18", "1.19", "1.20"] runs-on: ubuntu-latest steps: - name: Set up Go uses: actions/setup-go@v3 with: go-version: ${{ matrix.go-version }} id: go - name: Check out code into the Go module directory uses: actions/checkout@v3 - name: Create a network namespace for privileged tests run: sudo ip netns add unpriv0 - name: Build netlink test binary for privileged tests run: go test -c -race . - name: Run privileged netlink tests run: sudo ./netlink.test -test.v -test.run TestIntegration* - name: Build integration test binary for privileged tests working-directory: ./internal/integration run: go test -c -race . - name: Run privileged integration tests working-directory: ./internal/integration run: sudo ./integration.test -test.v netlink-1.7.2/.github/workflows/linux-test.yml000066400000000000000000000011701442322526200214220ustar00rootroot00000000000000name: Linux Test on: push: branches: - "*" pull_request: branches: - "*" jobs: build: strategy: matrix: go-version: ["1.18", "1.19", "1.20"] runs-on: ubuntu-latest steps: - name: Set up Go uses: actions/setup-go@v3 with: go-version: ${{ matrix.go-version }} id: go - name: Check out code into the Go module directory uses: actions/checkout@v3 - name: Create a network namespace for unprivileged tests run: sudo ip netns add unpriv0 - name: Run tests run: go test -v -race -tags gofuzz ./... netlink-1.7.2/.github/workflows/macos-test.yml000066400000000000000000000007611442322526200213720ustar00rootroot00000000000000name: macOS Test on: push: branches: - "*" pull_request: branches: - "*" jobs: build: strategy: matrix: go-version: ["1.20"] runs-on: macos-latest steps: - name: Set up Go uses: actions/setup-go@v3 with: go-version: ${{ matrix.go-version }} id: go - name: Check out code into the Go module directory uses: actions/checkout@v3 - name: Run tests run: go test -v -race ./... netlink-1.7.2/.github/workflows/static-analysis.yml000066400000000000000000000013341442322526200224200ustar00rootroot00000000000000name: Static Analysis on: push: branches: - "*" pull_request: branches: - "*" jobs: build: strategy: matrix: go-version: ["1.20"] runs-on: ubuntu-latest steps: - name: Set up Go uses: actions/setup-go@v3 with: go-version: ${{ matrix.go-version }} id: go - name: Check out code into the Go module directory uses: actions/checkout@v3 - name: Install staticcheck run: go install honnef.co/go/tools/cmd/staticcheck@latest - name: Print staticcheck version run: staticcheck -version - name: Run staticcheck run: staticcheck ./... - name: Run go vet run: go vet ./... netlink-1.7.2/.gitignore000066400000000000000000000001161442322526200151550ustar00rootroot00000000000000internal/integration/integration.test netlink.test netlink-fuzz.zip testdata/ netlink-1.7.2/CHANGELOG.md000066400000000000000000000172511442322526200150060ustar00rootroot00000000000000# CHANGELOG ## v1.7.2 - [Improvement]: updated dependencies, test with Go 1.20. ## v1.7.1 - [Bug Fix]: test only changes to avoid failures on big endian machines. ## v1.7.0 **This is the first release of package netlink that only supports Go 1.18+. Users on older versions of Go must use v1.6.2.** - [Improvement]: drop support for older versions of Go so we can begin using modern versions of `x/sys` and other dependencies. ## v1.6.2 **This is the last release of package netlink that supports Go 1.17 and below.** - [Bug Fix] [commit](https://github.com/mdlayher/netlink/commit/9f7f860d9865069cd1a6b4dee32a3095f0b841fc): undo update to `golang.org/x/sys` which would force the minimum Go version of this package to Go 1.17 due to use of `unsafe.Slice`. We encourage users to use the latest stable version of Go where possible, but continue to maintain some compatibility with older versions of Go as long as it is reasonable to do so. ## v1.6.1 - [Deprecation] [commit](https://github.com/mdlayher/netlink/commit/d1b69ea8697d721415c259ef8513ab699c6d3e96): the `netlink.Socket` interface has been marked as deprecated. The abstraction is awkward to use properly and disables much of the functionality of the Conn type when the basic interface is implemented. Do not use. ## v1.6.0 **This is the first release of package netlink that only supports Go 1.13+. Users on older versions of Go must use v1.5.0.** - [New API] [commit](https://github.com/mdlayher/netlink/commit/ad9e2c41caa993e3f4b68831d6cb2cb05818275d): the `netlink.Config.Strict` field can be used to apply a more strict default set of options to a `netlink.Conn`. This is recommended for applications running on modern Linux kernels, but cannot be enabled by default because the options may require a more recent kernel than the minimum kernel version that Go supports. See the documentation for details. - [Improvement]: broke some integration tests into a separate Go module so the default `go.mod` for package `netlink` has fewer dependencies. ## v1.5.0 **This is the last release of package netlink that supports Go 1.12.** - [New API] [commit](https://github.com/mdlayher/netlink/commit/53a1c10065e51077659ceedf921c8f0807abe8c0): the `netlink.Config.PID` field can be used to specify an explicit port ID when binding the netlink socket. This is intended for advanced use cases and most callers should leave this field set to 0. - [Improvement]: more low-level functionality ported to `github.com/mdlayher/socket`, reducing package complexity. ## v1.4.2 - [Documentation] [commit](https://github.com/mdlayher/netlink/commit/177e6364fb170d465d681c7c8a6283417a6d3e49): the `netlink.Config.DisableNSLockThread` now properly uses Go's deprecated identifier convention. This option has been a noop for a long time and should not be used. - [Improvement] [#189](https://github.com/mdlayher/netlink/pull/189): the package now uses Go 1.17's `//go:build` identifiers. Thanks @tklauser. - [Bug Fix] [commit](https://github.com/mdlayher/netlink/commit/fe6002e030928bd1f2a446c0b6c65e8f2df4ed5e): the `netlink.AttributeEncoder`'s `Bytes`, `String`, and `Do` methods now properly reject byte slices and strings which are too large to fit in the value of a netlink attribute. Thanks @ubiquitousbyte for the report. ## v1.4.1 - [Improvement]: significant runtime network poller integration cleanup through the use of `github.com/mdlayher/socket`. ## v1.4.0 - [New API] [#185](https://github.com/mdlayher/netlink/pull/185): the `netlink.AttributeDecoder` and `netlink.AttributeEncoder` types now have methods for dealing with signed integers: `Int8`, `Int16`, `Int32`, and `Int64`. These are necessary for working with rtnetlink's XDP APIs. Thanks @fbegyn. ## v1.3.2 - [Improvement] [commit](https://github.com/mdlayher/netlink/commit/ebc6e2e28bcf1a0671411288423d8116ff924d6d): `github.com/google/go-cmp` is no longer a (non-test) dependency of this module. ## v1.3.1 - [Improvement]: many internal cleanups and simplifications. The library is now slimmer and features less internal indirection. There are no user-facing changes in this release. ## v1.3.0 - [New API] [#176](https://github.com/mdlayher/netlink/pull/176): `netlink.OpError` now has `Message` and `Offset` fields which are populated when the kernel returns netlink extended acknowledgement data along with an error code. The caller can turn on this option by using `netlink.Conn.SetOption(netlink.ExtendedAcknowledge, true)`. - [New API] [commit](https://github.com/mdlayher/netlink/commit/beba85e0372133b6d57221191d2c557727cd1499): the `netlink.GetStrictCheck` option can be used to tell the kernel to be more strict when parsing requests. This enables more safety checks and can allow the kernel to perform more advanced request filtering in subsystems such as route netlink. ## v1.2.1 - [Bug Fix] [commit](https://github.com/mdlayher/netlink/commit/d81418f81b0bfa2465f33790a85624c63d6afe3d): `netlink.SetBPF` will no longer panic if an empty BPF filter is set. - [Improvement] [commit](https://github.com/mdlayher/netlink/commit/8014f9a7dbf4fd7b84a1783dd7b470db9113ff36): the library now uses https://github.com/josharian/native to provide the system's native endianness at compile time, rather than re-computing it many times at runtime. ## v1.2.0 **This is the first release of package netlink that only supports Go 1.12+. Users on older versions of Go must use v1.1.1.** - [Improvement] [#173](https://github.com/mdlayher/netlink/pull/173): support for Go 1.11 and below has been dropped. All users are highly recommended to use a stable and supported release of Go for their applications. - [Performance] [#171](https://github.com/mdlayher/netlink/pull/171): `netlink.Conn` no longer requires a locked OS thread for the vast majority of operations, which should result in a significant speedup for highly concurrent callers. Thanks @ti-mo. - [Bug Fix] [#169](https://github.com/mdlayher/netlink/pull/169): calls to `netlink.Conn.Close` are now able to unblock concurrent calls to `netlink.Conn.Receive` and other blocking operations. ## v1.1.1 **This is the last release of package netlink that supports Go 1.11.** - [Improvement] [#165](https://github.com/mdlayher/netlink/pull/165): `netlink.Conn` `SetReadBuffer` and `SetWriteBuffer` methods now attempt the `SO_*BUFFORCE` socket options to possibly ignore system limits given elevated caller permissions. Thanks @MarkusBauer. - [Note] [commit](https://github.com/mdlayher/netlink/commit/c5f8ab79aa345dcfcf7f14d746659ca1b80a0ecc): `netlink.Conn.Close` has had a long-standing bug [#162](https://github.com/mdlayher/netlink/pull/162) related to internal concurrency handling where a call to `Close` is not sufficient to unblock pending reads. To effectively fix this issue, it is necessary to drop support for Go 1.11 and below. This will be fixed in a future release, but a workaround is noted in the method documentation as of now. ## v1.1.0 - [New API] [#157](https://github.com/mdlayher/netlink/pull/157): the `netlink.AttributeDecoder.TypeFlags` method enables retrieval of the type bits stored in a netlink attribute's type field, because the existing `Type` method masks away these bits. Thanks @ti-mo! - [Performance] [#157](https://github.com/mdlayher/netlink/pull/157): `netlink.AttributeDecoder` now decodes netlink attributes on demand, enabling callers who only need a limited number of attributes to exit early from decoding loops. Thanks @ti-mo! - [Improvement] [#161](https://github.com/mdlayher/netlink/pull/161): `netlink.Conn` system calls are now ready for Go 1.14+'s changes to goroutine preemption. See the PR for details. ## v1.0.0 - Initial stable commit. netlink-1.7.2/LICENSE.md000066400000000000000000000020631442322526200145740ustar00rootroot00000000000000# MIT License Copyright (C) 2016-2022 Matt Layher Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. netlink-1.7.2/README.md000066400000000000000000000136031442322526200144510ustar00rootroot00000000000000# netlink [![Test Status](https://github.com/mdlayher/netlink/workflows/Linux%20Test/badge.svg)](https://github.com/mdlayher/netlink/actions) [![Go Reference](https://pkg.go.dev/badge/github.com/mdlayher/netlink.svg)](https://pkg.go.dev/github.com/mdlayher/netlink) [![Go Report Card](https://goreportcard.com/badge/github.com/mdlayher/netlink)](https://goreportcard.com/report/github.com/mdlayher/netlink) Package `netlink` provides low-level access to Linux netlink sockets (`AF_NETLINK`). MIT Licensed. For more information about how netlink works, check out my blog series on [Linux, Netlink, and Go](https://mdlayher.com/blog/linux-netlink-and-go-part-1-netlink/). If you have any questions or you'd like some guidance, please join us on [Gophers Slack](https://invite.slack.golangbridge.org) in the `#networking` channel! ## Stability See the [CHANGELOG](./CHANGELOG.md) file for a description of changes between releases. This package has a stable v1 API and any future breaking changes will prompt the release of a new major version. Features and bug fixes will continue to occur in the v1.x.x series. This package only supports the two most recent major versions of Go, mirroring Go's own release policy. Older versions of Go may lack critical features and bug fixes which are necessary for this package to function correctly. ## Design A number of netlink packages are already available for Go, but I wasn't able to find one that aligned with what I wanted in a netlink package: - Straightforward, idiomatic API - Well tested - Well documented - Doesn't use package/global variables or state - Doesn't necessarily need root to work My goal for this package is to use it as a building block for the creation of other netlink family packages. ## Ecosystem Over time, an ecosystem of Go packages has developed around package `netlink`. Many of these packages provide building blocks for further interactions with various netlink families, such as `NETLINK_GENERIC` or `NETLINK_ROUTE`. To have your package included in this diagram, please send a pull request! ```mermaid flowchart LR netlink["github.com/mdlayher/netlink"] click netlink "https://github.com/mdlayher/netlink" subgraph "NETLINK_CONNECTOR" direction LR garlic["github.com/fearful-symmetry/garlic"] click garlic "https://github.com/fearful-symmetry/garlic" end subgraph "NETLINK_CRYPTO" direction LR cryptonl["github.com/mdlayher/cryptonl"] click cryptonl "https://github.com/mdlayher/cryptonl" end subgraph "NETLINK_GENERIC" direction LR genetlink["github.com/mdlayher/genetlink"] click genetlink "https://github.com/mdlayher/genetlink" devlink["github.com/mdlayher/devlink"] click devlink "https://github.com/mdlayher/devlink" ethtool["github.com/mdlayher/ethtool"] click ethtool "https://github.com/mdlayher/ethtool" go-openvswitch["github.com/digitalocean/go-openvswitch"] click go-openvswitch "https://github.com/digitalocean/go-openvswitch" ipvs["github.com/cloudflare/ipvs"] click ipvs "https://github.com/cloudflare/ipvs" l2tp["github.com/axatrax/l2tp"] click l2tp "https://github.com/axatrax/l2tp" nbd["github.com/Merovius/nbd"] click nbd "https://github.com/Merovius/nbd" quota["github.com/mdlayher/quota"] click quota "https://github.com/mdlayher/quota" router7["github.com/rtr7/router7"] click router7 "https://github.com/rtr7/router7" taskstats["github.com/mdlayher/taskstats"] click taskstats "https://github.com/mdlayher/taskstats" u-bmc["github.com/u-root/u-bmc"] click u-bmc "https://github.com/u-root/u-bmc" wgctrl["golang.zx2c4.com/wireguard/wgctrl"] click wgctrl "https://golang.zx2c4.com/wireguard/wgctrl" wifi["github.com/mdlayher/wifi"] click wifi "https://github.com/mdlayher/wifi" devlink & ethtool & go-openvswitch & ipvs --> genetlink l2tp & nbd & quota & router7 & taskstats --> genetlink u-bmc & wgctrl & wifi --> genetlink end subgraph "NETLINK_KOBJECT_UEVENT" direction LR kobject["github.com/mdlayher/kobject"] click kobject "https://github.com/mdlayher/kobject" end subgraph "NETLINK_NETFILTER" direction LR go-conntrack["github.com/florianl/go-conntrack"] click go-conntrack "https://github.com/florianl/go-conntrack" go-nflog["github.com/florianl/go-nflog"] click go-nflog "https://github.com/florianl/go-nflog" go-nfqueue["github.com/florianl/go-nfqueue"] click go-nfqueue "https://github.com/florianl/go-nfqueue" netfilter["github.com/ti-mo/netfilter"] click netfilter "https://github.com/ti-mo/netfilter" nftables["github.com/google/nftables"] click nftables "https://github.com/google/nftables" conntrack["github.com/ti-mo/conntrack"] click conntrack "https://github.com/ti-mo/conntrack" conntrack --> netfilter end subgraph "NETLINK_ROUTE" direction LR go-tc["github.com/florianl/go-tc"] click go-tc "https://github.com/florianl/go-tc" qdisc["github.com/ema/qdisc"] click qdisc "https://github.com/ema/qdisc" rtnetlink["github.com/jsimonetti/rtnetlink"] click rtnetlink "https://github.com/jsimonetti/rtnetlink" rtnl["gitlab.com/mergetb/tech/rtnl"] click rtnl "https://gitlab.com/mergetb/tech/rtnl" end subgraph "NETLINK_W1" direction LR go-onewire["github.com/SpComb/go-onewire"] click go-onewire "https://github.com/SpComb/go-onewire" end NETLINK_CONNECTOR --> netlink NETLINK_CRYPTO --> netlink NETLINK_GENERIC --> netlink NETLINK_KOBJECT_UEVENT --> netlink NETLINK_NETFILTER --> netlink NETLINK_ROUTE --> netlink NETLINK_W1 --> netlink ``` netlink-1.7.2/align.go000066400000000000000000000021771442322526200146170ustar00rootroot00000000000000package netlink import "unsafe" // Functions and values used to properly align netlink messages, headers, // and attributes. Definitions taken from Linux kernel source. // #define NLMSG_ALIGNTO 4U const nlmsgAlignTo = 4 // #define NLMSG_ALIGN(len) ( ((len)+NLMSG_ALIGNTO-1) & ~(NLMSG_ALIGNTO-1) ) func nlmsgAlign(len int) int { return ((len) + nlmsgAlignTo - 1) & ^(nlmsgAlignTo - 1) } // #define NLMSG_LENGTH(len) ((len) + NLMSG_HDRLEN) func nlmsgLength(len int) int { return len + nlmsgHeaderLen } // #define NLMSG_HDRLEN ((int) NLMSG_ALIGN(sizeof(struct nlmsghdr))) var nlmsgHeaderLen = nlmsgAlign(int(unsafe.Sizeof(Header{}))) // #define NLA_ALIGNTO 4 const nlaAlignTo = 4 // #define NLA_ALIGN(len) (((len) + NLA_ALIGNTO - 1) & ~(NLA_ALIGNTO - 1)) func nlaAlign(len int) int { return ((len) + nlaAlignTo - 1) & ^(nlaAlignTo - 1) } // Because this package's Attribute type contains a byte slice, unsafe.Sizeof // can't be used to determine the correct length. const sizeofAttribute = 4 // #define NLA_HDRLEN ((int) NLA_ALIGN(sizeof(struct nlattr))) var nlaHeaderLen = nlaAlign(sizeofAttribute) netlink-1.7.2/align_test.go000066400000000000000000000023101442322526200156430ustar00rootroot00000000000000package netlink import ( "strconv" "testing" ) func Test_nlmsgAlign(t *testing.T) { tests := []struct { in int out int }{ { in: 0, out: 0, }, { in: 1, out: 4, }, { in: 2, out: 4, }, { in: 3, out: 4, }, { in: 4, out: 4, }, { in: 5, out: 8, }, { in: 6, out: 8, }, { in: 7, out: 8, }, { in: 8, out: 8, }, } for _, tt := range tests { t.Run(strconv.Itoa(tt.in), func(t *testing.T) { if want, got := tt.out, nlmsgAlign(tt.in); want != got { t.Fatalf("unexpected output:\n- want: %v\n- got: %v", want, got) } }) } } func Test_nlaAlign(t *testing.T) { tests := []struct { in int out int }{ { in: 0, out: 0, }, { in: 1, out: 4, }, { in: 2, out: 4, }, { in: 3, out: 4, }, { in: 4, out: 4, }, { in: 5, out: 8, }, { in: 6, out: 8, }, { in: 7, out: 8, }, { in: 8, out: 8, }, } for _, tt := range tests { t.Run(strconv.Itoa(tt.in), func(t *testing.T) { if want, got := tt.out, nlaAlign(tt.in); want != got { t.Fatalf("unexpected output:\n- want: %v\n- got: %v", want, got) } }) } } netlink-1.7.2/attribute.go000066400000000000000000000424771442322526200155370ustar00rootroot00000000000000package netlink import ( "encoding/binary" "errors" "fmt" "math" "github.com/josharian/native" "github.com/mdlayher/netlink/nlenc" ) // errInvalidAttribute specifies if an Attribute's length is incorrect. var errInvalidAttribute = errors.New("invalid attribute; length too short or too large") // An Attribute is a netlink attribute. Attributes are packed and unpacked // to and from the Data field of Message for some netlink families. type Attribute struct { // Length of an Attribute, including this field and Type. Length uint16 // The type of this Attribute, typically matched to a constant. Note that // flags such as Nested and NetByteOrder must be handled manually when // working with Attribute structures directly. Type uint16 // An arbitrary payload which is specified by Type. Data []byte } // marshal marshals the contents of a into b and returns the number of bytes // written to b, including attribute alignment padding. func (a *Attribute) marshal(b []byte) (int, error) { if int(a.Length) < nlaHeaderLen { return 0, errInvalidAttribute } nlenc.PutUint16(b[0:2], a.Length) nlenc.PutUint16(b[2:4], a.Type) n := copy(b[nlaHeaderLen:], a.Data) return nlaHeaderLen + nlaAlign(n), nil } // unmarshal unmarshals the contents of a byte slice into an Attribute. func (a *Attribute) unmarshal(b []byte) error { if len(b) < nlaHeaderLen { return errInvalidAttribute } a.Length = nlenc.Uint16(b[0:2]) a.Type = nlenc.Uint16(b[2:4]) if int(a.Length) > len(b) { return errInvalidAttribute } switch { // No length, no data case a.Length == 0: a.Data = make([]byte, 0) // Not enough length for any data case int(a.Length) < nlaHeaderLen: return errInvalidAttribute // Data present case int(a.Length) >= nlaHeaderLen: a.Data = make([]byte, len(b[nlaHeaderLen:a.Length])) copy(a.Data, b[nlaHeaderLen:a.Length]) } return nil } // MarshalAttributes packs a slice of Attributes into a single byte slice. // In most cases, the Length field of each Attribute should be set to 0, so it // can be calculated and populated automatically for each Attribute. // // It is recommend to use the AttributeEncoder type where possible instead of // calling MarshalAttributes and using package nlenc functions directly. func MarshalAttributes(attrs []Attribute) ([]byte, error) { // Count how many bytes we should allocate to store each attribute's contents. var c int for _, a := range attrs { c += nlaHeaderLen + nlaAlign(len(a.Data)) } // Advance through b with idx to place attribute data at the correct offset. var idx int b := make([]byte, c) for _, a := range attrs { // Infer the length of attribute if zero. if a.Length == 0 { a.Length = uint16(nlaHeaderLen + len(a.Data)) } // Marshal a into b and advance idx to show many bytes are occupied. n, err := a.marshal(b[idx:]) if err != nil { return nil, err } idx += n } return b, nil } // UnmarshalAttributes unpacks a slice of Attributes from a single byte slice. // // It is recommend to use the AttributeDecoder type where possible instead of calling // UnmarshalAttributes and using package nlenc functions directly. func UnmarshalAttributes(b []byte) ([]Attribute, error) { ad, err := NewAttributeDecoder(b) if err != nil { return nil, err } // Return a nil slice when there are no attributes to decode. if ad.Len() == 0 { return nil, nil } attrs := make([]Attribute, 0, ad.Len()) for ad.Next() { if ad.a.Length != 0 { attrs = append(attrs, ad.a) } } if err := ad.Err(); err != nil { return nil, err } return attrs, nil } // An AttributeDecoder provides a safe, iterator-like, API around attribute // decoding. // // It is recommend to use an AttributeDecoder where possible instead of calling // UnmarshalAttributes and using package nlenc functions directly. // // The Err method must be called after the Next method returns false to determine // if any errors occurred during iteration. type AttributeDecoder struct { // ByteOrder defines a specific byte order to use when processing integer // attributes. ByteOrder should be set immediately after creating the // AttributeDecoder: before any attributes are parsed. // // If not set, the native byte order will be used. ByteOrder binary.ByteOrder // The current attribute being worked on. a Attribute // The slice of input bytes and its iterator index. b []byte i int length int // Any error encountered while decoding attributes. err error } // NewAttributeDecoder creates an AttributeDecoder that unpacks Attributes // from b and prepares the decoder for iteration. func NewAttributeDecoder(b []byte) (*AttributeDecoder, error) { ad := &AttributeDecoder{ // By default, use native byte order. ByteOrder: native.Endian, b: b, } var err error ad.length, err = ad.available() if err != nil { return nil, err } return ad, nil } // Next advances the decoder to the next netlink attribute. It returns false // when no more attributes are present, or an error was encountered. func (ad *AttributeDecoder) Next() bool { if ad.err != nil { // Hit an error, stop iteration. return false } // Exit if array pointer is at or beyond the end of the slice. if ad.i >= len(ad.b) { return false } if err := ad.a.unmarshal(ad.b[ad.i:]); err != nil { ad.err = err return false } // Advance the pointer by at least one header's length. if int(ad.a.Length) < nlaHeaderLen { ad.i += nlaHeaderLen } else { ad.i += nlaAlign(int(ad.a.Length)) } return true } // Type returns the Attribute.Type field of the current netlink attribute // pointed to by the decoder. // // Type masks off the high bits of the netlink attribute type which may contain // the Nested and NetByteOrder flags. These can be obtained by calling TypeFlags. func (ad *AttributeDecoder) Type() uint16 { // Mask off any flags stored in the high bits. return ad.a.Type & attrTypeMask } // TypeFlags returns the two high bits of the Attribute.Type field of the current // netlink attribute pointed to by the decoder. // // These bits of the netlink attribute type are used for the Nested and NetByteOrder // flags, available as the Nested and NetByteOrder constants in this package. func (ad *AttributeDecoder) TypeFlags() uint16 { return ad.a.Type & ^attrTypeMask } // Len returns the number of netlink attributes pointed to by the decoder. func (ad *AttributeDecoder) Len() int { return ad.length } // count scans the input slice to count the number of netlink attributes // that could be decoded by Next(). func (ad *AttributeDecoder) available() (int, error) { var count int for i := 0; i < len(ad.b); { // Make sure there's at least a header's worth // of data to read on each iteration. if len(ad.b[i:]) < nlaHeaderLen { return 0, errInvalidAttribute } // Extract the length of the attribute. l := int(nlenc.Uint16(ad.b[i : i+2])) // Ignore zero-length attributes. if l != 0 { count++ } // Advance by at least a header's worth of bytes. if l < nlaHeaderLen { l = nlaHeaderLen } i += nlaAlign(l) } return count, nil } // data returns the Data field of the current Attribute pointed to by the decoder. func (ad *AttributeDecoder) data() []byte { return ad.a.Data } // Err returns the first error encountered by the decoder. func (ad *AttributeDecoder) Err() error { return ad.err } // Bytes returns the raw bytes of the current Attribute's data. func (ad *AttributeDecoder) Bytes() []byte { src := ad.data() dest := make([]byte, len(src)) copy(dest, src) return dest } // String returns the string representation of the current Attribute's data. func (ad *AttributeDecoder) String() string { if ad.err != nil { return "" } return nlenc.String(ad.data()) } // Uint8 returns the uint8 representation of the current Attribute's data. func (ad *AttributeDecoder) Uint8() uint8 { if ad.err != nil { return 0 } b := ad.data() if len(b) != 1 { ad.err = fmt.Errorf("netlink: attribute %d is not a uint8; length: %d", ad.Type(), len(b)) return 0 } return uint8(b[0]) } // Uint16 returns the uint16 representation of the current Attribute's data. func (ad *AttributeDecoder) Uint16() uint16 { if ad.err != nil { return 0 } b := ad.data() if len(b) != 2 { ad.err = fmt.Errorf("netlink: attribute %d is not a uint16; length: %d", ad.Type(), len(b)) return 0 } return ad.ByteOrder.Uint16(b) } // Uint32 returns the uint32 representation of the current Attribute's data. func (ad *AttributeDecoder) Uint32() uint32 { if ad.err != nil { return 0 } b := ad.data() if len(b) != 4 { ad.err = fmt.Errorf("netlink: attribute %d is not a uint32; length: %d", ad.Type(), len(b)) return 0 } return ad.ByteOrder.Uint32(b) } // Uint64 returns the uint64 representation of the current Attribute's data. func (ad *AttributeDecoder) Uint64() uint64 { if ad.err != nil { return 0 } b := ad.data() if len(b) != 8 { ad.err = fmt.Errorf("netlink: attribute %d is not a uint64; length: %d", ad.Type(), len(b)) return 0 } return ad.ByteOrder.Uint64(b) } // Int8 returns the Int8 representation of the current Attribute's data. func (ad *AttributeDecoder) Int8() int8 { if ad.err != nil { return 0 } b := ad.data() if len(b) != 1 { ad.err = fmt.Errorf("netlink: attribute %d is not a int8; length: %d", ad.Type(), len(b)) return 0 } return int8(b[0]) } // Int16 returns the Int16 representation of the current Attribute's data. func (ad *AttributeDecoder) Int16() int16 { if ad.err != nil { return 0 } b := ad.data() if len(b) != 2 { ad.err = fmt.Errorf("netlink: attribute %d is not a int16; length: %d", ad.Type(), len(b)) return 0 } return int16(ad.ByteOrder.Uint16(b)) } // Int32 returns the Int32 representation of the current Attribute's data. func (ad *AttributeDecoder) Int32() int32 { if ad.err != nil { return 0 } b := ad.data() if len(b) != 4 { ad.err = fmt.Errorf("netlink: attribute %d is not a int32; length: %d", ad.Type(), len(b)) return 0 } return int32(ad.ByteOrder.Uint32(b)) } // Int64 returns the Int64 representation of the current Attribute's data. func (ad *AttributeDecoder) Int64() int64 { if ad.err != nil { return 0 } b := ad.data() if len(b) != 8 { ad.err = fmt.Errorf("netlink: attribute %d is not a int64; length: %d", ad.Type(), len(b)) return 0 } return int64(ad.ByteOrder.Uint64(b)) } // Flag returns a boolean representing the Attribute. func (ad *AttributeDecoder) Flag() bool { if ad.err != nil { return false } b := ad.data() if len(b) != 0 { ad.err = fmt.Errorf("netlink: attribute %d is not a flag; length: %d", ad.Type(), len(b)) return false } return true } // Do is a general purpose function which allows access to the current data // pointed to by the AttributeDecoder. // // Do can be used to allow parsing arbitrary data within the context of the // decoder. Do is most useful when dealing with nested attributes, attribute // arrays, or decoding arbitrary types (such as C structures) which don't fit // cleanly into a typical unsigned integer value. // // The function fn should not retain any reference to the data b outside of the // scope of the function. func (ad *AttributeDecoder) Do(fn func(b []byte) error) { if ad.err != nil { return } b := ad.data() if err := fn(b); err != nil { ad.err = err } } // Nested decodes data into a nested AttributeDecoder to handle nested netlink // attributes. When calling Nested, the Err method does not need to be called on // the nested AttributeDecoder. // // The nested AttributeDecoder nad inherits the same ByteOrder setting as the // top-level AttributeDecoder ad. func (ad *AttributeDecoder) Nested(fn func(nad *AttributeDecoder) error) { // Because we are wrapping Do, there is no need to check ad.err immediately. ad.Do(func(b []byte) error { nad, err := NewAttributeDecoder(b) if err != nil { return err } nad.ByteOrder = ad.ByteOrder if err := fn(nad); err != nil { return err } return nad.Err() }) } // An AttributeEncoder provides a safe way to encode attributes. // // It is recommended to use an AttributeEncoder where possible instead of // calling MarshalAttributes or using package nlenc directly. // // Errors from intermediate encoding steps are returned in the call to // Encode. type AttributeEncoder struct { // ByteOrder defines a specific byte order to use when processing integer // attributes. ByteOrder should be set immediately after creating the // AttributeEncoder: before any attributes are encoded. // // If not set, the native byte order will be used. ByteOrder binary.ByteOrder attrs []Attribute err error } // NewAttributeEncoder creates an AttributeEncoder that encodes Attributes. func NewAttributeEncoder() *AttributeEncoder { return &AttributeEncoder{ByteOrder: native.Endian} } // Uint8 encodes uint8 data into an Attribute specified by typ. func (ae *AttributeEncoder) Uint8(typ uint16, v uint8) { if ae.err != nil { return } ae.attrs = append(ae.attrs, Attribute{ Type: typ, Data: []byte{v}, }) } // Uint16 encodes uint16 data into an Attribute specified by typ. func (ae *AttributeEncoder) Uint16(typ uint16, v uint16) { if ae.err != nil { return } b := make([]byte, 2) ae.ByteOrder.PutUint16(b, v) ae.attrs = append(ae.attrs, Attribute{ Type: typ, Data: b, }) } // Uint32 encodes uint32 data into an Attribute specified by typ. func (ae *AttributeEncoder) Uint32(typ uint16, v uint32) { if ae.err != nil { return } b := make([]byte, 4) ae.ByteOrder.PutUint32(b, v) ae.attrs = append(ae.attrs, Attribute{ Type: typ, Data: b, }) } // Uint64 encodes uint64 data into an Attribute specified by typ. func (ae *AttributeEncoder) Uint64(typ uint16, v uint64) { if ae.err != nil { return } b := make([]byte, 8) ae.ByteOrder.PutUint64(b, v) ae.attrs = append(ae.attrs, Attribute{ Type: typ, Data: b, }) } // Int8 encodes int8 data into an Attribute specified by typ. func (ae *AttributeEncoder) Int8(typ uint16, v int8) { if ae.err != nil { return } ae.attrs = append(ae.attrs, Attribute{ Type: typ, Data: []byte{uint8(v)}, }) } // Int16 encodes int16 data into an Attribute specified by typ. func (ae *AttributeEncoder) Int16(typ uint16, v int16) { if ae.err != nil { return } b := make([]byte, 2) ae.ByteOrder.PutUint16(b, uint16(v)) ae.attrs = append(ae.attrs, Attribute{ Type: typ, Data: b, }) } // Int32 encodes int32 data into an Attribute specified by typ. func (ae *AttributeEncoder) Int32(typ uint16, v int32) { if ae.err != nil { return } b := make([]byte, 4) ae.ByteOrder.PutUint32(b, uint32(v)) ae.attrs = append(ae.attrs, Attribute{ Type: typ, Data: b, }) } // Int64 encodes int64 data into an Attribute specified by typ. func (ae *AttributeEncoder) Int64(typ uint16, v int64) { if ae.err != nil { return } b := make([]byte, 8) ae.ByteOrder.PutUint64(b, uint64(v)) ae.attrs = append(ae.attrs, Attribute{ Type: typ, Data: b, }) } // Flag encodes a flag into an Attribute specified by typ. func (ae *AttributeEncoder) Flag(typ uint16, v bool) { // Only set flag on no previous error or v == true. if ae.err != nil || !v { return } // Flags have no length or data fields. ae.attrs = append(ae.attrs, Attribute{Type: typ}) } // String encodes string s as a null-terminated string into an Attribute // specified by typ. func (ae *AttributeEncoder) String(typ uint16, s string) { if ae.err != nil { return } // Length checking, thanks ubiquitousbyte on GitHub. if len(s) > math.MaxUint16-nlaHeaderLen { ae.err = errors.New("string is too large to fit in a netlink attribute") return } ae.attrs = append(ae.attrs, Attribute{ Type: typ, Data: nlenc.Bytes(s), }) } // Bytes embeds raw byte data into an Attribute specified by typ. func (ae *AttributeEncoder) Bytes(typ uint16, b []byte) { if ae.err != nil { return } if len(b) > math.MaxUint16-nlaHeaderLen { ae.err = errors.New("byte slice is too large to fit in a netlink attribute") return } ae.attrs = append(ae.attrs, Attribute{ Type: typ, Data: b, }) } // Do is a general purpose function to encode arbitrary data into an attribute // specified by typ. // // Do is especially helpful in encoding nested attributes, attribute arrays, // or encoding arbitrary types (such as C structures) which don't fit cleanly // into an unsigned integer value. func (ae *AttributeEncoder) Do(typ uint16, fn func() ([]byte, error)) { if ae.err != nil { return } b, err := fn() if err != nil { ae.err = err return } if len(b) > math.MaxUint16-nlaHeaderLen { ae.err = errors.New("byte slice produced by Do is too large to fit in a netlink attribute") return } ae.attrs = append(ae.attrs, Attribute{ Type: typ, Data: b, }) } // Nested embeds data produced by a nested AttributeEncoder and flags that data // with the Nested flag. When calling Nested, the Encode method should not be // called on the nested AttributeEncoder. // // The nested AttributeEncoder nae inherits the same ByteOrder setting as the // top-level AttributeEncoder ae. func (ae *AttributeEncoder) Nested(typ uint16, fn func(nae *AttributeEncoder) error) { // Because we are wrapping Do, there is no need to check ae.err immediately. ae.Do(Nested|typ, func() ([]byte, error) { nae := NewAttributeEncoder() nae.ByteOrder = ae.ByteOrder if err := fn(nae); err != nil { return nil, err } return nae.Encode() }) } // Encode returns the encoded bytes representing the attributes. func (ae *AttributeEncoder) Encode() ([]byte, error) { if ae.err != nil { return nil, ae.err } return MarshalAttributes(ae.attrs) } netlink-1.7.2/attribute_test.go000066400000000000000000000531151442322526200165650ustar00rootroot00000000000000package netlink import ( "bytes" "encoding/binary" "errors" "math" "reflect" "testing" "unsafe" "github.com/google/go-cmp/cmp" "github.com/josharian/native" "github.com/mdlayher/netlink/nlenc" ) func TestMarshalAttributes(t *testing.T) { skipBigEndian(t) tests := []struct { name string attrs []Attribute b []byte err error }{ { name: "one attribute, short length", attrs: []Attribute{{ Length: 3, Type: 1, }}, err: errInvalidAttribute, }, { name: "one attribute, no data", attrs: []Attribute{{ Length: 4, Type: 1, Data: make([]byte, 0), }}, b: []byte{ 0x04, 0x00, 0x01, 0x00, }, }, { name: "one attribute, no data, length calculated", attrs: []Attribute{{ Type: 1, Data: make([]byte, 0), }}, b: []byte{ 0x04, 0x00, 0x01, 0x00, }, }, { name: "one attribute, padded", attrs: []Attribute{{ Length: 5, Type: 1, Data: []byte{0xff}, }}, b: []byte{ 0x05, 0x00, 0x01, 0x00, 0xff, 0x00, 0x00, 0x00, }, }, { name: "one attribute, padded, length calculated", attrs: []Attribute{{ Type: 1, Data: []byte{0xff}, }}, b: []byte{ 0x05, 0x00, 0x01, 0x00, 0xff, 0x00, 0x00, 0x00, }, }, { name: "one attribute, aligned", attrs: []Attribute{{ Length: 8, Type: 2, Data: []byte{0xaa, 0xbb, 0xcc, 0xdd}, }}, b: []byte{ 0x08, 0x00, 0x02, 0x00, 0xaa, 0xbb, 0xcc, 0xdd, }, }, { name: "one attribute, aligned, length calculated", attrs: []Attribute{{ Type: 2, Data: []byte{0xaa, 0xbb, 0xcc, 0xdd}, }}, b: []byte{ 0x08, 0x00, 0x02, 0x00, 0xaa, 0xbb, 0xcc, 0xdd, }, }, { name: "multiple attributes", attrs: []Attribute{ { Length: 5, Type: 1, Data: []byte{0xff}, }, { Length: 8, Type: 2, Data: []byte{0xaa, 0xbb, 0xcc, 0xdd}, }, { Length: 4, Type: 3, Data: make([]byte, 0), }, { Length: 16, Type: 4, Data: []byte{ 0x11, 0x11, 0x11, 0x11, 0x22, 0x22, 0x22, 0x22, 0x33, 0x33, 0x33, 0x33, }, }, }, b: []byte{ // 1 0x05, 0x00, 0x01, 0x00, 0xff, 0x00, 0x00, 0x00, // 2 0x08, 0x00, 0x02, 0x00, 0xaa, 0xbb, 0xcc, 0xdd, // 3 0x04, 0x00, 0x03, 0x00, // 4 0x10, 0x00, 0x04, 0x00, 0x11, 0x11, 0x11, 0x11, 0x22, 0x22, 0x22, 0x22, 0x33, 0x33, 0x33, 0x33, }, }, { name: "multiple attributes, length calculated", attrs: []Attribute{ { Type: 1, Data: []byte{0xff}, }, { Type: 2, Data: []byte{0xaa, 0xbb, 0xcc, 0xdd}, }, { Type: 3, Data: make([]byte, 0), }, { Type: 4, Data: []byte{ 0x11, 0x11, 0x11, 0x11, 0x22, 0x22, 0x22, 0x22, 0x33, 0x33, 0x33, 0x33, }, }, }, b: []byte{ // 1 0x05, 0x00, 0x01, 0x00, 0xff, 0x00, 0x00, 0x00, // 2 0x08, 0x00, 0x02, 0x00, 0xaa, 0xbb, 0xcc, 0xdd, // 3 0x04, 0x00, 0x03, 0x00, // 4 0x10, 0x00, 0x04, 0x00, 0x11, 0x11, 0x11, 0x11, 0x22, 0x22, 0x22, 0x22, 0x33, 0x33, 0x33, 0x33, }, }, { name: "max type space, length 0", attrs: []Attribute{ { Length: 4, Type: 0xffff, Data: make([]byte, 0), }, }, b: []byte{ 0x04, 0x00, 0xff, 0xff, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { b, err := MarshalAttributes(tt.attrs) if want, got := tt.err, err; want != got { t.Fatalf("unexpected error:\n- want: %v\n- got: %v", want, got) } if err != nil { return } if want, got := tt.b, b; !bytes.Equal(want, got) { t.Fatalf("unexpected bytes:\n- want: [%# x]\n- got: [%# x]", want, got) } }) } } func TestUnmarshalAttributes(t *testing.T) { skipBigEndian(t) tests := []struct { name string b []byte attrs []Attribute err error }{ { name: "empty slice", }, { name: "short slice", b: make([]byte, 3), err: errInvalidAttribute, }, { name: "length too short (<4 bytes)", b: []byte{ 0x03, 0x00, 0x00, }, err: errInvalidAttribute, }, { name: "length too long", b: []byte{ 0xff, 0xff, 0x00, 0x00, }, err: errInvalidAttribute, }, { name: "one attribute, not aligned", b: []byte{ 0x05, 0x00, 0x01, 0x00, 0xff, }, attrs: []Attribute{{ Length: 5, Type: 1, Data: []byte{0xff}, }}, }, { name: "fuzz crasher: length 1, too short", b: []byte("\x01\x0000"), err: errInvalidAttribute, }, { name: "no attributes, length 0", b: []byte{ 0x00, 0x00, 0x00, 0x00, }, }, { name: "one attribute, no data", b: []byte{ 0x04, 0x00, 0x01, 0x00, }, attrs: []Attribute{{ Length: 4, Type: 1, Data: make([]byte, 0), }}, }, { name: "one attribute, padded", b: []byte{ 0x05, 0x00, 0x01, 0x00, 0xff, 0x00, 0x00, 0x00, }, attrs: []Attribute{{ Length: 5, Type: 1, Data: []byte{0xff}, }}, }, { name: "one attribute, aligned", b: []byte{ 0x08, 0x00, 0x02, 0x00, 0xaa, 0xbb, 0xcc, 0xdd, }, attrs: []Attribute{{ Length: 8, Type: 2, Data: []byte{0xaa, 0xbb, 0xcc, 0xdd}, }}, }, { name: "multiple attributes", b: []byte{ // 1 0x05, 0x00, 0x01, 0x00, 0xff, 0x00, 0x00, 0x00, // 2 0x08, 0x00, 0x02, 0x00, 0xaa, 0xbb, 0xcc, 0xdd, // 3 0x04, 0x00, 0x03, 0x00, // 4 0x10, 0x00, 0x04, 0x00, 0x11, 0x11, 0x11, 0x11, 0x22, 0x22, 0x22, 0x22, 0x33, 0x33, 0x33, 0x33, }, attrs: []Attribute{ { Length: 5, Type: 1, Data: []byte{0xff}, }, { Length: 8, Type: 2, Data: []byte{0xaa, 0xbb, 0xcc, 0xdd}, }, { Length: 4, Type: 3, Data: make([]byte, 0), }, { Length: 16, Type: 4, Data: []byte{ 0x11, 0x11, 0x11, 0x11, 0x22, 0x22, 0x22, 0x22, 0x33, 0x33, 0x33, 0x33, }, }, }, }, { name: "max type space, length 0", b: []byte{ 0x04, 0x00, 0xff, 0xff, }, attrs: []Attribute{ { Length: 4, Type: 0xffff, Data: make([]byte, 0), }, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { attrs, err := UnmarshalAttributes(tt.b) if want, got := tt.err, err; want != got { t.Fatalf("unexpected error:\n- want: %v\n- got: %v", want, got) } if err != nil { return } if want, got := tt.attrs, attrs; !reflect.DeepEqual(want, got) { t.Fatalf("unexpected attributes:\n- want: %v\n- got: %v", want, got) } }) } } func TestAttributeDecoderError(t *testing.T) { bad := []Attribute{{ Type: 1, // Doesn't fit any integer types. Data: []byte{0xe, 0xad, 0xbe}, }} skipBigEndian(t) tests := []struct { name string attrs []Attribute fn func(ad *AttributeDecoder) }{ { name: "uint8", attrs: bad, fn: func(ad *AttributeDecoder) { ad.Uint8() ad.Next() ad.Uint8() }, }, { name: "uint16", attrs: bad, fn: func(ad *AttributeDecoder) { ad.Uint16() ad.Next() ad.Uint16() }, }, { name: "uint32", attrs: bad, fn: func(ad *AttributeDecoder) { ad.Uint32() ad.Next() ad.Uint32() }, }, { name: "uint64", attrs: bad, fn: func(ad *AttributeDecoder) { ad.Uint64() ad.Next() ad.Uint64() }, }, { name: "int8", attrs: bad, fn: func(ad *AttributeDecoder) { ad.Int8() ad.Next() ad.Int8() }, }, { name: "int16", attrs: bad, fn: func(ad *AttributeDecoder) { ad.Int16() ad.Next() ad.Int16() }, }, { name: "int32", attrs: bad, fn: func(ad *AttributeDecoder) { ad.Int32() ad.Next() ad.Int32() }, }, { name: "int64", attrs: bad, fn: func(ad *AttributeDecoder) { ad.Int64() ad.Next() ad.Int64() }, }, { name: "do", attrs: bad, fn: func(ad *AttributeDecoder) { ad.Do(func(_ []byte) error { return errors.New("some error") }) ad.Do(func(_ []byte) error { panic("shouldn't be called") }) }, }, { name: "flag", attrs: []Attribute{{ Type: 1, // Flag data is not empty. Data: []byte{0xff}, }}, fn: func(ad *AttributeDecoder) { ad.Flag() ad.Next() ad.Flag() }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { b, err := MarshalAttributes(tt.attrs) if err != nil { t.Fatalf("failed to marshal attributes: %v", err) } ad, err := NewAttributeDecoder(b) if err != nil { t.Fatalf("failed to create attribute decoder: %v", err) } for ad.Next() { tt.fn(ad) } if err := ad.Err(); err == nil { t.Fatal("expected an error, but none occurred") } }) } } func TestAttributeDecoderOK(t *testing.T) { skipBigEndian(t) tests := []struct { name string attrs []Attribute fn func(ad *AttributeDecoder) }{ { name: "empty", attrs: nil, fn: func(_ *AttributeDecoder) { panic("should not be called") }, }, { name: "uint-int native endian", attrs: adEndianAttrs(native.Endian), fn: adEndianTest(native.Endian), }, { name: "uint-int little endian", attrs: adEndianAttrs(binary.LittleEndian), fn: adEndianTest(binary.LittleEndian), }, { name: "uint-int big endian", attrs: adEndianAttrs(binary.BigEndian), fn: adEndianTest(binary.BigEndian), }, { name: "bytes", attrs: []Attribute{{ Type: 1, Data: []byte{0xde, 0xad}, }}, fn: func(ad *AttributeDecoder) { var b []byte switch t := ad.Type(); t { case 1: b = ad.Bytes() default: panicf("unhandled attribute type: %d", t) } if diff := cmp.Diff([]byte{0xde, 0xad}, b); diff != "" { panicf("unexpected attribute value (-want +got):\n%s", diff) } b[0] = 0xff if diff := cmp.Diff(b, ad.Bytes()); diff == "" { panic("expected attribute value to be copied and different") } }, }, { name: "string", attrs: []Attribute{{ Type: 1, // The string should be able to contain extra trailing NULL // bytes which will all be removed automatically. Data: nlenc.Bytes("hello world\x00\x00\x00"), }}, fn: func(ad *AttributeDecoder) { var s string switch t := ad.Type(); t { case 1: s = ad.String() default: panicf("unhandled attribute type: %d", t) } if diff := cmp.Diff("hello world", s); diff != "" { panicf("unexpected attribute value (-want +got):\n%s", diff) } }, }, { name: "flag", attrs: []Attribute{{ Type: 1, }}, fn: func(ad *AttributeDecoder) { var flag bool switch t := ad.Type(); t { case 1: flag = ad.Flag() default: panicf("unhandled attribute type: %d", t) } if !flag { panic("flag was not set") } }, }, { name: "do", attrs: []Attribute{ // Arbitrary C-like structure. { Type: 1, Data: []byte{ // uint16 0xde, 0xad, // uint8 0xbe, // padding 0x00, }, }, // Nested attributes. { Type: 2, Data: func() []byte { b, err := MarshalAttributes([]Attribute{{ Type: 2, Data: nlenc.Uint16Bytes(2), }}) if err != nil { panicf("failed to marshal test attributes: %v", err) } return b }(), }, }, fn: func(ad *AttributeDecoder) { switch t := ad.Type(); t { case 1: type cstruct struct { A uint16 B uint8 } want := cstruct{ // Little-endian is the worst. A: 0xadde, B: 0xbe, } ad.Do(func(b []byte) error { // unsafe invariant check. if want, got := int(unsafe.Sizeof(cstruct{})), len(b); want != got { panicf("unexpected struct size: want: %d, got: %d", want, got) } got := *(*cstruct)(unsafe.Pointer(&b[0])) if diff := cmp.Diff(want, got); diff != "" { panicf("unexpected struct (-want +got):\n%s", diff) } return nil }) case 2: ad.Do(func(b []byte) error { adi, err := NewAttributeDecoder(b) if err != nil { return err } var got int first := true for adi.Next() { if !first { panic("loop iterated too many times") } first = false if adi.Type() != 2 { panicf("unhandled attribute type: %d", t) } got = int(adi.Uint16()) } if diff := cmp.Diff(2, got); diff != "" { panicf("unexpected nested attribute value (-want +got):\n%s", diff) } return adi.Err() }) default: panicf("unhandled attribute type: %d", t) } }, }, { name: "nested", attrs: []Attribute{ // Nested attributes. { Type: Nested | 1, Data: func() []byte { nb, err := MarshalAttributes([]Attribute{{ Type: 1, Data: nlenc.Uint32Bytes(2), }}) if err != nil { panicf("failed to marshal nested test attributes: %v", err) } b, err := MarshalAttributes([]Attribute{ { Type: 1, Data: nlenc.Uint16Bytes(1), }, { Type: Nested | 2, Data: nb, }, }) if err != nil { panicf("failed to marshal test attributes: %v", err) } return b }(), }, }, fn: func(ad *AttributeDecoder) { if diff := cmp.Diff(uint16(1), ad.Type()); diff != "" { panicf("unexpected attribute type (-want +got):\n%s", diff) } ad.Nested(func(nad *AttributeDecoder) error { for nad.Next() { switch t := nad.Type(); t { case 1: if diff := cmp.Diff(uint16(1), nad.Uint16()); diff != "" { panicf("unexpected nested uint16 (-want +got):\n%s", diff) } case 2: nad.Nested(func(nnad *AttributeDecoder) error { for nad.Next() { if diff := cmp.Diff(uint16(1), nnad.Type()); diff != "" { panicf("unexpected nested attribute type (-want +got):\n%s", diff) } if diff := cmp.Diff(uint32(2), nnad.Uint32()); diff != "" { panicf("unexpected nested uint32 (-want +got):\n%s", diff) } } return nil }) default: panicf("unhandled nested attribute type: %d", t) } } return nil }) }, }, { name: "typeflags", attrs: []Attribute{{ Type: 0xffff, }}, fn: func(ad *AttributeDecoder) { if diff := cmp.Diff(ad.Type(), uint16(0x3fff)); diff != "" { panicf("unexpected Type (-want +got):\n%s", diff) } if diff := cmp.Diff(ad.TypeFlags(), uint16(0xc000)); diff != "" { panicf("unexpected TypeFlags (-want +got):\n%s", diff) } }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { b, err := MarshalAttributes(tt.attrs) if err != nil { t.Fatalf("failed to marshal attributes: %v", err) } ad, err := NewAttributeDecoder(b) if err != nil { t.Fatalf("failed to create attribute decoder: %v", err) } // Len should always report the same number of input attributes. if diff := cmp.Diff(len(tt.attrs), ad.Len()); diff != "" { t.Fatalf("unexpected (-want +got):\n%s", diff) } for ad.Next() { tt.fn(ad) } if err := ad.Err(); err != nil { t.Fatalf("failed to decode attributes: %v", err) } }) } } func adEndianAttrs(order binary.ByteOrder) []Attribute { return []Attribute{ { Type: 1, Data: func() []byte { return []byte{1} }(), }, { Type: 2, Data: func() []byte { b := make([]byte, 2) order.PutUint16(b, 2) return b }(), }, { Type: 3, Data: func() []byte { b := make([]byte, 4) order.PutUint32(b, 3) return b }(), }, { Type: 4, Data: func() []byte { b := make([]byte, 8) order.PutUint64(b, 4) return b }(), }, { Type: 5, Data: func() []byte { return []byte{uint8(int8(5))} }(), }, { Type: 6, Data: func() []byte { b := make([]byte, 2) order.PutUint16(b, uint16(int16(6))) return b }(), }, { Type: 7, Data: func() []byte { b := make([]byte, 4) order.PutUint32(b, uint32(int32(7))) return b }(), }, { Type: 8, Data: func() []byte { b := make([]byte, 8) order.PutUint64(b, uint64(int64(8))) return b }(), }, } } func adEndianTest(order binary.ByteOrder) func(ad *AttributeDecoder) { return func(ad *AttributeDecoder) { ad.ByteOrder = order var ( t uint16 v int ) switch t = ad.Type(); t { case 1: v = int(ad.Uint8()) case 2: v = int(ad.Uint16()) case 3: v = int(ad.Uint32()) case 4: v = int(ad.Uint64()) case 5: v = int(ad.Int8()) case 6: v = int(ad.Int16()) case 7: v = int(ad.Int32()) case 8: v = int(ad.Int64()) default: panicf("unhandled attribute type: %d", t) } if diff := cmp.Diff(int(t), v); diff != "" { panicf("unexpected attribute value (-want +got):\n%s", diff) } } } func TestAttributeEncoderError(t *testing.T) { skipBigEndian(t) tests := []struct { name string fn func(ae *AttributeEncoder) }{ { name: "bytes length", fn: func(ae *AttributeEncoder) { ae.Bytes(1, make([]byte, math.MaxUint16)) }, }, { name: "string length", fn: func(ae *AttributeEncoder) { ae.String(1, string(make([]byte, math.MaxUint16))) }, }, { name: "do length", fn: func(ae *AttributeEncoder) { ae.Do(1, func() ([]byte, error) { return make([]byte, math.MaxUint16), nil }) }, }, { name: "do function", fn: func(ae *AttributeEncoder) { ae.Do(1, func() ([]byte, error) { return nil, errors.New("testing error") }) }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ae := NewAttributeEncoder() tt.fn(ae) _, err := ae.Encode() if err == nil { t.Fatal("expected an error, but none occurred") } }) } } func TestAttributeEncoderOK(t *testing.T) { skipBigEndian(t) tests := []struct { name string attrs []Attribute endian binary.ByteOrder fn func(ae *AttributeEncoder) }{ { name: "empty", attrs: nil, fn: func(_ *AttributeEncoder) { }, }, { name: "uint-int native endian", attrs: adEndianAttrs(native.Endian), fn: aeEndianTest(native.Endian), }, { name: "uint-int little endian", attrs: adEndianAttrs(binary.LittleEndian), endian: binary.LittleEndian, fn: aeEndianTest(binary.LittleEndian), }, { name: "uint-int big endian", attrs: adEndianAttrs(binary.BigEndian), endian: binary.BigEndian, fn: aeEndianTest(binary.BigEndian), }, { name: "flag true", attrs: []Attribute{{Type: 1}}, fn: func(ae *AttributeEncoder) { ae.Flag(1, true) }, }, { name: "flag false", attrs: []Attribute{}, fn: func(ae *AttributeEncoder) { ae.Flag(1, false) }, }, { name: "string", attrs: []Attribute{{ Type: 1, Data: nlenc.Bytes("hello netlink"), }}, fn: func(ae *AttributeEncoder) { ae.String(1, "hello netlink") }, }, { name: "byte", attrs: []Attribute{ { Type: 1, Data: []byte{0xde, 0xad}, }, }, fn: func(ae *AttributeEncoder) { ae.Bytes(1, []byte{0xde, 0xad}) }, }, { name: "do", attrs: []Attribute{ // Arbitrary C-like structure. { Type: 1, Data: []byte{0xde, 0xad, 0xbe}, }, // Nested attributes. { Type: 2, Data: func() []byte { b, err := MarshalAttributes([]Attribute{{ Type: 2, Data: nlenc.Uint16Bytes(2), }}) if err != nil { panicf("failed to marshal test attributes: %v", err) } return b }(), }, }, fn: func(ae *AttributeEncoder) { ae.Do(1, func() ([]byte, error) { return []byte{0xde, 0xad, 0xbe}, nil }) ae.Do(2, func() ([]byte, error) { ae1 := NewAttributeEncoder() ae1.Uint16(2, 2) return ae1.Encode() }) }, }, { name: "nested", attrs: []Attribute{ // Nested attributes. { Type: Nested | 1, Data: func() []byte { nb, err := MarshalAttributes([]Attribute{{ Type: 1, Data: nlenc.Uint32Bytes(2), }}) if err != nil { panicf("failed to marshal nested test attributes: %v", err) } b, err := MarshalAttributes([]Attribute{ { Type: 1, Data: nlenc.Uint16Bytes(1), }, { Type: Nested | 2, Data: nb, }, }) if err != nil { panicf("failed to marshal test attributes: %v", err) } return b }(), }, }, fn: func(ae *AttributeEncoder) { ae.Nested(1, func(nae *AttributeEncoder) error { nae.Uint16(1, 1) nae.Nested(2, func(nnae *AttributeEncoder) error { nnae.Uint32(1, 2) return nil }) return nil }) }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { b, err := MarshalAttributes(tt.attrs) if err != nil { t.Fatalf("failed to marshal attributes: %v", err) } ae := NewAttributeEncoder() tt.fn(ae) got, err := ae.Encode() if err != nil { t.Fatalf("failed to encode attributes: %v", err) } if diff := cmp.Diff(got, b); diff != "" { t.Fatalf("unexpected attribute encoding (-want +got):\n%s", diff) } }) } } func aeEndianTest(order binary.ByteOrder) func(ae *AttributeEncoder) { return func(ae *AttributeEncoder) { ae.ByteOrder = order ae.Uint8(1, uint8(1)) ae.Uint16(2, uint16(2)) ae.Uint32(3, uint32(3)) ae.Uint64(4, uint64(4)) ae.Int8(5, int8(5)) ae.Int16(6, int16(6)) ae.Int32(7, int32(7)) ae.Int64(8, int64(8)) } } netlink-1.7.2/bench_test.go000066400000000000000000000025241442322526200156370ustar00rootroot00000000000000package netlink_test import ( "testing" "github.com/mdlayher/netlink" ) var attrBench = []struct { name string attrs []netlink.Attribute }{ { name: "0", }, { name: "1", attrs: makeAttributes(1), }, { name: "8", attrs: makeAttributes(8), }, { name: "64", attrs: makeAttributes(64), }, { name: "512", attrs: makeAttributes(512), }, } func BenchmarkMarshalAttributes(b *testing.B) { for _, tt := range attrBench { b.Run(tt.name, func(b *testing.B) { b.ResetTimer() b.ReportAllocs() for i := 0; i < b.N; i++ { if _, err := netlink.MarshalAttributes(tt.attrs); err != nil { b.Fatalf("failed to marshal: %v", err) } } }) } } func BenchmarkUnmarshalAttributes(b *testing.B) { for _, tt := range attrBench { b.Run(tt.name, func(b *testing.B) { buf, err := netlink.MarshalAttributes(tt.attrs) if err != nil { b.Fatalf("failed to marshal: %v", err) } b.ResetTimer() b.ReportAllocs() for i := 0; i < b.N; i++ { if _, err := netlink.UnmarshalAttributes(buf); err != nil { b.Fatalf("failed to unmarshal: %v", err) } } }) } } func makeAttributes(n int) []netlink.Attribute { attrs := make([]netlink.Attribute, 0, n) for i := 0; i < n; i++ { attrs = append(attrs, netlink.Attribute{ Type: uint16(i), Data: make([]byte, n), }) } return attrs } netlink-1.7.2/conn.go000066400000000000000000000414671442322526200144670ustar00rootroot00000000000000package netlink import ( "math/rand" "sync" "sync/atomic" "syscall" "time" "golang.org/x/net/bpf" ) // A Conn is a connection to netlink. A Conn can be used to send and // receives messages to and from netlink. // // A Conn is safe for concurrent use, but to avoid contention in // high-throughput applications, the caller should almost certainly create a // pool of Conns and distribute them among workers. // // A Conn is capable of manipulating netlink subsystems from within a specific // Linux network namespace, but special care must be taken when doing so. See // the documentation of Config for details. type Conn struct { // Atomics must come first. // // seq is an atomically incremented integer used to provide sequence // numbers when Conn.Send is called. seq uint32 // mu serializes access to the netlink socket for the request/response // transaction within Execute. mu sync.RWMutex // sock is the operating system-specific implementation of // a netlink sockets connection. sock Socket // pid is the PID assigned by netlink. pid uint32 // d provides debugging capabilities for a Conn if not nil. d *debugger } // A Socket is an operating-system specific implementation of netlink // sockets used by Conn. // // Deprecated: the intent of Socket was to provide an abstraction layer for // testing, but this abstraction is awkward to use properly and disables much of // the functionality of the Conn type. Do not use. type Socket interface { Close() error Send(m Message) error SendMessages(m []Message) error Receive() ([]Message, error) } // Dial dials a connection to netlink, using the specified netlink family. // Config specifies optional configuration for Conn. If config is nil, a default // configuration will be used. func Dial(family int, config *Config) (*Conn, error) { // TODO(mdlayher): plumb in netlink.OpError wrapping? // Use OS-specific dial() to create Socket. c, pid, err := dial(family, config) if err != nil { return nil, err } return NewConn(c, pid), nil } // NewConn creates a Conn using the specified Socket and PID for netlink // communications. // // NewConn is primarily useful for tests. Most applications should use // Dial instead. func NewConn(sock Socket, pid uint32) *Conn { // Seed the sequence number using a random number generator. r := rand.New(rand.NewSource(time.Now().UnixNano())) seq := r.Uint32() // Configure a debugger if arguments are set. var d *debugger if len(debugArgs) > 0 { d = newDebugger(debugArgs) } return &Conn{ seq: seq, sock: sock, pid: pid, d: d, } } // debug executes fn with the debugger if the debugger is not nil. func (c *Conn) debug(fn func(d *debugger)) { if c.d == nil { return } fn(c.d) } // Close closes the connection and unblocks any pending read operations. func (c *Conn) Close() error { // Close does not acquire a lock because it must be able to interrupt any // blocked system calls, such as when Receive is waiting on a multicast // group message. // // We rely on the kernel to deal with concurrent operations to the netlink // socket itself. return newOpError("close", c.sock.Close()) } // Execute sends a single Message to netlink using Send, receives one or more // replies using Receive, and then checks the validity of the replies against // the request using Validate. // // Execute acquires a lock for the duration of the function call which blocks // concurrent calls to Send, SendMessages, and Receive, in order to ensure // consistency between netlink request/reply messages. // // See the documentation of Send, Receive, and Validate for details about // each function. func (c *Conn) Execute(m Message) ([]Message, error) { // Acquire the write lock and invoke the internal implementations of Send // and Receive which require the lock already be held. c.mu.Lock() defer c.mu.Unlock() req, err := c.lockedSend(m) if err != nil { return nil, err } res, err := c.lockedReceive() if err != nil { return nil, err } if err := Validate(req, res); err != nil { return nil, err } return res, nil } // SendMessages sends multiple Messages to netlink. The handling of // a Header's Length, Sequence and PID fields is the same as when // calling Send. func (c *Conn) SendMessages(msgs []Message) ([]Message, error) { // Wait for any concurrent calls to Execute to finish before proceeding. c.mu.RLock() defer c.mu.RUnlock() for i := range msgs { c.fixMsg(&msgs[i], nlmsgLength(len(msgs[i].Data))) } c.debug(func(d *debugger) { for _, m := range msgs { d.debugf(1, "send msgs: %+v", m) } }) if err := c.sock.SendMessages(msgs); err != nil { c.debug(func(d *debugger) { d.debugf(1, "send msgs: err: %v", err) }) return nil, newOpError("send-messages", err) } return msgs, nil } // Send sends a single Message to netlink. In most cases, a Header's Length, // Sequence, and PID fields should be set to 0, so they can be populated // automatically before the Message is sent. On success, Send returns a copy // of the Message with all parameters populated, for later validation. // // If Header.Length is 0, it will be automatically populated using the // correct length for the Message, including its payload. // // If Header.Sequence is 0, it will be automatically populated using the // next sequence number for this connection. // // If Header.PID is 0, it will be automatically populated using a PID // assigned by netlink. func (c *Conn) Send(m Message) (Message, error) { // Wait for any concurrent calls to Execute to finish before proceeding. c.mu.RLock() defer c.mu.RUnlock() return c.lockedSend(m) } // lockedSend implements Send, but must be called with c.mu acquired for reading. // We rely on the kernel to deal with concurrent reads and writes to the netlink // socket itself. func (c *Conn) lockedSend(m Message) (Message, error) { c.fixMsg(&m, nlmsgLength(len(m.Data))) c.debug(func(d *debugger) { d.debugf(1, "send: %+v", m) }) if err := c.sock.Send(m); err != nil { c.debug(func(d *debugger) { d.debugf(1, "send: err: %v", err) }) return Message{}, newOpError("send", err) } return m, nil } // Receive receives one or more messages from netlink. Multi-part messages are // handled transparently and returned as a single slice of Messages, with the // final empty "multi-part done" message removed. // // If any of the messages indicate a netlink error, that error will be returned. func (c *Conn) Receive() ([]Message, error) { // Wait for any concurrent calls to Execute to finish before proceeding. c.mu.RLock() defer c.mu.RUnlock() return c.lockedReceive() } // lockedReceive implements Receive, but must be called with c.mu acquired for reading. // We rely on the kernel to deal with concurrent reads and writes to the netlink // socket itself. func (c *Conn) lockedReceive() ([]Message, error) { msgs, err := c.receive() if err != nil { c.debug(func(d *debugger) { d.debugf(1, "recv: err: %v", err) }) return nil, err } c.debug(func(d *debugger) { for _, m := range msgs { d.debugf(1, "recv: %+v", m) } }) // When using nltest, it's possible for zero messages to be returned by receive. if len(msgs) == 0 { return msgs, nil } // Trim the final message with multi-part done indicator if // present. if m := msgs[len(msgs)-1]; m.Header.Flags&Multi != 0 && m.Header.Type == Done { return msgs[:len(msgs)-1], nil } return msgs, nil } // receive is the internal implementation of Conn.Receive, which can be called // recursively to handle multi-part messages. func (c *Conn) receive() ([]Message, error) { // NB: All non-nil errors returned from this function *must* be of type // OpError in order to maintain the appropriate contract with callers of // this package. // // This contract also applies to functions called within this function, // such as checkMessage. var res []Message for { msgs, err := c.sock.Receive() if err != nil { return nil, newOpError("receive", err) } // If this message is multi-part, we will need to continue looping to // drain all the messages from the socket. var multi bool for _, m := range msgs { if err := checkMessage(m); err != nil { return nil, err } // Does this message indicate a multi-part message? if m.Header.Flags&Multi == 0 { // No, check the next messages. continue } // Does this message indicate the last message in a series of // multi-part messages from a single read? multi = m.Header.Type != Done } res = append(res, msgs...) if !multi { // No more messages coming. return res, nil } } } // A groupJoinLeaver is a Socket that supports joining and leaving // netlink multicast groups. type groupJoinLeaver interface { Socket JoinGroup(group uint32) error LeaveGroup(group uint32) error } // JoinGroup joins a netlink multicast group by its ID. func (c *Conn) JoinGroup(group uint32) error { conn, ok := c.sock.(groupJoinLeaver) if !ok { return notSupported("join-group") } return newOpError("join-group", conn.JoinGroup(group)) } // LeaveGroup leaves a netlink multicast group by its ID. func (c *Conn) LeaveGroup(group uint32) error { conn, ok := c.sock.(groupJoinLeaver) if !ok { return notSupported("leave-group") } return newOpError("leave-group", conn.LeaveGroup(group)) } // A bpfSetter is a Socket that supports setting and removing BPF filters. type bpfSetter interface { Socket bpf.Setter RemoveBPF() error } // SetBPF attaches an assembled BPF program to a Conn. func (c *Conn) SetBPF(filter []bpf.RawInstruction) error { conn, ok := c.sock.(bpfSetter) if !ok { return notSupported("set-bpf") } return newOpError("set-bpf", conn.SetBPF(filter)) } // RemoveBPF removes a BPF filter from a Conn. func (c *Conn) RemoveBPF() error { conn, ok := c.sock.(bpfSetter) if !ok { return notSupported("remove-bpf") } return newOpError("remove-bpf", conn.RemoveBPF()) } // A deadlineSetter is a Socket that supports setting deadlines. type deadlineSetter interface { Socket SetDeadline(time.Time) error SetReadDeadline(time.Time) error SetWriteDeadline(time.Time) error } // SetDeadline sets the read and write deadlines associated with the connection. func (c *Conn) SetDeadline(t time.Time) error { conn, ok := c.sock.(deadlineSetter) if !ok { return notSupported("set-deadline") } return newOpError("set-deadline", conn.SetDeadline(t)) } // SetReadDeadline sets the read deadline associated with the connection. func (c *Conn) SetReadDeadline(t time.Time) error { conn, ok := c.sock.(deadlineSetter) if !ok { return notSupported("set-read-deadline") } return newOpError("set-read-deadline", conn.SetReadDeadline(t)) } // SetWriteDeadline sets the write deadline associated with the connection. func (c *Conn) SetWriteDeadline(t time.Time) error { conn, ok := c.sock.(deadlineSetter) if !ok { return notSupported("set-write-deadline") } return newOpError("set-write-deadline", conn.SetWriteDeadline(t)) } // A ConnOption is a boolean option that may be set for a Conn. type ConnOption int // Possible ConnOption values. These constants are equivalent to the Linux // setsockopt boolean options for netlink sockets. const ( PacketInfo ConnOption = iota BroadcastError NoENOBUFS ListenAllNSID CapAcknowledge ExtendedAcknowledge GetStrictCheck ) // An optionSetter is a Socket that supports setting netlink options. type optionSetter interface { Socket SetOption(option ConnOption, enable bool) error } // SetOption enables or disables a netlink socket option for the Conn. func (c *Conn) SetOption(option ConnOption, enable bool) error { conn, ok := c.sock.(optionSetter) if !ok { return notSupported("set-option") } return newOpError("set-option", conn.SetOption(option, enable)) } // A bufferSetter is a Socket that supports setting connection buffer sizes. type bufferSetter interface { Socket SetReadBuffer(bytes int) error SetWriteBuffer(bytes int) error } // SetReadBuffer sets the size of the operating system's receive buffer // associated with the Conn. func (c *Conn) SetReadBuffer(bytes int) error { conn, ok := c.sock.(bufferSetter) if !ok { return notSupported("set-read-buffer") } return newOpError("set-read-buffer", conn.SetReadBuffer(bytes)) } // SetWriteBuffer sets the size of the operating system's transmit buffer // associated with the Conn. func (c *Conn) SetWriteBuffer(bytes int) error { conn, ok := c.sock.(bufferSetter) if !ok { return notSupported("set-write-buffer") } return newOpError("set-write-buffer", conn.SetWriteBuffer(bytes)) } // A syscallConner is a Socket that supports syscall.Conn. type syscallConner interface { Socket syscall.Conn } var _ syscall.Conn = &Conn{} // SyscallConn returns a raw network connection. This implements the // syscall.Conn interface. // // SyscallConn is intended for advanced use cases, such as getting and setting // arbitrary socket options using the netlink socket's file descriptor. // // Once invoked, it is the caller's responsibility to ensure that operations // performed using Conn and the syscall.RawConn do not conflict with // each other. func (c *Conn) SyscallConn() (syscall.RawConn, error) { sc, ok := c.sock.(syscallConner) if !ok { return nil, notSupported("syscall-conn") } // TODO(mdlayher): mutex or similar to enforce syscall.RawConn contract of // FD remaining valid for duration of calls? return sc.SyscallConn() } // fixMsg updates the fields of m using the logic specified in Send. func (c *Conn) fixMsg(m *Message, ml int) { if m.Header.Length == 0 { m.Header.Length = uint32(nlmsgAlign(ml)) } if m.Header.Sequence == 0 { m.Header.Sequence = c.nextSequence() } if m.Header.PID == 0 { m.Header.PID = c.pid } } // nextSequence atomically increments Conn's sequence number and returns // the incremented value. func (c *Conn) nextSequence() uint32 { return atomic.AddUint32(&c.seq, 1) } // Validate validates one or more reply Messages against a request Message, // ensuring that they contain matching sequence numbers and PIDs. func Validate(request Message, replies []Message) error { for _, m := range replies { // Check for mismatched sequence, unless: // - request had no sequence, meaning we are probably validating // a multicast reply if m.Header.Sequence != request.Header.Sequence && request.Header.Sequence != 0 { return newOpError("validate", errMismatchedSequence) } // Check for mismatched PID, unless: // - request had no PID, meaning we are either: // - validating a multicast reply // - netlink has not yet assigned us a PID // - response had no PID, meaning it's from the kernel as a multicast reply if m.Header.PID != request.Header.PID && request.Header.PID != 0 && m.Header.PID != 0 { return newOpError("validate", errMismatchedPID) } } return nil } // Config contains options for a Conn. type Config struct { // Groups is a bitmask which specifies multicast groups. If set to 0, // no multicast group subscriptions will be made. Groups uint32 // NetNS specifies the network namespace the Conn will operate in. // // If set (non-zero), Conn will enter the specified network namespace and // an error will occur in Dial if the operation fails. // // If not set (zero), a best-effort attempt will be made to enter the // network namespace of the calling thread: this means that any changes made // to the calling thread's network namespace will also be reflected in Conn. // If this operation fails (due to lack of permissions or because network // namespaces are disabled by kernel configuration), Dial will not return // an error, and the Conn will operate in the default network namespace of // the process. This enables non-privileged use of Conn in applications // which do not require elevated privileges. // // Entering a network namespace is a privileged operation (root or // CAP_SYS_ADMIN are required), and most applications should leave this set // to 0. NetNS int // DisableNSLockThread is a no-op. // // Deprecated: internal changes have made this option obsolete and it has no // effect. Do not use. DisableNSLockThread bool // PID specifies the port ID used to bind the netlink socket. If set to 0, // the kernel will assign a port ID on the caller's behalf. // // Most callers should leave this field set to 0. This option is intended // for advanced use cases where the kernel expects a fixed unicast address // destination for netlink messages. PID uint32 // Strict applies a more strict default set of options to the Conn, // including: // - ExtendedAcknowledge: true // - provides more useful error messages when supported by the kernel // - GetStrictCheck: true // - more strictly enforces request validation for some families such // as rtnetlink which were historically misused // // If any of the options specified by Strict cannot be configured due to an // outdated kernel or similar, an error will be returned. // // When possible, setting Strict to true is recommended for applications // running on modern Linux kernels. Strict bool } netlink-1.7.2/conn_linux.go000066400000000000000000000145311442322526200156760ustar00rootroot00000000000000//go:build linux // +build linux package netlink import ( "context" "os" "syscall" "time" "unsafe" "github.com/mdlayher/socket" "golang.org/x/net/bpf" "golang.org/x/sys/unix" ) var _ Socket = &conn{} // A conn is the Linux implementation of a netlink sockets connection. type conn struct { s *socket.Conn } // dial is the entry point for Dial. dial opens a netlink socket using // system calls, and returns its PID. func dial(family int, config *Config) (*conn, uint32, error) { if config == nil { config = &Config{} } // Prepare the netlink socket. s, err := socket.Socket( unix.AF_NETLINK, unix.SOCK_RAW, family, "netlink", &socket.Config{NetNS: config.NetNS}, ) if err != nil { return nil, 0, err } return newConn(s, config) } // newConn binds a connection to netlink using the input *socket.Conn. func newConn(s *socket.Conn, config *Config) (*conn, uint32, error) { if config == nil { config = &Config{} } addr := &unix.SockaddrNetlink{ Family: unix.AF_NETLINK, Groups: config.Groups, Pid: config.PID, } // Socket must be closed in the event of any system call errors, to avoid // leaking file descriptors. if err := s.Bind(addr); err != nil { _ = s.Close() return nil, 0, err } sa, err := s.Getsockname() if err != nil { _ = s.Close() return nil, 0, err } c := &conn{s: s} if config.Strict { // The caller has requested the strict option set. Historically we have // recommended checking for ENOPROTOOPT if the kernel does not support // the option in question, but that may result in a silent failure and // unexpected behavior for the user. // // Treat any error here as a fatal error, and require the caller to deal // with it. for _, o := range []ConnOption{ExtendedAcknowledge, GetStrictCheck} { if err := c.SetOption(o, true); err != nil { _ = c.Close() return nil, 0, err } } } return c, sa.(*unix.SockaddrNetlink).Pid, nil } // SendMessages serializes multiple Messages and sends them to netlink. func (c *conn) SendMessages(messages []Message) error { var buf []byte for _, m := range messages { b, err := m.MarshalBinary() if err != nil { return err } buf = append(buf, b...) } sa := &unix.SockaddrNetlink{Family: unix.AF_NETLINK} _, err := c.s.Sendmsg(context.Background(), buf, nil, sa, 0) return err } // Send sends a single Message to netlink. func (c *conn) Send(m Message) error { b, err := m.MarshalBinary() if err != nil { return err } sa := &unix.SockaddrNetlink{Family: unix.AF_NETLINK} _, err = c.s.Sendmsg(context.Background(), b, nil, sa, 0) return err } // Receive receives one or more Messages from netlink. func (c *conn) Receive() ([]Message, error) { b := make([]byte, os.Getpagesize()) for { // Peek at the buffer to see how many bytes are available. // // TODO(mdlayher): deal with OOB message data if available, such as // when PacketInfo ConnOption is true. n, _, _, _, err := c.s.Recvmsg(context.Background(), b, nil, unix.MSG_PEEK) if err != nil { return nil, err } // Break when we can read all messages if n < len(b) { break } // Double in size if not enough bytes b = make([]byte, len(b)*2) } // Read out all available messages n, _, _, _, err := c.s.Recvmsg(context.Background(), b, nil, 0) if err != nil { return nil, err } raw, err := syscall.ParseNetlinkMessage(b[:nlmsgAlign(n)]) if err != nil { return nil, err } msgs := make([]Message, 0, len(raw)) for _, r := range raw { m := Message{ Header: sysToHeader(r.Header), Data: r.Data, } msgs = append(msgs, m) } return msgs, nil } // Close closes the connection. func (c *conn) Close() error { return c.s.Close() } // JoinGroup joins a multicast group by ID. func (c *conn) JoinGroup(group uint32) error { return c.s.SetsockoptInt(unix.SOL_NETLINK, unix.NETLINK_ADD_MEMBERSHIP, int(group)) } // LeaveGroup leaves a multicast group by ID. func (c *conn) LeaveGroup(group uint32) error { return c.s.SetsockoptInt(unix.SOL_NETLINK, unix.NETLINK_DROP_MEMBERSHIP, int(group)) } // SetBPF attaches an assembled BPF program to a conn. func (c *conn) SetBPF(filter []bpf.RawInstruction) error { return c.s.SetBPF(filter) } // RemoveBPF removes a BPF filter from a conn. func (c *conn) RemoveBPF() error { return c.s.RemoveBPF() } // SetOption enables or disables a netlink socket option for the Conn. func (c *conn) SetOption(option ConnOption, enable bool) error { o, ok := linuxOption(option) if !ok { // Return the typical Linux error for an unknown ConnOption. return os.NewSyscallError("setsockopt", unix.ENOPROTOOPT) } var v int if enable { v = 1 } return c.s.SetsockoptInt(unix.SOL_NETLINK, o, v) } func (c *conn) SetDeadline(t time.Time) error { return c.s.SetDeadline(t) } func (c *conn) SetReadDeadline(t time.Time) error { return c.s.SetReadDeadline(t) } func (c *conn) SetWriteDeadline(t time.Time) error { return c.s.SetWriteDeadline(t) } // SetReadBuffer sets the size of the operating system's receive buffer // associated with the Conn. func (c *conn) SetReadBuffer(bytes int) error { return c.s.SetReadBuffer(bytes) } // SetReadBuffer sets the size of the operating system's transmit buffer // associated with the Conn. func (c *conn) SetWriteBuffer(bytes int) error { return c.s.SetWriteBuffer(bytes) } // SyscallConn returns a raw network connection. func (c *conn) SyscallConn() (syscall.RawConn, error) { return c.s.SyscallConn() } // linuxOption converts a ConnOption to its Linux value. func linuxOption(o ConnOption) (int, bool) { switch o { case PacketInfo: return unix.NETLINK_PKTINFO, true case BroadcastError: return unix.NETLINK_BROADCAST_ERROR, true case NoENOBUFS: return unix.NETLINK_NO_ENOBUFS, true case ListenAllNSID: return unix.NETLINK_LISTEN_ALL_NSID, true case CapAcknowledge: return unix.NETLINK_CAP_ACK, true case ExtendedAcknowledge: return unix.NETLINK_EXT_ACK, true case GetStrictCheck: return unix.NETLINK_GET_STRICT_CHK, true default: return 0, false } } // sysToHeader converts a syscall.NlMsghdr to a Header. func sysToHeader(r syscall.NlMsghdr) Header { // NB: the memory layout of Header and syscall.NlMsgHdr must be // exactly the same for this unsafe cast to work return *(*Header)(unsafe.Pointer(&r)) } // newError converts an error number from netlink into the appropriate // system call error for Linux. func newError(errno int) error { return syscall.Errno(errno) } netlink-1.7.2/conn_linux_error_test.go000066400000000000000000000047171442322526200201530ustar00rootroot00000000000000//go:build linux // +build linux package netlink_test import ( "encoding/binary" "os" "testing" "github.com/google/go-cmp/cmp" "github.com/josharian/native" "github.com/mdlayher/netlink" "github.com/mdlayher/netlink/nltest" "golang.org/x/sys/unix" ) func TestConnReceiveErrorLinux(t *testing.T) { skipBigEndian(t) // Note: using *Conn instead of Linux-only *conn, to test // error handling logic in *Conn.Receive. // // This test also verifies the contractual behavior of OpError wrapping // errors from system calls in os.SyscallError, but NOT wrapping netlink // error codes. tests := []struct { name string msgs []netlink.Message in error want error }{ { name: "netlink message ENOENT", msgs: []netlink.Message{{ Header: netlink.Header{ Length: 20, Type: netlink.Error, Sequence: 1, PID: 1, }, // -2, little endian (ENOENT) Data: []byte{0xfe, 0xff, 0xff, 0xff}, }}, want: &netlink.OpError{ Op: "receive", Err: unix.ENOENT, }, }, { name: "syscall error ENOENT", in: unix.ENOENT, want: &netlink.OpError{ Op: "receive", Err: os.NewSyscallError("recvmsg", unix.ENOENT), }, }, { name: "multipart done without error", msgs: []netlink.Message{ { Header: netlink.Header{ Flags: netlink.Multi, }, }, { Header: netlink.Header{ Type: netlink.Done, Flags: netlink.Multi, }, }, }, }, { name: "multipart done with error", msgs: []netlink.Message{ { Header: netlink.Header{ Flags: netlink.Multi, }, }, { Header: netlink.Header{ Type: netlink.Done, Flags: netlink.Multi, }, // -2, little endian (ENOENT) Data: []byte{0xfe, 0xff, 0xff, 0xff}, }, }, want: &netlink.OpError{ Op: "receive", Err: unix.ENOENT, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := nltest.Dial(func(_ []netlink.Message) ([]netlink.Message, error) { return tt.msgs, tt.in }) defer c.Close() // Need to prepopulate nltest's internal buffers by invoking the // function once. _, _ = c.Send(netlink.Message{}) _, got := c.Receive() if diff := cmp.Diff(tt.want, got); diff != "" { t.Fatalf("unexpected error (-want +got):\n%s", diff) } }) } } func skipBigEndian(t *testing.T) { if binary.ByteOrder(native.Endian) == binary.BigEndian { t.Skip("skipping test on big-endian system") } } netlink-1.7.2/conn_linux_integration_test.go000066400000000000000000000520321442322526200213360ustar00rootroot00000000000000//go:build linux // +build linux package netlink_test import ( "errors" "fmt" "math/rand" "net" "os" "os/exec" "os/user" "sync" "testing" "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/mdlayher/netlink" "github.com/mdlayher/netlink/nlenc" "golang.org/x/net/bpf" "golang.org/x/sys/unix" ) func TestIntegrationConn(t *testing.T) { t.Parallel() c, err := netlink.Dial(unix.NETLINK_GENERIC, nil) if err != nil { t.Fatalf("failed to dial netlink: %v", err) } // Ask to send us an acknowledgement, which will contain an // error code (or success) and a copy of the payload we sent in req := netlink.Message{ Header: netlink.Header{ Flags: netlink.Request | netlink.Acknowledge, }, } // Perform a request, receive replies, and validate the replies msgs, err := c.Execute(req) if err != nil { t.Fatalf("failed to execute request: %v", err) } if want, got := 1, len(msgs); want != got { t.Fatalf("unexpected message count from netlink:\n- want: %v\n- got: %v", want, got) } if err := c.Close(); err != nil { t.Fatalf("error closing netlink connection: %v", err) } m := msgs[0] if want, got := 0, int(nlenc.Uint32(m.Data[0:4])); want != got { t.Fatalf("unexpected error code:\n- want: %v\n- got: %v", want, got) } if want, got := 36, int(m.Header.Length); want != got { t.Fatalf("unexpected header length:\n- want: %v\n- got: %v", want, got) } if want, got := netlink.Error, m.Header.Type; want != got { t.Fatalf("unexpected header type:\n- want: %v\n- got: %v", want, got) } // Recent kernel versions (> 4.14) return a 256 here instead of a 0 if want, wantAlt, got := 0, 256, int(m.Header.Flags); want != got && wantAlt != got { t.Fatalf("unexpected header flags:\n- want: %v or %v\n- got: %v", want, wantAlt, got) } // Sequence number is not checked because we assign one at random when // a Conn is created. PID is not checked because running tests in parallel // results in only the first socket getting assigned the process's PID as // its netlink PID. // Skip error code and unmarshal the copy of request sent back by // skipping the success code at bytes 0-4 var reply netlink.Message if err := (&reply).UnmarshalBinary(m.Data[4:]); err != nil { t.Fatalf("failed to unmarshal reply: %v", err) } if want, got := req.Header.Flags, reply.Header.Flags; want != got { t.Fatalf("unexpected copy header flags:\n- want: %v\n- got: %v", want, got) } if want, got := len(req.Data), len(reply.Data); want != got { t.Fatalf("unexpected copy header data length:\n- want: %v\n- got: %v", want, got) } } func TestIntegrationConnConcurrentManyConns(t *testing.T) { t.Parallel() skipShort(t) // Execute many concurrent operations on several netlink.Conns to ensure // the kernel is sending and receiving netlink messages to/from the correct // file descriptor. // // See: http://lists.infradead.org/pipermail/libnl/2017-February/002293.html. execN := func(n int) { c, err := netlink.Dial(unix.NETLINK_GENERIC, nil) if err != nil { panicf("failed to dial generic netlink: %v", err) } defer c.Close() req := netlink.Message{ Header: netlink.Header{ Flags: netlink.Request | netlink.Acknowledge, }, } for i := 0; i < n; i++ { msgs, err := c.Execute(req) if err != nil { panicf("failed to send request: %v", err) } if l := len(msgs); l != 1 { panicf("unexpected number of reply messages: %d", l) } } } const ( workers = 16 iterations = 10000 ) var wg sync.WaitGroup wg.Add(workers) for i := 0; i < workers; i++ { go func() { defer wg.Done() execN(iterations) }() } wg.Wait() } func TestIntegrationConnConcurrentOneConn(t *testing.T) { t.Parallel() skipShort(t) // Execute many concurrent operations on a single netlink.Conn. c, err := netlink.Dial(unix.NETLINK_GENERIC, nil) if err != nil { t.Fatalf("failed to dial netlink: %v", err) } execN := func(n int) { req := netlink.Message{ Header: netlink.Header{ Flags: netlink.Request | netlink.Acknowledge, }, } var res netlink.Message for i := 0; i < n; i++ { // Don't expect a "valid" request/reply because we are not serializing // our Send/Receive calls via Execute or with an external lock. // // Just verify that we don't trigger the race detector, we got a // valid netlink response, and it can be decoded as a valid // netlink message. if _, err := c.Send(req); err != nil { panicf("failed to send request: %v", err) } msgs, err := c.Receive() if err != nil { panicf("failed to receive reply: %v", err) } if l := len(msgs); l != 1 { panicf("unexpected number of reply messages: %d", l) } if err := res.UnmarshalBinary(msgs[0].Data[4:]); err != nil { panicf("failed to unmarshal reply: %v", err) } } } const ( workers = 16 iterations = 10000 ) var wg sync.WaitGroup wg.Add(workers) defer wg.Wait() for i := 0; i < workers; i++ { go func() { defer wg.Done() execN(iterations) }() } } func TestIntegrationConnConcurrentClosePreventsReceive(t *testing.T) { t.Parallel() c, err := netlink.Dial(unix.NETLINK_GENERIC, nil) if err != nil { t.Fatalf("failed to dial netlink: %v", err) } // Verify this test cannot block indefinitely due to Receive hanging after // a call to Close is completed. timer := time.AfterFunc(10*time.Second, func() { panic("test took too long") }) defer timer.Stop() var wg sync.WaitGroup wg.Add(1) defer wg.Wait() // The intent of this test is to schedule Close before Receive can ever // happen, resulting in EBADF. The test below covers the opposite case. sigC := make(chan struct{}) go func() { defer wg.Done() <-sigC _, err := c.Receive() if err == nil { panicf("expected an error, but none occurred") } // Expect an error due to file descriptor being closed. serr := err.(*netlink.OpError).Err.(*os.SyscallError).Err if diff := cmp.Diff(unix.EBADF, serr); diff != "" { panicf("unexpected error from receive (-want +got):\n%s", diff) } }() if err := c.Close(); err != nil { t.Fatalf("failed to close: %v", err) } close(sigC) } func TestIntegrationConnConcurrentCloseUnblocksReceive(t *testing.T) { t.Parallel() c, err := netlink.Dial(unix.NETLINK_GENERIC, nil) if err != nil { t.Fatalf("failed to dial netlink: %v", err) } // Verify this test cannot block indefinitely due to Receive hanging after // a call to Close is completed. timer := time.AfterFunc(10*time.Second, func() { panic("test took too long") }) defer timer.Stop() var wg sync.WaitGroup wg.Add(1) defer wg.Wait() // Try to enforce that Receive is scheduled before Close. sigC := make(chan struct{}) go func() { defer wg.Done() // Multiple Close operations should be a no-op. <-sigC for i := 0; i < 5; i++ { time.Sleep(50 * time.Millisecond) if err := c.Close(); err != nil { panicf("failed to close: %v", err) } } }() close(sigC) _, err = c.Receive() if err == nil { t.Fatalf("expected an error, but none occurred") } // Expect an error due to the use of a closed file descriptor. Unfortunately // there doesn't seem to be a typed error for this. // // Previous versions of this code would wrap the internal/poll error which // *os.SyscallError which technically was incorrect. If necessary, revert // this behavior. serr := err.(*netlink.OpError).Err if diff := cmp.Diff("use of closed file", serr.Error()); diff != "" { t.Fatalf("unexpected error from receive (-want +got):\n%s", diff) } } func TestIntegrationConnConcurrentSerializeExecute(t *testing.T) { t.Parallel() skipShort(t) c, err := netlink.Dial(unix.NETLINK_GENERIC, nil) if err != nil { t.Fatalf("failed to dial netlink: %v", err) } execN := func(n int) { req := netlink.Message{ Header: netlink.Header{ Flags: netlink.Request | netlink.Acknowledge, }, } for i := 0; i < n; i++ { // Execute will internally call Validate to ensure its // request/response transaction is serialized appropriately, and // any errors doing so will be reported here. if _, err := c.Execute(req); err != nil { panicf("failed to execute: %v", err) } } } const ( workers = 4 iterations = 2000 ) var wg sync.WaitGroup wg.Add(workers) defer wg.Wait() for i := 0; i < workers; i++ { go func() { defer wg.Done() execN(iterations) }() } } func TestIntegrationConnSetBuffersSyscallConn(t *testing.T) { tests := []struct { name string check func(t *testing.T) }{ // This test verifies both the force/non-force socket options depending // on the caller's privileges. { name: "unprivileged", check: skipPrivileged, }, { name: "privileged", check: skipUnprivileged, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tt.check(t) c, err := netlink.Dial(unix.NETLINK_GENERIC, nil) if err != nil { t.Fatalf("failed to dial netlink: %v", err) } defer c.Close() const ( set = 8192 // Per man 7 socket: // // "The kernel doubles this value (to allow space for // book‐keeping overhead) when it is set using setsockopt(2), // and this doubled value is returned by getsockopt(2)."" want = set * 2 ) if err := c.SetReadBuffer(set); err != nil { t.Fatalf("failed to set read buffer size: %v", err) } if err := c.SetWriteBuffer(set); err != nil { t.Fatalf("failed to set write buffer size: %v", err) } // Now that we've set the buffers, we can check the size by asking the // kernel using SyscallConn and getsockopt. rc, err := c.SyscallConn() if err != nil { t.Fatalf("failed to get syscall conn: %v", err) } mustSize := func(opt int) int { var ( value int serr error ) err := rc.Control(func(fd uintptr) { value, serr = unix.GetsockoptInt(int(fd), unix.SOL_SOCKET, opt) }) if err != nil { t.Fatalf("failed to call control: %v", err) } if serr != nil { t.Fatalf("failed to call getsockopt: %v", serr) } return value } if diff := cmp.Diff(want, mustSize(unix.SO_RCVBUF)); diff != "" { t.Fatalf("unexpected read buffer size (-want +got):\n%s", diff) } if diff := cmp.Diff(want, mustSize(unix.SO_SNDBUF)); diff != "" { t.Fatalf("unexpected write buffer size (-want +got):\n%s", diff) } }) } } func TestIntegrationConnSetBPFEmpty(t *testing.T) { c, err := netlink.Dial(unix.NETLINK_GENERIC, nil) if err != nil { t.Fatalf("failed to dial netlink: %v", err) } defer c.Close() if err := c.SetBPF(nil); err == nil { t.Fatal("expected an error, but none occurred") } } func TestIntegrationConnSetBPF(t *testing.T) { t.Parallel() c, err := netlink.Dial(unix.NETLINK_GENERIC, nil) if err != nil { t.Fatalf("failed to dial netlink: %v", err) } defer c.Close() // The sequence number which will be permitted by the BPF filter. // Using max uint32 helps us avoid dealing with host (netlink) vs // network (BPF) endianness during this test. const sequence uint32 = 0xffffffff prog, err := bpf.Assemble(testBPFProgram(sequence)) if err != nil { t.Fatalf("failed to assemble BPF program: %v", err) } if err := c.SetBPF(prog); err != nil { t.Fatalf("failed to attach BPF program to socket: %v", err) } req := netlink.Message{ Header: netlink.Header{ Flags: netlink.Request | netlink.Acknowledge, }, } sequences := []struct { seq uint32 ok bool }{ // OK, bad, OK. Expect two messages to be received. {seq: sequence, ok: true}, {seq: 10, ok: false}, {seq: sequence, ok: true}, } for _, s := range sequences { req.Header.Sequence = s.seq if _, err := c.Send(req); err != nil { t.Fatalf("failed to send with sequence %d: %v", s.seq, err) } if !s.ok { continue } msgs, err := c.Receive() if err != nil { t.Fatalf("failed to receive with sequence %d: %v", s.seq, err) } // Make sure the received message has the expected sequence number. if l := len(msgs); l != 1 { t.Fatalf("unexpected number of messages: %d", l) } if want, got := s.seq, msgs[0].Header.Sequence; want != got { t.Fatalf("unexpected reply sequence number:\n- want: %v\n- got: %v", want, got) } } if err := c.RemoveBPF(); err != nil { t.Fatalf("failed to remove BPF filter: %v", err) } } func Test_testBPFProgram(t *testing.T) { // Verify the validity of our test BPF program. vm, err := bpf.NewVM(testBPFProgram(0xffffffff)) if err != nil { t.Fatalf("failed to create BPF VM: %v", err) } msg := []byte{ 0x10, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01, 0x00, // Allowed sequence number. 0xff, 0xff, 0xff, 0xff, 0x01, 0x00, 0x00, 0x00, } out, err := vm.Run(msg) if err != nil { t.Fatalf("failed to execute OK input: %v", err) } if out == 0 { t.Fatal("BPF filter dropped OK input") } msg = []byte{ 0x10, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01, 0x00, // Bad sequence number. 0x00, 0x11, 0x22, 0x33, 0x01, 0x00, 0x00, 0x00, } out, err = vm.Run(msg) if err != nil { t.Fatalf("failed to execute bad input: %v", err) } if out != 0 { t.Fatal("BPF filter did not drop bad input") } } // testBPFProgram returns a BPF program which only allows frames with the // input sequence number. func testBPFProgram(allowSequence uint32) []bpf.Instruction { return []bpf.Instruction{ bpf.LoadAbsolute{ Off: 8, Size: 4, }, bpf.JumpIf{ Cond: bpf.JumpEqual, Val: allowSequence, SkipTrue: 1, }, bpf.RetConstant{ Val: 0, }, bpf.RetConstant{ Val: 128, }, } } func TestIntegrationConnExplicitPID(t *testing.T) { t.Parallel() // Compute a random uint32 PID and explicitly bind using it. We expect this // PID will be used in messages that are sent to and received from the // kernel. rng := rand.New(rand.NewSource(time.Now().UnixNano())) pid := rng.Uint32() c, err := netlink.Dial(unix.NETLINK_GENERIC, &netlink.Config{PID: pid}) if err != nil { t.Fatalf("failed to dial netlink: %v", err) } defer c.Close() req := netlink.Message{ Header: netlink.Header{ Flags: netlink.Request | netlink.Acknowledge, }, } msg, err := c.Send(req) if err != nil { t.Fatalf("failed to send message: %v", err) } msgs, err := c.Receive() if err != nil { t.Fatalf("failed to receive messages: %v", err) } // Verify both the request and response messages contain the same PID. for _, m := range append([]netlink.Message{msg}, msgs...) { if diff := cmp.Diff(pid, m.Header.PID); diff != "" { t.Fatalf("unexpected message PID (-want +got):\n%s", diff) } } } func TestIntegrationConnNetNSUnprivileged(t *testing.T) { t.Parallel() skipPrivileged(t) // Created in CI build environment. const ns = "unpriv0" f, err := os.Open("/var/run/netns/" + ns) if err != nil { if os.IsNotExist(err) { t.Skipf("skipping, expected %s namespace to exist", ns) } t.Fatalf("failed to open namespace file: %v", err) } defer f.Close() _, err = netlink.Dial(unix.NETLINK_ROUTE, &netlink.Config{ NetNS: int(f.Fd()), }) if !os.IsPermission(err) { t.Fatalf("expected permission denied, but got: %v", err) } } func TestIntegrationConnSendTimeout(t *testing.T) { t.Parallel() c, err := netlink.Dial(unix.NETLINK_GENERIC, nil) if err != nil { t.Fatalf("failed to dial: %v", err) } defer c.Close() if err := c.SetWriteDeadline(time.Unix(0, 1)); err != nil { t.Fatalf("failed to set deadline: %v", err) } _, err = c.Send(netlink.Message{ Header: netlink.Header{ Flags: netlink.Request | netlink.Acknowledge, }, }) mustBeTimeoutNetError(t, err) } func TestIntegrationConnReceiveTimeout(t *testing.T) { t.Parallel() c, err := netlink.Dial(unix.NETLINK_GENERIC, nil) if err != nil { t.Fatalf("failed to dial: %v", err) } defer c.Close() if err := c.SetReadDeadline(time.Unix(0, 1)); err != nil { t.Fatalf("failed to set deadline: %v", err) } _, err = c.Receive() mustBeTimeoutNetError(t, err) } func TestIntegrationConnExecuteTimeout(t *testing.T) { t.Parallel() c, err := netlink.Dial(unix.NETLINK_GENERIC, nil) if err != nil { t.Fatalf("failed to dial: %v", err) } defer c.Close() if err := c.SetDeadline(time.Unix(0, 1)); err != nil { t.Fatalf("failed to set deadline: %v", err) } req := netlink.Message{ Header: netlink.Header{ Flags: netlink.Request | netlink.Acknowledge, }, } _, err = c.Execute(req) if err == nil { t.Fatal("expected an error, but none occurred") } mustBeTimeoutNetError(t, err) } func TestOpErrorUnwrapLinux(t *testing.T) { tests := []struct { name string err error target error ok bool }{ { name: "ENOBUFS", err: unix.ENOBUFS, target: os.ErrNotExist, }, { name: "OpError ENOBUFS", err: &netlink.OpError{ Op: "receive", Err: unix.ENOBUFS, }, target: os.ErrNotExist, }, { name: "OpError os.SyscallError ENOBUFS", err: &netlink.OpError{ Op: "receive", Err: os.NewSyscallError("recvmsg", unix.ENOBUFS), }, target: os.ErrNotExist, }, { name: "ENOENT", err: unix.ENOENT, target: os.ErrNotExist, ok: true, }, { name: "OpError ENOENT", err: &netlink.OpError{ Op: "receive", Err: unix.ENOENT, }, target: os.ErrNotExist, ok: true, }, { name: "OpError os.SyscallError ENOENT", err: &netlink.OpError{ Op: "receive", Err: os.NewSyscallError("recvmsg", unix.ENOENT), }, target: os.ErrNotExist, ok: true, }, { name: "OpError os.SyscallError EEXIST", err: &netlink.OpError{ Op: "receive", Err: os.NewSyscallError("recvmsg", unix.EEXIST), }, target: os.ErrExist, ok: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := errors.Is(tt.err, tt.target) if diff := cmp.Diff(tt.ok, got); diff != "" { t.Fatalf("unexpected result (-want +got):\n%s", diff) } }) } } func TestIntegrationConnClosedConn(t *testing.T) { t.Parallel() c, err := netlink.Dial(unix.NETLINK_GENERIC, nil) if err != nil { t.Fatalf("failed to dial netlink: %v", err) } // Close the connection immediately and ensure that future calls get EBADF. if err := c.Close(); err != nil { t.Fatalf("failed to close: %v", err) } tests := []struct { name string fn func() error }{ { name: "receive", fn: func() error { _, err := c.Receive() return err }, }, { name: "send", fn: func() error { _, err := c.Send(netlink.Message{}) return err }, }, { name: "set option", fn: func() error { return c.SetOption(netlink.ExtendedAcknowledge, true) }, }, { name: "syscall conn", fn: func() error { _, err := c.SyscallConn() return err }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if diff := cmp.Diff(unix.EBADF, tt.fn(), cmpopts.EquateErrors()); diff != "" { t.Fatalf("unexpected error (-want +got):\n%s", diff) } }) } } func TestIntegrationConnStrict(t *testing.T) { c, err := netlink.Dial(unix.NETLINK_GENERIC, &netlink.Config{Strict: true}) if err != nil { if errors.Is(err, unix.ENOPROTOOPT) { t.Skipf("skipping, strict options not supported by this kernel: %v", err) } t.Fatalf("failed to dial netlink: %v", err) } defer c.Close() sc, err := c.SyscallConn() if err != nil { t.Fatalf("failed to open syscall conn: %v", err) } // Strict mode applies a series of socket options. Check each applied option // and update the map to true if we found it set to true. Any options which // were not applied as expected will result in the test failing. opts := map[int]bool{ unix.NETLINK_EXT_ACK: false, unix.NETLINK_GET_STRICT_CHK: false, } err = sc.Control(func(fd uintptr) { for k := range opts { // The kernel returns a non-zero value for true. if v, err := unix.GetsockoptInt(int(fd), unix.SOL_NETLINK, k); err == nil && v != 0 { opts[k] = true } } }) if err != nil { t.Fatalf("failed to call control: %v", err) } for k, v := range opts { if !v { t.Errorf("socket option %d was not set to true", k) } } } func mustBeTimeoutNetError(t *testing.T, err error) { t.Helper() nerr, ok := err.(net.Error) if !ok { t.Fatalf("expected net.Error, but got: %T", err) } if !nerr.Timeout() { t.Fatalf("error did not indicate a timeout") } } func skipPrivileged(t *testing.T) { u, err := user.Current() if err != nil { t.Fatalf("failed to get user: %v", err) } if u.Uid == "0" { t.Skip("skipping, test must be run as non-root user") } } func skipUnprivileged(t *testing.T) { const ifName = "nlprobe0" shell(t, "ip", "tuntap", "add", ifName, "mode", "tun") shell(t, "ip", "link", "del", ifName) } func skipShort(t *testing.T) { t.Helper() if testing.Short() { t.Skip("skipping in short test mode") } } func shell(t *testing.T, name string, arg ...string) { t.Helper() t.Logf("$ %s %v", name, arg) cmd := exec.Command(name, arg...) if err := cmd.Start(); err != nil { t.Fatalf("failed to start command %q: %v", name, err) } if err := cmd.Wait(); err != nil { // Shell operations in these tests require elevated privileges. if cmd.ProcessState.ExitCode() == int(unix.EPERM) { t.Skipf("skipping, permission denied: %v", err) } t.Fatalf("failed to wait for command %q: %v", name, err) } } func panicf(format string, a ...interface{}) { panic(fmt.Sprintf(format, a...)) } netlink-1.7.2/conn_others.go000066400000000000000000000017121442322526200160400ustar00rootroot00000000000000//go:build !linux // +build !linux package netlink import ( "fmt" "runtime" ) // errUnimplemented is returned by all functions on platforms that // cannot make use of netlink sockets. var errUnimplemented = fmt.Errorf("netlink: not implemented on %s/%s", runtime.GOOS, runtime.GOARCH) var _ Socket = &conn{} // A conn is the no-op implementation of a netlink sockets connection. type conn struct{} // All cross-platform functions and Socket methods are unimplemented outside // of Linux. func dial(_ int, _ *Config) (*conn, uint32, error) { return nil, 0, errUnimplemented } func newError(_ int) error { return errUnimplemented } func (c *conn) Send(_ Message) error { return errUnimplemented } func (c *conn) SendMessages(_ []Message) error { return errUnimplemented } func (c *conn) Receive() ([]Message, error) { return nil, errUnimplemented } func (c *conn) Close() error { return errUnimplemented } netlink-1.7.2/conn_others_test.go000066400000000000000000000016701442322526200171020ustar00rootroot00000000000000//go:build !linux // +build !linux package netlink import "testing" func TestOthersConnUnimplemented(t *testing.T) { c := &conn{} want := errUnimplemented if got := newError(0); want != got { t.Fatalf("unexpected error during newError:\n- want: %v\n- got: %v", want, got) } if _, _, got := dial(0, nil); want != got { t.Fatalf("unexpected error during dial:\n- want: %v\n- got: %v", want, got) } if got := c.Send(Message{}); want != got { t.Fatalf("unexpected error during c.Send:\n- want: %v\n- got: %v", want, got) } if got := c.SendMessages(nil); want != got { t.Fatalf("unexpected error during c.SendMessages:\n- want: %v\n- got: %v", want, got) } if _, got := c.Receive(); want != got { t.Fatalf("unexpected error during c.Receive:\n- want: %v\n- got: %v", want, got) } if got := c.Close(); want != got { t.Fatalf("unexpected error during c.Close:\n- want: %v\n- got: %v", want, got) } } netlink-1.7.2/conn_test.go000066400000000000000000000133151442322526200155150ustar00rootroot00000000000000package netlink_test import ( "errors" "io" "reflect" "strings" "testing" "time" "github.com/mdlayher/netlink" "github.com/mdlayher/netlink/nltest" ) func TestConnExecute(t *testing.T) { req := netlink.Message{ Header: netlink.Header{ Flags: netlink.Request | netlink.Acknowledge, Sequence: 1, }, } replies := []netlink.Message{{ Header: netlink.Header{ Type: netlink.Error, Sequence: 1, PID: 1, }, // Error code "success", no need to echo request back in this test Data: make([]byte, 4), }} c := nltest.Dial(func(_ []netlink.Message) ([]netlink.Message, error) { return replies, nil }) defer c.Close() msgs, err := c.Execute(req) if err != nil { t.Fatalf("failed to execute: %v", err) } // Fill in fields for comparison req.Header.Length = 16 if want, got := replies, msgs; !reflect.DeepEqual(want, got) { t.Fatalf("unexpected replies:\n- want: %#v\n- got: %#v", want, got) } } func TestConnSend(t *testing.T) { c := nltest.Dial(func(_ []netlink.Message) ([]netlink.Message, error) { return nil, errors.New("should not be received") }) defer c.Close() // Let Conn.Send populate length, sequence, PID m := netlink.Message{} out, err := c.Send(m) if err != nil { t.Fatalf("failed to send message: %v", err) } // Make the same changes that Conn.Send should m = netlink.Message{ Header: netlink.Header{ Length: 16, Sequence: out.Header.Sequence, PID: 1, }, } if want, got := m, out; !reflect.DeepEqual(want, got) { t.Fatalf("unexpected output message from Conn.Send:\n- want: %#v\n- got: %#v", want, got) } // Keep sending to verify sequence number increment seq := m.Header.Sequence for i := 0; i < 100; i++ { out, err := c.Send(netlink.Message{}) if err != nil { t.Fatalf("failed to send message: %v", err) } seq++ if want, got := seq, out.Header.Sequence; want != got { t.Fatalf("unexpected sequence number:\n- want: %v\n- got: %v", want, got) } } } func TestConnExecuteMultipart(t *testing.T) { msg := netlink.Message{ Header: netlink.Header{ Sequence: 1, }, Data: []byte{0xff, 0xff, 0xff, 0xff}, } c := nltest.Dial(func(_ []netlink.Message) ([]netlink.Message, error) { return nltest.Multipart([]netlink.Message{ msg, // Will be filled with multipart done information. {}, }) }) defer c.Close() msgs, err := c.Execute(msg) if err != nil { t.Fatalf("failed to receive messages: %v", err) } msg.Header.Flags |= netlink.Multi if want, got := []netlink.Message{msg}, msgs; !reflect.DeepEqual(want, got) { t.Fatalf("unexpected output messages from Conn.Receive:\n- want: %#v\n- got: %#v", want, got) } } func TestConnExecuteNoMessages(t *testing.T) { c := nltest.Dial(func(_ []netlink.Message) ([]netlink.Message, error) { return nil, io.EOF }) defer c.Close() msgs, err := c.Execute(netlink.Message{}) if err != nil { t.Fatalf("failed to execute: %v", err) } if l := len(msgs); l > 0 { t.Fatalf("expected no messages, but got: %d", l) } } func TestConnReceiveNoMessages(t *testing.T) { c := nltest.Dial(func(_ []netlink.Message) ([]netlink.Message, error) { return nil, io.EOF }) defer c.Close() msgs, err := c.Receive() if err != nil { t.Fatalf("failed to execute: %v", err) } if l := len(msgs); l > 0 { t.Fatalf("expected no messages, but got: %d", l) } } func TestConnReceiveShortErrorNumber(t *testing.T) { c := nltest.Dial(func(_ []netlink.Message) ([]netlink.Message, error) { return []netlink.Message{{ Header: netlink.Header{ Length: 20, Type: netlink.Error, }, Data: []byte{0x01}, }}, nil }) defer c.Close() _, err := c.Receive() if !strings.Contains(err.Error(), "not enough data") { t.Fatalf("unexpected error: %v", err) } } func TestConnReceiveShortErrorAcknowledgementHeader(t *testing.T) { c := nltest.Dial(func(_ []netlink.Message) ([]netlink.Message, error) { return []netlink.Message{{ Header: netlink.Header{ Length: 20, Type: netlink.Error, Flags: netlink.AcknowledgeTLVs, }, Data: []byte{ // errno. 0x01, 0x00, 0x00, 0x00, // nlmsghdr 0xff, }, }}, nil }) defer c.Close() _, err := c.Receive() if !strings.Contains(err.Error(), "not enough data") { t.Fatalf("unexpected error: %v", err) } } func TestConnJoinLeaveGroupUnsupported(t *testing.T) { c := nltest.Dial(nil) defer c.Close() ops := []func(group uint32) error{ c.JoinGroup, c.LeaveGroup, } for _, op := range ops { err := op(0) if !strings.Contains(err.Error(), "not supported") { t.Fatalf("unexpected error: %v", err) } } } func TestConnSetBPFUnsupported(t *testing.T) { c := nltest.Dial(nil) defer c.Close() err := c.SetBPF(nil) if !strings.Contains(err.Error(), "not supported") { t.Fatalf("unexpected error: %v", err) } } func TestConnSetDeadlineUnsupported(t *testing.T) { c := nltest.Dial(nil) defer c.Close() err := c.SetDeadline(time.Now()) if !strings.Contains(err.Error(), "not supported") { t.Fatalf("unexpected error: %v", err) } } func TestConnSetOptionUnsupported(t *testing.T) { c := nltest.Dial(nil) defer c.Close() err := c.SetOption(0, false) if !strings.Contains(err.Error(), "not supported") { t.Fatalf("unexpected error: %v", err) } } func TestConnSetBuffersUnsupported(t *testing.T) { c := nltest.Dial(nil) defer c.Close() ops := []func(n int) error{ c.SetReadBuffer, c.SetWriteBuffer, } for _, op := range ops { err := op(0) if !strings.Contains(err.Error(), "not supported") { t.Fatalf("unexpected error: %v", err) } } } func TestConnSyscallConnUnsupported(t *testing.T) { c := nltest.Dial(nil) defer c.Close() if _, err := c.SyscallConn(); !strings.Contains(err.Error(), "not supported") { t.Fatalf("unexpected error: %v", err) } } netlink-1.7.2/debug.go000066400000000000000000000024401442322526200146040ustar00rootroot00000000000000package netlink import ( "fmt" "log" "os" "strconv" "strings" ) // Arguments used to create a debugger. var debugArgs []string func init() { // Is netlink debugging enabled? s := os.Getenv("NLDEBUG") if s == "" { return } debugArgs = strings.Split(s, ",") } // A debugger is used to provide debugging information about a netlink connection. type debugger struct { Log *log.Logger Level int } // newDebugger creates a debugger by parsing key=value arguments. func newDebugger(args []string) *debugger { d := &debugger{ Log: log.New(os.Stderr, "nl: ", 0), Level: 1, } for _, a := range args { kv := strings.Split(a, "=") if len(kv) != 2 { // Ignore malformed pairs and assume callers wants defaults. continue } switch kv[0] { // Select the log level for the debugger. case "level": level, err := strconv.Atoi(kv[1]) if err != nil { panicf("netlink: invalid NLDEBUG level: %q", a) } d.Level = level } } return d } // debugf prints debugging information at the specified level, if d.Level is // high enough to print the message. func (d *debugger) debugf(level int, format string, v ...interface{}) { if d.Level >= level { d.Log.Printf(format, v...) } } func panicf(format string, a ...interface{}) { panic(fmt.Sprintf(format, a...)) } netlink-1.7.2/doc.go000066400000000000000000000022141442322526200142620ustar00rootroot00000000000000// Package netlink provides low-level access to Linux netlink sockets // (AF_NETLINK). // // If you have any questions or you'd like some guidance, please join us on // Gophers Slack (https://invite.slack.golangbridge.org) in the #networking // channel! // // # Network namespaces // // This package is aware of Linux network namespaces, and can enter different // network namespaces either implicitly or explicitly, depending on // configuration. The Config structure passed to Dial to create a Conn controls // these behaviors. See the documentation of Config.NetNS for details. // // # Debugging // // This package supports rudimentary netlink connection debugging support. To // enable this, run your binary with the NLDEBUG environment variable set. // Debugging information will be output to stderr with a prefix of "nl:". // // To use the debugging defaults, use: // // $ NLDEBUG=1 ./nlctl // // To configure individual aspects of the debugger, pass key/value options such // as: // // $ NLDEBUG=level=1 ./nlctl // // Available key/value debugger options include: // // level=N: specify the debugging level (only "1" is currently supported) package netlink netlink-1.7.2/errors.go000066400000000000000000000071751442322526200150440ustar00rootroot00000000000000package netlink import ( "errors" "fmt" "net" "os" "strings" ) // Error messages which can be returned by Validate. var ( errMismatchedSequence = errors.New("mismatched sequence in netlink reply") errMismatchedPID = errors.New("mismatched PID in netlink reply") errShortErrorMessage = errors.New("not enough data for netlink error code") ) // Errors which can be returned by a Socket that does not implement // all exposed methods of Conn. var errNotSupported = errors.New("operation not supported") // notSupported provides a concise constructor for "not supported" errors. func notSupported(op string) error { return newOpError(op, errNotSupported) } // IsNotExist determines if an error is produced as the result of querying some // file, object, resource, etc. which does not exist. // // Deprecated: use errors.Unwrap and/or `errors.Is(err, os.Permission)` in Go // 1.13+. func IsNotExist(err error) bool { switch err := err.(type) { case *OpError: // Unwrap the inner error and use the stdlib's logic. return os.IsNotExist(err.Err) default: return os.IsNotExist(err) } } var ( _ error = &OpError{} _ net.Error = &OpError{} // Ensure compatibility with Go 1.13+ errors package. _ interface{ Unwrap() error } = &OpError{} ) // An OpError is an error produced as the result of a failed netlink operation. type OpError struct { // Op is the operation which caused this OpError, such as "send" // or "receive". Op string // Err is the underlying error which caused this OpError. // // If Err was produced by a system call error, Err will be of type // *os.SyscallError. If Err was produced by an error code in a netlink // message, Err will contain a raw error value type such as a unix.Errno. // // Most callers should inspect Err using errors.Is from the standard // library. Err error // Message and Offset contain additional error information provided by the // kernel when the ExtendedAcknowledge option is set on a Conn and the // kernel indicates the AcknowledgeTLVs flag in a response. If this option // is not set, both of these fields will be empty. Message string Offset int } // newOpError is a small wrapper for creating an OpError. As a convenience, it // returns nil if the input err is nil: akin to os.NewSyscallError. func newOpError(op string, err error) error { if err == nil { return nil } return &OpError{ Op: op, Err: err, } } func (e *OpError) Error() string { if e == nil { return "" } var sb strings.Builder _, _ = sb.WriteString(fmt.Sprintf("netlink %s: %v", e.Op, e.Err)) if e.Message != "" || e.Offset != 0 { _, _ = sb.WriteString(fmt.Sprintf(", offset: %d, message: %q", e.Offset, e.Message)) } return sb.String() } // Unwrap unwraps the internal Err field for use with errors.Unwrap. func (e *OpError) Unwrap() error { return e.Err } // Portions of this code taken from the Go standard library: // // Copyright 2009 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. type timeout interface { Timeout() bool } // Timeout reports whether the error was caused by an I/O timeout. func (e *OpError) Timeout() bool { if ne, ok := e.Err.(*os.SyscallError); ok { t, ok := ne.Err.(timeout) return ok && t.Timeout() } t, ok := e.Err.(timeout) return ok && t.Timeout() } type temporary interface { Temporary() bool } // Temporary reports whether an operation may succeed if retried. func (e *OpError) Temporary() bool { if ne, ok := e.Err.(*os.SyscallError); ok { t, ok := ne.Err.(temporary) return ok && t.Temporary() } t, ok := e.Err.(temporary) return ok && t.Temporary() } netlink-1.7.2/example_attributedecoder_test.go000066400000000000000000000044271442322526200216300ustar00rootroot00000000000000package netlink_test import ( "fmt" "log" "github.com/mdlayher/netlink" ) // decodeNested is a nested structure within decodeOut. type decodeNested struct { A, B uint32 } // decodeOut is an example structure we will use to unpack netlink attributes. type decodeOut struct { Number uint16 String string Nested decodeNested } // decode is an example function used to adapt the ad.Nested method to decode an // arbitrary structure. func (n *decodeNested) decode(ad *netlink.AttributeDecoder) error { // Iterate over the attributes, checking the type of each attribute and // decoding them as appropriate. for ad.Next() { switch ad.Type() { // A and B are both uint32 values, so decode them as such. case 1: n.A = ad.Uint32() case 2: n.B = ad.Uint32() } } // No need to call ad.Err directly. return nil } // This example demonstrates using a netlink.AttributeDecoder to decode packed // netlink attributes in a message payload. func ExampleAttributeDecoder_decode() { // Create a netlink.AttributeDecoder using some example attribute bytes // that are prepared for this example. ad, err := netlink.NewAttributeDecoder(exampleAttributes()) if err != nil { log.Fatalf("failed to create attribute decoder: %v", err) } // Iterate attributes until completion, checking the type of each and // decoding them as appropriate. var out decodeOut for ad.Next() { // Check the type of the current attribute with ad.Type. Typically you // will find netlink attribute types and data values in C headers as // constants. switch ad.Type() { case 1: // Number is a uint16. out.Number = ad.Uint16() case 2: // String is a string. out.String = ad.String() case 3: // Nested is a nested structure, so we will use a method on the // nested type along with ad.Do to decode it in a concise way. ad.Nested(out.Nested.decode) } } // Any errors encountered during decoding (including any errors from // decoding the nested attributes) will be returned here. if err := ad.Err(); err != nil { log.Fatalf("failed to decode attributes: %v", err) } fmt.Printf(`Number: %d String: %q Nested: - A: %d - B: %d`, out.Number, out.String, out.Nested.A, out.Nested.B, ) // Output: // Number: 1 // String: "hello world" // Nested: // - A: 2 // - B: 3 } netlink-1.7.2/example_attributeencoder_test.go000066400000000000000000000042721442322526200216400ustar00rootroot00000000000000package netlink_test import ( "fmt" "log" "github.com/mdlayher/netlink" ) // encodeNested is a nested structure within out. type encodeNested struct { A, B uint32 } // encodeOut is an example structure we will use to pack netlink attributes. type encodeOut struct { Number uint16 String string Nested encodeNested } // encode is an example function used to adapt the ae.Nested method // to encode an arbitrary structure. func (n encodeNested) encode(ae *netlink.AttributeEncoder) error { // Encode the fields of the nested structure. ae.Uint32(1, n.A) ae.Uint32(2, n.B) return nil } func ExampleAttributeEncoder_encode() { // Create a netlink.AttributeEncoder that encodes to the same message // as that decoded by the netlink.AttributeDecoder example. ae := netlink.NewAttributeEncoder() o := encodeOut{ Number: 1, String: "hello world", Nested: encodeNested{ A: 2, B: 3, }, } // Encode the Number attribute as a uint16. ae.Uint16(1, o.Number) // Encode the String attribute as a string. ae.String(2, o.String) // Nested is a nested structure, so we will use the encodeNested type's // encode method with ae.Nested to encode it in a concise way. ae.Nested(3, o.Nested.encode) // Any errors encountered during encoding (including any errors from // encoding nested attributes) will be returned here. b, err := ae.Encode() if err != nil { log.Fatalf("failed to encode attributes: %v", err) } // Now decode the attributes again to verify the contents. ad, err := netlink.NewAttributeDecoder(b) if err != nil { log.Fatalf("failed to decode attributes: %v", err) } // Walk the attributes and print each out. for ad.Next() { switch ad.Type() { case 1: fmt.Println("uint16:", ad.Uint16()) case 2: fmt.Println("string:", ad.String()) case 3: fmt.Println("nested:") // Nested attributes use their own nested decoder. ad.Nested(func(nad *netlink.AttributeDecoder) error { for nad.Next() { switch nad.Type() { case 1: fmt.Println(" - A:", nad.Uint32()) case 2: fmt.Println(" - B:", nad.Uint32()) } } return nil }) } } // Output: uint16: 1 // string: hello world // nested: // - A: 2 // - B: 3 } netlink-1.7.2/example_test.go000066400000000000000000000047761442322526200162260ustar00rootroot00000000000000package netlink_test import ( "log" "github.com/mdlayher/netlink" "github.com/mdlayher/netlink/nlenc" "github.com/mdlayher/netlink/nltest" ) // This example demonstrates using a netlink.Conn to execute requests against // netlink. func ExampleConn_execute() { // Speak to generic netlink using netlink const familyGeneric = 16 c, err := netlink.Dial(familyGeneric, nil) if err != nil { log.Fatalf("failed to dial netlink: %v", err) } defer c.Close() // Ask netlink to send us an acknowledgement, which will contain // a copy of the header we sent to it req := netlink.Message{ Header: netlink.Header{ // Package netlink will automatically set header fields // which are set to zero Flags: netlink.Request | netlink.Acknowledge, }, } // Perform a request, receive replies, and validate the replies msgs, err := c.Execute(req) if err != nil { log.Fatalf("failed to execute request: %v", err) } if c := len(msgs); c != 1 { log.Fatalf("expected 1 message, but got: %d", c) } // Decode the copied request header, starting after 4 bytes // indicating "success" var res netlink.Message if err := (&res).UnmarshalBinary(msgs[0].Data[4:]); err != nil { log.Fatalf("failed to unmarshal response: %v", err) } log.Printf("res: %+v", res) } // This example demonstrates using a netlink.Conn to listen for multicast group // messages generated by the addition and deletion of network interfaces. func ExampleConn_listenMulticast() { const ( // Speak to route netlink using netlink familyRoute = 0 // Listen for events triggered by addition or deletion of // network interfaces rtmGroupLink = 0x1 ) c, err := netlink.Dial(familyRoute, &netlink.Config{ // Groups is a bitmask; more than one group can be specified // by OR'ing multiple group values together Groups: rtmGroupLink, }) if err != nil { log.Fatalf("failed to dial netlink: %v", err) } defer c.Close() for { // Listen for netlink messages triggered by multicast groups msgs, err := c.Receive() if err != nil { log.Fatalf("failed to receive messages: %v", err) } log.Printf("msgs: %+v", msgs) } } func exampleAttributes() []byte { return nltest.MustMarshalAttributes([]netlink.Attribute{ { Type: 1, Data: nlenc.Uint16Bytes(1), }, { Type: 2, Data: nlenc.Bytes("hello world"), }, { Type: 3, Data: nltest.MustMarshalAttributes([]netlink.Attribute{ { Type: 1, Data: nlenc.Uint32Bytes(2), }, { Type: 2, Data: nlenc.Uint32Bytes(3), }, }), }, }) } netlink-1.7.2/fuzz.go000066400000000000000000000036521442322526200145220ustar00rootroot00000000000000//go:build gofuzz // +build gofuzz package netlink import "github.com/google/go-cmp/cmp" func fuzz(b1 []byte) int { // 1. unmarshal, marshal, unmarshal again to check m1 and m2 for equality // after a round trip. checkMessage is also used because there is a fair // amount of tricky logic around testing for presence of error headers and // extended acknowledgement attributes. var m1 Message if err := m1.UnmarshalBinary(b1); err != nil { return 0 } if err := checkMessage(m1); err != nil { return 0 } b2, err := m1.MarshalBinary() if err != nil { panicf("failed to marshal m1: %v", err) } var m2 Message if err := m2.UnmarshalBinary(b2); err != nil { panicf("failed to unmarshal m2: %v", err) } if err := checkMessage(m2); err != nil { panicf("failed to check m2: %v", err) } if diff := cmp.Diff(m1, m2); diff != "" { panicf("unexpected Message (-want +got):\n%s", diff) } // 2. marshal again and compare b2 and b3 (b1 may have reserved bytes set // which we ignore and fill with zeros when marshaling) for equality. b3, err := m2.MarshalBinary() if err != nil { panicf("failed to marshal m2: %v", err) } if diff := cmp.Diff(b2, b3); diff != "" { panicf("unexpected message bytes (-want +got):\n%s", diff) } // 3. unmarshal any possible attributes from m1's data and marshal them // again for comparison. a1, err := UnmarshalAttributes(m1.Data) if err != nil { return 0 } ab1, err := MarshalAttributes(a1) if err != nil { panicf("failed to marshal a1: %v", err) } a2, err := UnmarshalAttributes(ab1) if err != nil { panicf("failed to unmarshal a2: %v", err) } if diff := cmp.Diff(a1, a2); diff != "" { panicf("unexpected Attributes (-want +got):\n%s", diff) } ab2, err := MarshalAttributes(a2) if err != nil { panicf("failed to marshal a2: %v", err) } if diff := cmp.Diff(ab1, ab2); diff != "" { panicf("unexpected attribute bytes (-want +got):\n%s", diff) } return 1 } netlink-1.7.2/fuzz_test.go000066400000000000000000000004031442322526200155500ustar00rootroot00000000000000//go:build gofuzz // +build gofuzz package netlink import "testing" func Test_fuzz(t *testing.T) { tests := []struct { name string s string }{} for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { _ = fuzz([]byte(tt.s)) }) } } netlink-1.7.2/go.mod000066400000000000000000000004011442322526200142700ustar00rootroot00000000000000module github.com/mdlayher/netlink go 1.18 require ( github.com/google/go-cmp v0.5.9 github.com/josharian/native v1.1.0 github.com/mdlayher/socket v0.4.1 golang.org/x/net v0.9.0 golang.org/x/sys v0.7.0 ) require golang.org/x/sync v0.1.0 // indirect netlink-1.7.2/go.sum000066400000000000000000000017061442322526200143260ustar00rootroot00000000000000github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA= github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U= github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA= golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM= golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU= golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= netlink-1.7.2/internal/000077500000000000000000000000001442322526200150035ustar00rootroot00000000000000netlink-1.7.2/internal/integration/000077500000000000000000000000001442322526200173265ustar00rootroot00000000000000netlink-1.7.2/internal/integration/go.mod000066400000000000000000000012641442322526200204370ustar00rootroot00000000000000module github.com/mdlayher/netlink/internal/integration go 1.18 require ( github.com/google/go-cmp v0.5.9 github.com/jsimonetti/rtnetlink v1.3.2 github.com/mdlayher/ethtool v0.0.0-20221212131811-ba3b4bc2e02c golang.org/x/net v0.9.0 golang.org/x/sys v0.7.0 ) require ( github.com/josharian/native v1.1.0 // indirect github.com/mdlayher/genetlink v1.3.1 // indirect github.com/mdlayher/socket v0.4.1 // indirect golang.org/x/sync v0.1.0 // indirect ) // We require a recent release, but in reality the integration tests should // always use the netlink module at the root of the repository. require github.com/mdlayher/netlink v1.7.1 replace github.com/mdlayher/netlink => ../../ netlink-1.7.2/internal/integration/go.sum000066400000000000000000000031401442322526200204570ustar00rootroot00000000000000github.com/cilium/ebpf v0.10.0 h1:nk5HPMeoBXtOzbkZBWym+ZWq1GIiHUsBFXxwewXAHLQ= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA= github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= github.com/jsimonetti/rtnetlink v1.3.2 h1:dcn0uWkfxycEEyNy0IGfx3GrhQ38LH7odjxAghimsVI= github.com/jsimonetti/rtnetlink v1.3.2/go.mod h1:BBu4jZCpTjP6Gk0/wfrO8qcqymnN3g0hoFqObRmUo6U= github.com/mdlayher/ethtool v0.0.0-20221212131811-ba3b4bc2e02c h1:Y7LoKqIgD7vmqJ7+6ZVnADuwUO+m3tGXbf2lK0OvjIw= github.com/mdlayher/ethtool v0.0.0-20221212131811-ba3b4bc2e02c/go.mod h1:i0nPbE+sL2G3OtdIb9SXxW/T4UiAwh6rxPW7zcuX+KQ= github.com/mdlayher/genetlink v1.3.1 h1:roBiPnual+eqtRkKX2Jb8UQN5ZPWnhDCGj/wR6Jlz2w= github.com/mdlayher/genetlink v1.3.1/go.mod h1:uaIPxkWmGk753VVIzDtROxQ8+T+dkHqOI0vB1NA9S/Q= github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U= github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA= golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM= golang.org/x/net v0.9.0/go.mod h1:d48xBJpPfHeWQsugry2m+kC02ZBRGRgulfHnEXEuWns= golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.7.0 h1:3jlCCIQZPdOYu1h8BkNvLz8Kgwtae2cagcG/VamtZRU= golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= netlink-1.7.2/internal/integration/integration_linux_test.go000066400000000000000000000246061442322526200244660ustar00rootroot00000000000000//go:build linux // +build linux package integration_test import ( "errors" "fmt" "net" "os" "os/exec" "sync" "testing" "time" "github.com/google/go-cmp/cmp" "github.com/jsimonetti/rtnetlink" "github.com/mdlayher/ethtool" "github.com/mdlayher/netlink" "golang.org/x/net/nettest" "golang.org/x/sys/unix" ) func TestIntegrationConnMulticast(t *testing.T) { skipUnprivileged(t) c, done := rtnlDial(t, 0) defer done() // Create an interface to trigger a notification, and remove it at the end // of the test. const ifName = "nltest0" defer shell(t, "ip", "link", "del", ifName) ifi := rtnlReceive(t, c, func() { shell(t, "ip", "tuntap", "add", ifName, "mode", "tun") }) if diff := cmp.Diff(ifName, ifi); diff != "" { t.Fatalf("unexpected interface name (-want +got):\n%s", diff) } } func TestIntegrationConnNetNSExplicit(t *testing.T) { skipUnprivileged(t) // Create a network namespace for use within this test. const ns = "nltest0" shell(t, "ip", "netns", "add", ns) defer shell(t, "ip", "netns", "del", ns) f, err := os.Open("/var/run/netns/" + ns) if err != nil { t.Fatalf("failed to open namespace file: %v", err) } defer f.Close() // Create a connection in each the host namespace and the new network // namespace. We will use these to validate that a namespace was entered // and that an interface creation notification was only visible to the // connection within the namespace. hostC, hostDone := rtnlDial(t, 0) defer hostDone() nsC, nsDone := rtnlDial(t, int(f.Fd())) defer nsDone() var wg sync.WaitGroup wg.Add(1) defer wg.Wait() go func() { defer wg.Done() _, err := hostC.Receive() if err == nil { panic("received netlink message in host namespace") } // Timeout means we were interrupted, so return. if nerr, ok := err.(net.Error); ok && nerr.Timeout() { return } panicf("failed to receive in host namespace: %v", err) }() // Create a temporary interface within the new network namespace. const ifName = "nltestns0" defer shell(t, "ip", "netns", "exec", ns, "ip", "link", "del", ifName) ifi := rtnlReceive(t, nsC, func() { // Trigger a notification in the new namespace. shell(t, "ip", "netns", "exec", ns, "ip", "tuntap", "add", ifName, "mode", "tun") }) // And finally interrupt the host connection so it can exit its // receive goroutine. if err := hostC.SetDeadline(time.Unix(1, 0)); err != nil { t.Fatalf("failed to interrupt host connection: %v", err) } if diff := cmp.Diff(ifName, ifi); diff != "" { t.Fatalf("unexpected interface name (-want +got):\n%s", diff) } } func TestIntegrationRTNetlinkStrictCheckExtendedAcknowledge(t *testing.T) { c, err := netlink.Dial(unix.NETLINK_ROUTE, nil) if err != nil { t.Fatalf("failed to open rtnetlink socket: %s", err) } defer c.Close() // Turn on extended acknowledgements and strict checking so rtnetlink // reports detailed error information regarding our invalid dump request. setStrictCheck(t, c) if err := c.SetOption(netlink.ExtendedAcknowledge, true); err != nil { t.Fatalf("failed to set extended acknowledge option: %v", err) } // The kernel will complain that this field isn't valid for a filtered dump // request. b, err := (&rtnetlink.RouteMessage{SrcLength: 1}).MarshalBinary() if err != nil { t.Fatalf("failed to marshal request: %v", err) } _, err = c.Execute(netlink.Message{ Header: netlink.Header{ Type: unix.RTM_GETROUTE, Flags: netlink.Request | netlink.Dump, }, Data: b, }) oerr, ok := err.(*netlink.OpError) if !ok { t.Fatalf("expected *netlink.OpError, but got: %T", err) } // Assume the message contents will be relatively static but don't hardcode // offset just in case things change. want := &netlink.OpError{ Op: "receive", Err: unix.EINVAL, Message: "Invalid values in header for FIB dump request", } if diff := cmp.Diff(want, oerr); diff != "" { t.Fatalf("unexpected *netlink.OpError (-want +got):\n%s", diff) } } func TestIntegrationRTNetlinkRouteManipulation(t *testing.T) { skipUnprivileged(t) c, err := netlink.Dial(unix.NETLINK_ROUTE, nil) if err != nil { t.Fatalf("failed to open rtnetlink socket: %s", err) } defer c.Close() // Required for kernel route dump filtering. setStrictCheck(t, c) lo, err := nettest.LoopbackInterface() if err != nil { t.Fatalf("failed to get loopback: %v", err) } // Install synthetic routes in documentation ranges into a non-default table // which we will later dump. const ( table = 100 ip4Mask = 32 ip6Mask = 128 ) var ( ip4 = &net.IPNet{ IP: net.IPv4(192, 2, 2, 1), Mask: net.CIDRMask(ip4Mask, ip4Mask), } ip6 = &net.IPNet{ IP: net.ParseIP("2001:db8::1"), Mask: net.CIDRMask(ip6Mask, ip6Mask), } want = []*net.IPNet{ip4, ip6} ) rtmsgs := []rtnetlink.RouteMessage{ { Family: unix.AF_INET, DstLength: ip4Mask, Protocol: unix.RTPROT_STATIC, Scope: unix.RT_SCOPE_UNIVERSE, Type: unix.RTN_UNICAST, Attributes: rtnetlink.RouteAttributes{ Dst: ip4.IP, OutIface: uint32(lo.Index), Table: table, }, }, { Family: unix.AF_INET6, DstLength: ip6Mask, Protocol: unix.RTPROT_STATIC, Scope: unix.RT_SCOPE_UNIVERSE, Type: unix.RTN_UNICAST, Attributes: rtnetlink.RouteAttributes{ Dst: ip6.IP, OutIface: uint32(lo.Index), Table: table, }, }, } // Verify we can send a batch of updates in one syscall. var msgs []netlink.Message for _, m := range rtmsgs { b, err := m.MarshalBinary() if err != nil { t.Fatalf("failed to marshal: %v", err) } msgs = append(msgs, netlink.Message{ Header: netlink.Header{ Type: unix.RTM_NEWROUTE, Flags: netlink.Request | netlink.Create | netlink.Replace, }, Data: b, }) } if _, err := c.SendMessages(msgs); err != nil { t.Fatalf("failed to add routes: %v", err) } // Only dump routes from the specified table. b, err := (&rtnetlink.RouteMessage{ Attributes: rtnetlink.RouteAttributes{Table: table}, }).MarshalBinary() if err != nil { t.Fatalf("failed to marshal request: %v", err) } routes, err := c.Execute( netlink.Message{ Header: netlink.Header{ Type: unix.RTM_GETROUTE, Flags: netlink.Request | netlink.Dump, }, Data: b, }, ) if err != nil { t.Fatalf("failed to dump routes: %v", err) } // Parse the routes back to Go structures. got := make([]*net.IPNet, 0, len(routes)) for _, r := range routes { var rtm rtnetlink.RouteMessage if err := rtm.UnmarshalBinary(r.Data); err != nil { t.Fatalf("failed to unmarshal route: %v", err) } got = append(got, &net.IPNet{ IP: rtm.Attributes.Dst, Mask: net.CIDRMask(int(rtm.DstLength), int(rtm.DstLength)), }) } // Now clear the routes and verify they're removed before ensuring we got // the expected routes. for i := range msgs { msgs[i].Header.Type = unix.RTM_DELROUTE msgs[i].Header.Flags = netlink.Request | netlink.Acknowledge } if _, err := c.SendMessages(msgs); err != nil { t.Fatalf("failed to send: %v", err) } if _, err := c.Receive(); err != nil { t.Fatalf("failed to receive: %v", err) } if diff := cmp.Diff(want, got); diff != "" { t.Fatalf("unexpected routes (-want +got):\n%s", diff) } } func TestIntegrationEthtoolExtendedAcknowledge(t *testing.T) { t.Parallel() // The ethtool package uses extended acknowledgements and should populate // all of netlink.OpError's fields when unwrapped. c, err := ethtool.New() if err != nil { if errors.Is(err, os.ErrNotExist) { t.Skip("skipping, ethtool genetlink not available on this system") } t.Fatalf("failed to open ethtool genetlink: %v", err) } _, err = c.LinkInfo(ethtool.Interface{Name: "notexist0"}) if err == nil { t.Fatal("expected an error, but none occurred") } var oerr *netlink.OpError if !errors.As(err, &oerr) { t.Fatalf("expected wrapped *netlink.OpError, but got: %T", err) } // Assume the message contents will be relatively static but don't hardcode // offset just in case things change. if oerr.Offset == 0 { t.Fatal("no offset specified in *netlink.OpError") } oerr.Offset = 0 want := &netlink.OpError{ Op: "receive", Err: unix.ENODEV, Message: "no device matches name", } if diff := cmp.Diff(want, oerr); diff != "" { t.Fatalf("unexpected *netlink.OpError (-want +got):\n%s", diff) } } func rtnlDial(t *testing.T, netNS int) (*netlink.Conn, func()) { t.Helper() timer := time.AfterFunc(10*time.Second, func() { panic("test took too long") }) c, err := netlink.Dial(unix.NETLINK_ROUTE, &netlink.Config{ Groups: unix.RTMGRP_LINK, NetNS: netNS, }) if err != nil { t.Fatalf("failed to dial rtnetlink: %v", err) } return c, func() { if err := c.Close(); err != nil { t.Fatalf("failed to close rtnetlink connection: %v", err) } // Stop the timer to prevent a panic if other tests run for a long time. timer.Stop() } } func setStrictCheck(t *testing.T, c *netlink.Conn) { if err := c.SetOption(netlink.GetStrictCheck, true); err != nil { if errors.Is(err, unix.ENOPROTOOPT) { t.Skipf("skipping, netlink strict checking is not supported on this kernel") } t.Fatalf("failed to set strict check option: %v", err) } } func rtnlReceive(t *testing.T, c *netlink.Conn, do func()) string { t.Helper() // Receive messages in goroutine. msgC := make(chan rtnetlink.LinkMessage) go func() { msgs, err := c.Receive() if err != nil { panicf("failed to receive rtnetlink messages: %v", err) } var rtmsg rtnetlink.LinkMessage if err := rtmsg.UnmarshalBinary(msgs[0].Data); err != nil { panicf("failed to unmarshal rtnetlink message: %v", err) } msgC <- rtmsg }() // Execute the function which will generate messages, and then wait for // a message. do() m := <-msgC return m.Attributes.Name } func skipUnprivileged(t *testing.T) { const ifName = "nlprobe0" shell(t, "ip", "tuntap", "add", ifName, "mode", "tun") shell(t, "ip", "link", "del", ifName) } func shell(t *testing.T, name string, arg ...string) { t.Helper() t.Logf("$ %s %v", name, arg) cmd := exec.Command(name, arg...) if err := cmd.Start(); err != nil { t.Fatalf("failed to start command %q: %v", name, err) } if err := cmd.Wait(); err != nil { // Shell operations in these tests require elevated privileges. if cmd.ProcessState.ExitCode() == int(unix.EPERM) { t.Skipf("skipping, permission denied: %v", err) } t.Fatalf("failed to wait for command %q: %v", name, err) } } func panicf(format string, a ...interface{}) { panic(fmt.Sprintf(format, a...)) } netlink-1.7.2/message.go000066400000000000000000000220021442322526200151360ustar00rootroot00000000000000package netlink import ( "errors" "fmt" "unsafe" "github.com/mdlayher/netlink/nlenc" ) // Flags which may apply to netlink attribute types when communicating with // certain netlink families. const ( Nested uint16 = 0x8000 NetByteOrder uint16 = 0x4000 // attrTypeMask masks off Type bits used for the above flags. attrTypeMask uint16 = 0x3fff ) // Various errors which may occur when attempting to marshal or unmarshal // a Message to and from its binary form. var ( errIncorrectMessageLength = errors.New("netlink message header length incorrect") errShortMessage = errors.New("not enough data to create a netlink message") errUnalignedMessage = errors.New("input data is not properly aligned for netlink message") ) // HeaderFlags specify flags which may be present in a Header. type HeaderFlags uint16 const ( // General netlink communication flags. // Request indicates a request to netlink. Request HeaderFlags = 1 // Multi indicates a multi-part message, terminated by Done on the // last message. Multi HeaderFlags = 2 // Acknowledge requests that netlink reply with an acknowledgement // using Error and, if needed, an error code. Acknowledge HeaderFlags = 4 // Echo requests that netlink echo this request back to the sender. Echo HeaderFlags = 8 // DumpInterrupted indicates that a dump was inconsistent due to a // sequence change. DumpInterrupted HeaderFlags = 16 // DumpFiltered indicates that a dump was filtered as requested. DumpFiltered HeaderFlags = 32 // Flags used to retrieve data from netlink. // Root requests that netlink return a complete table instead of a // single entry. Root HeaderFlags = 0x100 // Match requests that netlink return a list of all matching entries. Match HeaderFlags = 0x200 // Atomic requests that netlink send an atomic snapshot of its entries. // Requires CAP_NET_ADMIN or an effective UID of 0. Atomic HeaderFlags = 0x400 // Dump requests that netlink return a complete list of all entries. Dump HeaderFlags = Root | Match // Flags used to create objects. // Replace indicates request replaces an existing matching object. Replace HeaderFlags = 0x100 // Excl indicates request does not replace the object if it already exists. Excl HeaderFlags = 0x200 // Create indicates request creates an object if it doesn't already exist. Create HeaderFlags = 0x400 // Append indicates request adds to the end of the object list. Append HeaderFlags = 0x800 // Flags for extended acknowledgements. // Capped indicates the size of a request was capped in an extended // acknowledgement. Capped HeaderFlags = 0x100 // AcknowledgeTLVs indicates the presence of netlink extended // acknowledgement TLVs in a response. AcknowledgeTLVs HeaderFlags = 0x200 ) // String returns the string representation of a HeaderFlags. func (f HeaderFlags) String() string { names := []string{ "request", "multi", "acknowledge", "echo", "dumpinterrupted", "dumpfiltered", } var s string left := uint(f) for i, name := range names { if f&(1< 0 { if s != "" { s += "|" } s += fmt.Sprintf("%#x", left) } return s } // HeaderType specifies the type of a Header. type HeaderType uint16 const ( // Noop indicates that no action was taken. Noop HeaderType = 0x1 // Error indicates an error code is present, which is also used to indicate // success when the code is 0. Error HeaderType = 0x2 // Done indicates the end of a multi-part message. Done HeaderType = 0x3 // Overrun indicates that data was lost from this message. Overrun HeaderType = 0x4 ) // String returns the string representation of a HeaderType. func (t HeaderType) String() string { switch t { case Noop: return "noop" case Error: return "error" case Done: return "done" case Overrun: return "overrun" default: return fmt.Sprintf("unknown(%d)", t) } } // NB: the memory layout of Header and Linux's syscall.NlMsgHdr must be // exactly the same. Cannot reorder, change data type, add, or remove fields. // Named types of the same size (e.g. HeaderFlags is a uint16) are okay. // A Header is a netlink header. A Header is sent and received with each // Message to indicate metadata regarding a Message. type Header struct { // Length of a Message, including this Header. Length uint32 // Contents of a Message. Type HeaderType // Flags which may be used to modify a request or response. Flags HeaderFlags // The sequence number of a Message. Sequence uint32 // The port ID of the sending process. PID uint32 } // A Message is a netlink message. It contains a Header and an arbitrary // byte payload, which may be decoded using information from the Header. // // Data is often populated with netlink attributes. For easy encoding and // decoding of attributes, see the AttributeDecoder and AttributeEncoder types. type Message struct { Header Header Data []byte } // MarshalBinary marshals a Message into a byte slice. func (m Message) MarshalBinary() ([]byte, error) { ml := nlmsgAlign(int(m.Header.Length)) if ml < nlmsgHeaderLen || ml != int(m.Header.Length) { return nil, errIncorrectMessageLength } b := make([]byte, ml) nlenc.PutUint32(b[0:4], m.Header.Length) nlenc.PutUint16(b[4:6], uint16(m.Header.Type)) nlenc.PutUint16(b[6:8], uint16(m.Header.Flags)) nlenc.PutUint32(b[8:12], m.Header.Sequence) nlenc.PutUint32(b[12:16], m.Header.PID) copy(b[16:], m.Data) return b, nil } // UnmarshalBinary unmarshals the contents of a byte slice into a Message. func (m *Message) UnmarshalBinary(b []byte) error { if len(b) < nlmsgHeaderLen { return errShortMessage } if len(b) != nlmsgAlign(len(b)) { return errUnalignedMessage } // Don't allow misleading length m.Header.Length = nlenc.Uint32(b[0:4]) if int(m.Header.Length) != len(b) { return errShortMessage } m.Header.Type = HeaderType(nlenc.Uint16(b[4:6])) m.Header.Flags = HeaderFlags(nlenc.Uint16(b[6:8])) m.Header.Sequence = nlenc.Uint32(b[8:12]) m.Header.PID = nlenc.Uint32(b[12:16]) m.Data = b[16:] return nil } // checkMessage checks a single Message for netlink errors. func checkMessage(m Message) error { // NB: All non-nil errors returned from this function *must* be of type // OpError in order to maintain the appropriate contract with callers of // this package. // The libnl documentation indicates that type error can // contain error codes: // https://www.infradead.org/~tgr/libnl/doc/core.html#core_errmsg. // // However, rtnetlink at least seems to also allow errors to occur at the // end of a multipart message with done/multi and an error number. var hasHeader bool switch { case m.Header.Type == Error: // Error code followed by nlmsghdr/ext ack attributes. hasHeader = true case m.Header.Type == Done && m.Header.Flags&Multi != 0: // If no data, there must be no error number so just exit early. Some // of the unit tests hard-coded this but I don't actually know if this // case occurs in the wild. if len(m.Data) == 0 { return nil } // Done|Multi potentially followed by ext ack attributes. default: // Neither, nothing to do. return nil } // Errno occupies 4 bytes. const endErrno = 4 if len(m.Data) < endErrno { return newOpError("receive", errShortErrorMessage) } c := nlenc.Int32(m.Data[:endErrno]) if c == 0 { // 0 indicates no error. return nil } oerr := &OpError{ Op: "receive", // Error code is a negative integer, convert it into an OS-specific raw // system call error, but do not wrap with os.NewSyscallError to signify // that this error was produced by a netlink message; not a system call. Err: newError(-1 * int(c)), } // TODO(mdlayher): investigate the Capped flag. if m.Header.Flags&AcknowledgeTLVs == 0 { // No extended acknowledgement. return oerr } // Flags indicate an extended acknowledgement. The type/flags combination // checked above determines the offset where the TLVs occur. var off int if hasHeader { // There is an nlmsghdr preceding the TLVs. if len(m.Data) < endErrno+nlmsgHeaderLen { return newOpError("receive", errShortErrorMessage) } // The TLVs should be at the offset indicated by the nlmsghdr.length, // plus the offset where the header began. But make sure the calculated // offset is still in-bounds. h := *(*Header)(unsafe.Pointer(&m.Data[endErrno : endErrno+nlmsgHeaderLen][0])) off = endErrno + int(h.Length) if len(m.Data) < off { return newOpError("receive", errShortErrorMessage) } } else { // There is no nlmsghdr preceding the TLVs, parse them directly. off = endErrno } ad, err := NewAttributeDecoder(m.Data[off:]) if err != nil { // Malformed TLVs, just return the OpError with the info we have. return oerr } for ad.Next() { switch ad.Type() { case 1: // unix.NLMSGERR_ATTR_MSG oerr.Message = ad.String() case 2: // unix.NLMSGERR_ATTR_OFFS oerr.Offset = int(ad.Uint32()) } } // Explicitly ignore ad.Err: malformed TLVs, just return the OpError with // the info we have. return oerr } netlink-1.7.2/message_linux_test.go000066400000000000000000000066071442322526200174310ustar00rootroot00000000000000//go:build linux // +build linux package netlink import ( "syscall" "testing" "unsafe" "github.com/google/go-cmp/cmp" "github.com/mdlayher/netlink/nlenc" "golang.org/x/sys/unix" ) func TestHeaderMemoryLayoutLinux(t *testing.T) { var nh Header var sh syscall.NlMsghdr if want, got := unsafe.Sizeof(sh), unsafe.Sizeof(nh); want != got { t.Fatalf("unexpected structure sizes:\n- want: %v\n- got: %v", want, got) } sh = syscall.NlMsghdr{ Len: 0x10101010, Type: 0x2020, Flags: 0x3030, Seq: 0x40404040, Pid: 0x50505050, } nh = sysToHeader(sh) if want, got := sh.Len, nh.Length; want != got { t.Fatalf("unexpected header length:\n- want: %v\n- got: %v", want, got) } if want, got := sh.Type, uint16(nh.Type); want != got { t.Fatalf("unexpected header type:\n- want: %v\n- got: %v", want, got) } if want, got := sh.Flags, uint16(nh.Flags); want != got { t.Fatalf("unexpected header flags:\n- want: %v\n- got: %v", want, got) } if want, got := sh.Seq, nh.Sequence; want != got { t.Fatalf("unexpected header sequence:\n- want: %v\n- got: %v", want, got) } if want, got := sh.Pid, nh.PID; want != got { t.Fatalf("unexpected header PID:\n- want: %v\n- got: %v", want, got) } } func Test_checkMessageExtendedAcknowledgementTLVs(t *testing.T) { tests := []struct { name string m Message err *OpError }{ { name: "error", m: Message{ Header: Header{ Type: Error, // Indicate the use of extended acknowledgement. Flags: AcknowledgeTLVs, }, Data: packExtACK( -1, // The caller's request message with arbitrary bytes that we // skip over when parsing the TLVs. &Message{ Header: Header{Length: 4}, Data: []byte{0xff, 0xff, 0xff, 0xff}, }, // The actual extended acknowledgement TLVs. []Attribute{ { Type: 1, Data: nlenc.Bytes("bad request"), }, { Type: 2, Data: nlenc.Uint32Bytes(2), }, }, ), }, err: &OpError{ Op: "receive", Err: unix.Errno(1), Message: "bad request", Offset: 2, }, }, { name: "done multi", m: Message{ Header: Header{ Type: Done, // Indicate the use of extended acknowledgement. Flags: Multi | AcknowledgeTLVs, }, Data: packExtACK( -1, // No message, straight to TLVs. nil, []Attribute{ { Type: 1, Data: nlenc.Bytes("bad request"), }, { Type: 2, Data: nlenc.Uint32Bytes(2), }, }, ), }, err: &OpError{ Op: "receive", Err: unix.Errno(1), Message: "bad request", Offset: 2, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if diff := cmp.Diff(tt.err, checkMessage(tt.m)); diff != "" { t.Fatalf("unexpected OpError (-want +got):\n%s", diff) } }) } } // packExtACK packs an extended acknowledgement response. func packExtACK(errno int32, m *Message, tlvs []Attribute) []byte { b := nlenc.Int32Bytes(errno) if m != nil { // Copy the header length logic from Conn. m.Header.Length = uint32(nlmsgAlign(nlmsgLength(len(m.Data)))) mb, err := m.MarshalBinary() if err != nil { panicf("failed to marshal message: %v", err) } b = append(b, mb...) } ab, err := MarshalAttributes(tlvs) if err != nil { panicf("failed to marshal attributes: %v", err) } return append(b, ab...) } netlink-1.7.2/message_test.go000066400000000000000000000167571442322526200162210ustar00rootroot00000000000000package netlink import ( "bytes" "encoding/binary" "errors" "reflect" "testing" "github.com/josharian/native" ) func TestHeaderFlagsString(t *testing.T) { tests := []struct { f HeaderFlags s string }{ { f: 0, s: "0", }, { f: Request, s: "request", }, { f: Multi, s: "multi", }, { f: Echo, s: "echo", }, { f: DumpInterrupted, s: "dumpinterrupted", }, { f: DumpFiltered, s: "dumpfiltered", }, { f: Root, s: "0x100", }, { f: Replace, s: "0x100", }, { f: Match, s: "0x200", }, { f: Excl, s: "0x200", }, { f: Atomic, s: "0x400", }, { f: Create, s: "0x400", }, { f: Append, s: "0x800", }, { f: Dump, s: "0x300", }, { f: Request | Dump, s: "request|0x300", }, { f: Request | Acknowledge | Create | Replace, s: "request|acknowledge|0x500", }, } for _, tt := range tests { t.Run(tt.s, func(t *testing.T) { if want, got := tt.s, tt.f.String(); want != got { t.Fatalf("unexpected flag string for: %016b\n- want: %q\n- got: %q", tt.f, want, got) } }) } } func TestHeaderTypeString(t *testing.T) { tests := []struct { t HeaderType s string }{ { t: 0, s: "unknown(0)", }, { t: Noop, s: "noop", }, { t: Error, s: "error", }, { t: Done, s: "done", }, { t: Overrun, s: "overrun", }, } for _, tt := range tests { t.Run(tt.s, func(t *testing.T) { if want, got := tt.s, tt.t.String(); want != got { t.Fatalf("unexpected header type string:\n- want: %q\n- got: %q", want, got) } }) } } func TestMessageMarshal(t *testing.T) { skipBigEndian(t) tests := []struct { name string m Message b []byte err error }{ { name: "empty", m: Message{}, err: errIncorrectMessageLength, }, { name: "short", m: Message{ Header: Header{ Length: 15, }, }, err: errIncorrectMessageLength, }, { name: "unaligned", m: Message{ Header: Header{ Length: 17, }, }, err: errIncorrectMessageLength, }, { name: "OK no data", m: Message{ Header: Header{ Length: 16, }, }, b: []byte{ 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, }, }, { name: "OK unaligned data", m: Message{ Header: Header{ Length: 20, Flags: Request, Sequence: 1, PID: 10, }, Data: []byte("abc"), }, b: []byte{ 0x14, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0a, 0x00, 0x00, 0x00, 0x61, 0x62, 0x63, 0x00, /* last byte padded */ }, }, { name: "OK aligned data", m: Message{ Header: Header{ Length: 20, Type: Error, Sequence: 2, PID: 20, }, Data: []byte("abcd"), }, b: []byte{ 0x14, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x61, 0x62, 0x63, 0x64, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { b, err := tt.m.MarshalBinary() if want, got := tt.err, err; want != got { t.Fatalf("unexpected error:\n- want: %v\n- got: %v", want, got) } if err != nil { return } if want, got := tt.b, b; !bytes.Equal(want, got) { t.Fatalf("unexpected Message bytes:\n- want: [%# x]\n- got: [%# x]", want, got) } }) } } func TestMessageUnmarshal(t *testing.T) { skipBigEndian(t) tests := []struct { name string b []byte m Message err error }{ { name: "empty", err: errShortMessage, }, { name: "short", b: make([]byte, 15), err: errShortMessage, }, { name: "unaligned", b: make([]byte, 17), err: errUnalignedMessage, }, { name: "fuzz crasher: length shorter than slice", b: []byte("\x1d000000000000000"), err: errShortMessage, }, { name: "fuzz crasher: length longer than slice", b: []byte("\x13\x00\x00\x000000000000000000"), err: errShortMessage, }, { name: "OK no data", b: []byte{ 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, }, m: Message{ Header: Header{ Length: 16, }, Data: make([]byte, 0), }, }, { name: "OK data", m: Message{ Header: Header{ Length: 20, Type: Error, Sequence: 2, PID: 20, }, Data: []byte("abcd"), }, b: []byte{ 0x14, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x14, 0x00, 0x00, 0x00, 0x61, 0x62, 0x63, 0x64, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var m Message err := (&m).UnmarshalBinary(tt.b) if want, got := tt.err, err; want != got { t.Fatalf("unexpected error:\n- want: %v\n- got: %v", want, got) } if err != nil { return } if want, got := tt.m, m; !reflect.DeepEqual(want, got) { t.Fatalf("unexpected Message:\n- want: %#v\n- got: %#v", want, got) } }) } } func TestValidate(t *testing.T) { tests := []struct { name string req Message rep []Message err error }{ { name: "mismatched sequence", req: Message{ Header: Header{ Sequence: 1, }, }, rep: []Message{{ Header: Header{ Sequence: 2, }, }}, err: errMismatchedSequence, }, { name: "mismatched sequence second message", req: Message{ Header: Header{ Sequence: 1, }, }, rep: []Message{ { Header: Header{ Sequence: 1, }, }, { Header: Header{ Sequence: 2, }, }, }, err: errMismatchedSequence, }, { name: "mismatched PID", req: Message{ Header: Header{ PID: 1, }, }, rep: []Message{{ Header: Header{ PID: 2, }, }}, err: errMismatchedPID, }, { name: "mismatched PID second message", req: Message{ Header: Header{ PID: 1, }, }, rep: []Message{ { Header: Header{ PID: 1, }, }, { Header: Header{ PID: 2, }, }, }, err: errMismatchedPID, }, { name: "OK matching sequence and PID", req: Message{ Header: Header{ Sequence: 1, PID: 1, }, }, rep: []Message{{ Header: Header{ Sequence: 1, PID: 1, }, }}, }, { name: "OK multicast messages", // No request req: Message{}, rep: []Message{{ Header: Header{ Sequence: 1, PID: 0, }, }}, }, { name: "OK no PID assigned yet", // No request req: Message{ Header: Header{ Sequence: 1, PID: 0, }, }, rep: []Message{{ Header: Header{ Sequence: 1, PID: 9999, }, }}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := Validate(tt.req, tt.rep) if err == nil { if tt.err != nil { t.Fatal("expected an error, but none occurred") } return } var oerr *OpError if !errors.As(err, &oerr) { t.Fatalf("unexpected validate error type: %T", err) } if want, got := "validate", oerr.Op; want != got { t.Fatalf("unexpected op:\n- want: %v\n- got: %v", want, got) } if want, got := tt.err, oerr.Err; want != got { t.Fatalf("unexpected error:\n- want: %v\n- got: %v", want, got) } }) } } func skipBigEndian(t *testing.T) { if binary.ByteOrder(native.Endian) == binary.BigEndian { t.Skip("skipping test on big-endian system") } } netlink-1.7.2/nlenc/000077500000000000000000000000001442322526200142665ustar00rootroot00000000000000netlink-1.7.2/nlenc/doc.go000066400000000000000000000005731442322526200153670ustar00rootroot00000000000000// Package nlenc implements encoding and decoding functions for netlink // messages and attributes. package nlenc import ( "encoding/binary" "github.com/josharian/native" ) // NativeEndian returns the native byte order of this system. func NativeEndian() binary.ByteOrder { // TODO(mdlayher): consider deprecating and removing this function for v2. return native.Endian } netlink-1.7.2/nlenc/int.go000066400000000000000000000105211442322526200154060ustar00rootroot00000000000000package nlenc import ( "fmt" "unsafe" ) // PutUint8 encodes a uint8 into b. // If b is not exactly 1 byte in length, PutUint8 will panic. func PutUint8(b []byte, v uint8) { if l := len(b); l != 1 { panic(fmt.Sprintf("PutUint8: unexpected byte slice length: %d", l)) } b[0] = v } // PutUint16 encodes a uint16 into b using the host machine's native endianness. // If b is not exactly 2 bytes in length, PutUint16 will panic. func PutUint16(b []byte, v uint16) { if l := len(b); l != 2 { panic(fmt.Sprintf("PutUint16: unexpected byte slice length: %d", l)) } *(*uint16)(unsafe.Pointer(&b[0])) = v } // PutUint32 encodes a uint32 into b using the host machine's native endianness. // If b is not exactly 4 bytes in length, PutUint32 will panic. func PutUint32(b []byte, v uint32) { if l := len(b); l != 4 { panic(fmt.Sprintf("PutUint32: unexpected byte slice length: %d", l)) } *(*uint32)(unsafe.Pointer(&b[0])) = v } // PutUint64 encodes a uint64 into b using the host machine's native endianness. // If b is not exactly 8 bytes in length, PutUint64 will panic. func PutUint64(b []byte, v uint64) { if l := len(b); l != 8 { panic(fmt.Sprintf("PutUint64: unexpected byte slice length: %d", l)) } *(*uint64)(unsafe.Pointer(&b[0])) = v } // PutInt32 encodes a int32 into b using the host machine's native endianness. // If b is not exactly 4 bytes in length, PutInt32 will panic. func PutInt32(b []byte, v int32) { if l := len(b); l != 4 { panic(fmt.Sprintf("PutInt32: unexpected byte slice length: %d", l)) } *(*int32)(unsafe.Pointer(&b[0])) = v } // Uint8 decodes a uint8 from b. // If b is not exactly 1 byte in length, Uint8 will panic. func Uint8(b []byte) uint8 { if l := len(b); l != 1 { panic(fmt.Sprintf("Uint8: unexpected byte slice length: %d", l)) } return b[0] } // Uint16 decodes a uint16 from b using the host machine's native endianness. // If b is not exactly 2 bytes in length, Uint16 will panic. func Uint16(b []byte) uint16 { if l := len(b); l != 2 { panic(fmt.Sprintf("Uint16: unexpected byte slice length: %d", l)) } return *(*uint16)(unsafe.Pointer(&b[0])) } // Uint32 decodes a uint32 from b using the host machine's native endianness. // If b is not exactly 4 bytes in length, Uint32 will panic. func Uint32(b []byte) uint32 { if l := len(b); l != 4 { panic(fmt.Sprintf("Uint32: unexpected byte slice length: %d", l)) } return *(*uint32)(unsafe.Pointer(&b[0])) } // Uint64 decodes a uint64 from b using the host machine's native endianness. // If b is not exactly 8 bytes in length, Uint64 will panic. func Uint64(b []byte) uint64 { if l := len(b); l != 8 { panic(fmt.Sprintf("Uint64: unexpected byte slice length: %d", l)) } return *(*uint64)(unsafe.Pointer(&b[0])) } // Int32 decodes an int32 from b using the host machine's native endianness. // If b is not exactly 4 bytes in length, Int32 will panic. func Int32(b []byte) int32 { if l := len(b); l != 4 { panic(fmt.Sprintf("Int32: unexpected byte slice length: %d", l)) } return *(*int32)(unsafe.Pointer(&b[0])) } // Uint8Bytes encodes a uint8 into a newly-allocated byte slice. It is a // shortcut for allocating a new byte slice and filling it using PutUint8. func Uint8Bytes(v uint8) []byte { b := make([]byte, 1) PutUint8(b, v) return b } // Uint16Bytes encodes a uint16 into a newly-allocated byte slice using the // host machine's native endianness. It is a shortcut for allocating a new // byte slice and filling it using PutUint16. func Uint16Bytes(v uint16) []byte { b := make([]byte, 2) PutUint16(b, v) return b } // Uint32Bytes encodes a uint32 into a newly-allocated byte slice using the // host machine's native endianness. It is a shortcut for allocating a new // byte slice and filling it using PutUint32. func Uint32Bytes(v uint32) []byte { b := make([]byte, 4) PutUint32(b, v) return b } // Uint64Bytes encodes a uint64 into a newly-allocated byte slice using the // host machine's native endianness. It is a shortcut for allocating a new // byte slice and filling it using PutUint64. func Uint64Bytes(v uint64) []byte { b := make([]byte, 8) PutUint64(b, v) return b } // Int32Bytes encodes a int32 into a newly-allocated byte slice using the // host machine's native endianness. It is a shortcut for allocating a new // byte slice and filling it using PutInt32. func Int32Bytes(v int32) []byte { b := make([]byte, 4) PutInt32(b, v) return b } netlink-1.7.2/nlenc/int_test.go000066400000000000000000000175741442322526200164640ustar00rootroot00000000000000package nlenc import ( "bytes" "encoding/binary" "fmt" "testing" ) func TestUintPanic(t *testing.T) { tests := []struct { name string b []byte fn func(b []byte) }{ { name: "short put 8", b: make([]byte, 0), fn: func(b []byte) { PutUint8(b, 0) }, }, { name: "long put 8", b: make([]byte, 2), fn: func(b []byte) { PutUint8(b, 0) }, }, { name: "short get 8", b: make([]byte, 0), fn: func(b []byte) { Uint8(b) }, }, { name: "long get 8", b: make([]byte, 2), fn: func(b []byte) { Uint8(b) }, }, { name: "short put 16", b: make([]byte, 1), fn: func(b []byte) { PutUint16(b, 0) }, }, { name: "long put 16", b: make([]byte, 3), fn: func(b []byte) { PutUint16(b, 0) }, }, { name: "short get 16", b: make([]byte, 1), fn: func(b []byte) { Uint16(b) }, }, { name: "long get 16", b: make([]byte, 3), fn: func(b []byte) { Uint16(b) }, }, { name: "short put 32", b: make([]byte, 3), fn: func(b []byte) { PutUint32(b, 0) }, }, { name: "long put 32", b: make([]byte, 5), fn: func(b []byte) { PutUint32(b, 0) }, }, { name: "short get 32", b: make([]byte, 3), fn: func(b []byte) { Uint32(b) }, }, { name: "long get 32", b: make([]byte, 5), fn: func(b []byte) { Uint32(b) }, }, { name: "short put 64", b: make([]byte, 7), fn: func(b []byte) { PutUint64(b, 0) }, }, { name: "long put 64", b: make([]byte, 9), fn: func(b []byte) { PutUint64(b, 0) }, }, { name: "short get 64", b: make([]byte, 7), fn: func(b []byte) { Uint64(b) }, }, { name: "long get 64", b: make([]byte, 9), fn: func(b []byte) { Uint64(b) }, }, { name: "short put signed 32", b: make([]byte, 3), fn: func(b []byte) { PutInt32(b, 0) }, }, { name: "short get signed 32", b: make([]byte, 3), fn: func(b []byte) { Int32(b) }, }, { name: "long put signed 32", b: make([]byte, 5), fn: func(b []byte) { PutInt32(b, 0) }, }, { name: "long get signed 32", b: make([]byte, 5), fn: func(b []byte) { Int32(b) }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { defer func() { if r := recover(); r == nil { t.Fatal("expected panic, but none occurred") } }() tt.fn(tt.b) t.Fatal("reached end of test case without panic") }) } } func TestUint8(t *testing.T) { tests := []struct { v uint8 b []byte }{ { v: 0x01, b: []byte{0x01}, }, { v: 0xff, b: []byte{0xff}, }, } for _, tt := range tests { t.Run(fmt.Sprintf("0x%03x", tt.v), func(t *testing.T) { b := make([]byte, 1) PutUint8(b, tt.v) if want, got := tt.b, b; !bytes.Equal(want, got) { t.Fatalf("unexpected bytes:\n- want: [%# x]\n- got: [%# x]", want, got) } v := Uint8(b) if want, got := tt.v, v; want != got { t.Fatalf("unexpected integer:\n- want: 0x%03x\n- got: 0x%03x", want, got) } b = Uint8Bytes(tt.v) if want, got := tt.b, b; !bytes.Equal(want, got) { t.Fatalf("unexpected bytes:\n- want: [%# x]\n- got: [%# x]", want, got) } }) } } func TestUint16(t *testing.T) { skipBigEndian(t) tests := []struct { v uint16 b []byte }{ { v: 0x1, b: []byte{0x01, 0x00}, }, { v: 0x0102, b: []byte{0x02, 0x01}, }, { v: 0x1234, b: []byte{0x34, 0x12}, }, { v: 0xffff, b: []byte{0xff, 0xff}, }, } for _, tt := range tests { t.Run(fmt.Sprintf("0x%04x", tt.v), func(t *testing.T) { b := make([]byte, 2) PutUint16(b, tt.v) if want, got := tt.b, b; !bytes.Equal(want, got) { t.Fatalf("unexpected bytes:\n- want: [%# x]\n- got: [%# x]", want, got) } v := Uint16(b) if want, got := tt.v, v; want != got { t.Fatalf("unexpected integer:\n- want: 0x%04x\n- got: 0x%04x", want, got) } b = Uint16Bytes(tt.v) if want, got := tt.b, b; !bytes.Equal(want, got) { t.Fatalf("unexpected bytes:\n- want: [%# x]\n- got: [%# x]", want, got) } }) } } func TestUint32(t *testing.T) { skipBigEndian(t) tests := []struct { v uint32 b []byte }{ { v: 0x1, b: []byte{0x01, 0x00, 0x00, 0x00}, }, { v: 0x0102, b: []byte{0x02, 0x01, 0x00, 0x00}, }, { v: 0x1234, b: []byte{0x34, 0x12, 0x00, 0x00}, }, { v: 0xffff, b: []byte{0xff, 0xff, 0x00, 0x00}, }, { v: 0x01020304, b: []byte{0x04, 0x03, 0x02, 0x01}, }, { v: 0x1a2a3a4a, b: []byte{0x4a, 0x3a, 0x2a, 0x1a}, }, } for _, tt := range tests { t.Run(fmt.Sprintf("0x%08x", tt.v), func(t *testing.T) { b := make([]byte, 4) PutUint32(b, tt.v) if want, got := tt.b, b; !bytes.Equal(want, got) { t.Fatalf("unexpected bytes:\n- want: [%# x]\n- got: [%# x]", want, got) } v := Uint32(b) if want, got := tt.v, v; want != got { t.Fatalf("unexpected integer:\n- want: 0x%04x\n- got: 0x%04x", want, got) } b = Uint32Bytes(tt.v) if want, got := tt.b, b; !bytes.Equal(want, got) { t.Fatalf("unexpected bytes:\n- want: [%# x]\n- got: [%# x]", want, got) } }) } } func TestUint64(t *testing.T) { skipBigEndian(t) tests := []struct { v uint64 b []byte }{ { v: 0x1, b: []byte{0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, }, { v: 0x0102, b: []byte{0x02, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, }, { v: 0x1234, b: []byte{0x34, 0x12, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, }, { v: 0xffff, b: []byte{0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, }, { v: 0x01020304, b: []byte{0x04, 0x03, 0x02, 0x01, 0x00, 0x00, 0x00, 0x00}, }, { v: 0x1a2a3a4a, b: []byte{0x4a, 0x3a, 0x2a, 0x1a, 0x00, 0x00, 0x00, 0x00}, }, { v: 0x0102030405060708, b: []byte{0x08, 0x07, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01}, }, { v: 0x1a2a3a4a5a6a7a8a, b: []byte{0x8a, 0x7a, 0x6a, 0x5a, 0x4a, 0x3a, 0x2a, 0x1a}, }, } for _, tt := range tests { t.Run(fmt.Sprintf("0x%016x", tt.v), func(t *testing.T) { b := make([]byte, 8) PutUint64(b, tt.v) if want, got := tt.b, b; !bytes.Equal(want, got) { t.Fatalf("unexpected bytes:\n- want: [%# x]\n- got: [%# x]", want, got) } v := Uint64(b) if want, got := tt.v, v; want != got { t.Fatalf("unexpected integer:\n- want: 0x%04x\n- got: 0x%04x", want, got) } b = Uint64Bytes(tt.v) if want, got := tt.b, b; !bytes.Equal(want, got) { t.Fatalf("unexpected bytes:\n- want: [%# x]\n- got: [%# x]", want, got) } }) } } func TestInt32(t *testing.T) { skipBigEndian(t) tests := []struct { v int32 b []byte }{ { v: 0x1, b: []byte{0x01, 0x00, 0x00, 0x00}, }, { v: 0x0102, b: []byte{0x02, 0x01, 0x00, 0x00}, }, { v: 0x1234, b: []byte{0x34, 0x12, 0x00, 0x00}, }, { v: 0xffff, b: []byte{0xff, 0xff, 0x00, 0x00}, }, { v: 0x01020304, b: []byte{0x04, 0x03, 0x02, 0x01}, }, { v: 0x1a2a3a4a, b: []byte{0x4a, 0x3a, 0x2a, 0x1a}, }, { v: -1, b: []byte{0xff, 0xff, 0xff, 0xff}, }, { v: -2, b: []byte{0xfe, 0xff, 0xff, 0xff}, }, } for _, tt := range tests { t.Run(fmt.Sprintf("0x%08x", tt.v), func(t *testing.T) { b := make([]byte, 4) PutInt32(b, tt.v) if want, got := tt.b, b; !bytes.Equal(want, got) { t.Fatalf("unexpected bytes:\n- want: [%# x]\n- got: [%# x]", want, got) } v := Int32(b) if want, got := tt.v, v; want != got { t.Fatalf("unexpected integer:\n- want: 0x%04x\n- got: 0x%04x", want, got) } b = Int32Bytes(tt.v) if want, got := tt.b, b; !bytes.Equal(want, got) { t.Fatalf("unexpected bytes:\n- want: [%# x]\n- got: [%# x]", want, got) } }) } } func skipBigEndian(t *testing.T) { if NativeEndian() == binary.BigEndian { t.Skip("skipping test on big-endian system") } } netlink-1.7.2/nlenc/string.go000066400000000000000000000011111442322526200161150ustar00rootroot00000000000000package nlenc import "bytes" // Bytes returns a null-terminated byte slice with the contents of s. func Bytes(s string) []byte { return append([]byte(s), 0x00) } // String returns a string with the contents of b from a null-terminated // byte slice. func String(b []byte) string { // If the string has more than one NULL terminator byte, we want to remove // all of them before returning the string to the caller; hence the use of // strings.TrimRight instead of strings.TrimSuffix (which previously only // removed a single NULL). return string(bytes.TrimRight(b, "\x00")) } netlink-1.7.2/nlenc/string_test.go000066400000000000000000000016071442322526200171660ustar00rootroot00000000000000package nlenc import ( "testing" "github.com/google/go-cmp/cmp" ) func TestBytesString(t *testing.T) { tests := []struct { s string b []byte }{ { s: "foo", b: []byte{'f', 'o', 'o', 0x00}, }, { s: "nl80211", b: []byte{'n', 'l', '8', '0', '2', '1', '1', 0x00}, }, { s: "TASKSTATS", b: []byte{'T', 'A', 'S', 'K', 'S', 'T', 'A', 'T', 'S', 0x00}, }, } for _, tt := range tests { t.Run(tt.s, func(t *testing.T) { s := String(Bytes(tt.s)) if want, got := tt.s, s; want != got { t.Fatalf("unexpected string:\n- want: %q\n- got: %q", want, got) } }) } } func TestStringTrailingNull(t *testing.T) { const want = "hello world" // Buffer has many trailing NULL bytes which should all be removed. var b [64]byte copy(b[:], want) if diff := cmp.Diff(want, String(b[:])); diff != "" { t.Fatalf("unexpected string (-want +got):\n%s", diff) } } netlink-1.7.2/nltest/000077500000000000000000000000001442322526200145005ustar00rootroot00000000000000netlink-1.7.2/nltest/errors_others.go000066400000000000000000000001711442322526200177260ustar00rootroot00000000000000//go:build plan9 || windows // +build plan9 windows package nltest func isSyscallError(_ error) bool { return false } netlink-1.7.2/nltest/errors_unix.go000066400000000000000000000002671442322526200174130ustar00rootroot00000000000000//go:build !plan9 && !windows // +build !plan9,!windows package nltest import "golang.org/x/sys/unix" func isSyscallError(err error) bool { _, ok := err.(unix.Errno) return ok } netlink-1.7.2/nltest/nltest.go000066400000000000000000000137301442322526200163440ustar00rootroot00000000000000// Package nltest provides utilities for netlink testing. package nltest import ( "fmt" "io" "os" "github.com/mdlayher/netlink" "github.com/mdlayher/netlink/nlenc" ) // PID is the netlink header PID value assigned by nltest. const PID = 1 // MustMarshalAttributes marshals a slice of netlink.Attributes to their binary // format, but panics if any errors occur. func MustMarshalAttributes(attrs []netlink.Attribute) []byte { b, err := netlink.MarshalAttributes(attrs) if err != nil { panic(fmt.Sprintf("failed to marshal attributes to binary: %v", err)) } return b } // Multipart sends a slice of netlink.Messages to the caller as a // netlink multi-part message. If less than two messages are present, // the messages are not altered. func Multipart(msgs []netlink.Message) ([]netlink.Message, error) { if len(msgs) < 2 { return msgs, nil } for i := range msgs { // Last message has header type "done" in addition to multi-part flag. if i == len(msgs)-1 { msgs[i].Header.Type = netlink.Done } msgs[i].Header.Flags |= netlink.Multi } return msgs, nil } // Error returns a netlink error to the caller with the specified error // number, in the body of the specified request message. func Error(number int, reqs []netlink.Message) ([]netlink.Message, error) { req := reqs[0] req.Header.Length += 4 req.Header.Type = netlink.Error errno := -1 * int32(number) req.Data = append(nlenc.Int32Bytes(errno), req.Data...) return []netlink.Message{req}, nil } // A Func is a function that can be used to test netlink.Conn interactions. // The function can choose to return zero or more netlink messages, or an // error if needed. // // For a netlink request/response interaction, a request req is populated by // netlink.Conn.Send and passed to the function. // // For multicast interactions, an empty request req is passed to the function // when netlink.Conn.Receive is called. // // If a Func returns an error, the error will be returned as-is to the caller. // If no messages and io.EOF are returned, no messages and no error will be // returned to the caller, simulating a multi-part message with no data. type Func func(req []netlink.Message) ([]netlink.Message, error) // Dial sets up a netlink.Conn for testing using the specified Func. All requests // sent from the connection will be passed to the Func. The connection should be // closed as usual when it is no longer needed. func Dial(fn Func) *netlink.Conn { sock := &socket{ fn: fn, } return netlink.NewConn(sock, PID) } // CheckRequest returns a Func that verifies that each message in an incoming // request has the specified netlink header type and flags in the same slice // position index, and then passes the request through to fn. // // The length of the types and flags slices must match the number of requests // passed to the returned Func, or CheckRequest will panic. // // As an example: // - types[0] and flags[0] will be checked against reqs[0] // - types[1] and flags[1] will be checked against reqs[1] // - ... and so on // // If an element of types or flags is set to the zero value, that check will // be skipped for the request message that occurs at the same index. // // As an example, if types[0] is 0 and reqs[0].Header.Type is 1, the check will // succeed because types[0] was not specified. func CheckRequest(types []netlink.HeaderType, flags []netlink.HeaderFlags, fn Func) Func { if len(types) != len(flags) { panicf("nltest: CheckRequest called with mismatched types and flags slice lengths: %d != %d", len(types), len(flags)) } return func(req []netlink.Message) ([]netlink.Message, error) { if len(types) != len(req) { panicf("nltest: CheckRequest function invoked types/flags and request message slice lengths: %d != %d", len(types), len(req)) } for i := range req { if want, got := types[i], req[i].Header.Type; types[i] != 0 && want != got { return nil, fmt.Errorf("nltest: unexpected netlink header type: %s, want: %s", got, want) } if want, got := flags[i], req[i].Header.Flags; flags[i] != 0 && want != got { return nil, fmt.Errorf("nltest: unexpected netlink header flags: %s, want: %s", got, want) } } return fn(req) } } // A socket is a netlink.Socket used for testing. type socket struct { fn Func msgs []netlink.Message err error } func (c *socket) Close() error { return nil } func (c *socket) SendMessages(messages []netlink.Message) error { msgs, err := c.fn(messages) c.msgs = append(c.msgs, msgs...) c.err = err return nil } func (c *socket) Send(m netlink.Message) error { c.msgs, c.err = c.fn([]netlink.Message{m}) return nil } func (c *socket) Receive() ([]netlink.Message, error) { // No messages set by Send means that we are emulating a // multicast response or an error occurred. if len(c.msgs) == 0 { switch c.err { case nil: // No error, simulate multicast, but also return EOF to simulate // no replies if needed. msgs, err := c.fn(nil) if err == io.EOF { err = nil } return msgs, err case io.EOF: // EOF, simulate no replies in multi-part message. return nil, nil } // If the error is a system call error, wrap it in os.NewSyscallError // to simulate what the Linux netlink.Conn does. if isSyscallError(c.err) { return nil, os.NewSyscallError("recvmsg", c.err) } // Some generic error occurred and should be passed to the caller. return nil, c.err } // Detect multi-part messages. var multi bool for _, m := range c.msgs { if m.Header.Flags&netlink.Multi != 0 && m.Header.Type != netlink.Done { multi = true } } // When a multi-part message is detected, return all messages except for the // final "multi-part done", so that a second call to Receive from netlink.Conn // will drain that message. if multi { last := c.msgs[len(c.msgs)-1] ret := c.msgs[:len(c.msgs)-1] c.msgs = []netlink.Message{last} return ret, c.err } msgs, err := c.msgs, c.err c.msgs, c.err = nil, nil return msgs, err } func panicf(format string, a ...interface{}) { panic(fmt.Sprintf(format, a...)) } netlink-1.7.2/nltest/nltest_linux_test.go000066400000000000000000000017541442322526200206250ustar00rootroot00000000000000package nltest_test import ( "errors" "os" "testing" "github.com/mdlayher/netlink" "github.com/mdlayher/netlink/nltest" "golang.org/x/sys/unix" ) func TestLinuxDialError(t *testing.T) { c := nltest.Dial(func(req []netlink.Message) ([]netlink.Message, error) { return nltest.Error(int(unix.ENOENT), req) }) if _, err := c.Execute(netlink.Message{}); !errors.Is(err, os.ErrNotExist) { t.Fatalf("expected error is not exist, but got: %v", err) } } func TestLinuxSyscallError(t *testing.T) { c := nltest.Dial(func(req []netlink.Message) ([]netlink.Message, error) { return nil, unix.ENOENT }) _, err := c.Execute(netlink.Message{}) if !errors.Is(err, os.ErrNotExist) { t.Fatalf("expected error is not exist, but got: %v", err) } // Expect raw system call errors to be wrapped. var serr *os.SyscallError if !errors.As(err, &serr) { t.Fatalf("error did not contain *os.SyscallError") } if serr.Err != unix.ENOENT { t.Fatalf("expected ENOENT, but got: %v", serr.Err) } } netlink-1.7.2/nltest/nltest_test.go000066400000000000000000000256401442322526200174060ustar00rootroot00000000000000package nltest_test import ( "bytes" "encoding/binary" "errors" "io" "reflect" "testing" "github.com/google/go-cmp/cmp" "github.com/josharian/native" "github.com/mdlayher/netlink" "github.com/mdlayher/netlink/nltest" ) func TestConnSend(t *testing.T) { req := netlink.Message{ Data: []byte{0xff, 0xff, 0xff, 0xff}, } c := nltest.Dial(func(creq []netlink.Message) ([]netlink.Message, error) { if got, want := len(creq), 1; got != want { t.Fatalf("unexpected number of messages: got %d, want %d", got, want) } if want, got := req.Data, creq[0].Data; !bytes.Equal(want, got) { t.Fatalf("unexpected request data:\n- want: %v\n- got: %v", want, got) } return nil, nil }) defer c.Close() if _, err := c.Send(req); err != nil { t.Fatalf("failed to send request: %v", err) } } func TestConnReceiveMulticast(t *testing.T) { msgs := []netlink.Message{{ Data: []byte{0xff, 0xff, 0xff, 0xff}, }} c := nltest.Dial(func(zero []netlink.Message) ([]netlink.Message, error) { if zero == nil { return msgs, nil } if want, got := (netlink.Message{}), zero; !reflect.DeepEqual(want, got) { t.Fatalf("unexpected zero message:\n- want: %v\n- got: %v", want, got) } return msgs, nil }) defer c.Close() got, err := c.Receive() if err != nil { t.Fatalf("failed to receive messages: %v", err) } if want := msgs; !reflect.DeepEqual(want, got) { t.Fatalf("unexpected multicast messages:\n- want: %v\n- got: %v", want, got) } } func TestConnReceiveNoMessages(t *testing.T) { c := nltest.Dial(func(_ []netlink.Message) ([]netlink.Message, error) { return nil, io.EOF }) defer c.Close() msgs, err := c.Receive() if err != nil { t.Fatalf("failed to execute: %v", err) } if l := len(msgs); l > 0 { t.Fatalf("expected no messages, but got: %d", l) } } func TestConnReceiveError(t *testing.T) { errFoo := errors.New("foo") c := nltest.Dial(func(_ []netlink.Message) ([]netlink.Message, error) { return nil, errFoo }) defer c.Close() want := &netlink.OpError{ Op: "receive", Err: errFoo, } _, err := c.Receive() if diff := cmp.Diff(want.Error(), err.Error()); diff != "" { t.Fatalf("unexpected error (-want +got):\n%s", diff) } } func TestConnExecuteOK(t *testing.T) { req := netlink.Message{ Header: netlink.Header{ Length: 16, Flags: netlink.Request, Sequence: 1, PID: 1, }, } c := nltest.Dial(func(creq []netlink.Message) ([]netlink.Message, error) { // Turn the request back around to the client. return creq, nil }) defer c.Close() got, err := c.Execute(req) if err != nil { t.Fatalf("failed to execute request: %v", err) } if want := []netlink.Message{req}; !reflect.DeepEqual(want, got) { t.Fatalf("unexpected response messages:\n- want: %v\n- got: %v", want, got) } } func TestConnExecuteMultipartOK(t *testing.T) { req := netlink.Message{ Header: netlink.Header{ Length: 16, Flags: netlink.Request, Sequence: 1, PID: 1, }, } c := nltest.Dial(func(creq []netlink.Message) ([]netlink.Message, error) { // Client should only receive one message with multipart flag set. // TODO: append(creq, creq)? creqs := make([]netlink.Message, 2*len(creq)) copy(creqs, creq) copy(creqs[len(creq):], creq) return nltest.Multipart(creqs) }) defer c.Close() got, err := c.Execute(req) if err != nil { t.Fatalf("failed to execute request: %v", err) } req.Header.Flags |= netlink.Multi if want := []netlink.Message{req}; !reflect.DeepEqual(want, got) { t.Fatalf("unexpected response messages:\n- want: %v\n- got: %v", want, got) } } func TestConnExecuteError(t *testing.T) { err := errors.New("foo") c := nltest.Dial(func(creq []netlink.Message) ([]netlink.Message, error) { // Error should be surfaced by Execute's call to Receive. return nil, err }) defer c.Close() want := &netlink.OpError{ Op: "receive", Err: err, } _, got := c.Execute(netlink.Message{}) if diff := cmp.Diff(want.Error(), got.Error()); diff != "" { t.Fatalf("unexpected error (-want +got):\n%s", diff) } } func TestConnExecuteNoMessages(t *testing.T) { c := nltest.Dial(func(_ []netlink.Message) ([]netlink.Message, error) { return nil, io.EOF }) defer c.Close() msgs, err := c.Execute(netlink.Message{}) if err != nil { t.Fatalf("failed to execute: %v", err) } if l := len(msgs); l > 0 { t.Fatalf("expected no messages, but got: %d", l) } } func TestError(t *testing.T) { skipBigEndian(t) const ( eperm = 1 enoent = 2 ) tests := []struct { name string number int in []netlink.Message out []netlink.Message }{ { name: "EPERM", number: eperm, in: []netlink.Message{ { Header: netlink.Header{ Length: 24, Flags: netlink.Request | netlink.Dump, Sequence: 10, PID: 1000, }, Data: []byte{0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88}, }, }, out: []netlink.Message{{ Header: netlink.Header{ Length: 28, Type: netlink.Error, Flags: netlink.Request | netlink.Dump, Sequence: 10, PID: 1000, }, Data: []byte{ 0xff, 0xff, 0xff, 0xff, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, }, }}, }, { name: "ENOENT", number: enoent, in: []netlink.Message{ { Header: netlink.Header{ Length: 20, Flags: netlink.Request, Sequence: 1, PID: 100, }, Data: []byte{0x11, 0x22, 0x33, 0x44}, }, }, out: []netlink.Message{{ Header: netlink.Header{ Length: 24, Type: netlink.Error, Flags: netlink.Request, Sequence: 1, PID: 100, }, Data: []byte{ 0xfe, 0xff, 0xff, 0xff, 0x11, 0x22, 0x33, 0x44, }, }}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { out, err := nltest.Error(tt.number, tt.in) if err != nil { t.Fatalf("unexpected error: %v", err) } if want, got := tt.out, out; !reflect.DeepEqual(want, got) { t.Fatalf("unexpected output messages:\n- want: %v\n- got: %v", want, got) } }) } } func TestMultipart(t *testing.T) { tests := []struct { name string in []netlink.Message out []netlink.Message }{ { name: "no messages", }, { name: "one message, no changes", in: []netlink.Message{{ Header: netlink.Header{ Length: 20, }, Data: []byte{0xff, 0xff, 0xff, 0xff}, }}, out: []netlink.Message{{ Header: netlink.Header{ Length: 20, }, Data: []byte{0xff, 0xff, 0xff, 0xff}, }}, }, { name: "two messages, multipart", in: []netlink.Message{ { Header: netlink.Header{ Length: 20, }, Data: []byte{0xff, 0xff, 0xff, 0xff}, }, { Header: netlink.Header{ Length: 16, }, }, }, out: []netlink.Message{ { Header: netlink.Header{ Length: 20, Flags: netlink.Multi, }, Data: []byte{0xff, 0xff, 0xff, 0xff}, }, { Header: netlink.Header{ Length: 16, Type: netlink.Done, Flags: netlink.Multi, }, }, }, }, { name: "three messages, multipart", in: []netlink.Message{ { Header: netlink.Header{ Length: 20, }, Data: []byte{0xff, 0xff, 0xff, 0xff}, }, { Header: netlink.Header{ Length: 24, }, Data: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, }, { Header: netlink.Header{ Length: 16, }, }, }, out: []netlink.Message{ { Header: netlink.Header{ Length: 20, Flags: netlink.Multi, }, Data: []byte{0xff, 0xff, 0xff, 0xff}, }, { Header: netlink.Header{ Length: 24, Flags: netlink.Multi, }, Data: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}, }, { Header: netlink.Header{ Length: 16, Type: netlink.Done, Flags: netlink.Multi, }, }, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { out, err := nltest.Multipart(tt.in) if err != nil { t.Fatalf("unexpected error: %v", err) } if want, got := tt.out, out; !reflect.DeepEqual(want, got) { t.Fatalf("unexpected output messages:\n- want: %v\n- got: %v", want, got) } }) } } func TestCheckRequestPanic(t *testing.T) { tests := []struct { name string types []netlink.HeaderType flags []netlink.HeaderFlags reqs []netlink.Message }{ { name: "types", types: []netlink.HeaderType{0}, }, { name: "flags", flags: []netlink.HeaderFlags{0}, }, { name: "requests", types: []netlink.HeaderType{0}, flags: []netlink.HeaderFlags{0}, reqs: []netlink.Message{{}, {}}, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { defer func() { if r := recover(); r == nil { t.Fatal("expected a panic, but none occurred") } }() fn := nltest.CheckRequest(tt.types, tt.flags, noop) fn(tt.reqs) }) } } func TestCheckRequest(t *testing.T) { tests := []struct { name string types []netlink.HeaderType flags []netlink.HeaderFlags reqs []netlink.Message ok bool }{ { name: "no checking", types: []netlink.HeaderType{0}, flags: []netlink.HeaderFlags{0}, reqs: []netlink.Message{{}}, ok: true, }, { name: "type only", types: []netlink.HeaderType{10}, flags: []netlink.HeaderFlags{0}, reqs: []netlink.Message{{ Header: netlink.Header{ Type: 10, Flags: netlink.Request, }, }}, ok: true, }, { name: "flags only", types: []netlink.HeaderType{0}, flags: []netlink.HeaderFlags{netlink.Request}, reqs: []netlink.Message{{ Header: netlink.Header{ Type: 10, Flags: netlink.Request, }, }}, ok: true, }, { name: "bad type", types: []netlink.HeaderType{10, 20}, flags: []netlink.HeaderFlags{netlink.Request, netlink.Replace}, reqs: []netlink.Message{ { Header: netlink.Header{ Type: 10, Flags: netlink.Request, }, }, { Header: netlink.Header{ Type: 99, Flags: netlink.Replace, }, }, }, }, { name: "bad flags", types: []netlink.HeaderType{10, 20}, flags: []netlink.HeaderFlags{netlink.Request, netlink.Replace}, reqs: []netlink.Message{ { Header: netlink.Header{ Type: 10, Flags: netlink.Request, }, }, { Header: netlink.Header{ Type: 20, Flags: netlink.Request, }, }, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { fn := nltest.CheckRequest(tt.types, tt.flags, noop) _, err := fn(tt.reqs) if err != nil && tt.ok { t.Fatalf("unexpected error: %v", err) } if err == nil && !tt.ok { t.Fatal("expected an error, but none occurred") } }) } } var noop = func(req []netlink.Message) ([]netlink.Message, error) { return nil, nil } func skipBigEndian(t *testing.T) { if binary.ByteOrder(native.Endian) == binary.BigEndian { t.Skip("skipping test on big-endian system") } }