pax_global_header00006660000000000000000000000064141661754460014530gustar00rootroot0000000000000052 comment=e06b719a377d74a18bdff0d0ffb9abf4a4afdbac grab-3.0.1/000077500000000000000000000000001416617544600124445ustar00rootroot00000000000000grab-3.0.1/.gitignore000066400000000000000000000000501416617544600144270ustar00rootroot00000000000000# ignore IDE project files *.iml .idea/ grab-3.0.1/.travis.yml000066400000000000000000000002001416617544600145450ustar00rootroot00000000000000language: go go: - tip - 1.17.x - 1.16.x - 1.15.x - 1.14.x script: make check env: - GOARCH=amd64 - GOARCH=386 grab-3.0.1/LICENSE000066400000000000000000000027341416617544600134570ustar00rootroot00000000000000Copyright (c) 2017 Ryan Armstrong. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. grab-3.0.1/Makefile000066400000000000000000000004051416617544600141030ustar00rootroot00000000000000GO = go GOGET = $(GO) get -u all: check check: cd v3 && $(GO) test -v -cover -race ./... cd v3/cmd/grab && $(MAKE) -B all install: cd v3/cmd/grab && $(MAKE) install clean: cd v3 && $(GO) clean -x ./... rm -rvf ./.test* .PHONY: all check install clean grab-3.0.1/README.md000066400000000000000000000077721416617544600137400ustar00rootroot00000000000000# grab [![GoDoc](https://godoc.org/github.com/cavaliercoder/grab?status.svg)](https://godoc.org/github.com/cavaliercoder/grab) [![Build Status](https://travis-ci.org/cavaliercoder/grab.svg?branch=master)](https://travis-ci.org/cavaliercoder/grab) [![Go Report Card](https://goreportcard.com/badge/github.com/cavaliercoder/grab)](https://goreportcard.com/report/github.com/cavaliercoder/grab) *Downloading the internet, one goroutine at a time!* $ go get github.com/cavaliergopher/grab/v3 Grab is a Go package for downloading files from the internet with the following rad features: * Monitor download progress concurrently * Auto-resume incomplete downloads * Guess filename from content header or URL path * Safely cancel downloads using context.Context * Validate downloads using checksums * Download batches of files concurrently * Apply rate limiters Requires Go v1.7+ ## Example The following example downloads a PDF copy of the free eBook, "An Introduction to Programming in Go" into the current working directory. ```go resp, err := grab.Get(".", "http://www.golang-book.com/public/pdf/gobook.pdf") if err != nil { log.Fatal(err) } fmt.Println("Download saved to", resp.Filename) ``` The following, more complete example allows for more granular control and periodically prints the download progress until it is complete. The second time you run the example, it will auto-resume the previous download and exit sooner. ```go package main import ( "fmt" "os" "time" "github.com/cavaliergopher/grab/v3" ) func main() { // create client client := grab.NewClient() req, _ := grab.NewRequest(".", "http://www.golang-book.com/public/pdf/gobook.pdf") // start download fmt.Printf("Downloading %v...\n", req.URL()) resp := client.Do(req) fmt.Printf(" %v\n", resp.HTTPResponse.Status) // start UI loop t := time.NewTicker(500 * time.Millisecond) defer t.Stop() Loop: for { select { case <-t.C: fmt.Printf(" transferred %v / %v bytes (%.2f%%)\n", resp.BytesComplete(), resp.Size, 100*resp.Progress()) case <-resp.Done: // download is complete break Loop } } // check for errors if err := resp.Err(); err != nil { fmt.Fprintf(os.Stderr, "Download failed: %v\n", err) os.Exit(1) } fmt.Printf("Download saved to ./%v \n", resp.Filename) // Output: // Downloading http://www.golang-book.com/public/pdf/gobook.pdf... // 200 OK // transferred 42970 / 2893557 bytes (1.49%) // transferred 1207474 / 2893557 bytes (41.73%) // transferred 2758210 / 2893557 bytes (95.32%) // Download saved to ./gobook.pdf } ``` ## Design trade-offs The primary use case for Grab is to concurrently downloading thousands of large files from remote file repositories where the remote files are immutable. Examples include operating system package repositories or ISO libraries. Grab aims to provide robust, sane defaults. These are usually determined using the HTTP specifications, or by mimicking the behavior of common web clients like cURL, wget and common web browsers. Grab aims to be stateless. The only state that exists is the remote files you wish to download and the local copy which may be completed, partially completed or not yet created. The advantage to this is that the local file system is not cluttered unnecessarily with addition state files (like a `.crdownload` file). The disadvantage of this approach is that grab must make assumptions about the local and remote state; specifically, that they have not been modified by another program. If the local or remote file are modified outside of grab, and you download the file again with resuming enabled, the local file will likely become corrupted. In this case, you might consider making remote files immutable, or disabling resume. Grab aims to enable best-in-class functionality for more complex features through extensible interfaces, rather than reimplementation. For example, you can provide your own Hash algorithm to compute file checksums, or your own rate limiter implementation (with all the associated trade-offs) to rate limit downloads. grab-3.0.1/states.wsd000066400000000000000000000043421416617544600144710ustar00rootroot00000000000000@startuml title Grab transfer state legend | # | Meaning | | D | Destination path known | | S | File size known | | O | Server options known (Accept-Ranges) | | R | Resume supported (Accept-Ranges) | | Z | Local file empty or missing | | P | Local file partially complete | endlegend [*] --> Empty [*] --> D [*] --> S [*] --> DS Empty : Filename: "" Empty : Size: 0 Empty --> O : HEAD: Method not allowed Empty --> DSO : HEAD: Range not supported Empty --> DSOR : HEAD: Range supported DS : Filename: "foo.bar" DS : Size: > 0 DS --> DSZ : checkExisting(): File missing DS --> DSP : checkExisting(): File partial DS --> [*] : checkExisting(): File complete DS --> ERROR S : Filename: "" S : Size: > 0 S --> SO : HEAD: Method not allowed S --> DSO : HEAD: Range not supported S --> DSOR : HEAD: Range supported D : Filename: "foo.bar" D : Size: 0 D --> DO : HEAD: Method not allowed D --> DSO : HEAD: Range not supported D --> DSOR : HEAD: Range supported O : Filename: "" O : Size: 0 O : CanResume: false O --> DSO : GET 200 O --> ERROR SO : Filename: "" SO : Size: > 0 SO : CanResume: false SO --> DSO : GET: 200 SO --> ERROR DO : Filename: "foo.bar" DO : Size: 0 DO : CanResume: false DO --> DSO : GET 200 DO --> ERROR DSZ : Filename: "foo.bar" DSZ : Size: > 0 DSZ : File: empty DSZ --> DSORZ : HEAD: Range supported DSZ --> DSOZ : HEAD 405 or Range unsupported DSP : Filename: "foo.bar" DSP : Size: > 0 DSP : File: partial DSP --> DSORP : HEAD: Range supported DSP --> DSOZ : HEAD: 405 or Range unsupported DSO : Filename: "foo.bar" DSO : Size: > 0 DSO : CanResume: false DSO --> DSOZ : checkExisting(): File partial|missing DSO --> [*] : checkExisting(): File complete DSOR : Filename: "foo.bar" DSOR : Size: > 0 DSOR : CanResume: true DSOR --> DSORP : CheckLocal: File partial DSOR --> DSORZ : CheckLocal: File missing DSORP : Filename: "foo.bar" DSORP : Size: > 0 DSORP : CanResume: true DSORP : File: partial DSORP --> Transferring DSORZ : Filename: "foo.bar" DSORZ : Size: > 0 DSORZ : CanResume: true DSORZ : File: empty DSORZ --> Transferring DSOZ : Filename: "foo.bar" DSOZ : Size: > 0 DSOZ : CanResume: false DSOZ : File: empty DSOZ --> Transferring Transferring --> [*] Transferring --> ERROR ERROR : Something went wrong ERROR --> [*] @endumlgrab-3.0.1/v3/000077500000000000000000000000001416617544600127745ustar00rootroot00000000000000grab-3.0.1/v3/client.go000066400000000000000000000357501416617544600146130ustar00rootroot00000000000000package grab import ( "bytes" "context" "fmt" "io" "net/http" "os" "path/filepath" "sync" "sync/atomic" "time" ) // HTTPClient provides an interface allowing us to perform HTTP requests. type HTTPClient interface { Do(req *http.Request) (*http.Response, error) } // truncater is a private interface allowing different response // Writers to be truncated type truncater interface { Truncate(size int64) error } // A Client is a file download client. // // Clients are safe for concurrent use by multiple goroutines. type Client struct { // HTTPClient specifies the http.Client which will be used for communicating // with the remote server during the file transfer. HTTPClient HTTPClient // UserAgent specifies the User-Agent string which will be set in the // headers of all requests made by this client. // // The user agent string may be overridden in the headers of each request. UserAgent string // BufferSize specifies the size in bytes of the buffer that is used for // transferring all requested files. Larger buffers may result in faster // throughput but will use more memory and result in less frequent updates // to the transfer progress statistics. The BufferSize of each request can // be overridden on each Request object. Default: 32KB. BufferSize int } // NewClient returns a new file download Client, using default configuration. func NewClient() *Client { return &Client{ UserAgent: "grab", HTTPClient: &http.Client{ Transport: &http.Transport{ Proxy: http.ProxyFromEnvironment, }, }, } } // DefaultClient is the default client and is used by all Get convenience // functions. var DefaultClient = NewClient() // Do sends a file transfer request and returns a file transfer response, // following policy (e.g. redirects, cookies, auth) as configured on the // client's HTTPClient. // // Like http.Get, Do blocks while the transfer is initiated, but returns as soon // as the transfer has started transferring in a background goroutine, or if it // failed early. // // An error is returned via Response.Err if caused by client policy (such as // CheckRedirect), or if there was an HTTP protocol or IO error. Response.Err // will block the caller until the transfer is completed, successfully or // otherwise. func (c *Client) Do(req *Request) *Response { // cancel will be called on all code-paths via closeResponse ctx, cancel := context.WithCancel(req.Context()) req = req.WithContext(ctx) resp := &Response{ Request: req, Start: time.Now(), Done: make(chan struct{}, 0), Filename: req.Filename, ctx: ctx, cancel: cancel, bufferSize: req.BufferSize, } if resp.bufferSize == 0 { // default to Client.BufferSize resp.bufferSize = c.BufferSize } // Run state-machine while caller is blocked to initialize the file transfer. // Must never transition to the copyFile state - this happens next in another // goroutine. c.run(resp, c.statFileInfo) // Run copyFile in a new goroutine. copyFile will no-op if the transfer is // already complete or failed. go c.run(resp, c.copyFile) return resp } // DoChannel executes all requests sent through the given Request channel, one // at a time, until it is closed by another goroutine. The caller is blocked // until the Request channel is closed and all transfers have completed. All // responses are sent through the given Response channel as soon as they are // received from the remote servers and can be used to track the progress of // each download. // // Slow Response receivers will cause a worker to block and therefore delay the // start of the transfer for an already initiated connection - potentially // causing a server timeout. It is the caller's responsibility to ensure a // sufficient buffer size is used for the Response channel to prevent this. // // If an error occurs during any of the file transfers it will be accessible via // the associated Response.Err function. func (c *Client) DoChannel(reqch <-chan *Request, respch chan<- *Response) { // TODO: enable cancelling of batch jobs for req := range reqch { resp := c.Do(req) respch <- resp <-resp.Done } } // DoBatch executes all the given requests using the given number of concurrent // workers. Control is passed back to the caller as soon as the workers are // initiated. // // If the requested number of workers is less than one, a worker will be created // for every request. I.e. all requests will be executed concurrently. // // If an error occurs during any of the file transfers it will be accessible via // call to the associated Response.Err. // // The returned Response channel is closed only after all of the given Requests // have completed, successfully or otherwise. func (c *Client) DoBatch(workers int, requests ...*Request) <-chan *Response { if workers < 1 { workers = len(requests) } reqch := make(chan *Request, len(requests)) respch := make(chan *Response, len(requests)) wg := sync.WaitGroup{} for i := 0; i < workers; i++ { wg.Add(1) go func() { c.DoChannel(reqch, respch) wg.Done() }() } // queue requests go func() { for _, req := range requests { reqch <- req } close(reqch) wg.Wait() close(respch) }() return respch } // An stateFunc is an action that mutates the state of a Response and returns // the next stateFunc to be called. type stateFunc func(*Response) stateFunc // run calls the given stateFunc function and all subsequent returned stateFuncs // until a stateFunc returns nil or the Response.ctx is canceled. Each stateFunc // should mutate the state of the given Response until it has completed // downloading or failed. func (c *Client) run(resp *Response, f stateFunc) { for { select { case <-resp.ctx.Done(): if resp.IsComplete() { return } resp.err = resp.ctx.Err() f = c.closeResponse default: // keep working } if f = f(resp); f == nil { return } } } // statFileInfo retrieves FileInfo for any local file matching // Response.Filename. // // If the file does not exist, is a directory, or its name is unknown the next // stateFunc is headRequest. // // If the file exists, Response.fi is set and the next stateFunc is // validateLocal. // // If an error occurs, the next stateFunc is closeResponse. func (c *Client) statFileInfo(resp *Response) stateFunc { if resp.Request.NoStore || resp.Filename == "" { return c.headRequest } fi, err := os.Stat(resp.Filename) if err != nil { if os.IsNotExist(err) { return c.headRequest } resp.err = err return c.closeResponse } if fi.IsDir() { resp.Filename = "" return c.headRequest } resp.fi = fi return c.validateLocal } // validateLocal compares a local copy of the downloaded file to the remote // file. // // An error is returned if the local file is larger than the remote file, or // Request.SkipExisting is true. // // If the existing file matches the length of the remote file, the next // stateFunc is checksumFile. // // If the local file is smaller than the remote file and the remote server is // known to support ranged requests, the next stateFunc is getRequest. func (c *Client) validateLocal(resp *Response) stateFunc { if resp.Request.SkipExisting { resp.err = ErrFileExists return c.closeResponse } // determine target file size expectedSize := resp.Request.Size if expectedSize == 0 && resp.HTTPResponse != nil { expectedSize = resp.HTTPResponse.ContentLength } if expectedSize == 0 { // size is either actually 0 or unknown // if unknown, we ask the remote server // if known to be 0, we proceed with a GET return c.headRequest } if expectedSize == resp.fi.Size() { // local file matches remote file size - wrap it up resp.DidResume = true resp.bytesResumed = resp.fi.Size() return c.checksumFile } if resp.Request.NoResume { // local file should be overwritten return c.getRequest } if expectedSize >= 0 && expectedSize < resp.fi.Size() { // remote size is known, is smaller than local size and we want to resume resp.err = ErrBadLength return c.closeResponse } if resp.CanResume { // set resume range on GET request resp.Request.HTTPRequest.Header.Set( "Range", fmt.Sprintf("bytes=%d-", resp.fi.Size())) resp.DidResume = true resp.bytesResumed = resp.fi.Size() return c.getRequest } return c.headRequest } func (c *Client) checksumFile(resp *Response) stateFunc { if resp.Request.hash == nil { return c.closeResponse } if resp.Filename == "" { panic("grab: developer error: filename not set") } if resp.Size() < 0 { panic("grab: developer error: size unknown") } req := resp.Request // compute checksum var sum []byte sum, resp.err = resp.checksumUnsafe() if resp.err != nil { return c.closeResponse } // compare checksum if !bytes.Equal(sum, req.checksum) { resp.err = ErrBadChecksum if !resp.Request.NoStore && req.deleteOnError { if err := os.Remove(resp.Filename); err != nil { // err should be os.PathError and include file path resp.err = fmt.Errorf( "cannot remove downloaded file with checksum mismatch: %v", err) } } } return c.closeResponse } // doHTTPRequest sends a HTTP Request and returns the response func (c *Client) doHTTPRequest(req *http.Request) (*http.Response, error) { if c.UserAgent != "" && req.Header.Get("User-Agent") == "" { req.Header.Set("User-Agent", c.UserAgent) } return c.HTTPClient.Do(req) } func (c *Client) headRequest(resp *Response) stateFunc { if resp.optionsKnown { return c.getRequest } resp.optionsKnown = true if resp.Request.NoResume { return c.getRequest } if resp.Filename != "" && resp.fi == nil { // destination path is already known and does not exist return c.getRequest } hreq := new(http.Request) *hreq = *resp.Request.HTTPRequest hreq.Method = "HEAD" resp.HTTPResponse, resp.err = c.doHTTPRequest(hreq) if resp.err != nil { return c.closeResponse } resp.HTTPResponse.Body.Close() if resp.HTTPResponse.StatusCode != http.StatusOK { return c.getRequest } // In case of redirects during HEAD, record the final URL and use it // instead of the original URL when sending future requests. // This way we avoid sending potentially unsupported requests to // the original URL, e.g. "Range", since it was the final URL // that advertised its support. resp.Request.HTTPRequest.URL = resp.HTTPResponse.Request.URL resp.Request.HTTPRequest.Host = resp.HTTPResponse.Request.Host return c.readResponse } func (c *Client) getRequest(resp *Response) stateFunc { resp.HTTPResponse, resp.err = c.doHTTPRequest(resp.Request.HTTPRequest) if resp.err != nil { return c.closeResponse } // TODO: check Content-Range // check status code if !resp.Request.IgnoreBadStatusCodes { if resp.HTTPResponse.StatusCode < 200 || resp.HTTPResponse.StatusCode > 299 { resp.err = StatusCodeError(resp.HTTPResponse.StatusCode) return c.closeResponse } } return c.readResponse } func (c *Client) readResponse(resp *Response) stateFunc { if resp.HTTPResponse == nil { panic("grab: developer error: Response.HTTPResponse is nil") } // check expected size resp.sizeUnsafe = resp.HTTPResponse.ContentLength if resp.sizeUnsafe >= 0 { // remote size is known resp.sizeUnsafe += resp.bytesResumed if resp.Request.Size > 0 && resp.Request.Size != resp.sizeUnsafe { resp.err = ErrBadLength return c.closeResponse } } // check filename if resp.Filename == "" { filename, err := guessFilename(resp.HTTPResponse) if err != nil { resp.err = err return c.closeResponse } // Request.Filename will be empty or a directory resp.Filename = filepath.Join(resp.Request.Filename, filename) } if !resp.Request.NoStore && resp.requestMethod() == "HEAD" { if resp.HTTPResponse.Header.Get("Accept-Ranges") == "bytes" { resp.CanResume = true } return c.statFileInfo } return c.openWriter } // openWriter opens the destination file for writing and seeks to the location // from whence the file transfer will resume. // // Requires that Response.Filename and resp.DidResume are already be set. func (c *Client) openWriter(resp *Response) stateFunc { if !resp.Request.NoStore && !resp.Request.NoCreateDirectories { resp.err = mkdirp(resp.Filename) if resp.err != nil { return c.closeResponse } } if resp.Request.NoStore { resp.writer = &resp.storeBuffer } else { // compute write flags flag := os.O_CREATE | os.O_WRONLY if resp.fi != nil { if resp.DidResume { flag = os.O_APPEND | os.O_WRONLY } else { // truncate later in copyFile, if not cancelled // by BeforeCopy hook flag = os.O_WRONLY } } // open file f, err := os.OpenFile(resp.Filename, flag, 0666) if err != nil { resp.err = err return c.closeResponse } resp.writer = f // seek to start or end whence := os.SEEK_SET if resp.bytesResumed > 0 { whence = os.SEEK_END } _, resp.err = f.Seek(0, whence) if resp.err != nil { return c.closeResponse } } // init transfer if resp.bufferSize < 1 { resp.bufferSize = 32 * 1024 } b := make([]byte, resp.bufferSize) resp.transfer = newTransfer( resp.Request.Context(), resp.Request.RateLimiter, resp.writer, resp.HTTPResponse.Body, b) // next step is copyFile, but this will be called later in another goroutine return nil } // copy transfers content for a HTTP connection established via Client.do() func (c *Client) copyFile(resp *Response) stateFunc { if resp.IsComplete() { return nil } // run BeforeCopy hook if f := resp.Request.BeforeCopy; f != nil { resp.err = f(resp) if resp.err != nil { return c.closeResponse } } var bytesCopied int64 if resp.transfer == nil { panic("grab: developer error: Response.transfer is nil") } // We waited to truncate the file in openWriter() to make sure // the BeforeCopy didn't cancel the copy. If this was an existing // file that is not going to be resumed, truncate the contents. if t, ok := resp.writer.(truncater); ok && resp.fi != nil && !resp.DidResume { t.Truncate(0) } bytesCopied, resp.err = resp.transfer.copy() if resp.err != nil { return c.closeResponse } closeWriter(resp) // set file timestamp if !resp.Request.NoStore && !resp.Request.IgnoreRemoteTime { resp.err = setLastModified(resp.HTTPResponse, resp.Filename) if resp.err != nil { return c.closeResponse } } // update transfer size if previously unknown if resp.Size() < 0 { discoveredSize := resp.bytesResumed + bytesCopied atomic.StoreInt64(&resp.sizeUnsafe, discoveredSize) if resp.Request.Size > 0 && resp.Request.Size != discoveredSize { resp.err = ErrBadLength return c.closeResponse } } // run AfterCopy hook if f := resp.Request.AfterCopy; f != nil { resp.err = f(resp) if resp.err != nil { return c.closeResponse } } return c.checksumFile } func closeWriter(resp *Response) { if closer, ok := resp.writer.(io.Closer); ok { closer.Close() } resp.writer = nil } // close finalizes the Response func (c *Client) closeResponse(resp *Response) stateFunc { if resp.IsComplete() { panic("grab: developer error: response already closed") } resp.fi = nil closeWriter(resp) resp.closeResponseBody() resp.End = time.Now() close(resp.Done) if resp.cancel != nil { resp.cancel() } return nil } grab-3.0.1/v3/client_test.go000066400000000000000000000657041416617544600156540ustar00rootroot00000000000000package grab import ( "bytes" "context" "crypto/md5" "crypto/sha1" "crypto/sha256" "crypto/sha512" "errors" "fmt" "hash" "io/ioutil" "math/rand" "net/http" "os" "path/filepath" "strings" "testing" "time" "github.com/cavaliergopher/grab/v3/pkg/grabtest" ) // TestFilenameResolutions tests that the destination filename for Requests can // be determined correctly, using an explicitly requested path, // Content-Disposition headers or a URL path - with or without an existing // target directory. func TestFilenameResolution(t *testing.T) { tests := []struct { Name string Filename string URL string AttachmentFilename string Expect string }{ {"Using Request.Filename", ".testWithFilename", "/url-filename", "header-filename", ".testWithFilename"}, {"Using Content-Disposition Header", "", "/url-filename", ".testWithHeaderFilename", ".testWithHeaderFilename"}, {"Using Content-Disposition Header with target directory", ".test", "/url-filename", "header-filename", ".test/header-filename"}, {"Using URL Path", "", "/.testWithURLFilename?params-filename", "", ".testWithURLFilename"}, {"Using URL Path with target directory", ".test", "/url-filename?garbage", "", ".test/url-filename"}, {"Failure", "", "", "", ""}, } err := os.Mkdir(".test", 0777) if err != nil { panic(err) } defer os.RemoveAll(".test") for _, test := range tests { t.Run(test.Name, func(t *testing.T) { opts := []grabtest.HandlerOption{} if test.AttachmentFilename != "" { opts = append(opts, grabtest.AttachmentFilename(test.AttachmentFilename)) } grabtest.WithTestServer(t, func(url string) { req := mustNewRequest(test.Filename, url+test.URL) resp := DefaultClient.Do(req) defer os.Remove(resp.Filename) if err := resp.Err(); err != nil { if test.Expect != "" || err != ErrNoFilename { panic(err) } } else { if test.Expect == "" { t.Errorf("expected: %v, got: %v", ErrNoFilename, err) } } if resp.Filename != test.Expect { t.Errorf("Filename mismatch. Expected '%s', got '%s'.", test.Expect, resp.Filename) } testComplete(t, resp) }, opts...) }) } } // TestChecksums checks that checksum validation behaves as expected for valid // and corrupted downloads. func TestChecksums(t *testing.T) { tests := []struct { size int hash hash.Hash sum string match bool }{ {128, md5.New(), "37eff01866ba3f538421b30b7cbefcac", true}, {128, md5.New(), "37eff01866ba3f538421b30b7cbefcad", false}, {1024, md5.New(), "b2ea9f7fcea831a4a63b213f41a8855b", true}, {1024, md5.New(), "b2ea9f7fcea831a4a63b213f41a8855c", false}, {1048576, md5.New(), "c35cc7d8d91728a0cb052831bc4ef372", true}, {1048576, md5.New(), "c35cc7d8d91728a0cb052831bc4ef373", false}, {128, sha1.New(), "e6434bc401f98603d7eda504790c98c67385d535", true}, {128, sha1.New(), "e6434bc401f98603d7eda504790c98c67385d536", false}, {1024, sha1.New(), "5b00669c480d5cffbdfa8bdba99561160f2d1b77", true}, {1024, sha1.New(), "5b00669c480d5cffbdfa8bdba99561160f2d1b78", false}, {1048576, sha1.New(), "ecfc8e86fdd83811f9cc9bf500993b63069923be", true}, {1048576, sha1.New(), "ecfc8e86fdd83811f9cc9bf500993b63069923bf", false}, {128, sha256.New(), "471fb943aa23c511f6f72f8d1652d9c880cfa392ad80503120547703e56a2be5", true}, {128, sha256.New(), "471fb943aa23c511f6f72f8d1652d9c880cfa392ad80503120547703e56a2be4", false}, {1024, sha256.New(), "785b0751fc2c53dc14a4ce3d800e69ef9ce1009eb327ccf458afe09c242c26c9", true}, {1024, sha256.New(), "785b0751fc2c53dc14a4ce3d800e69ef9ce1009eb327ccf458afe09c242c26c8", false}, {1048576, sha256.New(), "fbbab289f7f94b25736c58be46a994c441fd02552cc6022352e3d86d2fab7c83", true}, {1048576, sha256.New(), "fbbab289f7f94b25736c58be46a994c441fd02552cc6022352e3d86d2fab7c82", false}, {128, sha512.New(), "1dffd5e3adb71d45d2245939665521ae001a317a03720a45732ba1900ca3b8351fc5c9b4ca513eba6f80bc7b1d1fdad4abd13491cb824d61b08d8c0e1561b3f7", true}, {128, sha512.New(), "1dffd5e3adb71d45d2245939665521ae001a317a03720a45732ba1900ca3b8351fc5c9b4ca513eba6f80bc7b1d1fdad4abd13491cb824d61b08d8c0e1561b3f8", false}, {1024, sha512.New(), "37f652be867f28ed033269cbba201af2112c2b3fd334a89fd2f757938ddee815787cc61d6e24a8a33340d0f7e86ffc058816b88530766ba6e231620a130b566c", true}, {1024, sha512.New(), "37f652bf867f28ed033269cbba201af2112c2b3fd334a89fd2f757938ddee815787cc61d6e24a8a33340d0f7e86ffc058816b88530766ba6e231620a130b566d", false}, {1048576, sha512.New(), "ac1d097b4ea6f6ad7ba640275b9ac290e4828cd760a0ebf76d555463a4f505f95df4f611629539a2dd1848e7c1304633baa1826462b3c87521c0c6e3469b67af", true}, {1048576, sha512.New(), "ac1d097c4ea6f6ad7ba640275b9ac290e4828cd760a0ebf76d555463a4f505f95df4f611629539a2dd1848e7c1304633baa1826462b3c87521c0c6e3469b67af", false}, } for _, test := range tests { var expect error comparison := "Match" if !test.match { comparison = "Mismatch" expect = ErrBadChecksum } t.Run(fmt.Sprintf("With%s%s", comparison, test.sum[:8]), func(t *testing.T) { filename := fmt.Sprintf(".testChecksum-%s-%s", comparison, test.sum[:8]) defer os.Remove(filename) grabtest.WithTestServer(t, func(url string) { req := mustNewRequest(filename, url) req.SetChecksum(test.hash, grabtest.MustHexDecodeString(test.sum), true) resp := DefaultClient.Do(req) err := resp.Err() if err != expect { t.Errorf("expected error: %v, got: %v", expect, err) } // ensure mismatch file was deleted if !test.match { if _, err := os.Stat(filename); err == nil { t.Errorf("checksum failure not cleaned up: %s", filename) } else if !os.IsNotExist(err) { panic(err) } } testComplete(t, resp) }, grabtest.ContentLength(test.size)) }) } } // TestContentLength ensures that ErrBadLength is returned if a server response // does not match the requested length. func TestContentLength(t *testing.T) { size := int64(32768) testCases := []struct { Name string NoHead bool Size int64 Expect int64 Match bool }{ {"Good size in HEAD request", false, size, size, true}, {"Good size in GET request", true, size, size, true}, {"Bad size in HEAD request", false, size - 1, size, false}, {"Bad size in GET request", true, size - 1, size, false}, } for _, test := range testCases { t.Run(test.Name, func(t *testing.T) { opts := []grabtest.HandlerOption{ grabtest.ContentLength(int(test.Size)), } if test.NoHead { opts = append(opts, grabtest.MethodWhitelist("GET")) } grabtest.WithTestServer(t, func(url string) { req := mustNewRequest(".testSize-mismatch-head", url) req.Size = size resp := DefaultClient.Do(req) defer os.Remove(resp.Filename) err := resp.Err() if test.Match { if err == ErrBadLength { t.Errorf("error: %v", err) } else if err != nil { panic(err) } else if resp.Size() != size { t.Errorf("expected %v bytes, got %v bytes", size, resp.Size()) } } else { if err == nil { t.Errorf("expected: %v, got %v", ErrBadLength, err) } else if err != ErrBadLength { panic(err) } } testComplete(t, resp) }, opts...) }) } } // TestAutoResume tests segmented downloading of a large file. func TestAutoResume(t *testing.T) { segs := 8 size := 1048576 sum := grabtest.DefaultHandlerSHA256ChecksumBytes //grab/v3test.MustHexDecodeString("fbbab289f7f94b25736c58be46a994c441fd02552cc6022352e3d86d2fab7c83") filename := ".testAutoResume" defer os.Remove(filename) for i := 0; i < segs; i++ { segsize := (i + 1) * (size / segs) t.Run(fmt.Sprintf("With%vBytes", segsize), func(t *testing.T) { grabtest.WithTestServer(t, func(url string) { req := mustNewRequest(filename, url) if i == segs-1 { req.SetChecksum(sha256.New(), sum, false) } resp := mustDo(req) if i > 0 && !resp.DidResume { t.Errorf("expected Response.DidResume to be true") } testComplete(t, resp) }, grabtest.ContentLength(segsize), ) }) } t.Run("WithFailure", func(t *testing.T) { grabtest.WithTestServer(t, func(url string) { // request smaller segment req := mustNewRequest(filename, url) resp := DefaultClient.Do(req) if err := resp.Err(); err != ErrBadLength { t.Errorf("expected ErrBadLength for smaller request, got: %v", err) } }, grabtest.ContentLength(size-128), ) }) t.Run("WithNoResume", func(t *testing.T) { grabtest.WithTestServer(t, func(url string) { req := mustNewRequest(filename, url) req.NoResume = true resp := mustDo(req) if resp.DidResume { t.Errorf("expected Response.DidResume to be false") } testComplete(t, resp) }, grabtest.ContentLength(size+128), ) }) t.Run("WithNoResumeAndTruncate", func(t *testing.T) { size := size - 128 grabtest.WithTestServer(t, func(url string) { req := mustNewRequest(filename, url) req.NoResume = true resp := mustDo(req) if resp.DidResume { t.Errorf("expected Response.DidResume to be false") } if v := resp.BytesComplete(); v != int64(size) { t.Errorf("expected Response.BytesComplete: %d, got: %d", size, v) } testComplete(t, resp) }, grabtest.ContentLength(size), ) }) t.Run("WithNoContentLengthHeader", func(t *testing.T) { grabtest.WithTestServer(t, func(url string) { req := mustNewRequest(filename, url) req.SetChecksum(sha256.New(), sum, false) resp := mustDo(req) if !resp.DidResume { t.Errorf("expected Response.DidResume to be true") } if actual := resp.Size(); actual != int64(size) { t.Errorf("expected Response.Size: %d, got: %d", size, actual) } testComplete(t, resp) }, grabtest.ContentLength(size), grabtest.HeaderBlacklist("Content-Length"), ) }) t.Run("WithNoContentLengthHeaderAndChecksumFailure", func(t *testing.T) { // ref: https://github.com/cavaliergopher/grab/v3/pull/27 size := size * 2 grabtest.WithTestServer(t, func(url string) { req := mustNewRequest(filename, url) req.SetChecksum(sha256.New(), sum, false) resp := DefaultClient.Do(req) if err := resp.Err(); err != ErrBadChecksum { t.Errorf("expected error: %v, got: %v", ErrBadChecksum, err) } if !resp.DidResume { t.Errorf("expected Response.DidResume to be true") } if actual := resp.BytesComplete(); actual != int64(size) { t.Errorf("expected Response.BytesComplete: %d, got: %d", size, actual) } if actual := resp.Size(); actual != int64(size) { t.Errorf("expected Response.Size: %d, got: %d", size, actual) } testComplete(t, resp) }, grabtest.ContentLength(size), grabtest.HeaderBlacklist("Content-Length"), ) }) // TODO: test when existing file is corrupted } func TestSkipExisting(t *testing.T) { filename := ".testSkipExisting" defer os.Remove(filename) // download a file grabtest.WithTestServer(t, func(url string) { resp := mustDo(mustNewRequest(filename, url)) testComplete(t, resp) }) // redownload grabtest.WithTestServer(t, func(url string) { resp := mustDo(mustNewRequest(filename, url)) testComplete(t, resp) // ensure download was resumed if !resp.DidResume { t.Fatalf("Expected download to skip existing file, but it did not") } // ensure all bytes were resumed if resp.Size() == 0 || resp.Size() != resp.bytesResumed { t.Fatalf("Expected to skip %d bytes in redownload; got %d", resp.Size(), resp.bytesResumed) } }) // ensure checksum is performed on pre-existing file grabtest.WithTestServer(t, func(url string) { req := mustNewRequest(filename, url) req.SetChecksum(sha256.New(), []byte{0x01, 0x02, 0x03, 0x04}, true) resp := DefaultClient.Do(req) if err := resp.Err(); err != ErrBadChecksum { t.Fatalf("Expected checksum error, got: %v", err) } }) } // TestBatch executes multiple requests simultaneously and validates the // responses. func TestBatch(t *testing.T) { tests := 32 size := 32768 sum := grabtest.MustHexDecodeString("e11360251d1173650cdcd20f111d8f1ca2e412f572e8b36a4dc067121c1799b8") // test with 4 workers and with one per request grabtest.WithTestServer(t, func(url string) { for _, workerCount := range []int{4, 0} { // create requests reqs := make([]*Request, tests) for i := 0; i < len(reqs); i++ { filename := fmt.Sprintf(".testBatch.%d", i+1) reqs[i] = mustNewRequest(filename, url+fmt.Sprintf("/request_%d?", i+1)) reqs[i].Label = fmt.Sprintf("Test %d", i+1) reqs[i].SetChecksum(sha256.New(), sum, false) } // batch run responses := DefaultClient.DoBatch(workerCount, reqs...) // listen for responses Loop: for i := 0; i < len(reqs); { select { case resp := <-responses: if resp == nil { break Loop } testComplete(t, resp) if err := resp.Err(); err != nil { t.Errorf("%s: %v", resp.Filename, err) } // remove test file if resp.IsComplete() { os.Remove(resp.Filename) // ignore errors } i++ } } } }, grabtest.ContentLength(size), ) } // TestCancelContext tests that a batch of requests can be cancel using a // context.Context cancellation. Requests are cancelled in multiple states: // in-progress and unstarted. func TestCancelContext(t *testing.T) { fileSize := 134217728 tests := 256 client := NewClient() ctx, cancel := context.WithCancel(context.Background()) defer cancel() grabtest.WithTestServer(t, func(url string) { reqs := make([]*Request, tests) for i := 0; i < tests; i++ { req := mustNewRequest("", fmt.Sprintf("%s/.testCancelContext%d", url, i)) reqs[i] = req.WithContext(ctx) } respch := client.DoBatch(8, reqs...) time.Sleep(time.Millisecond * 500) cancel() for resp := range respch { defer os.Remove(resp.Filename) // err should be context.Canceled or http.errRequestCanceled if resp.Err() == nil || !strings.Contains(resp.Err().Error(), "canceled") { t.Errorf("expected '%v', got '%v'", context.Canceled, resp.Err()) } if resp.BytesComplete() >= int64(fileSize) { t.Errorf("expected Response.BytesComplete: < %d, got: %d", fileSize, resp.BytesComplete()) } } }, grabtest.ContentLength(fileSize), ) } // TestCancelHangingResponse tests that a never ending request is terminated // when the response is cancelled. func TestCancelHangingResponse(t *testing.T) { fileSize := 10 client := NewClient() grabtest.WithTestServer(t, func(url string) { req := mustNewRequest("", fmt.Sprintf("%s/.testCancelHangingResponse", url)) resp := client.Do(req) defer os.Remove(resp.Filename) // Wait for some bytes to be transferred for resp.BytesComplete() == 0 { time.Sleep(50 * time.Millisecond) } done := make(chan error) go func() { done <- resp.Cancel() }() select { case err := <-done: if err != context.Canceled { t.Errorf("Expected context.Canceled error, go: %v", err) } case <-time.After(time.Second): t.Fatal("response was not cancelled within 1s") } if resp.BytesComplete() == int64(fileSize) { t.Error("download was not supposed to be complete") } }, grabtest.RateLimiter(1), grabtest.ContentLength(fileSize), ) } // TestNestedDirectory tests that missing subdirectories are created. func TestNestedDirectory(t *testing.T) { dir := "./.testNested/one/two/three" filename := ".testNestedFile" expect := dir + "/" + filename t.Run("Create", func(t *testing.T) { grabtest.WithTestServer(t, func(url string) { resp := mustDo(mustNewRequest(expect, url+"/"+filename)) defer os.RemoveAll("./.testNested/") if resp.Filename != expect { t.Errorf("expected nested Request.Filename to be %v, got %v", expect, resp.Filename) } }) }) t.Run("No create", func(t *testing.T) { grabtest.WithTestServer(t, func(url string) { req := mustNewRequest(expect, url+"/"+filename) req.NoCreateDirectories = true resp := DefaultClient.Do(req) err := resp.Err() if !os.IsNotExist(err) { t.Errorf("expected: %v, got: %v", os.ErrNotExist, err) } }) }) } // TestRemoteTime tests that the timestamp of the downloaded file can be set // according to the timestamp of the remote file. func TestRemoteTime(t *testing.T) { filename := "./.testRemoteTime" defer os.Remove(filename) // random time between epoch and now expect := time.Unix(rand.Int63n(time.Now().Unix()), 0) grabtest.WithTestServer(t, func(url string) { resp := mustDo(mustNewRequest(filename, url)) fi, err := os.Stat(resp.Filename) if err != nil { panic(err) } actual := fi.ModTime() if !actual.Equal(expect) { t.Errorf("expected %v, got %v", expect, actual) } }, grabtest.LastModified(expect), ) } func TestResponseCode(t *testing.T) { filename := "./.testResponseCode" t.Run("With404", func(t *testing.T) { defer os.Remove(filename) grabtest.WithTestServer(t, func(url string) { req := mustNewRequest(filename, url) resp := DefaultClient.Do(req) expect := StatusCodeError(http.StatusNotFound) err := resp.Err() if err != expect { t.Errorf("expected %v, got '%v'", expect, err) } if !IsStatusCodeError(err) { t.Errorf("expected IsStatusCodeError to return true for %T: %v", err, err) } }, grabtest.StatusCodeStatic(http.StatusNotFound), ) }) t.Run("WithIgnoreNon2XX", func(t *testing.T) { defer os.Remove(filename) grabtest.WithTestServer(t, func(url string) { req := mustNewRequest(filename, url) req.IgnoreBadStatusCodes = true resp := DefaultClient.Do(req) if err := resp.Err(); err != nil { t.Errorf("expected nil, got '%v'", err) } }, grabtest.StatusCodeStatic(http.StatusNotFound), ) }) } func TestBeforeCopyHook(t *testing.T) { filename := "./.testBeforeCopy" t.Run("Noop", func(t *testing.T) { defer os.RemoveAll(filename) grabtest.WithTestServer(t, func(url string) { called := false req := mustNewRequest(filename, url) req.BeforeCopy = func(resp *Response) error { called = true if resp.IsComplete() { t.Error("Response object passed to BeforeCopy hook has already been closed") } if resp.Progress() != 0 { t.Error("Download progress already > 0 when BeforeCopy hook was called") } if resp.Duration() == 0 { t.Error("Duration was zero when BeforeCopy was called") } if resp.BytesComplete() != 0 { t.Error("BytesComplete already > 0 when BeforeCopy hook was called") } return nil } resp := DefaultClient.Do(req) if err := resp.Err(); err != nil { t.Errorf("unexpected error using BeforeCopy hook: %v", err) } testComplete(t, resp) if !called { t.Error("BeforeCopy hook was never called") } }) }) t.Run("WithError", func(t *testing.T) { defer os.RemoveAll(filename) grabtest.WithTestServer(t, func(url string) { testError := errors.New("test") req := mustNewRequest(filename, url) req.BeforeCopy = func(resp *Response) error { return testError } resp := DefaultClient.Do(req) if err := resp.Err(); err != testError { t.Errorf("expected error '%v', got '%v'", testError, err) } if resp.BytesComplete() != 0 { t.Errorf("expected 0 bytes completed for canceled BeforeCopy hook, got %d", resp.BytesComplete()) } testComplete(t, resp) }) }) // Assert that an existing local file will not be truncated prior to the // BeforeCopy hook has a chance to cancel the request t.Run("NoTruncate", func(t *testing.T) { tfile, err := ioutil.TempFile("", "grab_client_test.*.file") if err != nil { t.Fatal(err) } defer os.Remove(tfile.Name()) const size = 128 _, err = tfile.Write(bytes.Repeat([]byte("x"), size)) if err != nil { t.Fatal(err) } grabtest.WithTestServer(t, func(url string) { called := false req := mustNewRequest(tfile.Name(), url) req.NoResume = true req.BeforeCopy = func(resp *Response) error { called = true fi, err := tfile.Stat() if err != nil { t.Errorf("failed to stat temp file: %v", err) return nil } if fi.Size() != size { t.Errorf("expected existing file size of %d bytes "+ "prior to BeforeCopy hook, got %d", size, fi.Size()) } return nil } resp := DefaultClient.Do(req) if err := resp.Err(); err != nil { t.Errorf("unexpected error using BeforeCopy hook: %v", err) } testComplete(t, resp) if !called { t.Error("BeforeCopy hook was never called") } }) }) } func TestAfterCopyHook(t *testing.T) { filename := "./.testAfterCopy" t.Run("Noop", func(t *testing.T) { defer os.RemoveAll(filename) grabtest.WithTestServer(t, func(url string) { called := false req := mustNewRequest(filename, url) req.AfterCopy = func(resp *Response) error { called = true if resp.IsComplete() { t.Error("Response object passed to AfterCopy hook has already been closed") } if resp.Progress() <= 0 { t.Error("Download progress was 0 when AfterCopy hook was called") } if resp.Duration() == 0 { t.Error("Duration was zero when AfterCopy was called") } if resp.BytesComplete() <= 0 { t.Error("BytesComplete was 0 when AfterCopy hook was called") } return nil } resp := DefaultClient.Do(req) if err := resp.Err(); err != nil { t.Errorf("unexpected error using AfterCopy hook: %v", err) } testComplete(t, resp) if !called { t.Error("AfterCopy hook was never called") } }) }) t.Run("WithError", func(t *testing.T) { defer os.RemoveAll(filename) grabtest.WithTestServer(t, func(url string) { testError := errors.New("test") req := mustNewRequest(filename, url) req.AfterCopy = func(resp *Response) error { return testError } resp := DefaultClient.Do(req) if err := resp.Err(); err != testError { t.Errorf("expected error '%v', got '%v'", testError, err) } if resp.BytesComplete() <= 0 { t.Errorf("ByteCompleted was %d after AfterCopy hook was called", resp.BytesComplete()) } testComplete(t, resp) }) }) } func TestIssue37(t *testing.T) { // ref: https://github.com/cavaliergopher/grab/v3/issues/37 filename := "./.testIssue37" largeSize := int64(2097152) smallSize := int64(1048576) defer os.RemoveAll(filename) // download large file grabtest.WithTestServer(t, func(url string) { resp := mustDo(mustNewRequest(filename, url)) if resp.Size() != largeSize { t.Errorf("expected response size: %d, got: %d", largeSize, resp.Size()) } }, grabtest.ContentLength(int(largeSize))) // download new, smaller version of same file grabtest.WithTestServer(t, func(url string) { req := mustNewRequest(filename, url) req.NoResume = true resp := mustDo(req) if resp.Size() != smallSize { t.Errorf("expected response size: %d, got: %d", smallSize, resp.Size()) } // local file should have truncated and not resumed if resp.DidResume { t.Errorf("expected download to truncate, resumed instead") } }, grabtest.ContentLength(int(smallSize))) fi, err := os.Stat(filename) if err != nil { t.Fatal(err) } if fi.Size() != int64(smallSize) { t.Errorf("expected file size %d, got %d", smallSize, fi.Size()) } } // TestHeadBadStatus validates that HEAD requests that return non-200 can be // ignored and succeed if the GET requests succeeeds. // // Fixes: https://github.com/cavaliergopher/grab/v3/issues/43 func TestHeadBadStatus(t *testing.T) { expect := http.StatusOK filename := ".testIssue43" statusFunc := func(r *http.Request) int { if r.Method == "HEAD" { return http.StatusForbidden } return http.StatusOK } grabtest.WithTestServer(t, func(url string) { testURL := fmt.Sprintf("%s/%s", url, filename) resp := mustDo(mustNewRequest("", testURL)) if resp.HTTPResponse.StatusCode != expect { t.Errorf( "expected status code: %d, got:% d", expect, resp.HTTPResponse.StatusCode) } }, grabtest.StatusCode(statusFunc), ) } // TestMissingContentLength ensures that the Response.Size is correct for // transfers where the remote server does not send a Content-Length header. // // TestAutoResume also covers cases with checksum validation. // // Kudos to Setnička Jiří for identifying and raising // a solution to this issue. Ref: https://github.com/cavaliergopher/grab/v3/pull/27 func TestMissingContentLength(t *testing.T) { // expectSize must be sufficiently large that DefaultClient.Do won't prefetch // the entire body and compute ContentLength before returning a Response. expectSize := 1048576 opts := []grabtest.HandlerOption{ grabtest.ContentLength(expectSize), grabtest.HeaderBlacklist("Content-Length"), grabtest.TimeToFirstByte(time.Millisecond * 100), // delay for initial read } grabtest.WithTestServer(t, func(url string) { req := mustNewRequest(".testMissingContentLength", url) req.SetChecksum( md5.New(), grabtest.DefaultHandlerMD5ChecksumBytes, false) resp := DefaultClient.Do(req) // ensure remote server is not sending content-length header if v := resp.HTTPResponse.Header.Get("Content-Length"); v != "" { panic(fmt.Sprintf("http header content length must be empty, got: %s", v)) } if v := resp.HTTPResponse.ContentLength; v != -1 { panic(fmt.Sprintf("http response content length must be -1, got: %d", v)) } // before completion, response size should be -1 if resp.Size() != -1 { t.Errorf("expected response size: -1, got: %d", resp.Size()) } // block for completion if err := resp.Err(); err != nil { panic(err) } // on completion, response size should be actual transfer size if resp.Size() != int64(expectSize) { t.Errorf("expected response size: %d, got: %d", expectSize, resp.Size()) } }, opts...) } func TestNoStore(t *testing.T) { filename := ".testSubdir/testNoStore" t.Run("DefaultCase", func(t *testing.T) { grabtest.WithTestServer(t, func(url string) { req := mustNewRequest(filename, url) req.NoStore = true req.SetChecksum(md5.New(), grabtest.DefaultHandlerMD5ChecksumBytes, true) resp := mustDo(req) // ensure Response.Bytes is correct and can be reread b, err := resp.Bytes() if err != nil { panic(err) } grabtest.AssertSHA256Sum( t, grabtest.DefaultHandlerSHA256ChecksumBytes, bytes.NewReader(b), ) // ensure Response.Open stream is correct and can be reread r, err := resp.Open() if err != nil { panic(err) } defer r.Close() grabtest.AssertSHA256Sum( t, grabtest.DefaultHandlerSHA256ChecksumBytes, r, ) // Response.Filename should still be set if resp.Filename != filename { t.Errorf("expected Response.Filename: %s, got: %s", filename, resp.Filename) } // ensure no files were written paths := []string{ filename, filepath.Base(filename), filepath.Dir(filename), resp.Filename, filepath.Base(resp.Filename), filepath.Dir(resp.Filename), } for _, path := range paths { _, err := os.Stat(path) if !os.IsNotExist(err) { t.Errorf( "expect error: %v, got: %v, for path: %s", os.ErrNotExist, err, path) } } }) }) t.Run("ChecksumValidation", func(t *testing.T) { grabtest.WithTestServer(t, func(url string) { req := mustNewRequest("", url) req.NoStore = true req.SetChecksum( md5.New(), grabtest.MustHexDecodeString("deadbeefcafebabe"), true) resp := DefaultClient.Do(req) if err := resp.Err(); err != ErrBadChecksum { t.Errorf("expected error: %v, got: %v", ErrBadChecksum, err) } }) }) } grab-3.0.1/v3/cmd/000077500000000000000000000000001416617544600135375ustar00rootroot00000000000000grab-3.0.1/v3/cmd/grab/000077500000000000000000000000001416617544600144525ustar00rootroot00000000000000grab-3.0.1/v3/cmd/grab/.gitignore000066400000000000000000000000051416617544600164350ustar00rootroot00000000000000grab grab-3.0.1/v3/cmd/grab/Makefile000066400000000000000000000003011416617544600161040ustar00rootroot00000000000000SOURCES = main.go all : grab grab: $(SOURCES) go build -o grab $(SOURCES) clean: go clean -x rm -vf grab check: go test -v . install: go install -v . .PHONY: all clean check install grab-3.0.1/v3/cmd/grab/main.go000066400000000000000000000010731416617544600157260ustar00rootroot00000000000000package main import ( "context" "fmt" "os" "github.com/cavaliergopher/grab/v3/pkg/grabui" ) func main() { // validate command args if len(os.Args) < 2 { fmt.Fprintf(os.Stderr, "usage: %s url...\n", os.Args[0]) os.Exit(1) } urls := os.Args[1:] // download files respch, err := grabui.GetBatch(context.Background(), 0, ".", urls...) if err != nil { fmt.Fprint(os.Stderr, err) os.Exit(1) } // return the number of failed downloads as exit code failed := 0 for resp := range respch { if resp.Err() != nil { failed++ } } os.Exit(failed) } grab-3.0.1/v3/doc.go000066400000000000000000000034521416617544600140740ustar00rootroot00000000000000/* Package grab provides a HTTP download manager implementation. Get is the most simple way to download a file: resp, err := grab.Get("/tmp", "http://example.com/example.zip") // ... Get will download the given URL and save it to the given destination directory. The destination filename will be determined automatically by grab using Content-Disposition headers returned by the remote server, or by inspecting the requested URL path. An empty destination string or "." means the transfer will be stored in the current working directory. If a destination file already exists, grab will assume it is a complete or partially complete download of the requested file. If the remote server supports resuming interrupted downloads, grab will resume downloading from the end of the partial file. If the server does not support resumed downloads, the file will be retransferred in its entirety. If the file is already complete, grab will return successfully. For control over the HTTP client, destination path, auto-resume, checksum validation and other settings, create a Client: client := grab.NewClient() client.HTTPClient.Transport.DisableCompression = true req, err := grab.NewRequest("/tmp", "http://example.com/example.zip") // ... req.NoResume = true req.HTTPRequest.Header.Set("Authorization", "Basic YWxhZGRpbjpvcGVuc2VzYW1l") resp := client.Do(req) // ... You can monitor the progress of downloads while they are transferring: client := grab.NewClient() req, err := grab.NewRequest("", "http://example.com/example.zip") // ... resp := client.Do(req) t := time.NewTicker(time.Second) defer t.Stop() for { select { case <-t.C: fmt.Printf("%.02f%% complete\n", resp.Progress()) case <-resp.Done: if err := resp.Err(); err != nil { // ... } // ... return } } */ package grab grab-3.0.1/v3/error.go000066400000000000000000000025501416617544600144560ustar00rootroot00000000000000package grab import ( "errors" "fmt" "net/http" ) var ( // ErrBadLength indicates that the server response or an existing file does // not match the expected content length. ErrBadLength = errors.New("bad content length") // ErrBadChecksum indicates that a downloaded file failed to pass checksum // validation. ErrBadChecksum = errors.New("checksum mismatch") // ErrNoFilename indicates that a reasonable filename could not be // automatically determined using the URL or response headers from a server. ErrNoFilename = errors.New("no filename could be determined") // ErrNoTimestamp indicates that a timestamp could not be automatically // determined using the response headers from the remote server. ErrNoTimestamp = errors.New("no timestamp could be determined for the remote file") // ErrFileExists indicates that the destination path already exists. ErrFileExists = errors.New("file exists") ) // StatusCodeError indicates that the server response had a status code that // was not in the 200-299 range (after following any redirects). type StatusCodeError int func (err StatusCodeError) Error() string { return fmt.Sprintf("server returned %d %s", err, http.StatusText(int(err))) } // IsStatusCodeError returns true if the given error is of type StatusCodeError. func IsStatusCodeError(err error) bool { _, ok := err.(StatusCodeError) return ok } grab-3.0.1/v3/example_client_test.go000066400000000000000000000037621416617544600173630ustar00rootroot00000000000000package grab import ( "fmt" "sync" ) func ExampleClient_Do() { client := NewClient() req, err := NewRequest("/tmp", "http://example.com/example.zip") if err != nil { panic(err) } resp := client.Do(req) if err := resp.Err(); err != nil { panic(err) } fmt.Println("Download saved to", resp.Filename) } // This example uses DoChannel to create a Producer/Consumer model for // downloading multiple files concurrently. This is similar to how DoBatch uses // DoChannel under the hood except that it allows the caller to continually send // new requests until they wish to close the request channel. func ExampleClient_DoChannel() { // create a request and a buffered response channel reqch := make(chan *Request) respch := make(chan *Response, 10) // start 4 workers client := NewClient() wg := sync.WaitGroup{} for i := 0; i < 4; i++ { wg.Add(1) go func() { client.DoChannel(reqch, respch) wg.Done() }() } go func() { // send requests for i := 0; i < 10; i++ { url := fmt.Sprintf("http://example.com/example%d.zip", i+1) req, err := NewRequest("/tmp", url) if err != nil { panic(err) } reqch <- req } close(reqch) // wait for workers to finish wg.Wait() close(respch) }() // check each response for resp := range respch { // block until complete if err := resp.Err(); err != nil { panic(err) } fmt.Printf("Downloaded %s to %s\n", resp.Request.URL(), resp.Filename) } } func ExampleClient_DoBatch() { // create multiple download requests reqs := make([]*Request, 0) for i := 0; i < 10; i++ { url := fmt.Sprintf("http://example.com/example%d.zip", i+1) req, err := NewRequest("/tmp", url) if err != nil { panic(err) } reqs = append(reqs, req) } // start downloads with 4 workers client := NewClient() respch := client.DoBatch(4, reqs...) // check each response for resp := range respch { if err := resp.Err(); err != nil { panic(err) } fmt.Printf("Downloaded %s to %s\n", resp.Request.URL(), resp.Filename) } } grab-3.0.1/v3/example_request_test.go000066400000000000000000000020711416617544600175650ustar00rootroot00000000000000package grab import ( "context" "crypto/sha256" "encoding/hex" "fmt" "time" ) func ExampleRequest_WithContext() { // create context with a 100ms timeout ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() // create download request with context req, err := NewRequest("", "http://example.com/example.zip") if err != nil { panic(err) } req = req.WithContext(ctx) // send download request resp := DefaultClient.Do(req) if err := resp.Err(); err != nil { fmt.Println("error: request cancelled") } // Output: // error: request cancelled } func ExampleRequest_SetChecksum() { // create download request req, err := NewRequest("", "http://example.com/example.zip") if err != nil { panic(err) } // set request checksum sum, err := hex.DecodeString("33daf4c03f86120fdfdc66bddf6bfff4661c7ca11c5da473e537f4d69b470e57") if err != nil { panic(err) } req.SetChecksum(sha256.New(), sum, true) // download and validate file resp := DefaultClient.Do(req) if err := resp.Err(); err != nil { panic(err) } } grab-3.0.1/v3/go.mod000066400000000000000000000000621416617544600141000ustar00rootroot00000000000000module github.com/cavaliergopher/grab/v3 go 1.14 grab-3.0.1/v3/grab.go000066400000000000000000000036131416617544600142410ustar00rootroot00000000000000package grab import ( "fmt" "os" ) // Get sends a HTTP request and downloads the content of the requested URL to // the given destination file path. The caller is blocked until the download is // completed, successfully or otherwise. // // An error is returned if caused by client policy (such as CheckRedirect), or // if there was an HTTP protocol or IO error. // // For non-blocking calls or control over HTTP client headers, redirect policy, // and other settings, create a Client instead. func Get(dst, urlStr string) (*Response, error) { req, err := NewRequest(dst, urlStr) if err != nil { return nil, err } resp := DefaultClient.Do(req) return resp, resp.Err() } // GetBatch sends multiple HTTP requests and downloads the content of the // requested URLs to the given destination directory using the given number of // concurrent worker goroutines. // // The Response for each requested URL is sent through the returned Response // channel, as soon as a worker receives a response from the remote server. The // Response can then be used to track the progress of the download while it is // in progress. // // The returned Response channel will be closed by Grab, only once all downloads // have completed or failed. // // If an error occurs during any download, it will be available via call to the // associated Response.Err. // // For control over HTTP client headers, redirect policy, and other settings, // create a Client instead. func GetBatch(workers int, dst string, urlStrs ...string) (<-chan *Response, error) { fi, err := os.Stat(dst) if err != nil { return nil, err } if !fi.IsDir() { return nil, fmt.Errorf("destination is not a directory") } reqs := make([]*Request, len(urlStrs)) for i := 0; i < len(urlStrs); i++ { req, err := NewRequest(dst, urlStrs[i]) if err != nil { return nil, err } reqs[i] = req } ch := DefaultClient.DoBatch(workers, reqs...) return ch, nil } grab-3.0.1/v3/grab_test.go000066400000000000000000000024721416617544600153020ustar00rootroot00000000000000package grab import ( "fmt" "io/ioutil" "log" "os" "testing" "github.com/cavaliergopher/grab/v3/pkg/grabtest" ) func TestMain(m *testing.M) { os.Exit(func() int { // chdir to temp so test files downloaded to pwd are isolated and cleaned up cwd, err := os.Getwd() if err != nil { panic(err) } tmpDir, err := ioutil.TempDir("", "grab-") if err != nil { panic(err) } if err := os.Chdir(tmpDir); err != nil { panic(err) } defer func() { os.Chdir(cwd) if err := os.RemoveAll(tmpDir); err != nil { panic(err) } }() return m.Run() }()) } // TestGet tests grab.Get func TestGet(t *testing.T) { filename := ".testGet" defer os.Remove(filename) grabtest.WithTestServer(t, func(url string) { resp, err := Get(filename, url) if err != nil { t.Fatalf("error in Get(): %v", err) } testComplete(t, resp) }) } func ExampleGet() { // download a file to /tmp resp, err := Get("/tmp", "http://example.com/example.zip") if err != nil { log.Fatal(err) } fmt.Println("Download saved to", resp.Filename) } func mustNewRequest(dst, urlStr string) *Request { req, err := NewRequest(dst, urlStr) if err != nil { panic(err) } return req } func mustDo(req *Request) *Response { resp := DefaultClient.Do(req) if err := resp.Err(); err != nil { panic(err) } return resp } grab-3.0.1/v3/pkg/000077500000000000000000000000001416617544600135555ustar00rootroot00000000000000grab-3.0.1/v3/pkg/bps/000077500000000000000000000000001416617544600143415ustar00rootroot00000000000000grab-3.0.1/v3/pkg/bps/bps.go000066400000000000000000000033531416617544600154600ustar00rootroot00000000000000/* Package bps provides gauges for calculating the Bytes Per Second transfer rate of data streams. */ package bps import ( "context" "time" ) // Gauge is the common interface for all BPS gauges in this package. Given a // set of samples over time, each gauge type can be used to measure the Bytes // Per Second transfer rate of a data stream. // // All samples must monotonically increase in timestamp and value. Each sample // should represent the total number of bytes sent in a stream, rather than // accounting for the number sent since the last sample. // // To ensure a gauge can report progress as quickly as possible, take an initial // sample when your stream first starts. // // All gauge implementations are safe for concurrent use. type Gauge interface { // Sample adds a new sample of the progress of the monitored stream. Sample(t time.Time, n int64) // BPS returns the calculated Bytes Per Second rate of the monitored stream. BPS() float64 } // SampleFunc is used by Watch to take periodic samples of a monitored stream. type SampleFunc func() (n int64) // Watch will periodically call the given SampleFunc to sample the progress of // a monitored stream and update the given gauge. SampleFunc should return the // total number of bytes transferred by the stream since it started. // // Watch is a blocking call and should typically be called in a new goroutine. // To prevent the goroutine from leaking, make sure to cancel the given context // once the stream is completed or canceled. func Watch(ctx context.Context, g Gauge, f SampleFunc, interval time.Duration) { g.Sample(time.Now(), f()) t := time.NewTicker(interval) defer t.Stop() for { select { case <-ctx.Done(): return case now := <-t.C: g.Sample(now, f()) } } } grab-3.0.1/v3/pkg/bps/sma.go000066400000000000000000000043351416617544600154550ustar00rootroot00000000000000package bps import ( "sync" "time" ) // NewSMA returns a gauge that uses a Simple Moving Average with the given // number of samples to measure the bytes per second of a byte stream. // // BPS is computed using the timestamp of the most recent and oldest sample in // the sample buffer. When a new sample is added, the oldest sample is dropped // if the sample count exceeds maxSamples. // // The gauge does not account for any latency in arrival time of new samples or // the desired window size. Any variance in the arrival of samples will result // in a BPS measurement that is correct for the submitted samples, but over a // varying time window. // // maxSamples should be equal to 1 + (window size / sampling interval) where // window size is the number of seconds over which the moving average is // smoothed and sampling interval is the number of seconds between each sample. // // For example, if you want a five second window, sampling once per second, // maxSamples should be 1 + 5/1 = 6. func NewSMA(maxSamples int) Gauge { if maxSamples < 2 { panic("sample count must be greater than 1") } return &sma{ maxSamples: uint64(maxSamples), samples: make([]int64, maxSamples), timestamps: make([]time.Time, maxSamples), } } type sma struct { mu sync.Mutex index uint64 maxSamples uint64 sampleCount uint64 samples []int64 timestamps []time.Time } func (c *sma) Sample(t time.Time, n int64) { c.mu.Lock() defer c.mu.Unlock() c.timestamps[c.index] = t c.samples[c.index] = n c.index = (c.index + 1) % c.maxSamples // prevent integer overflow in sampleCount. Values greater or equal to // maxSamples have the same semantic meaning. c.sampleCount++ if c.sampleCount > c.maxSamples { c.sampleCount = c.maxSamples } } func (c *sma) BPS() float64 { c.mu.Lock() defer c.mu.Unlock() // we need two samples to start if c.sampleCount < 2 { return 0 } // First sample is always the oldest until ring buffer first overflows oldest := c.index if c.sampleCount < c.maxSamples { oldest = 0 } newest := (c.index + c.maxSamples - 1) % c.maxSamples seconds := c.timestamps[newest].Sub(c.timestamps[oldest]).Seconds() bytes := float64(c.samples[newest] - c.samples[oldest]) return bytes / seconds } grab-3.0.1/v3/pkg/bps/sma_test.go000066400000000000000000000021431416617544600165070ustar00rootroot00000000000000package bps import ( "testing" "time" ) type Sample struct { N int64 Expect float64 } func getSimpleSamples(sampleCount, rate int) []Sample { a := make([]Sample, sampleCount) for i := 1; i < sampleCount; i++ { a[i] = Sample{N: int64(i * rate), Expect: float64(rate)} } return a } type SampleSetTest struct { Gauge Gauge Interval time.Duration Samples []Sample } func (c *SampleSetTest) Run(t *testing.T) { ts := time.Unix(0, 0) for i, sample := range c.Samples { c.Gauge.Sample(ts, sample.N) if actual := c.Gauge.BPS(); actual != sample.Expect { t.Errorf("expected: Gauge.BPS() → %0.2f, got %0.2f in test %d", sample.Expect, actual, i+1) } ts = ts.Add(c.Interval) } } func TestSMA_SimpleSteadyCase(t *testing.T) { test := &SampleSetTest{ Interval: time.Second, Samples: getSimpleSamples(100000, 3), } t.Run("SmallSampleSize", func(t *testing.T) { test.Gauge = NewSMA(2) test.Run(t) }) t.Run("RegularSize", func(t *testing.T) { test.Gauge = NewSMA(6) test.Run(t) }) t.Run("LargeSampleSize", func(t *testing.T) { test.Gauge = NewSMA(1000) test.Run(t) }) } grab-3.0.1/v3/pkg/grabtest/000077500000000000000000000000001416617544600153705ustar00rootroot00000000000000grab-3.0.1/v3/pkg/grabtest/assert.go000066400000000000000000000042411416617544600172210ustar00rootroot00000000000000package grabtest import ( "bytes" "crypto/sha256" "fmt" "io" "io/ioutil" "net/http" "testing" ) func AssertHTTPResponseStatusCode(t *testing.T, resp *http.Response, expect int) (ok bool) { if resp.StatusCode != expect { t.Errorf("expected status code: %d, got: %d", expect, resp.StatusCode) return } ok = true return true } func AssertHTTPResponseHeader(t *testing.T, resp *http.Response, key, format string, a ...interface{}) (ok bool) { expect := fmt.Sprintf(format, a...) actual := resp.Header.Get(key) if actual != expect { t.Errorf("expected header %s: %s, got: %s", key, expect, actual) return } ok = true return } func AssertHTTPResponseContentLength(t *testing.T, resp *http.Response, n int64) (ok bool) { ok = true if resp.ContentLength != n { ok = false t.Errorf("expected header Content-Length: %d, got: %d", n, resp.ContentLength) } if !AssertHTTPResponseBodyLength(t, resp, n) { ok = false } return } func AssertHTTPResponseBodyLength(t *testing.T, resp *http.Response, n int64) (ok bool) { defer func() { if err := resp.Body.Close(); err != nil { panic(err) } }() b, err := ioutil.ReadAll(resp.Body) if err != nil { panic(err) } if int64(len(b)) != n { ok = false t.Errorf("expected body length: %d, got: %d", n, len(b)) } return } func MustHTTPNewRequest(method, url string, body io.Reader) *http.Request { req, err := http.NewRequest(method, url, body) if err != nil { panic(err) } return req } func MustHTTPDo(req *http.Request) *http.Response { resp, err := http.DefaultClient.Do(req) if err != nil { panic(err) } return resp } func MustHTTPDoWithClose(req *http.Request) *http.Response { resp := MustHTTPDo(req) if _, err := io.Copy(ioutil.Discard, resp.Body); err != nil { panic(err) } if err := resp.Body.Close(); err != nil { panic(err) } return resp } func AssertSHA256Sum(t *testing.T, sum []byte, r io.Reader) (ok bool) { h := sha256.New() if _, err := io.Copy(h, r); err != nil { panic(err) } computed := h.Sum(nil) ok = bytes.Equal(sum, computed) if !ok { t.Errorf( "expected checksum: %s, got: %s", MustHexEncodeString(sum), MustHexEncodeString(computed), ) } return } grab-3.0.1/v3/pkg/grabtest/handler.go000066400000000000000000000072461416617544600173450ustar00rootroot00000000000000package grabtest import ( "bufio" "fmt" "net/http" "net/http/httptest" "testing" "time" ) var ( DefaultHandlerContentLength = 1 << 20 DefaultHandlerMD5Checksum = "c35cc7d8d91728a0cb052831bc4ef372" DefaultHandlerMD5ChecksumBytes = MustHexDecodeString(DefaultHandlerMD5Checksum) DefaultHandlerSHA256Checksum = "fbbab289f7f94b25736c58be46a994c441fd02552cc6022352e3d86d2fab7c83" DefaultHandlerSHA256ChecksumBytes = MustHexDecodeString(DefaultHandlerSHA256Checksum) ) type StatusCodeFunc func(req *http.Request) int type handler struct { statusCodeFunc StatusCodeFunc methodWhitelist []string headerBlacklist []string contentLength int acceptRanges bool attachmentFilename string lastModified time.Time ttfb time.Duration rateLimiter *time.Ticker } func NewHandler(options ...HandlerOption) (http.Handler, error) { h := &handler{ statusCodeFunc: func(req *http.Request) int { return http.StatusOK }, methodWhitelist: []string{"GET", "HEAD"}, contentLength: DefaultHandlerContentLength, acceptRanges: true, } for _, option := range options { if err := option(h); err != nil { return nil, err } } return h, nil } func WithTestServer(t *testing.T, f func(url string), options ...HandlerOption) { h, err := NewHandler(options...) if err != nil { t.Fatalf("unable to create test server handler: %v", err) return } s := httptest.NewServer(h) defer func() { h.(*handler).close() s.Close() }() f(s.URL) } func (h *handler) close() { if h.rateLimiter != nil { h.rateLimiter.Stop() } } func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // delay response if h.ttfb > 0 { time.Sleep(h.ttfb) } // validate request method allowed := false for _, m := range h.methodWhitelist { if r.Method == m { allowed = true break } } if !allowed { httpError(w, http.StatusMethodNotAllowed) return } // set server options if h.acceptRanges { w.Header().Set("Accept-Ranges", "bytes") } // set attachment filename if h.attachmentFilename != "" { w.Header().Set( "Content-Disposition", fmt.Sprintf("attachment;filename=\"%s\"", h.attachmentFilename), ) } // set last modified timestamp lastMod := time.Now() if !h.lastModified.IsZero() { lastMod = h.lastModified } w.Header().Set("Last-Modified", lastMod.Format(http.TimeFormat)) // set content-length offset := 0 if h.acceptRanges { if reqRange := r.Header.Get("Range"); reqRange != "" { if _, err := fmt.Sscanf(reqRange, "bytes=%d-", &offset); err != nil { httpError(w, http.StatusBadRequest) return } if offset >= h.contentLength { httpError(w, http.StatusRequestedRangeNotSatisfiable) return } } } w.Header().Set("Content-Length", fmt.Sprintf("%d", h.contentLength-offset)) // apply header blacklist for _, key := range h.headerBlacklist { w.Header().Del(key) } // send header and status code w.WriteHeader(h.statusCodeFunc(r)) // send body if r.Method == "GET" { // use buffered io to reduce overhead on the reader bw := bufio.NewWriterSize(w, 4096) for i := offset; !isRequestClosed(r) && i < h.contentLength; i++ { bw.Write([]byte{byte(i)}) if h.rateLimiter != nil { bw.Flush() w.(http.Flusher).Flush() // force the server to send the data to the client select { case <-h.rateLimiter.C: case <-r.Context().Done(): } } } if !isRequestClosed(r) { bw.Flush() } } } // isRequestClosed returns true if the client request has been canceled. func isRequestClosed(r *http.Request) bool { return r.Context().Err() != nil } func httpError(w http.ResponseWriter, code int) { http.Error(w, http.StatusText(code), code) } grab-3.0.1/v3/pkg/grabtest/handler_option.go000066400000000000000000000033701416617544600207270ustar00rootroot00000000000000package grabtest import ( "errors" "net/http" "time" ) type HandlerOption func(*handler) error func StatusCodeStatic(code int) HandlerOption { return func(h *handler) error { return StatusCode(func(req *http.Request) int { return code })(h) } } func StatusCode(f StatusCodeFunc) HandlerOption { return func(h *handler) error { if f == nil { return errors.New("status code function cannot be nil") } h.statusCodeFunc = f return nil } } func MethodWhitelist(methods ...string) HandlerOption { return func(h *handler) error { h.methodWhitelist = methods return nil } } func HeaderBlacklist(headers ...string) HandlerOption { return func(h *handler) error { h.headerBlacklist = headers return nil } } func ContentLength(n int) HandlerOption { return func(h *handler) error { if n < 0 { return errors.New("content length must be zero or greater") } h.contentLength = n return nil } } func AcceptRanges(enabled bool) HandlerOption { return func(h *handler) error { h.acceptRanges = enabled return nil } } func LastModified(t time.Time) HandlerOption { return func(h *handler) error { h.lastModified = t.UTC() return nil } } func TimeToFirstByte(d time.Duration) HandlerOption { return func(h *handler) error { if d < 1 { return errors.New("time to first byte must be greater than zero") } h.ttfb = d return nil } } func RateLimiter(bps int) HandlerOption { return func(h *handler) error { if bps < 1 { return errors.New("bytes per second must be greater than zero") } h.rateLimiter = time.NewTicker(time.Second / time.Duration(bps)) return nil } } func AttachmentFilename(filename string) HandlerOption { return func(h *handler) error { h.attachmentFilename = filename return nil } } grab-3.0.1/v3/pkg/grabtest/handler_test.go000066400000000000000000000077351416617544600204070ustar00rootroot00000000000000package grabtest import ( "fmt" "io/ioutil" "net/http" "testing" "time" ) func TestHandlerDefaults(t *testing.T) { WithTestServer(t, func(url string) { resp := MustHTTPDo(MustHTTPNewRequest("GET", url, nil)) AssertHTTPResponseStatusCode(t, resp, http.StatusOK) AssertHTTPResponseContentLength(t, resp, 1048576) AssertHTTPResponseHeader(t, resp, "Accept-Ranges", "bytes") }) } func TestHandlerMethodWhitelist(t *testing.T) { tests := []struct { Whitelist []string Method string ExpectStatusCode int }{ {[]string{"GET", "HEAD"}, "GET", http.StatusOK}, {[]string{"GET", "HEAD"}, "HEAD", http.StatusOK}, {[]string{"GET"}, "HEAD", http.StatusMethodNotAllowed}, {[]string{"HEAD"}, "GET", http.StatusMethodNotAllowed}, } for _, test := range tests { WithTestServer(t, func(url string) { resp := MustHTTPDoWithClose(MustHTTPNewRequest(test.Method, url, nil)) AssertHTTPResponseStatusCode(t, resp, test.ExpectStatusCode) }, MethodWhitelist(test.Whitelist...)) } } func TestHandlerHeaderBlacklist(t *testing.T) { contentLength := 4096 WithTestServer(t, func(url string) { resp := MustHTTPDo(MustHTTPNewRequest("GET", url, nil)) defer resp.Body.Close() if resp.ContentLength != -1 { t.Errorf("expected Response.ContentLength: -1, got: %d", resp.ContentLength) } AssertHTTPResponseHeader(t, resp, "Content-Length", "") AssertHTTPResponseBodyLength(t, resp, int64(contentLength)) }, ContentLength(contentLength), HeaderBlacklist("Content-Length"), ) } func TestHandlerStatusCodeFuncs(t *testing.T) { expect := 418 // I'm a teapot WithTestServer(t, func(url string) { resp := MustHTTPDo(MustHTTPNewRequest("GET", url, nil)) defer resp.Body.Close() AssertHTTPResponseStatusCode(t, resp, expect) }, StatusCode(func(req *http.Request) int { return expect }), ) } func TestHandlerContentLength(t *testing.T) { tests := []struct { Method string ContentLength int ExpectHeaderLen int64 ExpectBodyLen int }{ {"GET", 321, 321, 321}, {"HEAD", 321, 321, 0}, {"GET", 0, 0, 0}, {"HEAD", 0, 0, 0}, } for _, test := range tests { WithTestServer(t, func(url string) { resp := MustHTTPDo(MustHTTPNewRequest(test.Method, url, nil)) defer resp.Body.Close() AssertHTTPResponseHeader(t, resp, "Content-Length", "%d", test.ExpectHeaderLen) b, err := ioutil.ReadAll(resp.Body) if err != nil { panic(err) } if len(b) != test.ExpectBodyLen { t.Errorf( "expected body length: %v, got: %v, in: %v", test.ExpectBodyLen, len(b), test, ) } }, ContentLength(test.ContentLength), ) } } func TestHandlerAcceptRanges(t *testing.T) { header := "Accept-Ranges" n := 128 t.Run("Enabled", func(t *testing.T) { WithTestServer(t, func(url string) { req := MustHTTPNewRequest("GET", url, nil) req.Header.Set("Range", fmt.Sprintf("bytes=%d-", n/2)) resp := MustHTTPDo(req) AssertHTTPResponseHeader(t, resp, header, "bytes") AssertHTTPResponseContentLength(t, resp, int64(n/2)) }, ContentLength(n), ) }) t.Run("Disabled", func(t *testing.T) { WithTestServer(t, func(url string) { req := MustHTTPNewRequest("GET", url, nil) req.Header.Set("Range", fmt.Sprintf("bytes=%d-", n/2)) resp := MustHTTPDo(req) AssertHTTPResponseHeader(t, resp, header, "") AssertHTTPResponseContentLength(t, resp, int64(n)) }, AcceptRanges(false), ContentLength(n), ) }) } func TestHandlerAttachmentFilename(t *testing.T) { filename := "foo.pdf" WithTestServer(t, func(url string) { resp := MustHTTPDoWithClose(MustHTTPNewRequest("GET", url, nil)) AssertHTTPResponseHeader(t, resp, "Content-Disposition", `attachment;filename="%s"`, filename) }, AttachmentFilename(filename), ) } func TestHandlerLastModified(t *testing.T) { WithTestServer(t, func(url string) { resp := MustHTTPDoWithClose(MustHTTPNewRequest("GET", url, nil)) AssertHTTPResponseHeader(t, resp, "Last-Modified", "Thu, 29 Nov 1973 21:33:09 GMT") }, LastModified(time.Unix(123456789, 0)), ) } grab-3.0.1/v3/pkg/grabtest/util.go000066400000000000000000000004021416617544600166700ustar00rootroot00000000000000package grabtest import "encoding/hex" func MustHexDecodeString(s string) (b []byte) { var err error b, err = hex.DecodeString(s) if err != nil { panic(err) } return } func MustHexEncodeString(b []byte) (s string) { return hex.EncodeToString(b) } grab-3.0.1/v3/pkg/grabui/000077500000000000000000000000001416617544600150265ustar00rootroot00000000000000grab-3.0.1/v3/pkg/grabui/console_client.go000066400000000000000000000065461416617544600203700ustar00rootroot00000000000000package grabui import ( "context" "fmt" "os" "sync" "time" "github.com/cavaliergopher/grab/v3" ) type ConsoleClient struct { mu sync.Mutex client *grab.Client succeeded, failed, inProgress int responses []*grab.Response } func NewConsoleClient(client *grab.Client) *ConsoleClient { return &ConsoleClient{ client: client, } } func (c *ConsoleClient) Do( ctx context.Context, workers int, reqs ...*grab.Request, ) <-chan *grab.Response { // buffer size prevents slow receivers causing back pressure pump := make(chan *grab.Response, len(reqs)) go func() { c.mu.Lock() defer c.mu.Unlock() c.failed = 0 c.inProgress = 0 c.succeeded = 0 c.responses = make([]*grab.Response, 0, len(reqs)) if c.client == nil { c.client = grab.DefaultClient } fmt.Printf("Downloading %d files...\n", len(reqs)) respch := c.client.DoBatch(workers, reqs...) t := time.NewTicker(200 * time.Millisecond) defer t.Stop() Loop: for { select { case <-ctx.Done(): break Loop case resp := <-respch: if resp != nil { // a new response has been received and has started downloading c.responses = append(c.responses, resp) pump <- resp // send to caller } else { // channel is closed - all downloads are complete break Loop } case <-t.C: // update UI on clock tick c.refresh() } } c.refresh() close(pump) fmt.Printf( "Finished %d successful, %d failed, %d incomplete.\n", c.succeeded, c.failed, c.inProgress) }() return pump } // refresh prints the progress of all downloads to the terminal func (c *ConsoleClient) refresh() { // clear lines for incomplete downloads if c.inProgress > 0 { fmt.Printf("\033[%dA\033[K", c.inProgress) } // print newly completed downloads for i, resp := range c.responses { if resp != nil && resp.IsComplete() { if resp.Err() != nil { c.failed++ fmt.Fprintf(os.Stderr, "Error downloading %s: %v\n", resp.Request.URL(), resp.Err()) } else { c.succeeded++ fmt.Printf("Finished %s %s / %s (%d%%)\n", resp.Filename, byteString(resp.BytesComplete()), byteString(resp.Size()), int(100*resp.Progress())) } c.responses[i] = nil } } // print progress for incomplete downloads c.inProgress = 0 for _, resp := range c.responses { if resp != nil { fmt.Printf("Downloading %s %s / %s (%d%%) - %s ETA: %s \033[K\n", resp.Filename, byteString(resp.BytesComplete()), byteString(resp.Size()), int(100*resp.Progress()), bpsString(resp.BytesPerSecond()), etaString(resp.ETA())) c.inProgress++ } } } func bpsString(n float64) string { if n < 1e3 { return fmt.Sprintf("%.02fBps", n) } if n < 1e6 { return fmt.Sprintf("%.02fKB/s", n/1e3) } if n < 1e9 { return fmt.Sprintf("%.02fMB/s", n/1e6) } return fmt.Sprintf("%.02fGB/s", n/1e9) } func byteString(n int64) string { if n < 1<<10 { return fmt.Sprintf("%dB", n) } if n < 1<<20 { return fmt.Sprintf("%dKB", n>>10) } if n < 1<<30 { return fmt.Sprintf("%dMB", n>>20) } if n < 1<<40 { return fmt.Sprintf("%dGB", n>>30) } return fmt.Sprintf("%dTB", n>>40) } func etaString(eta time.Time) string { d := eta.Sub(time.Now()) if d < time.Second { return "<1s" } // truncate to 1s resolution d /= time.Second d *= time.Second return d.String() } grab-3.0.1/v3/pkg/grabui/grabui.go000066400000000000000000000007671416617544600166400ustar00rootroot00000000000000package grabui import ( "context" "github.com/cavaliergopher/grab/v3" ) func GetBatch( ctx context.Context, workers int, dst string, urlStrs ...string, ) (<-chan *grab.Response, error) { reqs := make([]*grab.Request, len(urlStrs)) for i := 0; i < len(urlStrs); i++ { req, err := grab.NewRequest(dst, urlStrs[i]) if err != nil { return nil, err } req = req.WithContext(ctx) reqs[i] = req } ui := NewConsoleClient(grab.DefaultClient) return ui.Do(ctx, workers, reqs...), nil } grab-3.0.1/v3/rate_limiter.go000066400000000000000000000005611416617544600160050ustar00rootroot00000000000000package grab import "context" // RateLimiter is an interface that must be satisfied by any third-party rate // limiters that may be used to limit download transfer speeds. // // A recommended token bucket implementation can be found at // https://godoc.org/golang.org/x/time/rate#Limiter. type RateLimiter interface { WaitN(ctx context.Context, n int) (err error) } grab-3.0.1/v3/rate_limiter_test.go000066400000000000000000000033571416617544600170520ustar00rootroot00000000000000package grab import ( "context" "log" "os" "testing" "time" "github.com/cavaliergopher/grab/v3/pkg/grabtest" ) // testRateLimiter is a naive rate limiter that limits throughput to r tokens // per second. The total number of tokens issued is tracked as n. type testRateLimiter struct { r, n int } func NewLimiter(r int) RateLimiter { return &testRateLimiter{r: r} } func (c *testRateLimiter) WaitN(ctx context.Context, n int) (err error) { c.n += n time.Sleep( time.Duration(1.00 / float64(c.r) * float64(n) * float64(time.Second))) return } func TestRateLimiter(t *testing.T) { // download a 128 byte file, 8 bytes at a time, with a naive 512bps limiter // should take > 250ms filesize := 128 filename := ".testRateLimiter" defer os.Remove(filename) grabtest.WithTestServer(t, func(url string) { // limit to 512bps lim := &testRateLimiter{r: 512} req := mustNewRequest(filename, url) // ensure multiple trips to the rate limiter by downloading 8 bytes at a time req.BufferSize = 8 req.RateLimiter = lim resp := mustDo(req) testComplete(t, resp) if lim.n != filesize { t.Errorf("expected %d bytes to pass through limiter, got %d", filesize, lim.n) } if resp.Duration().Seconds() < 0.25 { // BUG: this test can pass if the transfer was slow for unrelated reasons t.Errorf("expected transfer to take >250ms, took %v", resp.Duration()) } }, grabtest.ContentLength(filesize)) } func ExampleRateLimiter() { req, _ := NewRequest("", "http://www.golang-book.com/public/pdf/gobook.pdf") // Attach a 1Mbps rate limiter, like the token bucket implementation from // golang.org/x/time/rate. req.RateLimiter = NewLimiter(1048576) resp := DefaultClient.Do(req) if err := resp.Err(); err != nil { log.Fatal(err) } } grab-3.0.1/v3/request.go000066400000000000000000000141071416617544600150160ustar00rootroot00000000000000package grab import ( "context" "hash" "net/http" "net/url" ) // A Hook is a user provided callback function that can be called by grab at // various stages of a requests lifecycle. If a hook returns an error, the // associated request is canceled and the same error is returned on the Response // object. // // Hook functions are called synchronously and should never block unnecessarily. // Response methods that block until a download is complete, such as // Response.Err, Response.Cancel or Response.Wait will deadlock. To cancel a // download from a callback, simply return a non-nil error. type Hook func(*Response) error // A Request represents an HTTP file transfer request to be sent by a Client. type Request struct { // Label is an arbitrary string which may used to label a Request with a // user friendly name. Label string // Tag is an arbitrary interface which may be used to relate a Request to // other data. Tag interface{} // HTTPRequest specifies the http.Request to be sent to the remote server to // initiate a file transfer. It includes request configuration such as URL, // protocol version, HTTP method, request headers and authentication. HTTPRequest *http.Request // Filename specifies the path where the file transfer will be stored in // local storage. If Filename is empty or a directory, the true Filename will // be resolved using Content-Disposition headers or the request URL. // // An empty string means the transfer will be stored in the current working // directory. Filename string // SkipExisting specifies that ErrFileExists should be returned if the // destination path already exists. The existing file will not be checked for // completeness. SkipExisting bool // NoResume specifies that a partially completed download will be restarted // without attempting to resume any existing file. If the download is already // completed in full, it will not be restarted. NoResume bool // NoStore specifies that grab should not write to the local file system. // Instead, the download will be stored in memory and accessible only via // Response.Open or Response.Bytes. NoStore bool // NoCreateDirectories specifies that any missing directories in the given // Filename path should not be created automatically, if they do not already // exist. NoCreateDirectories bool // IgnoreBadStatusCodes specifies that grab should accept any status code in // the response from the remote server. Otherwise, grab expects the response // status code to be within the 2XX range (after following redirects). IgnoreBadStatusCodes bool // IgnoreRemoteTime specifies that grab should not attempt to set the // timestamp of the local file to match the remote file. IgnoreRemoteTime bool // Size specifies the expected size of the file transfer if known. If the // server response size does not match, the transfer is cancelled and // ErrBadLength returned. Size int64 // BufferSize specifies the size in bytes of the buffer that is used for // transferring the requested file. Larger buffers may result in faster // throughput but will use more memory and result in less frequent updates // to the transfer progress statistics. If a RateLimiter is configured, // BufferSize should be much lower than the rate limit. Default: 32KB. BufferSize int // RateLimiter allows the transfer rate of a download to be limited. The given // Request.BufferSize determines how frequently the RateLimiter will be // polled. RateLimiter RateLimiter // BeforeCopy is a user provided callback that is called immediately before // a request starts downloading. If BeforeCopy returns an error, the request // is cancelled and the same error is returned on the Response object. BeforeCopy Hook // AfterCopy is a user provided callback that is called immediately after a // request has finished downloading, before checksum validation and closure. // This hook is only called if the transfer was successful. If AfterCopy // returns an error, the request is canceled and the same error is returned on // the Response object. AfterCopy Hook // hash, checksum and deleteOnError - set via SetChecksum. hash hash.Hash checksum []byte deleteOnError bool // Context for cancellation and timeout - set via WithContext ctx context.Context } // NewRequest returns a new file transfer Request suitable for use with // Client.Do. func NewRequest(dst, urlStr string) (*Request, error) { if dst == "" { dst = "." } req, err := http.NewRequest("GET", urlStr, nil) if err != nil { return nil, err } return &Request{ HTTPRequest: req, Filename: dst, }, nil } // Context returns the request's context. To change the context, use // WithContext. // // The returned context is always non-nil; it defaults to the background // context. // // The context controls cancelation. func (r *Request) Context() context.Context { if r.ctx != nil { return r.ctx } return context.Background() } // WithContext returns a shallow copy of r with its context changed // to ctx. The provided ctx must be non-nil. func (r *Request) WithContext(ctx context.Context) *Request { if ctx == nil { panic("nil context") } r2 := new(Request) *r2 = *r r2.ctx = ctx r2.HTTPRequest = r2.HTTPRequest.WithContext(ctx) return r2 } // URL returns the URL to be downloaded. func (r *Request) URL() *url.URL { return r.HTTPRequest.URL } // SetChecksum sets the desired hashing algorithm and checksum value to validate // a downloaded file. Once the download is complete, the given hashing algorithm // will be used to compute the actual checksum of the downloaded file. If the // checksums do not match, an error will be returned by the associated // Response.Err method. // // If deleteOnError is true, the downloaded file will be deleted automatically // if it fails checksum validation. // // To prevent corruption of the computed checksum, the given hash must not be // used by any other request or goroutines. // // To disable checksum validation, call SetChecksum with a nil hash. func (r *Request) SetChecksum(h hash.Hash, sum []byte, deleteOnError bool) { r.hash = h r.checksum = sum r.deleteOnError = deleteOnError } grab-3.0.1/v3/response.go000066400000000000000000000164651416617544600151750ustar00rootroot00000000000000package grab import ( "bytes" "context" "io" "io/ioutil" "net/http" "os" "sync/atomic" "time" ) // Response represents the response to a completed or in-progress download // request. // // A response may be returned as soon a HTTP response is received from a remote // server, but before the body content has started transferring. // // All Response method calls are thread-safe. type Response struct { // The Request that was submitted to obtain this Response. Request *Request // HTTPResponse represents the HTTP response received from an HTTP request. // // The response Body should not be used as it will be consumed and closed by // grab. HTTPResponse *http.Response // Filename specifies the path where the file transfer is stored in local // storage. Filename string // Size specifies the total expected size of the file transfer. sizeUnsafe int64 // Start specifies the time at which the file transfer started. Start time.Time // End specifies the time at which the file transfer completed. // // This will return zero until the transfer has completed. End time.Time // CanResume specifies that the remote server advertised that it can resume // previous downloads, as the 'Accept-Ranges: bytes' header is set. CanResume bool // DidResume specifies that the file transfer resumed a previously incomplete // transfer. DidResume bool // Done is closed once the transfer is finalized, either successfully or with // errors. Errors are available via Response.Err Done chan struct{} // ctx is a Context that controls cancelation of an inprogress transfer ctx context.Context // cancel is a cancel func that can be used to cancel the context of this // Response. cancel context.CancelFunc // fi is the FileInfo for the destination file if it already existed before // transfer started. fi os.FileInfo // optionsKnown indicates that a HEAD request has been completed and the // capabilities of the remote server are known. optionsKnown bool // writer is the file handle used to write the downloaded file to local // storage writer io.Writer // storeBuffer receives the contents of the transfer if Request.NoStore is // enabled. storeBuffer bytes.Buffer // bytesCompleted specifies the number of bytes which were already // transferred before this transfer began. bytesResumed int64 // transfer is responsible for copying data from the remote server to a local // file, tracking progress and allowing for cancelation. transfer *transfer // bufferSize specifies the size in bytes of the transfer buffer. bufferSize int // Error contains any error that may have occurred during the file transfer. // This should not be read until IsComplete returns true. err error } // IsComplete returns true if the download has completed. If an error occurred // during the download, it can be returned via Err. func (c *Response) IsComplete() bool { select { case <-c.Done: return true default: return false } } // Cancel cancels the file transfer by canceling the underlying Context for // this Response. Cancel blocks until the transfer is closed and returns any // error - typically context.Canceled. func (c *Response) Cancel() error { c.cancel() return c.Err() } // Wait blocks until the download is completed. func (c *Response) Wait() { <-c.Done } // Err blocks the calling goroutine until the underlying file transfer is // completed and returns any error that may have occurred. If the download is // already completed, Err returns immediately. func (c *Response) Err() error { <-c.Done return c.err } // Size returns the size of the file transfer. If the remote server does not // specify the total size and the transfer is incomplete, the return value is // -1. func (c *Response) Size() int64 { return atomic.LoadInt64(&c.sizeUnsafe) } // BytesComplete returns the total number of bytes which have been copied to // the destination, including any bytes that were resumed from a previous // download. func (c *Response) BytesComplete() int64 { return c.bytesResumed + c.transfer.N() } // BytesPerSecond returns the number of bytes per second transferred using a // simple moving average of the last five seconds. If the download is already // complete, the average bytes/sec for the life of the download is returned. func (c *Response) BytesPerSecond() float64 { if c.IsComplete() { return float64(c.transfer.N()) / c.Duration().Seconds() } return c.transfer.BPS() } // Progress returns the ratio of total bytes that have been downloaded. Multiply // the returned value by 100 to return the percentage completed. func (c *Response) Progress() float64 { size := c.Size() if size <= 0 { return 0 } return float64(c.BytesComplete()) / float64(size) } // Duration returns the duration of a file transfer. If the transfer is in // process, the duration will be between now and the start of the transfer. If // the transfer is complete, the duration will be between the start and end of // the completed transfer process. func (c *Response) Duration() time.Duration { if c.IsComplete() { return c.End.Sub(c.Start) } return time.Now().Sub(c.Start) } // ETA returns the estimated time at which the the download will complete, given // the current BytesPerSecond. If the transfer has already completed, the actual // end time will be returned. func (c *Response) ETA() time.Time { if c.IsComplete() { return c.End } bt := c.BytesComplete() bps := c.transfer.BPS() if bps == 0 { return time.Time{} } secs := float64(c.Size()-bt) / bps return time.Now().Add(time.Duration(secs) * time.Second) } // Open blocks the calling goroutine until the underlying file transfer is // completed and then opens the transferred file for reading. If Request.NoStore // was enabled, the reader will read from memory. // // If an error occurred during the transfer, it will be returned. // // It is the callers responsibility to close the returned file handle. func (c *Response) Open() (io.ReadCloser, error) { if err := c.Err(); err != nil { return nil, err } return c.openUnsafe() } func (c *Response) openUnsafe() (io.ReadCloser, error) { if c.Request.NoStore { return ioutil.NopCloser(bytes.NewReader(c.storeBuffer.Bytes())), nil } return os.Open(c.Filename) } // Bytes blocks the calling goroutine until the underlying file transfer is // completed and then reads all bytes from the completed tranafer. If // Request.NoStore was enabled, the bytes will be read from memory. // // If an error occurred during the transfer, it will be returned. func (c *Response) Bytes() ([]byte, error) { if err := c.Err(); err != nil { return nil, err } if c.Request.NoStore { return c.storeBuffer.Bytes(), nil } f, err := c.Open() if err != nil { return nil, err } defer f.Close() return ioutil.ReadAll(f) } func (c *Response) requestMethod() string { if c == nil || c.HTTPResponse == nil || c.HTTPResponse.Request == nil { return "" } return c.HTTPResponse.Request.Method } func (c *Response) checksumUnsafe() ([]byte, error) { f, err := c.openUnsafe() if err != nil { return nil, err } defer f.Close() t := newTransfer(c.Request.Context(), nil, c.Request.hash, f, nil) if _, err = t.copy(); err != nil { return nil, err } sum := c.Request.hash.Sum(nil) return sum, nil } func (c *Response) closeResponseBody() error { if c.HTTPResponse == nil || c.HTTPResponse.Body == nil { return nil } return c.HTTPResponse.Body.Close() } grab-3.0.1/v3/response_test.go000066400000000000000000000052201416617544600162170ustar00rootroot00000000000000package grab import ( "bytes" "os" "testing" "time" "github.com/cavaliergopher/grab/v3/pkg/grabtest" ) // testComplete validates that a completed Response has all the desired fields. func testComplete(t *testing.T, resp *Response) { <-resp.Done if !resp.IsComplete() { t.Errorf("Response.IsComplete returned false") } if resp.Start.IsZero() { t.Errorf("Response.Start is zero") } if resp.End.IsZero() { t.Error("Response.End is zero") } if eta := resp.ETA(); eta != resp.End { t.Errorf("Response.ETA is not equal to Response.End: %v", eta) } // the following fields should only be set if no error occurred if resp.Err() == nil { if resp.Filename == "" { t.Errorf("Response.Filename is empty") } if resp.Size() == 0 { t.Error("Response.Size is zero") } if p := resp.Progress(); p != 1.00 { t.Errorf("Response.Progress returned %v (%v/%v bytes), expected 1", p, resp.BytesComplete(), resp.Size()) } } } // TestResponseProgress tests the functions which indicate the progress of an // in-process file transfer. func TestResponseProgress(t *testing.T) { filename := ".testResponseProgress" defer os.Remove(filename) sleep := 300 * time.Millisecond size := 1024 * 8 // bytes grabtest.WithTestServer(t, func(url string) { // request a slow transfer req := mustNewRequest(filename, url) resp := DefaultClient.Do(req) // make sure transfer has not started if resp.IsComplete() { t.Errorf("Transfer should not have started") } if p := resp.Progress(); p != 0 { t.Errorf("Transfer should not have started yet but progress is %v", p) } // wait for transfer to complete <-resp.Done // make sure transfer is complete if p := resp.Progress(); p != 1 { t.Errorf("Transfer is complete but progress is %v", p) } if s := resp.BytesComplete(); s != int64(size) { t.Errorf("Expected to transfer %v bytes, got %v", size, s) } }, grabtest.TimeToFirstByte(sleep), grabtest.ContentLength(size), ) } func TestResponseOpen(t *testing.T) { grabtest.WithTestServer(t, func(url string) { resp := mustDo(mustNewRequest("", url+"/someFilename")) f, err := resp.Open() if err != nil { t.Error(err) return } defer func() { if err := f.Close(); err != nil { t.Error(err) } }() grabtest.AssertSHA256Sum(t, grabtest.DefaultHandlerSHA256ChecksumBytes, f) }) } func TestResponseBytes(t *testing.T) { grabtest.WithTestServer(t, func(url string) { resp := mustDo(mustNewRequest("", url+"/someFilename")) b, err := resp.Bytes() if err != nil { t.Error(err) return } grabtest.AssertSHA256Sum( t, grabtest.DefaultHandlerSHA256ChecksumBytes, bytes.NewReader(b), ) }) } grab-3.0.1/v3/transfer.go000066400000000000000000000036761416617544600151630ustar00rootroot00000000000000package grab import ( "context" "io" "sync/atomic" "time" "github.com/cavaliergopher/grab/v3/pkg/bps" ) type transfer struct { n int64 // must be 64bit aligned on 386 ctx context.Context gauge bps.Gauge lim RateLimiter w io.Writer r io.Reader b []byte } func newTransfer(ctx context.Context, lim RateLimiter, dst io.Writer, src io.Reader, buf []byte) *transfer { return &transfer{ ctx: ctx, gauge: bps.NewSMA(6), // five second moving average sampling every second lim: lim, w: dst, r: src, b: buf, } } // copy behaves similarly to io.CopyBuffer except that it checks for cancelation // of the given context.Context, reports progress in a thread-safe manner and // tracks the transfer rate. func (c *transfer) copy() (written int64, err error) { // maintain a bps gauge in another goroutine ctx, cancel := context.WithCancel(c.ctx) defer cancel() go bps.Watch(ctx, c.gauge, c.N, time.Second) // start the transfer if c.b == nil { c.b = make([]byte, 32*1024) } for { select { case <-c.ctx.Done(): err = c.ctx.Err() return default: // keep working } nr, er := c.r.Read(c.b) if nr > 0 { nw, ew := c.w.Write(c.b[0:nr]) if nw > 0 { written += int64(nw) atomic.StoreInt64(&c.n, written) } if ew != nil { err = ew break } if nr != nw { err = io.ErrShortWrite break } // wait for rate limiter if c.lim != nil { err = c.lim.WaitN(c.ctx, nr) if err != nil { return } } } if er != nil { if er != io.EOF { err = er } break } } return written, err } // N returns the number of bytes transferred. func (c *transfer) N() (n int64) { if c == nil { return 0 } n = atomic.LoadInt64(&c.n) return } // BPS returns the current bytes per second transfer rate using a simple moving // average. func (c *transfer) BPS() (bps float64) { if c == nil || c.gauge == nil { return 0 } return c.gauge.BPS() } grab-3.0.1/v3/util.go000066400000000000000000000037271416617544600143110ustar00rootroot00000000000000package grab import ( "fmt" "mime" "net/http" "os" "path" "path/filepath" "strings" "time" ) // setLastModified sets the last modified timestamp of a local file according to // the Last-Modified header returned by a remote server. func setLastModified(resp *http.Response, filename string) error { // https://tools.ietf.org/html/rfc7232#section-2.2 // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Last-Modified header := resp.Header.Get("Last-Modified") if header == "" { return nil } lastmod, err := time.Parse(http.TimeFormat, header) if err != nil { return nil } return os.Chtimes(filename, lastmod, lastmod) } // mkdirp creates all missing parent directories for the destination file path. func mkdirp(path string) error { dir := filepath.Dir(path) if fi, err := os.Stat(dir); err != nil { if !os.IsNotExist(err) { return fmt.Errorf("error checking destination directory: %v", err) } if err := os.MkdirAll(dir, 0777); err != nil { return fmt.Errorf("error creating destination directory: %v", err) } } else if !fi.IsDir() { panic("grab: developer error: destination path is not directory") } return nil } // guessFilename returns a filename for the given http.Response. If none can be // determined ErrNoFilename is returned. // // TODO: NoStore operations should not require a filename func guessFilename(resp *http.Response) (string, error) { filename := resp.Request.URL.Path if cd := resp.Header.Get("Content-Disposition"); cd != "" { if _, params, err := mime.ParseMediaType(cd); err == nil { if val, ok := params["filename"]; ok { filename = val } // else filename directive is missing.. fallback to URL.Path } } // sanitize if filename == "" || strings.HasSuffix(filename, "/") || strings.Contains(filename, "\x00") { return "", ErrNoFilename } filename = filepath.Base(path.Clean("/" + filename)) if filename == "" || filename == "." || filename == "/" { return "", ErrNoFilename } return filename, nil } grab-3.0.1/v3/util_test.go000066400000000000000000000070221416617544600153400ustar00rootroot00000000000000package grab import ( "fmt" "net/http" "net/url" "testing" ) func TestURLFilenames(t *testing.T) { t.Run("Valid", func(t *testing.T) { expect := "filename" testCases := []string{ "http://test.com/filename", "http://test.com/path/filename", "http://test.com/deep/path/filename", "http://test.com/filename?with=args", "http://test.com/filename#with-fragment", "http://test.com/filename?with=args&and#with-fragment", } for _, tc := range testCases { req, _ := http.NewRequest("GET", tc, nil) resp := &http.Response{ Request: req, } actual, err := guessFilename(resp) if err != nil { t.Errorf("%v", err) } if actual != expect { t.Errorf("expected '%v', got '%v'", expect, actual) } } }) t.Run("Invalid", func(t *testing.T) { testCases := []string{ "http://test.com", "http://test.com/", "http://test.com/filename/", "http://test.com/filename/?with=args", "http://test.com/filename/#with-fragment", "http://test.com/filename\x00", } for _, tc := range testCases { t.Run(tc, func(t *testing.T) { req, err := http.NewRequest("GET", tc, nil) if err != nil { if tc == "http://test.com/filename\x00" { // Since go1.12, urls with invalid control character return an error // See https://github.com/golang/go/commit/829c5df58694b3345cb5ea41206783c8ccf5c3ca t.Skip() } } resp := &http.Response{ Request: req, } _, err = guessFilename(resp) if err != ErrNoFilename { t.Errorf("expected '%v', got '%v'", ErrNoFilename, err) } }) } }) } func TestHeaderFilenames(t *testing.T) { u, _ := url.ParseRequestURI("http://test.com/badfilename") resp := &http.Response{ Request: &http.Request{ URL: u, }, Header: http.Header{}, } setFilename := func(resp *http.Response, filename string) { resp.Header.Set("Content-Disposition", fmt.Sprintf("attachment;filename=\"%s\"", filename)) } t.Run("Valid", func(t *testing.T) { expect := "filename" testCases := []string{ "filename", "path/filename", "/path/filename", "../../filename", "/path/../../filename", "/../../././///filename", } for _, tc := range testCases { setFilename(resp, tc) actual, err := guessFilename(resp) if err != nil { t.Errorf("error (%v): %v", tc, err) } if actual != expect { t.Errorf("expected '%v' (%v), got '%v'", expect, tc, actual) } } }) t.Run("Invalid", func(t *testing.T) { testCases := []string{ "", "/", ".", "/.", "/./", "..", "../", "/../", "/path/", "../path/", "filename\x00", "filename/", "filename//", "filename/..", } for _, tc := range testCases { setFilename(resp, tc) if actual, err := guessFilename(resp); err != ErrNoFilename { t.Errorf("expected: %v (%v), got: %v (%v)", ErrNoFilename, tc, err, actual) } } }) } func TestHeaderWithMissingDirective(t *testing.T) { u, _ := url.ParseRequestURI("http://test.com/filename") resp := &http.Response{ Request: &http.Request{ URL: u, }, Header: http.Header{}, } setHeader := func(resp *http.Response, value string) { resp.Header.Set("Content-Disposition", value) } t.Run("Valid", func(t *testing.T) { expect := "filename" testCases := []string{ "inline", "attachment", } for _, tc := range testCases { setHeader(resp, tc) actual, err := guessFilename(resp) if err != nil { t.Errorf("error (%v): %v", tc, err) } if actual != expect { t.Errorf("expected '%v' (%v), got '%v'", expect, tc, actual) } } }) }