pax_global_header00006660000000000000000000000064145242726050014521gustar00rootroot0000000000000052 comment=86973d514e41d284888068bdde79cb706e996278 golang-github-pin-tftp-3.1.0/000077500000000000000000000000001452427260500160105ustar00rootroot00000000000000golang-github-pin-tftp-3.1.0/.github/000077500000000000000000000000001452427260500173505ustar00rootroot00000000000000golang-github-pin-tftp-3.1.0/.github/workflows/000077500000000000000000000000001452427260500214055ustar00rootroot00000000000000golang-github-pin-tftp-3.1.0/.github/workflows/macos.yml000066400000000000000000000005721452427260500232360ustar00rootroot00000000000000name: MacOS test on: push: branches: [ "master" ] pull_request: branches: [ "master" ] jobs: build: runs-on: macos-latest steps: - uses: actions/checkout@v3 - name: Set up Go uses: actions/setup-go@v3 with: go-version: 1.13 - name: Build run: go build -v ./... - name: Test run: go test -v ./... -race golang-github-pin-tftp-3.1.0/.github/workflows/ubuntu.yml000066400000000000000000000005731452427260500234570ustar00rootroot00000000000000name: Linux test on: push: branches: [ "master" ] pull_request: branches: [ "master" ] jobs: build: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - name: Set up Go uses: actions/setup-go@v3 with: go-version: 1.18 - name: Build run: go build -v ./... - name: Test run: go test -v ./... -race golang-github-pin-tftp-3.1.0/.github/workflows/windows.yml000066400000000000000000000005761452427260500236320ustar00rootroot00000000000000name: Windows test on: push: branches: [ "master" ] pull_request: branches: [ "master" ] jobs: build: runs-on: windows-latest steps: - uses: actions/checkout@v3 - name: Set up Go uses: actions/setup-go@v3 with: go-version: 1.18 - name: Build run: go build -v ./... - name: Test run: go test -v ./... -race golang-github-pin-tftp-3.1.0/.gitignore000066400000000000000000000004121452427260500177750ustar00rootroot00000000000000# Compiled Object files, Static and Dynamic libs (Shared Objects) *.o *.a *.so # Folders _obj _test # Architecture specific extensions/prefixes *.[568vq] [568vq].out *.cgo1.go *.cgo2.c _cgo_defun.c _cgo_gotypes.go _cgo_export.* _testmain.go *.exe *.test *.prof golang-github-pin-tftp-3.1.0/CONTRIBUTORS000066400000000000000000000001601452427260500176650ustar00rootroot00000000000000Dmitri Popov Mojo Talantikite Giovanni Bajo Andrew Danforth Victor Lowther minghuadev on github.com Owen Mooney golang-github-pin-tftp-3.1.0/LICENSE000066400000000000000000000020661452427260500170210ustar00rootroot00000000000000The MIT License (MIT) Copyright (c) 2016 Dmitri Popov Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. golang-github-pin-tftp-3.1.0/README.md000066400000000000000000000127261452427260500172770ustar00rootroot00000000000000TFTP server and client library for Golang ========================================= [![GoDoc](https://godoc.org/github.com/pin/tftp/v3?status.svg)](https://godoc.org/github.com/pin/tftp/v3) Implements: * [RFC 1350](https://tools.ietf.org/html/rfc1350) - The TFTP Protocol (Revision 2) * [RFC 2347](https://tools.ietf.org/html/rfc2347) - TFTP Option Extension * [RFC 2348](https://tools.ietf.org/html/rfc2348) - TFTP Blocksize Option Partially implements (tsize server side only): * [RFC 2349](https://tools.ietf.org/html/rfc2349) - TFTP Timeout Interval and Transfer Size Options Set of features is sufficient for PXE boot support. ``` go import "github.com/pin/tftp/v3" ``` The package is cohesive to Golang `io` and implements `io.ReaderFrom` and `io.WriterTo` interfaces. That allows efficient data transmission without unnecessary memory copying and allocations. TFTP Server ----------- ```go // readHandler is called when client starts file download from server func readHandler(filename string, rf io.ReaderFrom) error { file, err := os.Open(filename) if err != nil { fmt.Fprintf(os.Stderr, "%v\n", err) return err } n, err := rf.ReadFrom(file) if err != nil { fmt.Fprintf(os.Stderr, "%v\n", err) return err } fmt.Printf("%d bytes sent\n", n) return nil } // writeHandler is called when client starts file upload to server func writeHandler(filename string, wt io.WriterTo) error { file, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0644) if err != nil { fmt.Fprintf(os.Stderr, "%v\n", err) return err } n, err := wt.WriteTo(file) if err != nil { fmt.Fprintf(os.Stderr, "%v\n", err) return err } fmt.Printf("%d bytes received\n", n) return nil } func main() { // use nil in place of handler to disable read or write operations s := tftp.NewServer(readHandler, writeHandler) s.SetTimeout(5 * time.Second) // optional err := s.ListenAndServe(":69") // blocks until s.Shutdown() is called if err != nil { fmt.Fprintf(os.Stdout, "server: %v\n", err) os.Exit(1) } } ``` See [gotftpd](https://github.com/pin/golang-tftp-example/blob/master/src/gotftpd/main.go) in [golang-tftp-example](https://github.com/pin/golang-tftp-example) repository for working code. TFTP Client ----------- Upload file to server: ```go c, err := tftp.NewClient("172.16.4.21:69") file, err := os.Open(path) c.SetTimeout(5 * time.Second) // optional rf, err := c.Send("foobar.txt", "octet") n, err := rf.ReadFrom(file) fmt.Printf("%d bytes sent\n", n) ``` Download file from server: ```go c, err := tftp.NewClient("172.16.4.21:69") wt, err := c.Receive("foobar.txt", "octet") file, err := os.Create(path) // Optionally obtain transfer size before actual data. if n, ok := wt.(tftp.IncomingTransfer).Size(); ok { fmt.Printf("Transfer size: %d\n", n) } n, err := wt.WriteTo(file) fmt.Printf("%d bytes received\n", n) ``` See [goftp](https://github.com/pin/golang-tftp-example/blob/master/src/gotftp/main.go) in [golang-tftp-example](https://github.com/pin/golang-tftp-example) repository for working code. TSize option ------------ PXE boot ROM often expects tsize option support from a server: client (e.g. computer that boots over the network) wants to know size of a download before the actual data comes. Server has to obtain stream size and send it to a client. Often it will happen automatically because TFTP library tries to check if `io.Reader` provided to `ReadFrom` method also satisfies `io.Seeker` interface (`os.File` for instance) and uses `Seek` to determine file size. In case `io.Reader` you provide to `ReadFrom` in read handler does not satisfy `io.Seeker` interface or you do not want TFTP library to call `Seek` on your reader but still want to respond with tsize option during outgoing request you can use an `OutgoingTransfer` interface: ```go func readHandler(filename string, rf io.ReaderFrom) error { ... // Set transfer size before calling ReadFrom. rf.(tftp.OutgoingTransfer).SetSize(myFileSize) ... // ReadFrom ... ``` Similarly, it is possible to obtain size of a file that is about to be received using `IncomingTransfer` interface (see `Size` method). Local and Remote Address ------------------------ The `OutgoingTransfer` and `IncomingTransfer` interfaces also provide the `RemoteAddr` method which returns the peer IP address and port as a `net.UDPAddr`. The `RequestPacketInfo` interface provides a `LocalIP` method with returns the local IP address as a `net.IP` that the request is being handled on. These can be used for detailed logging in a server handler, among other things. Note that LocalIP may return nil or an unspecified IP address if finding that is not supported on a particular operating system by the Go net libraries, or if you call it as a TFTP client. ```go func readHandler(filename string, rf io.ReaderFrom) error { ... raddr := rf.(tftp.OutgoingTransfer).RemoteAddr() laddr := rf.(tftp.RequestPacketInfo).LocalIP() log.Println("RRQ from", raddr.String(), "To ",laddr.String()) log.Println("") ... // ReadFrom ... ``` Backoff ------- The default backoff before retransmitting an unacknowledged packet is a random duration between 0 and 1 second. This behavior can be overridden in clients and servers by providing a custom backoff calculation function. ```go s := tftp.NewServer(readHandler, writeHandler) s.SetBackoff(func (attempts int) time.Duration { return time.Duration(attempts) * time.Second }) ``` or, for no backoff ```go s.SetBackoff(func (int) time.Duration { return 0 }) ``` golang-github-pin-tftp-3.1.0/backoff.go000066400000000000000000000007511452427260500177350ustar00rootroot00000000000000package tftp import ( "math/rand" "time" ) const ( defaultTimeout = 5 * time.Second defaultRetries = 5 ) type backoffFunc func(int) time.Duration type backoff struct { attempt int handler backoffFunc } func (b *backoff) reset() { b.attempt = 0 } func (b *backoff) count() int { return b.attempt } func (b *backoff) backoff() { if b.handler == nil { time.Sleep(time.Duration(rand.Int63n(int64(time.Second)))) } else { time.Sleep(b.handler(b.attempt)) } b.attempt++ } golang-github-pin-tftp-3.1.0/client.go000066400000000000000000000063161452427260500176230ustar00rootroot00000000000000package tftp import ( "fmt" "io" "net" "strconv" "time" ) // NewClient creates TFTP client for server on address provided. func NewClient(addr string) (*Client, error) { a, err := net.ResolveUDPAddr("udp", addr) if err != nil { return nil, fmt.Errorf("resolving address %s: %v", addr, err) } return &Client{ addr: a, timeout: defaultTimeout, retries: defaultRetries, }, nil } // SetTimeout sets maximum time client waits for single network round-trip to succeed. // Default is 5 seconds. func (c *Client) SetTimeout(t time.Duration) { if t <= 0 { c.timeout = defaultTimeout } c.timeout = t } // SetRetries sets maximum number of attempts client made to transmit a packet. // Default is 5 attempts. func (c *Client) SetRetries(count int) { if count < 1 { c.retries = defaultRetries } c.retries = count } // SetBackoff sets a user provided function that is called to provide a // backoff duration prior to retransmitting an unacknowledged packet. func (c *Client) SetBackoff(h backoffFunc) { c.backoff = h } // SetBlockSize sets a custom block size used in the transmission. func (c *Client) SetBlockSize(s int) { c.blksize = s } // RequestTSize sets flag to indicate if tsize should be requested. func (c *Client) RequestTSize(s bool) { c.tsize = s } // Client stores data about a single TFTP client type Client struct { addr *net.UDPAddr timeout time.Duration retries int backoff backoffFunc blksize int tsize bool } // Send starts outgoing file transmission. It returns io.ReaderFrom or error. func (c Client) Send(filename string, mode string) (io.ReaderFrom, error) { conn, err := net.ListenUDP("udp", &net.UDPAddr{}) if err != nil { return nil, err } s := &sender{ send: make([]byte, datagramLength), receive: make([]byte, datagramLength), conn: &connConnection{conn: conn}, retry: &backoff{handler: c.backoff}, timeout: c.timeout, retries: c.retries, addr: c.addr, mode: mode, } if c.blksize != 0 { s.opts = make(options) s.opts["blksize"] = strconv.Itoa(c.blksize) } n := packRQ(s.send, opWRQ, filename, mode, s.opts) addr, err := s.sendWithRetry(n) if err != nil { return nil, err } s.addr = addr s.opts = nil return s, nil } // Receive starts incoming file transmission. It returns io.WriterTo or error. func (c Client) Receive(filename string, mode string) (io.WriterTo, error) { conn, err := net.ListenUDP("udp", &net.UDPAddr{}) if err != nil { return nil, err } if c.timeout == 0 { c.timeout = defaultTimeout } r := &receiver{ send: make([]byte, datagramLength), receive: make([]byte, datagramLength), conn: &connConnection{conn: conn}, retry: &backoff{handler: c.backoff}, timeout: c.timeout, retries: c.retries, addr: c.addr, autoTerm: true, block: 1, mode: mode, } if c.blksize != 0 || c.tsize { r.opts = make(options) } if c.blksize != 0 { r.opts["blksize"] = strconv.Itoa(c.blksize) // Clean it up so we don't send options twice defer func() { delete(r.opts, "blksize") }() } if c.tsize { r.opts["tsize"] = "0" } n := packRQ(r.send, opRRQ, filename, mode, r.opts) l, addr, err := r.receiveWithRetry(n) if err != nil { return nil, err } r.l = l r.addr = addr return r, nil } golang-github-pin-tftp-3.1.0/connection.go000066400000000000000000000044401452427260500205000ustar00rootroot00000000000000package tftp import ( "fmt" "net" "time" "golang.org/x/net/ipv6" "golang.org/x/net/ipv4" ) type connectionError struct { error timeout bool temporary bool } func (t *connectionError) Timeout() bool { return t.timeout } func (t *connectionError) Temporary() bool { return t.temporary } type connection interface { sendTo([]byte, *net.UDPAddr) error readFrom([]byte) (int, *net.UDPAddr, error) setDeadline(time.Duration) error close() } type connConnection struct { conn *net.UDPConn } type chanConnection struct { server *Server channel chan []byte srcAddr, addr *net.UDPAddr timeout time.Duration complete chan string } func (c *chanConnection) sendTo(data []byte, addr *net.UDPAddr) error { var err error c.server.Lock() defer c.server.Unlock() if conn, ok := c.server.conn.(*net.UDPConn); ok { srcAddr := c.srcAddr.IP.To4() var cmm []byte if srcAddr != nil { cm := &ipv4.ControlMessage{Src: srcAddr} cmm = cm.Marshal() } else { cm := &ipv6.ControlMessage{Src: c.srcAddr.IP} cmm = cm.Marshal() } _, _, err = conn.WriteMsgUDP(data, cmm, c.addr) } else { _, err = c.server.conn.WriteTo(data, addr) } return err } func (c *chanConnection) readFrom(buffer []byte) (int, *net.UDPAddr, error) { select { case data := <-c.channel: for i := range data { buffer[i] = data[i] } return len(data), c.addr, nil case <-time.After(c.timeout): return 0, nil, makeError(c.addr.String()) } } func (c *chanConnection) setDeadline(deadline time.Duration) error { c.timeout = deadline return nil } func (c *chanConnection) close() { c.server.Lock() defer c.server.Unlock() close(c.channel) delete(c.server.handlers, c.addr.String()) } func (c *connConnection) sendTo(data []byte, addr *net.UDPAddr) error { _, err := c.conn.WriteToUDP(data, addr) return err } func makeError(addr string) net.Error { error := connectionError{ timeout: true, temporary: true, } error.error = fmt.Errorf("Channel timeout: %v", addr) return &error } func (c *connConnection) readFrom(buffer []byte) (int, *net.UDPAddr, error) { return c.conn.ReadFromUDP(buffer) } func (c *connConnection) setDeadline(deadline time.Duration) error { return c.conn.SetReadDeadline(time.Now().Add(deadline)) } func (c *connConnection) close() { c.conn.Close() } golang-github-pin-tftp-3.1.0/go.mod000066400000000000000000000001441452427260500171150ustar00rootroot00000000000000module github.com/pin/tftp/v3 go 1.13 require golang.org/x/net v0.0.0-20200202094626-16171245cfb2 golang-github-pin-tftp-3.1.0/go.sum000066400000000000000000000011341452427260500171420ustar00rootroot00000000000000golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/net v0.0.0-20200202094626-16171245cfb2 h1:CCH4IOTTfewWjGOlSp+zGcjutRKlBEZQ6wTn8ozI/nI= golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang-github-pin-tftp-3.1.0/netascii/000077500000000000000000000000001452427260500176075ustar00rootroot00000000000000golang-github-pin-tftp-3.1.0/netascii/netascii.go000066400000000000000000000026441452427260500217430ustar00rootroot00000000000000package netascii // TODO: make it work not only on linux import "io" const ( CR = '\x0d' LF = '\x0a' NUL = '\x00' ) func ToReader(r io.Reader) io.Reader { return &toReader{ r: r, buf: make([]byte, 256), } } type toReader struct { r io.Reader buf []byte n int i int err error lf bool nul bool } func (r *toReader) Read(p []byte) (int, error) { var n int for n < len(p) { if r.lf { p[n] = LF n++ r.lf = false continue } if r.nul { p[n] = NUL n++ r.nul = false continue } if r.i < r.n { if r.buf[r.i] == LF { p[n] = CR r.lf = true } else if r.buf[r.i] == CR { p[n] = CR r.nul = true } else { p[n] = r.buf[r.i] } r.i++ n++ continue } if r.err == nil { r.n, r.err = r.r.Read(r.buf) r.i = 0 } else { return n, r.err } } return n, r.err } type fromWriter struct { w io.Writer buf []byte i int cr bool } func FromWriter(w io.Writer) io.Writer { return &fromWriter{ w: w, buf: make([]byte, 256), } } func (w *fromWriter) Write(p []byte) (n int, err error) { for n < len(p) { if w.cr { if p[n] == LF { w.buf[w.i] = LF } if p[n] == NUL { w.buf[w.i] = CR } w.cr = false w.i++ } else if p[n] == CR { w.cr = true } else { w.buf[w.i] = p[n] w.i++ } n++ if w.i == len(w.buf) || n == len(p) { _, err = w.w.Write(w.buf[:w.i]) w.i = 0 } } return n, err } golang-github-pin-tftp-3.1.0/netascii/netascii_test.go000066400000000000000000000050101452427260500227700ustar00rootroot00000000000000package netascii import ( "bytes" "io/ioutil" "strings" "testing" "testing/iotest" ) var basic = map[string]string{ "\r": "\r\x00", "\n": "\r\n", "la\nbu": "la\r\nbu", "la\rbu": "la\r\x00bu", "\r\r\r": "\r\x00\r\x00\r\x00", "\n\n\n": "\r\n\r\n\r\n", } func TestTo(t *testing.T) { for text, netascii := range basic { to := ToReader(strings.NewReader(text)) n, _ := ioutil.ReadAll(to) if !bytes.Equal(n, []byte(netascii)) { t.Errorf("%q to netascii: %q != %q", text, n, netascii) } } } func TestFrom(t *testing.T) { for text, netascii := range basic { r := bytes.NewReader([]byte(netascii)) b := &bytes.Buffer{} from := FromWriter(b) r.WriteTo(from) n, _ := ioutil.ReadAll(b) if string(n) != text { t.Errorf("%q from netascii: %q != %q", netascii, n, text) } } } const text = ` Therefore, the sequence "CR LF" must be treated as a single "new line" character and used whenever their combined action is intended; the sequence "CR NUL" must be used where a carriage return alone is actually desired; and the CR character must be avoided in other contexts. This rule gives assurance to systems which must decide whether to perform a "new line" function or a multiple-backspace that the TELNET stream contains a character following a CR that will allow a rational decision. (in the default ASCII mode), to preserve the symmetry of the NVT model. Even though it may be known in some situations (e.g., with remote echo and suppress go ahead options in effect) that characters are not being sent to an actual printer, nonetheless, for the sake of consistency, the protocol requires that a NUL be inserted following a CR not followed by a LF in the data stream. The converse of this is that a NUL received in the data stream after a CR (in the absence of options negotiations which explicitly specify otherwise) should be stripped out prior to applying the NVT to local character set mapping. ` func TestWriteRead(t *testing.T) { var one bytes.Buffer to := ToReader(strings.NewReader(text)) one.ReadFrom(to) two := &bytes.Buffer{} from := FromWriter(two) one.WriteTo(from) text2, _ := ioutil.ReadAll(two) if text != string(text2) { t.Errorf("text mismatch \n%x \n%x", text, text2) } } func TestOneByte(t *testing.T) { var one bytes.Buffer to := iotest.OneByteReader(ToReader(strings.NewReader(text))) one.ReadFrom(to) two := &bytes.Buffer{} from := FromWriter(two) one.WriteTo(from) text2, _ := ioutil.ReadAll(two) if text != string(text2) { t.Errorf("text mismatch \n%x \n%x", text, text2) } } golang-github-pin-tftp-3.1.0/packet.go000066400000000000000000000077671452427260500176270ustar00rootroot00000000000000package tftp import ( "bytes" "encoding/binary" "fmt" ) const ( opRRQ = uint16(1) // Read request (RRQ) opWRQ = uint16(2) // Write request (WRQ) opDATA = uint16(3) // Data opACK = uint16(4) // Acknowledgement opERROR = uint16(5) // Error opOACK = uint16(6) // Options Acknowledgment ) const ( blockLength = 512 datagramLength = 516 ) type options map[string]string // RRQ/WRQ packet // // 2 bytes string 1 byte string 1 byte // -------------------------------------------------- // | Opcode | Filename | 0 | Mode | 0 | // -------------------------------------------------- type pRRQ []byte type pWRQ []byte // packRQ returns length of the packet in b func packRQ(p []byte, op uint16, filename, mode string, opts options) int { binary.BigEndian.PutUint16(p, op) n := 2 n += copy(p[2:len(p)-10], filename) p[n] = 0 n++ n += copy(p[n:], mode) p[n] = 0 n++ for name, value := range opts { n += copy(p[n:], name) p[n] = 0 n++ n += copy(p[n:], value) p[n] = 0 n++ } return n } func unpackRQ(p []byte) (filename, mode string, opts options, err error) { bs := bytes.Split(p[2:], []byte{0}) if len(bs) < 2 { return "", "", nil, fmt.Errorf("missing filename or mode") } filename = string(bs[0]) mode = string(bs[1]) if len(bs) < 4 { return filename, mode, nil, nil } opts = make(options) for i := 2; i+1 < len(bs); i += 2 { opts[string(bs[i])] = string(bs[i+1]) } return filename, mode, opts, nil } // OACK packet // // +----------+---~~---+---+---~~---+---+---~~---+---+---~~---+---+ // | Opcode | opt1 | 0 | value1 | 0 | optN | 0 | valueN | 0 | // +----------+---~~---+---+---~~---+---+---~~---+---+---~~---+---+ type pOACK []byte func packOACK(p []byte, opts options) int { binary.BigEndian.PutUint16(p, opOACK) n := 2 for name, value := range opts { n += copy(p[n:], name) p[n] = 0 n++ n += copy(p[n:], value) p[n] = 0 n++ } return n } func unpackOACK(p []byte) (opts options, err error) { bs := bytes.Split(p[2:], []byte{0}) opts = make(options) for i := 0; i+1 < len(bs); i += 2 { opts[string(bs[i])] = string(bs[i+1]) } return opts, nil } // ERROR packet // // 2 bytes 2 bytes string 1 byte // ------------------------------------------ // | Opcode | ErrorCode | ErrMsg | 0 | // ------------------------------------------ type pERROR []byte func packERROR(p []byte, code uint16, message string) int { binary.BigEndian.PutUint16(p, opERROR) binary.BigEndian.PutUint16(p[2:], code) n := copy(p[4:len(p)-2], message) p[4+n] = 0 return n + 5 } func (p pERROR) code() uint16 { return binary.BigEndian.Uint16(p[2:]) } func (p pERROR) message() string { return string(p[4:]) } // DATA packet // // 2 bytes 2 bytes n bytes // ---------------------------------- // | Opcode | Block # | Data | // ---------------------------------- type pDATA []byte func (p pDATA) block() uint16 { return binary.BigEndian.Uint16(p[2:]) } // ACK packet // // 2 bytes 2 bytes // ----------------------- // | Opcode | Block # | // ----------------------- type pACK []byte func (p pACK) block() uint16 { return binary.BigEndian.Uint16(p[2:]) } func parsePacket(p []byte) (interface{}, error) { l := len(p) if l < 2 { return nil, fmt.Errorf("short packet") } opcode := binary.BigEndian.Uint16(p) switch opcode { case opRRQ: if l < 4 { return nil, fmt.Errorf("short RRQ packet: %d", l) } return pRRQ(p), nil case opWRQ: if l < 4 { return nil, fmt.Errorf("short WRQ packet: %d", l) } return pWRQ(p), nil case opDATA: if l < 4 { return nil, fmt.Errorf("short DATA packet: %d", l) } return pDATA(p), nil case opACK: if l < 4 { return nil, fmt.Errorf("short ACK packet: %d", l) } return pACK(p), nil case opERROR: if l < 5 { return nil, fmt.Errorf("short ERROR packet: %d", l) } return pERROR(p), nil case opOACK: if l < 6 { return nil, fmt.Errorf("short OACK packet: %d", l) } return pOACK(p), nil default: return nil, fmt.Errorf("unknown opcode: %d", opcode) } } golang-github-pin-tftp-3.1.0/receiver.go000066400000000000000000000130071452427260500201440ustar00rootroot00000000000000package tftp import ( "encoding/binary" "fmt" "io" "net" "strconv" "time" "github.com/pin/tftp/v3/netascii" ) // IncomingTransfer provides methods that expose information associated with // an incoming transfer. type IncomingTransfer interface { // Size returns the size of an incoming file if the request included the // tsize option (see RFC2349). To differentiate a zero-sized file transfer // from a request without tsize use the second boolean "ok" return value. Size() (n int64, ok bool) // RemoteAddr returns the remote peer's IP address and port. RemoteAddr() net.UDPAddr } func (r *receiver) RemoteAddr() net.UDPAddr { return *r.addr } func (r *receiver) LocalIP() net.IP { return r.localIP } func (r *receiver) Size() (n int64, ok bool) { if r.opts != nil { if s, ok := r.opts["tsize"]; ok { n, err := strconv.ParseInt(s, 10, 64) if err != nil { return 0, false } return n, true } } return 0, false } type receiver struct { send []byte receive []byte addr *net.UDPAddr filename string localIP net.IP tid int conn connection block uint16 retry *backoff timeout time.Duration retries int l int autoTerm bool dally bool mode string opts options singlePort bool maxBlockLen int hook Hook startTime time.Time datagramsSent int datagramsAcked int } func (r *receiver) WriteTo(w io.Writer) (n int64, err error) { if r.mode == "netascii" { w = netascii.FromWriter(w) } if r.opts != nil { err := r.sendOptions() if err != nil { r.abort(err) return 0, err } } binary.BigEndian.PutUint16(r.send[0:2], opACK) for { if r.l > 0 { l, err := w.Write(r.receive[4:r.l]) n += int64(l) if err != nil { r.abort(err) return n, err } if r.l < len(r.receive) { if r.autoTerm { r.terminate() } return n, nil } } binary.BigEndian.PutUint16(r.send[2:4], r.block) r.block++ // send ACK for current block and expect next one ll, _, err := r.receiveWithRetry(4) if err != nil { r.abort(err) return n, err } r.l = ll } } func (r *receiver) sendOptions() error { for name, value := range r.opts { if name == "blksize" { err := r.setBlockSize(value) if err != nil { delete(r.opts, name) continue } } else { delete(r.opts, name) } } if len(r.opts) > 0 { m := packOACK(r.send, r.opts) r.block = 1 // expect data block number 1 ll, _, err := r.receiveWithRetry(m) if err != nil { r.abort(err) return err } r.l = ll } return nil } func (r *receiver) setBlockSize(blksize string) error { n, err := strconv.Atoi(blksize) if err != nil { return err } if n < 512 { return fmt.Errorf("blksize too small: %d", n) } if n > 65464 { return fmt.Errorf("blksize too large: %d", n) } if r.maxBlockLen > 0 && n > r.maxBlockLen { n = r.maxBlockLen r.opts["blksize"] = strconv.Itoa(n) } r.receive = make([]byte, n+4) return nil } func (r *receiver) receiveWithRetry(l int) (int, *net.UDPAddr, error) { r.retry.reset() for { n, addr, err := r.receiveDatagram(l) if _, ok := err.(net.Error); ok && r.retry.count() < r.retries { r.retry.backoff() continue } return n, addr, err } } func (r *receiver) receiveDatagram(l int) (int, *net.UDPAddr, error) { err := r.conn.setDeadline(r.timeout) if err != nil { return 0, nil, err } err = r.conn.sendTo(r.send[:l], r.addr) if err != nil { return 0, nil, err } r.datagramsSent++ for { c, addr, err := r.conn.readFrom(r.receive) if err != nil { return 0, nil, err } if !addr.IP.Equal(r.addr.IP) || (r.tid != 0 && addr.Port != r.tid) { continue } p, err := parsePacket(r.receive[:c]) if err != nil { return 0, addr, err } r.tid = addr.Port switch p := p.(type) { case pDATA: if p.block() == r.block { r.datagramsAcked++ return c, addr, nil } case pOACK: opts, err := unpackOACK(p) if r.block != 1 { continue } if err != nil { r.abort(err) return 0, addr, err } for name, value := range opts { if name == "blksize" { err := r.setBlockSize(value) if err != nil { continue } } } r.block = 0 // ACK with block number 0 r.opts = opts return 0, addr, nil case pERROR: return 0, addr, fmt.Errorf("code: %d, message: %s", p.code(), p.message()) } } } func (r *receiver) terminate() error { if r.conn == nil { return nil } defer func() { if r.hook != nil { r.hook.OnSuccess(r.buildTransferStats()) } r.conn.close() }() binary.BigEndian.PutUint16(r.send[2:4], r.block) if r.dally { for i := 0; i < 3; i++ { _, _, err := r.receiveDatagram(4) if err != nil { return nil } } return fmt.Errorf("dallying termination failed") } err := r.conn.sendTo(r.send[:4], r.addr) if err != nil { return err } return nil } func (r *receiver) buildTransferStats() TransferStats { return TransferStats{ RemoteAddr: r.addr.IP, Filename: r.filename, Tid: r.tid, Mode: r.mode, Opts: r.opts, Duration: time.Since(r.startTime), DatagramsSent: r.datagramsSent, DatagramsAcked: r.datagramsAcked, } } func (r *receiver) abort(err error) error { if r.conn == nil { return nil } if r.hook != nil { r.hook.OnFailure(r.buildTransferStats(), err) } n := packERROR(r.send, 1, err.Error()) err = r.conn.sendTo(r.send[:n], r.addr) if err != nil { return err } r.conn.close() r.conn = nil return nil } golang-github-pin-tftp-3.1.0/sender.go000066400000000000000000000144411452427260500176230ustar00rootroot00000000000000package tftp import ( "encoding/binary" "fmt" "io" "net" "strconv" "time" "github.com/pin/tftp/v3/netascii" ) // OutgoingTransfer provides methods to set the outgoing transfer size and // retrieve the remote address of the peer. type OutgoingTransfer interface { // SetSize is used to set the outgoing transfer size (tsize option: RFC2349) // manually in a server write transfer handler. // // It is not necessary in most cases; when the io.Reader provided to // ReadFrom also satisfies io.Seeker (e.g. os.File) the transfer size will // be determined automatically. Seek will not be attempted when the // transfer size option is set with SetSize. // // The value provided will be used only if SetSize is called before ReadFrom // and only on in a server read handler. SetSize(n int64) // RemoteAddr returns the remote peer's IP address and port. RemoteAddr() net.UDPAddr } type sender struct { conn connection addr *net.UDPAddr filename string localIP net.IP tid int send []byte sendA senderAnticipate receive []byte retry *backoff timeout time.Duration retries int block uint16 maxBlockLen int mode string opts options hook Hook startTime time.Time datagramsSent int datagramsAcked int } func (s *sender) RemoteAddr() net.UDPAddr { return *s.addr } func (s *sender) LocalIP() net.IP { return s.localIP } func (s *sender) SetSize(n int64) { if s.opts != nil { if _, ok := s.opts["tsize"]; ok { s.opts["tsize"] = strconv.FormatInt(n, 10) } } } func (s *sender) ReadFrom(r io.Reader) (n int64, err error) { if s.mode == "netascii" { r = netascii.ToReader(r) } defer func() { if s.conn != nil { s.conn.close() s.conn = nil } }() if s.opts != nil { // check that tsize is set if ts, ok := s.opts["tsize"]; ok { // check that tsize is not set with SetSize already i, err := strconv.ParseInt(ts, 10, 64) if err == nil && i == 0 { if rs, ok := r.(io.Seeker); ok { pos, err := rs.Seek(0, 1) if err != nil { return 0, err } size, err := rs.Seek(0, 2) if err != nil { return 0, err } s.opts["tsize"] = strconv.FormatInt(size, 10) _, err = rs.Seek(pos, 0) if err != nil { return 0, err } } } } err = s.sendOptions() if err != nil { s.abort(err) return 0, err } } if s.sendA.enabled { /* senderAnticipate */ return readFromAnticipate(s, r) } s.block = 1 // start data transmission with block 1 binary.BigEndian.PutUint16(s.send[0:2], opDATA) for { l, err := io.ReadFull(r, s.send[4:]) n += int64(l) if err != nil && err != io.ErrUnexpectedEOF { if err == io.EOF { binary.BigEndian.PutUint16(s.send[2:4], s.block) _, err = s.sendWithRetry(4) if err != nil { s.abort(err) return n, err } if s.hook != nil { s.hook.OnSuccess(s.buildTransferStats()) } return n, nil } s.abort(err) return n, err } binary.BigEndian.PutUint16(s.send[2:4], s.block) _, err = s.sendWithRetry(4 + l) if err != nil { s.abort(err) return n, err } if l < len(s.send)-4 { if s.hook != nil { s.hook.OnSuccess(s.buildTransferStats()) } return n, nil } s.block++ } } func (s *sender) sendOptions() error { for name, value := range s.opts { if name == "blksize" { err := s.setBlockSize(value) if err != nil { delete(s.opts, name) continue } } else if name == "tsize" { if value != "0" { s.opts["tsize"] = value } else { delete(s.opts, name) continue } } else { delete(s.opts, name) } } if len(s.opts) > 0 { m := packOACK(s.send, s.opts) _, err := s.sendWithRetry(m) if err != nil { return err } } return nil } func (s *sender) setBlockSize(blksize string) error { n, err := strconv.Atoi(blksize) if err != nil { return err } if n < 512 { return fmt.Errorf("blksize too small: %d", n) } if n > 65464 { return fmt.Errorf("blksize too large: %d", n) } if s.maxBlockLen > 0 && n > s.maxBlockLen { n = s.maxBlockLen s.opts["blksize"] = strconv.Itoa(n) } s.send = make([]byte, n+4) if s.sendA.enabled { /* senderAnticipate */ sendAInit(&s.sendA, uint(n+4), s.sendA.winsz) } return nil } func (s *sender) sendWithRetry(l int) (*net.UDPAddr, error) { s.retry.reset() for { addr, err := s.sendDatagram(l) if _, ok := err.(net.Error); ok && s.retry.count() < s.retries { s.retry.backoff() continue } return addr, err } } func (s *sender) sendDatagram(l int) (*net.UDPAddr, error) { err := s.conn.setDeadline(s.timeout) if err != nil { return nil, err } err = s.conn.sendTo(s.send[:l], s.addr) if err != nil { return nil, err } s.datagramsSent++ for { n, addr, err := s.conn.readFrom(s.receive) if err != nil { return nil, err } if !addr.IP.Equal(s.addr.IP) || (s.tid != 0 && addr.Port != s.tid) { continue } p, err := parsePacket(s.receive[:n]) if err != nil { continue } s.tid = addr.Port switch p := p.(type) { case pACK: if p.block() == s.block { s.datagramsAcked++ return addr, nil } case pOACK: opts, err := unpackOACK(p) if s.block != 0 { continue } if err != nil { s.abort(err) return addr, err } for name, value := range opts { if name == "blksize" { err := s.setBlockSize(value) if err != nil { continue } } } return addr, nil case pERROR: return nil, fmt.Errorf("sending block %d: code=%d, error: %s", s.block, p.code(), p.message()) } } } func (s *sender) buildTransferStats() TransferStats { return TransferStats{ RemoteAddr: s.addr.IP, Filename: s.filename, Tid: s.tid, SenderAnticipateEnabled: s.sendA.enabled, Mode: s.mode, Opts: s.opts, Duration: time.Since(s.startTime), DatagramsSent: s.datagramsSent, DatagramsAcked: s.datagramsAcked, } } func (s *sender) abort(err error) error { if s.conn == nil { return nil } if s.hook != nil { s.hook.OnFailure(s.buildTransferStats(), err) } n := packERROR(s.send, 1, err.Error()) err = s.conn.sendTo(s.send[:n], s.addr) if err != nil { return err } s.conn.close() s.conn = nil return nil } golang-github-pin-tftp-3.1.0/sender_anticipate.go000066400000000000000000000106051452427260500220220ustar00rootroot00000000000000package tftp import ( "encoding/binary" "fmt" "io" "net" ) // the struct embedded into sender{} as sendA type senderAnticipate struct { enabled bool winsz uint /* init windows size in number of buffers */ num uint /* actual packets to send. */ sends [][]byte /* buffers for a number of packets */ sendslens []uint /* data lens in buffers */ } const anticipateWindowDefMax = 60 /* 60 by 512 is about 30k */ const anticipateDebug bool = false func sendAInit(sA *senderAnticipate, ln uint, winSz uint) { var ksz uint if winSz > anticipateWindowDefMax { ksz = anticipateWindowDefMax } else if winSz < 2 { ksz = 2 } else { ksz = winSz } sA.sends = make([][]byte, ksz) sA.sendslens = make([]uint, ksz) for k := uint(0); k < ksz; k++ { sA.sends[k] = make([]byte, ln) sA.sendslens[k] = 0 } sA.winsz = ksz //fmt.Printf(" Set packet buffer size %v\n", ln) } // derived from ReadFrom() func readFromAnticipate(s *sender, r io.Reader) (n int64, err error) { s.block = 1 // start data transmission with block 1 ksz := uint(len(s.sendA.sends)) for k := uint(0); k < ksz; k++ { binary.BigEndian.PutUint16(s.sendA.sends[k][0:2], opDATA) s.sendA.sendslens[k] = 0 } s.sendA.num = 0 for { nx := int64(0) knum := uint(0) kfillOk := true /* default ok */ kfillPartial := false for k := uint(0); k < ksz; k++ { lx, err := io.ReadFull(r, s.sendA.sends[k][4:]) nx += int64(lx) if err != nil && err != io.ErrUnexpectedEOF { if err == io.EOF { if kfillPartial { break /* short packet already sent in last loop */ } binary.BigEndian.PutUint16(s.sendA.sends[k][2:4], s.block+uint16(k)) s.sendA.sendslens[k] = 4 knum = k + 1 kfillPartial = true break } kfillOk = false break /* fail */ } else if err != nil /* has to be io.ErrUnexpectedEOF now */ { kfillPartial = true /* set the flag and send the packet */ } binary.BigEndian.PutUint16(s.sendA.sends[k][2:4], s.block+uint16(k)) s.sendA.sendslens[k] = uint(4 + lx) knum = k + 1 } if !kfillOk { s.abort(err) return n, err } s.sendA.num = knum n += int64(nx) if anticipateDebug { fmt.Printf(" **** sends s.block %v pkts %v ", s.block, knum) for k := uint(0); k < ksz; k++ { fmt.Printf(" %v ", s.sendA.sendslens[k]) } fmt.Println("") } _, err = s.sendWithRetryAnticipate() if err != nil { s.abort(err) return n, err } if kfillPartial { s.conn.close() return n, nil } s.block += uint16(knum) } } // derived from sendWithRetry() func (s *sender) sendWithRetryAnticipate() (*net.UDPAddr, error) { s.retry.reset() for { addr, err := s.sendDatagramAnticipate() if _, ok := err.(net.Error); ok && s.retry.count() < s.retries { s.retry.backoff() continue } return addr, err } } // derived from sendDatagram() func (s *sender) sendDatagramAnticipate() (*net.UDPAddr, error) { err1 := s.conn.setDeadline(s.timeout) if err1 != nil { return nil, err1 } var err error ksz := uint(len(s.sendA.sends)) knum := s.sendA.num if knum > ksz { err = fmt.Errorf("knum %v bigger than ksz %v", knum, ksz) return nil, err } for k := uint(0); k < knum; k++ { lx := s.sendA.sendslens[k] if lx < 4 { err = fmt.Errorf("lx smaller than 4") break } errx := s.conn.sendTo(s.sendA.sends[k][:lx], s.addr) if errx != nil { err = fmt.Errorf("k %v errx %v", k, errx.Error()) break } } if err != nil { return nil, err } k := uint(0) for { n, addr, err := s.conn.readFrom(s.receive) if err != nil { return nil, err } if !addr.IP.Equal(s.addr.IP) || (s.tid != 0 && addr.Port != s.tid) { continue } p, err := parsePacket(s.receive[:n]) if err != nil { continue } s.tid = addr.Port switch p := p.(type) { case pACK: if anticipateDebug { fmt.Printf(" **** pACK p.block %v s.block %v k %v\n", p.block(), s.block, k) } if p.block() == s.block+uint16(k) { k++ if k == knum { return addr, nil } } case pOACK: opts, err := unpackOACK(p) if s.block != 0 { continue } if err != nil { s.abort(err) return addr, err } for name, value := range opts { if name == "blksize" { err := s.setBlockSize(value) if err != nil { continue } } } return addr, nil case pERROR: return nil, fmt.Errorf("sending block %d: code=%d, error: %s", s.block, p.code(), p.message()) } } } golang-github-pin-tftp-3.1.0/server.go000066400000000000000000000316301452427260500176500ustar00rootroot00000000000000package tftp import ( "context" "fmt" "io" "net" "sync" "time" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" ) // NewServer creates TFTP server. It requires two functions to handle // read and write requests. // In case nil is provided for read or write handler the respective // operation is disabled. func NewServer(readHandler func(filename string, rf io.ReaderFrom) error, writeHandler func(filename string, wt io.WriterTo) error) *Server { s := &Server{ Mutex: &sync.Mutex{}, timeout: defaultTimeout, retries: defaultRetries, packetReadTimeout: 100 * time.Millisecond, readHandler: readHandler, writeHandler: writeHandler, wg: &sync.WaitGroup{}, } s.cancel, s.cancelFn = context.WithCancel(context.Background()) return s } // RequestPacketInfo provides a method of getting the local IP address // that is handling a UDP request. It relies for its accuracy on the // OS providing methods to inspect the underlying UDP and IP packets // directly. type RequestPacketInfo interface { // LocalIP returns the IP address we are servicing the request on. // If it is unable to determine what address that is, the returned // net.IP will be nil. LocalIP() net.IP } // Server is an instance of a TFTP server type Server struct { *sync.Mutex readHandler func(filename string, rf io.ReaderFrom) error writeHandler func(filename string, wt io.WriterTo) error hook Hook backoff backoffFunc conn net.PacketConn conn6 *ipv6.PacketConn conn4 *ipv4.PacketConn wg *sync.WaitGroup timeout time.Duration retries int maxBlockLen int sendAEnable bool /* senderAnticipate enable by server */ sendAWinSz uint // Single port fields singlePort bool handlers map[string]chan []byte packetReadTimeout time.Duration cancel context.Context cancelFn context.CancelFunc } // TransferStats contains details about a single TFTP transfer type TransferStats struct { RemoteAddr net.IP Filename string Tid int SenderAnticipateEnabled bool Mode string Opts options Duration time.Duration DatagramsSent int DatagramsAcked int } // Hook is an interface used to provide the server with success and failure hooks type Hook interface { OnSuccess(stats TransferStats) OnFailure(stats TransferStats, err error) } // SetAnticipate provides an experimental feature in which when a packets // is requested the server will keep sending a number of packets before // checking whether an ack has been received. It improves tftp downloading // speed by a few times. // The argument winsz specifies how many packets will be sent before // waiting for an ack packet. // When winsz is bigger than 1, the feature is enabled, and the server // runs through a different experimental code path. When winsz is 0 or 1, // the feature is disabled. func (s *Server) SetAnticipate(winsz uint) { s.Lock() defer s.Unlock() if winsz > 1 { s.sendAEnable = true s.sendAWinSz = winsz } else { s.sendAEnable = false s.sendAWinSz = 1 } } // SetHook sets the Hook for success and failure of transfers func (s *Server) SetHook(hook Hook) { s.Lock() defer s.Unlock() s.hook = hook } // EnableSinglePort enables an experimental mode where the server will // serve all connections on port 69 only. There will be no random TIDs // on the server side. // // Enabling this will negatively impact performance func (s *Server) EnableSinglePort() { s.Lock() defer s.Unlock() s.singlePort = true s.handlers = make(map[string]chan []byte) if s.maxBlockLen == 0 { s.maxBlockLen = blockLength } } // SetTimeout sets maximum time server waits for single network // round-trip to succeed. // Default is 5 seconds. func (s *Server) SetTimeout(t time.Duration) { s.Lock() defer s.Unlock() if t <= 0 { s.timeout = defaultTimeout } else { s.timeout = t } } // SetBlockSize sets the maximum size of an individual data block. // This must be a value between 512 (the default block size for TFTP) // and 65456 (the max size a UDP packet payload can be). // // This is an advisory value -- it will be clamped to the smaller of // the block size the client wants and the MTU of the interface being // communicated over munis overhead. func (s *Server) SetBlockSize(i int) { s.Lock() defer s.Unlock() if i > 512 && i < 65465 { s.maxBlockLen = i } } // SetRetries sets maximum number of attempts server made to transmit a // packet. // Default is 5 attempts. func (s *Server) SetRetries(count int) { s.Lock() defer s.Unlock() if count < 1 { s.retries = defaultRetries } else { s.retries = count } } // SetBackoff sets a user provided function that is called to provide a // backoff duration prior to retransmitting an unacknowledged packet. func (s *Server) SetBackoff(h backoffFunc) { s.Lock() defer s.Unlock() s.backoff = h } // ListenAndServe binds to address provided and start the server. // ListenAndServe returns when Shutdown is called. func (s *Server) ListenAndServe(addr string) error { a, err := net.ResolveUDPAddr("udp", addr) if err != nil { return err } conn, err := net.ListenUDP("udp", a) if err != nil { return err } return s.Serve(conn) } // Serve starts server provided already opened UDP connection. It is // useful for the case when you want to run server in separate goroutine // but still want to be able to handle any errors opening connection. // Serve returns when Shutdown is called. func (s *Server) Serve(conn net.PacketConn) error { laddr := conn.LocalAddr() host, _, err := net.SplitHostPort(laddr.String()) if err != nil { return err } s.Lock() s.conn = conn s.Unlock() // Having seperate control paths for IP4 and IP6 is annoying, // but necessary at this point. addr := net.ParseIP(host) if addr == nil { return fmt.Errorf("Failed to determine IP class of listening address") } if conn, ok := s.conn.(*net.UDPConn); ok { if addr.To4() != nil { s.conn4 = ipv4.NewPacketConn(conn) if err := s.conn4.SetControlMessage(ipv4.FlagDst|ipv4.FlagInterface, true); err != nil { s.conn4 = nil } } else { s.conn6 = ipv6.NewPacketConn(conn) if err := s.conn6.SetControlMessage(ipv6.FlagDst|ipv6.FlagInterface, true); err != nil { s.conn6 = nil } } } if s.singlePort { return s.singlePortProcessRequests() } for { select { case <-s.cancel.Done(): // Stop server because Shutdown was called return nil default: var err error if s.conn4 != nil { err = s.processRequest4() } else if s.conn6 != nil { err = s.processRequest6() } else { err = s.processRequest() } if err != nil && s.hook != nil { s.hook.OnFailure(TransferStats{ SenderAnticipateEnabled: s.sendAEnable, }, err) } } } } // Yes, I don't really like having separate IPv4 and IPv6 variants, // bit we are relying on the low-level packet control channel info to // get a reliable source address, and those have different types and // the struct itself is not easily interface-ized or embedded. // // If control is nil for whatever reason (either things not being // implemented on a target OS or whatever other reason), localIP // (and hence LocalIP()) will return a nil IP address. func (s *Server) processRequest4() error { buf := make([]byte, datagramLength) cnt, control, srcAddr, err := s.conn4.ReadFrom(buf) if err != nil { return nil } maxSz := blockLength var localAddr net.IP if control != nil { localAddr = control.Dst if intf, err := net.InterfaceByIndex(control.IfIndex); err == nil { // mtu - ipv4 overhead - udp overhead maxSz = intf.MTU - 28 } } return s.handlePacket(localAddr, srcAddr.(*net.UDPAddr), buf, cnt, maxSz, nil) } func (s *Server) processRequest6() error { buf := make([]byte, datagramLength) cnt, control, srcAddr, err := s.conn6.ReadFrom(buf) if err != nil { return nil } maxSz := blockLength var localAddr net.IP if control != nil { localAddr = control.Dst if intf, err := net.InterfaceByIndex(control.IfIndex); err == nil { // mtu - ipv6 overhead - udp overhead maxSz = intf.MTU - 48 } } return s.handlePacket(localAddr, srcAddr.(*net.UDPAddr), buf, cnt, maxSz, nil) } // Fallback if we had problems opening a ipv4/6 control channel func (s *Server) processRequest() error { buf := make([]byte, datagramLength) cnt, srcAddr, err := s.conn.ReadFrom(buf) if err != nil { return fmt.Errorf("reading UDP: %v", err) } return s.handlePacket(nil, srcAddr.(*net.UDPAddr), buf, cnt, blockLength, nil) } // Shutdown make server stop listening for new requests, allows // server to finish outstanding transfers and stops server. // Shutdown blocks until all outstanding requests are processed or timed out. // Calling Shutdown from the handler or hook might cause deadlock. func (s *Server) Shutdown() { if !s.singlePort { s.Lock() // Connection could not exist if Serve or // ListenAndServe was never called. if s.conn != nil { s.conn.Close() } s.Unlock() } s.cancelFn() if !s.singlePort { s.wg.Wait() } } func (s *Server) handlePacket(localAddr net.IP, remoteAddr *net.UDPAddr, buffer []byte, n, maxBlockLen int, listener chan []byte) error { s.Lock() defer s.Unlock() // Cope with packets received on the broadcast address // We can't use this address as the source address in responses // so fallback to the OS default. if localAddr.Equal(net.IPv4bcast) { localAddr = net.IPv4zero } // handlePacket is always called with maxBlockLen = blockLength (above, in processRequest). // As a result, the block size would always be capped at 512 bytes, even when the tftp // client indicated to use a larger value. So override that value. And make sure to // use that value below, when allocating buffers. (Happening on Windows Server 2016.) // if s.maxBlockLen > 0 { if s.maxBlockLen > 0 && s.maxBlockLen < maxBlockLen { maxBlockLen = s.maxBlockLen } if maxBlockLen < blockLength { maxBlockLen = blockLength } p, err := parsePacket(buffer[:n]) if err != nil { return err } listenAddr := &net.UDPAddr{IP: localAddr} switch p := p.(type) { case pWRQ: filename, mode, opts, err := unpackRQ(p) if err != nil { return fmt.Errorf("unpack WRQ: %v", err) } //fmt.Printf("got WRQ (filename=%s, mode=%s, opts=%v)\n", filename, mode, opts) if err != nil { return fmt.Errorf("open transmission: %v", err) } wt := &receiver{ send: make([]byte, datagramLength), receive: make([]byte, datagramLength), retry: &backoff{handler: s.backoff}, timeout: s.timeout, retries: s.retries, addr: remoteAddr, localIP: localAddr, mode: mode, opts: opts, maxBlockLen: maxBlockLen, hook: s.hook, filename: filename, startTime: time.Now(), } if s.singlePort { wt.conn = &chanConnection{ server: s, srcAddr: listenAddr, addr: remoteAddr, channel: listener, timeout: s.timeout, } wt.singlePort = true } else { conn, err := net.ListenUDP("udp", listenAddr) if err != nil { return err } wt.conn = &connConnection{conn: conn} } s.wg.Add(1) go func() { if s.writeHandler != nil { err := s.writeHandler(filename, wt) if err != nil { wt.abort(err) } else { wt.terminate() } } else { wt.abort(fmt.Errorf("server does not support write requests")) } s.wg.Done() }() case pRRQ: filename, mode, opts, err := unpackRQ(p) if err != nil { return fmt.Errorf("unpack RRQ: %v", err) } //fmt.Printf("got RRQ (filename=%s, mode=%s, opts=%v)\n", filename, mode, opts) rf := &sender{ send: make([]byte, datagramLength), sendA: senderAnticipate{enabled: false}, receive: make([]byte, datagramLength), tid: remoteAddr.Port, retry: &backoff{handler: s.backoff}, timeout: s.timeout, retries: s.retries, addr: remoteAddr, localIP: localAddr, mode: mode, opts: opts, maxBlockLen: maxBlockLen, hook: s.hook, filename: filename, startTime: time.Now(), } if s.singlePort { rf.conn = &chanConnection{ server: s, srcAddr: listenAddr, addr: remoteAddr, channel: listener, timeout: s.timeout, } } else { conn, err := net.ListenUDP("udp", listenAddr) if err != nil { return err } rf.conn = &connConnection{conn: conn} } if s.sendAEnable { /* senderAnticipate if enabled in server */ rf.sendA.enabled = true /* pass enable from server to sender */ sendAInit(&rf.sendA, datagramLength, s.sendAWinSz) } s.wg.Add(1) go func(rh func(string, io.ReaderFrom) error, rf *sender, wg *sync.WaitGroup) { if s.readHandler != nil { err := s.readHandler(filename, rf) if err != nil { rf.abort(err) } } else { rf.abort(fmt.Errorf("server does not support read requests")) } s.wg.Done() }(s.readHandler, rf, s.wg) default: return fmt.Errorf("unexpected %T", p) } return nil } golang-github-pin-tftp-3.1.0/single_port.go000066400000000000000000000041611452427260500206660ustar00rootroot00000000000000package tftp import ( "net" ) func (s *Server) singlePortProcessRequests() error { for { select { case <-s.cancel.Done(): s.wg.Wait() return nil default: buf := make([]byte, s.maxBlockLen+4) cnt, localAddr, srcAddr, maxSz, err := s.getPacket(buf) if err != nil || cnt == 0 { if s.hook != nil { s.hook.OnFailure(TransferStats{ SenderAnticipateEnabled: s.sendAEnable, }, err) } continue } s.Lock() if receiverChannel, ok := s.handlers[srcAddr.String()]; ok { s.Unlock() select { case receiverChannel <- buf[:cnt]: default: // We don't want to block the main loop if a channel is full } } else { lc := make(chan []byte, 1) s.handlers[srcAddr.String()] = lc s.Unlock() go func() { err := s.handlePacket(localAddr, srcAddr, buf, cnt, maxSz, lc) if err != nil && s.hook != nil { s.hook.OnFailure(TransferStats{ SenderAnticipateEnabled: s.sendAEnable, }, err) } }() } } } } func (s *Server) getPacket(buf []byte) (int, net.IP, *net.UDPAddr, int, error) { if s.conn6 != nil { cnt, control, srcAddr, err := s.conn6.ReadFrom(buf) if err != nil || cnt == 0 { return 0, nil, nil, 0, err } var localAddr net.IP maxSz := blockLength if control != nil { localAddr = control.Dst if intf, err := net.InterfaceByIndex(control.IfIndex); err == nil { // mtu - ipv4 overhead - udp overhead maxSz = intf.MTU - 28 } } return cnt, localAddr, srcAddr.(*net.UDPAddr), maxSz, nil } else if s.conn4 != nil { cnt, control, srcAddr, err := s.conn4.ReadFrom(buf) if err != nil || cnt == 0 { return 0, nil, nil, 0, err } var localAddr net.IP maxSz := blockLength if control != nil { localAddr = control.Dst if intf, err := net.InterfaceByIndex(control.IfIndex); err == nil { // mtu - ipv6 overhead - udp overhead maxSz = intf.MTU - 48 } } return cnt, localAddr, srcAddr.(*net.UDPAddr), maxSz, nil } else { cnt, srcAddr, err := s.conn.ReadFrom(buf) if err != nil { return 0, nil, nil, 0, err } return cnt, nil, srcAddr.(*net.UDPAddr), blockLength, nil } } golang-github-pin-tftp-3.1.0/single_port_test.go000066400000000000000000000014121452427260500217210ustar00rootroot00000000000000package tftp import ( "testing" ) func TestZeroLengthSinglePort(t *testing.T) { s, c := makeTestServer(true) defer s.Shutdown() testSendReceive(t, c, 0) } func TestSendReceiveSinglePort(t *testing.T) { s, c := makeTestServer(true) defer s.Shutdown() for i := 600; i < 1000; i++ { testSendReceive(t, c, 5000+int64(i)) } } func TestSendReceiveSinglePortWithBlockSize(t *testing.T) { s, c := makeTestServer(true) defer s.Shutdown() for i := 600; i < 1000; i++ { c.blksize = i testSendReceive(t, c, 5000+int64(i)) } } func TestServerSendTimeoutSinglePort(t *testing.T) { s, c := makeTestServer(true) serverTimeoutSendTest(s, c, t) } func TestServerReceiveTimeoutSinglePort(t *testing.T) { s, c := makeTestServer(true) serverReceiveTimeoutTest(s, c, t) } golang-github-pin-tftp-3.1.0/tftp_anticipate_test.go000066400000000000000000000013771452427260500225640ustar00rootroot00000000000000package tftp import ( "net" "testing" ) // derived from Test900 func TestAnticipateWindow900(t *testing.T) { s, c := makeTestServerAnticipateWindow() defer s.Shutdown() for i := 600; i < 4000; i++ { c.blksize = i testSendReceive(t, c, 9000+int64(i)) } } // derived from makeTestServer func makeTestServerAnticipateWindow() (*Server, *Client) { b := &testBackend{} b.m = make(map[string][]byte) // Create server s := NewServer(b.handleRead, b.handleWrite) s.SetAnticipate(16) /* senderAnticipate window size set to 16 */ conn, err := net.ListenUDP("udp", &net.UDPAddr{}) if err != nil { panic(err) } go s.Serve(conn) // Create client for that server c, err := NewClient(localSystem(conn)) if err != nil { panic(err) } return s, c } golang-github-pin-tftp-3.1.0/tftp_test.go000066400000000000000000000520071452427260500203570ustar00rootroot00000000000000package tftp import ( "bytes" "errors" "fmt" "io" "io/ioutil" "math/rand" "net" "os" "strconv" "sync" "testing" "testing/iotest" "time" ) var localhost = determineLocalhost() func determineLocalhost() string { l, err := net.ListenTCP("tcp", nil) if err != nil { panic(fmt.Sprintf("ListenTCP error: %s", err)) } _, lport, _ := net.SplitHostPort(l.Addr().String()) defer l.Close() lo := make(chan string) go func() { for { conn, err := l.Accept() if err != nil { break } conn.Close() } }() go func() { port, _ := strconv.Atoi(lport) for _, af := range []string{"tcp6", "tcp4"} { conn, err := net.DialTCP(af, &net.TCPAddr{}, &net.TCPAddr{Port: port}) if err == nil { conn.Close() host, _, _ := net.SplitHostPort(conn.LocalAddr().String()) lo <- host return } } panic("could not determine address family") }() return <-lo } func localSystem(c *net.UDPConn) string { _, port, _ := net.SplitHostPort(c.LocalAddr().String()) return net.JoinHostPort(localhost, port) } func TestPackUnpack(t *testing.T) { v := []string{"test-filename/with-subdir"} testOptsList := []options{ nil, { "tsize": "1234", "blksize": "22", }, } for _, filename := range v { for _, mode := range []string{"octet", "netascii"} { for _, opts := range testOptsList { packUnpack(t, filename, mode, opts) } } } } func packUnpack(t *testing.T, filename, mode string, opts options) { b := make([]byte, datagramLength) for _, op := range []uint16{opRRQ, opWRQ} { n := packRQ(b, op, filename, mode, opts) f, m, o, err := unpackRQ(b[:n]) if err != nil { t.Errorf("%s pack/unpack: %v", filename, err) } if f != filename { t.Errorf("filename mismatch (%s): '%x' vs '%x'", filename, f, filename) } if m != mode { t.Errorf("mode mismatch (%s): '%x' vs '%x'", mode, m, mode) } for name, value := range opts { v, ok := o[name] if !ok { t.Errorf("missing %s option", name) } if v != value { t.Errorf("option %s mismatch: '%x' vs '%x'", name, v, value) } } } } func TestZeroLength(t *testing.T) { s, c := makeTestServer(false) defer s.Shutdown() testSendReceive(t, c, 0) } func Test900(t *testing.T) { s, c := makeTestServer(false) defer s.Shutdown() for i := 600; i < 4000; i++ { c.SetBlockSize(i) s.SetBlockSize(4600 - i) testSendReceive(t, c, 9000+int64(i)) } } func Test1000(t *testing.T) { s, c := makeTestServer(false) defer s.Shutdown() for i := int64(0); i < 5000; i++ { filename := fmt.Sprintf("length-%d-bytes-%d", i, time.Now().UnixNano()) rf, err := c.Send(filename, "octet") if err != nil { t.Fatalf("requesting %s write: %v", filename, err) } r := io.LimitReader(newRandReader(rand.NewSource(i)), i) n, err := rf.ReadFrom(r) if err != nil { t.Fatalf("sending %s: %v", filename, err) } if n != i { t.Errorf("%s length mismatch: %d != %d", filename, n, i) } } } func Test1810(t *testing.T) { s, c := makeTestServer(false) defer s.Shutdown() c.SetBlockSize(1810) testSendReceive(t, c, 9000+1810) } type testHook struct { *sync.Mutex transfersCompleted int transfersFailed int } func newTestHook() *testHook { return &testHook{ Mutex: &sync.Mutex{}, } } func (h *testHook) OnSuccess(result TransferStats) { h.Lock() defer h.Unlock() h.transfersCompleted++ } func (h *testHook) OnFailure(result TransferStats, err error) { h.Lock() defer h.Unlock() h.transfersFailed++ } func TestHookSuccess(t *testing.T) { s, c := makeTestServer(false) th := newTestHook() s.SetHook(th) c.SetBlockSize(1810) length := int64(9000) filename := fmt.Sprintf("length-%d-bytes-%d", length, time.Now().UnixNano()) rf, err := c.Send(filename, "octet") if err != nil { t.Fatalf("requesting %s write: %v", filename, err) } r := io.LimitReader(newRandReader(rand.NewSource(length)), length) n, err := rf.ReadFrom(r) if err != nil { t.Fatalf("sending %s: %v", filename, err) } if n != length { t.Errorf("%s length mismatch: %d != %d", filename, n, length) } s.Shutdown() th.Lock() defer th.Unlock() if th.transfersCompleted != 1 { t.Errorf("unexpected completed transfers count: %d", th.transfersCompleted) } } func TestHookFailure(t *testing.T) { s, c := makeTestServer(false) th := newTestHook() s.SetHook(th) filename := "test-not-exists" mode := "octet" _, err := c.Receive(filename, mode) if err == nil { t.Fatalf("file not exists: %v", err) } t.Logf("receiving file that does not exist: %v", err) s.Shutdown() th.Lock() defer th.Unlock() if th.transfersFailed == 0 { // TODO: there are two failures, not one on Windows? t.Errorf("unexpected failed transfers count: %d", th.transfersFailed) } } func TestTSize(t *testing.T) { s, c := makeTestServer(false) defer s.Shutdown() c.tsize = true testSendReceive(t, c, 640) } func TestNearBlockLength(t *testing.T) { s, c := makeTestServer(false) defer s.Shutdown() for i := 450; i < 520; i++ { testSendReceive(t, c, int64(i)) } } func TestBlockWrapsAround(t *testing.T) { s, c := makeTestServer(false) defer s.Shutdown() n := 65535 * 512 for i := n - 2; i < n+2; i++ { testSendReceive(t, c, int64(i)) } } func TestRandomLength(t *testing.T) { s, c := makeTestServer(false) defer s.Shutdown() r := rand.New(rand.NewSource(42)) for i := 0; i < 100; i++ { testSendReceive(t, c, r.Int63n(100000)) } } func TestBigFile(t *testing.T) { s, c := makeTestServer(false) defer s.Shutdown() testSendReceive(t, c, 3*1000*1000) } func TestByOneByte(t *testing.T) { s, c := makeTestServer(false) defer s.Shutdown() filename := "test-by-one-byte" mode := "octet" const length = 80000 sender, err := c.Send(filename, mode) if err != nil { t.Fatalf("requesting write: %v", err) } r := iotest.OneByteReader(io.LimitReader( newRandReader(rand.NewSource(42)), length)) n, err := sender.ReadFrom(r) if err != nil { t.Fatalf("send error: %v", err) } if n != length { t.Errorf("%s read length mismatch: %d != %d", filename, n, length) } readTransfer, err := c.Receive(filename, mode) if err != nil { t.Fatalf("requesting read %s: %v", filename, err) } buf := &bytes.Buffer{} n, err = readTransfer.WriteTo(buf) if err != nil { t.Fatalf("%s read error: %v", filename, err) } if n != length { t.Errorf("%s read length mismatch: %d != %d", filename, n, length) } bs, _ := ioutil.ReadAll(io.LimitReader( newRandReader(rand.NewSource(42)), length)) if !bytes.Equal(bs, buf.Bytes()) { t.Errorf("\nsent: %x\nrcvd: %x", bs, buf) } } func TestDuplicate(t *testing.T) { s, c := makeTestServer(false) defer s.Shutdown() filename := "test-duplicate" mode := "octet" bs := []byte("lalala") sender, err := c.Send(filename, mode) if err != nil { t.Fatalf("requesting write: %v", err) } buf := bytes.NewBuffer(bs) _, err = sender.ReadFrom(buf) if err != nil { t.Fatalf("send error: %v", err) } sender, err = c.Send(filename, mode) if err == nil { t.Fatalf("file already exists") } t.Logf("sending file that already exists: %v", err) } func TestNotFound(t *testing.T) { s, c := makeTestServer(false) defer s.Shutdown() filename := "test-not-exists" mode := "octet" _, err := c.Receive(filename, mode) if err == nil { t.Fatalf("file not exists: %v", err) } t.Logf("receiving file that does not exist: %v", err) } func testSendReceive(t *testing.T, client *Client, length int64) { filename := fmt.Sprintf("length-%d-bytes", length) mode := "octet" writeTransfer, err := client.Send(filename, mode) if err != nil { t.Fatalf("requesting write %s: %v", filename, err) } r := io.LimitReader(newRandReader(rand.NewSource(42)), length) n, err := writeTransfer.ReadFrom(r) if err != nil { t.Fatalf("%s write error: %v", filename, err) } if n != length { t.Errorf("%s write length mismatch: %d != %d", filename, n, length) } readTransfer, err := client.Receive(filename, mode) if err != nil { t.Fatalf("requesting read %s: %v", filename, err) } if it, ok := readTransfer.(IncomingTransfer); ok { if n, ok := it.Size(); ok { fmt.Printf("Transfer size: %d\n", n) if n != length { t.Errorf("tsize mismatch: %d vs %d", n, length) } } } buf := &bytes.Buffer{} n, err = readTransfer.WriteTo(buf) if err != nil { t.Fatalf("%s read error: %v", filename, err) } if n != length { t.Errorf("%s read length mismatch: %d != %d", filename, n, length) } bs, _ := ioutil.ReadAll(io.LimitReader( newRandReader(rand.NewSource(42)), length)) if !bytes.Equal(bs, buf.Bytes()) { t.Errorf("\nsent: %x\nrcvd: %x", bs, buf) } } func TestSendTsizeFromSeek(t *testing.T) { // create read-only server s := NewServer(func(filename string, rf io.ReaderFrom) error { b := make([]byte, 100) rr := newRandReader(rand.NewSource(42)) rr.Read(b) // bytes.Reader implements io.Seek r := bytes.NewReader(b) _, err := rf.ReadFrom(r) if err != nil { t.Errorf("sending bytes: %v", err) } return nil }, nil) conn, err := net.ListenUDP("udp", &net.UDPAddr{}) if err != nil { t.Fatalf("listening: %v", err) } go s.Serve(conn) defer s.Shutdown() c, _ := NewClient(localSystem(conn)) c.RequestTSize(true) r, _ := c.Receive("f", "octet") var size int64 if it, ok := r.(IncomingTransfer); ok { if n, ok := it.Size(); ok { size = n fmt.Printf("Transfer size: %d\n", n) } } if size != 100 { t.Errorf("size expected: 100, got %d", size) } r.WriteTo(ioutil.Discard) c.RequestTSize(false) r, _ = c.Receive("f", "octet") if it, ok := r.(IncomingTransfer); ok { _, ok := it.Size() if ok { t.Errorf("unexpected size received") } } r.WriteTo(ioutil.Discard) } type testBackend struct { m map[string][]byte mu sync.Mutex } func makeTestServer(singlePort bool) (*Server, *Client) { b := &testBackend{} b.m = make(map[string][]byte) // Create server s := NewServer(b.handleRead, b.handleWrite) if singlePort { s.SetBlockSize(2000) s.EnableSinglePort() } conn, err := net.ListenUDP("udp", &net.UDPAddr{}) if err != nil { panic(err) } go s.Serve(conn) // Create client for that server c, err := NewClient(localSystem(conn)) if err != nil { panic(err) } return s, c } func TestNoHandlers(t *testing.T) { s := NewServer(nil, nil) conn, err := net.ListenUDP("udp", &net.UDPAddr{}) if err != nil { panic(err) } go s.Serve(conn) c, err := NewClient(localSystem(conn)) if err != nil { panic(err) } _, err = c.Send("test", "octet") if err == nil { t.Errorf("error expected") } _, err = c.Receive("test", "octet") if err == nil { t.Errorf("error expected") } } func (b *testBackend) handleWrite(filename string, wt io.WriterTo) error { b.mu.Lock() defer b.mu.Unlock() _, ok := b.m[filename] if ok { fmt.Fprintf(os.Stderr, "File %s already exists\n", filename) return fmt.Errorf("file already exists") } if t, ok := wt.(IncomingTransfer); ok { if n, ok := t.Size(); ok { fmt.Printf("Transfer size: %d\n", n) } } buf := &bytes.Buffer{} _, err := wt.WriteTo(buf) if err != nil { fmt.Fprintf(os.Stderr, "Can't receive %s: %v\n", filename, err) return err } b.m[filename] = buf.Bytes() return nil } func (b *testBackend) handleRead(filename string, rf io.ReaderFrom) error { b.mu.Lock() defer b.mu.Unlock() bs, ok := b.m[filename] if !ok { fmt.Fprintf(os.Stderr, "File %s not found\n", filename) return fmt.Errorf("file not found") } if t, ok := rf.(OutgoingTransfer); ok { t.SetSize(int64(len(bs))) } _, err := rf.ReadFrom(bytes.NewBuffer(bs)) if err != nil { fmt.Fprintf(os.Stderr, "Can't send %s: %v\n", filename, err) return err } return nil } type randReader struct { src rand.Source next int64 i int8 } func newRandReader(src rand.Source) io.Reader { r := &randReader{ src: src, next: src.Int63(), } return r } func (r *randReader) Read(p []byte) (n int, err error) { next, i := r.next, r.i for n = 0; n < len(p); n++ { if i == 7 { next, i = r.src.Int63(), 0 } p[n] = byte(next) next >>= 8 i++ } r.next, r.i = next, i return } func serverTimeoutSendTest(s *Server, c *Client, t *testing.T) { s.SetTimeout(time.Second) s.SetRetries(2) sec := make(chan error, 1) s.Lock() s.readHandler = func(filename string, rf io.ReaderFrom) error { r := io.LimitReader(newRandReader(rand.NewSource(42)), 80000) _, err := rf.ReadFrom(r) sec <- err return err } s.Unlock() defer s.Shutdown() filename := "test-server-send-timeout" mode := "octet" readTransfer, err := c.Receive(filename, mode) if err != nil { t.Fatalf("requesting read %s: %v", filename, err) } w := &slowWriter{ n: 3, delay: 8 * time.Second, } _, _ = readTransfer.WriteTo(w) servErr := <-sec netErr, ok := servErr.(net.Error) if !ok { t.Fatalf("network error expected: %T", servErr) } if !netErr.Timeout() { t.Fatalf("timout is expected: %v", servErr) } } func TestServerSendTimeout(t *testing.T) { s, c := makeTestServer(false) serverTimeoutSendTest(s, c, t) } func serverReceiveTimeoutTest(s *Server, c *Client, t *testing.T) { s.SetTimeout(time.Second) s.SetRetries(2) sec := make(chan error, 1) s.Lock() s.writeHandler = func(filename string, wt io.WriterTo) error { buf := &bytes.Buffer{} _, err := wt.WriteTo(buf) sec <- err return err } s.Unlock() defer s.Shutdown() filename := "test-server-receive-timeout" mode := "octet" writeTransfer, err := c.Send(filename, mode) if err != nil { t.Fatalf("requesting write %s: %v", filename, err) } r := &slowReader{ r: io.LimitReader(newRandReader(rand.NewSource(42)), 80000), n: 3, delay: 8 * time.Second, } _, _ = writeTransfer.ReadFrom(r) servErr := <-sec netErr, ok := servErr.(net.Error) if !ok { t.Fatalf("network error expected: %T", servErr) } if !netErr.Timeout() { t.Fatalf("timout is expected: %v", servErr) } } func TestServerReceiveTimeout(t *testing.T) { s, c := makeTestServer(false) serverReceiveTimeoutTest(s, c, t) } func TestClientReceiveTimeout(t *testing.T) { s, c := makeTestServer(false) c.SetTimeout(time.Second) c.SetRetries(2) s.Lock() s.readHandler = func(filename string, rf io.ReaderFrom) error { r := &slowReader{ r: io.LimitReader(newRandReader(rand.NewSource(42)), 80000), n: 3, delay: 8 * time.Second, } _, err := rf.ReadFrom(r) return err } s.Unlock() defer s.Shutdown() filename := "test-client-receive-timeout" mode := "octet" readTransfer, err := c.Receive(filename, mode) if err != nil { t.Fatalf("requesting read %s: %v", filename, err) } buf := &bytes.Buffer{} _, err = readTransfer.WriteTo(buf) netErr, ok := err.(net.Error) if !ok { t.Fatalf("network error expected: %T", err) } if !netErr.Timeout() { t.Fatalf("timout is expected: %v", err) } } func TestClientSendTimeout(t *testing.T) { s, c := makeTestServer(false) c.SetTimeout(time.Second) c.SetRetries(2) s.Lock() s.writeHandler = func(filename string, wt io.WriterTo) error { w := &slowWriter{ n: 3, delay: 8 * time.Second, } _, err := wt.WriteTo(w) return err } s.Unlock() defer s.Shutdown() filename := "test-client-send-timeout" mode := "octet" writeTransfer, err := c.Send(filename, mode) if err != nil { t.Fatalf("requesting write %s: %v", filename, err) } r := io.LimitReader(newRandReader(rand.NewSource(42)), 80000) _, err = writeTransfer.ReadFrom(r) netErr, ok := err.(net.Error) if !ok { t.Fatalf("network error expected: %T", err) } if !netErr.Timeout() { t.Fatalf("timout is expected: %v", err) } } type slowReader struct { r io.Reader n int64 delay time.Duration } func (r *slowReader) Read(p []byte) (n int, err error) { if r.n > 0 { r.n-- return r.r.Read(p) } time.Sleep(r.delay) return r.r.Read(p) } type slowWriter struct { r io.Reader n int64 delay time.Duration } func (r *slowWriter) Write(p []byte) (n int, err error) { if r.n > 0 { r.n-- return len(p), nil } time.Sleep(r.delay) return len(p), nil } // TestRequestPacketInfo checks that request packet destination address // obtained by server using out-of-band socket info is sane. // It creates server and tries to do transfers using different local interfaces. // NB: Test ignores transfer errors and validates RequestPacketInfo only // if transfer is completed successfully. So it checks that LocalIP returns // correct result if any result is returned, but does not check if result was // returned at all when it should. func TestRequestPacketInfo(t *testing.T) { // localIP keeps value received from RequestPacketInfo.LocalIP // call inside handler. // If RequestPacketInfo is not supported, value is set to unspecified // IP address. var localIP net.IP var localIPMu sync.Mutex s := NewServer( func(_ string, rf io.ReaderFrom) error { localIPMu.Lock() if rpi, ok := rf.(RequestPacketInfo); ok { localIP = rpi.LocalIP() } else { localIP = net.IP{} } localIPMu.Unlock() _, err := rf.ReadFrom(io.LimitReader( newRandReader(rand.NewSource(42)), 42)) if err != nil { t.Logf("sending to client: %v", err) } return nil }, func(_ string, wt io.WriterTo) error { localIPMu.Lock() if rpi, ok := wt.(RequestPacketInfo); ok { localIP = rpi.LocalIP() } else { localIP = net.IP{} } localIPMu.Unlock() _, err := wt.WriteTo(ioutil.Discard) if err != nil { t.Logf("receiving from client: %v", err) } return nil }, ) conn, err := net.ListenUDP("udp", &net.UDPAddr{}) if err != nil { t.Fatalf("listen UDP: %v", err) } _, port, err := net.SplitHostPort(conn.LocalAddr().String()) if err != nil { t.Fatalf("parsing server port: %v", err) } // Start server go func() { err := s.Serve(conn) if err != nil { t.Fatalf("serve: %v", err) } }() defer s.Shutdown() addrs, err := net.InterfaceAddrs() if err != nil { t.Fatalf("listing interface addresses: %v", err) } for _, addr := range addrs { ip := networkIP(addr.(*net.IPNet)) if ip == nil { continue } c, err := NewClient(net.JoinHostPort(ip.String(), port)) if err != nil { t.Fatalf("new client: %v", err) } // Skip re-tries to skip non-routable interfaces faster c.SetRetries(0) ot, err := c.Send("a", "octet") if err != nil { t.Logf("start sending to %v: %v", ip, err) continue } _, err = ot.ReadFrom(io.LimitReader( newRandReader(rand.NewSource(42)), 42)) if err != nil { t.Logf("sending to %v: %v", ip, err) continue } // Check that read handler received IP that was used // to create the client. localIPMu.Lock() if localIP != nil && !localIP.IsUnspecified() { // Skip check if no packet info if !localIP.Equal(ip) { t.Errorf("sent to: %v, request packet: %v", ip, localIP) } } else { fmt.Printf("Skip %v\n", ip) } localIPMu.Unlock() it, err := c.Receive("a", "octet") if err != nil { t.Logf("start receiving from %v: %v", ip, err) continue } _, err = it.WriteTo(ioutil.Discard) if err != nil { t.Logf("receiving from %v: %v", ip, err) continue } // Check that write handler received IP that was used // to create the client. localIPMu.Lock() if localIP != nil && !localIP.IsUnspecified() { // Skip check if no packet info if !localIP.Equal(ip) { t.Errorf("sent to: %v, request packet: %v", ip, localIP) } } else { fmt.Printf("Skip %v\n", ip) } localIPMu.Unlock() fmt.Printf("Done %v\n", ip) } } func networkIP(n *net.IPNet) net.IP { if ip := n.IP.To4(); ip != nil { return ip } if len(n.IP) == net.IPv6len { return n.IP } return nil } // TestFileIOExceptions checks that errors returned by io.Reader or io.Writer used by // the handler are handled correctly. func TestReadWriteErrors(t *testing.T) { s := NewServer( func(_ string, rf io.ReaderFrom) error { _, err := rf.ReadFrom(&failingReader{}) // Read operation fails immediately. if err != errRead { t.Errorf("want: %v, got: %v", errRead, err) } // return no error from handler, client still should receive error return nil }, func(_ string, wt io.WriterTo) error { _, err := wt.WriteTo(&failingWriter{}) // Write operation fails immediately. if err != errWrite { t.Errorf("want: %v, got: %v", errWrite, err) } // return no error from handler, client still should receive error return nil }, ) conn, err := net.ListenUDP("udp", &net.UDPAddr{}) if err != nil { t.Fatalf("listen UDP: %v", err) } _, port, err := net.SplitHostPort(conn.LocalAddr().String()) if err != nil { t.Fatalf("parsing server port: %v", err) } // Start server go func() { err := s.Serve(conn) if err != nil { t.Fatalf("running serve: %v", err) } }() defer s.Shutdown() // Create client c, err := NewClient(net.JoinHostPort(localhost, port)) if err != nil { t.Fatalf("creating new client: %v", err) } ot, err := c.Send("a", "octet") if err != nil { t.Errorf("start sending: %v", err) } _, err = ot.ReadFrom(io.LimitReader( newRandReader(rand.NewSource(42)), 42)) if err == nil { t.Errorf("missing write error") } _, err = c.Receive("a", "octet") if err == nil { t.Errorf("missing read error") } } type failingReader struct{} var errRead = errors.New("read error") func (r *failingReader) Read(_ []byte) (int, error) { return 0, errRead } type failingWriter struct{} var errWrite = errors.New("write error") func (r *failingWriter) Write(_ []byte) (int, error) { return 0, errWrite }