pax_global_header00006660000000000000000000000064142137143400014511gustar00rootroot0000000000000052 comment=b41869697484173d2e936754152cd64e40a046d4 martian-3.3.2/000077500000000000000000000000001421371434000131515ustar00rootroot00000000000000martian-3.3.2/.gitignore000066400000000000000000000000341421371434000151360ustar00rootroot00000000000000node_modules dist .DS_Store martian-3.3.2/.travis.yml000066400000000000000000000006711421371434000152660ustar00rootroot00000000000000# Allow Travis to use container based infrastructure sudo: false language: go go: - 1.13.x install: - go get -u golang.org/x/net/http2 - go get -u golang.org/x/net/http2/hpack - go get -u golang.org/x/net/websocket - go get -u golang.org/x/lint/golint - go get -u google.golang.org/protobuf/proto - go get -u google.golang.org/grpc - go get -u github.com/golang/snappy script: - golint ./... - go test -v ./... --race martian-3.3.2/CONTRIBUTING000066400000000000000000000025631421371434000150110ustar00rootroot00000000000000Want to contribute? Great! First, read this page (including the small print at the end). ### Before you contribute Before we can use your code, you must sign the [Google Individual Contributor License Agreement](https://developers.google.com/open-source/cla/individual?csw=1) (CLA), which you can do online. The CLA is necessary mainly because you own the copyright to your changes, even after your contribution becomes part of our codebase, so we need your permission to use and distribute your code. We also need to be sure of various other things—for instance that you'll tell us if you know that your code infringes on other people's patents. You don't have to sign the CLA until after you've submitted your code for review and a member has approved it, but you must do it before we can put your code into our codebase. Before you start working on a larger contribution, you should get in touch with us first through the issue tracker with your idea so that we can help out and possibly guide you. Coordinating up front makes it much easier to avoid frustration later on. ### Code reviews All submissions, including submissions by project members, require review. We use GitHub pull requests for this purpose. ### The small print Contributions made by corporations are covered by a different agreement than the one above, the Software Grant and Corporate Contributor License Agreement. martian-3.3.2/LICENSE000066400000000000000000000261361421371434000141660ustar00rootroot00000000000000 Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. martian-3.3.2/README.md000066400000000000000000000250271421371434000144360ustar00rootroot00000000000000# Martian Proxy [![Build Status](https://travis-ci.org/google/martian.svg?branch=master)](https://travis-ci.org/google/martian) Martian Proxy is a programmable HTTP proxy designed to be used for testing. Martian is a great tool to use if you want to: * Verify that all (or some subset) of requests are secure * Mock external services at the network layer * Inject headers, modify cookies or perform other mutations of HTTP requests and responses * Verify that pingbacks happen when you think they should * Unwrap encrypted traffic (requires install of CA certificate in browser) By taking advantage of Go cross-compilation, Martian can be deployed anywhere that Go can target. ## Latest Version v3.0.0 ## Requirements Go 1.11 ## Go Modules Support Martian Proxy added support for Go modules since v3.0.0. If you use a Go version that does not support modules, this will break you. The latest version without Go modules support was tagged v2.1.0. ## Getting Started ### Installation Martian Proxy can be installed using `go install` go get github.com/google/martian/ && \ go install github.com/google/martian/cmd/proxy ### Start the Proxy Assuming you've installed Martian, running the proxy is as simple as $GOPATH/bin/proxy If you want to see system logs as Martian is running, pass in the verbosity flag: $GOPATH/bin/proxy -v=2 By default, Martian will be running on port 8080, and the Martian API will be running on 8181 . The port can be specified via flags: $GOPATH/bin/proxy -addr=:9999 -api-addr=:9898 ### Logging For logging of requests and responses a [logging modifier](https://github.com/google/martian/wiki/Modifier-Reference#logging) is available or [HAR](http://www.softwareishard.com/blog/har-12-spec/) logs are available if the `-har` flag is used. #### HAR Logging To enable HAR logging in Martian call the binary with the `-har` flag: $GOPATH/bin/proxy -har If the `-har` flag has been enabled two HAR related endpoints will be available: GET http://martian.proxy/logs Will retrieve the HAR log of all requests and responses seen by the proxy since the last reset. DELETE http://martian.proxy/logs/reset Will reset the in-memory HAR log. Note that the log will grow unbounded unless it is periodically reset. ### Configure Once Martian is running, you need to configure its behavior. Without configuration, Martian is just proxying without doing anything to the requests or responses. If enabled, logging will take place without additional configuration. Martian is configured by JSON messages sent over HTTP that take the general form of: { "header.Modifier": { "scope": ["response"], "name": "Test-Header", "value": "true" } } The above configuration tells Martian to inject a header with the name "Test-Header" and the value "true" on all responses. Let's break down the parts of this message. * `[package.Type]`: The package.Type of the modifier that you want to use. In this case, it's "header.Modifier", which is the name of the modifier that sets headers (to learn more about the `header.Modifier`, please refer to the [modifier reference](https://github.com/google/martian/wiki/Modifier-Reference)). * `[package.Type].scope`: Indicates whether to apply to the modifier to requests, responses or both. This can be an array containing "request", "response", or both. * `[package.Type].[key]`: Modifier specific data. In the case of the header modifier, we need the `name` and `value` of the header. This is a simple configuration, for more complex configurations, modifiers are combined with groups and filters to compose the desired behavior. To configure Martian, `POST` the JSON to `http://martian.proxy/modifiers`. You'll want to use whatever mechanism your language of choice provides you to make HTTP requests, but for demo purposes, curl works (assuming your configuration is in a file called `modifier.json`). curl -x localhost:8080 \ -X POST \ -H "Content-Type: application/json" \ -d @modifier.json \ "http://martian.proxy/configure" ### Intercepting HTTPS Requests and Responses Martian supports modifying HTTPS requests and responses if configured to do so. In order for Martian to intercept HTTPS traffic a custom CA certificate must be installed in the browser so that connection warnings are not shown. The easiest way to install the CA certificate is to start the proxy with the necessary flags to use a custom CA certificate and private key using the `-cert` and `-key` flags, or to have the proxy generate one using the `-generate-ca-cert` flag. After the proxy has started, visit http://martian.proxy/authority.cer in the browser configured to use the proxy and a prompt will be displayed to install the certificate. Several flags are available in `examples/main.go` to help configure MITM functionality: -key="" PEM encoded private key file of the CA certificate provided in -cert; used to sign certificates that are generated on-the-fly -cert="" PEM encoded CA certificate file used to generate certificates -generate-ca-cert=false generates a CA certificate and private key to use for man-in-the-middle; most users choosing this option will immediately visit http://martian.proxy/authority.cer in the browser whose traffic is to be intercepted to install the newly generated CA certificate -organization="Martian Proxy" organization name set on the dynamically-generated certificates during man-in-the-middle -validity="1h" window of time around the time of request that the dynamically-generated certificate is valid for; the duration is set such that the total valid timeframe is double the value of validity (1h before & 1h after) ### Check Verifiers Let's assume that you've configured Martian to verify the presence a specific header in responses to a specific URL. Here's a configuration to verify that all requests to `example.com` return responses with a `200 OK`. { "url.Filter": { "scope": ["request", "response"], "host" : "example.com", "modifier" : { "status.Verifier": { "scope" : ["response"], "statusCode": 200 } } } } Once Martian is running, configured and the requests and resultant responses you wish to verify have taken place, you can verify your expectation that you only got back `200 OK` responses. To check verifications, perform GET http://martian.proxy/verify Failed expectations are tracked as errors, and the list of errors are retrieved by making a `GET` request to `host:port/martian/verify`, which will return a list of errors: { "errors" : [ { "message": "response(http://example.com) status code verify failure: got 500, want 200" }, { "message": "response(http://example.com/foo) status code verify failure: got 500, want 200" } ] } Verification errors are held in memory until they are explicitly cleared by POST http://martian.proxy/verify/reset ## Martian as a Library Martian can also be included into any Go program and used as a library. ## Modifiers All The Way Down Martian's request and response modification system is designed to be general and extensible. The design objective is to provide individual modifier behaviors that can arranged to build out nearly any desired modification. When working with Martian to compose behaviors, you'll need to be familiar with these different types of interactions: * Modifiers: Changes the state of a request or a response * Filters: Conditionally allows a contained Modifier to execute * Groups: Bundles multiple modifiers to be executed in the order specified in the group * Verifiers: Tracks network traffic against expectations Modifiers, filters and groups all implement `RequestModifier`, `ResponseModifier` or `RequestResponseModifier` (defined in [`martian.go`](https://github.com/google/martian/blob/master/martian.go)). ```go ModifyRequest(req *http.Request) error ModifyResponse(res *http.Response) error ``` Throughout the code (and this documentation) you'll see the word "modifier" used as a term that encompasses modifiers, groups and filters. Even though a group does not modify a request or response, we still refer to it as a "modifier". We refer to anything that implements the `modifier` interface as a Modifier. ### Parser Registration Each modifier must register its own parser with Martian. The parser is responsible for parsing a JSON message into a Go struct that implements a modifier interface. Martian holds modifier parsers as a map of strings to functions that is built out at run-time. Each modifier is responsible for registering its parser with a call to `parse.Register` in `init()`. Signature of parse.Register: ```go Register(name string, parseFunc func(b []byte) (interface{}, error)) ``` Register takes in the key as a string in the form `package.Type`. For instance, `cookie_modifier` registers itself with the key `cookie.Modifier` and `query_string_filter` registers itself as `querystring.Filter`. This string is the same as the value of `name` in the JSON configuration message. In the following configuration message, `header.Modifier` is how the header modifier is registered in the `init()` of `header_modifier.go`. { "header.Modifier": { "scope": ["response"], "name" : "Test-Header", "value" : "true" } } Example of parser registration from `header_modifier.go`: ```go func init() { parse.Register("header.Modifier", modifierFromJSON) } func modifierFromJSON(b []byte) (interface{}, error) { ... } ``` ### Adding Your Own Modifier If you have a use-case in mind that we have not developed modifiers, filters or verifiers for, you can easily extend Martian to your very specific needs. There are 2 mandatory parts of a modifier: * Implement the modifier interface * Register the parser Any Go struct that implements those interfaces can act as a `modifier`. ## Contact For questions and comments on how to use Martian, features announcements, or design discussions check out our public Google Group at https://groups.google.com/forum/#!forum/martianproxy-users. For security related issues please send a detailed report to our private core group at martianproxy-core@googlegroups.com. ## Disclaimer This is not an official Google product (experimental or otherwise), it is just code that happens to be owned by Google. martian-3.3.2/api/000077500000000000000000000000001421371434000137225ustar00rootroot00000000000000martian-3.3.2/api/forwarder.go000066400000000000000000000033361421371434000162510ustar00rootroot00000000000000// Copyright 2016 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package api contains a forwarder to route system HTTP requests to the // local API server. package api import ( "fmt" "net/http" "github.com/google/martian/v3" "github.com/google/martian/v3/log" ) // Forwarder is a request modifier that routes the request to the API server and // marks the request for skipped logging. type Forwarder struct { host string port int } // NewForwarder returns a Forwarder that rewrites requests to host. func NewForwarder(host string, port int) *Forwarder { if host == "" { host = "localhost" } return &Forwarder{ host: host, port: port, } } // ModifyRequest forwards the request to the local API server running at f.port, // downgrades the scheme to http and marks the request context for skipped logging. // API requests are marked for skipping the roundtrip. func (f *Forwarder) ModifyRequest(req *http.Request) error { ctx := martian.NewContext(req) ctx.APIRequest() ctx.SkipLogging() in := req.URL.String() req.URL.Scheme = "http" req.URL.Host = fmt.Sprintf("%s:%d", f.host, f.port) out := req.URL.String() log.Infof("api.Forwarder: forwarding %s to %s", in, out) return nil } martian-3.3.2/api/forwarder_test.go000066400000000000000000000042661421371434000173130ustar00rootroot00000000000000// Copyright 2016 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package api import ( "net/http" "testing" "github.com/google/martian/v3" ) func TestApiForwarder(t *testing.T) { forwarder := NewForwarder("", 8181) req, err := http.NewRequest("GET", "https://martian.proxy/configure", nil) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } ctx, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("TestContext(): got %v, want no error", err) } defer remove() if err := forwarder.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := req.URL.Scheme, "http"; got != want { t.Errorf("req.URL.Scheme: got %s, want %s", got, want) } if got, want := req.URL.Host, "localhost:8181"; got != want { t.Errorf("req.URL.Host: got %s, want %s", got, want) } if !ctx.SkippingLogging() { t.Errorf("ctx.SkippingLogging: got false, want true") } if !ctx.IsAPIRequest() { t.Errorf("ctx.IsApiRequest: got false, want true") } } func TestApiForwarderWithHost(t *testing.T) { forwarder := NewForwarder("example.com", 8181) req, err := http.NewRequest("GET", "https://martian.proxy/configure", nil) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } _, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("TestContext(): got %v, want no error", err) } defer remove() if err := forwarder.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := req.URL.Host, "example.com:8181"; got != want { t.Errorf("req.URL.Host: got %s, want %s", got, want) } } martian-3.3.2/auth/000077500000000000000000000000001421371434000141125ustar00rootroot00000000000000martian-3.3.2/auth/auth_filter.go000066400000000000000000000072401421371434000167520ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package auth provides filtering support for a martian.Proxy based on auth // ID. package auth import ( "fmt" "net/http" "sync" "github.com/google/martian/v3" ) // Filter filters RequestModifiers and ResponseModifiers by auth ID. type Filter struct { authRequired bool mu sync.RWMutex reqmods map[string]martian.RequestModifier resmods map[string]martian.ResponseModifier } // NewFilter returns a new auth.Filter. func NewFilter() *Filter { return &Filter{ reqmods: make(map[string]martian.RequestModifier), resmods: make(map[string]martian.ResponseModifier), } } // SetAuthRequired determines whether the auth ID must have an associated // RequestModifier or ResponseModifier. If true, it will set auth error. func (f *Filter) SetAuthRequired(required bool) { f.authRequired = required } // SetRequestModifier sets the RequestModifier for the given ID. It will // overwrite any existing modifier with the same ID. func (f *Filter) SetRequestModifier(id string, reqmod martian.RequestModifier) error { f.mu.Lock() defer f.mu.Unlock() if reqmod != nil { f.reqmods[id] = reqmod } else { delete(f.reqmods, id) } return nil } // SetResponseModifier sets the ResponseModifier for the given ID. It will // overwrite any existing modifier with the same ID. func (f *Filter) SetResponseModifier(id string, resmod martian.ResponseModifier) error { f.mu.Lock() defer f.mu.Unlock() if resmod != nil { f.resmods[id] = resmod } else { delete(f.resmods, id) } return nil } // RequestModifier retrieves the RequestModifier for the given ID. Returns nil // if no modifier exists for the given ID. func (f *Filter) RequestModifier(id string) martian.RequestModifier { f.mu.RLock() defer f.mu.RUnlock() return f.reqmods[id] } // ResponseModifier retrieves the ResponseModifier for the given ID. Returns nil // if no modifier exists for the given ID. func (f *Filter) ResponseModifier(id string) martian.ResponseModifier { f.mu.RLock() defer f.mu.RUnlock() return f.resmods[id] } // ModifyRequest runs the RequestModifier for the associated auth ID. If no // modifier is found for auth ID then auth error is set. func (f *Filter) ModifyRequest(req *http.Request) error { ctx := martian.NewContext(req) actx := FromContext(ctx) if reqmod, ok := f.reqmods[actx.ID()]; ok { return reqmod.ModifyRequest(req) } if err := f.requireKnownAuth(actx.ID()); err != nil { actx.SetError(err) } return nil } // ModifyResponse runs the ResponseModifier for the associated auth ID. If no // modifier is found for the auth ID then the auth error is set. func (f *Filter) ModifyResponse(res *http.Response) error { ctx := martian.NewContext(res.Request) actx := FromContext(ctx) if resmod, ok := f.resmods[actx.ID()]; ok { return resmod.ModifyResponse(res) } if err := f.requireKnownAuth(actx.ID()); err != nil { actx.SetError(err) } return nil } func (f *Filter) requireKnownAuth(id string) error { _, reqok := f.reqmods[id] _, resok := f.resmods[id] if !reqok && !resok && f.authRequired { return fmt.Errorf("auth: unrecognized credentials: %s", id) } return nil } martian-3.3.2/auth/auth_filter_test.go000066400000000000000000000101341421371434000200050ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package auth import ( "net/http" "testing" "github.com/google/martian/v3" "github.com/google/martian/v3/martiantest" "github.com/google/martian/v3/proxyutil" ) func TestFilter(t *testing.T) { f := NewFilter() if f.RequestModifier("id") != nil { t.Fatalf("f.RequestModifier(%q): got reqmod, want nil", "id") } if f.ResponseModifier("id") != nil { t.Fatalf("f.ResponseModifier(%q): got resmod, want nil", "id") } tm := martiantest.NewModifier() f.SetRequestModifier("id", tm) f.SetResponseModifier("id", tm) if f.RequestModifier("id") != tm { t.Errorf("f.RequestModifier(%q): got nil, want martiantest.Modifier", "id") } if f.ResponseModifier("id") != tm { t.Errorf("f.ResponseModifier(%q): got nil, want martiantest.Modifier", "id") } } func TestModifyRequest(t *testing.T) { f := NewFilter() tm := martiantest.NewModifier() f.SetRequestModifier("id", tm) req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } // No ID, auth required. f.SetAuthRequired(true) ctx, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("martian.TestContext(): got %v, want no error", err) } defer remove() if err := f.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } actx := FromContext(ctx) if actx.Error() == nil { t.Error("actx.Error(): got nil, want error") } if tm.RequestModified() { t.Error("tm.RequestModified(): got true, want false") } tm.Reset() // No ID, auth not required. f.SetAuthRequired(false) actx.SetError(nil) if err := f.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if actx.Error() != nil { t.Errorf("actx.Error(): got %v, want no error", err) } if tm.RequestModified() { t.Error("tm.RequestModified(): got true, want false") } // Valid ID. actx.SetError(nil) actx.SetID("id") if err := f.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if actx.Error() != nil { t.Errorf("actx.Error(): got %v, want no error", actx.Error()) } if !tm.RequestModified() { t.Error("tm.RequestModified(): got false, want true") } } func TestModifyResponse(t *testing.T) { f := NewFilter() tm := martiantest.NewModifier() f.SetResponseModifier("id", tm) req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } res := proxyutil.NewResponse(200, nil, req) // No ID, auth required. f.SetAuthRequired(true) ctx, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("martian.TestContext(): got %v, want no error", err) } defer remove() if err := f.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } actx := FromContext(ctx) if actx.Error() == nil { t.Error("actx.Error(): got nil, want error") } if tm.ResponseModified() { t.Error("tm.RequestModified(): got true, want false") } // No ID, no auth required. f.SetAuthRequired(false) actx.SetError(nil) if err := f.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if tm.ResponseModified() { t.Error("tm.ResponseModified(): got true, want false") } // Valid ID. actx.SetID("id") if err := f.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if !tm.ResponseModified() { t.Error("tm.ResponseModified(): got false, want true") } } martian-3.3.2/auth/context.go000066400000000000000000000031241421371434000161250ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package auth import ( "sync" "github.com/google/martian/v3" ) const key = "auth.Context" // Context contains authentication information. type Context struct { mu sync.RWMutex id string err error } // FromContext retrieves the auth.Context from the session. func FromContext(ctx *martian.Context) *Context { if v, ok := ctx.Session().Get(key); ok { return v.(*Context) } actx := &Context{} ctx.Session().Set(key, actx) return actx } // ID returns the ID. func (ctx *Context) ID() string { ctx.mu.RLock() defer ctx.mu.RUnlock() return ctx.id } // SetID sets the ID. func (ctx *Context) SetID(id string) { ctx.mu.Lock() defer ctx.mu.Unlock() ctx.err = nil if id == "" { return } ctx.id = id } // SetError sets the error and resets the ID. func (ctx *Context) SetError(err error) { ctx.mu.Lock() defer ctx.mu.Unlock() ctx.id = "" ctx.err = err } // Error returns the error. func (ctx *Context) Error() error { ctx.mu.RLock() defer ctx.mu.RUnlock() return ctx.err } martian-3.3.2/body/000077500000000000000000000000001421371434000141065ustar00rootroot00000000000000martian-3.3.2/body/body_modifier.go000066400000000000000000000131621421371434000172530ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package body allows for the replacement of message body on responses. package body import ( "bytes" "crypto/rand" "encoding/json" "fmt" "io" "io/ioutil" "mime/multipart" "net/http" "net/textproto" "strconv" "strings" "github.com/google/martian/v3/log" "github.com/google/martian/v3/parse" ) func init() { parse.Register("body.Modifier", modifierFromJSON) } // Modifier substitutes the body on an HTTP response. type Modifier struct { contentType string body []byte boundary string } type modifierJSON struct { ContentType string `json:"contentType"` Body []byte `json:"body"` // Body is expected to be a Base64 encoded string. Scope []parse.ModifierType `json:"scope"` } // NewModifier constructs and returns a body.Modifier. func NewModifier(b []byte, contentType string) *Modifier { log.Debugf("body.NewModifier: len(b): %d, contentType %s", len(b), contentType) return &Modifier{ contentType: contentType, body: b, boundary: randomBoundary(), } } // modifierFromJSON takes a JSON message as a byte slice and returns a // body.Modifier and an error. // // Example JSON Configuration message: // { // "scope": ["request", "response"], // "contentType": "text/plain", // "body": "c29tZSBkYXRhIHdpdGggACBhbmQg77u/" // Base64 encoded body // } func modifierFromJSON(b []byte) (*parse.Result, error) { msg := &modifierJSON{} if err := json.Unmarshal(b, msg); err != nil { return nil, err } mod := NewModifier(msg.Body, msg.ContentType) return parse.NewResult(mod, msg.Scope) } // ModifyRequest sets the Content-Type header and overrides the request body. func (m *Modifier) ModifyRequest(req *http.Request) error { log.Debugf("body.ModifyRequest: request: %s", req.URL) req.Body.Close() req.Header.Set("Content-Type", m.contentType) // Reset the Content-Encoding since we know that the new body isn't encoded. req.Header.Del("Content-Encoding") req.ContentLength = int64(len(m.body)) req.Body = ioutil.NopCloser(bytes.NewReader(m.body)) return nil } // SetBoundary set the boundary string used for multipart range responses. func (m *Modifier) SetBoundary(boundary string) { m.boundary = boundary } // ModifyResponse sets the Content-Type header and overrides the response body. func (m *Modifier) ModifyResponse(res *http.Response) error { log.Debugf("body.ModifyResponse: request: %s", res.Request.URL) // Replace the existing body, close it first. res.Body.Close() res.Header.Set("Content-Type", m.contentType) // Reset the Content-Encoding since we know that the new body isn't encoded. res.Header.Del("Content-Encoding") // If no range request header is present, return the body as the response body. if res.Request.Header.Get("Range") == "" { res.ContentLength = int64(len(m.body)) res.Body = ioutil.NopCloser(bytes.NewReader(m.body)) return nil } rh := res.Request.Header.Get("Range") rh = strings.ToLower(rh) sranges := strings.Split(strings.TrimLeft(rh, "bytes="), ",") var ranges [][]int for _, rng := range sranges { if strings.HasSuffix(rng, "-") { rng = fmt.Sprintf("%s%d", rng, len(m.body)-1) } rs := strings.Split(rng, "-") if len(rs) != 2 { res.StatusCode = http.StatusRequestedRangeNotSatisfiable return nil } start, err := strconv.Atoi(strings.TrimSpace(rs[0])) if err != nil { return err } end, err := strconv.Atoi(strings.TrimSpace(rs[1])) if err != nil { return err } if start > end { res.StatusCode = http.StatusRequestedRangeNotSatisfiable return nil } ranges = append(ranges, []int{start, end}) } // Range request. res.StatusCode = http.StatusPartialContent // Single range request. if len(ranges) == 1 { start := ranges[0][0] end := ranges[0][1] seg := m.body[start : end+1] res.ContentLength = int64(len(seg)) res.Body = ioutil.NopCloser(bytes.NewReader(seg)) res.Header.Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, len(m.body))) return nil } // Multipart range request. var mpbody bytes.Buffer mpw := multipart.NewWriter(&mpbody) mpw.SetBoundary(m.boundary) for _, rng := range ranges { start, end := rng[0], rng[1] mimeh := make(textproto.MIMEHeader) mimeh.Set("Content-Type", m.contentType) mimeh.Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, len(m.body))) seg := m.body[start : end+1] pw, err := mpw.CreatePart(mimeh) if err != nil { return err } if _, err := pw.Write(seg); err != nil { return err } } mpw.Close() res.ContentLength = int64(len(mpbody.Bytes())) res.Body = ioutil.NopCloser(bytes.NewReader(mpbody.Bytes())) res.Header.Set("Content-Type", fmt.Sprintf("multipart/byteranges; boundary=%s", m.boundary)) return nil } // randomBoundary generates a 30 character string for boundaries for mulipart range // requests. This func panics if io.Readfull fails. // Borrowed from: https://golang.org/src/mime/multipart/writer.go?#L73 func randomBoundary() string { var buf [30]byte _, err := io.ReadFull(rand.Reader, buf[:]) if err != nil { panic(err) } return fmt.Sprintf("%x", buf[:]) } martian-3.3.2/body/body_modifier_test.go000066400000000000000000000232161421371434000203130ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package body import ( "bytes" "encoding/base64" "fmt" "io" "io/ioutil" "mime/multipart" "net/http" "strings" "testing" "github.com/google/martian/v3/messageview" "github.com/google/martian/v3/parse" "github.com/google/martian/v3/proxyutil" ) func TestBodyModifier(t *testing.T) { mod := NewModifier([]byte("text"), "text/plain") req, err := http.NewRequest("GET", "/", strings.NewReader("")) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } req.Header.Set("Content-Encoding", "gzip") if err := mod.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := req.Header.Get("Content-Type"), "text/plain"; got != want { t.Errorf("req.Header.Get(%q): got %v, want %v", "Content-Type", got, want) } if got, want := req.ContentLength, int64(len([]byte("text"))); got != want { t.Errorf("req.ContentLength: got %d, want %d", got, want) } if got, want := req.Header.Get("Content-Encoding"), ""; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", "Content-Encoding", got, want) } got, err := ioutil.ReadAll(req.Body) if err != nil { t.Fatalf("ioutil.ReadAll(): got %v, want no error", err) } req.Body.Close() if want := []byte("text"); !bytes.Equal(got, want) { t.Errorf("res.Body: got %q, want %q", got, want) } res := proxyutil.NewResponse(200, nil, req) res.Header.Set("Content-Encoding", "gzip") if err := mod.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := res.Header.Get("Content-Type"), "text/plain"; got != want { t.Errorf("res.Header.Get(%q): got %v, want %v", "Content-Type", got, want) } if got, want := res.ContentLength, int64(len([]byte("text"))); got != want { t.Errorf("res.ContentLength: got %d, want %d", got, want) } if got, want := res.Header.Get("Content-Encoding"), ""; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Content-Encoding", got, want) } got, err = ioutil.ReadAll(res.Body) if err != nil { t.Fatalf("ioutil.ReadAll(): got %v, want no error", err) } res.Body.Close() if want := []byte("text"); !bytes.Equal(got, want) { t.Errorf("res.Body: got %q, want %q", got, want) } } func TestRangeHeaderRequestSingleRange(t *testing.T) { mod := NewModifier([]byte("0123456789"), "text/plain") req, err := http.NewRequest("GET", "/", strings.NewReader("")) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } req.Header.Set("Range", "bytes=1-4") res := proxyutil.NewResponse(200, nil, req) if err := mod.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := res.StatusCode, http.StatusPartialContent; got != want { t.Errorf("res.Status: got %v, want %v", got, want) } if got, want := res.ContentLength, int64(len([]byte("1234"))); got != want { t.Errorf("res.ContentLength: got %d, want %d", got, want) } if got, want := res.Header.Get("Content-Range"), "bytes 1-4/10"; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Content-Encoding", got, want) } got, err := ioutil.ReadAll(res.Body) if err != nil { t.Fatalf("ioutil.ReadAll(): got %v, want no error", err) } res.Body.Close() if want := []byte("1234"); !bytes.Equal(got, want) { t.Errorf("res.Body: got %q, want %q", got, want) } } func TestRangeHeaderRequestSingleRangeHasAllTheBytes(t *testing.T) { mod := NewModifier([]byte("0123456789"), "text/plain") req, err := http.NewRequest("GET", "/", strings.NewReader("")) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } req.Header.Set("Range", "bytes=0-") res := proxyutil.NewResponse(200, nil, req) if err := mod.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := res.StatusCode, http.StatusPartialContent; got != want { t.Errorf("res.Status: got %v, want %v", got, want) } if got, want := res.ContentLength, int64(len([]byte("0123456789"))); got != want { t.Errorf("res.ContentLength: got %d, want %d", got, want) } if got, want := res.Header.Get("Content-Range"), "bytes 0-9/10"; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Content-Encoding", got, want) } got, err := ioutil.ReadAll(res.Body) if err != nil { t.Fatalf("ioutil.ReadAll(): got %v, want no error", err) } res.Body.Close() if want := []byte("0123456789"); !bytes.Equal(got, want) { t.Errorf("res.Body: got %q, want %q", got, want) } } func TestRangeNoEndingIndexSpecified(t *testing.T) { mod := NewModifier([]byte("0123456789"), "text/plain") req, err := http.NewRequest("GET", "/", strings.NewReader("")) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } req.Header.Set("Range", "bytes=8-") res := proxyutil.NewResponse(200, nil, req) if err := mod.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := res.StatusCode, http.StatusPartialContent; got != want { t.Errorf("res.Status: got %v, want %v", got, want) } if got, want := res.ContentLength, int64(len([]byte("89"))); got != want { t.Errorf("res.ContentLength: got %d, want %d", got, want) } if got, want := res.Header.Get("Content-Range"), "bytes 8-9/10"; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Content-Encoding", got, want) } } func TestRangeHeaderMultipartRange(t *testing.T) { mod := NewModifier([]byte("0123456789"), "text/plain") bndry := "3d6b6a416f9b5" mod.SetBoundary(bndry) req, err := http.NewRequest("GET", "/", strings.NewReader("")) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } req.Header.Set("Range", "bytes=1-4, 7-9") res := proxyutil.NewResponse(200, nil, req) if err := mod.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := res.StatusCode, http.StatusPartialContent; got != want { t.Errorf("res.Status: got %v, want %v", got, want) } if got, want := res.Header.Get("Content-Type"), "multipart/byteranges; boundary=3d6b6a416f9b5"; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Content-Type", got, want) } mv := messageview.New() if err := mv.SnapshotResponse(res); err != nil { t.Fatalf("mv.SnapshotResponse(res): got %v, want no error", err) } br, err := mv.BodyReader() if err != nil { t.Fatalf("mv.BodyReader(): got %v, want no error", err) } mpr := multipart.NewReader(br, bndry) prt1, err := mpr.NextPart() if err != nil { t.Fatalf("mpr.NextPart(): got %v, want no error", err) } defer prt1.Close() if got, want := prt1.Header.Get("Content-Type"), "text/plain"; got != want { t.Errorf("prt1.Header.Get(%q): got %q, want %q", "Content-Type", got, want) } if got, want := prt1.Header.Get("Content-Range"), "bytes 1-4/10"; got != want { t.Errorf("prt1.Header.Get(%q): got %q, want %q", "Content-Range", got, want) } prt1b, err := ioutil.ReadAll(prt1) if err != nil { t.Errorf("ioutil.Readall(prt1): got %v, want no error", err) } if got, want := string(prt1b), "1234"; got != want { t.Errorf("prt1 body: got %s, want %s", got, want) } prt2, err := mpr.NextPart() if err != nil { t.Fatalf("mpr.NextPart(): got %v, want no error", err) } defer prt2.Close() if got, want := prt2.Header.Get("Content-Type"), "text/plain"; got != want { t.Errorf("prt2.Header.Get(%q): got %q, want %q", "Content-Type", got, want) } if got, want := prt2.Header.Get("Content-Range"), "bytes 7-9/10"; got != want { t.Errorf("prt2.Header.Get(%q): got %q, want %q", "Content-Range", got, want) } prt2b, err := ioutil.ReadAll(prt2) if err != io.ErrUnexpectedEOF && err != nil { t.Errorf("ioutil.Readall(prt2): got %v, want no error", err) } if got, want := string(prt2b), "789"; got != want { t.Errorf("prt2 body: got %s, want %s", got, want) } _, err = mpr.NextPart() if err == nil { t.Errorf("mpr.NextPart: want io.EOF, got no error") } if err != io.EOF { t.Errorf("mpr.NextPart: want io.EOF, got %v", err) } } func TestModifierFromJSON(t *testing.T) { data := base64.StdEncoding.EncodeToString([]byte("data")) msg := fmt.Sprintf(`{ "body.Modifier":{ "scope": ["response"], "contentType": "text/plain", "body": %q } }`, data) r, err := parse.FromJSON([]byte(msg)) if err != nil { t.Fatalf("parse.FromJSON(): got %v, want no error", err) } resmod := r.ResponseModifier() if resmod == nil { t.Fatalf("resmod: got nil, want not nil") } req, err := http.NewRequest("GET", "/", strings.NewReader("")) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } res := proxyutil.NewResponse(200, nil, req) if err := resmod.ModifyResponse(res); err != nil { t.Fatalf("resmod.ModifyResponse(): got %v, want no error", err) } if got, want := res.Header.Get("Content-Type"), "text/plain"; got != want { t.Errorf("res.Header.Get(%q): got %v, want %v", "Content-Type", got, want) } got, err := ioutil.ReadAll(res.Body) if err != nil { t.Fatalf("ioutil.ReadAll(): got %v, want no error", err) } res.Body.Close() if want := []byte("data"); !bytes.Equal(got, want) { t.Errorf("res.Body: got %q, want %q", got, want) } } martian-3.3.2/cmd/000077500000000000000000000000001421371434000137145ustar00rootroot00000000000000martian-3.3.2/cmd/marbl/000077500000000000000000000000001421371434000150115ustar00rootroot00000000000000martian-3.3.2/cmd/marbl/viewer.go000066400000000000000000000056121421371434000166450ustar00rootroot00000000000000// Copyright 2018 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Command-line tool to view .marbl files. This tool reads all headers from provided .marbl // file and prints them to stdout. Bodies of request/response are not printed to stdout, // instead they are saved into individual files in form of "marbl_ID_TYPE" where // ID is the ID of request or response and TYPE is "request" or "response". // // Command line arguments: // --file Path to the .marbl file to view. // --out Optional, folder where this tool will save request/response bodies. // uses current folder by default. package main import ( "flag" "fmt" "io" "log" "os" "github.com/google/martian/v3/marbl" ) var ( file = flag.String("file", "", ".marbl file to show contents of") out = flag.String("out", "", "folder to write request/response bodies to. Folder must exist.") ) func main() { flag.Parse() if *file == "" { fmt.Println("--file flag is required") return } file, err := os.Open(*file) if err != nil { log.Fatal(err) } reader := marbl.NewReader(file) // Iterate through all frames in .marbl file. for { frame, err := reader.ReadFrame() if frame == nil && err == io.EOF { break } if err != nil { log.Fatalf("reader.ReadFrame(): got %v, want no error or io.EOF\n", err) break } // Print current frame to stdout. if frame.FrameType() == marbl.HeaderFrame { fmt.Print("Header ") } else { fmt.Print("Data ") } fmt.Println(frame.String()) // If frame is Data then we write it into separate // file that can be inspected later. if frame.FrameType() == marbl.DataFrame { df := frame.(marbl.Data) var t string if df.MessageType == marbl.Request { t = "request" } else if df.MessageType == marbl.Response { t = "response" } else { t = fmt.Sprintf("unknown_%d", df.MessageType) } fout := fmt.Sprintf("marbl_%s_%s", df.ID, t) if *out != "" { fout = *out + "/" + fout } fmt.Printf("Appending data to file %s\n", fout) // Append data to the file. Note that body can be split // into multiple frames so we have to append and not overwrite. f, err := os.OpenFile(fout, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) if err != nil { log.Fatal(err) } if _, err := f.Write(df.Data); err != nil { log.Fatal(err) } if err := f.Close(); err != nil { log.Fatal(err) } } } } martian-3.3.2/cmd/proxy/000077500000000000000000000000001421371434000150755ustar00rootroot00000000000000martian-3.3.2/cmd/proxy/main.go000066400000000000000000000332101421371434000163470ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // proxy is an HTTP/S proxy configurable via an HTTP API. // // It can be dynamically configured/queried at runtime by issuing requests to // proxy specific paths using JSON. // // Supported configuration endpoints: // // POST http://martian.proxy/configure // // sets the request and response modifier of the proxy; modifiers adhere to the // following top-level JSON structure: // // { // "package.Modifier": { // "scope": ["request", "response"], // "attribute 1": "value", // "attribute 2": "value" // } // } // // modifiers may be "stacked" to provide support for additional behaviors; for // example, to add a "Martian-Test" header with the value "true" for requests // with the domain "www.example.com" the JSON message would be: // // { // "url.Filter": { // "scope": ["request"], // "host": "www.example.com", // "modifier": { // "header.Modifier": { // "name": "Martian-Test", // "value": "true" // } // } // } // } // // url.Filter parses the JSON object in the value of the "url.Filter" attribute; // the "host" key tells the url.Filter to filter requests if the host explicitly // matches "www.example.com" // // the "modifier" key within the "url.Filter" JSON object contains another // modifier message of the type header.Modifier to run iff the filter passes // // groups may also be used to run multiple modifiers sequentially; for example to // log requests and responses after adding the "Martian-Test" header to the // request, but only when the host matches www.example.com: // // { // "url.Filter": { // "host": "www.example.com", // "modifier": { // "fifo.Group": { // "modifiers": [ // { // "header.Modifier": { // "scope": ["request"], // "name": "Martian-Test", // "value": "true" // } // }, // { // "log.Logger": { } // } // ] // } // } // } // } // // modifiers are designed to be composed together in ways that allow the user to // write a single JSON structure to accomplish a variety of functionality // // GET http://martian.proxy/verify // // retrieves the verifications errors as JSON with the following structure: // // { // "errors": [ // { // "message": "request(url) verification failure" // }, // { // "message": "response(url) verification failure" // } // ] // } // // verifiers also adhere to the modifier interface and thus can be included in the // modifier configuration request; for example, to verify that all requests to // "www.example.com" are sent over HTTPS send the following JSON to the // configuration endpoint: // // { // "url.Filter": { // "scope": ["request"], // "host": "www.example.com", // "modifier": { // "url.Verifier": { // "scope": ["request"], // "scheme": "https" // } // } // } // } // // sending a request to "http://martian.proxy/verify" will then return errors from the url.Verifier // // POST http://martian.proxy/verify/reset // // resets the verifiers to their initial state; note some verifiers may start in // a failure state (e.g., pingback.Verifier is failed if no requests have been // seen by the proxy) // // GET http://martian.proxy/authority.cer // // prompts the user to install the CA certificate used by the proxy if MITM is enabled // // GET http://martian.proxy/logs // // retrieves the HAR logs for all requests and responses seen by the proxy if // the HAR flag is enabled // // DELETE http://martian.proxy/logs/reset // // reset the in-memory HAR log; note that the log will grow unbounded unless it // is periodically reset // // passing the -cors flag will enable CORS support for the endpoints so that they // may be called via AJAX // // Sending a sigint will cause the proxy to stop receiving new connections, // finish processing any inflight requests, and close existing connections without // reading anymore requests from them. // // The flags are: // -addr=":8080" // host:port of the proxy // -api-addr=":8181" // host:port of the proxy API // -tls-addr=":4443" // host:port of the proxy over TLS // -api="martian.proxy" // hostname that can be used to reference the configuration API when // configuring through the proxy // -cert="" // PEM encoded X.509 CA certificate; if set, it will be set as the // issuer for dynamically-generated certificates during man-in-the-middle // -key="" // PEM encoded private key of cert (RSA or ECDSA); if set, the key will be used // to sign dynamically-generated certificates during man-in-the-middle // -generate-ca-cert=false // generates a CA certificate and private key to use for man-in-the-middle; // the certificate is only valid while the proxy is running and will be // discarded on shutdown // -organization="Martian Proxy" // organization name set on the dynamically-generated certificates during // man-in-the-middle // -validity="1h" // window of time around the time of request that the dynamically-generated // certificate is valid for; the duration is set such that the total valid // timeframe is double the value of validity (1h before & 1h after) // -cors=false // allow the proxy to be configured via CORS requests; such as when // configuring the proxy via AJAX // -har=false // enable logging endpoints for retrieving full request/response logs in // HAR format. // -traffic-shaping=false // enable traffic shaping endpoints for simulating latency and constrained // bandwidth conditions (e.g. mobile, exotic network infrastructure, the // 90's) // -skip-tls-verify=false // skip TLS server verification; insecure and intended for testing only // -v=0 // log level for console logs; defaults to error only. package main import ( "crypto/tls" "crypto/x509" "flag" "log" "net" "net/http" "net/url" "os" "os/signal" "path" "strconv" "strings" "time" "github.com/google/martian/v3" mapi "github.com/google/martian/v3/api" "github.com/google/martian/v3/cors" "github.com/google/martian/v3/fifo" "github.com/google/martian/v3/har" "github.com/google/martian/v3/httpspec" "github.com/google/martian/v3/marbl" "github.com/google/martian/v3/martianhttp" "github.com/google/martian/v3/martianlog" "github.com/google/martian/v3/mitm" "github.com/google/martian/v3/servemux" "github.com/google/martian/v3/trafficshape" "github.com/google/martian/v3/verify" _ "github.com/google/martian/v3/body" _ "github.com/google/martian/v3/cookie" _ "github.com/google/martian/v3/failure" _ "github.com/google/martian/v3/martianurl" _ "github.com/google/martian/v3/method" _ "github.com/google/martian/v3/pingback" _ "github.com/google/martian/v3/port" _ "github.com/google/martian/v3/priority" _ "github.com/google/martian/v3/querystring" _ "github.com/google/martian/v3/skip" _ "github.com/google/martian/v3/stash" _ "github.com/google/martian/v3/static" _ "github.com/google/martian/v3/status" ) var ( addr = flag.String("addr", ":8080", "host:port of the proxy") apiAddr = flag.String("api-addr", ":8181", "host:port of the configuration API") tlsAddr = flag.String("tls-addr", ":4443", "host:port of the proxy over TLS") api = flag.String("api", "martian.proxy", "hostname for the API") generateCA = flag.Bool("generate-ca-cert", false, "generate CA certificate and private key for MITM") cert = flag.String("cert", "", "filepath to the CA certificate used to sign MITM certificates") key = flag.String("key", "", "filepath to the private key of the CA used to sign MITM certificates") organization = flag.String("organization", "Martian Proxy", "organization name for MITM certificates") validity = flag.Duration("validity", time.Hour, "window of time that MITM certificates are valid") allowCORS = flag.Bool("cors", false, "allow CORS requests to configure the proxy") harLogging = flag.Bool("har", false, "enable HAR logging API") marblLogging = flag.Bool("marbl", false, "enable MARBL logging API") trafficShaping = flag.Bool("traffic-shaping", false, "enable traffic shaping API") skipTLSVerify = flag.Bool("skip-tls-verify", false, "skip TLS server verification; insecure") dsProxyURL = flag.String("downstream-proxy-url", "", "URL of downstream proxy") ) func main() { flag.Parse() martian.Init() p := martian.NewProxy() defer p.Close() l, err := net.Listen("tcp", *addr) if err != nil { log.Fatal(err) } lAPI, err := net.Listen("tcp", *apiAddr) if err != nil { log.Fatal(err) } log.Printf("martian: starting proxy on %s and api on %s", l.Addr().String(), lAPI.Addr().String()) tr := &http.Transport{ Dial: (&net.Dialer{ Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, }).Dial, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: time.Second, TLSClientConfig: &tls.Config{ InsecureSkipVerify: *skipTLSVerify, }, } p.SetRoundTripper(tr) if *dsProxyURL != "" { u, err := url.Parse(*dsProxyURL) if err != nil { log.Fatal(err) } p.SetDownstreamProxy(u) } mux := http.NewServeMux() var x509c *x509.Certificate var priv interface{} if *generateCA { var err error x509c, priv, err = mitm.NewAuthority("martian.proxy", "Martian Authority", 30*24*time.Hour) if err != nil { log.Fatal(err) } } else if *cert != "" && *key != "" { tlsc, err := tls.LoadX509KeyPair(*cert, *key) if err != nil { log.Fatal(err) } priv = tlsc.PrivateKey x509c, err = x509.ParseCertificate(tlsc.Certificate[0]) if err != nil { log.Fatal(err) } } if x509c != nil && priv != nil { mc, err := mitm.NewConfig(x509c, priv) if err != nil { log.Fatal(err) } mc.SetValidity(*validity) mc.SetOrganization(*organization) mc.SkipTLSVerify(*skipTLSVerify) p.SetMITM(mc) // Expose certificate authority. ah := martianhttp.NewAuthorityHandler(x509c) configure("/authority.cer", ah, mux) // Start TLS listener for transparent MITM. tl, err := net.Listen("tcp", *tlsAddr) if err != nil { log.Fatal(err) } go p.Serve(tls.NewListener(tl, mc.TLS())) } stack, fg := httpspec.NewStack("martian") // wrap stack in a group so that we can forward API requests to the API port // before the httpspec modifiers which include the via modifier which will // trip loop detection topg := fifo.NewGroup() // Redirect API traffic to API server. if *apiAddr != "" { addrParts := strings.Split(lAPI.Addr().String(), ":") apip := addrParts[len(addrParts)-1] port, err := strconv.Atoi(apip) if err != nil { log.Fatal(err) } host := strings.Join(addrParts[:len(addrParts)-1], ":") // Forward traffic that pattern matches in http.DefaultServeMux apif := servemux.NewFilter(mux) apif.SetRequestModifier(mapi.NewForwarder(host, port)) topg.AddRequestModifier(apif) } topg.AddRequestModifier(stack) topg.AddResponseModifier(stack) p.SetRequestModifier(topg) p.SetResponseModifier(topg) m := martianhttp.NewModifier() fg.AddRequestModifier(m) fg.AddResponseModifier(m) if *harLogging { hl := har.NewLogger() muxf := servemux.NewFilter(mux) // Only append to HAR logs when the requests are not API requests, // that is, they are not matched in http.DefaultServeMux muxf.RequestWhenFalse(hl) muxf.ResponseWhenFalse(hl) stack.AddRequestModifier(muxf) stack.AddResponseModifier(muxf) configure("/logs", har.NewExportHandler(hl), mux) configure("/logs/reset", har.NewResetHandler(hl), mux) } logger := martianlog.NewLogger() logger.SetDecode(true) stack.AddRequestModifier(logger) stack.AddResponseModifier(logger) if *marblLogging { lsh := marbl.NewHandler() lsm := marbl.NewModifier(lsh) muxf := servemux.NewFilter(mux) muxf.RequestWhenFalse(lsm) muxf.ResponseWhenFalse(lsm) stack.AddRequestModifier(muxf) stack.AddResponseModifier(muxf) // retrieve binary marbl logs mux.Handle("/binlogs", lsh) } // Configure modifiers. configure("/configure", m, mux) // Verify assertions. vh := verify.NewHandler() vh.SetRequestVerifier(m) vh.SetResponseVerifier(m) configure("/verify", vh, mux) // Reset verifications. rh := verify.NewResetHandler() rh.SetRequestVerifier(m) rh.SetResponseVerifier(m) configure("/verify/reset", rh, mux) if *trafficShaping { tsl := trafficshape.NewListener(l) tsh := trafficshape.NewHandler(tsl) configure("/shape-traffic", tsh, mux) l = tsl } go p.Serve(l) go http.Serve(lAPI, mux) sigc := make(chan os.Signal, 1) signal.Notify(sigc, os.Interrupt) <-sigc log.Println("martian: shutting down") os.Exit(0) } // configure installs a configuration handler at path. func configure(pattern string, handler http.Handler, mux *http.ServeMux) { if *allowCORS { handler = cors.NewHandler(handler) } // register handler for martian.proxy to be forwarded to // local API server mux.Handle(path.Join(*api, pattern), handler) // register handler for local API server p := path.Join("localhost"+*apiAddr, pattern) mux.Handle(p, handler) } martian-3.3.2/cmd/proxy/main_test.go000066400000000000000000000264241421371434000174170ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package main import ( "crypto/tls" "crypto/x509" "encoding/base64" "fmt" "io/ioutil" "net" "net/http" "net/url" "os" "os/exec" "path/filepath" "strings" "testing" "time" "github.com/google/martian/v3/mitm" ) func waitForProxy(t *testing.T, c *http.Client, apiURL string) { timeout := 5 * time.Second deadline := time.Now().Add(timeout) for time.Now().Before(deadline) { res, err := c.Get(apiURL) if err != nil { time.Sleep(200 * time.Millisecond) continue } defer res.Body.Close() if got, want := res.StatusCode, http.StatusOK; got != want { t.Fatalf("waitForProxy: c.Get(%q): got status %d, want %d", apiURL, got, want) } return } t.Fatalf("waitForProxy: did not start up within %.1f seconds", timeout.Seconds()) } // getFreePort returns a port string preceded by a colon, e.g. ":1234" func getFreePort(t *testing.T) string { l, err := net.Listen("tcp", ":") if err != nil { t.Fatalf("getFreePort: could not get free port: %v", err) } defer l.Close() return l.Addr().String()[strings.LastIndex(l.Addr().String(), ":"):] } func parseURL(t *testing.T, u string) *url.URL { p, err := url.Parse(u) if err != nil { t.Fatalf("url.Parse(%q): got error %v, want no error", u, err) } return p } func TestProxyMain(t *testing.T) { tempDir, err := ioutil.TempDir("", t.Name()) if err != nil { t.Fatal(err) } defer os.RemoveAll(tempDir) // Build proxy binary binPath := filepath.Join(tempDir, "proxy") cmd := exec.Command("go", "build", "-o", binPath) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr if err := cmd.Run(); err != nil { t.Fatal(err) } t.Run("Http", func(t *testing.T) { // Start proxy proxyPort := getFreePort(t) apiPort := getFreePort(t) cmd := exec.Command(binPath, "-addr="+proxyPort, "-api-addr="+apiPort) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr if err := cmd.Start(); err != nil { t.Fatal(err) } defer cmd.Wait() defer cmd.Process.Signal(os.Interrupt) proxyURL := "http://localhost" + proxyPort apiURL := "http://localhost" + apiPort configureURL := "http://martian.proxy/configure" // TODO: Make using API hostport directly work on Travis. apiClient := &http.Client{Transport: &http.Transport{Proxy: http.ProxyURL(parseURL(t, apiURL))}} waitForProxy(t, apiClient, configureURL) // Configure modifiers config := strings.NewReader(` { "fifo.Group": { "scope": ["request", "response"], "modifiers": [ { "status.Modifier": { "scope": ["response"], "statusCode": 418 } }, { "skip.RoundTrip": {} } ] } }`) res, err := apiClient.Post(configureURL, "application/json", config) if err != nil { t.Fatalf("apiClient.Post(%q): got error %v, want no error", configureURL, err) } defer res.Body.Close() if got, want := res.StatusCode, http.StatusOK; got != want { t.Fatalf("apiClient.Post(%q): got status %d, want %d", configureURL, got, want) } // Exercise proxy client := &http.Client{Transport: &http.Transport{Proxy: http.ProxyURL(parseURL(t, proxyURL))}} testURL := "http://super.fake.domain/" res, err = client.Get(testURL) if err != nil { t.Fatalf("client.Get(%q): got error %v, want no error", testURL, err) } defer res.Body.Close() if got, want := res.StatusCode, http.StatusTeapot; got != want { t.Errorf("client.Get(%q): got status %d, want %d", testURL, got, want) } }) t.Run("HttpsGenerateCert", func(t *testing.T) { // Create test certificate for test TLS server certName := "martian.proxy" certOrg := "Martian Authority" certExpiry := 90 * time.Minute servCert, servPriv, err := mitm.NewAuthority(certName, certOrg, certExpiry) if err != nil { t.Fatalf("mitm.NewAuthority(%q, %q, %q): got error %v, want no error", certName, certOrg, certExpiry, err) } mc, err := mitm.NewConfig(servCert, servPriv) if err != nil { t.Fatalf("mitm.NewConfig(%p, %q): got error %v, want no error", servCert, servPriv, err) } sc := mc.TLS() // Configure and start test TLS server servPort := getFreePort(t) l, err := tls.Listen("tcp", servPort, sc) if err != nil { t.Fatalf("tls.Listen(\"tcp\", %q, %p): got error %v, want no error", servPort, sc, err) } defer l.Close() server := &http.Server{ Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusTeapot) w.Write([]byte("Hello!")) }), } go server.Serve(l) defer server.Close() // Start proxy proxyPort := getFreePort(t) apiPort := getFreePort(t) cmd := exec.Command(binPath, "-addr="+proxyPort, "-api-addr="+apiPort, "-generate-ca-cert", "-skip-tls-verify") cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr if err := cmd.Start(); err != nil { t.Fatal(err) } defer cmd.Wait() defer cmd.Process.Signal(os.Interrupt) proxyURL := "http://localhost" + proxyPort apiURL := "http://localhost" + apiPort configureURL := "http://martian.proxy/configure" // TODO: Make using API hostport directly work on Travis. apiClient := &http.Client{Transport: &http.Transport{Proxy: http.ProxyURL(parseURL(t, apiURL))}} waitForProxy(t, apiClient, configureURL) // Configure modifiers config := strings.NewReader(fmt.Sprintf(` { "body.Modifier": { "scope": ["response"], "contentType": "text/plain", "body": "%s" } }`, base64.StdEncoding.EncodeToString([]byte("茶壺")))) res, err := apiClient.Post(configureURL, "application/json", config) if err != nil { t.Fatalf("apiClient.Post(%q): got error %v, want no error", configureURL, err) } defer res.Body.Close() if got, want := res.StatusCode, http.StatusOK; got != want { t.Fatalf("apiClient.Post(%q): got status %d, want %d", configureURL, got, want) } // Install proxy's CA cert into http client caCertURL := "http://martian.proxy/authority.cer" res, err = apiClient.Get(caCertURL) if err != nil { t.Fatalf("apiClient.Get(%q): got error %v, want no error", caCertURL, err) } defer res.Body.Close() caCert, err := ioutil.ReadAll(res.Body) if err != nil { t.Fatalf("ioutil.ReadAll(res.Body): got error %v, want no error", err) } caCertPool := x509.NewCertPool() caCertPool.AppendCertsFromPEM(caCert) // Exercise proxy client := &http.Client{Transport: &http.Transport{ Proxy: http.ProxyURL(parseURL(t, proxyURL)), TLSClientConfig: &tls.Config{ RootCAs: caCertPool, }, }} testURL := "https://localhost" + servPort res, err = client.Get(testURL) if err != nil { t.Fatalf("client.Get(%q): got error %v, want no error", testURL, err) } defer res.Body.Close() if got, want := res.StatusCode, http.StatusTeapot; got != want { t.Fatalf("client.Get(%q): got status %d, want %d", testURL, got, want) } body, err := ioutil.ReadAll(res.Body) if err != nil { t.Fatalf("ioutil.ReadAll(res.Body): got error %v, want no error", err) } if got, want := string(body), "茶壺"; got != want { t.Fatalf("modified response body: got %s, want %s", got, want) } }) t.Run("DownstreamProxy", func(t *testing.T) { // Start downstream proxy dsProxyPort := getFreePort(t) dsAPIPort := getFreePort(t) cmd := exec.Command(binPath, "-addr="+dsProxyPort, "-api-addr="+dsAPIPort) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr if err := cmd.Start(); err != nil { t.Fatal(err) } defer cmd.Wait() defer cmd.Process.Signal(os.Interrupt) dsProxyURL := "http://localhost" + dsProxyPort dsAPIURL := "http://localhost" + dsAPIPort configureURL := "http://martian.proxy/configure" // TODO: Make using API hostport directly work on Travis. dsAPIClient := &http.Client{Transport: &http.Transport{Proxy: http.ProxyURL(parseURL(t, dsAPIURL))}} waitForProxy(t, dsAPIClient, configureURL) // Configure modifiers config := strings.NewReader(` { "fifo.Group": { "scope": ["request", "response"], "modifiers": [ { "status.Modifier": { "scope": ["response"], "statusCode": 418 } }, { "skip.RoundTrip": {} } ] } }`) res, err := dsAPIClient.Post(configureURL, "application/json", config) if err != nil { t.Fatalf("dsApiClient.Post(%q): got error %v, want no error", configureURL, err) } defer res.Body.Close() if got, want := res.StatusCode, http.StatusOK; got != want { t.Fatalf("dsApiClient.Post(%q): got status %d, want %d", configureURL, got, want) } // Start main proxy proxyPort := getFreePort(t) apiPort := getFreePort(t) cmd = exec.Command(binPath, "-addr="+proxyPort, "-api-addr="+apiPort, "-downstream-proxy-url="+dsProxyURL) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr if err := cmd.Start(); err != nil { t.Fatal(err) } defer cmd.Wait() defer cmd.Process.Signal(os.Interrupt) proxyURL := "http://localhost" + proxyPort apiURL := "http://localhost" + apiPort // TODO: Make using API hostport directly work on Travis. apiClient := &http.Client{Transport: &http.Transport{Proxy: http.ProxyURL(parseURL(t, apiURL))}} waitForProxy(t, apiClient, configureURL) // Configure modifiers // Setting a different Via header value to circumvent loop detection. config = strings.NewReader(fmt.Sprintf(` { "fifo.Group": { "scope": ["request", "response"], "modifiers": [ { "header.Modifier": { "scope": ["request"], "name": "Via", "value": "martian_1" } }, { "body.Modifier": { "scope": ["response"], "contentType": "text/plain", "body": "%s" } } ] } }`, base64.StdEncoding.EncodeToString([]byte("茶壺")))) res, err = apiClient.Post(configureURL, "application/json", config) if err != nil { t.Fatalf("apiClient.Post(%q): got error %v, want no error", configureURL, err) } defer res.Body.Close() if got, want := res.StatusCode, http.StatusOK; got != want { t.Fatalf("apiClient.Post(%q): got status %d, want %d", configureURL, got, want) } // Exercise proxy client := &http.Client{Transport: &http.Transport{Proxy: http.ProxyURL(parseURL(t, proxyURL))}} testURL := "http://super.fake.domain/" res, err = client.Get(testURL) if err != nil { t.Fatalf("client.Get(%q): got error %v, want no error", testURL, err) } defer res.Body.Close() if got, want := res.StatusCode, http.StatusTeapot; got != want { t.Errorf("client.Get(%q): got status %d, want %d", testURL, got, want) } body, err := ioutil.ReadAll(res.Body) if err != nil { t.Fatalf("ioutil.ReadAll(res.Body): got error %v, want no error", err) } if got, want := string(body), "茶壺"; got != want { t.Fatalf("modified response body: got %s, want %s", got, want) } }) } martian-3.3.2/context.go000066400000000000000000000160521421371434000151700ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package martian import ( "bufio" "crypto/rand" "encoding/hex" "fmt" "net" "net/http" "sync" ) // Context provides information and storage for a single request/response pair. // Contexts are linked to shared session that is used for multiple requests on // a single connection. type Context struct { session *Session id string mu sync.RWMutex vals map[string]interface{} skipRoundTrip bool skipLogging bool apiRequest bool } // Session provides information and storage about a connection. type Session struct { mu sync.RWMutex id string secure bool hijacked bool conn net.Conn brw *bufio.ReadWriter vals map[string]interface{} } var ( ctxmu sync.RWMutex ctxs = make(map[*http.Request]*Context) ) // NewContext returns a context for the in-flight HTTP request. func NewContext(req *http.Request) *Context { ctxmu.RLock() defer ctxmu.RUnlock() return ctxs[req] } // TestContext builds a new session and associated context and returns the // context and a function to remove the associated context. If it fails to // generate either a new session or a new context it will return an error. // Intended for tests only. func TestContext(req *http.Request, conn net.Conn, bw *bufio.ReadWriter) (ctx *Context, remove func(), err error) { ctxmu.Lock() defer ctxmu.Unlock() ctx, ok := ctxs[req] if ok { return ctx, func() { unlink(req) }, nil } s, err := newSession(conn, bw) if err != nil { return nil, nil, err } ctx, err = withSession(s) if err != nil { return nil, nil, err } ctxs[req] = ctx return ctx, func() { unlink(req) }, nil } // ID returns the session ID. func (s *Session) ID() string { s.mu.RLock() defer s.mu.RUnlock() return s.id } // IsSecure returns whether the current session is from a secure connection, // such as when receiving requests from a TLS connection that has been MITM'd. func (s *Session) IsSecure() bool { s.mu.RLock() defer s.mu.RUnlock() return s.secure } // MarkSecure marks the session as secure. func (s *Session) MarkSecure() { s.mu.Lock() defer s.mu.Unlock() s.secure = true } // MarkInsecure marks the session as insecure. func (s *Session) MarkInsecure() { s.mu.Lock() defer s.mu.Unlock() s.secure = false } // Hijack takes control of the connection from the proxy. No further action // will be taken by the proxy and the connection will be closed following the // return of the hijacker. func (s *Session) Hijack() (net.Conn, *bufio.ReadWriter, error) { s.mu.Lock() defer s.mu.Unlock() if s.hijacked { return nil, nil, fmt.Errorf("martian: session has already been hijacked") } s.hijacked = true return s.conn, s.brw, nil } // Hijacked returns whether the connection has been hijacked. func (s *Session) Hijacked() bool { s.mu.RLock() defer s.mu.RUnlock() return s.hijacked } // setConn resets the underlying connection and bufio.ReadWriter of the // session. Used by the proxy when the connection is upgraded to TLS. func (s *Session) setConn(conn net.Conn, brw *bufio.ReadWriter) { s.mu.Lock() defer s.mu.Unlock() s.conn = conn s.brw = brw } // Get takes key and returns the associated value from the session. func (s *Session) Get(key string) (interface{}, bool) { s.mu.RLock() defer s.mu.RUnlock() val, ok := s.vals[key] return val, ok } // Set takes a key and associates it with val in the session. The value is // persisted for the entire session across multiple requests and responses. func (s *Session) Set(key string, val interface{}) { s.mu.Lock() defer s.mu.Unlock() s.vals[key] = val } // Session returns the session for the context. func (ctx *Context) Session() *Session { return ctx.session } // ID returns the context ID. func (ctx *Context) ID() string { return ctx.id } // Get takes key and returns the associated value from the context. func (ctx *Context) Get(key string) (interface{}, bool) { ctx.mu.RLock() defer ctx.mu.RUnlock() val, ok := ctx.vals[key] return val, ok } // Set takes a key and associates it with val in the context. The value is // persisted for the duration of the request and is removed on the following // request. func (ctx *Context) Set(key string, val interface{}) { ctx.mu.Lock() defer ctx.mu.Unlock() ctx.vals[key] = val } // SkipRoundTrip skips the round trip for the current request. func (ctx *Context) SkipRoundTrip() { ctx.mu.Lock() defer ctx.mu.Unlock() ctx.skipRoundTrip = true } // SkippingRoundTrip returns whether the current round trip will be skipped. func (ctx *Context) SkippingRoundTrip() bool { ctx.mu.RLock() defer ctx.mu.RUnlock() return ctx.skipRoundTrip } // SkipLogging skips logging by Martian loggers for the current request. func (ctx *Context) SkipLogging() { ctx.mu.Lock() defer ctx.mu.Unlock() ctx.skipLogging = true } // SkippingLogging returns whether the current request / response pair will be logged. func (ctx *Context) SkippingLogging() bool { ctx.mu.RLock() defer ctx.mu.RUnlock() return ctx.skipLogging } // APIRequest marks the requests as a request to the proxy API. func (ctx *Context) APIRequest() { ctx.mu.Lock() defer ctx.mu.Unlock() ctx.apiRequest = true } // IsAPIRequest returns true when the request patterns matches a pattern in the proxy // mux. The mux is usually defined as a parameter to the api.Forwarder, which uses // http.DefaultServeMux by default. func (ctx *Context) IsAPIRequest() bool { ctx.mu.RLock() defer ctx.mu.RUnlock() return ctx.apiRequest } // newID creates a new 16 character random hex ID; note these are not UUIDs. func newID() (string, error) { src := make([]byte, 8) if _, err := rand.Read(src); err != nil { return "", err } return hex.EncodeToString(src), nil } // link associates the context with request. func link(req *http.Request, ctx *Context) { ctxmu.Lock() defer ctxmu.Unlock() ctxs[req] = ctx } // unlink removes the context for request. func unlink(req *http.Request) { ctxmu.Lock() defer ctxmu.Unlock() delete(ctxs, req) } // newSession builds a new session. func newSession(conn net.Conn, brw *bufio.ReadWriter) (*Session, error) { sid, err := newID() if err != nil { return nil, err } return &Session{ id: sid, conn: conn, brw: brw, vals: make(map[string]interface{}), }, nil } // withSession builds a new context from an existing session. Session must be // non-nil. func withSession(s *Session) (*Context, error) { cid, err := newID() if err != nil { return nil, err } return &Context{ session: s, id: cid, vals: make(map[string]interface{}), }, nil } martian-3.3.2/context_test.go000066400000000000000000000064171421371434000162330ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package martian import ( "bufio" "bytes" "io/ioutil" "net" "net/http" "testing" ) func TestContexts(t *testing.T) { req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } ctx, remove, err := TestContext(req, nil, nil) if err != nil { t.Fatalf("TestContext(): got %v, want no error", err) } defer remove() if len(ctx.ID()) != 16 { t.Errorf("ctx.ID(): got %q, want 16 character random ID", ctx.ID()) } ctx.Set("key", "value") got, ok := ctx.Get("key") if !ok { t.Errorf("ctx.Get(%q): got !ok, want ok", "key") } if want := "value"; got != want { t.Errorf("ctx.Get(%q): got %q, want %q", "key", got, want) } ctx.SkipRoundTrip() if !ctx.SkippingRoundTrip() { t.Error("ctx.SkippingRoundTrip(): got false, want true") } ctx.SkipLogging() if !ctx.SkippingLogging() { t.Error("ctx.SkippingLogging(): got false, want true") } s := ctx.Session() if len(s.ID()) != 16 { t.Errorf("s.ID(): got %q, want 16 character random ID", s.ID()) } s.MarkSecure() if !s.IsSecure() { t.Error("s.IsSecure(): got false, want true") } s.Set("key", "value") got, ok = s.Get("key") if !ok { t.Errorf("s.Get(%q): got !ok, want ok", "key") } if want := "value"; got != want { t.Errorf("s.Get(%q): got %q, want %q", "key", got, want) } ctx2, remove, err := TestContext(req, nil, nil) if err != nil { t.Fatalf("TestContext(): got %v, want no error", err) } defer remove() if ctx != ctx2 { t.Error("TestContext(): got new context, want existing context") } } func TestContextHijack(t *testing.T) { rc, wc := net.Pipe() req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } ctx, remove, err := TestContext(req, rc, bufio.NewReadWriter(bufio.NewReader(rc), bufio.NewWriter(rc))) if err != nil { t.Fatalf("TestContext(): got %v, want no error", err) } defer remove() session := ctx.Session() if session.Hijacked() { t.Fatal("session.Hijacked(): got true, want false") } conn, brw, err := session.Hijack() if err != nil { t.Fatalf("session.Hijack(): got %v, want no error", err) } if !session.Hijacked() { t.Fatal("session.Hijacked(): got false, want true") } if _, _, err := session.Hijack(); err == nil { t.Fatal("session.Hijack(): got nil, want rehijack error") } go func() { brw.Write([]byte("test message")) brw.Flush() conn.Close() }() got, err := ioutil.ReadAll(wc) if err != nil { t.Fatalf("ioutil.ReadAll(): got %v, want no error", err) } if want := []byte("test message"); !bytes.Equal(got, want) { t.Errorf("connection: got %q, want %q", got, want) } } martian-3.3.2/cookie/000077500000000000000000000000001421371434000144225ustar00rootroot00000000000000martian-3.3.2/cookie/cookie_filter.go000066400000000000000000000045001421371434000175660ustar00rootroot00000000000000// Copyright 2017 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package cookie import ( "encoding/json" "net/http" "github.com/google/martian/v3" "github.com/google/martian/v3/filter" "github.com/google/martian/v3/log" "github.com/google/martian/v3/parse" ) var noop = martian.Noop("cookie.Filter") type filterJSON struct { Name string `json:"name"` Value string `json:"value"` Modifier json.RawMessage `json:"modifier"` ElseModifier json.RawMessage `json:"else"` Scope []parse.ModifierType `json:"scope"` } func init() { parse.Register("cookie.Filter", filterFromJSON) } // NewFilter builds a new cookie filter. func NewFilter(cookie *http.Cookie) *filter.Filter { log.Debugf("cookie.NewFilter: cookie: %s", cookie.String()) f := filter.New() m := NewMatcher(cookie) f.SetRequestCondition(m) f.SetResponseCondition(m) return f } // filterFromJSON builds a header.Filter from JSON. // // Example JSON: // { // "scope": ["request", "result"], // "name": "Martian-Testing", // "value": "true", // "modifier": { ... }, // "else": { ... } // } func filterFromJSON(b []byte) (*parse.Result, error) { msg := &filterJSON{} if err := json.Unmarshal(b, msg); err != nil { return nil, err } cookie := &http.Cookie{ Name: msg.Name, Value: msg.Value, } filter := NewFilter(cookie) m, err := parse.FromJSON(msg.Modifier) if err != nil { return nil, err } filter.RequestWhenTrue(m.RequestModifier()) filter.ResponseWhenTrue(m.ResponseModifier()) if msg.ElseModifier != nil { em, err := parse.FromJSON(msg.ElseModifier) if err != nil { return nil, err } filter.RequestWhenFalse(em.RequestModifier()) filter.ResponseWhenFalse(em.ResponseModifier()) } return parse.NewResult(filter, msg.Scope) } martian-3.3.2/cookie/cookie_filter_test.go000066400000000000000000000164501421371434000206340ustar00rootroot00000000000000// Copyright 2017 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package cookie import ( "net/http" "testing" "github.com/google/martian/v3/filter" _ "github.com/google/martian/v3/header" "github.com/google/martian/v3/martiantest" "github.com/google/martian/v3/parse" "github.com/google/martian/v3/proxyutil" ) func TestFilterFromJSON(t *testing.T) { msg := []byte(`{ "cookie.Filter": { "scope": ["request", "response"], "name": "martian-cookie", "value": "true", "modifier": { "header.Modifier" : { "scope": ["request", "response"], "name": "Martian-Testing", "value": "true" } }, "else": { "header.Modifier" : { "scope": ["request", "response"], "name": "Martian-Testing", "value": "false" } } } }`) r, err := parse.FromJSON(msg) if err != nil { t.Fatalf("parse.FromJSON(): got %v, want no error", err) } reqmod := r.RequestModifier() if reqmod == nil { t.Fatal("reqmod: got nil, want not nil") } resmod := r.ResponseModifier() if resmod == nil { t.Fatal("resmod: got nil, want not nil") } for _, tc := range []struct { name string wantMatch bool cookie *http.Cookie }{ { name: "matching name and value", wantMatch: true, cookie: &http.Cookie{ Name: "martian-cookie", Value: "true", }, }, { name: "matching name with mismatched value", wantMatch: false, cookie: &http.Cookie{ Name: "martian-cookie", Value: "false", }, }, { name: "missing cookie", wantMatch: false, }, } { req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Errorf("%s: http.NewRequest(): got %v, want no error", tc.name, err) continue } if tc.cookie != nil { req.AddCookie(tc.cookie) } if err := reqmod.ModifyRequest(req); err != nil { t.Errorf("%s: ModifyRequest(): got %v, want no error", tc.name, err) continue } want := "false" if tc.wantMatch { want = "true" } if got := req.Header.Get("Martian-Testing"); got != want { t.Errorf("%s: req.Header.Get(%q): got %q, want %q", "Martian-Testing", tc.name, got, want) continue } res := proxyutil.NewResponse(200, nil, req) if tc.cookie != nil { c := &http.Cookie{Name: tc.cookie.Name, Value: tc.cookie.Value} res.Header.Add("Set-Cookie", c.String()) } if err := resmod.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got := res.Header.Get("Martian-Testing"); got != want { t.Fatalf("res.Header.Get(%q): got %q, want %q", "Martian-Testing", got, want) } } } func TestFilterFromJSONWithoutElse(t *testing.T) { msg := []byte(`{ "cookie.Filter": { "scope": ["request", "response"], "name": "martian-cookie", "value": "true", "modifier": { "header.Modifier" : { "scope": ["request", "response"], "name": "Martian-Testing", "value": "true" } } } }`) _, err := parse.FromJSON(msg) if err != nil { t.Fatalf("parse.FromJSON(): got %v, want no error", err) } } func TestRequestWhenTrueCondition(t *testing.T) { cm := NewMatcher(&http.Cookie{Name: "Martian-Testing", Value: "true"}) tt := []struct { name string value string want bool }{ { name: "Martian-Production", value: "true", want: false, }, { name: "Martian-Testing", value: "true", want: true, }, } for i, tc := range tt { tm := martiantest.NewModifier() f := filter.New() f.SetRequestCondition(cm) f.RequestWhenTrue(tm) req, err := http.NewRequest("GET", "/", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.AddCookie(&http.Cookie{Name: tc.name, Value: tc.value}) if err := f.ModifyRequest(req); err != nil { t.Fatalf("%d. ModifyRequest(): got %v, want no error", i, err) } if tm.RequestModified() != tc.want { t.Errorf("%d. tm.RequestModified(): got %t, want %t", i, tm.RequestModified(), tc.want) } } } func TestRequestWhenFalse(t *testing.T) { cm := NewMatcher(&http.Cookie{Name: "Martian-Testing", Value: "true"}) tt := []struct { name string value string want bool }{ { name: "Martian-Production", value: "true", want: true, }, { name: "Martian-Testing", value: "true", want: false, }, } for i, tc := range tt { tm := martiantest.NewModifier() f := filter.New() f.SetRequestCondition(cm) f.RequestWhenFalse(tm) req, err := http.NewRequest("GET", "/", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.AddCookie(&http.Cookie{Name: tc.name, Value: tc.value}) if err := f.ModifyRequest(req); err != nil { t.Fatalf("%d. ModifyRequest(): got %v, want no error", i, err) } if tm.RequestModified() != tc.want { t.Errorf("%d. tm.RequestModified(): got %t, want %t", i, tm.RequestModified(), tc.want) } } } func TestResponseWhenTrue(t *testing.T) { cm := NewMatcher(&http.Cookie{Name: "Martian-Testing", Value: "true"}) tt := []struct { name string value string want bool }{ { name: "Martian-Production", value: "true", want: false, }, { name: "Martian-Testing", value: "true", want: true, }, } for i, tc := range tt { tm := martiantest.NewModifier() f := filter.New() f.SetResponseCondition(cm) f.ResponseWhenTrue(tm) req, err := http.NewRequest("GET", "/", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } res := proxyutil.NewResponse(200, nil, req) c := &http.Cookie{Name: tc.name, Value: tc.value} res.Header.Add("Set-Cookie", c.String()) if err := f.ModifyResponse(res); err != nil { t.Fatalf("%d. ModifyResponse(): got %v, want no error", i, err) } if tm.ResponseModified() != tc.want { t.Errorf("%d. tm.ResponseModified(): got %t, want %t", i, tm.RequestModified(), tc.want) } } } func TestResponseWhenFalse(t *testing.T) { cm := NewMatcher(&http.Cookie{Name: "Martian-Testing", Value: "true"}) tt := []struct { name string value string want bool }{ { name: "Martian-Production", value: "true", want: true, }, { name: "Martian-Testing", value: "true", want: false, }, } for i, tc := range tt { tm := martiantest.NewModifier() f := filter.New() f.SetResponseCondition(cm) f.ResponseWhenFalse(tm) req, err := http.NewRequest("GET", "/", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } res := proxyutil.NewResponse(200, nil, req) c := &http.Cookie{Name: tc.name, Value: tc.value} res.Header.Add("Set-Cookie", c.String()) if err := f.ModifyResponse(res); err != nil { t.Fatalf("%d. ModifyResponse(): got %v, want no error", i, err) } if tm.ResponseModified() != tc.want { t.Errorf("%d. tm.ResponseModified(): got %t, want %t", i, tm.RequestModified(), tc.want) } } } martian-3.3.2/cookie/cookie_matcher.go000066400000000000000000000036001421371434000177240ustar00rootroot00000000000000// Copyright 2017 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package cookie import ( "net/http" "github.com/google/martian/v3/log" ) // Matcher is a conditonal evalutor of request or // response cookies to be used in structs that take conditions. type Matcher struct { cookie *http.Cookie } // NewMatcher builds a cookie matcher. func NewMatcher(cookie *http.Cookie) *Matcher { return &Matcher{ cookie: cookie, } } // MatchRequest evaluates a request and returns whether or not // the request contains a cookie that matches the provided name, path // and value. func (m *Matcher) MatchRequest(req *http.Request) bool { for _, c := range req.Cookies() { if m.match(c) { log.Debugf("cookie.MatchRequest: %s, matched: cookie: %s", req.URL, c) return true } } return false } // MatchResponse evaluates a response and returns whether or not the response // contains a cookie that matches the provided name and value. func (m *Matcher) MatchResponse(res *http.Response) bool { for _, c := range res.Cookies() { if m.match(c) { log.Debugf("cookie.MatchResponse: %s, matched: cookie: %s", res.Request.URL, c) return true } } return false } func (m *Matcher) match(cs *http.Cookie) bool { switch { case m.cookie.Name != cs.Name: return false case m.cookie.Value != "" && m.cookie.Value != cs.Value: return false } return true } martian-3.3.2/cookie/cookie_modifier.go000066400000000000000000000056331421371434000201070ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package cookie allows for the modification of cookies on http requests and responses. package cookie import ( "encoding/json" "net/http" "time" "github.com/google/martian/v3" "github.com/google/martian/v3/log" "github.com/google/martian/v3/parse" ) func init() { parse.Register("cookie.Modifier", modifierFromJSON) } type modifier struct { cookie *http.Cookie } type modifierJSON struct { Name string `json:"name"` Value string `json:"value"` Path string `json:"path"` Domain string `json:"domain"` Expires time.Time `json:"expires"` Secure bool `json:"secure"` HTTPOnly bool `json:"httpOnly"` MaxAge int `json:"maxAge"` Scope []parse.ModifierType `json:"scope"` } // ModifyRequest adds cookie to the request. func (m *modifier) ModifyRequest(req *http.Request) error { req.AddCookie(m.cookie) log.Debugf("cookie.ModifyRequest: %s: cookie: %s", req.URL, m.cookie) return nil } // ModifyResponse sets cookie on the response. func (m *modifier) ModifyResponse(res *http.Response) error { res.Header.Add("Set-Cookie", m.cookie.String()) log.Debugf("cookie.ModifyResponse: %s: cookie: %s", res.Request.URL, m.cookie) return nil } // NewModifier returns a modifier that injects the provided cookie into the // request or response. func NewModifier(cookie *http.Cookie) martian.RequestResponseModifier { return &modifier{ cookie: cookie, } } // modifierFromJSON takes a JSON message as a byte slice and returns a // CookieModifier and an error. // // Example JSON Configuration message: // { // "name": "Martian-Cookie", // "value": "some value", // "path": "/some/path", // "domain": "example.com", // "expires": "2025-04-12T23:20:50.52Z", // RFC 3339 // "secure": true, // "httpOnly": false, // "maxAge": 86400, // "scope": ["request", "result"] // } func modifierFromJSON(b []byte) (*parse.Result, error) { msg := &modifierJSON{} if err := json.Unmarshal(b, msg); err != nil { return nil, err } c := &http.Cookie{ Name: msg.Name, Value: msg.Value, Path: msg.Path, Domain: msg.Domain, Expires: msg.Expires, Secure: msg.Secure, HttpOnly: msg.HTTPOnly, MaxAge: msg.MaxAge, } return parse.NewResult(NewModifier(c), msg.Scope) } martian-3.3.2/cookie/cookie_modifier_test.go000066400000000000000000000072261421371434000211460ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package cookie import ( "net/http" "testing" "github.com/google/martian/v3/parse" "github.com/google/martian/v3/proxyutil" ) func TestCookieModifier(t *testing.T) { cookie := &http.Cookie{ Name: "name", Value: "value", } mod := NewModifier(cookie) req, err := http.NewRequest("GET", "/", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := mod.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := len(req.Cookies()), 1; got != want { t.Errorf("len(req.Cookies): got %v, want %v", got, want) } if got, want := req.Cookies()[0].Name, cookie.Name; got != want { t.Errorf("req.Cookies()[0].Name: got %v, want %v", got, want) } if got, want := req.Cookies()[0].Value, cookie.Value; got != want { t.Errorf("req.Cookies()[0].Value: got %v, want %v", got, want) } res := proxyutil.NewResponse(200, nil, req) if err := mod.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := len(res.Cookies()), 1; got != want { t.Errorf("len(res.Cookies): got %v, want %v", got, want) } if got, want := res.Cookies()[0].Name, cookie.Name; got != want { t.Errorf("res.Cookies()[0].Name: got %v, want %v", got, want) } if got, want := res.Cookies()[0].Value, cookie.Value; got != want { t.Errorf("res.Cookies()[0].Value: got %v, want %v", got, want) } } func TestModifierFromJSON(t *testing.T) { msg := []byte(`{ "cookie.Modifier": { "scope": ["request", "response"], "name": "martian", "value": "value" } }`) r, err := parse.FromJSON(msg) if err != nil { t.Fatalf("parse.FromJSON(): got %v, want no error", err) } req, err := http.NewRequest("GET", "http://example.com/path/", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } reqmod := r.RequestModifier() if reqmod == nil { t.Fatal("reqmod: got nil, want not nil") } if err := reqmod.ModifyRequest(req); err != nil { t.Fatalf("reqmod.ModifyRequest(): got %v, want no error", err) } if got, want := len(req.Cookies()), 1; got != want { t.Fatalf("len(req.Cookies): got %v, want %v", got, want) } if got, want := req.Cookies()[0].Name, "martian"; got != want { t.Errorf("req.Cookies()[0].Name: got %v, want %v", got, want) } if got, want := req.Cookies()[0].Value, "value"; got != want { t.Errorf("req.Cookies()[0].Value: got %v, want %v", got, want) } resmod := r.ResponseModifier() if resmod == nil { t.Fatal("resmod: got nil, want not nil") } res := proxyutil.NewResponse(200, nil, req) if err := resmod.ModifyResponse(res); err != nil { t.Fatalf("resmod.ModifyResponse(): got %v, want no error", err) } if got, want := len(res.Cookies()), 1; got != want { t.Fatalf("len(res.Cookies): got %v, want %v", got, want) } if got, want := res.Cookies()[0].Name, "martian"; got != want { t.Errorf("res.Cookies()[0].Name: got %v, want %v", got, want) } if got, want := res.Cookies()[0].Value, "value"; got != want { t.Errorf("res.Cookies()[0].Value: got %v, want %v", got, want) } } martian-3.3.2/cors/000077500000000000000000000000001421371434000141175ustar00rootroot00000000000000martian-3.3.2/cors/cors.go000066400000000000000000000044751421371434000154260ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package cors provides CORS support for http.Handlers. package cors import ( "net/http" ) // Handler is an http.Handler that wraps other http.Handlers and provides CORS // support. type Handler struct { handler http.Handler origin string allowCredentials bool } // NewHandler wraps an existing http.Handler allowing it to be requested via CORS. func NewHandler(h http.Handler) *Handler { return &Handler{ handler: h, origin: "*", } } // SetOrigin sets the origin(s) to allow when requested with CORS. func (h *Handler) SetOrigin(origin string) { h.origin = origin } // AllowCredentials allows cookies to be read by the CORS request. func (h *Handler) AllowCredentials(allow bool) { h.allowCredentials = allow } // ServeHTTP determines if a request is a CORS request (normal or preflight) // and sets the appropriate Access-Control-Allow-* headers. It will send the // request to the underlying handler in all cases, except for a preflight // (OPTIONS) request. func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { // Definitely not a CORS request, send it directly to handler. if req.Header.Get("Origin") == "" { h.handler.ServeHTTP(rw, req) return } rw.Header().Set("Access-Control-Allow-Origin", h.origin) if h.allowCredentials { rw.Header().Set("Access-Control-Allow-Credentials", "true") } acrm := req.Header.Get("Access-Control-Request-Method") rw.Header().Set("Access-Control-Allow-Methods", acrm) if acrh := req.Header.Get("Access-Control-Request-Headers"); acrh != "" { rw.Header().Set("Access-Control-Allow-Headers", acrh) } // Preflight request, don't bother sending it to the handler. if req.Method == "OPTIONS" { return } h.handler.ServeHTTP(rw, req) } martian-3.3.2/cors/cors_test.go000066400000000000000000000072371421371434000164640ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package cors import ( "net/http" "net/http/httptest" "testing" ) func TestServeHTTPSameOrigin(t *testing.T) { var handlerRun bool h := NewHandler(http.HandlerFunc( func(rw http.ResponseWriter, req *http.Request) { handlerRun = true })) req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } rw := httptest.NewRecorder() h.ServeHTTP(rw, req) if !handlerRun { t.Error("handlerRun: got false, want true") } } func TestServeHTTPPreflight(t *testing.T) { var handlerRun bool h := NewHandler(http.HandlerFunc( func(rw http.ResponseWriter, req *http.Request) { handlerRun = true })) h.AllowCredentials(true) req, err := http.NewRequest("OPTIONS", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Header.Set("Origin", "http://google.com") req.Header.Set("Access-Control-Request-Method", "PUT") req.Header.Set("Access-Control-Request-Headers", "Cors-Test") rw := httptest.NewRecorder() h.ServeHTTP(rw, req) if got, want := rw.Header().Get("Access-Control-Allow-Origin"), "*"; got != want { t.Errorf("rw.Header().Get(%q): got %q, want %q", "Access-Control-Allow-Origin", got, want) } if got, want := rw.Header().Get("Access-Control-Allow-Methods"), "PUT"; got != want { t.Errorf("rw.Header().Get(%q): got %q, want %q", "Access-Control-Allow-Methods", got, want) } if got, want := rw.Header().Get("Access-Control-Allow-Headers"), "Cors-Test"; got != want { t.Errorf("rw.Header().Get(%q): got %q, want %q", "Access-Control-Allow-Headers", got, want) } if got, want := rw.Header().Get("Access-Control-Allow-Credentials"), "true"; got != want { t.Errorf("rw.Header().Get(%q): got %q, want %q", "Access-Control-Allow-Credentials", got, want) } if handlerRun { t.Error("handlerRun: got true, want false") } } func TestServeHTTPSimple(t *testing.T) { var handlerRun bool h := NewHandler(http.HandlerFunc( func(rw http.ResponseWriter, req *http.Request) { handlerRun = true })) h.SetOrigin("http://martian.local") req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Header.Set("Origin", "http://google.com") req.Header.Set("Access-Control-Request-Method", "GET") req.Header.Set("Access-Control-Request-Headers", "Cors-Test") rw := httptest.NewRecorder() h.ServeHTTP(rw, req) if got, want := rw.Header().Get("Access-Control-Allow-Origin"), "http://martian.local"; got != want { t.Errorf("rw.Header().Get(%q): got %q, want %q", "Access-Control-Allow-Origin", got, want) } if got, want := rw.Header().Get("Access-Control-Allow-Methods"), "GET"; got != want { t.Errorf("rw.Header().Get(%q): got %q, want %q", "Access-Control-Allow-Methods", got, want) } if got, want := rw.Header().Get("Access-Control-Allow-Headers"), "Cors-Test"; got != want { t.Errorf("rw.Header().Get(%q): got %q, want %q", "Access-Control-Allow-Headers", got, want) } if !handlerRun { t.Error("handlerRun: got false, want true") } } martian-3.3.2/cybervillains/000077500000000000000000000000001421371434000160175ustar00rootroot00000000000000martian-3.3.2/cybervillains/cybervillains.go000066400000000000000000000056671421371434000212320ustar00rootroot00000000000000// Copyright 2016 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package cybervillains provides the publically published Selenium project CyberVillains // certificate and key. The CyberVillains cert and key allow for a man in the middle, // and should only be used in testing scenarios. Client installation of the CyberVillains // certificate is inherently and intentionally insecure. package cybervillains // Cert is the x509 CyberVillains public key published by the Selenium project. const Cert string = `-----BEGIN CERTIFICATE----- MIIClTCCAf6gAwIBAgIBATANBgkqhkiG9w0BAQUFADBZMRowGAYDVQQKDBFDeWJl clZpbGxpYW5zLmNvbTEuMCwGA1UECwwlQ3liZXJWaWxsaWFucyBDZXJ0aWZpY2F0 aW9uIEF1dGhvcml0eTELMAkGA1UEBhMCVVMwHhcNMTEwMjEwMDMwMDEwWhcNMzEx MDIzMDMwMDEwWjBZMRowGAYDVQQKDBFDeWJlclZpbGxpYW5zLmNvbTEuMCwGA1UE CwwlQ3liZXJWaWxsaWFucyBDZXJ0aWZpY2F0aW9uIEF1dGhvcml0eTELMAkGA1UE BhMCVVMwgZ8wDQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBAIVQhWYazIfMvJUBP5qh qRyh2tkrYI9wVZ9/Sj1l4tlWY4HOC6Dy5OYBRCmo2T9N8EXrAxXZKKUPzgmb3gIv AQJ9DP6woiHyyztZJ5/cbhlp8EbHBIvGWK3T0Oph3kEPPS2FWKjiH/+pV6qY0Yt+ lkzcwxrjZIah/3VHQXUDm8X1AgMBAAGjbTBrMB0GA1UdDgQWBBQKvBeVNGu8hxtb TP31Y4UttI/1bDASBgNVHRMBAf8ECDAGAQH/AgEAMAsGA1UdDwQEAwIBBjApBgNV HSUEIjAgBggrBgEFBQcDAQYIKwYBBQUHAwkGCmCGSAGG+EUBCAEwDQYJKoZIhvcN AQEFBQADgYEAD/6m8czx19uRPuaHVYhsEX5QGwJ4Y1NFswAByAuSBQB9KI9P2C7I muf1aOoslUC4TxnC6g9H5/XmlK1zbZ+2YuABOb08CTXBC2x3ewJnm94DGPBRzj9o 0rXGEC+jsqsziBw+kg69xFn7PH09ZKUCue8POaaN/z5VoQMoM4ZNTP4= -----END CERTIFICATE-----` // Key is the CyberVillains private key published by the Selenium project. const Key string = `-----BEGIN RSA PRIVATE KEY----- MIICXQIBAAKBgQCFUIVmGsyHzLyVAT+aoakcodrZK2CPcFWff0o9ZeLZVmOBzgug 8uTmAUQpqNk/TfBF6wMV2SilD84Jm94CLwECfQz+sKIh8ss7WSef3G4ZafBGxwSL xlit09DqYd5BDz0thVio4h//qVeqmNGLfpZM3MMa42SGof91R0F1A5vF9QIDAQAB AoGAEVuTkuDIYqIYp7n64xJLZ4v3Z7FLKEHzFApJy0a5y5yA5kTCpNkbTos5qcbv SlvGfgQEadLVhPBS3lNqC5S9J7iUmmdpveXxV5ZaOsK3Zh+QCURfjLvqLH5Fzn1c 341YTCXpPdlbZElbARh3WKtW7R4c5GNNdf7zrWRqjYsXacECQQD4CVJ0l2AOTfLh 0uOXr1wwblIVscNv5WO9WLERtDWZP2EhDkRFMFsV8gTTvs01LiX0PRkuUjP+C6/e g1DlBrqxAkEAiZhE4Ui7AHF6CYg+eamQKf4ECn4KgZ/y68Tan9YiULRXOx4HSpsM 3g+uPvwWnp9Pd/0gVSmQlJn3oNi5LQtIhQJBANF6ZgYL1lceY/NuvUJdGrnYYkDq Ocml7P98CUePb/j2OxzExMm+Vh8JoCQIr5yrVeiZNUwWpsx2qFh/hPF4JnECQCej /8wryPxStQcEAoPIjykZ7o4bS+mWbETynM3Jwm8f1bXJa+5ZhzZ+rAOnWtjuKtX1 zhfa9rVpOkdTyN2qT4UCQQD35VDm82aDi9mC8Zs1T/SrYKwRJuz25JPM8Yh9xiuK 7iI4qfwwaX99fo09cH0pfUdx+z7QyNba8bMfTWe8qPHm -----END RSA PRIVATE KEY-----` martian-3.3.2/failure/000077500000000000000000000000001421371434000146005ustar00rootroot00000000000000martian-3.3.2/failure/failure_verifier.go000066400000000000000000000053231421371434000204540ustar00rootroot00000000000000// Copyright 2017 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package failure provides a verifier that always fails, adding a given message // to the multierror log. This can be used to turn any filter in to a defacto // verifier, by wrapping it in a filter and thus causing a verifier failure whenever // a request passes the filter. package failure import ( "encoding/json" "fmt" "net/http" "github.com/google/martian/v3" "github.com/google/martian/v3/parse" "github.com/google/martian/v3/verify" ) func init() { parse.Register("failure.Verifier", verifierFromJSON) } type verifier struct { message string merr *martian.MultiError } type verifierJSON struct { Message string `json:"message"` Scope []parse.ModifierType `json:"scope"` } // NewVerifier returns a new failing verifier. func NewVerifier(message string) (verify.RequestVerifier, error) { return &verifier{ message: message, merr: martian.NewMultiError(), }, nil } // ModifyRequest adds an error message containing the message field in the verifier to the verifier errors. // This means that any time a request hits the verifier it's treated as an error. func (v *verifier) ModifyRequest(req *http.Request) error { err := fmt.Errorf("request(%v) verification error: %s", req.URL, v.message) v.merr.Add(err) return nil } // VerifyRequests returns an error if any requests have hit the verifier. // If an error is returned it will be of type *martian.MultiError. func (v *verifier) VerifyRequests() error { if v.merr.Empty() { return nil } return v.merr } // ResetRequestVerifications clears all failed request verifications. func (v *verifier) ResetRequestVerifications() { v.merr = martian.NewMultiError() } // verifierFromJSON builds a failure.Verifier from JSON // // Example JSON: // { // "failure.Verifier": { // "scope": ["request", "response"], // "message": "Request passed a filter it should not have" // } // } func verifierFromJSON(b []byte) (*parse.Result, error) { msg := &verifierJSON{} if err := json.Unmarshal(b, msg); err != nil { return nil, err } v, err := NewVerifier(msg.Message) if err != nil { return nil, err } return parse.NewResult(v, msg.Scope) } martian-3.3.2/failure/failure_verifier_test.go000066400000000000000000000063301421371434000215120ustar00rootroot00000000000000// Copyright 2017 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package failure import ( "net/http" "testing" "github.com/google/martian/v3" "github.com/google/martian/v3/parse" "github.com/google/martian/v3/verify" ) func TestVerifyRequestFails(t *testing.T) { v, err := NewVerifier("foo") if err != nil { t.Fatalf("NewVerifier(%q): got %v, want no error", "foo", err) } req, err := http.NewRequest("GET", "http://www.google.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := v.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if err := v.VerifyRequests(); err == nil { t.Fatalf("VerifyRequests(): got no error, want *verify.MultiError") } } func TestFailureWithMultiFail(t *testing.T) { v, err := NewVerifier("foo") if err != nil { t.Fatalf("NewVerifier(%q): got %v, want no error", "foo", err) } req, err := http.NewRequest("GET", "http://www.google.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := v.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if err := v.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } merr, ok := v.VerifyRequests().(*martian.MultiError) if !ok { t.Fatalf("VerifyRequests(): got nil, want *verify.MultiError") } errs := merr.Errors() if len(errs) != 2 { t.Fatalf("len(merr.Errors()): got %d, want 2", len(errs)) } expectErr := "request(http://www.google.com) verification error: foo" for i := range errs { if got, want := errs[i].Error(), expectErr; got != want { t.Errorf("%d. err.Error(): mismatched error output\ngot: %s\nwant: %s", i, got, want) } } v.ResetRequestVerifications() if err := v.VerifyRequests(); err != nil { t.Fatalf("VerifyRequests(): got %v, want no error", err) } } func TestVerifierFromJSON(t *testing.T) { msg := []byte(`{ "failure.Verifier": { "scope": ["request"], "message": "foo" } }`) r, err := parse.FromJSON(msg) if err != nil { t.Fatalf("parse.FromJSON(): got %v, want no error", err) } reqmod := r.RequestModifier() if reqmod == nil { t.Fatal("reqmod: got nil, want not nil") } reqv, ok := reqmod.(verify.RequestVerifier) if !ok { t.Fatal("reqmod.(verify.RequestVerifier): got !ok, want ok") } req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := reqv.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if err := reqv.VerifyRequests(); err == nil { t.Error("VerifyRequests(): got nil, want not nil") } } martian-3.3.2/fifo/000077500000000000000000000000001421371434000140745ustar00rootroot00000000000000martian-3.3.2/fifo/fifo_group.go000066400000000000000000000153541421371434000165720ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package fifo provides Group, which is a list of modifiers that are executed // consecutively. By default, when an error is returned by a modifier, the // execution of the modifiers is halted, and the error is returned. Optionally, // when errror aggregation is enabled (by calling SetAggretateErrors(true)), modifier // execution is not halted, and errors are aggretated and returned after all // modifiers have been executed. package fifo import ( "encoding/json" "net/http" "sync" "github.com/google/martian/v3" "github.com/google/martian/v3/log" "github.com/google/martian/v3/parse" "github.com/google/martian/v3/verify" ) // Group is a martian.RequestResponseModifier that maintains lists of // request and response modifiers executed on a first-in, first-out basis. type Group struct { reqmu sync.RWMutex reqmods []martian.RequestModifier resmu sync.RWMutex resmods []martian.ResponseModifier aggregateErrors bool } type groupJSON struct { Modifiers []json.RawMessage `json:"modifiers"` Scope []parse.ModifierType `json:"scope"` AggregateErrors bool `json:"aggregateErrors"` } func init() { parse.Register("fifo.Group", groupFromJSON) } // NewGroup returns a modifier group. func NewGroup() *Group { return &Group{} } // SetAggregateErrors sets the error behavior for the Group. When true, the Group will // continue to execute consecutive modifiers when a modifier in the group encounters an // error. The Group will then return all errors returned by each modifier after all // modifiers have been executed. When false, if an error is returned by a modifier, the // error is returned by ModifyRequest/Response and no further modifiers are run. // By default, error aggregation is disabled. func (g *Group) SetAggregateErrors(aggerr bool) { g.aggregateErrors = aggerr } // AddRequestModifier adds a RequestModifier to the group's list of request modifiers. func (g *Group) AddRequestModifier(reqmod martian.RequestModifier) { g.reqmu.Lock() defer g.reqmu.Unlock() g.reqmods = append(g.reqmods, reqmod) } // AddResponseModifier adds a ResponseModifier to the group's list of response modifiers. func (g *Group) AddResponseModifier(resmod martian.ResponseModifier) { g.resmu.Lock() defer g.resmu.Unlock() g.resmods = append(g.resmods, resmod) } // ModifyRequest modifies the request. By default, aggregateErrors is false; if an error is // returned by a RequestModifier the error is returned and no further modifiers are run. When // aggregateErrors is set to true, the errors returned by each modifier in the group are // aggregated. func (g *Group) ModifyRequest(req *http.Request) error { log.Debugf("fifo.ModifyRequest: %s", req.URL) g.reqmu.RLock() defer g.reqmu.RUnlock() merr := martian.NewMultiError() for _, reqmod := range g.reqmods { if err := reqmod.ModifyRequest(req); err != nil { if g.aggregateErrors { merr.Add(err) continue } return err } } if merr.Empty() { return nil } return merr } // ModifyResponse modifies the request. By default, aggregateErrors is false; if an error is // returned by a RequestModifier the error is returned and no further modifiers are run. When // aggregateErrors is set to true, the errors returned by each modifier in the group are // aggregated. func (g *Group) ModifyResponse(res *http.Response) error { requ := "" if res.Request != nil { requ = res.Request.URL.String() log.Debugf("fifo.ModifyResponse: %s", requ) } g.resmu.RLock() defer g.resmu.RUnlock() merr := martian.NewMultiError() for _, resmod := range g.resmods { if err := resmod.ModifyResponse(res); err != nil { if g.aggregateErrors { merr.Add(err) continue } return err } } if merr.Empty() { return nil } return merr } // VerifyRequests returns a MultiError containing all the // verification errors returned by request verifiers. func (g *Group) VerifyRequests() error { log.Debugf("fifo.VerifyRequests()") g.reqmu.Lock() defer g.reqmu.Unlock() merr := martian.NewMultiError() for _, reqmod := range g.reqmods { reqv, ok := reqmod.(verify.RequestVerifier) if !ok { continue } if err := reqv.VerifyRequests(); err != nil { merr.Add(err) } } if merr.Empty() { return nil } return merr } // VerifyResponses returns a MultiError containing all the // verification errors returned by response verifiers. func (g *Group) VerifyResponses() error { log.Debugf("fifo.VerifyResponses()") g.resmu.Lock() defer g.resmu.Unlock() merr := martian.NewMultiError() for _, resmod := range g.resmods { resv, ok := resmod.(verify.ResponseVerifier) if !ok { continue } if err := resv.VerifyResponses(); err != nil { merr.Add(err) } } if merr.Empty() { return nil } return merr } // ResetRequestVerifications resets the state of the contained request verifiers. func (g *Group) ResetRequestVerifications() { log.Debugf("fifo.ResetRequestVerifications()") g.reqmu.Lock() defer g.reqmu.Unlock() for _, reqmod := range g.reqmods { if reqv, ok := reqmod.(verify.RequestVerifier); ok { reqv.ResetRequestVerifications() } } } // ResetResponseVerifications resets the state of the contained request verifiers. func (g *Group) ResetResponseVerifications() { log.Debugf("fifo.ResetResponseVerifications()") g.resmu.Lock() defer g.resmu.Unlock() for _, resmod := range g.resmods { if resv, ok := resmod.(verify.ResponseVerifier); ok { resv.ResetResponseVerifications() } } } // groupFromJSON builds a fifo.Group from JSON. // // Example JSON: // { // "fifo.Group" : { // "scope": ["request", "result"], // "modifiers": [ // { ... }, // { ... }, // ] // } // } func groupFromJSON(b []byte) (*parse.Result, error) { msg := &groupJSON{} if err := json.Unmarshal(b, msg); err != nil { return nil, err } g := NewGroup() if msg.AggregateErrors { g.SetAggregateErrors(true) } for _, m := range msg.Modifiers { r, err := parse.FromJSON(m) if err != nil { return nil, err } reqmod := r.RequestModifier() if reqmod != nil { g.AddRequestModifier(reqmod) } resmod := r.ResponseModifier() if resmod != nil { g.AddResponseModifier(resmod) } } return parse.NewResult(g, msg.Scope) } martian-3.3.2/fifo/fifo_group_test.go000066400000000000000000000211121421371434000176160ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package fifo import ( "errors" "fmt" "net/http" "reflect" "testing" "github.com/google/martian/v3" "github.com/google/martian/v3/martiantest" "github.com/google/martian/v3/parse" "github.com/google/martian/v3/proxyutil" "github.com/google/martian/v3/verify" _ "github.com/google/martian/v3/header" ) func TestGroupFromJSON(t *testing.T) { msg := []byte(`{ "fifo.Group": { "scope": ["request", "response"], "aggregateErrors": true, "modifiers": [ { "header.Modifier": { "scope": ["request", "response"], "name": "X-Testing", "value": "true" } }, { "header.Modifier": { "scope": ["request", "response"], "name": "Y-Testing", "value": "true" } } ] } }`) r, err := parse.FromJSON(msg) if err != nil { t.Fatalf("parse.FromJSON(): got %v, want no error", err) } reqmod := r.RequestModifier() if reqmod == nil { t.Fatal("reqmod: got nil, want not nil") } req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := reqmod.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := req.Header.Get("X-Testing"), "true"; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", "X-Testing", got, want) } if got, want := req.Header.Get("Y-Testing"), "true"; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", "Y-Testing", got, want) } resmod := r.ResponseModifier() if resmod == nil { t.Fatal("resmod: got nil, want not nil") } res := proxyutil.NewResponse(200, nil, req) if err := resmod.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := res.Header.Get("X-Testing"), "true"; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "X-Testing", got, want) } if got, want := res.Header.Get("Y-Testing"), "true"; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Y-Testing", got, want) } } func TestModifyRequest(t *testing.T) { fg := NewGroup() tm := martiantest.NewModifier() fg.AddRequestModifier(tm) req, err := http.NewRequest("GET", "/", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := fg.ModifyRequest(req); err != nil { t.Fatalf("fg.ModifyRequest(): got %v, want no error", err) } if !tm.RequestModified() { t.Error("tm.RequestModified(): got false, want true") } } func TestModifyRequestHaltsOnError(t *testing.T) { fg := NewGroup() reqerr := errors.New("request error") tm := martiantest.NewModifier() tm.RequestError(reqerr) fg.AddRequestModifier(tm) tm2 := martiantest.NewModifier() fg.AddRequestModifier(tm2) req, err := http.NewRequest("GET", "http://example.com/", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := fg.ModifyRequest(req); err != reqerr { t.Fatalf("fg.ModifyRequest(): got %v, want %v", err, reqerr) } if tm2.RequestModified() { t.Error("tm2.RequestModified(): got true, want false") } } func TestModifyRequestAggregatesErrors(t *testing.T) { fg := NewGroup() fg.SetAggregateErrors(true) reqerr1 := errors.New("1. request error") tm := martiantest.NewModifier() tm.RequestError(reqerr1) fg.AddRequestModifier(tm) tm2 := martiantest.NewModifier() reqerr2 := errors.New("2. request error") tm2.RequestError(reqerr2) fg.AddRequestModifier(tm2) req, err := http.NewRequest("GET", "http://example.com/", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } merr := martian.NewMultiError() merr.Add(reqerr1) merr.Add(reqerr2) if err := fg.ModifyRequest(req); err == nil { t.Fatalf("fg.ModifyRequest(): got %v, want not nil", err) } if err := fg.ModifyRequest(req); err.Error() != merr.Error() { t.Fatalf("fg.ModifyRequest(): got %v, want %v", err, merr) } if err, want := fg.ModifyRequest(req), "1. request error\n2. request error"; err.Error() != want { t.Fatalf("fg.ModifyRequest(): got %v, want %v", err, want) } } func TestModifyResponse(t *testing.T) { fg := NewGroup() tm := martiantest.NewModifier() fg.AddResponseModifier(tm) res := proxyutil.NewResponse(200, nil, nil) if err := fg.ModifyResponse(res); err != nil { t.Fatalf("fg.ModifyResponse(): got %v, want no error", err) } if !tm.ResponseModified() { t.Error("tm.ResponseModified(): got false, want true") } } func TestModifyResponseHaltsOnError(t *testing.T) { fg := NewGroup() reserr := errors.New("request error") tm := martiantest.NewModifier() tm.ResponseError(reserr) fg.AddResponseModifier(tm) tm2 := martiantest.NewModifier() fg.AddResponseModifier(tm2) res := proxyutil.NewResponse(200, nil, nil) if err := fg.ModifyResponse(res); err != reserr { t.Fatalf("fg.ModifyResponse(): got %v, want %v", err, reserr) } if tm2.ResponseModified() { t.Error("tm2.ResponseModified(): got true, want false") } } func TestModifyResponseAggregatesErrors(t *testing.T) { fg := NewGroup() fg.SetAggregateErrors(true) reserr1 := errors.New("1. response error") tm := martiantest.NewModifier() tm.ResponseError(reserr1) fg.AddResponseModifier(tm) tm2 := martiantest.NewModifier() reserr2 := errors.New("2. response error") tm2.ResponseError(reserr2) fg.AddResponseModifier(tm2) req, err := http.NewRequest("GET", "http://example.com/", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } _, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("TestContext(): got %v, want no error", err) } defer remove() res := proxyutil.NewResponse(200, nil, req) merr := martian.NewMultiError() merr.Add(reserr1) merr.Add(reserr2) if err := fg.ModifyResponse(res); err == nil { t.Fatalf("fg.ModifyResponse(): got %v, want %v", err, merr) } if err := fg.ModifyResponse(res); err.Error() != merr.Error() { t.Fatalf("fg.ModifyResponse(): got %v, want %v", err, merr) } } func TestVerifyRequests(t *testing.T) { fg := NewGroup() if err := fg.VerifyRequests(); err != nil { t.Fatalf("VerifyRequest(): got %v, want no error", err) } errs := []error{} for i := 0; i < 3; i++ { err := fmt.Errorf("%d. verify request failure", i) tv := &verify.TestVerifier{ RequestError: err, } fg.AddRequestModifier(tv) errs = append(errs, err) } merr, ok := fg.VerifyRequests().(*martian.MultiError) if !ok { t.Fatal("VerifyRequests(): got nil, want *verify.MultiError") } if !reflect.DeepEqual(merr.Errors(), errs) { t.Errorf("merr.Errors(): got %v, want %v", merr.Errors(), errs) } } func TestVerifyResponses(t *testing.T) { fg := NewGroup() if err := fg.VerifyResponses(); err != nil { t.Fatalf("VerifyResponses(): got %v, want no error", err) } errs := []error{} for i := 0; i < 3; i++ { err := fmt.Errorf("%d. verify responses failure", i) tv := &verify.TestVerifier{ ResponseError: err, } fg.AddResponseModifier(tv) errs = append(errs, err) } merr, ok := fg.VerifyResponses().(*martian.MultiError) if !ok { t.Fatal("VerifyResponses(): got nil, want *verify.MultiError") } if !reflect.DeepEqual(merr.Errors(), errs) { t.Errorf("merr.Errors(): got %v, want %v", merr.Errors(), errs) } } func TestResets(t *testing.T) { fg := NewGroup() for i := 0; i < 3; i++ { tv := &verify.TestVerifier{ RequestError: fmt.Errorf("%d. verify request error", i), ResponseError: fmt.Errorf("%d. verify response error", i), } fg.AddRequestModifier(tv) fg.AddResponseModifier(tv) } if err := fg.VerifyRequests(); err == nil { t.Fatal("VerifyRequests(): got nil, want error") } if err := fg.VerifyResponses(); err == nil { t.Fatal("VerifyResponses(): got nil, want error") } fg.ResetRequestVerifications() fg.ResetResponseVerifications() if err := fg.VerifyRequests(); err != nil { t.Errorf("VerifyRequests(): got %v, want no error", err) } if err := fg.VerifyResponses(); err != nil { t.Errorf("VerifyResponses(): got %v, want no error", err) } } martian-3.3.2/filter/000077500000000000000000000000001421371434000144365ustar00rootroot00000000000000martian-3.3.2/filter/condition.go000066400000000000000000000016771421371434000167660ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package filter import ( "net/http" ) // ResponseCondition is the interface that describes matchers for response filters type ResponseCondition interface { MatchResponse(*http.Response) bool } // RequestCondition is the interface that describes matchers for response filters type RequestCondition interface { MatchRequest(*http.Request) bool } martian-3.3.2/filter/filter.go000066400000000000000000000140741421371434000162600ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package filter provides a modifier that executes a given set of child // modifiers based on the evaluated value of the provided conditional. package filter import ( "fmt" "net/http" "github.com/google/martian/v3" "github.com/google/martian/v3/log" "github.com/google/martian/v3/verify" ) var noop = martian.Noop("Filter") // Filter is a modifer that contains conditions to evaluate on request and // response as well as a set of modifiers to execute based on the value of // the provided RequestCondition or ResponseCondition. type Filter struct { reqcond RequestCondition rescond ResponseCondition treqmod martian.RequestModifier tresmod martian.ResponseModifier freqmod martian.RequestModifier fresmod martian.ResponseModifier } // New returns a pointer to a Filter with all child modifiers initialized to // the noop modifier. func New() *Filter { return &Filter{ treqmod: noop, tresmod: noop, fresmod: noop, freqmod: noop, } } // SetRequestCondition sets the condition to evaluate on requests. func (f *Filter) SetRequestCondition(reqcond RequestCondition) { f.reqcond = reqcond } // SetResponseCondition sets the condition to evaluate on responses. func (f *Filter) SetResponseCondition(rescond ResponseCondition) { f.rescond = rescond } // SetRequestModifier sets the martian.RequestModifier that is executed // when the RequestCondition evaluates to True. This function is provided // to maintain backwards compatability with filtering prior to filter.Filter. func (f *Filter) SetRequestModifier(reqmod martian.RequestModifier) { f.RequestWhenTrue(reqmod) } // RequestWhenTrue sets the martian.RequestModifier that is executed // when the RequestCondition evaluates to True. func (f *Filter) RequestWhenTrue(mod martian.RequestModifier) { if mod == nil { f.treqmod = noop return } f.treqmod = mod } // SetResponseModifier sets the martian.ResponseModifier that is executed // when the ResponseCondition evaluates to True. This function is provided // to maintain backwards compatability with filtering prior to filter.Filter. func (f *Filter) SetResponseModifier(resmod martian.ResponseModifier) { f.ResponseWhenTrue(resmod) } // RequestWhenFalse sets the martian.RequestModifier that is executed // when the RequestCondition evaluates to False. func (f *Filter) RequestWhenFalse(mod martian.RequestModifier) { if mod == nil { f.freqmod = noop return } f.freqmod = mod } // ResponseWhenTrue sets the martian.ResponseModifier that is executed // when the ResponseCondition evaluates to True. func (f *Filter) ResponseWhenTrue(mod martian.ResponseModifier) { if mod == nil { f.tresmod = noop return } f.tresmod = mod } // ResponseWhenFalse sets the martian.ResponseModifier that is executed // when the ResponseCondition evaluates to False. func (f *Filter) ResponseWhenFalse(mod martian.ResponseModifier) { if mod == nil { f.fresmod = noop return } f.fresmod = mod } // ModifyRequest evaluates reqcond and executes treqmod iff reqcond evaluates // to true; otherwise, freqmod is executed. func (f *Filter) ModifyRequest(req *http.Request) error { if f.reqcond == nil { return fmt.Errorf("filter.ModifyRequest: no request condition set. Set condition with SetRequestCondition") } match := f.reqcond.MatchRequest(req) if match { log.Debugf("filter.ModifyRequest: matched %s", req.URL) return f.treqmod.ModifyRequest(req) } return f.freqmod.ModifyRequest(req) } // ModifyResponse evaluates rescond and executes tresmod iff rescond evaluates // to true; otherwise, fresmod is executed. func (f *Filter) ModifyResponse(res *http.Response) error { if f.rescond == nil { return fmt.Errorf("filter.ModifyResponse: no response condition set. Set condition with SetResponseCondition") } match := f.rescond.MatchResponse(res) if match { requ := "" if res.Request != nil { requ = res.Request.URL.String() } log.Debugf("filter.ModifyResponse: %s", requ) return f.tresmod.ModifyResponse(res) } return f.fresmod.ModifyResponse(res) } // VerifyRequests returns an error containing all the verification errors // returned by request verifiers. func (f *Filter) VerifyRequests() error { merr := martian.NewMultiError() freqv, ok := f.freqmod.(verify.RequestVerifier) if ok { if ve := freqv.VerifyRequests(); ve != nil { merr.Add(ve) } } treqv, ok := f.treqmod.(verify.RequestVerifier) if ok { if ve := treqv.VerifyRequests(); ve != nil { merr.Add(ve) } } if merr.Empty() { return nil } return merr } // VerifyResponses returns an error containing all the verification errors // returned by response verifiers. func (f *Filter) VerifyResponses() error { merr := martian.NewMultiError() tresv, ok := f.tresmod.(verify.ResponseVerifier) if ok { if ve := tresv.VerifyResponses(); ve != nil { merr.Add(ve) } } fresv, ok := f.fresmod.(verify.ResponseVerifier) if ok { if ve := fresv.VerifyResponses(); ve != nil { merr.Add(ve) } } if merr.Empty() { return nil } return merr } // ResetRequestVerifications resets the state of the contained request verifiers. func (f *Filter) ResetRequestVerifications() { if treqv, ok := f.treqmod.(verify.RequestVerifier); ok { treqv.ResetRequestVerifications() } if freqv, ok := f.freqmod.(verify.RequestVerifier); ok { freqv.ResetRequestVerifications() } } // ResetResponseVerifications resets the state of the contained request verifiers. func (f *Filter) ResetResponseVerifications() { if tresv, ok := f.tresmod.(verify.ResponseVerifier); ok { tresv.ResetResponseVerifications() } } martian-3.3.2/filter/filter_test.go000066400000000000000000000133001421371434000173060ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package filter import ( "errors" "net/http" "testing" "github.com/google/martian/v3/martiantest" "github.com/google/martian/v3/proxyutil" "github.com/google/martian/v3/verify" ) func TestRequestWhenTrueCondition(t *testing.T) { filter := New() tmc := martiantest.NewMatcher() tmc.RequestEvaluatesTo(true) filter.SetRequestCondition(tmc) tmod := martiantest.NewModifier() filter.RequestWhenTrue(tmod) req, err := http.NewRequest("GET", "/", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := filter.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := tmod.RequestModified(), true; got != want { t.Errorf("tmod.RequestModified(): got %t, want %t", got, want) } } func TestRequestWithoutSettingCondition(t *testing.T) { filter := New() // neglect to set a matcher tmod := martiantest.NewModifier() filter.RequestWhenFalse(tmod) req, err := http.NewRequest("GET", "/", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := filter.ModifyRequest(req); err == nil { t.Fatalf("ModifyRequest(): got no error, want error") } } func TestRequestWhenFalseCondition(t *testing.T) { filter := New() tmc := martiantest.NewMatcher() tmc.RequestEvaluatesTo(false) filter.SetRequestCondition(tmc) tmod := martiantest.NewModifier() filter.RequestWhenFalse(tmod) req, err := http.NewRequest("GET", "/", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := filter.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := tmod.RequestModified(), true; got != want { t.Errorf("tmod.RequestModified(): got %t, want %t", got, want) } } func TestResponseWithoutSettingCondition(t *testing.T) { filter := New() // neglect to set a matcher tmod := martiantest.NewModifier() filter.ResponseWhenFalse(tmod) res := proxyutil.NewResponse(200, nil, nil) if err := filter.ModifyResponse(res); err == nil { t.Fatalf("ModifyResponse(): got no error, want error") } } func TestResponseWhenTrueCondition(t *testing.T) { filter := New() tmc := martiantest.NewMatcher() tmc.ResponseEvaluatesTo(true) filter.SetResponseCondition(tmc) tmod := martiantest.NewModifier() filter.ResponseWhenTrue(tmod) res := proxyutil.NewResponse(200, nil, nil) if err := filter.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := tmod.ResponseModified(), true; got != want { t.Errorf("tmod.ResponseModified(): got %t, want %t", got, want) } } func TestResponseWhenFalseCondition(t *testing.T) { filter := New() tmc := martiantest.NewMatcher() tmc.ResponseEvaluatesTo(false) filter.SetResponseCondition(tmc) tmod := martiantest.NewModifier() filter.ResponseWhenFalse(tmod) res := proxyutil.NewResponse(200, nil, nil) if err := filter.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := tmod.ResponseModified(), true; got != want { t.Errorf("tmod.ResponseModified(): got %t, want %t", got, want) } } func TestResetVerifications(t *testing.T) { filter := New() tmc := martiantest.NewMatcher() tmc.ResponseEvaluatesTo(true) filter.SetResponseCondition(tmc) tv := &verify.TestVerifier{ ResponseError: errors.New("verify response failure"), } filter.ResponseWhenTrue(tv) tv = &verify.TestVerifier{ RequestError: errors.New("verify request failure"), } filter.RequestWhenTrue(tv) if err := filter.VerifyRequests(); err == nil { t.Fatal("VerifyRequests(): got nil, want error") } if err := filter.VerifyResponses(); err == nil { t.Fatal("VerifyResponses(): got nil, want error") } filter.ResetRequestVerifications() filter.ResetResponseVerifications() if err := filter.VerifyResponses(); err != nil { t.Errorf("VerifyResponses(): got %v, want no error", err) } if err := filter.VerifyRequests(); err != nil { t.Errorf("VerifyRequests(): got %v, want no error", err) } } func TestPassThroughVerifyRequests(t *testing.T) { filter := New() tmc := martiantest.NewMatcher() tmc.RequestEvaluatesTo(true) filter.SetRequestCondition(tmc) if err := filter.VerifyRequests(); err != nil { t.Fatalf("VerifyRequest(): got %v, want no error", err) } tv := &verify.TestVerifier{ RequestError: errors.New("verify request failure"), } filter.RequestWhenTrue(tv) if got, want := filter.VerifyRequests().Error(), "verify request failure"; got != want { t.Fatalf("VerifyRequests(): got %s, want %s", got, want) } } func TestPassThroughVerifyResponses(t *testing.T) { filter := New() tmc := martiantest.NewMatcher() tmc.ResponseEvaluatesTo(true) filter.SetResponseCondition(tmc) if err := filter.VerifyResponses(); err != nil { t.Fatalf("VerifyResponses(): got %v, want no error", err) } tv := &verify.TestVerifier{ ResponseError: errors.New("verify response failure"), } filter.ResponseWhenTrue(tv) if got, want := filter.VerifyResponses().Error(), "verify response failure"; got != want { t.Fatalf("VerifyResponses(): got %s, want %s", got, want) } } martian-3.3.2/go.mod000066400000000000000000000005051421371434000142570ustar00rootroot00000000000000module github.com/google/martian/v3 go 1.11 require ( github.com/golang/protobuf v1.5.2 github.com/golang/snappy v0.0.3 golang.org/x/net v0.0.0-20190628185345-da137c7871d7 google.golang.org/grpc v1.37.0 google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.1.0 // indirect google.golang.org/protobuf v1.26.0 // indirect ) martian-3.3.2/go.sum000066400000000000000000000224561421371434000143150ustar00rootroot00000000000000cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= github.com/golang/protobuf v1.4.2 h1:+Z5KGCizgyZCbGh1KZqA0fcLLkwbsjIzS4aV2v7wJX0= github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA= github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/martian v2.1.0+incompatible h1:/CP5g8u/VJHijgedC/Legn3BAbAaWPgecwXBIDzw5no= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190628185345-da137c7871d7 h1:rTIdg5QFRR7XCaK4LCjBiPbx8j4DQRpdYMnGn/bJUEU= golang.org/x/net v0.0.0-20190628185345-da137c7871d7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210505214959-0714010a04ed h1:V9kAVxLvz1lkufatrpHuUVyJ/5tR3Ms7rk951P4mI98= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013 h1:+kGHl1aib/qcwaRi1CbqBZ1rk19r85MNUf8HaBghugY= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= google.golang.org/grpc v1.37.0 h1:uSZWeQJX5j11bIQ4AJoj+McDBo29cY1MCoC1wO3ts+c= google.golang.org/grpc v1.37.0/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQdJfM= google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.1.0 h1:M1YKkFIboKNieVO5DLUEVzQfGwJD30Nv2jfUgzb5UcE= google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.1.0/go.mod h1:6Kw0yEErY5E/yWrBtf03jp27GLLJujG4z/JK95pnjjw= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.25.0 h1:Ejskq+SyPohKW+1uil0JJMtmHCgJPJ/qWTxr8qp+R4c= google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0 h1:bxAC2xTBsZGibn2RTntX0oH50xLsqy1OxA9tTL3p/lk= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= martian-3.3.2/h2/000077500000000000000000000000001421371434000134625ustar00rootroot00000000000000martian-3.3.2/h2/grpc/000077500000000000000000000000001421371434000144155ustar00rootroot00000000000000martian-3.3.2/h2/grpc/grpc.go000066400000000000000000000237151421371434000157070ustar00rootroot00000000000000// Copyright 2021 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package grpc contains gRPC functionality for Martian proxy. package grpc import ( "bytes" "compress/flate" "compress/gzip" "encoding/binary" "fmt" "io/ioutil" "net/url" "sync/atomic" "github.com/golang/snappy" "github.com/google/martian/v3/h2" "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" ) // Encoding is the grpc-encoding type. See Content-Coding entry at: // https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests type Encoding uint8 const ( // Identity indicates that no compression is used. Identity Encoding = iota // Gzip indicates that Gzip compression is used. Gzip // Deflate indicates that Deflate compression is used. Deflate // Snappy indicates that Snappy compression is used. Snappy ) // ProcessorFactory creates gRPC processors that implement the Processor interface, which abstracts // away some of the details of the underlying HTTP/2 protocol. A processor must forward // invocations to the given `server` or `client` processors, which will arrange to have the data // forwarded to the destination, with possible edits. Nil values are safe to return and no // processing occurs in such cases. NOTE: an interface may have a non-nil type with a nil value. // Such values are treated as valid processors. type ProcessorFactory func(url *url.URL, server, client Processor) (Processor, Processor) // AsStreamProcessorFactory converts a ProcessorFactory into a StreamProcessorFactory. It creates // an adapter that abstracts HTTP/2 frames into a representation that is closer to gRPC. func AsStreamProcessorFactory(f ProcessorFactory) h2.StreamProcessorFactory { return func(url *url.URL, sinks *h2.Processors) (h2.Processor, h2.Processor) { var cToS, sToC h2.Processor // A grpc.Processor is translated into an h2.Processor in layers. // // adapter → processor → emitter → sink // \_____________________________↗ // // * The adapter wraps the grpc.Processor interface so that it conforms with h2.Processor. It // performs some processing to translate HTTP/2 frames into gRPC concepts. Frames that are // not relevant to gRPC are forwarded directly to the sink. // * The processor is the gRPC processing logic provided by the client factory. // * The emitter wraps an h2.Processor sink and translates the processed gRPC data into HTTP/2 // frames. cToSEmitter := &emitter{sink: sinks.ForDirection(h2.ClientToServer)} sToCEmitter := &emitter{sink: sinks.ForDirection(h2.ServerToClient)} cToSProcessor, sToCProcessor := f(url, cToSEmitter, sToCEmitter) // enabled indicates whether the stream should be processed as gRPC. It is shared between the // the two adapters because its detection is on a client-to-server HEADER frame and the state // applies bidirectionally. enabled := int32(0) if cToSProcessor != nil { cToSEmitter.adapter = &adapter{ enabled: &enabled, dir: h2.ClientToServer, processor: cToSProcessor, sink: sinks.ForDirection(h2.ClientToServer), } cToS = cToSEmitter.adapter } if sToCProcessor != nil { sToCEmitter.adapter = &adapter{ enabled: &enabled, dir: h2.ServerToClient, processor: sToCProcessor, sink: sinks.ForDirection(h2.ServerToClient), } sToC = sToCEmitter.adapter } return cToS, sToC } } // Processor processes gRPC traffic. type Processor interface { h2.HeaderProcessor // Message receives serialized messages. Message(data []byte, streamEnded bool) error } // dataState represents one of two possible states when consuming gRPC DATA frames. type dataState uint8 const ( readingMetadata dataState = iota readingMessageData ) // adapter wraps the Processor interface with an h2.Processor interface. It filters streams that // are not gRPC and handles decompressing the message data. type adapter struct { enabled *int32 dir h2.Direction processor Processor sink h2.Processor encoding Encoding // State for the data interpreter. buffer bytes.Buffer state dataState compressed bool length uint32 } func (a *adapter) Header( headers []hpack.HeaderField, streamEnded bool, priority http2.PriorityParam, ) error { if !a.isEnabled() { for _, h := range headers { if h.Name == "content-type" && h.Value == "application/grpc" { atomic.StoreInt32(a.enabled, 1) break } } if !a.isEnabled() { return a.sink.Header(headers, streamEnded, priority) } } for _, h := range headers { if h.Name == "grpc-encoding" { switch h.Value { case "identity": a.encoding = Identity case "gzip": a.encoding = Gzip case "deflate": a.encoding = Deflate case "snappy": a.encoding = Snappy default: return fmt.Errorf("unrecognized grpc-encoding %s in %v", h.Value, headers) } } } return a.processor.Header(headers, streamEnded, priority) } func (a *adapter) Data(data []byte, streamEnded bool) error { if !a.isEnabled() { return a.sink.Data(data, streamEnded) } a.buffer.Write(data) for { switch a.state { case readingMetadata: if streamEnded && a.buffer.Len() == 0 { // gRPC may send empty DATA frames to end a stream. if err := a.processor.Message(nil, true); err != nil { return err } } if a.buffer.Len() < 5 { return nil } compressed, _ := a.buffer.ReadByte() a.compressed = compressed > 0 if err := binary.Read(&a.buffer, binary.BigEndian, &a.length); err != nil { return fmt.Errorf("reading message length: %w", err) } a.state = readingMessageData case readingMessageData: if uint32(a.buffer.Len()) < a.length { return nil } data := make([]byte, a.length) a.buffer.Read(data) if a.compressed { switch a.encoding { case Identity: case Gzip: var err error data, err = gunzip(data) if err != nil { return fmt.Errorf("gunzipping data: %w", err) } case Deflate: var err error data, err = deflate(data) if err != nil { return fmt.Errorf("deflating data: %w", err) } case Snappy: var err error data, err = ioutil.ReadAll(snappy.NewReader(bytes.NewReader(data))) if err != nil { return fmt.Errorf("uncompressing snappy: %w", err) } default: panic(fmt.Sprintf("unexpected enocding: %v", a.encoding)) } } a.state = readingMetadata // Only marks stream ended for the message if there is no data remaining. For ease of // implementation, this proxy aligns messages with data frames. This means that if a data // frame with stream ended contains multiple messages, the earlier ones should not be // marked with stream ended. // // As explained in https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#data-frames, // this reframing is safe because gRPC implementations won't be making any assumptions about // the framing. if err := a.processor.Message(data, streamEnded && a.buffer.Len() == 0); err != nil { return err } default: panic(fmt.Sprintf("unexpected state: %v", a.state)) } if a.buffer.Len() == 0 { return nil } } } func (a *adapter) Priority(priority http2.PriorityParam) error { return a.sink.Priority(priority) } func (a *adapter) RSTStream(errCode http2.ErrCode) error { return a.sink.RSTStream(errCode) } func (a *adapter) PushPromise(promiseID uint32, headers []hpack.HeaderField) error { return a.sink.PushPromise(promiseID, headers) } func (a *adapter) isEnabled() bool { return atomic.LoadInt32(a.enabled) > 0 } // emitter is a Processor implementation that wraps a h2.Processor instance, forwarding traffic to // it. It handles recompression of the data. type emitter struct { sink h2.Processor // adapter is a reference to the adapter needed to retrieve state. adapter *adapter } func (e *emitter) Header( headers []hpack.HeaderField, streamEnded bool, priority http2.PriorityParam, ) error { return e.sink.Header(headers, streamEnded, priority) } func (e *emitter) Message(data []byte, streamEnded bool) error { // Applies compression to `data` depending on `adapter`'s state. if e.adapter.compressed { switch e.adapter.encoding { case Identity: case Gzip: var buf bytes.Buffer w := gzip.NewWriter(&buf) if _, err := w.Write(data); err != nil { return fmt.Errorf("gzipping message data: %w", err) } if err := w.Close(); err != nil { return fmt.Errorf("gzipping message data: %w", err) } data = buf.Bytes() case Deflate: var buf bytes.Buffer w, _ := flate.NewWriter(&buf, -1) if _, err := w.Write(data); err != nil { return fmt.Errorf("flate compressing message data: %w", err) } if err := w.Close(); err != nil { return fmt.Errorf("flate compressing message data: %w", err) } data = buf.Bytes() case Snappy: data = snappy.Encode(nil, data) } } var buf bytes.Buffer // Writes the compression status. if e.adapter.compressed { buf.WriteByte(1) } else { buf.WriteByte(0) } binary.Write(&buf, binary.BigEndian, uint32(len(data))) // Writes the length of the data. buf.Write(data) // Writes the actual data. return e.sink.Data(buf.Bytes(), streamEnded) } func gunzip(data []byte) ([]byte, error) { r, err := gzip.NewReader(bytes.NewReader(data)) if err != nil { return nil, err } return ioutil.ReadAll(r) } func deflate(data []byte) (_ []byte, rerr error) { r := flate.NewReader(bytes.NewReader(data)) defer func() { if err := r.Close(); err != nil && rerr != nil { rerr = err } }() return ioutil.ReadAll(r) } martian-3.3.2/h2/h2.go000066400000000000000000000117441421371434000143310ustar00rootroot00000000000000// Copyright 2021 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package h2 contains basic HTTP/2 handling for Martian. package h2 import ( "bytes" "crypto/tls" "crypto/x509" "encoding/hex" "fmt" "io" "net/url" "sync" "github.com/google/martian/v3/log" "golang.org/x/net/http2" ) var ( // connectionPreface is the constant value of the connection preface. // https://tools.ietf.org/html/rfc7540#section-3.5 connectionPreface = []byte("PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n") ) // Config stores the configuration information needed for HTTP/2 processing. type Config struct { // AllowedHostsFilter is a function returning true if the argument is a host for which H2 is // permitted. AllowedHostsFilter func(string) bool // RootCAs is the pool of CA certificates used by the MitM client to authenticate the server. RootCAs *x509.CertPool // StreamProcessorFactories is a list of factories used to instantiate a chain of HTTP/2 stream // processors. A chain is created for every stream. StreamProcessorFactories []StreamProcessorFactory // EnableDebugLogs turns on fine-grained debug logging for HTTP/2. EnableDebugLogs bool } // Proxy proxies HTTP/2 traffic between a client connection, `cc`, and the HTTP/2 `url` assuming // h2 is being used. Since no browsers use h2c, it's safe to assume all traffic uses TLS. func (c *Config) Proxy(closing chan bool, cc io.ReadWriter, url *url.URL) error { if c.EnableDebugLogs { log.Infof("\u001b[1;35mProxying %v with HTTP/2\u001b[0m", url) } sc, err := tls.Dial("tcp", url.Host, &tls.Config{ RootCAs: c.RootCAs, NextProtos: []string{"h2"}, }) if err != nil { return fmt.Errorf("connecting h2 to %v: %w", url, err) } if err := forwardPreface(sc, cc); err != nil { return fmt.Errorf("initializing h2 with %v: %w", url, err) } cf, sf := http2.NewFramer(cc, cc), http2.NewFramer(sc, sc) cToS := newRelay(ClientToServer, "client", url.String(), cf, sf, &c.EnableDebugLogs) sToC := newRelay(ServerToClient, url.String(), "client", sf, cf, &c.EnableDebugLogs) // Completes circular parts of the initialization. // The client-to-server relay depends on the server-to-client relay and vice versa. cToS.peer, sToC.peer = sToC, cToS // Creating processors is circular because the create function references the relays and the // relays need to call create. cToS.processors = &streamProcessors{ create: func(id uint32) *Processors { p := &Processors{cToS: &relayAdapter{id, cToS}, sToC: &relayAdapter{id, sToC}} // Chains the pipeline of processors together. for i := len(c.StreamProcessorFactories) - 1; i >= 0; i-- { cToS, sToC := c.StreamProcessorFactories[i](url, p) // Bypasses any nil processors. if cToS == nil { cToS = p.ForDirection(ClientToServer) } if sToC == nil { sToC = p.ForDirection(ServerToClient) } p = &Processors{cToS: cToS, sToC: sToC} } return p }, } sToC.processors = cToS.processors var wg sync.WaitGroup wg.Add(2) go func() { // Forwards frames from client to server. defer wg.Done() if err := cToS.relayFrames(closing); err != nil { log.Errorf("relaying frame from client to %v: %v", url, err) } }() go func() { // Forwards frames from server to client. defer wg.Done() if err := sToC.relayFrames(closing); err != nil { log.Errorf("relaying frame from %v to client: %v", url, err) } }() wg.Wait() return nil } // forwardPreface forwards the connection preface from the client to the server. func forwardPreface(server io.Writer, client io.Reader) error { preface := make([]byte, len(connectionPreface)) if _, err := client.Read(preface); err != nil { return fmt.Errorf("reading preface: %w", err) } if !bytes.Equal(preface, connectionPreface) { return fmt.Errorf("client sent unexpected preface: %s", hex.Dump(preface)) } for m := len(connectionPreface); m > 0; { n, err := server.Write([]byte(preface)) if err != nil { return fmt.Errorf("writing preface: %w", err) } preface = preface[n:] m -= n } return nil } type streamProcessors struct { // processors stores `*Processors` instances keyed by uint32 stream ID. processors sync.Map // create creates `*Processors` for the given stream ID. create func(uint32) *Processors } // Get returns a the processor with the given ID and direction. func (s *streamProcessors) Get(id uint32, dir Direction) Processor { value, ok := s.processors.Load(id) if !ok { value, _ = s.processors.LoadOrStore(id, s.create(id)) } return value.(*Processors).ForDirection(dir) } martian-3.3.2/h2/h2_test.go000066400000000000000000000306621421371434000153700ustar00rootroot00000000000000// Copyright 2021 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package h2_test import ( "context" "encoding/base64" "fmt" "io" "math/rand" "net/url" "sync" "testing" "github.com/google/martian/v3/h2" mgrpc "github.com/google/martian/v3/h2/grpc" ht "github.com/google/martian/v3/h2/testing" "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" "google.golang.org/grpc" "google.golang.org/grpc/encoding/gzip" "google.golang.org/protobuf/proto" tspb "github.com/google/martian/v3/h2/testservice" ) type requestProcessor struct { dest mgrpc.Processor requests *[]*tspb.EchoRequest } func (p *requestProcessor) Header( headers []hpack.HeaderField, streamEnded bool, priority http2.PriorityParam, ) error { return p.dest.Header(headers, streamEnded, priority) } func (p *requestProcessor) Message(data []byte, streamEnded bool) error { msg := &tspb.EchoRequest{} if err := proto.Unmarshal(data, msg); err != nil { return fmt.Errorf("unmarshalling request: %w", err) } *p.requests = append(*p.requests, msg) return p.dest.Message(data, streamEnded) } type responseProcessor struct { dest mgrpc.Processor responses *[]*tspb.EchoResponse } func (p *responseProcessor) Header( headers []hpack.HeaderField, streamEnded bool, priority http2.PriorityParam, ) error { return p.dest.Header(headers, streamEnded, priority) } func (p *responseProcessor) Message(data []byte, streamEnded bool) error { msg := &tspb.EchoResponse{} if err := proto.Unmarshal(data, msg); err != nil { return fmt.Errorf("unmarshalling response: %w", err) } *p.responses = append(*p.responses, msg) return p.dest.Message(data, streamEnded) } func TestEcho(t *testing.T) { // This is a basic smoke test. It verifies that the end-to-end flow works and that gRPC messages // are observed as expected in processors. var requests []*tspb.EchoRequest var responses []*tspb.EchoResponse fixture, err := ht.New([]h2.StreamProcessorFactory{ mgrpc.AsStreamProcessorFactory( func(_ *url.URL, server, client mgrpc.Processor) (mgrpc.Processor, mgrpc.Processor) { return &requestProcessor{server, &requests}, &responseProcessor{client, &responses} }), }) if err != nil { t.Fatalf("ht.New(...) = %v, want nil", err) } defer func() { if err := fixture.Close(); err != nil { t.Fatalf("f.Close() = %v, want nil", err) } }() ctx := context.Background() req := &tspb.EchoRequest{ Payload: "Hello", } resp, err := fixture.Echo(ctx, req) if err != nil { t.Fatalf("fixture.Echo(...) = _, %v, want _, nil", err) } if got, want := resp.GetPayload(), req.GetPayload(); got != want { t.Errorf("resp.GetPayload() = %s, want = %s", got, want) } // Verifies the captured requests and responses. if got := len(requests); got != 1 { t.Fatalf("len(requests) = %d, want 1", got) } if got, want := requests[0].GetPayload(), req.GetPayload(); got != want { t.Errorf("requests[0].GetPayload() = %s, want = %s", got, want) } if got := len(responses); got != 1 { t.Fatalf("len(requests) = %d, want 1", got) } if got, want := responses[0].GetPayload(), req.GetPayload(); got != want { t.Errorf("responses[0].GetPayload() = %s, want = %s", got, want) } } type requestEditor struct { dest mgrpc.Processor } func (p *requestEditor) Header( headers []hpack.HeaderField, streamEnded bool, priority http2.PriorityParam, ) error { return p.dest.Header(headers, streamEnded, priority) } func (p *requestEditor) Message(_ []byte, streamEnded bool) error { msg := &tspb.EchoRequest{ Payload: "Goodbye", } data, err := proto.Marshal(msg) if err != nil { return fmt.Errorf("marshalling request: %w", err) } return p.dest.Message(data, streamEnded) } func TestRequestEditor(t *testing.T) { // This test inserts a request modifier that changes the payload from "Hello" to "Goodbye". fixture, err := ht.New([]h2.StreamProcessorFactory{ mgrpc.AsStreamProcessorFactory( func(_ *url.URL, server, client mgrpc.Processor) (mgrpc.Processor, mgrpc.Processor) { return &requestEditor{server}, nil }), }) if err != nil { t.Fatalf("ht.New(...) = %v, want nil", err) } defer func() { if err := fixture.Close(); err != nil { t.Fatalf("f.Close() = %v, want nil", err) } }() ctx := context.Background() req := &tspb.EchoRequest{ Payload: "Hello", } resp, err := fixture.Echo(ctx, req) if err != nil { t.Fatalf("fixture.Echo(...) = _, %v, want _, nil", err) } if got, want := resp.GetPayload(), "Goodbye"; got != want { t.Errorf("resp.GetPayload() = %s, want = %s", got, want) } } type plusOne struct { dest mgrpc.Processor } func (p *plusOne) Header( headers []hpack.HeaderField, streamEnded bool, priority http2.PriorityParam, ) error { return p.dest.Header(headers, streamEnded, priority) } func (p *plusOne) Message(data []byte, streamEnded bool) error { msg := &tspb.SumRequest{} if err := proto.Unmarshal(data, msg); err != nil { return fmt.Errorf("unmarshalling request: %w", err) } msg.Values = append(msg.Values, 1) data, err := proto.Marshal(msg) if err != nil { return fmt.Errorf("marshalling request: %w", err) } return p.dest.Message(data, streamEnded) } func TestProcessorChaining(t *testing.T) { // This test constructs a chain of processors and checks that the effects are correctly applied // at the result. fixture, err := ht.New([]h2.StreamProcessorFactory{ mgrpc.AsStreamProcessorFactory( func(_ *url.URL, server, client mgrpc.Processor) (mgrpc.Processor, mgrpc.Processor) { return &plusOne{server}, nil }), mgrpc.AsStreamProcessorFactory( func(_ *url.URL, server, client mgrpc.Processor) (mgrpc.Processor, mgrpc.Processor) { return &plusOne{server}, nil }), mgrpc.AsStreamProcessorFactory( func(_ *url.URL, server, client mgrpc.Processor) (mgrpc.Processor, mgrpc.Processor) { return &plusOne{server}, nil }), }) if err != nil { t.Fatalf("ht.New(...) = %v, want nil", err) } defer func() { if err := fixture.Close(); err != nil { t.Fatalf("f.Close() = %v, want nil", err) } }() ctx := context.Background() req := &tspb.SumRequest{ Values: []int32{5}, } resp, err := fixture.Sum(ctx, req) if err != nil { t.Fatalf("fixture.Sum(...) = _, %v, want _, nil", err) } if got, want := resp.GetValue(), int32(8); got != want { t.Errorf("resp.GetValue() = %d, want = %d", got, want) } } type headerCapture struct { dest mgrpc.Processor headers *[][]hpack.HeaderField } func (h *headerCapture) Header( headers []hpack.HeaderField, streamEnded bool, priority http2.PriorityParam, ) error { c := make([]hpack.HeaderField, len(headers)) copy(c, headers) *h.headers = append(*h.headers, c) return h.dest.Header(headers, streamEnded, priority) } func (h *headerCapture) Message(data []byte, streamEnded bool) error { return h.dest.Message(data, streamEnded) } func TestLargeEcho(t *testing.T) { // Sends a >128KB payload through the proxy. Since the standard gRPC frame size is only 16KB, // this exercises frame merging, splitting and flow control code. payload := make([]byte, 128*1024) rand.Read(payload) req := &tspb.EchoRequest{ Payload: base64.StdEncoding.EncodeToString(payload), } // This test also covers using gzip compression. Ideally, we would test more compression types // but the golang gRPC implementation only provides a gzip compressor. tests := []struct { name string useCompression bool }{ {"RawData", false}, {"Gzip", true}, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { var cToSHeaders, sToCHeaders [][]hpack.HeaderField fixture, err := ht.New([]h2.StreamProcessorFactory{ mgrpc.AsStreamProcessorFactory( func(_ *url.URL, server, client mgrpc.Processor) (mgrpc.Processor, mgrpc.Processor) { return &headerCapture{server, &cToSHeaders}, &headerCapture{client, &sToCHeaders} }), }) if err != nil { t.Fatalf("ht.New(...) = %v, want nil", err) } defer func() { if err := fixture.Close(); err != nil { t.Fatalf("f.Close() = %v, want nil", err) } }() ctx := context.Background() var resp *tspb.EchoResponse if tc.useCompression { resp, err = fixture.Echo(ctx, req, grpc.UseCompressor(gzip.Name)) } else { resp, err = fixture.Echo(ctx, req) } if err != nil { t.Fatalf("fixture.Echo(...) = _, %v, want _, nil", err) } if got, want := resp.GetPayload(), req.GetPayload(); got != want { t.Errorf("resp.GetPayload() = %s, want = %s", got, want) } // Verifies that grpc-encoding=gzip is present in the first headers on the stream when // compression is active. for _, headers := range [][]hpack.HeaderField{cToSHeaders[0], sToCHeaders[0]} { foundGRPCEncoding := false for _, h := range headers { if h.Name == "grpc-encoding" { foundGRPCEncoding = true if got, want := h.Value, "gzip"; got != want { t.Errorf("h.Value = %s, want %s", got, want) } } } if got, want := foundGRPCEncoding, tc.useCompression; got != want { t.Errorf("foundGRPCEncoding = %t, want %t", got, want) } } }) } } type noopProcessor struct { sink mgrpc.Processor } func (p *noopProcessor) Header( headers []hpack.HeaderField, streamEnded bool, priority http2.PriorityParam, ) error { return p.sink.Header(headers, streamEnded, priority) } func (p *noopProcessor) Message(data []byte, streamEnded bool) error { return p.sink.Message(data, streamEnded) } func TestStream(t *testing.T) { tests := []struct { name string factory h2.StreamProcessorFactory }{ { "NilH2Processor", func(_ *url.URL, _ *h2.Processors) (h2.Processor, h2.Processor) { return nil, nil }, }, { // This differs from NilH2Processor only in how mgrpc.AsStreamProcessorFactory handles nil // grpc.Processor values. It should end up processing exactly the same as // h2.StreamProcessorFactory afterwards. "NilGRPCProcessor", mgrpc.AsStreamProcessorFactory( func(_ *url.URL, _, _ mgrpc.Processor) (mgrpc.Processor, mgrpc.Processor) { return nil, nil }), }, { // This differs from NilGRPCProcessor in that NilGRPCProcessor ends up behaving like // NilH2Processor and no gRPC processing takes place. NoopProcessor causes the frames to // be processed as gRPC. "NoopGRPCProcessor", mgrpc.AsStreamProcessorFactory( func(_ *url.URL, server, client mgrpc.Processor) (mgrpc.Processor, mgrpc.Processor) { return &noopProcessor{server}, &noopProcessor{client} }), }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { fixture, err := ht.New([]h2.StreamProcessorFactory{tc.factory}) if err != nil { t.Fatalf("ht.New(...) = %v, want nil", err) } defer func() { if err := fixture.Close(); err != nil { t.Fatalf("f.Close() = %v, want nil", err) } }() ctx := context.Background() stream, err := fixture.DoubleEcho(ctx) if err != nil { t.Fatalf("fixture.DoubleEcho(ctx) = _, %v, want _, nil", err) } var received []*tspb.EchoResponse var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() for { resp, err := stream.Recv() if err == io.EOF { return } if err != nil { t.Errorf("stream.Recv() = %v, want nil", err) return } received = append(received, resp) } }() var sent []*tspb.EchoRequest for i := 0; i < 5; i++ { payload := make([]byte, 20*1024) rand.Read(payload) req := &tspb.EchoRequest{ Payload: base64.StdEncoding.EncodeToString(payload), } if err := stream.Send(req); err != nil { t.Fatalf("stream.Send(req) = %v, want nil", err) } sent = append(sent, req) } if err := stream.CloseSend(); err != nil { t.Fatalf("stream.CloseSend() = %v, want nil", err) } wg.Wait() for i, req := range sent { want := req.GetPayload() if got := received[2*i].GetPayload(); got != want { t.Errorf("received[2*i].GetPayload() = %s, want %s", got, want) } if got := received[2*i+1].GetPayload(); got != want { t.Errorf("received[2*i+1].GetPayload() = %s, want %s", got, want) } } }) } } martian-3.3.2/h2/processor.go000066400000000000000000000106371421371434000160370ustar00rootroot00000000000000// Copyright 2021 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package h2 import ( "fmt" "net/url" "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" ) // Direction indicates the direction of the traffic flow. type Direction uint8 const ( // ClientToServer indicates traffic flowing from client-to-server. ClientToServer Direction = iota // ServerToClient indicates traffic flowing from server-to-client. ServerToClient ) // StreamProcessorFactory is implemented by clients that wish to observe or edit HTTP/2 frames // flowing through the proxy. It creates a pair of processors for the bidirectional stream. A // processor consumes frames then calls the corresponding sink methods to forward frames to the // destination, modifying the frame if needed. // // Returns the client-to-server and server-to-client processors. Nil values are safe to return and // no processing occurs in such cases. NOTE: an interface may have a non-nil type with a nil value. // Such values are treated as valid processors. // // Concurrency: there is a separate client-to-server and server-to-client thread. Calls against // the `ClientToServer` sink must be made on the client-to-server thread and calls against // the `ServerToClient` sink must be made on the server-to-client thread. Implementors should // guard interactions across threads. type StreamProcessorFactory func(url *url.URL, sinks *Processors) (Processor, Processor) // Processors encapsulates the two traffic receiving endpoints. type Processors struct { cToS, sToC Processor } // ForDirection returns the processor receiving traffic in the given direction. func (s *Processors) ForDirection(dir Direction) Processor { switch dir { case ClientToServer: return s.cToS case ServerToClient: return s.sToC } panic(fmt.Sprintf("invalid direction: %v", dir)) } // Processor accepts the possible stream frames. // // This API abstracts away some of the lower level HTTP/2 mechanisms. // CONTINUATION frames are appropriately buffered and turned into Header calls and Header or // PushPromise calls are split into CONTINUATION frames when needed. // // The proxy handles WINDOW_UPDATE frames and flow control, managing it independently for both // endpoints. type Processor interface { DataFrameProcessor HeaderProcessor PriorityFrameProcessor RSTStreamProcessor PushPromiseProcessor } // DataFrameProcessor processes data frames. type DataFrameProcessor interface { Data(data []byte, streamEnded bool) error } // HeaderProcessor processes headers, abstracting out continuations. type HeaderProcessor interface { Header( headers []hpack.HeaderField, streamEnded bool, priority http2.PriorityParam, ) error } // PriorityFrameProcessor processes priority frames. type PriorityFrameProcessor interface { Priority(http2.PriorityParam) error } // RSTStreamProcessor processes RSTStream frames. type RSTStreamProcessor interface { RSTStream(http2.ErrCode) error } // PushPromiseProcessor processes push promises, abstracting out continuations. type PushPromiseProcessor interface { PushPromise(promiseID uint32, headers []hpack.HeaderField) error } // relayAdapter implements the Processor interface by delegating to an underlying relay. type relayAdapter struct { id uint32 relay *relay } func (r *relayAdapter) Data(data []byte, streamEnded bool) error { return r.relay.data(r.id, data, streamEnded) } func (r *relayAdapter) Header( headers []hpack.HeaderField, streamEnded bool, priority http2.PriorityParam, ) error { return r.relay.header(r.id, headers, streamEnded, priority) } func (r *relayAdapter) Priority(priority http2.PriorityParam) error { r.relay.priority(r.id, priority) return nil } func (r *relayAdapter) RSTStream(errCode http2.ErrCode) error { r.relay.rstStream(r.id, errCode) return nil } func (r *relayAdapter) PushPromise(promiseID uint32, headers []hpack.HeaderField) error { return r.relay.pushPromise(r.id, promiseID, headers) } martian-3.3.2/h2/queued_frames.go000066400000000000000000000153351421371434000166450ustar00rootroot00000000000000// Copyright 2021 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package h2 import ( "bytes" "fmt" "golang.org/x/net/http2" ) // queuedFrame stores frames that belong to a stream and need to be kept in order. The need for // this stems from flow control needed in the context of gRPC. Since a gRPC message can be split // over multiple DATA frames, the proxy needs to buffer such frames so they can be reassembled // into messages and edited before being forwarded. // // Note that the proxy does man-in-the-middle flow control independently to each endpoint instead // of forwarding endpoint flow-control messages to each other directly. This is necessary because // multiple DATA frames need to be captured before they can be forwarded. While the data frames are // being held in the proxy, the destination of those frames cannot see them to send WINDOW_UPDATE // acknowledgements and the sender will stop sending data. So the proxy must emit its own // WINDOW_UPDATEs. // // Example: While DATA frames are being output-buffered due to pending WINDOW_UPDATE frames from // the destination, it's possible for the source to send subsequent HEADER frames. Those HEADER // frames must be queued after the DATA frames for consistency with HTTP/2's total ordering of // frames within a stream. // // While the example only illustrates the need for HEADER frame buffering, a similar argument // applies to other types of stream frames. WINDOW_UPDATE is a special case that is associated // with a stream but does not require buffering or special ordering. This is because WINDOW_UPDATEs // are basically acknowledgements for messages coming from the peer endpoint. In other words, // WINDOW_UPDATE frames are associated with messages being received instead of messages being sent. // The asynchrony of receiving remote messages should allow reordering freedom. type queuedFrame interface { // StreamID is the stream ID for the frame. StreamID() uint32 // flowControlSize returns the size of this frame for the purposes of flow control. It is only // non-zero for DATA frames. flowControlSize() int // send writes the frame to the provided framer. This is not thread-safe and the caller should be // holding appropriate locks. send(*http2.Framer) error } type queuedDataFrame struct { streamID uint32 endStream bool data []byte } func (f *queuedDataFrame) StreamID() uint32 { return f.streamID } func (f *queuedDataFrame) flowControlSize() int { return len(f.data) } func (f *queuedDataFrame) send(dest *http2.Framer) error { return dest.WriteData(f.streamID, f.endStream, f.data) } func (f *queuedDataFrame) String() string { return fmt.Sprintf("data[id=%d, endStream=%t, len=%d]", f.streamID, f.endStream, len(f.data)) } type queuedHeaderFrame struct { streamID uint32 endStream bool priority http2.PriorityParam chunks [][]byte } func (f *queuedHeaderFrame) StreamID() uint32 { return f.streamID } func (*queuedHeaderFrame) flowControlSize() int { return 0 } func (f *queuedHeaderFrame) send(dest *http2.Framer) error { if err := dest.WriteHeaders(http2.HeadersFrameParam{ StreamID: f.streamID, BlockFragment: f.chunks[0], EndStream: f.endStream, EndHeaders: len(f.chunks) <= 1, PadLength: 0, Priority: f.priority, }); err != nil { return fmt.Errorf("sending header %v: %w", f, err) } for i := 1; i < len(f.chunks); i++ { headersEnded := i == len(f.chunks)-1 if err := dest.WriteContinuation(f.streamID, headersEnded, f.chunks[i]); err != nil { return fmt.Errorf("sending header continuations %v: %w", f, err) } } return nil } func (f *queuedHeaderFrame) String() string { var buf bytes.Buffer // strings.Builder is not available on App Engine. fmt.Fprintf(&buf, "header[id=%d, endStream=%t", f.streamID, f.endStream) fmt.Fprintf(&buf, ", priority=%v, chunk lengths=[", f.priority) for i, c := range f.chunks { if i > 0 { fmt.Fprintf(&buf, ",") } fmt.Fprintf(&buf, "%d", len(c)) } fmt.Fprintf(&buf, "]]") return buf.String() } type queuedPushPromiseFrame struct { streamID uint32 promiseID uint32 chunks [][]byte } func (f *queuedPushPromiseFrame) StreamID() uint32 { return f.streamID } func (*queuedPushPromiseFrame) flowControlSize() int { return 0 } func (f *queuedPushPromiseFrame) send(dest *http2.Framer) error { if err := dest.WritePushPromise(http2.PushPromiseParam{ StreamID: f.streamID, PromiseID: f.promiseID, BlockFragment: f.chunks[0], EndHeaders: len(f.chunks) <= 1, PadLength: 0, }); err != nil { return fmt.Errorf("sending push promise %v: %w", f, err) } for i := 1; i < len(f.chunks); i++ { headersEnded := i == len(f.chunks)-1 if err := dest.WriteContinuation(f.streamID, headersEnded, f.chunks[i]); err != nil { return fmt.Errorf("sending push promise continuations %v: %w", f, err) } } return nil } func (f *queuedPushPromiseFrame) String() string { var buf bytes.Buffer fmt.Fprintf(&buf, "push promise[streamID=%d, promiseID= %d", f.streamID, f.promiseID) fmt.Fprintf(&buf, ", chunk lengths=[") for i, c := range f.chunks { if i > 0 { fmt.Fprintf(&buf, ",") } fmt.Fprintf(&buf, "%d", len(c)) } fmt.Fprintf(&buf, "]]") return buf.String() } type queuedPriorityFrame struct { streamID uint32 priority http2.PriorityParam } func (f *queuedPriorityFrame) StreamID() uint32 { return f.streamID } func (*queuedPriorityFrame) flowControlSize() int { return 0 } func (f *queuedPriorityFrame) send(dest *http2.Framer) error { if err := dest.WritePriority(f.streamID, f.priority); err != nil { return fmt.Errorf("sending %v: %w", f, err) } return nil } func (f *queuedPriorityFrame) String() string { return fmt.Sprintf("priority[id=%d, priority=%v]", f.streamID, f.priority) } type queuedRSTStreamFrame struct { streamID uint32 errCode http2.ErrCode } func (f *queuedRSTStreamFrame) StreamID() uint32 { return f.streamID } func (*queuedRSTStreamFrame) flowControlSize() int { return 0 } func (f *queuedRSTStreamFrame) send(dest *http2.Framer) error { if err := dest.WriteRSTStream(f.streamID, f.errCode); err != nil { return fmt.Errorf("sending %v: %w", f, err) } return nil } func (f *queuedRSTStreamFrame) String() string { return fmt.Sprintf("RSTStream[id=%d, errCode=%v]", f.streamID, f.errCode) } martian-3.3.2/h2/relay.go000066400000000000000000000466161421371434000151420ustar00rootroot00000000000000// Copyright 2021 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package h2 import ( "bytes" "container/list" "errors" "fmt" "io" "math" "sync" "sync/atomic" "github.com/google/martian/v3/log" "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" ) const ( // See: https://httpwg.org/specs/rfc7540.html#SettingValues initialMaxFrameSize = 16384 initialMaxHeaderTableSize = 4096 // See: https://tools.ietf.org/html/rfc7540#section-6.9.2 defaultInitialWindowSize = 65535 // headersPriorityMetadataLength is the length of the priority metadata that optionally occurs at // the beginning of the payload of the header frame. // // See: https://tools.ietf.org/html/rfc7540#section-6.2 headersPriorityMetadataLength = 5 // pushPromiseMetadataLength is the length of the metadata that is part of the payload of the // pushPromise frame. This does not include the padding length octet, which isn't needed due to // the relaxed security constraints of a development proxy. // // See: https://tools.ietf.org/html/rfc7540#section-6.6 pushPromiseMetadataLength = 4 // outputChannelSize is the size of the output channel. Roughly, it should be large enough to // allow a window's worth of frames to minimize synchronization overhead. outputChannelSize = 15 ) // relay encapsulates a flow of h2 traffic in one direction. type relay struct { dir Direction // srcLabel and destLabel are used only to create debugging messages. srcLabel, destLabel string src *http2.Framer // destMu guards writes to dest, which may occur on from either the `relayFrames` thread of // this relay or `peer`. `peer` writes WINDOW_UPDATE frames to this relay when it receives // DATA frames. destMu sync.Mutex dest *http2.Framer // maxFrameSize is set by the peer relay and is accessed atomically. maxFrameSize uint32 // The decoder and encoder settings can be adjusted by the peer connection so access to these // fields must be guarded. decoderMu sync.Mutex decoder *hpack.Decoder encoderMu sync.Mutex encoder *hpack.Encoder reencoded bytes.Buffer // handle to the output buffer of `encoder` // headerBuffer collects header fragments that are received across multiple frames, i.e., // when there are continuation frames. headerBuffer bytes.Buffer continuationState continuationState // flowMu guards access to flow-control related fields. flowMu sync.Mutex initialWindowSize uint32 connectionWindowSize int // "global" connection-level window size // outputBuffers is output pending available window size per-stream outputBuffers map[uint32]*outputBuffer // output stores stream output that is ready to be sent over HTTP/2. It provides a way to // guarantee frame order without blocking on each frame being sent. output chan queuedFrame enableDebugLogs *bool // The following fields depend on a circular dependency between the relays in opposite directions // so must be set explicitly after initialization. // processors stores per HTTP/2 stream processors. processors *streamProcessors peer *relay // relay for traffic from the peer } // newRelay initializes a relay for the given direction. This performs only partial initialization // due to circular dependency. func newRelay( dir Direction, srcLabel, destLabel string, src, dest *http2.Framer, enableDebugLogs *bool, ) *relay { ret := &relay{ dir: dir, srcLabel: srcLabel, destLabel: destLabel, src: src, dest: dest, maxFrameSize: initialMaxFrameSize, decoder: hpack.NewDecoder(initialMaxHeaderTableSize, nil), initialWindowSize: defaultInitialWindowSize, connectionWindowSize: defaultInitialWindowSize, outputBuffers: make(map[uint32]*outputBuffer), output: make(chan queuedFrame, outputChannelSize), enableDebugLogs: enableDebugLogs, } ret.encoder = hpack.NewEncoder(&ret.reencoded) // These limits seem to be part of the Go implementation of hpack. They exist because in a // production system, there must be limits on the resources requested by clients. However, this // is irrevelevant in a development proxy context. ret.decoder.SetAllowedMaxDynamicTableSize(math.MaxUint32) ret.encoder.SetMaxDynamicTableSizeLimit(math.MaxUint32) return ret } // relayFrames reads frames from `f.src` to `f.dest` until an error occurs or the connection closes. func (r *relay) relayFrames(closing chan bool) error { // Shutting down producer-consumers linked by channels is subtle. In this function, the writer // goroutine consumes frames from `r.output`, which are populated by the reader goroutine. If // the writer shuts down before the reader, the reader may deadlock on inserting frames into // `r.output`. The writer therefore has to keep processing until the reader is done. This is // coordinated via `readerDone`. // // A second subtlely is that errors on the writer goroutine should stop the reader goroutine. // This is communicated via `writeErr`. To avoid deadlocks, even after the error occurs, the // writer thread must still wait until `readerDone` has been communicated to stop processing. // Communicates to the consuming writer goroutine that the reader (the calling goroutine of this // method) is done. readerDone := make(chan struct{}) defer func() { readerDone <- struct{}{} }() // Communicates errors occuring on the writer goroutine to the reader goroutine. writerErr := make(chan error, 1) // This writer goroutine consumes the strictly ordered frames in `r.output` and delivers them. go func() { var err error for { select { case f := <-r.output: if err == nil { r.destMu.Lock() err = f.send(r.dest) r.destMu.Unlock() if err != nil { writerErr <- err } } // Once an output error has occurred, the remaining frames are drained from the channel // without sending them. case <-readerDone: return } } }() // This channel is buffered to allow the ReadFrame goroutine to drain on closing. frameReady := make(chan struct{}, 1) for { var frame http2.Frame var err error go func() { // ReadFrame is called in its own goroutine to make this function responsive to closing. It // does not need to block here to close. frame, err = r.src.ReadFrame() frameReady <- struct{}{} }() select { case <-frameReady: if err != nil { if err == io.EOF { return nil } return fmt.Errorf("reading frame: %w", err) } if err := r.processFrame(frame); err != nil { return fmt.Errorf("processing frame: %w", err) } if *r.enableDebugLogs { log.Infof("%s--%v-->%s", r.srcLabel, frame, r.destLabel) } case err := <-writerErr: return fmt.Errorf("sending frame: %w", err) case <-closing: // The ReadFrame goroutine is abandoned at this point. It completes as soon as the blocking // ReadFrame call completes, but could potentially leak for an unspecified duration. return nil } } } func (r *relay) processFrame(f http2.Frame) error { var err error switch f := f.(type) { case *http2.DataFrame: // The proxy's window increments as soon as it receives data. This assumes that the proxy has // ample resources because it is inteded for testing and development. if err = r.peer.sendWindowUpdates(f); err == nil { err = r.processor(f.StreamID).Data(f.Data(), f.StreamEnded()) } case *http2.HeadersFrame: if !f.HeadersEnded() { r.headerBuffer.Reset() r.headerBuffer.Write(f.HeaderBlockFragment()) r.continuationState = &headerContinuation{f.Priority} } else { var headers []hpack.HeaderField headers, err = r.decodeFull(f.HeaderBlockFragment()) if err != nil { return fmt.Errorf("decoding header %v: %w", f, err) } err = r.processor(f.StreamID).Header(headers, f.StreamEnded(), f.Priority) } case *http2.PriorityFrame: err = r.processor(f.StreamID).Priority(f.PriorityParam) case *http2.RSTStreamFrame: err = r.processor(f.StreamID).RSTStream(f.ErrCode) case *http2.SettingsFrame: if f.IsAck() { r.destMu.Lock() err = r.dest.WriteSettingsAck() r.destMu.Unlock() } else { var settings []http2.Setting if err = f.ForeachSetting(func(s http2.Setting) error { switch s.ID { case http2.SettingHeaderTableSize: r.peer.updateTableSize(s.Val) case http2.SettingInitialWindowSize: r.peer.updateInitialWindowSize(s.Val) case http2.SettingMaxFrameSize: r.peer.updateMaxFrameSize(s.Val) } settings = append(settings, s) return nil }); err == nil { r.destMu.Lock() err = r.dest.WriteSettings(settings...) r.destMu.Unlock() } } case *http2.PushPromiseFrame: if !f.HeadersEnded() { r.headerBuffer.Reset() r.headerBuffer.Write(f.HeaderBlockFragment()) r.continuationState = &pushPromiseContinuation{f.PromiseID} } else { var headers []hpack.HeaderField headers, err = r.decodeFull(f.HeaderBlockFragment()) if err != nil { return fmt.Errorf("decoding push promise %v: %w", f, err) } err = r.processor(f.StreamID).PushPromise(f.PromiseID, headers) } case *http2.PingFrame: r.destMu.Lock() err = r.dest.WritePing(f.IsAck(), f.Data) r.destMu.Unlock() case *http2.GoAwayFrame: r.destMu.Lock() err = r.dest.WriteGoAway(f.LastStreamID, f.ErrCode, f.DebugData()) r.destMu.Unlock() case *http2.WindowUpdateFrame: r.peer.updateWindow(f) case *http2.ContinuationFrame: r.headerBuffer.Write(f.HeaderBlockFragment()) if f.HeadersEnded() { var headers []hpack.HeaderField headers, err = r.decodeFull(r.headerBuffer.Bytes()) if err != nil { return fmt.Errorf("decoding headers for continuation %v: %w", f, err) } err = r.continuationState.complete(r.processor(f.StreamID), headers) } default: err = errors.New("unrecognized frame type") } return err } func (r *relay) processor(id uint32) Processor { return r.processors.Get(id, r.dir) } func (r *relay) updateTableSize(v uint32) { r.decoderMu.Lock() r.decoder.SetMaxDynamicTableSize(v) r.decoderMu.Unlock() r.encoderMu.Lock() r.encoder.SetMaxDynamicTableSize(v) r.encoderMu.Unlock() } func (r *relay) updateMaxFrameSize(v uint32) { atomic.StoreUint32(&r.maxFrameSize, v) } // updateInitialWindowSize updates the initial window size and updates all stream windows based on // the difference. Note that this should not include the connection window. // See: https://tools.ietf.org/html/rfc7540#section-6.9.2 // // This is called by `peer`, so requires a thread-safe implementation. func (r *relay) updateInitialWindowSize(v uint32) { r.flowMu.Lock() delta := int(v) - int(r.initialWindowSize) r.initialWindowSize = v for _, w := range r.outputBuffers { w.windowSize += delta } r.flowMu.Unlock() // Since all the stream windows may be impacted, all the queues need to be checked for newly // eligible frames. r.sendQueuedFramesUnderWindowSize() } // updateWindow updates the specified window size and may result in the sending of data frames. func (r *relay) updateWindow(f *http2.WindowUpdateFrame) { if f.StreamID == 0 { // A stream ID of 0 means updating the global connection window size. This may cause any // queued frame belonging to any stream to become eligible for sending. r.flowMu.Lock() r.connectionWindowSize += int(f.Increment) r.flowMu.Unlock() r.sendQueuedFramesUnderWindowSize() } r.flowMu.Lock() w := r.outputBuffer(f.StreamID) w.windowSize += int(f.Increment) w.emitEligibleFrames(r.output, &r.connectionWindowSize) r.flowMu.Unlock() } func (r *relay) data(id uint32, data []byte, streamEnded bool) error { // This implementation only allows `WriteData` without padding. Padding is used to improve the // security against attacks like CRIME, but this isn't relevant for a development proxy. // // If padding were allowed, this length would need to vary depending on whether the padding // length octet is present. maxPayloadLength := atomic.LoadUint32(&r.maxFrameSize) r.flowMu.Lock() w := r.outputBuffer(id) r.flowMu.Unlock() // If data is larger than what would be permitted at the current max frame size setting, the data // is split across multiple frames. for { nextPayloadLength := uint32(len(data)) if nextPayloadLength > maxPayloadLength { nextPayloadLength = maxPayloadLength } nextPayload := make([]byte, nextPayloadLength) copy(nextPayload, data) data = data[nextPayloadLength:] f := &queuedDataFrame{id, streamEnded && len(data) == 0, nextPayload} r.flowMu.Lock() w.enqueue(f) w.emitEligibleFrames(r.output, &r.connectionWindowSize) r.flowMu.Unlock() // Some protocols send empty data frames with END_STREAM so the check is done here at the end // of the loop instead of at the beginning of the loop. if len(data) == 0 { break } } return nil } func (r *relay) header( id uint32, headers []hpack.HeaderField, streamEnded bool, priority http2.PriorityParam, ) error { encoded, err := r.encodeFull(headers) if err != nil { return fmt.Errorf("encoding headers %v: %w", headers, err) } maxPayloadLength := atomic.LoadUint32(&r.maxFrameSize) // Padding is not implemented because the extra security is not needed for a development proxy. // If it were used, a single padding length octet should be deducted from the max header fragment // length. maxHeaderFragmentLength := maxPayloadLength if !priority.IsZero() { maxHeaderFragmentLength -= headersPriorityMetadataLength } chunks := splitIntoChunks(int(maxHeaderFragmentLength), int(maxPayloadLength), encoded) r.enqueueFrame(&queuedHeaderFrame{ streamID: id, endStream: streamEnded, priority: priority, chunks: chunks, }) return nil } func (r *relay) priority(id uint32, priority http2.PriorityParam) { r.enqueueFrame(&queuedPriorityFrame{ streamID: id, priority: priority, }) } func (r *relay) rstStream(id uint32, errCode http2.ErrCode) { r.enqueueFrame(&queuedRSTStreamFrame{ streamID: id, errCode: errCode, }) } func (r *relay) pushPromise(id, promiseID uint32, headers []hpack.HeaderField) error { encoded, err := r.encodeFull(headers) if err != nil { return fmt.Errorf("encoding push promise headers %v: %w", headers, err) } maxPayloadLength := atomic.LoadUint32(&r.maxFrameSize) maxHeaderFragmentLength := maxPayloadLength - pushPromiseMetadataLength chunks := splitIntoChunks(int(maxHeaderFragmentLength), int(maxPayloadLength), encoded) r.enqueueFrame(&queuedPushPromiseFrame{ streamID: id, promiseID: promiseID, chunks: chunks, }) return nil } func (r *relay) enqueueFrame(f queuedFrame) { // The frame is first added to the appropriate stream. r.flowMu.Lock() w := r.outputBuffer(f.StreamID()) w.enqueue(f) w.emitEligibleFrames(r.output, &r.connectionWindowSize) r.flowMu.Unlock() } func (r *relay) sendQueuedFramesUnderWindowSize() { r.flowMu.Lock() for _, w := range r.outputBuffers { w.emitEligibleFrames(r.output, &r.connectionWindowSize) } r.flowMu.Unlock() } // outputBuffer returns the outputBuffer instance for the given stream, creating one if needed. // // This method is not thread-safe. The caller should be holding `flowMu`. func (r *relay) outputBuffer(streamID uint32) *outputBuffer { w, ok := r.outputBuffers[streamID] if !ok { w = &outputBuffer{ windowSize: int(r.initialWindowSize), } r.outputBuffers[streamID] = w } return w } // sendWindowUpdates sends WINDOW_UPDATE frames effectively acknowledging consumption of the // given data frame. func (r *relay) sendWindowUpdates(f *http2.DataFrame) error { if len(f.Data()) <= 0 { return nil } r.destMu.Lock() defer r.destMu.Unlock() // First updates the connection level window. if err := r.dest.WriteWindowUpdate(0, uint32(len(f.Data()))); err != nil { return err } // Next updates the stream specific window. return r.dest.WriteWindowUpdate(f.StreamID, uint32(len(f.Data()))) } func (r *relay) decodeFull(data []byte) ([]hpack.HeaderField, error) { r.decoderMu.Lock() defer r.decoderMu.Unlock() return r.decoder.DecodeFull(data) } func (r *relay) encodeFull(headers []hpack.HeaderField) ([]byte, error) { r.encoderMu.Lock() defer r.encoderMu.Unlock() r.reencoded.Reset() var buf bytes.Buffer for _, h := range headers { if *r.enableDebugLogs { if h.Name == "content-type" && h.Value == "application/grpc" { fmt.Fprintf(&buf, " \u001b[1;36m%v\u001b[0m\n", h) } else { fmt.Fprintf(&buf, " %v\n", h) } } if err := r.encoder.WriteField(h); err != nil { return nil, fmt.Errorf("reencoding header field %v in %v: %w", h, headers, err) } } if *r.enableDebugLogs { log.Infof("sending headers %s -> %s:\n%s", r.srcLabel, r.destLabel, buf.Bytes()) } return r.reencoded.Bytes(), nil } // outputBuffer stores enqueued output frames for a given stream. type outputBuffer struct { // windowSize indicates how much data the receiver is ready to process. windowSize int queue list.List // contains queuedFrame elements } // emitEligibleFrames emits frames that would fit under both the stream window size and the // given connection window size. It updates the given connectionWindowSize if applicable. // // This is not thread-safe. The caller should be holding `relay.flowMu`. func (w *outputBuffer) emitEligibleFrames(output chan queuedFrame, connectionWindowSize *int) { for e := w.queue.Front(); e != nil; { f := e.Value.(queuedFrame) if f.flowControlSize() > *connectionWindowSize || f.flowControlSize() > w.windowSize { break } output <- f *connectionWindowSize -= f.flowControlSize() w.windowSize -= f.flowControlSize() next := e.Next() w.queue.Remove(e) e = next } } // enqueue adds the frame to this stream output. This is not thread-safe. The caller must hold // relay.flowMu. func (w *outputBuffer) enqueue(f queuedFrame) { w.queue.PushBack(f) } // continuationState holds the context needed to interpret CONTINUATION frames, specifically whether // the parents were HEADERS or PUSH_PROMISE frames. type continuationState interface { complete(s Processor, headers []hpack.HeaderField) error } type headerContinuation struct { priority http2.PriorityParam } func (h *headerContinuation) complete(s Processor, headers []hpack.HeaderField) error { return s.Header(headers, true, h.priority) } type pushPromiseContinuation struct { promiseID uint32 } func (p *pushPromiseContinuation) complete(s Processor, headers []hpack.HeaderField) error { return s.PushPromise(p.promiseID, headers) } // splitIntoChunks splits header payloads into chunks that respect frame size limits. func splitIntoChunks(firstChunkMax, continuationMax int, data []byte) [][]byte { var chunks [][]byte firstChunkLength := len(data) if firstChunkLength > firstChunkMax { firstChunkLength = firstChunkMax } buf := make([]byte, firstChunkLength) copy(buf, data[:firstChunkLength]) chunks = append(chunks, buf) remaining := data[firstChunkLength:] for len(remaining) > 0 { nextChunkLength := len(remaining) if nextChunkLength > continuationMax { nextChunkLength = continuationMax } buf = make([]byte, nextChunkLength) copy(buf, remaining[:nextChunkLength]) chunks = append(chunks, buf) remaining = remaining[nextChunkLength:] } return chunks } martian-3.3.2/h2/testing/000077500000000000000000000000001421371434000151375ustar00rootroot00000000000000martian-3.3.2/h2/testing/certs.go000066400000000000000000000076141421371434000166160ustar00rootroot00000000000000// Copyright 2021 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package testing import ( "crypto" "crypto/rand" "crypto/rsa" "crypto/sha256" "crypto/tls" "crypto/x509" "crypto/x509/pkix" "fmt" "log" "os" "time" "github.com/google/martian/v3/cybervillains" "github.com/google/martian/v3/mitm" "google.golang.org/grpc/credentials" ) var ( // CA is the certificate authority. It uses the Cybervillains key pair. CA *x509.Certificate // CAKey is the private key of the certificate authority. CAKey crypto.PrivateKey // RootCAs is a certificate pool containing `CA`. RootCAs *x509.CertPool // ClientTLS is a set of transport credentials to use with chains signed by `CA`. ClientTLS credentials.TransportCredentials // Localhost is a certificate for "localhost" signed by `CA`. Localhost *tls.Certificate ) func init() { var err error CA, CAKey, err = initCA() if err != nil { log.Fatalf("Error initializing Cybervillains CA: %v", err) } RootCAs = x509.NewCertPool() RootCAs.AddCert(CA) ClientTLS = credentials.NewClientTLSFromCert(RootCAs, "") Localhost, err = initLocalhostCert(CA, CAKey) if err != nil { log.Fatalf("Error creating localhost server certificate: %v", err) } } func initCA() (*x509.Certificate, crypto.PrivateKey, error) { chain, err := tls.X509KeyPair([]byte(cybervillains.Cert), []byte(cybervillains.Key)) if err != nil { return nil, nil, fmt.Errorf("creating Cybervillains root: %w", err) } cert, err := x509.ParseCertificate(chain.Certificate[0]) if err != nil { return nil, nil, fmt.Errorf("parsing Cybervillains certificate: %w", err) } return cert, chain.PrivateKey, nil } func initLocalhostCert(ca *x509.Certificate, caPriv crypto.PrivateKey) (*tls.Certificate, error) { priv, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { return nil, fmt.Errorf("generating random key: %w", err) } // Subject Key Identifier support for end entity certificate. // https://www.ietf.org/rfc/rfc3280.txt (section 4.2.1.2) pkixpub, err := x509.MarshalPKIXPublicKey(priv.Public()) if err != nil { return nil, fmt.Errorf("marshalling public key: %w", err) } hasher := sha256.New() hasher.Write(pkixpub) keyID := hasher.Sum(nil) serial, err := rand.Int(rand.Reader, mitm.MaxSerialNumber) if err != nil { return nil, fmt.Errorf("generating serial number: %w", err) } hostname, err := os.Hostname() if err != nil { return nil, fmt.Errorf("getting hostname for creating cert: %w", err) } tmpl := &x509.Certificate{ SerialNumber: serial, Subject: pkix.Name{ CommonName: hostname, Organization: []string{"Martian Proxy"}, }, SubjectKeyId: keyID, KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, BasicConstraintsValid: true, NotBefore: time.Now().Add(-time.Hour), NotAfter: time.Now().Add(time.Hour), DNSNames: []string{hostname}, } der, err := x509.CreateCertificate(rand.Reader, tmpl, ca, priv.Public(), caPriv) if err != nil { return nil, fmt.Errorf("creating X509 server certificate: %w", err) } x509c, err := x509.ParseCertificate(der) if err != nil { return nil, fmt.Errorf("parsing DER encoded certificate: %w", err) } return &tls.Certificate{ Certificate: [][]byte{x509c.Raw, ca.Raw}, PrivateKey: priv, Leaf: x509c, }, nil } martian-3.3.2/h2/testing/fixture.go000066400000000000000000000132621421371434000171600ustar00rootroot00000000000000// Copyright 2021 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package testing contains a test fixture for working with gRPC over HTTP/2. package testing import ( "crypto/tls" "fmt" "io/ioutil" "net" "net/http" "os" "strconv" "sync" "time" "github.com/google/martian/v3" "github.com/google/martian/v3/h2" "github.com/google/martian/v3/mitm" "google.golang.org/grpc" "google.golang.org/grpc/credentials" tspb "github.com/google/martian/v3/h2/testservice" ) var ( // proxyPort is a global variable that stores the listener used by the proxy. This value is // shared globally because golang http transport code caches the environment variable values, in // particular HTTPS_PROXY. proxyPort int ) // Fixture encapsulates the TestService gRPC server, a proxy and a gRPC client. type Fixture struct { // TestServiceClient is a client pointing at the service and redirected through the proxy. tspb.TestServiceClient wg sync.WaitGroup server *grpc.Server // serverErr is any error returned by invoking `Serve` on the gRPC server. serverErr error proxyListener net.Listener proxy *martian.Proxy conn *grpc.ClientConn } // New creates a new instance of the Fixture. It is not possible for there to be more than one // instance concurrently because clients decide whether to use the proxy based on the global // HTTPS_PROXY environment variable. func New(spf []h2.StreamProcessorFactory) (*Fixture, error) { f := &Fixture{} // Starts the gRPC server. f.server = grpc.NewServer(grpc.Creds(credentials.NewServerTLSFromCert(Localhost))) tspb.RegisterTestServiceServer(f.server, &Server{}) lis, err := net.Listen("tcp", ":0") if err != nil { return nil, fmt.Errorf("creating listener for gRPC service: %w", err) } f.wg.Add(1) go func() { defer f.wg.Done() f.serverErr = f.server.Serve(lis) }() hostname, err := os.Hostname() if err != nil { return nil, fmt.Errorf("getting hostname: %w", err) } // Creates a listener for the proxy, obtaining a new port if needed. if proxyPort == 0 { // Attempts a query to port server first, falling back if it is unavailable. Ports that are // provided by listening on ":0" can be recyled by the OS leading to flakiness in certain // environments since we need the same port to be available across multiple instances of the // test fixture. proxyPort = queryPortServer() if proxyPort == 0 { var err error f.proxyListener, err = net.Listen("tcp", ":0") if err != nil { return nil, fmt.Errorf("creating listener for proxy; %w", err) } proxyPort = f.proxyListener.Addr().(*net.TCPAddr).Port } proxyTarget := hostname + ":" + strconv.Itoa(proxyPort) // Sets the HTTPS_PROXY environment variable so that http requests will go through the proxy. os.Setenv("HTTPS_PROXY", fmt.Sprintf("http://%s", proxyTarget)) fmt.Printf("proxy at %s\n", proxyTarget) } if f.proxyListener == nil { var err error f.proxyListener, err = net.Listen("tcp", fmt.Sprintf(":%d", proxyPort)) if err != nil { return nil, fmt.Errorf("creating listener for proxy; %w", err) } } // Starts the proxy. f.proxy, err = newProxy(spf) if err != nil { return nil, fmt.Errorf("creating proxy: %w", err) } go func() { f.proxy.Serve(f.proxyListener) }() port := lis.Addr().(*net.TCPAddr).Port target := hostname + ":" + strconv.Itoa(port) fmt.Printf("server at %s\n", target) // Connects a gRPC client with the service via the proxy. f.conn, err = grpc.Dial(target, grpc.WithTransportCredentials(ClientTLS)) if err != nil { return nil, fmt.Errorf("error dialing %s: %w", target, err) } f.TestServiceClient = tspb.NewTestServiceClient(f.conn) return f, nil } // Close cleans up the servers and connections. func (f *Fixture) Close() error { f.conn.Close() f.server.Stop() f.proxy.Close() f.wg.Wait() if err := f.proxyListener.Close(); err != nil { return fmt.Errorf("closing proxy listener: %w", err) } return f.serverErr } func newProxy(spf []h2.StreamProcessorFactory) (*martian.Proxy, error) { p := martian.NewProxy() mc, err := mitm.NewConfig(CA, CAKey) if err != nil { return nil, fmt.Errorf("creating mitm config: %w", err) } mc.SetValidity(time.Hour) mc.SetOrganization("Martian Proxy") mc.SetH2Config(&h2.Config{ AllowedHostsFilter: func(_ string) bool { return true }, RootCAs: RootCAs, StreamProcessorFactories: spf, EnableDebugLogs: true, }) p.SetMITM(mc) tr := &http.Transport{ TLSClientConfig: &tls.Config{ RootCAs: RootCAs, }, } p.SetRoundTripper(tr) return p, nil } func queryPortServer() int { // portpicker isn't available in third_party. if portServer := os.Getenv("PORTSERVER_ADDRESS"); portServer != "" { c, err := net.Dial("unix", portServer) if err != nil { // failed connection to portServer; this is normal in many circumstances. return 0 } defer c.Close() if _, err := fmt.Fprintf(c, "%d\n", os.Getpid()); err != nil { return 0 } buf, err := ioutil.ReadAll(c) if err != nil || len(buf) == 0 { return 0 } buf = buf[:len(buf)-1] // remove newline char port, err := strconv.Atoi(string(buf)) if err != nil { return 0 } fmt.Printf("got port %d\n", port) return port } return 0 } martian-3.3.2/h2/testing/test_service.go000066400000000000000000000032251421371434000201670ustar00rootroot00000000000000// Copyright 2021 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package testing import ( "context" "io" tspb "github.com/google/martian/v3/h2/testservice" ) // Server is a testing gRPC server. type Server struct { tspb.UnimplementedTestServiceServer } // Echo handles TestService.Echo RPCs. func (s *Server) Echo(ctx context.Context, in *tspb.EchoRequest) (*tspb.EchoResponse, error) { return &tspb.EchoResponse{ Payload: in.GetPayload(), }, nil } // Sum handles TestService.Sum RPCs. func (s *Server) Sum(_ context.Context, in *tspb.SumRequest) (*tspb.SumResponse, error) { sum := int32(0) for _, v := range in.GetValues() { sum += v } return &tspb.SumResponse{ Value: sum, }, nil } // DoubleEcho handles TestService.DoubleEcho RPCs. func (s *Server) DoubleEcho(stream tspb.TestService_DoubleEchoServer) error { for { req, err := stream.Recv() if err == io.EOF { return nil } if err != nil { return err } resp := &tspb.EchoResponse{ Payload: req.GetPayload(), } if err := stream.Send(resp); err != nil { return err } if err := stream.Send(resp); err != nil { return err } } } martian-3.3.2/h2/testservice/000077500000000000000000000000001421371434000160225ustar00rootroot00000000000000martian-3.3.2/h2/testservice/test_service.pb.go000066400000000000000000000273751421371434000214660ustar00rootroot00000000000000// Copyright 2021 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.26.0 // protoc v3.6.1 // source: test_service.proto package testservice import ( protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" reflect "reflect" sync "sync" ) const ( // Verify that this generated code is sufficiently up-to-date. _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) // Verify that runtime/protoimpl is sufficiently up-to-date. _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) type EchoRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields Payload string `protobuf:"bytes,1,opt,name=payload,proto3" json:"payload,omitempty"` } func (x *EchoRequest) Reset() { *x = EchoRequest{} if protoimpl.UnsafeEnabled { mi := &file_test_service_proto_msgTypes[0] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } } func (x *EchoRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*EchoRequest) ProtoMessage() {} func (x *EchoRequest) ProtoReflect() protoreflect.Message { mi := &file_test_service_proto_msgTypes[0] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use EchoRequest.ProtoReflect.Descriptor instead. func (*EchoRequest) Descriptor() ([]byte, []int) { return file_test_service_proto_rawDescGZIP(), []int{0} } func (x *EchoRequest) GetPayload() string { if x != nil { return x.Payload } return "" } type EchoResponse struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields Payload string `protobuf:"bytes,1,opt,name=payload,proto3" json:"payload,omitempty"` } func (x *EchoResponse) Reset() { *x = EchoResponse{} if protoimpl.UnsafeEnabled { mi := &file_test_service_proto_msgTypes[1] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } } func (x *EchoResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*EchoResponse) ProtoMessage() {} func (x *EchoResponse) ProtoReflect() protoreflect.Message { mi := &file_test_service_proto_msgTypes[1] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use EchoResponse.ProtoReflect.Descriptor instead. func (*EchoResponse) Descriptor() ([]byte, []int) { return file_test_service_proto_rawDescGZIP(), []int{1} } func (x *EchoResponse) GetPayload() string { if x != nil { return x.Payload } return "" } type SumRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields Values []int32 `protobuf:"varint,1,rep,packed,name=values,proto3" json:"values,omitempty"` } func (x *SumRequest) Reset() { *x = SumRequest{} if protoimpl.UnsafeEnabled { mi := &file_test_service_proto_msgTypes[2] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } } func (x *SumRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*SumRequest) ProtoMessage() {} func (x *SumRequest) ProtoReflect() protoreflect.Message { mi := &file_test_service_proto_msgTypes[2] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use SumRequest.ProtoReflect.Descriptor instead. func (*SumRequest) Descriptor() ([]byte, []int) { return file_test_service_proto_rawDescGZIP(), []int{2} } func (x *SumRequest) GetValues() []int32 { if x != nil { return x.Values } return nil } type SumResponse struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields Value int32 `protobuf:"varint,1,opt,name=value,proto3" json:"value,omitempty"` } func (x *SumResponse) Reset() { *x = SumResponse{} if protoimpl.UnsafeEnabled { mi := &file_test_service_proto_msgTypes[3] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } } func (x *SumResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*SumResponse) ProtoMessage() {} func (x *SumResponse) ProtoReflect() protoreflect.Message { mi := &file_test_service_proto_msgTypes[3] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use SumResponse.ProtoReflect.Descriptor instead. func (*SumResponse) Descriptor() ([]byte, []int) { return file_test_service_proto_rawDescGZIP(), []int{3} } func (x *SumResponse) GetValue() int32 { if x != nil { return x.Value } return 0 } var File_test_service_proto protoreflect.FileDescriptor var file_test_service_proto_rawDesc = []byte{ 0x0a, 0x12, 0x74, 0x65, 0x73, 0x74, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0c, 0x74, 0x65, 0x73, 0x74, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x22, 0x27, 0x0a, 0x0b, 0x45, 0x63, 0x68, 0x6f, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x18, 0x0a, 0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x22, 0x28, 0x0a, 0x0c, 0x45, 0x63, 0x68, 0x6f, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x22, 0x24, 0x0a, 0x0a, 0x53, 0x75, 0x6d, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x05, 0x52, 0x06, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x73, 0x22, 0x23, 0x0a, 0x0b, 0x53, 0x75, 0x6d, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x32, 0xd7, 0x01, 0x0a, 0x0b, 0x54, 0x65, 0x73, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x3f, 0x0a, 0x04, 0x45, 0x63, 0x68, 0x6f, 0x12, 0x19, 0x2e, 0x74, 0x65, 0x73, 0x74, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x45, 0x63, 0x68, 0x6f, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x74, 0x65, 0x73, 0x74, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x45, 0x63, 0x68, 0x6f, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x3c, 0x0a, 0x03, 0x53, 0x75, 0x6d, 0x12, 0x18, 0x2e, 0x74, 0x65, 0x73, 0x74, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x53, 0x75, 0x6d, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x74, 0x65, 0x73, 0x74, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x53, 0x75, 0x6d, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x49, 0x0a, 0x0a, 0x44, 0x6f, 0x75, 0x62, 0x6c, 0x65, 0x45, 0x63, 0x68, 0x6f, 0x12, 0x19, 0x2e, 0x74, 0x65, 0x73, 0x74, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x45, 0x63, 0x68, 0x6f, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x74, 0x65, 0x73, 0x74, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x45, 0x63, 0x68, 0x6f, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x2a, 0x5a, 0x28, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x6d, 0x61, 0x72, 0x74, 0x69, 0x61, 0x6e, 0x2f, 0x68, 0x32, 0x2f, 0x74, 0x65, 0x73, 0x74, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( file_test_service_proto_rawDescOnce sync.Once file_test_service_proto_rawDescData = file_test_service_proto_rawDesc ) func file_test_service_proto_rawDescGZIP() []byte { file_test_service_proto_rawDescOnce.Do(func() { file_test_service_proto_rawDescData = protoimpl.X.CompressGZIP(file_test_service_proto_rawDescData) }) return file_test_service_proto_rawDescData } var file_test_service_proto_msgTypes = make([]protoimpl.MessageInfo, 4) var file_test_service_proto_goTypes = []interface{}{ (*EchoRequest)(nil), // 0: test_service.EchoRequest (*EchoResponse)(nil), // 1: test_service.EchoResponse (*SumRequest)(nil), // 2: test_service.SumRequest (*SumResponse)(nil), // 3: test_service.SumResponse } var file_test_service_proto_depIdxs = []int32{ 0, // 0: test_service.TestService.Echo:input_type -> test_service.EchoRequest 2, // 1: test_service.TestService.Sum:input_type -> test_service.SumRequest 0, // 2: test_service.TestService.DoubleEcho:input_type -> test_service.EchoRequest 1, // 3: test_service.TestService.Echo:output_type -> test_service.EchoResponse 3, // 4: test_service.TestService.Sum:output_type -> test_service.SumResponse 1, // 5: test_service.TestService.DoubleEcho:output_type -> test_service.EchoResponse 3, // [3:6] is the sub-list for method output_type 0, // [0:3] is the sub-list for method input_type 0, // [0:0] is the sub-list for extension type_name 0, // [0:0] is the sub-list for extension extendee 0, // [0:0] is the sub-list for field type_name } func init() { file_test_service_proto_init() } func file_test_service_proto_init() { if File_test_service_proto != nil { return } if !protoimpl.UnsafeEnabled { file_test_service_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*EchoRequest); i { case 0: return &v.state case 1: return &v.sizeCache case 2: return &v.unknownFields default: return nil } } file_test_service_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*EchoResponse); i { case 0: return &v.state case 1: return &v.sizeCache case 2: return &v.unknownFields default: return nil } } file_test_service_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*SumRequest); i { case 0: return &v.state case 1: return &v.sizeCache case 2: return &v.unknownFields default: return nil } } file_test_service_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*SumResponse); i { case 0: return &v.state case 1: return &v.sizeCache case 2: return &v.unknownFields default: return nil } } } type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_test_service_proto_rawDesc, NumEnums: 0, NumMessages: 4, NumExtensions: 0, NumServices: 1, }, GoTypes: file_test_service_proto_goTypes, DependencyIndexes: file_test_service_proto_depIdxs, MessageInfos: file_test_service_proto_msgTypes, }.Build() File_test_service_proto = out.File file_test_service_proto_rawDesc = nil file_test_service_proto_goTypes = nil file_test_service_proto_depIdxs = nil } martian-3.3.2/h2/testservice/test_service.proto000066400000000000000000000023431421371434000216100ustar00rootroot00000000000000// Copyright 2021 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. syntax = "proto3"; package test_service; option go_package = "github.com/google/martian/h2/testservice"; message EchoRequest { string payload = 1; } message EchoResponse { string payload = 1; } message SumRequest { repeated int32 values = 1; } message SumResponse { int32 value = 1; } service TestService { // The server returns the client message as-is. rpc Echo(EchoRequest) returns (EchoResponse) {} // The server returns the sum of the input values. rpc Sum(SumRequest) returns (SumResponse) {} // The server returns every message twice. rpc DoubleEcho(stream EchoRequest) returns (stream EchoResponse) {} } martian-3.3.2/h2/testservice/test_service_grpc.pb.go000066400000000000000000000155141421371434000224710ustar00rootroot00000000000000// Code generated by protoc-gen-go-grpc. DO NOT EDIT. package testservice import ( context "context" grpc "google.golang.org/grpc" codes "google.golang.org/grpc/codes" status "google.golang.org/grpc/status" ) // This is a compile-time assertion to ensure that this generated file // is compatible with the grpc package it is being compiled against. // Requires gRPC-Go v1.32.0 or later. const _ = grpc.SupportPackageIsVersion7 // TestServiceClient is the client API for TestService service. // // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. type TestServiceClient interface { // The server returns the client message as-is. Echo(ctx context.Context, in *EchoRequest, opts ...grpc.CallOption) (*EchoResponse, error) // The server returns the sum of the input values. Sum(ctx context.Context, in *SumRequest, opts ...grpc.CallOption) (*SumResponse, error) // The server returns every message twice. DoubleEcho(ctx context.Context, opts ...grpc.CallOption) (TestService_DoubleEchoClient, error) } type testServiceClient struct { cc grpc.ClientConnInterface } func NewTestServiceClient(cc grpc.ClientConnInterface) TestServiceClient { return &testServiceClient{cc} } func (c *testServiceClient) Echo(ctx context.Context, in *EchoRequest, opts ...grpc.CallOption) (*EchoResponse, error) { out := new(EchoResponse) err := c.cc.Invoke(ctx, "/test_service.TestService/Echo", in, out, opts...) if err != nil { return nil, err } return out, nil } func (c *testServiceClient) Sum(ctx context.Context, in *SumRequest, opts ...grpc.CallOption) (*SumResponse, error) { out := new(SumResponse) err := c.cc.Invoke(ctx, "/test_service.TestService/Sum", in, out, opts...) if err != nil { return nil, err } return out, nil } func (c *testServiceClient) DoubleEcho(ctx context.Context, opts ...grpc.CallOption) (TestService_DoubleEchoClient, error) { stream, err := c.cc.NewStream(ctx, &TestService_ServiceDesc.Streams[0], "/test_service.TestService/DoubleEcho", opts...) if err != nil { return nil, err } x := &testServiceDoubleEchoClient{stream} return x, nil } type TestService_DoubleEchoClient interface { Send(*EchoRequest) error Recv() (*EchoResponse, error) grpc.ClientStream } type testServiceDoubleEchoClient struct { grpc.ClientStream } func (x *testServiceDoubleEchoClient) Send(m *EchoRequest) error { return x.ClientStream.SendMsg(m) } func (x *testServiceDoubleEchoClient) Recv() (*EchoResponse, error) { m := new(EchoResponse) if err := x.ClientStream.RecvMsg(m); err != nil { return nil, err } return m, nil } // TestServiceServer is the server API for TestService service. // All implementations must embed UnimplementedTestServiceServer // for forward compatibility type TestServiceServer interface { // The server returns the client message as-is. Echo(context.Context, *EchoRequest) (*EchoResponse, error) // The server returns the sum of the input values. Sum(context.Context, *SumRequest) (*SumResponse, error) // The server returns every message twice. DoubleEcho(TestService_DoubleEchoServer) error mustEmbedUnimplementedTestServiceServer() } // UnimplementedTestServiceServer must be embedded to have forward compatible implementations. type UnimplementedTestServiceServer struct { } func (UnimplementedTestServiceServer) Echo(context.Context, *EchoRequest) (*EchoResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method Echo not implemented") } func (UnimplementedTestServiceServer) Sum(context.Context, *SumRequest) (*SumResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method Sum not implemented") } func (UnimplementedTestServiceServer) DoubleEcho(TestService_DoubleEchoServer) error { return status.Errorf(codes.Unimplemented, "method DoubleEcho not implemented") } func (UnimplementedTestServiceServer) mustEmbedUnimplementedTestServiceServer() {} // UnsafeTestServiceServer may be embedded to opt out of forward compatibility for this service. // Use of this interface is not recommended, as added methods to TestServiceServer will // result in compilation errors. type UnsafeTestServiceServer interface { mustEmbedUnimplementedTestServiceServer() } func RegisterTestServiceServer(s grpc.ServiceRegistrar, srv TestServiceServer) { s.RegisterService(&TestService_ServiceDesc, srv) } func _TestService_Echo_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(EchoRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(TestServiceServer).Echo(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: "/test_service.TestService/Echo", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(TestServiceServer).Echo(ctx, req.(*EchoRequest)) } return interceptor(ctx, in, info, handler) } func _TestService_Sum_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(SumRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(TestServiceServer).Sum(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: "/test_service.TestService/Sum", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(TestServiceServer).Sum(ctx, req.(*SumRequest)) } return interceptor(ctx, in, info, handler) } func _TestService_DoubleEcho_Handler(srv interface{}, stream grpc.ServerStream) error { return srv.(TestServiceServer).DoubleEcho(&testServiceDoubleEchoServer{stream}) } type TestService_DoubleEchoServer interface { Send(*EchoResponse) error Recv() (*EchoRequest, error) grpc.ServerStream } type testServiceDoubleEchoServer struct { grpc.ServerStream } func (x *testServiceDoubleEchoServer) Send(m *EchoResponse) error { return x.ServerStream.SendMsg(m) } func (x *testServiceDoubleEchoServer) Recv() (*EchoRequest, error) { m := new(EchoRequest) if err := x.ServerStream.RecvMsg(m); err != nil { return nil, err } return m, nil } // TestService_ServiceDesc is the grpc.ServiceDesc for TestService service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) var TestService_ServiceDesc = grpc.ServiceDesc{ ServiceName: "test_service.TestService", HandlerType: (*TestServiceServer)(nil), Methods: []grpc.MethodDesc{ { MethodName: "Echo", Handler: _TestService_Echo_Handler, }, { MethodName: "Sum", Handler: _TestService_Sum_Handler, }, }, Streams: []grpc.StreamDesc{ { StreamName: "DoubleEcho", Handler: _TestService_DoubleEcho_Handler, ServerStreams: true, ClientStreams: true, }, }, Metadata: "test_service.proto", } martian-3.3.2/har/000077500000000000000000000000001421371434000137235ustar00rootroot00000000000000martian-3.3.2/har/har.go000066400000000000000000000526141421371434000150340ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package har collects HTTP requests and responses and stores them in HAR format. // // For more information on HAR, see: // https://w3c.github.io/web-performance/specs/HAR/Overview.html package har import ( "bytes" "encoding/base64" "encoding/json" "fmt" "io" "io/ioutil" "mime" "mime/multipart" "net/http" "net/url" "strings" "sync" "time" "unicode/utf8" "github.com/google/martian/v3" "github.com/google/martian/v3/log" "github.com/google/martian/v3/messageview" "github.com/google/martian/v3/proxyutil" ) // Logger maintains request and response log entries. type Logger struct { bodyLogging func(*http.Response) bool postDataLogging func(*http.Request) bool creator *Creator mu sync.Mutex entries map[string]*Entry tail *Entry } // HAR is the top level object of a HAR log. type HAR struct { Log *Log `json:"log"` } // Log is the HAR HTTP request and response log. type Log struct { // Version number of the HAR format. Version string `json:"version"` // Creator holds information about the log creator application. Creator *Creator `json:"creator"` // Entries is a list containing requests and responses. Entries []*Entry `json:"entries"` } // Creator is the program responsible for generating the log. Martian, in this case. type Creator struct { // Name of the log creator application. Name string `json:"name"` // Version of the log creator application. Version string `json:"version"` } // Entry is a individual log entry for a request or response. type Entry struct { // ID is the unique ID for the entry. ID string `json:"_id"` // StartedDateTime is the date and time stamp of the request start (ISO 8601). StartedDateTime time.Time `json:"startedDateTime"` // Time is the total elapsed time of the request in milliseconds. Time int64 `json:"time"` // Request contains the detailed information about the request. Request *Request `json:"request"` // Response contains the detailed information about the response. Response *Response `json:"response,omitempty"` // Cache contains information about a request coming from browser cache. Cache *Cache `json:"cache"` // Timings describes various phases within request-response round trip. All // times are specified in milliseconds. Timings *Timings `json:"timings"` next *Entry } // Request holds data about an individual HTTP request. type Request struct { // Method is the request method (GET, POST, ...). Method string `json:"method"` // URL is the absolute URL of the request (fragments are not included). URL string `json:"url"` // HTTPVersion is the Request HTTP version (HTTP/1.1). HTTPVersion string `json:"httpVersion"` // Cookies is a list of cookies. Cookies []Cookie `json:"cookies"` // Headers is a list of headers. Headers []Header `json:"headers"` // QueryString is a list of query parameters. QueryString []QueryString `json:"queryString"` // PostData is the posted data information. PostData *PostData `json:"postData,omitempty"` // HeaderSize is the Total number of bytes from the start of the HTTP request // message until (and including) the double CLRF before the body. Set to -1 // if the info is not available. HeadersSize int64 `json:"headersSize"` // BodySize is the size of the request body (POST data payload) in bytes. Set // to -1 if the info is not available. BodySize int64 `json:"bodySize"` } // Response holds data about an individual HTTP response. type Response struct { // Status is the response status code. Status int `json:"status"` // StatusText is the response status description. StatusText string `json:"statusText"` // HTTPVersion is the Response HTTP version (HTTP/1.1). HTTPVersion string `json:"httpVersion"` // Cookies is a list of cookies. Cookies []Cookie `json:"cookies"` // Headers is a list of headers. Headers []Header `json:"headers"` // Content contains the details of the response body. Content *Content `json:"content"` // RedirectURL is the target URL from the Location response header. RedirectURL string `json:"redirectURL"` // HeadersSize is the total number of bytes from the start of the HTTP // request message until (and including) the double CLRF before the body. // Set to -1 if the info is not available. HeadersSize int64 `json:"headersSize"` // BodySize is the size of the request body (POST data payload) in bytes. Set // to -1 if the info is not available. BodySize int64 `json:"bodySize"` } // Cache contains information about a request coming from browser cache. type Cache struct { // Has no fields as they are not supported, but HAR requires the "cache" // object to exist. } // Timings describes various phases within request-response round trip. All // times are specified in milliseconds type Timings struct { // Send is the time required to send HTTP request to the server. Send int64 `json:"send"` // Wait is the time spent waiting for a response from the server. Wait int64 `json:"wait"` // Receive is the time required to read entire response from server or cache. Receive int64 `json:"receive"` } // Cookie is the data about a cookie on a request or response. type Cookie struct { // Name is the cookie name. Name string `json:"name"` // Value is the cookie value. Value string `json:"value"` // Path is the path pertaining to the cookie. Path string `json:"path,omitempty"` // Domain is the host of the cookie. Domain string `json:"domain,omitempty"` // Expires contains cookie expiration time. Expires time.Time `json:"-"` // Expires8601 contains cookie expiration time in ISO 8601 format. Expires8601 string `json:"expires,omitempty"` // HTTPOnly is set to true if the cookie is HTTP only, false otherwise. HTTPOnly bool `json:"httpOnly,omitempty"` // Secure is set to true if the cookie was transmitted over SSL, false // otherwise. Secure bool `json:"secure,omitempty"` } // Header is an HTTP request or response header. type Header struct { // Name is the header name. Name string `json:"name"` // Value is the header value. Value string `json:"value"` } // QueryString is a query string parameter on a request. type QueryString struct { // Name is the query parameter name. Name string `json:"name"` // Value is the query parameter value. Value string `json:"value"` } // PostData describes posted data on a request. type PostData struct { // MimeType is the MIME type of the posted data. MimeType string `json:"mimeType"` // Params is a list of posted parameters (in case of URL encoded parameters). Params []Param `json:"params"` // Text contains the posted data. Although its type is string, it may contain // binary data. Text string `json:"text"` } // pdBinary is the JSON representation of binary PostData. type pdBinary struct { MimeType string `json:"mimeType"` // Params is a list of posted parameters (in case of URL encoded parameters). Params []Param `json:"params"` Text []byte `json:"text"` Encoding string `json:"encoding"` } // MarshalJSON returns a JSON representation of binary PostData. func (p *PostData) MarshalJSON() ([]byte, error) { if utf8.ValidString(p.Text) { type noMethod PostData // avoid infinite recursion return json.Marshal((*noMethod)(p)) } return json.Marshal(pdBinary{ MimeType: p.MimeType, Params: p.Params, Text: []byte(p.Text), Encoding: "base64", }) } // UnmarshalJSON populates PostData based on the []byte representation of // the binary PostData. func (p *PostData) UnmarshalJSON(data []byte) error { if bytes.Equal(data, []byte("null")) { // conform to json.Unmarshaler spec return nil } var enc struct { Encoding string `json:"encoding"` } if err := json.Unmarshal(data, &enc); err != nil { return err } if enc.Encoding != "base64" { type noMethod PostData // avoid infinite recursion return json.Unmarshal(data, (*noMethod)(p)) } var pb pdBinary if err := json.Unmarshal(data, &pb); err != nil { return err } p.MimeType = pb.MimeType p.Params = pb.Params p.Text = string(pb.Text) return nil } // Param describes an individual posted parameter. type Param struct { // Name of the posted parameter. Name string `json:"name"` // Value of the posted parameter. Value string `json:"value,omitempty"` // Filename of a posted file. Filename string `json:"fileName,omitempty"` // ContentType is the content type of a posted file. ContentType string `json:"contentType,omitempty"` } // Content describes details about response content. type Content struct { // Size is the length of the returned content in bytes. Should be equal to // response.bodySize if there is no compression and bigger when the content // has been compressed. Size int64 `json:"size"` // MimeType is the MIME type of the response text (value of the Content-Type // response header). MimeType string `json:"mimeType"` // Text contains the response body sent from the server or loaded from the // browser cache. This field is populated with fully decoded version of the // respose body. Text []byte `json:"text,omitempty"` // The desired encoding to use for the text field when encoding to JSON. Encoding string `json:"encoding,omitempty"` } // For marshaling Content to and from json. This works around the json library's // default conversion of []byte to base64 encoded string. type contentJSON struct { Size int64 `json:"size"` MimeType string `json:"mimeType"` // Text contains the response body sent from the server or loaded from the // browser cache. This field is populated with textual content only. The text // field is either HTTP decoded text or a encoded (e.g. "base64") // representation of the response body. Leave out this field if the // information is not available. Text string `json:"text,omitempty"` // Encoding used for response text field e.g "base64". Leave out this field // if the text field is HTTP decoded (decompressed & unchunked), than // trans-coded from its original character set into UTF-8. Encoding string `json:"encoding,omitempty"` } // MarshalJSON marshals the byte slice into json after encoding based on c.Encoding. func (c Content) MarshalJSON() ([]byte, error) { var txt string switch c.Encoding { case "base64": txt = base64.StdEncoding.EncodeToString(c.Text) case "": txt = string(c.Text) default: return nil, fmt.Errorf("unsupported encoding for Content.Text: %s", c.Encoding) } cj := contentJSON{ Size: c.Size, MimeType: c.MimeType, Text: txt, Encoding: c.Encoding, } return json.Marshal(cj) } // UnmarshalJSON unmarshals the bytes slice into Content. func (c *Content) UnmarshalJSON(data []byte) error { var cj contentJSON if err := json.Unmarshal(data, &cj); err != nil { return err } var txt []byte var err error switch cj.Encoding { case "base64": txt, err = base64.StdEncoding.DecodeString(cj.Text) if err != nil { return fmt.Errorf("failed to decode base64-encoded Content.Text: %v", err) } case "": txt = []byte(cj.Text) default: return fmt.Errorf("unsupported encoding for Content.Text: %s", cj.Encoding) } c.Size = cj.Size c.MimeType = cj.MimeType c.Text = txt c.Encoding = cj.Encoding return nil } // Option is a configurable setting for the logger. type Option func(l *Logger) // PostDataLogging returns an option that configures request post data logging. func PostDataLogging(enabled bool) Option { return func(l *Logger) { l.postDataLogging = func(*http.Request) bool { return enabled } } } // PostDataLoggingForContentTypes returns an option that logs request bodies based // on opting in to the Content-Type of the request. func PostDataLoggingForContentTypes(cts ...string) Option { return func(l *Logger) { l.postDataLogging = func(req *http.Request) bool { rct := req.Header.Get("Content-Type") for _, ct := range cts { if strings.HasPrefix(strings.ToLower(rct), strings.ToLower(ct)) { return true } } return false } } } // SkipPostDataLoggingForContentTypes returns an option that logs request bodies based // on opting out of the Content-Type of the request. func SkipPostDataLoggingForContentTypes(cts ...string) Option { return func(l *Logger) { l.postDataLogging = func(req *http.Request) bool { rct := req.Header.Get("Content-Type") for _, ct := range cts { if strings.HasPrefix(strings.ToLower(rct), strings.ToLower(ct)) { return false } } return true } } } // BodyLogging returns an option that configures response body logging. func BodyLogging(enabled bool) Option { return func(l *Logger) { l.bodyLogging = func(*http.Response) bool { return enabled } } } // BodyLoggingForContentTypes returns an option that logs response bodies based // on opting in to the Content-Type of the response. func BodyLoggingForContentTypes(cts ...string) Option { return func(l *Logger) { l.bodyLogging = func(res *http.Response) bool { rct := res.Header.Get("Content-Type") for _, ct := range cts { if strings.HasPrefix(strings.ToLower(rct), strings.ToLower(ct)) { return true } } return false } } } // SkipBodyLoggingForContentTypes returns an option that logs response bodies based // on opting out of the Content-Type of the response. func SkipBodyLoggingForContentTypes(cts ...string) Option { return func(l *Logger) { l.bodyLogging = func(res *http.Response) bool { rct := res.Header.Get("Content-Type") for _, ct := range cts { if strings.HasPrefix(strings.ToLower(rct), strings.ToLower(ct)) { return false } } return true } } } // NewLogger returns a HAR logger. The returned // logger logs all request post data and response bodies by default. func NewLogger() *Logger { l := &Logger{ creator: &Creator{ Name: "martian proxy", Version: "2.0.0", }, entries: make(map[string]*Entry), } l.SetOption(BodyLogging(true)) l.SetOption(PostDataLogging(true)) return l } // SetOption sets configurable options on the logger. func (l *Logger) SetOption(opts ...Option) { for _, opt := range opts { opt(l) } } // ModifyRequest logs requests. func (l *Logger) ModifyRequest(req *http.Request) error { ctx := martian.NewContext(req) if ctx.SkippingLogging() { return nil } id := ctx.ID() return l.RecordRequest(id, req) } // RecordRequest logs the HTTP request with the given ID. The ID should be unique // per request/response pair. func (l *Logger) RecordRequest(id string, req *http.Request) error { hreq, err := NewRequest(req, l.postDataLogging(req)) if err != nil { return err } entry := &Entry{ ID: id, StartedDateTime: time.Now().UTC(), Request: hreq, Cache: &Cache{}, Timings: &Timings{}, } l.mu.Lock() defer l.mu.Unlock() if _, exists := l.entries[id]; exists { return fmt.Errorf("Duplicate request ID: %s", id) } l.entries[id] = entry if l.tail == nil { l.tail = entry } entry.next = l.tail.next l.tail.next = entry l.tail = entry return nil } // NewRequest constructs and returns a Request from req. If withBody is true, // req.Body is read to EOF and replaced with a copy in a bytes.Buffer. An error // is returned (and req.Body may be in an intermediate state) if an error is // returned from req.Body.Read. func NewRequest(req *http.Request, withBody bool) (*Request, error) { r := &Request{ Method: req.Method, URL: req.URL.String(), HTTPVersion: req.Proto, HeadersSize: -1, BodySize: req.ContentLength, QueryString: []QueryString{}, Headers: headers(proxyutil.RequestHeader(req).Map()), Cookies: cookies(req.Cookies()), } for n, vs := range req.URL.Query() { for _, v := range vs { r.QueryString = append(r.QueryString, QueryString{ Name: n, Value: v, }) } } pd, err := postData(req, withBody) if err != nil { return nil, err } r.PostData = pd return r, nil } // ModifyResponse logs responses. func (l *Logger) ModifyResponse(res *http.Response) error { ctx := martian.NewContext(res.Request) if ctx.SkippingLogging() { return nil } id := ctx.ID() return l.RecordResponse(id, res) } // RecordResponse logs an HTTP response, associating it with the previously-logged // HTTP request with the same ID. func (l *Logger) RecordResponse(id string, res *http.Response) error { hres, err := NewResponse(res, l.bodyLogging(res)) if err != nil { return err } l.mu.Lock() defer l.mu.Unlock() if e, ok := l.entries[id]; ok { e.Response = hres e.Time = time.Since(e.StartedDateTime).Nanoseconds() / 1000000 } return nil } // NewResponse constructs and returns a Response from resp. If withBody is true, // resp.Body is read to EOF and replaced with a copy in a bytes.Buffer. An error // is returned (and resp.Body may be in an intermediate state) if an error is // returned from resp.Body.Read. func NewResponse(res *http.Response, withBody bool) (*Response, error) { r := &Response{ HTTPVersion: res.Proto, Status: res.StatusCode, StatusText: http.StatusText(res.StatusCode), HeadersSize: -1, BodySize: res.ContentLength, Headers: headers(proxyutil.ResponseHeader(res).Map()), Cookies: cookies(res.Cookies()), } if res.StatusCode >= 300 && res.StatusCode < 400 { r.RedirectURL = res.Header.Get("Location") } r.Content = &Content{ Encoding: "base64", MimeType: res.Header.Get("Content-Type"), } if withBody { mv := messageview.New() if err := mv.SnapshotResponse(res); err != nil { return nil, err } br, err := mv.BodyReader(messageview.Decode()) if err != nil { return nil, err } body, err := ioutil.ReadAll(br) if err != nil { return nil, err } r.Content.Text = body r.Content.Size = int64(len(body)) } return r, nil } // Export returns the in-memory log. func (l *Logger) Export() *HAR { l.mu.Lock() defer l.mu.Unlock() es := make([]*Entry, 0, len(l.entries)) curr := l.tail for curr != nil { curr = curr.next es = append(es, curr) if curr == l.tail { break } } return l.makeHAR(es) } // ExportAndReset returns the in-memory log for completed requests, clearing them. func (l *Logger) ExportAndReset() *HAR { l.mu.Lock() defer l.mu.Unlock() es := make([]*Entry, 0, len(l.entries)) curr := l.tail prev := l.tail var first *Entry for curr != nil { curr = curr.next if curr.Response != nil { es = append(es, curr) delete(l.entries, curr.ID) } else { if first == nil { first = curr } prev.next = curr prev = curr } if curr == l.tail { break } } if len(l.entries) == 0 { l.tail = nil } else { l.tail = prev l.tail.next = first } return l.makeHAR(es) } func (l *Logger) makeHAR(es []*Entry) *HAR { return &HAR{ Log: &Log{ Version: "1.2", Creator: l.creator, Entries: es, }, } } // Reset clears the in-memory log of entries. func (l *Logger) Reset() { l.mu.Lock() defer l.mu.Unlock() l.entries = make(map[string]*Entry) l.tail = nil } func cookies(cs []*http.Cookie) []Cookie { hcs := make([]Cookie, 0, len(cs)) for _, c := range cs { var expires string if !c.Expires.IsZero() { expires = c.Expires.Format(time.RFC3339) } hcs = append(hcs, Cookie{ Name: c.Name, Value: c.Value, Path: c.Path, Domain: c.Domain, HTTPOnly: c.HttpOnly, Secure: c.Secure, Expires: c.Expires, Expires8601: expires, }) } return hcs } func headers(hs http.Header) []Header { hhs := make([]Header, 0, len(hs)) for n, vs := range hs { for _, v := range vs { hhs = append(hhs, Header{ Name: n, Value: v, }) } } return hhs } func postData(req *http.Request, logBody bool) (*PostData, error) { // If the request has no body (no Content-Length and Transfer-Encoding isn't // chunked), skip the post data. if req.ContentLength <= 0 && len(req.TransferEncoding) == 0 { return nil, nil } ct := req.Header.Get("Content-Type") mt, ps, err := mime.ParseMediaType(ct) if err != nil { log.Errorf("har: cannot parse Content-Type header %q: %v", ct, err) mt = ct } pd := &PostData{ MimeType: mt, Params: []Param{}, } if !logBody { return pd, nil } mv := messageview.New() if err := mv.SnapshotRequest(req); err != nil { return nil, err } br, err := mv.BodyReader() if err != nil { return nil, err } switch mt { case "multipart/form-data": mpr := multipart.NewReader(br, ps["boundary"]) for { p, err := mpr.NextPart() if err == io.EOF { break } if err != nil { return nil, err } defer p.Close() body, err := ioutil.ReadAll(p) if err != nil { return nil, err } pd.Params = append(pd.Params, Param{ Name: p.FormName(), Filename: p.FileName(), ContentType: p.Header.Get("Content-Type"), Value: string(body), }) } case "application/x-www-form-urlencoded": body, err := ioutil.ReadAll(br) if err != nil { return nil, err } vs, err := url.ParseQuery(string(body)) if err != nil { return nil, err } for n, vs := range vs { for _, v := range vs { pd.Params = append(pd.Params, Param{ Name: n, Value: v, }) } } default: body, err := ioutil.ReadAll(br) if err != nil { return nil, err } pd.Text = string(body) } return pd, nil } martian-3.3.2/har/har_handlers.go000066400000000000000000000052711421371434000167110ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package har import ( "encoding/json" "net/http" "net/url" "strconv" "github.com/google/martian/v3/log" ) type exportHandler struct { logger *Logger } type resetHandler struct { logger *Logger } // NewExportHandler returns an http.Handler for requesting HAR logs. func NewExportHandler(l *Logger) http.Handler { return &exportHandler{ logger: l, } } // NewResetHandler returns an http.Handler for clearing in-memory log entries. func NewResetHandler(l *Logger) http.Handler { return &resetHandler{ logger: l, } } // ServeHTTP writes the log in HAR format to the response body. func (h *exportHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { if req.Method != "GET" { rw.Header().Add("Allow", "GET") rw.WriteHeader(http.StatusMethodNotAllowed) log.Errorf("har.ServeHTTP: method not allowed: %s", req.Method) return } log.Debugf("exportHandler.ServeHTTP: writing HAR logs to ResponseWriter") rw.Header().Set("Content-Type", "application/json; charset=utf-8") hl := h.logger.Export() json.NewEncoder(rw).Encode(hl) } // ServeHTTP resets the log, which clears its entries. func (h *resetHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { if !(req.Method == "POST" || req.Method == "DELETE") { rw.Header().Add("Allow", "POST") rw.Header().Add("Allow", "DELETE") rw.WriteHeader(http.StatusMethodNotAllowed) log.Errorf("har: method not allowed: %s", req.Method) return } v, err := parseBoolQueryParam(req.URL.Query(), "return") if err != nil { log.Errorf("har: invalid value for return param: %s", err) rw.WriteHeader(http.StatusBadRequest) return } if v { rw.Header().Set("Content-Type", "application/json; charset=utf-8") hl := h.logger.ExportAndReset() json.NewEncoder(rw).Encode(hl) } else { h.logger.Reset() rw.WriteHeader(http.StatusNoContent) } log.Infof("resetHandler.ServeHTTP: HAR logs cleared") } func parseBoolQueryParam(params url.Values, name string) (bool, error) { if params[name] == nil { return false, nil } v, err := strconv.ParseBool(params.Get("return")) if err != nil { return false, err } return v, nil } martian-3.3.2/har/har_handlers_test.go000066400000000000000000000101521421371434000177420ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package har import ( "encoding/json" "net/http" "net/http/httptest" "testing" "github.com/google/martian/v3" "github.com/google/martian/v3/proxyutil" ) func TestExportHandlerServeHTTP(t *testing.T) { logger := NewLogger() req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } _, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("martian.TestContext(): got %v, want no error", err) } defer remove() if err := logger.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } res := proxyutil.NewResponse(200, nil, req) if err := logger.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } h := NewExportHandler(logger) req, err = http.NewRequest("GET", "/", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } rw := httptest.NewRecorder() h.ServeHTTP(rw, req) if got, want := rw.Code, http.StatusOK; got != want { t.Errorf("rw.Code: got %d, want %d", got, want) } hl := &HAR{} if err := json.Unmarshal(rw.Body.Bytes(), hl); err != nil { t.Fatalf("json.Unmarshal(): got %v, want no error", err) } if got, want := len(hl.Log.Entries), 1; got != want { t.Fatalf("len(hl.Log.Entries): got %v, want %v", got, want) } entry := hl.Log.Entries[0] if got, want := entry.Request.URL, "http://example.com"; got != want { t.Errorf("Request.URL: got %q, want %q", got, want) } if got, want := entry.Response.Status, 200; got != want { t.Errorf("Response.Status: got %d, want %d", got, want) } rh := NewResetHandler(logger) req, err = http.NewRequest("DELETE", "/", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } rw = httptest.NewRecorder() rh.ServeHTTP(rw, req) req, err = http.NewRequest("GET", "/", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } rw = httptest.NewRecorder() h.ServeHTTP(rw, req) if got, want := rw.Code, http.StatusOK; got != want { t.Errorf("rw.Code: got %d, want %d", got, want) } hl = &HAR{} if err := json.Unmarshal(rw.Body.Bytes(), hl); err != nil { t.Fatalf("json.Unmarshal(): got %v, want no error", err) } if got, want := len(hl.Log.Entries), 0; got != want { t.Errorf("len(Log.Entries): got %v, want %v", got, want) } req, err = http.NewRequest("DELETE", "/?return=1", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } rw = httptest.NewRecorder() rh.ServeHTTP(rw, req) if got, want := rw.Code, http.StatusOK; got != want { t.Errorf("rw.Code: got %d, want %d", got, want) } hl = &HAR{} if err := json.Unmarshal(rw.Body.Bytes(), hl); err != nil { t.Fatalf("json.Unmarshal(): got %v, want no error", err) } if got, want := len(hl.Log.Entries), 0; got != want { t.Errorf("len(Log.Entries): got %v, want %v", got, want) } req, err = http.NewRequest("DELETE", "/?return=0", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } rw = httptest.NewRecorder() rh.ServeHTTP(rw, req) if got, want := rw.Code, http.StatusNoContent; got != want { t.Errorf("rw.Code: got %d, want %d", got, want) } req, err = http.NewRequest("DELETE", "/?return=notboolean", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } rw = httptest.NewRecorder() rh.ServeHTTP(rw, req) if got, want := rw.Code, http.StatusBadRequest; got != want { t.Errorf("rw.Code: got %d, want %d", got, want) } } martian-3.3.2/har/har_test.go000066400000000000000000000635021421371434000160710ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package har import ( "bytes" "encoding/json" "mime/multipart" "net/http" "reflect" "strings" "testing" "time" "github.com/google/martian/v3" "github.com/google/martian/v3/proxyutil" ) func TestModifyRequest(t *testing.T) { req, err := http.NewRequest("GET", "http://example.com/path?query=true", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Header.Add("Request-Header", "first") req.Header.Add("Request-Header", "second") cookie := &http.Cookie{ Name: "request", Value: "cookie", } req.AddCookie(cookie) _, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("martian.TestContext(): got %v, want no error", err) } defer remove() logger := NewLogger() if err := logger.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } log := logger.Export().Log if got, want := log.Version, "1.2"; got != want { t.Errorf("log.Version: got %q, want %q", got, want) } if got, want := len(log.Entries), 1; got != want { t.Fatalf("len(log.Entries): got %d, want %d", got, want) } entry := log.Entries[0] if got, want := time.Since(entry.StartedDateTime), time.Second; got > want { t.Errorf("entry.StartedDateTime: got %s, want less than %s", got, want) } hreq := entry.Request if got, want := hreq.Method, "GET"; got != want { t.Errorf("hreq.Method: got %q, want %q", got, want) } if got, want := hreq.URL, "http://example.com/path?query=true"; got != want { t.Errorf("hreq.URL: got %q, want %q", got, want) } if got, want := hreq.HTTPVersion, "HTTP/1.1"; got != want { t.Errorf("hreq.HTTPVersion: got %q, want %q", got, want) } if got, want := hreq.BodySize, int64(0); got != want { t.Errorf("hreq.BodySize: got %d, want %d", got, want) } if got, want := hreq.HeadersSize, int64(-1); got != want { t.Errorf("hreq.HeadersSize: got %d, want %d", got, want) } if got, want := len(hreq.QueryString), 1; got != want { t.Fatalf("len(hreq.QueryString): got %d, want %q", got, want) } qs := hreq.QueryString[0] if got, want := qs.Name, "query"; got != want { t.Errorf("qs.Name: got %q, want %q", got, want) } if got, want := qs.Value, "true"; got != want { t.Errorf("qs.Value: got %q, want %q", got, want) } wantHeaders := http.Header{ "Request-Header": {"first", "second"}, "Cookie": {cookie.String()}, "Host": {"example.com"}, } if got := headersToHTTP(hreq.Headers); !reflect.DeepEqual(got, wantHeaders) { t.Errorf("headers:\ngot:\n%+v\nwant:\n%+v", got, wantHeaders) } if got, want := len(hreq.Cookies), 1; got != want { t.Fatalf("len(hreq.Cookies): got %d, want %d", got, want) } hcookie := hreq.Cookies[0] if got, want := hcookie.Name, "request"; got != want { t.Errorf("hcookie.Name: got %q, want %q", got, want) } if got, want := hcookie.Value, "cookie"; got != want { t.Errorf("hcookie.Value: got %q, want %q", got, want) } } func headersToHTTP(hs []Header) http.Header { hh := http.Header{} for _, h := range hs { hh[h.Name] = append(hh[h.Name], h.Value) } return hh } func TestModifyResponse(t *testing.T) { req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } _, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("martian.TestContext(): got %v, want no error", err) } defer remove() res := proxyutil.NewResponse(301, strings.NewReader("response body"), req) res.ContentLength = 13 res.Header.Add("Response-Header", "first") res.Header.Add("Response-Header", "second") res.Header.Set("Location", "google.com") expires := time.Now() cookie := &http.Cookie{ Name: "response", Value: "cookie", Path: "/", Domain: "example.com", Expires: expires, Secure: true, HttpOnly: true, } res.Header.Set("Set-Cookie", cookie.String()) logger := NewLogger() if err := logger.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if err := logger.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } log := logger.Export().Log if got, want := len(log.Entries), 1; got != want { t.Fatalf("len(log.Entries): got %d, want %d", got, want) } hres := log.Entries[0].Response if got, want := hres.Status, 301; got != want { t.Errorf("hres.Status: got %d, want %d", got, want) } if got, want := hres.StatusText, "Moved Permanently"; got != want { t.Errorf("hres.StatusText: got %q, want %q", got, want) } if got, want := hres.HTTPVersion, "HTTP/1.1"; got != want { t.Errorf("hres.HTTPVersion: got %q, want %q", got, want) } if got, want := hres.Content.Text, []byte("response body"); !bytes.Equal(got, want) { t.Errorf("hres.Content.Text: got %q, want %q", got, want) } wantHeaders := http.Header{ "Response-Header": {"first", "second"}, "Set-Cookie": {cookie.String()}, "Location": {"google.com"}, "Content-Length": {"13"}, } if got := headersToHTTP(hres.Headers); !reflect.DeepEqual(got, wantHeaders) { t.Errorf("headers:\ngot:\n%+v\nwant:\n%+v", got, wantHeaders) } if got, want := len(hres.Cookies), 1; got != want { t.Fatalf("len(hres.Cookies): got %d, want %d", got, want) } hcookie := hres.Cookies[0] if got, want := hcookie.Name, "response"; got != want { t.Errorf("hcookie.Name: got %q, want %q", got, want) } if got, want := hcookie.Value, "cookie"; got != want { t.Errorf("hcookie.Value: got %q, want %q", got, want) } if got, want := hcookie.Path, "/"; got != want { t.Errorf("hcookie.Path: got %q, want %q", got, want) } if got, want := hcookie.Domain, "example.com"; got != want { t.Errorf("hcookie.Domain: got %q, want %q", got, want) } if got, want := hcookie.Expires, expires; got.Equal(want) { t.Errorf("hcookie.Expires: got %s, want %s", got, want) } if !hcookie.HTTPOnly { t.Error("hcookie.HTTPOnly: got false, want true") } if !hcookie.Secure { t.Error("hcookie.Secure: got false, want true") } } func TestModifyRequestBodyURLEncoded(t *testing.T) { logger := NewLogger() body := strings.NewReader("first=true&second=false") req, err := http.NewRequest("POST", "http://example.com", body) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") _, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("martian.TestContext(): got %v, want no error", err) } defer remove() if err := logger.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } log := logger.Export().Log if got, want := len(log.Entries), 1; got != want { t.Errorf("len(log.Entries): got %v, want %v", got, want) } pd := log.Entries[0].Request.PostData if got, want := pd.MimeType, "application/x-www-form-urlencoded"; got != want { t.Errorf("PostData.MimeType: got %v, want %v", got, want) } if got, want := len(pd.Params), 2; got != want { t.Fatalf("len(PostData.Params): got %d, want %d", got, want) } for _, p := range pd.Params { var want string switch p.Name { case "first": want = "true" case "second": want = "false" default: t.Errorf("PostData.Params: got %q, want to not be present", p.Name) continue } if got := p.Value; got != want { t.Errorf("PostData.Params[%q]: got %q, want %q", p.Name, got, want) } } } func TestModifyRequestBodyArbitraryContentType(t *testing.T) { logger := NewLogger() body := "arbitrary binary data" req, err := http.NewRequest("POST", "http://www.example.com", strings.NewReader(body)) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } _, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("martian.TestContext(): got %v, want no error", err) } defer remove() if err := logger.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } log := logger.Export().Log if got, want := len(log.Entries), 1; got != want { t.Errorf("len(log.Entries): got %d, want %d", got, want) } pd := log.Entries[0].Request.PostData if got, want := pd.MimeType, ""; got != want { t.Errorf("PostData.MimeType: got %q, want %q", got, want) } if got, want := len(pd.Params), 0; got != want { t.Errorf("len(PostData.Params): got %d, want %d", got, want) } if got, want := pd.Text, body; got != want { t.Errorf("PostData.Text: got %q, want %q", got, want) } } func TestModifyRequestBodyMultipart(t *testing.T) { logger := NewLogger() body := new(bytes.Buffer) mpw := multipart.NewWriter(body) mpw.SetBoundary("boundary") if err := mpw.WriteField("key", "value"); err != nil { t.Errorf("mpw.WriteField(): got %v, want no error", err) } w, err := mpw.CreateFormFile("file", "test.txt") if _, err = w.Write([]byte("file contents")); err != nil { t.Fatalf("Write(): got %v, want no error", err) } mpw.Close() req, err := http.NewRequest("POST", "http://example.com", body) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Header.Set("Content-Type", mpw.FormDataContentType()) _, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("martian.TestContext(): got %v, want no error", err) } defer remove() if err := logger.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } log := logger.Export().Log if got, want := len(log.Entries), 1; got != want { t.Fatalf("len(log.Entries): got %d, want %d", got, want) } pd := log.Entries[0].Request.PostData if got, want := pd.MimeType, "multipart/form-data"; got != want { t.Errorf("PostData.MimeType: got %q, want %q", got, want) } if got, want := len(pd.Params), 2; got != want { t.Errorf("PostData.Params: got %d, want %d", got, want) } for _, p := range pd.Params { var want Param switch p.Name { case "key": want = Param{ Filename: "", ContentType: "", Value: "value", } case "file": want = Param{ Filename: "test.txt", ContentType: "application/octet-stream", Value: "file contents", } default: t.Errorf("pd.Params: got %q, want not to be present", p.Name) continue } if got, want := p.Filename, want.Filename; got != want { t.Errorf("p.Filename: got %q, want %q", got, want) } if got, want := p.ContentType, want.ContentType; got != want { t.Errorf("p.ContentType: got %q, want %q", got, want) } if got, want := p.Value, want.Value; got != want { t.Errorf("p.Value: got %q, want %q", got, want) } } } func TestModifyRequestErrorsOnDuplicateRequest(t *testing.T) { logger := NewLogger() req, err := http.NewRequest("POST", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } _, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("martian.TestContext(): got %v, want no error", err) } defer remove() if err := logger.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if logger.ModifyRequest(req) == nil { t.Fatalf("ModifyRequest(): was supposed to error") } } func TestHARExportsTime(t *testing.T) { logger := NewLogger() req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } _, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("martian.TestContext(): got %v, want no error", err) } defer remove() if err := logger.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } // Simulate fast network round trip. time.Sleep(10 * time.Millisecond) res := proxyutil.NewResponse(200, nil, req) if err := logger.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } log := logger.Export().Log if got, want := len(log.Entries), 1; got != want { t.Fatalf("len(log.Entries): got %v, want %v", got, want) } entry := log.Entries[0] min, max := int64(10), int64(100) if got := entry.Time; got < min || got > max { t.Errorf("entry.Time: got %dms, want between %dms and %vms", got, min, max) } } func TestReset(t *testing.T) { logger := NewLogger() req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } _, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("martian.TestContext(): got %v, want no error", err) } defer remove() if err := logger.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } log := logger.Export().Log if got, want := len(log.Entries), 1; got != want { t.Fatalf("len(log.Entries): got %d, want %d", got, want) } logger.Reset() log = logger.Export().Log if got, want := len(log.Entries), 0; got != want { t.Errorf("len(log.Entries): got %d, want %d", got, want) } } func TestExportSortsEntries(t *testing.T) { logger := NewLogger() count := 10 for i := 0; i < count; i++ { req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } _, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("martian.TestContext(): got %v, want no error", err) } defer remove() if err := logger.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } } log := logger.Export().Log for i := 0; i < count-1; i++ { first := log.Entries[i] second := log.Entries[i+1] if got, want := first.StartedDateTime, second.StartedDateTime; got.After(want) { t.Errorf("entry.StartedDateTime: got %s, want to be before %s", got, want) } } } func TestExportIgnoresOrphanedResponse(t *testing.T) { logger := NewLogger() req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } _, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("martian.TestContext(): got %v, want no error", err) } defer remove() if err := logger.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } // Reset before the response comes back. logger.Reset() res := proxyutil.NewResponse(200, nil, req) if err := logger.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } log := logger.Export().Log if got, want := len(log.Entries), 0; got != want { t.Errorf("len(log.Entries): got %d, want %d", got, want) } } func TestExportAndResetResetsCompleteRequests(t *testing.T) { logger := NewLogger() req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } _, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("martian.TestContext(): got %v, want no error", err) } defer remove() if err := logger.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } res := proxyutil.NewResponse(200, nil, req) if err := logger.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } logger.ExportAndReset() log := logger.Export().Log if got, want := len(log.Entries), 0; got != want { t.Errorf("len(log.Entries): got %d, want %d", got, want) } } func TestExportAndResetLeavesPendingRequests(t *testing.T) { logger := NewLogger() req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } _, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("martian.TestContext(): got %v, want no error", err) } defer remove() if err := logger.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } logger.ExportAndReset() log := logger.Export().Log if got, want := len(log.Entries), 1; got != want { t.Errorf("len(log.Entries): got %d, want %d", got, want) } } func TestExportAndResetExportsCompleteRequests(t *testing.T) { logger := NewLogger() req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } _, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("martian.TestContext(): got %v, want no error", err) } defer remove() if err := logger.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } res := proxyutil.NewResponse(200, nil, req) if err := logger.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } log := logger.ExportAndReset().Log if got, want := len(log.Entries), 1; got != want { t.Errorf("len(log.Entries): got %d, want %d", got, want) } } func TestExportAndResetExportsCompleteRequestsWithPendingLeft(t *testing.T) { logger := NewLogger() req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } _, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("martian.TestContext(): got %v, want no error", err) } defer remove() if err := logger.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } req, err = http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } _, remove, err = martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("martian.TestContext(): got %v, want no error", err) } defer remove() if err := logger.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } req, err = http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } _, remove, err = martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("martian.TestContext(): got %v, want no error", err) } defer remove() if err := logger.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } res := proxyutil.NewResponse(200, nil, req) if err := logger.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } log := logger.ExportAndReset().Log if got, want := len(log.Entries), 1; got != want { t.Errorf("len(log.Entries): got %d, want %d", got, want) } log = logger.Export().Log if got, want := len(log.Entries), 2; got != want { t.Errorf("len(log.Entries): got %d, want %d", got, want) } } func TestSkippingLogging(t *testing.T) { req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } ctx, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("martian.TestContext(): got %v, want no error", err) } defer remove() ctx.SkipLogging() logger := NewLogger() if err := logger.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } res := proxyutil.NewResponse(200, nil, req) if err := logger.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } log := logger.Export().Log if got, want := len(log.Entries), 0; got != want { t.Fatalf("len(log.Entries): got %d, want %d", got, want) } } func TestOptionResponseBodyLogging(t *testing.T) { req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } _, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("martian.TestContext(): got %v, want no error", err) } defer remove() bdr := strings.NewReader("{\"response\": \"body\"}") res := proxyutil.NewResponse(200, bdr, req) res.ContentLength = int64(bdr.Len()) res.Header.Set("Content-Type", "application/json") logger := NewLogger() if err := logger.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if err := logger.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } log := logger.Export().Log if got, want := len(log.Entries), 1; got != want { t.Fatalf("len(log.Entries): got %d, want %d", got, want) } if got, want := string(log.Entries[0].Response.Content.Text), "{\"response\": \"body\"}"; got != want { t.Fatalf("log.Entries[0].Response.Content.Text: got %s, want %s", got, want) } logger = NewLogger() logger.SetOption(BodyLogging(false)) if err := logger.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if err := logger.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } log = logger.Export().Log if got, want := len(log.Entries), 1; got != want { t.Fatalf("len(log.Entries): got %d, want %d", got, want) } if got, want := string(log.Entries[0].Response.Content.Text), ""; got != want { t.Fatalf("log.Entries[0].Response.Content: got %s, want %s", got, want) } logger = NewLogger() logger.SetOption(BodyLoggingForContentTypes("application/json")) if err := logger.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if err := logger.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } log = logger.Export().Log if got, want := string(log.Entries[0].Response.Content.Text), "{\"response\": \"body\"}"; got != want { t.Fatalf("log.Entries[0].Response.Content: got %s, want %s", got, want) } logger = NewLogger() logger.SetOption(SkipBodyLoggingForContentTypes("application/json")) if err := logger.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if err := logger.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } log = logger.Export().Log if got, want := string(log.Entries[0].Response.Content.Text), ""; got != want { t.Fatalf("log.Entries[0].Response.Content: got %v, want %v", got, want) } } func TestOptionRequestPostDataLogging(t *testing.T) { logger := NewLogger() logger.SetOption(PostDataLoggingForContentTypes("application/x-www-form-urlencoded")) body := strings.NewReader("first=true&second=false") req, err := http.NewRequest("POST", "http://example.com", body) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") _, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("martian.TestContext(): got %v, want no error", err) } defer remove() if err := logger.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } log := logger.Export().Log for _, param := range log.Entries[0].Request.PostData.Params { if param.Name == "first" { if got, want := param.Value, "true"; got != want { t.Fatalf("Params[%q].Value: got %s, want %s", param.Name, got, want) } } if param.Name == "second" { if got, want := param.Value, "false"; got != want { t.Fatalf("Params[%q].Value: got %s, want %s", param.Name, got, want) } } } logger = NewLogger() logger.SetOption(SkipPostDataLoggingForContentTypes("application/x-www-form-urlencoded")) body = strings.NewReader("first=true&second=false") req, err = http.NewRequest("POST", "http://example.com", body) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") _, remove, err = martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("martian.TestContext(): got %v, want no error", err) } defer remove() if err := logger.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } log = logger.Export().Log if got, want := len(log.Entries[0].Request.PostData.Params), 0; got != want { t.Fatalf("len(log.Entries[0].Request.PostData.Params): got %v, want %v", got, want) } } func TestJSONMarshalPostData(t *testing.T) { // Verify that encoding/json round-trips har.PostData with both text and binary data. for _, text := range []string{"hello", string([]byte{150, 151, 152})} { want := &PostData{ MimeType: "m", Params: []Param{{Name: "n", Value: "v"}}, Text: text, } data, err := json.Marshal(want) if err != nil { t.Fatal(err) } var got PostData if err := json.Unmarshal(data, &got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(&got, want) { t.Errorf("got %+v, want %+v", &got, want) } } } func TestJSONMarshalContent(t *testing.T) { testCases := []struct { name string text []byte encoding string }{ { name: "binary data with base64 encoding", text: []byte{120, 31, 99, 3}, encoding: "base64", }, { name: "ascii data with no encoding", text: []byte("hello martian"), }, { name: "ascii data with base64 encoding", text: []byte("hello martian"), encoding: "base64", }, } for _, c := range testCases { want := Content{ Size: int64(len(c.text)), MimeType: "application/x-test", Text: c.text, Encoding: c.encoding, } data, err := json.Marshal(want) if err != nil { t.Fatal(err) } var got Content if err := json.Unmarshal(data, &got); err != nil { t.Fatal(err) } if !reflect.DeepEqual(got, want) { t.Errorf("got %+v, want %+v", got, want) } } } martian-3.3.2/header/000077500000000000000000000000001421371434000144015ustar00rootroot00000000000000martian-3.3.2/header/copy_modifier.go000066400000000000000000000045251421371434000175660ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package header import ( "encoding/json" "net/http" "github.com/google/martian/v3" "github.com/google/martian/v3/log" "github.com/google/martian/v3/parse" "github.com/google/martian/v3/proxyutil" ) func init() { parse.Register("header.Copy", copyModifierFromJSON) } type copyModifier struct { from, to string } type copyModifierJSON struct { From string `json:"from"` To string `json:"to"` Scope []parse.ModifierType `json:"scope"` } // ModifyRequest copies the header in from to the request header for to. func (m *copyModifier) ModifyRequest(req *http.Request) error { log.Debugf("header: copyModifier.ModifyRequest %s, from: %s, to: %s", req.URL, m.from, m.to) h := proxyutil.RequestHeader(req) return h.Set(m.to, h.Get(m.from)) } // ModifyResponse copies the header in from to the response header for to. func (m *copyModifier) ModifyResponse(res *http.Response) error { log.Debugf("header: copyModifier.ModifyResponse %s, from: %s, to: %s", res.Request.URL, m.from, m.to) h := proxyutil.ResponseHeader(res) return h.Set(m.to, h.Get(m.from)) } // NewCopyModifier returns a modifier that will copy the header in from to the // header in to. func NewCopyModifier(from, to string) martian.RequestResponseModifier { return ©Modifier{ from: from, to: to, } } // copyModifierFromJSON builds a copy modifier from JSON. // // Example JSON: // { // "header.Copy": { // "scope": ["request", "response"], // "from": "Original-Header", // "to": "Copy-Header" // } // } func copyModifierFromJSON(b []byte) (*parse.Result, error) { msg := ©ModifierJSON{} if err := json.Unmarshal(b, msg); err != nil { return nil, err } return parse.NewResult(NewCopyModifier(msg.From, msg.To), msg.Scope) } martian-3.3.2/header/copy_modifier_test.go000066400000000000000000000054041421371434000206220ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package header import ( "net/http" "testing" "github.com/google/martian/v3/parse" "github.com/google/martian/v3/proxyutil" ) func TestCopyModifier(t *testing.T) { m := NewCopyModifier("Original", "Copy") req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Header.Set("Original", "test") if err := m.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := req.Header.Get("Copy"), "test"; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", "Copy", got, want) } res := proxyutil.NewResponse(200, nil, req) res.Header.Set("Original", "test") if err := m.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := res.Header.Get("Copy"), "test"; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Copy", got, want) } } func TestCopyModifierFromJSON(t *testing.T) { msg := []byte(`{ "header.Copy": { "from": "Original", "to": "Copy", "scope": ["request", "response"] } }`) r, err := parse.FromJSON(msg) if err != nil { t.Fatalf("parse.FromJSON(): got %v, want no error", err) } req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %q, want no error", err) } req.Header.Set("Original", "test") reqmod := r.RequestModifier() if reqmod == nil { t.Fatal("reqmod: got nil, want not nil") } if err := reqmod.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := req.Header.Get("Copy"), "test"; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", "Copy", got, want) } resmod := r.ResponseModifier() if resmod == nil { t.Fatal("resmod: got nil, want not nil") } res := proxyutil.NewResponse(200, nil, req) res.Header.Set("Original", "test") if err := resmod.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := res.Header.Get("Copy"), "test"; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Copy", got, want) } } martian-3.3.2/header/forwarded_modifier.go000066400000000000000000000034211421371434000205630ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package header import ( "net" "net/http" "github.com/google/martian/v3" ) // NewForwardedModifier sets the X-Forwarded-For, X-Forwarded-Proto, // X-Forwarded-Host, and X-Forwarded-Url headers. // // If X-Forwarded-For is already present, the client IP is appended to // the existing value. X-Forwarded-Proto, X-Forwarded-Host, and // X-Forwarded-Url are preserved if already present. // // TODO: Support "Forwarded" header. // see: http://tools.ietf.org/html/rfc7239 func NewForwardedModifier() martian.RequestModifier { return martian.RequestModifierFunc( func(req *http.Request) error { if v := req.Header.Get("X-Forwarded-Proto"); v == "" { req.Header.Set("X-Forwarded-Proto", req.URL.Scheme) } if v := req.Header.Get("X-Forwarded-Host"); v == "" { req.Header.Set("X-Forwarded-Host", req.Host) } if v := req.Header.Get("X-Forwarded-Url"); v == "" { req.Header.Set("X-Forwarded-Url", req.URL.String()) } xff, _, err := net.SplitHostPort(req.RemoteAddr) if err != nil { xff = req.RemoteAddr } if v := req.Header.Get("X-Forwarded-For"); v != "" { xff = v + ", " + xff } req.Header.Set("X-Forwarded-For", xff) return nil }) } martian-3.3.2/header/forwarded_modifier_test.go000066400000000000000000000055311421371434000216260ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package header import ( "net/http" "testing" ) func TestSetForwardHeaders(t *testing.T) { xfp := "X-Forwarded-Proto" xff := "X-Forwarded-For" xfh := "X-Forwarded-Host" xfu := "X-Forwarded-Url" m := NewForwardedModifier() req, err := http.NewRequest("GET", "http://martian.local?key=value", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.RemoteAddr = "10.0.0.1:8112" if m.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := req.Header.Get(xfp), "http"; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", xfp, got, want) } if got, want := req.Header.Get(xff), "10.0.0.1"; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", xff, got, want) } if got, want := req.Header.Get(xfh), "martian.local"; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", xfh, got, want) } if got, want := req.Header.Get(xfu), "http://martian.local?key=value"; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", xfh, got, want) } // Test with existing X-Forwarded-For. req.RemoteAddr = "12.12.12.12" if m.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := req.Header.Get(xff), "10.0.0.1, 12.12.12.12"; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", xff, got, want) } // Test that proto, host, and URL headers are preserved if already present. req, err = http.NewRequest("GET", "http://example.com/path?k=v", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Header.Set(xfp, "https") req.Header.Set(xfh, "preserved.host.com") req.Header.Set(xfu, "https://preserved.host.com/foo?x=y") if m.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := req.Header.Get(xfp), "https"; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", xfh, got, want) } if got, want := req.Header.Get(xfh), "preserved.host.com"; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", xfh, got, want) } if got, want := req.Header.Get(xfu), "https://preserved.host.com/foo?x=y"; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", xfh, got, want) } } martian-3.3.2/header/framing_modifier.go000066400000000000000000000050741421371434000202370ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package header import ( "fmt" "net/http" "strings" "github.com/google/martian/v3" ) // NewBadFramingModifier makes a best effort to fix inconsistencies in the // request such as multiple Content-Lengths or the lack of Content-Length and // improper Transfer-Encoding. If it is unable to determine a proper resolution // it returns an error. // // http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-14#section-3.3 func NewBadFramingModifier() martian.RequestModifier { return martian.RequestModifierFunc( func(req *http.Request) error { cls := req.Header["Content-Length"] if len(cls) > 0 { var length string // Iterate over all Content-Length headers, splitting any we find with // commas, and check that all Content-Lengths are equal. for _, ls := range cls { for _, l := range strings.Split(ls, ",") { // First length, set it as the canonical Content-Length. if length == "" { length = strings.TrimSpace(l) continue } // Mismatched Content-Lengths. if length != strings.TrimSpace(l) { return fmt.Errorf(`bad request framing: multiple mismatched "Content-Length" headers: %v`, cls) } } } // All Content-Lengths are equal, remove extras and set it to the // canonical value. req.Header.Set("Content-Length", length) } tes := req.Header["Transfer-Encoding"] if len(tes) > 0 { // Extract the last Transfer-Encoding value, and split on commas. last := strings.Split(tes[len(tes)-1], ",") // Check that the last, potentially comma-delimited, value is // "chunked", else we have no way to determine when the request is // finished. if strings.TrimSpace(last[len(last)-1]) != "chunked" { return fmt.Errorf(`bad request framing: "Transfer-Encoding" header is present, but does not end in "chunked"`) } // Transfer-Encoding "chunked" takes precedence over // Content-Length. req.Header.Del("Content-Length") } return nil }) } martian-3.3.2/header/framing_modifier_test.go000066400000000000000000000041231421371434000212700ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package header import ( "net/http" "reflect" "testing" ) func TestBadFramingMultipleContentLengths(t *testing.T) { m := NewBadFramingModifier() req, err := http.NewRequest("GET", "/", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Header["Content-Length"] = []string{"42", "42, 42"} if err := m.ModifyRequest(req); err != nil { t.Errorf("ModifyRequest(): got %v, want no error", err) } if got, want := req.Header["Content-Length"], []string{"42"}; !reflect.DeepEqual(got, want) { t.Errorf("req.Header[%q]: got %v, want %v", "Content-Length", got, want) } req.Header["Content-Length"] = []string{"42", "32, 42"} if err := m.ModifyRequest(req); err == nil { t.Error("ModifyRequest(): got nil, want error") } } func TestBadFramingTransferEncodingAndContentLength(t *testing.T) { m := NewBadFramingModifier() req, err := http.NewRequest("GET", "/", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Header["Transfer-Encoding"] = []string{"gzip, chunked"} req.Header["Content-Length"] = []string{"42"} if err := m.ModifyRequest(req); err != nil { t.Errorf("ModifyRequest(): got %v, want no error", err) } if _, ok := req.Header["Content-Length"]; ok { t.Fatalf("req.Header[%q]: got ok, want !ok", "Content-Length") } req.Header.Set("Transfer-Encoding", "gzip, identity") req.Header.Del("Content-Length") if err := m.ModifyRequest(req); err == nil { t.Error("ModifyRequest(): got nil, want error") } } martian-3.3.2/header/header_append_modifier.go000066400000000000000000000044451421371434000213740ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package header import ( "encoding/json" "net/http" "github.com/google/martian/v3" "github.com/google/martian/v3/parse" "github.com/google/martian/v3/proxyutil" ) func init() { parse.Register("header.Append", appendModifierFromJSON) } type appendModifier struct { name, value string } type appendModifierJSON struct { Name string `json:"name"` Value string `json:"value"` Scope []parse.ModifierType `json:"scope"` } // ModifyRequest appends the header at name with value to the request. func (m *appendModifier) ModifyRequest(req *http.Request) error { return proxyutil.RequestHeader(req).Add(m.name, m.value) } // ModifyResponse appends the header at name with value to the response. func (m *appendModifier) ModifyResponse(res *http.Response) error { return proxyutil.ResponseHeader(res).Add(m.name, m.value) } // NewAppendModifier returns an appendModifier that will append a header with // with the given name and value for both requests and responses. Existing // headers with the same name will be left in place. func NewAppendModifier(name, value string) martian.RequestResponseModifier { return &appendModifier{ name: http.CanonicalHeaderKey(name), value: value, } } // appendModifierFromJSON takes a JSON message as a byte slice and returns // an appendModifier and an error. // // Example JSON configuration message: // { // "scope": ["request", "result"], // "name": "X-Martian", // "value": "true" // } func appendModifierFromJSON(b []byte) (*parse.Result, error) { msg := &modifierJSON{} if err := json.Unmarshal(b, msg); err != nil { return nil, err } modifier := NewAppendModifier(msg.Name, msg.Value) return parse.NewResult(modifier, msg.Scope) } martian-3.3.2/header/header_append_modifier_test.go000066400000000000000000000051561421371434000224330ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package header import ( "net/http" "testing" "github.com/google/martian/v3/parse" "github.com/google/martian/v3/proxyutil" ) func TestModifyRequestWithMultipleHeaders(t *testing.T) { m := NewAppendModifier("X-Repeated", "modifier") req, err := http.NewRequest("GET", "www.example.com", nil) req.Header.Add("X-Repeated", "original") if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := m.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := req.Header["X-Repeated"][0], "original"; got != want { t.Errorf("req.Header[\"X-Repeated\"][0]: got %q, want %q", got, want) } if got, want := req.Header["X-Repeated"][1], "modifier"; got != want { t.Errorf("req.Header[\"X-Repeated\"][1]: got %q, want %q", got, want) } } func TestAppendModifierFromJSON(t *testing.T) { msg := []byte(`{ "header.Append": { "scope": ["request", "response"], "name": "X-Martian", "value": "true" } }`) r, err := parse.FromJSON(msg) if err != nil { t.Fatalf("parse.FromJSON(): got %v, want no error", err) } req, err := http.NewRequest("GET", "http://martian.test", nil) req.Header.Add("X-Martian", "false") if err != nil { t.Fatalf("http.NewRequest(): got %q, want no error", err) } reqmod := r.RequestModifier() if reqmod == nil { t.Fatalf("reqmod: got nil, want not nil") } if err := reqmod.ModifyRequest(req); err != nil { t.Fatalf("reqmod.ModifyRequest(): got %v, want no error", err) } if n := len(req.Header["X-Martian"]); n != 2 { t.Errorf("res.Header[%q]: got len %d, want 2", "X-Martian", n) } resmod := r.ResponseModifier() if resmod == nil { t.Fatalf("resmod: got nil, want not nil") } res := proxyutil.NewResponse(200, nil, req) res.Header.Add("X-Martian", "false") if err := resmod.ModifyResponse(res); err != nil { t.Fatalf("resmod.ModifyResponse(): got %v, want no error", err) } if n := len(res.Header["X-Martian"]); n != 2 { t.Errorf("res.Header[%q]: got len %d, want 2", "X-Martian", n) } } martian-3.3.2/header/header_blacklist_modifier.go000066400000000000000000000043271421371434000220740ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package header import ( "encoding/json" "net/http" "github.com/google/martian/v3" "github.com/google/martian/v3/parse" "github.com/google/martian/v3/proxyutil" ) func init() { parse.Register("header.Blacklist", blacklistModifierFromJSON) } type blacklistModifier struct { names []string } type blacklistModifierJSON struct { Names []string `json:"names"` Scope []parse.ModifierType `json:"scope"` } // ModifyRequest deletes all request headers based on the header name. func (m *blacklistModifier) ModifyRequest(req *http.Request) error { h := proxyutil.RequestHeader(req) for _, name := range m.names { h.Del(name) } return nil } // ModifyResponse deletes all response headers based on the header name. func (m *blacklistModifier) ModifyResponse(res *http.Response) error { h := proxyutil.ResponseHeader(res) for _, name := range m.names { h.Del(name) } return nil } // NewBlacklistModifier returns a modifier that will delete any header that // matches a name contained in the names parameter. func NewBlacklistModifier(names ...string) martian.RequestResponseModifier { return &blacklistModifier{ names: names, } } // blacklistModifierFromJSON takes a JSON message as a byte slice and returns // a blacklistModifier and an error. // // Example JSON configuration message: // { // "names": ["X-Header", "Y-Header"], // "scope": ["request", "result"] // } func blacklistModifierFromJSON(b []byte) (*parse.Result, error) { msg := &blacklistModifierJSON{} if err := json.Unmarshal(b, msg); err != nil { return nil, err } return parse.NewResult(NewBlacklistModifier(msg.Names...), msg.Scope) } martian-3.3.2/header/header_blacklist_modifier_test.go000066400000000000000000000070721421371434000231330ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package header import ( "net/http" "testing" "github.com/google/martian/v3/parse" "github.com/google/martian/v3/proxyutil" ) func TestBlacklistModifierOnRequest(t *testing.T) { mod := NewBlacklistModifier("X-Testing") req, err := http.NewRequest("GET", "/", nil) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } req.Header.Set("X-Testing", "value") req.Header.Set("Y-Testing", "value") if err := mod.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if _, ok := req.Header["X-Testing"]; ok { t.Errorf("req.Header[%q]: got true, want false", "X-Testing") } if _, ok := req.Header["Y-Testing"]; !ok { t.Errorf("req.Header[%q]: got false, want true", "Y-Testing") } } func TestBlacklistModifierOnResponse(t *testing.T) { mod := NewBlacklistModifier("X-Testing") res := proxyutil.NewResponse(200, nil, nil) res.Header.Set("X-Testing", "value") res.Header.Set("Y-Testing", "value") if err := mod.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if _, ok := res.Header["X-Testing"]; ok { t.Errorf("res.Header[%q]: got true, want false", "X-Testing") } if _, ok := res.Header["Y-Testing"]; !ok { t.Errorf("res.Header[%q]: got false, want true", "Y-Testing") } } func TestBlacklistModifierFromJSON(t *testing.T) { msg := []byte(`{ "header.Blacklist": { "scope": ["request", "response"], "names": ["X-Testing", "Y-Testing"] } }`) r, err := parse.FromJSON(msg) if err != nil { t.Fatalf("parse.FromJSON(): got %v, want no error", err) } req, err := http.NewRequest("GET", "http://martian.test", nil) if err != nil { t.Fatalf("http.NewRequest(): got %q, want no error", err) } req.Header.Set("X-Testing", "value") req.Header.Set("Y-Testing", "value") req.Header.Set("Z-Testing", "value") reqmod := r.RequestModifier() if reqmod == nil { t.Fatalf("reqmod: got nil, want not nil") } res := proxyutil.NewResponse(200, nil, req) res.Header.Set("X-Testing", "value") res.Header.Set("Y-Testing", "value") res.Header.Set("Z-Testing", "value") resmod := r.ResponseModifier() if resmod == nil { t.Fatalf("resmod: got nil, want not nil") } tt := []struct { header string want string }{ { header: "X-Testing", want: "", }, { header: "Y-Testing", want: "", }, { header: "Z-Testing", want: "value", }, } for i, tc := range tt { if err := reqmod.ModifyRequest(req); err != nil { t.Fatalf("%d. reqmod.ModifyRequest(): got %v, want no error", i, err) } if got, want := req.Header.Get(tc.header), tc.want; got != want { t.Errorf("%d. req.Header.Get(%q): got %q, want %q", i, tc.header, got, want) } if err := resmod.ModifyResponse(res); err != nil { t.Fatalf("%d. resmod.ModifyResponse(): got %v, want no error", i, err) } if got, want := res.Header.Get(tc.header), tc.want; got != want { t.Errorf("%d. res.Header.Get(%q): got %q, want %q", i, tc.header, got, want) } } } martian-3.3.2/header/header_filter.go000066400000000000000000000045201421371434000175260ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package header import ( "encoding/json" "net/http" "github.com/google/martian/v3" "github.com/google/martian/v3/filter" "github.com/google/martian/v3/parse" ) var noop = martian.Noop("header.Filter") // Filter filters requests and responses based on header name and value. type Filter struct { *filter.Filter } type filterJSON struct { Name string `json:"name"` Value string `json:"value"` Modifier json.RawMessage `json:"modifier"` ElseModifier json.RawMessage `json:"else"` Scope []parse.ModifierType `json:"scope"` } func init() { parse.Register("header.Filter", filterFromJSON) } // NewFilter builds a new header filter. func NewFilter(name, value string) *Filter { m := NewMatcher(http.CanonicalHeaderKey(name), value) f := filter.New() f.SetRequestCondition(m) f.SetResponseCondition(m) return &Filter{f} } // filterFromJSON builds a header.Filter from JSON. // // Example JSON: // { // "scope": ["request", "result"], // "name": "Martian-Testing", // "value": "true", // "modifier": { ... }, // "else": { ... } // } func filterFromJSON(b []byte) (*parse.Result, error) { msg := &filterJSON{} if err := json.Unmarshal(b, msg); err != nil { return nil, err } filter := NewFilter(msg.Name, msg.Value) m, err := parse.FromJSON(msg.Modifier) if err != nil { return nil, err } filter.RequestWhenTrue(m.RequestModifier()) filter.ResponseWhenTrue(m.ResponseModifier()) if len(msg.ElseModifier) > 0 { em, err := parse.FromJSON(msg.ElseModifier) if err != nil { return nil, err } if em != nil { filter.RequestWhenFalse(em.RequestModifier()) filter.ResponseWhenFalse(em.ResponseModifier()) } } return parse.NewResult(filter, msg.Scope) } martian-3.3.2/header/header_filter_test.go000066400000000000000000000170151421371434000205700ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package header import ( "net/http" "testing" "github.com/google/martian/v3/filter" "github.com/google/martian/v3/martiantest" "github.com/google/martian/v3/parse" "github.com/google/martian/v3/proxyutil" ) func TestFilterFromJSON(t *testing.T) { msg := []byte(`{ "header.Filter": { "scope": ["request", "response"], "name": "Martian-Passthrough", "value": "true", "modifier": { "header.Modifier" : { "scope": ["request", "response"], "name": "Martian-Testing", "value": "true" } }, "else": { "header.Modifier" : { "scope": ["request", "response"], "name": "Martian-Testing", "value": "false" } } } }`) r, err := parse.FromJSON(msg) if err != nil { t.Fatalf("parse.FromJSON(): got %v, want no error", err) } reqmod := r.RequestModifier() if reqmod == nil { t.Fatal("reqmod: got nil, want not nil") } // Matching condition for request req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Header.Set("Martian-Passthrough", "true") if err := reqmod.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := req.Header.Get("Martian-Testing"), "true"; got != want { t.Fatalf("req.Header.Get(%q): got %q, want %q", "Martian-Testing", got, want) } // Else condition for request req, err = http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Header.Set("Martian-Passthrough", "false") if err := reqmod.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := req.Header.Get("Martian-Testing"), "false"; got != want { t.Fatalf("req.Header.Get(%q): got %q, want %q", "Martian-Testing", got, want) } // Matching condition for response resmod := r.ResponseModifier() if resmod == nil { t.Fatal("resmod: got nil, want not nil") } res := proxyutil.NewResponse(200, nil, req) res.Header.Set("Martian-Passthrough", "true") if err := resmod.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := res.Header.Get("Martian-Testing"), "true"; got != want { t.Fatalf("res.Header.Get(%q): got %q, want %q", "Martian-Testing", got, want) } // Else condition for response resmod = r.ResponseModifier() if resmod == nil { t.Fatal("resmod: got nil, want not nil") } res = proxyutil.NewResponse(200, nil, req) res.Header.Set("Martian-Passthrough", "false") if err := resmod.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := res.Header.Get("Martian-Testing"), "false"; got != want { t.Fatalf("res.Header.Get(%q): got %q, want %q", "Martian-Testing", got, want) } } func TestFilterFromJSONWithoutElse(t *testing.T) { msg := []byte(`{ "header.Filter": { "scope": ["request", "response"], "name": "Martian-Passthrough", "value": "true", "modifier": { "header.Modifier" : { "scope": ["request", "response"], "name": "Martian-Testing", "value": "true" } } } }`) _, err := parse.FromJSON(msg) if err != nil { t.Fatalf("parse.FromJSON(): got %v, want no error", err) } } func TestRequestWhenTrueCondition(t *testing.T) { hm := NewMatcher("Martian-Testing", "true") tt := []struct { name string values []string want bool }{ { name: "Martian-Production", values: []string{"true"}, want: false, }, { name: "Martian-Testing", values: []string{"see-next-value", "true"}, want: true, }, } for i, tc := range tt { tm := martiantest.NewModifier() f := filter.New() f.SetRequestCondition(hm) f.RequestWhenTrue(tm) req, err := http.NewRequest("GET", "/", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Header[tc.name] = tc.values if err := f.ModifyRequest(req); err != nil { t.Fatalf("%d. ModifyRequest(): got %v, want no error", i, err) } if tm.RequestModified() != tc.want { t.Errorf("%d. tm.RequestModified(): got %t, want %t", i, tm.RequestModified(), tc.want) } } } func TestRequestWhenFalse(t *testing.T) { hm := NewMatcher("Martian-Testing", "true") tt := []struct { name string values []string want bool }{ { name: "Martian-Production", values: []string{"true"}, want: true, }, { name: "Martian-Testing", values: []string{"see-next-value", "true"}, want: false, }, } for i, tc := range tt { tm := martiantest.NewModifier() f := filter.New() f.SetRequestCondition(hm) f.RequestWhenFalse(tm) req, err := http.NewRequest("GET", "/", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Header[tc.name] = tc.values if err := f.ModifyRequest(req); err != nil { t.Fatalf("%d. ModifyRequest(): got %v, want no error", i, err) } if tm.RequestModified() != tc.want { t.Errorf("%d. tm.RequestModified(): got %t, want %t", i, tm.RequestModified(), tc.want) } } } func TestResponseWhenTrue(t *testing.T) { hm := NewMatcher("Martian-Testing", "true") tt := []struct { name string values []string want bool }{ { name: "Martian-Production", values: []string{"true"}, want: false, }, { name: "Martian-Testing", values: []string{"see-next-value", "true"}, want: true, }, } for i, tc := range tt { tm := martiantest.NewModifier() f := filter.New() f.SetResponseCondition(hm) f.ResponseWhenTrue(tm) req, err := http.NewRequest("GET", "/", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } res := proxyutil.NewResponse(200, nil, req) res.Header[tc.name] = tc.values if err := f.ModifyResponse(res); err != nil { t.Fatalf("%d. ModifyResponse(): got %v, want no error", i, err) } if tm.ResponseModified() != tc.want { t.Errorf("%d. tm.ResponseModified(): got %t, want %t", i, tm.RequestModified(), tc.want) } } } func TestResponseWhenFalse(t *testing.T) { hm := NewMatcher("Martian-Testing", "true") tt := []struct { name string values []string want bool }{ { name: "Martian-Production", values: []string{"true"}, want: true, }, { name: "Martian-Testing", values: []string{"see-next-value", "true"}, want: false, }, } for i, tc := range tt { tm := martiantest.NewModifier() f := filter.New() f.SetResponseCondition(hm) f.ResponseWhenFalse(tm) req, err := http.NewRequest("GET", "/", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } res := proxyutil.NewResponse(200, nil, req) res.Header[tc.name] = tc.values if err := f.ModifyResponse(res); err != nil { t.Fatalf("%d. ModifyResponse(): got %v, want no error", i, err) } if tm.ResponseModified() != tc.want { t.Errorf("%d. tm.ResponseModified(): got %t, want %t", i, tm.RequestModified(), tc.want) } } } martian-3.3.2/header/header_matcher.go000066400000000000000000000033221421371434000176630ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package header import ( "net/http" "github.com/google/martian/v3/proxyutil" ) // Matcher is a conditonal evalutor of request or // response headers to be used in structs that take conditions. type Matcher struct { name, value string } // NewMatcher builds a new header matcher. func NewMatcher(name, value string) *Matcher { return &Matcher{ name: name, value: value, } } // MatchRequest evaluates a request and returns whether or not // the request contains a header that matches the provided name // and value. func (m *Matcher) MatchRequest(req *http.Request) bool { h := proxyutil.RequestHeader(req) vs, ok := h.All(m.name) if !ok { return false } for _, v := range vs { if v == m.value { return true } } return false } // MatchResponse evaluates a response and returns whether or not // the response contains a header that matches the provided name // and value. func (m *Matcher) MatchResponse(res *http.Response) bool { h := proxyutil.ResponseHeader(res) vs, ok := h.All(m.name) if !ok { return false } for _, v := range vs { if v == m.value { return true } } return false } martian-3.3.2/header/header_modifier.go000066400000000000000000000043271421371434000200440ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package header import ( "encoding/json" "net/http" "github.com/google/martian/v3" "github.com/google/martian/v3/parse" "github.com/google/martian/v3/proxyutil" ) func init() { parse.Register("header.Modifier", modifierFromJSON) } type modifier struct { name, value string } type modifierJSON struct { Name string `json:"name"` Value string `json:"value"` Scope []parse.ModifierType `json:"scope"` } // ModifyRequest sets the header at name with value on the request. func (m *modifier) ModifyRequest(req *http.Request) error { return proxyutil.RequestHeader(req).Set(m.name, m.value) } // ModifyResponse sets the header at name with value on the response. func (m *modifier) ModifyResponse(res *http.Response) error { return proxyutil.ResponseHeader(res).Set(m.name, m.value) } // NewModifier returns a modifier that will set the header at name with // the given value for both requests and responses. If the header name already // exists all values will be overwritten. func NewModifier(name, value string) martian.RequestResponseModifier { return &modifier{ name: http.CanonicalHeaderKey(name), value: value, } } // modifierFromJSON takes a JSON message as a byte slice and returns // a headerModifier and an error. // // Example JSON configuration message: // { // "scope": ["request", "result"], // "name": "X-Martian", // "value": "true" // } func modifierFromJSON(b []byte) (*parse.Result, error) { msg := &modifierJSON{} if err := json.Unmarshal(b, msg); err != nil { return nil, err } modifier := NewModifier(msg.Name, msg.Value) return parse.NewResult(modifier, msg.Scope) } martian-3.3.2/header/header_modifier_test.go000066400000000000000000000062621421371434000211030ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package header import ( "net/http" "testing" "github.com/google/martian/v3/parse" "github.com/google/martian/v3/proxyutil" ) func TestNewHeaderModifier(t *testing.T) { mod := NewModifier("testing", "true") req, err := http.NewRequest("GET", "/", nil) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } if err := mod.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := req.Header.Get("testing"), "true"; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", "testing", got, want) } res := proxyutil.NewResponse(200, nil, req) if err := mod.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := req.Header.Get("testing"), "true"; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", "testing", got, want) } } func TestModifyRequestWithHostHeader(t *testing.T) { m := NewModifier("Host", "www.google.com") req, err := http.NewRequest("GET", "www.example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := m.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := req.Host, "www.google.com"; got != want { t.Errorf("req.Host: got %q, want %q", got, want) } } func TestModifierFromJSON(t *testing.T) { msg := []byte(`{ "header.Modifier": { "scope": ["request", "response"], "name": "X-Martian", "value": "true" } }`) r, err := parse.FromJSON(msg) if err != nil { t.Fatalf("parse.FromJSON(): got %v, want no error", err) } req, err := http.NewRequest("GET", "http://martian.test", nil) req.Header.Add("X-Martian", "false") if err != nil { t.Fatalf("http.NewRequest(): got %q, want no error", err) } reqmod := r.RequestModifier() if reqmod == nil { t.Fatalf("reqmod: got nil, want not nil") } if err := reqmod.ModifyRequest(req); err != nil { t.Fatalf("reqmod.ModifyRequest(): got %v, want no error", err) } if got, want := req.Header.Get("X-Martian"), "true"; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", "X-Martian", got, want) } resmod := r.ResponseModifier() if resmod == nil { t.Fatalf("resmod: got nil, want not nil") } res := proxyutil.NewResponse(200, nil, req) res.Header.Add("X-Martian", "false") if err := resmod.ModifyResponse(res); err != nil { t.Fatalf("resmod.ModifyResponse(): got %v, want no error", err) } if got, want := res.Header.Get("X-Martian"), "true"; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "X-Martian", got, want) } } martian-3.3.2/header/header_value_regex_filter.go000066400000000000000000000061631421371434000221210ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package header import ( "encoding/json" "net/http" "regexp" "github.com/google/martian/v3" "github.com/google/martian/v3/parse" ) // ValueRegexFilter executes resmod and reqmod when the header // value matches regex. type ValueRegexFilter struct { regex *regexp.Regexp header string reqmod martian.RequestModifier resmod martian.ResponseModifier } type headerValueRegexFilterJSON struct { Regex string `json:"regex"` HeaderName string `json:"header"` Modifier json.RawMessage `json:"modifier"` Scope []parse.ModifierType `json:"scope"` } func init() { parse.Register("header.RegexFilter", headerValueRegexFilterFromJSON) } // NewValueRegexFilter builds a new header value regex filter. func NewValueRegexFilter(regex *regexp.Regexp, header string) *ValueRegexFilter { return &ValueRegexFilter{ regex: regex, header: header, reqmod: noop, resmod: noop, } } func headerValueRegexFilterFromJSON(b []byte) (*parse.Result, error) { msg := &headerValueRegexFilterJSON{} if err := json.Unmarshal(b, msg); err != nil { return nil, err } cr, err := regexp.Compile(msg.Regex) if err != nil { return nil, err } filter := NewValueRegexFilter(cr, msg.HeaderName) r, err := parse.FromJSON(msg.Modifier) if err != nil { return nil, err } reqmod := r.RequestModifier() filter.SetRequestModifier(reqmod) resmod := r.ResponseModifier() filter.SetResponseModifier(resmod) return parse.NewResult(filter, msg.Scope) } // ModifyRequest runs reqmod iff the value of header matches regex. func (f *ValueRegexFilter) ModifyRequest(req *http.Request) error { hvalue := req.Header.Get(f.header) if hvalue == "" { return nil } if f.regex.MatchString(hvalue) { return f.reqmod.ModifyRequest(req) } return nil } // ModifyResponse runs resmod iff the value of request header matches regex. func (f *ValueRegexFilter) ModifyResponse(res *http.Response) error { hvalue := res.Request.Header.Get(f.header) if hvalue == "" { return nil } if f.regex.MatchString(hvalue) { return f.resmod.ModifyResponse(res) } return nil } // SetRequestModifier sets the request modifier of HeaderValueRegexFilter. func (f *ValueRegexFilter) SetRequestModifier(reqmod martian.RequestModifier) { if reqmod == nil { f.reqmod = noop return } f.reqmod = reqmod } // SetResponseModifier sets the response modifier of HeaderValueRegexFilter. func (f *ValueRegexFilter) SetResponseModifier(resmod martian.ResponseModifier) { if resmod == nil { f.resmod = noop return } f.resmod = resmod } martian-3.3.2/header/header_verifier.go000066400000000000000000000104661421371434000200620ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package header provides utilities for modifying, filtering, and // verifying headers in martian.Proxy. package header import ( "encoding/json" "fmt" "net/http" "strings" "github.com/google/martian/v3" "github.com/google/martian/v3/parse" "github.com/google/martian/v3/proxyutil" "github.com/google/martian/v3/verify" ) const ( headerErrFormat = "%s(%s) header verify failure: got no header, want %s header" valueErrFormat = "%s(%s) header verify failure: got %s with value %s, want value %s" ) type verifier struct { name, value string reqerr *martian.MultiError reserr *martian.MultiError } type verifierJSON struct { Name string `json:"name"` Value string `json:"value"` Scope []parse.ModifierType `json:"scope"` } func init() { parse.Register("header.Verifier", verifierFromJSON) } // NewVerifier creates a new header verifier for the given name and value. func NewVerifier(name, value string) verify.RequestResponseVerifier { return &verifier{ name: name, value: value, reqerr: martian.NewMultiError(), reserr: martian.NewMultiError(), } } // ModifyRequest verifies that the header for name is present in all modified // requests. If value is non-empty the value must be present in at least one // header for name. An error will be added to the contained *MultiError for // every unmatched request. func (v *verifier) ModifyRequest(req *http.Request) error { h := proxyutil.RequestHeader(req) vs, ok := h.All(v.name) if !ok { v.reqerr.Add(fmt.Errorf(headerErrFormat, "request", req.URL, v.name)) return nil } for _, value := range vs { switch v.value { case "", value: return nil } } v.reqerr.Add(fmt.Errorf(valueErrFormat, "request", req.URL, v.name, strings.Join(vs, ", "), v.value)) return nil } // ModifyResponse verifies that the header for name is present in all modified // responses. If value is non-empty the value must be present in at least one // header for name. An error will be added to the contained *MultiError for // every unmatched response. func (v *verifier) ModifyResponse(res *http.Response) error { h := proxyutil.ResponseHeader(res) vs, ok := h.All(v.name) if !ok { v.reserr.Add(fmt.Errorf(headerErrFormat, "response", res.Request.URL, v.name)) return nil } for _, value := range vs { switch v.value { case "", value: return nil } } v.reserr.Add(fmt.Errorf(valueErrFormat, "response", res.Request.URL, v.name, strings.Join(vs, ", "), v.value)) return nil } // VerifyRequests returns an error if verification for any request failed. // If an error is returned it will be of type *martian.MultiError. func (v *verifier) VerifyRequests() error { if v.reqerr.Empty() { return nil } return v.reqerr } // VerifyResponses returns an error if verification for any request failed. // If an error is returned it will be of type *martian.MultiError. func (v *verifier) VerifyResponses() error { if v.reserr.Empty() { return nil } return v.reserr } // ResetRequestVerifications clears all failed request verifications. func (v *verifier) ResetRequestVerifications() { v.reqerr = martian.NewMultiError() } // ResetResponseVerifications clears all failed response verifications. func (v *verifier) ResetResponseVerifications() { v.reserr = martian.NewMultiError() } // verifierFromJSON builds a header.Verifier from JSON. // // Example JSON: // { // "name": "header.Verifier", // "scope": ["request", "result"], // "modifier": { // "name": "Martian-Testing", // "value": "true" // } // } func verifierFromJSON(b []byte) (*parse.Result, error) { msg := &verifierJSON{} if err := json.Unmarshal(b, msg); err != nil { return nil, err } return parse.NewResult(NewVerifier(msg.Name, msg.Value), msg.Scope) } martian-3.3.2/header/header_verifier_test.go000066400000000000000000000172361421371434000211230ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package header import ( "fmt" "net/http" "testing" "github.com/google/martian/v3" "github.com/google/martian/v3/parse" "github.com/google/martian/v3/proxyutil" "github.com/google/martian/v3/verify" ) func TestVerifyRequestsBlankValue(t *testing.T) { v := NewVerifier("Martian-Test", "") for i := 0; i < 4; i++ { req, err := http.NewRequest("GET", fmt.Sprintf("http://www.example.com/%d", i), nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } // Request 1, 3 should fail verification. if i%2 == 0 { req.Header.Set("Martian-Test", "true") } if err := v.ModifyRequest(req); err != nil { t.Fatalf("%d. ModifyRequest(): got %v, want no error", i, err) } } merr, ok := v.VerifyRequests().(*martian.MultiError) if !ok { t.Fatal("VerifyRequests(): got no error, want *verify.MultiError") } if got, want := len(merr.Errors()), 2; got != want { t.Fatalf("len(merr.Errors()): got %d, want %d", got, want) } wants := []string{ `request(http://www.example.com/1) header verify failure: got no header, want Martian-Test header`, `request(http://www.example.com/3) header verify failure: got no header, want Martian-Test header`, } for i, err := range merr.Errors() { if got := err.Error(); got != wants[i] { t.Errorf("Errors()[%d]: got %q, want %q", i, got, wants[i]) } } v.ResetRequestVerifications() if err := v.VerifyRequests(); err != nil { t.Errorf("VerifyRequests(): got %v, want no error", err) } } func TestVerifierFromJSON(t *testing.T) { msg := []byte(`{ "header.Verifier": { "scope": ["request", "response"], "name": "Martian-Test", "value": "true" } }`) r, err := parse.FromJSON(msg) if err != nil { t.Fatalf("parse.FromJSON(): got %v, want no error", err) } reqmod := r.RequestModifier() if reqmod == nil { t.Fatal("reqmod: got nil, want not nil") } reqv, ok := reqmod.(verify.RequestVerifier) if !ok { t.Fatal("reqmod.(verify.RequestVerifier): got !ok, want ok") } req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := reqv.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if err := reqv.VerifyRequests(); err == nil { t.Error("VerifyRequests(): got nil, want not nil") } resmod := r.ResponseModifier() if resmod == nil { t.Fatal("resmod: got nil, want not nil") } resv, ok := resmod.(verify.ResponseVerifier) if !ok { t.Fatal("resmod.(verify.ResponseVerifier): got !ok, want ok") } res := proxyutil.NewResponse(200, nil, req) if err := resv.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if err := resv.VerifyResponses(); err == nil { t.Error("VerifyResponses(): got nil, want not nil") } } func TestVerifyRequests(t *testing.T) { v := NewVerifier("Martian-Test", "testing-even") for i := 0; i < 4; i++ { req, err := http.NewRequest("GET", fmt.Sprintf("http://www.example.com/%d", i), nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Header.Add("Martian-Test", fmt.Sprintf("test-%d", i)) // Request 1, 3 should fail verification. if i%2 == 0 { req.Header.Add("Martian-Test", "testing-even") } else { req.Header.Add("Martian-Test", "testing-odd") } if err := v.ModifyRequest(req); err != nil { t.Fatalf("%d. ModifyRequest(): got %v, want no error", i, err) } } merr, ok := v.VerifyRequests().(*martian.MultiError) if !ok { t.Fatal("VerifyRequests(): got no error, want *verify.MultiError") } if got, want := len(merr.Errors()), 2; got != want { t.Fatalf("len(merr.Errors()): got %d, want %d", got, want) } wants := []string{ `request(http://www.example.com/1) header verify failure: got Martian-Test with value test-1, testing-odd, want value testing-even`, `request(http://www.example.com/3) header verify failure: got Martian-Test with value test-3, testing-odd, want value testing-even`, } for i, err := range merr.Errors() { if got := err.Error(); got != wants[i] { t.Errorf("Errors()[%d]: got %q, want %q", i, got, wants[i]) } } v.ResetRequestVerifications() if err := v.VerifyRequests(); err != nil { t.Errorf("VerifyRequests(): got %v, want no error", err) } } func TestVerifyResponsesBlankValue(t *testing.T) { v := NewVerifier("Martian-Test", "") for i := 0; i < 4; i++ { req, err := http.NewRequest("GET", fmt.Sprintf("http://www.example.com/%d", i), nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } res := proxyutil.NewResponse(200, nil, req) // Response 1, 3 should fail verification. if i%2 == 0 { res.Header.Set("Martian-Test", "true") } if err := v.ModifyResponse(res); err != nil { t.Fatalf("%d. ModifyResponse(): got %v, want no error", i, err) } } merr, ok := v.VerifyResponses().(*martian.MultiError) if !ok { t.Fatal("VerifyResponses(): got no error, want *verify.MultiError") } if got, want := len(merr.Errors()), 2; got != want { t.Fatalf("len(merr.Errors()): got %d, want %d", got, want) } wants := []string{ `response(http://www.example.com/1) header verify failure: got no header, want Martian-Test header`, `response(http://www.example.com/3) header verify failure: got no header, want Martian-Test header`, } for i, err := range merr.Errors() { if got := err.Error(); got != wants[i] { t.Errorf("Errors()[%d]: got %q, want %q", i, got, wants[i]) } } v.ResetResponseVerifications() if err := v.VerifyResponses(); err != nil { t.Errorf("VerifyResponses(): got %v, want no error", err) } } func TestVerifyResponses(t *testing.T) { v := NewVerifier("Martian-Test", "testing-even") for i := 0; i < 4; i++ { req, err := http.NewRequest("GET", fmt.Sprintf("http://www.example.com/%d", i), nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } res := proxyutil.NewResponse(200, nil, req) res.Header.Add("Martian-Test", fmt.Sprintf("test-%d", i)) // Response 1, 3 should fail verification. if i%2 == 0 { res.Header.Add("Martian-Test", "testing-even") } else { res.Header.Add("Martian-Test", "testing-odd") } if err := v.ModifyResponse(res); err != nil { t.Fatalf("%d. ModifyResponse(): got %v, want no error", i, err) } } merr, ok := v.VerifyResponses().(*martian.MultiError) if !ok { t.Fatal("VerifyResponses(): got no error, want *verify.MultiError") } if got, want := len(merr.Errors()), 2; got != want { t.Fatalf("len(merr.Errors()): got %d, want %d", got, want) } wants := []string{ `response(http://www.example.com/1) header verify failure: got Martian-Test with value test-1, testing-odd, want value testing-even`, `response(http://www.example.com/3) header verify failure: got Martian-Test with value test-3, testing-odd, want value testing-even`, } for i, err := range merr.Errors() { if got := err.Error(); got != wants[i] { t.Errorf("Errors()[%d]: got %q, want %q", i, got, wants[i]) } } v.ResetResponseVerifications() if err := v.VerifyResponses(); err != nil { t.Errorf("VerifyResponses(): got %v, want no error", err) } } martian-3.3.2/header/hopbyhop_modifier.go000066400000000000000000000042701421371434000204410ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package header import ( "net/http" "strings" "github.com/google/martian/v3" ) // Hop-by-hop headers as defined by RFC2616. // // http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-14#section-7.1.3.1 var hopByHopHeaders = []string{ "Connection", "Keep-Alive", "Proxy-Authenticate", "Proxy-Authorization", "Proxy-Connection", // Non-standard, but required for HTTP/2. "Te", "Trailer", "Transfer-Encoding", "Upgrade", } type hopByHopModifier struct{} // NewHopByHopModifier removes Hop-By-Hop headers from requests and // responses. func NewHopByHopModifier() martian.RequestResponseModifier { return &hopByHopModifier{} } // ModifyRequest removes all hop-by-hop headers defined by RFC2616 as // well as any additional hop-by-hop headers specified in the // Connection header. func (m *hopByHopModifier) ModifyRequest(req *http.Request) error { removeHopByHopHeaders(req.Header) return nil } // ModifyResponse removes all hop-by-hop headers defined by RFC2616 as // well as any additional hop-by-hop headers specified in the // Connection header. func (m *hopByHopModifier) ModifyResponse(res *http.Response) error { removeHopByHopHeaders(res.Header) return nil } func removeHopByHopHeaders(header http.Header) { // Additional hop-by-hop headers may be specified in `Connection` headers. // http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-14#section-9.1 for _, vs := range header["Connection"] { for _, v := range strings.Split(vs, ",") { k := http.CanonicalHeaderKey(strings.TrimSpace(v)) header.Del(k) } } for _, k := range hopByHopHeaders { header.Del(k) } } martian-3.3.2/header/hopbyhop_modifier_test.go000066400000000000000000000044521421371434000215020ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package header import ( "net/http" "testing" "github.com/google/martian/v3/proxyutil" ) func TestRemoveHopByHopHeaders(t *testing.T) { m := NewHopByHopModifier() req, err := http.NewRequest("GET", "/", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Header = http.Header{ // Additional hop-by-hop headers are listed in the // Connection header. "Connection": []string{ "X-Connection", "X-Hop-By-Hop, close", }, // RFC hop-by-hop headers. "Keep-Alive": []string{}, "Proxy-Authenticate": []string{}, "Proxy-Authorization": []string{}, "Te": []string{}, "Trailer": []string{}, "Transfer-Encoding": []string{}, "Upgrade": []string{}, "Proxy-Connection": []string{}, // Hop-by-hop headers listed in the Connection header. "X-Connection": []string{}, "X-Hop-By-Hop": []string{}, // End-to-end header that should not be removed. "X-End-To-End": []string{}, } if err := m.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := len(req.Header), 1; got != want { t.Fatalf("len(req.Header): got %d, want %d", got, want) } if _, ok := req.Header["X-End-To-End"]; !ok { t.Errorf("req.Header[%q]: got !ok, want ok", "X-End-To-End") } res := proxyutil.NewResponse(200, nil, req) res.Header = req.Header if err := m.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := len(res.Header), 1; got != want { t.Fatalf("len(res.Header): got %d, want %d", got, want) } if _, ok := res.Header["X-End-To-End"]; !ok { t.Errorf("res.Header[%q]: got !ok, want ok", "X-End-To-End") } } martian-3.3.2/header/id_modifier.go000066400000000000000000000035661421371434000172140ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package header import ( "encoding/json" "net/http" "github.com/google/martian/v3" "github.com/google/martian/v3/parse" ) const idHeaderName string = "X-Martian-ID" func init() { parse.Register("header.Id", idModifierFromJSON) } type idModifier struct{} type idModifierJSON struct { Scope []parse.ModifierType `json:"scope"` } // NewIDModifier returns a request modifier that will set a header with the name // X-Martian-ID with a value that is a unique identifier for the request. In the case // that the X-Martian-ID header is already set, the header is unmodified. func NewIDModifier() martian.RequestModifier { return &idModifier{} } // ModifyRequest sets the X-Martian-ID header with a unique identifier. In the case // that the X-Martian-ID header is already set, the header is unmodified. func (im *idModifier) ModifyRequest(req *http.Request) error { // Do not rewrite an ID if req already has one if req.Header.Get(idHeaderName) != "" { return nil } ctx := martian.NewContext(req) req.Header.Set(idHeaderName, ctx.ID()) return nil } func idModifierFromJSON(b []byte) (*parse.Result, error) { msg := &idModifierJSON{} if err := json.Unmarshal(b, msg); err != nil { return nil, err } modifier := NewIDModifier() return parse.NewResult(modifier, msg.Scope) } martian-3.3.2/header/id_modifier_test.go000066400000000000000000000042231421371434000202420ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package header import ( "net/http" "testing" "github.com/google/martian/v3" "github.com/google/martian/v3/parse" ) func TestIdModifier(t *testing.T) { mod := NewIDModifier() req, err := http.NewRequest("GET", "/", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } ctx, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("TestContext(): got %v, want no error", err) } defer remove() if err := mod.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := req.Header.Get("X-Martian-ID"), ctx.ID(); got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", "X-Martian-ID", got, want) } } func TestFromJSON(t *testing.T) { msg := []byte(`{ "header.Id": {} }`) r, err := parse.FromJSON(msg) if err != nil { t.Fatalf("parse.FromJSON(): got %v, want no error", err) } reqmod := r.RequestModifier() if _, ok := reqmod.(*idModifier); !ok { t.Fatal("reqmod.(*idModifier): got !ok, want ok") } req, err := http.NewRequest("GET", "/", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } ctx, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("TestContext(): got %v, want no error", err) } defer remove() if err := reqmod.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := req.Header.Get("X-Martian-ID"), ctx.ID(); got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", "X-Martian-ID", got, want) } } martian-3.3.2/header/via_modifier.go000066400000000000000000000065271421371434000173770ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package header import ( "crypto/rand" "fmt" "io" "net/http" "regexp" "strings" "github.com/google/martian/v3" ) const viaLoopKey = "via.LoopDetection" var whitespace = regexp.MustCompile("[\t ]+") // ViaModifier is a header modifier that checks for proxy redirect loops. type ViaModifier struct { requestedBy string boundary string } // NewViaModifier returns a new Via modifier. func NewViaModifier(requestedBy string) *ViaModifier { return &ViaModifier{ requestedBy: requestedBy, boundary: randomBoundary(), } } // ModifyRequest sets the Via header and provides loop-detection. If Via is // already present, it will be appended to the existing value. If a loop is // detected an error is added to the context and the request round trip is // skipped. // // http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-14#section-9.9 func (m *ViaModifier) ModifyRequest(req *http.Request) error { via := fmt.Sprintf("%d.%d %s-%s", req.ProtoMajor, req.ProtoMinor, m.requestedBy, m.boundary) if v := req.Header.Get("Via"); v != "" { if m.hasLoop(v) { err := fmt.Errorf("via: detected request loop, header contains %s", via) ctx := martian.NewContext(req) ctx.Set(viaLoopKey, err) ctx.SkipRoundTrip() return err } via = fmt.Sprintf("%s, %s", v, via) } req.Header.Set("Via", via) return nil } // ModifyResponse sets the status code to 400 Bad Request if a loop was // detected in the request. func (m *ViaModifier) ModifyResponse(res *http.Response) error { ctx := martian.NewContext(res.Request) if err, _ := ctx.Get(viaLoopKey); err != nil { res.StatusCode = 400 res.Status = http.StatusText(400) return err.(error) } return nil } // hasLoop parses via and attempts to match requestedBy against the contained // pseudonyms/host:port pairs. func (m *ViaModifier) hasLoop(via string) bool { for _, v := range strings.Split(via, ",") { parts := whitespace.Split(strings.TrimSpace(v), 3) // No pseudonym or host:port, assume there is no loop. if len(parts) < 2 { continue } if fmt.Sprintf("%s-%s", m.requestedBy, m.boundary) == parts[1] { return true } } return false } // SetBoundary sets the boundary string (random 10 character by default) used to // disabiguate Martians that are chained together with identical requestedBy values. // This should only be used for testing. func (m *ViaModifier) SetBoundary(boundary string) { m.boundary = boundary } // randomBoundary generates a 10 character string to ensure that Martians that // are chained together with the same requestedBy value do not collide. This func // panics if io.Readfull fails. func randomBoundary() string { var buf [10]byte _, err := io.ReadFull(rand.Reader, buf[:]) if err != nil { panic(err) } return fmt.Sprintf("%x", buf[:]) } martian-3.3.2/header/via_modifier_test.go000066400000000000000000000047651421371434000204400ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package header import ( "net/http" "strings" "testing" "github.com/google/martian/v3" "github.com/google/martian/v3/proxyutil" ) func TestViaModifier(t *testing.T) { m := NewViaModifier("martian") req, err := http.NewRequest("GET", "/", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } res := proxyutil.NewResponse(200, nil, req) ctx, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("martian.TestContext(): got %v, want no error", err) } defer remove() if err := m.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := req.Header.Get("Via"), "1.1 martian"; !strings.HasPrefix(got, want) { t.Errorf("req.Header.Get(%q): got %q, want prefixed with %q", "Via", got, want) } if err := m.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } req.Header.Set("Via", "1.0\talpha\t(martian)") if err := m.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := req.Header.Get("Via"), "1.0\talpha\t(martian), 1.1 martian"; !strings.HasPrefix(got, want) { t.Errorf("req.Header.Get(%q): got %q, want prefixed with %q", "Via", got, want) } m.SetBoundary("boundary") req.Header.Set("Via", "1.0\talpha\t(martian), 1.1 martian-boundary, 1.1 beta") if err := m.ModifyRequest(req); err == nil { t.Fatal("ModifyRequest(): got nil, want request loop error") } if !ctx.SkippingRoundTrip() { t.Errorf("ctx.SkippingRoundTrip(): got false, want true") } if err := m.ModifyResponse(res); err == nil { t.Fatal("ModifyResponse(): got nil, want request loop error") } if got, want := res.StatusCode, 400; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } if got, want := res.Status, http.StatusText(400); got != want { t.Errorf("res.Status: got %q, want %q", got, want) } } martian-3.3.2/httpspec/000077500000000000000000000000001421371434000150035ustar00rootroot00000000000000martian-3.3.2/httpspec/httpspec.go000066400000000000000000000031311421371434000171620ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package httpspec provides a modifier stack that has been preconfigured to // provide spec-compliant HTTP proxy behavior. // // Related: https://www.mnot.net/blog/2011/07/11/what_proxies_must_do package httpspec import ( "github.com/google/martian/v3/fifo" "github.com/google/martian/v3/header" ) // NewStack returns a martian modifier stack that handles ensuring proper proxy // behavior, in addition to a fifo.Group that can be used to add additional // modifiers within the stack. func NewStack(via string) (outer *fifo.Group, inner *fifo.Group) { outer = fifo.NewGroup() hbhm := header.NewHopByHopModifier() outer.AddRequestModifier(hbhm) outer.AddRequestModifier(header.NewForwardedModifier()) outer.AddRequestModifier(header.NewBadFramingModifier()) vm := header.NewViaModifier(via) outer.AddRequestModifier(vm) inner = fifo.NewGroup() outer.AddRequestModifier(inner) outer.AddResponseModifier(inner) outer.AddResponseModifier(vm) outer.AddResponseModifier(hbhm) return outer, inner } martian-3.3.2/httpspec/httpspec_test.go000066400000000000000000000045451421371434000202330ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package httpspec import ( "net/http" "strings" "testing" "github.com/google/martian/v3" "github.com/google/martian/v3/martiantest" "github.com/google/martian/v3/proxyutil" ) func TestNewStack(t *testing.T) { stack, fg := NewStack("martian") tm := martiantest.NewModifier() fg.AddRequestModifier(tm) fg.AddResponseModifier(tm) req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } _, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("martian.TestContext(): got %v, want no error", err) } defer remove() // Hop-by-hop header to be removed. req.Header.Set("Hop-By-Hop", "true") req.Header.Set("Connection", "Hop-By-Hop") req.RemoteAddr = "10.0.0.1:5000" if err := stack.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := req.Header.Get("Hop-By-Hop"), ""; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", "Hop-By-Hop", got, want) } if got, want := req.Header.Get("X-Forwarded-For"), "10.0.0.1"; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", "X-Forwarded-For", got, want) } if got, want := req.Header.Get("Via"), "1.1 martian"; !strings.HasPrefix(got, want) { t.Errorf("req.Header.Get(%q): got %q, want %q", "Via", got, want) } res := proxyutil.NewResponse(200, nil, req) // Hop-by-hop header to be removed. res.Header.Set("Hop-By-Hop", "true") res.Header.Set("Connection", "Hop-By-Hop") if err := stack.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := res.Header.Get("Hop-By-Hop"), ""; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Hop-By-Hop", got, want) } } martian-3.3.2/init.go000066400000000000000000000015051421371434000144440ustar00rootroot00000000000000// Copyright 2016 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package martian import ( "flag" mlog "github.com/google/martian/v3/log" ) var ( level = flag.Int("v", 0, "log level") ) // Init runs common initialization code for a martian proxy. func Init() { mlog.SetLevel(*level) } martian-3.3.2/ipauth/000077500000000000000000000000001421371434000144435ustar00rootroot00000000000000martian-3.3.2/ipauth/ipauth.go000066400000000000000000000046771421371434000163020ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package ipauth provides a martian.Modifier that sets auth based on IP. package ipauth import ( "net" "net/http" "github.com/google/martian/v3" "github.com/google/martian/v3/auth" ) var noop = martian.Noop("ipauth.Modifier") // Modifier is the IP authentication modifier. type Modifier struct { reqmod martian.RequestModifier resmod martian.ResponseModifier } // NewModifier returns a new IP authentication modifier. func NewModifier() *Modifier { return &Modifier{ reqmod: noop, resmod: noop, } } // SetRequestModifier sets the request modifier. func (m *Modifier) SetRequestModifier(reqmod martian.RequestModifier) { if reqmod == nil { reqmod = noop } m.reqmod = reqmod } // SetResponseModifier sets the response modifier. func (m *Modifier) SetResponseModifier(resmod martian.ResponseModifier) { if resmod == nil { resmod = noop } m.resmod = resmod } // ModifyRequest sets the auth ID in the context from the request iff it has // not already been set and runs reqmod.ModifyRequest. If the underlying // modifier has indicated via auth error that no valid auth credentials // have been found we set ctx.SkipRoundTrip. func (m *Modifier) ModifyRequest(req *http.Request) error { ctx := martian.NewContext(req) actx := auth.FromContext(ctx) ip, _, err := net.SplitHostPort(req.RemoteAddr) if err != nil { ip = req.RemoteAddr } actx.SetID(ip) err = m.reqmod.ModifyRequest(req) if actx.Error() != nil { ctx.SkipRoundTrip() } return err } // ModifyResponse runs resmod.ModifyResponse. // // If an error is returned from resmod.ModifyResponse it is returned. func (m *Modifier) ModifyResponse(res *http.Response) error { ctx := martian.NewContext(res.Request) actx := auth.FromContext(ctx) err := m.resmod.ModifyResponse(res) if actx.Error() != nil { res.StatusCode = 403 res.Status = http.StatusText(403) } return err } martian-3.3.2/ipauth/ipauth_test.go000066400000000000000000000073611421371434000173320ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package ipauth import ( "errors" "net/http" "testing" "github.com/google/martian/v3" "github.com/google/martian/v3/auth" "github.com/google/martian/v3/martiantest" "github.com/google/martian/v3/proxyutil" ) func TestModifyRequest(t *testing.T) { m := NewModifier() m.SetRequestModifier(nil) req, err := http.NewRequest("CONNECT", "https://www.example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } ctx, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("martian.TestContext(): got %v, want no error", err) } defer remove() if err := m.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } actx := auth.FromContext(ctx) if got, want := actx.ID(), ""; got != want { t.Errorf("actx.ID(): got %q, want %q", got, want) } // IP with port and modifier with error. tm := martiantest.NewModifier() reqerr := errors.New("request error") tm.RequestError(reqerr) req.RemoteAddr = "1.1.1.1:8111" m.SetRequestModifier(tm) if err := m.ModifyRequest(req); err != reqerr { t.Fatalf("ModifyConnectRequest(): got %v, want %v", err, reqerr) } if got, want := actx.ID(), "1.1.1.1"; got != want { t.Errorf("actx.ID(): got %q, want %q", got, want) } // IP without port and modifier with auth error. req.RemoteAddr = "4.4.4.4" authErr := errors.New("auth error") tm.RequestError(nil) tm.RequestFunc(func(req *http.Request) { ctx := martian.NewContext(req) actx := auth.FromContext(ctx) actx.SetError(authErr) }) if err := m.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := actx.ID(), ""; got != want { t.Errorf("actx.ID(): got %q, want %q", got, want) } } func TestModifyResponse(t *testing.T) { m := NewModifier() m.SetResponseModifier(nil) req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } ctx, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("martian.TestContext(): got %v, want no error", err) } defer remove() res := proxyutil.NewResponse(200, nil, req) if err := m.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } // Modifier with error. tm := martiantest.NewModifier() reserr := errors.New("response error") tm.ResponseError(reserr) m.SetResponseModifier(tm) if err := m.ModifyResponse(res); err != reserr { t.Fatalf("ModifyResponse(): got %v, want %v", err, reserr) } // Modifier with auth error. tm.ResponseError(nil) authErr := errors.New("auth error") tm.ResponseFunc(func(res *http.Response) { ctx := martian.NewContext(res.Request) actx := auth.FromContext(ctx) actx.SetError(authErr) }) actx := auth.FromContext(ctx) actx.SetID("bad-auth") if err := m.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := res.StatusCode, 403; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } if got, want := actx.Error(), authErr; got != want { t.Errorf("actx.Error(): got %v, want %v", got, want) } } martian-3.3.2/log/000077500000000000000000000000001421371434000137325ustar00rootroot00000000000000martian-3.3.2/log/log.go000066400000000000000000000036461421371434000150530ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package log provides a universal logger for martian packages. package log import ( "fmt" "log" "sync" ) const ( // Silent is a level that logs nothing. Silent int = iota // Error is a level that logs error logs. Error // Info is a level that logs error, and info logs. Info // Debug is a level that logs error, info, and debug logs. Debug ) // Default log level is Error. var ( level = Error lock sync.Mutex ) // SetLevel sets the global log level. func SetLevel(l int) { lock.Lock() defer lock.Unlock() level = l } // Infof logs an info message. func Infof(format string, args ...interface{}) { lock.Lock() defer lock.Unlock() if level < Info { return } msg := fmt.Sprintf("INFO: %s", format) if len(args) > 0 { msg = fmt.Sprintf(msg, args...) } log.Println(msg) } // Debugf logs a debug message. func Debugf(format string, args ...interface{}) { lock.Lock() defer lock.Unlock() if level < Debug { return } msg := fmt.Sprintf("DEBUG: %s", format) if len(args) > 0 { msg = fmt.Sprintf(msg, args...) } log.Println(msg) } // Errorf logs an error message. func Errorf(format string, args ...interface{}) { lock.Lock() defer lock.Unlock() if level < Error { return } msg := fmt.Sprintf("ERROR: %s", format) if len(args) > 0 { msg = fmt.Sprintf(msg, args...) } log.Println(msg) } martian-3.3.2/log/log_test.go000066400000000000000000000026611421371434000161060ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package log import ( "bytes" "os" "strings" "testing" stdlog "log" ) func TestLog(t *testing.T) { buf := new(bytes.Buffer) stdlog.SetOutput(buf) defer stdlog.SetOutput(os.Stdout) // Reset log level after tests. defer func(l int) { level = l }(level) level = Debug Infof("log: %s test", "info") if got, want := buf.String(), "INFO: log: info test\n"; !strings.HasSuffix(got, want) { t.Errorf("Infof(): got %q, want to contain %q", got, want) } Debugf("log: %s test", "debug") if got, want := buf.String(), "DEBUG: log: debug test\n"; !strings.HasSuffix(got, want) { t.Errorf("Debugf(): got %q, want to contain %q", got, want) } Errorf("log: %s test", "error") if got, want := buf.String(), "ERROR: log: error test\n"; !strings.HasSuffix(got, want) { t.Errorf("Errorf(): got %q, want to contain %q", got, want) } } martian-3.3.2/marbl/000077500000000000000000000000001421371434000142465ustar00rootroot00000000000000martian-3.3.2/marbl/handler.go000066400000000000000000000053251421371434000162170ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package marbl import ( "crypto/rand" "encoding/hex" "net/http" "sync" "github.com/google/martian/v3/log" "golang.org/x/net/websocket" ) // Handler exposes marbl logs over websockets. type Handler struct { mu sync.RWMutex subs map[string]chan<- []byte } // NewHandler instantiates a Handler with an empty set of subscriptions. func NewHandler() *Handler { return &Handler{ subs: make(map[string]chan<- []byte), } } // Write writes frames to all websocket subscribers and returns the number // of bytes written and an error. func (h *Handler) Write(b []byte) (int, error) { h.mu.RLock() defer h.mu.RUnlock() var wg sync.WaitGroup for id, framec := range h.subs { wg.Add(1) go func(id string, fc chan<- []byte) { defer wg.Done() select { case fc <- b: default: log.Errorf("logstream: buffer full for connection, dropping") go h.unsubscribe(id) } }(id, framec) } wg.Wait() return len(b), nil } func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { websocket.Server{Handler: h.streamLogs}.ServeHTTP(rw, req) } func (h *Handler) streamLogs(conn *websocket.Conn) { defer conn.Close() id, err := newID() if err != nil { log.Errorf("logstream: failed to create ID: %v", err) return } framec := make(chan []byte, 16384) h.subscribe(id, framec) defer h.unsubscribe(id) for b := range framec { if err := websocket.Message.Send(conn, b); err != nil { log.Errorf("logstream: failed to send message: %v", err) return } } } func newID() (string, error) { src := make([]byte, 8) if _, err := rand.Read(src); err != nil { return "", err } return hex.EncodeToString(src), nil } func (h *Handler) unsubscribe(id string) { h.mu.Lock() defer h.mu.Unlock() if fc, ok := h.subs[id]; ok { close(fc) delete(h.subs, id) } } func (h *Handler) subscribe(id string, framec chan<- []byte) { h.mu.Lock() defer h.mu.Unlock() if fc, ok := h.subs[id]; ok { // TODO: Re-pick the id. log.Errorf("Resubscribing with ID: %v", id) // Close the channel for now so the websocket gets disconnected, // instead of silently failing. close(fc) } h.subs[id] = framec } martian-3.3.2/marbl/handler_test.go000066400000000000000000000061011421371434000172470ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package marbl import ( "fmt" "math/rand" "net" "net/http" "strconv" "testing" "time" "golang.org/x/net/websocket" ) func TestStreamsInSentOrder(t *testing.T) { l, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } handler := NewHandler() go http.Serve(l, handler) ws, err := websocket.Dial(fmt.Sprintf("ws://%s", l.Addr()), "", "http://localhost/") if err != nil { t.Fatalf("websocket.Dial(): got %v, want no error", err) } defer ws.Close() // Gives handler time to create the subscription channel. time.Sleep(200 * time.Millisecond) ws.SetDeadline(time.Now().Add(5 * time.Second)) // server could still be in the processs of registering the client // no easy way to synchronize so we just wait a bit time.Sleep(300 * time.Millisecond) var iterations int64 = 5000 go func() { for i := int64(0); i < iterations; i++ { hex := strconv.FormatInt(int64(i), 16) handler.Write([]byte(hex)) } }() for i := int64(0); i < iterations; i++ { var bytes []byte err = websocket.Message.Receive(ws, &bytes) if err != nil { t.Fatalf("websocket.Conn.Read(): got %v, want no error", err) } parsed, err := strconv.ParseInt(string(bytes), 16, 64) if err != nil { t.Fatalf("strconv.ParseInt(): got %v, want no error", err) } if parsed != i { t.Fatalf("Messages arrived out of order, got %d want %d", parsed, i) } } } func TestUnreadsDontBlock(t *testing.T) { l, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } handler := NewHandler() go http.Serve(l, handler) ws, err := websocket.Dial(fmt.Sprintf("ws://%s", l.Addr()), "", "http://localhost/") if err != nil { t.Fatalf("websocket.Dial(): got %v, want no error", err) } defer ws.Close() // Gives handler time to create the subscription channel. time.Sleep(200 * time.Millisecond) bytes := make([]byte, 1024) _, err = rand.Read(bytes) if err != nil { t.Fatalf("rand.Read(): got %v, want no error", err) } // Purposely using more iterations than frame channel size. var iterations int64 = 50000 for i := int64(0); i < iterations; i++ { to := doOrTimeout(3*time.Second, func() { handler.Write(bytes) }) if to { t.Fatalf("handler.Write() Timed out") } } } func doOrTimeout(d time.Duration, f func()) bool { done := make(chan interface{}) go func() { f() done <- 1 }() select { case <-done: return false case <-time.After(d): return true } } martian-3.3.2/marbl/marbl.go000066400000000000000000000141741421371434000157010ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package marbl provides HTTP traffic logs streamed over websockets // that can be added to any point within a Martian modifier tree. // Marbl transmits HTTP logs that are serialized based on the following // schema: // // Frame Header // FrameType uint8 // MessageType uint8 // ID [8]byte // Payload HeaderFrame/DataFrame // // Header Frame // NameLen uint32 // ValueLen uint32 // Name variable // Value variable // // Data Frame // Index uint32 // Terminal uint8 // Len uint32 // Data variable package marbl import ( "io" "net/http" "strconv" "sync/atomic" "time" "github.com/google/martian/v3" "github.com/google/martian/v3/log" "github.com/google/martian/v3/proxyutil" ) // MessageType incicates whether the message represents an HTTP request or response. type MessageType uint8 const ( // Unknown type of Message. Unknown MessageType = 0x0 // Request indicates a message that contains an HTTP request. Request MessageType = 0x1 // Response indicates a message that contains an HTTP response. Response MessageType = 0x2 ) // FrameType indicates whether the frame contains a Header or Data. type FrameType uint8 const ( // UnknownFrame indicates an unknown type of Frame. UnknownFrame FrameType = 0x0 // HeaderFrame indicates a frame that contains a header. HeaderFrame FrameType = 0x1 // DataFrame indicates a frame that contains the payload, usually the body. DataFrame FrameType = 0x2 ) // Stream writes logs of requests and responses to a writer. type Stream struct { w io.Writer framec chan []byte closec chan struct{} } // NewStream initializes a Stream with an io.Writer to log requests and // responses to. Upon construction, a goroutine is started that listens for frames // and writes them to w. func NewStream(w io.Writer) *Stream { s := &Stream{ w: w, framec: make(chan []byte), closec: make(chan struct{}), } go s.loop() return s } func (s *Stream) loop() { for { select { case f := <-s.framec: _, err := s.w.Write(f) if err != nil { log.Errorf("martian: Error while writing frame") } case <-s.closec: return } } } // Close signals Stream to stop listening for frames in the log loop and stop writing logs. func (s *Stream) Close() error { s.closec <- struct{}{} close(s.closec) return nil } func newFrame(id string, ft FrameType, mt MessageType, plen uint32) []byte { f := make([]byte, 0, 10+plen) f = append(f, byte(ft), byte(mt)) f = append(f, id[:8]...) return f } func (s *Stream) sendHeader(id string, mt MessageType, key, value string) { kl := uint32(len(key)) vl := uint32(len(value)) f := newFrame(id, HeaderFrame, mt, 64+kl+vl) f = append(f, byte(kl>>24), byte(kl>>16), byte(kl>>8), byte(kl)) f = append(f, byte(vl>>24), byte(vl>>16), byte(vl>>8), byte(vl)) f = append(f, key[:kl]...) f = append(f, value[:vl]...) s.framec <- f } func (s *Stream) sendData(id string, mt MessageType, i uint32, terminal bool, b []byte, bl int) { var ti uint8 if terminal { ti = 1 } f := newFrame(id, DataFrame, mt, 72+uint32(bl)) f = append(f, byte(i>>24), byte(i>>16), byte(i>>8), byte(i)) f = append(f, byte(ti)) f = append(f, byte(bl>>24), byte(bl>>16), byte(bl>>8), byte(bl)) f = append(f, b[:bl]...) s.framec <- f } // LogRequest writes an http.Request to Stream with an id unique for the request / response pair. func (s *Stream) LogRequest(id string, req *http.Request) error { s.sendHeader(id, Request, ":method", req.Method) s.sendHeader(id, Request, ":scheme", req.URL.Scheme) s.sendHeader(id, Request, ":authority", req.URL.Host) s.sendHeader(id, Request, ":path", req.URL.EscapedPath()) s.sendHeader(id, Request, ":query", req.URL.RawQuery) s.sendHeader(id, Request, ":proto", req.Proto) s.sendHeader(id, Request, ":remote", req.RemoteAddr) ts := strconv.FormatInt(time.Now().UnixNano()/1000/1000, 10) s.sendHeader(id, Request, ":timestamp", ts) ctx := martian.NewContext(req) if ctx.IsAPIRequest() { s.sendHeader(id, Request, ":api", "true") } h := proxyutil.RequestHeader(req) for k, vs := range h.Map() { for _, v := range vs { s.sendHeader(id, Request, k, v) } } req.Body = &bodyLogger{ s: s, id: id, mt: Request, body: req.Body, } return nil } // LogResponse writes an http.Response to Stream with an id unique for the request / response pair. func (s *Stream) LogResponse(id string, res *http.Response) error { s.sendHeader(id, Response, ":proto", res.Proto) s.sendHeader(id, Response, ":status", strconv.Itoa(res.StatusCode)) s.sendHeader(id, Response, ":reason", res.Status) ts := strconv.FormatInt(time.Now().UnixNano()/1000/1000, 10) s.sendHeader(id, Response, ":timestamp", ts) ctx := martian.NewContext(res.Request) if ctx.IsAPIRequest() { s.sendHeader(id, Response, ":api", "true") } h := proxyutil.ResponseHeader(res) for k, vs := range h.Map() { for _, v := range vs { s.sendHeader(id, Response, k, v) } } res.Body = &bodyLogger{ s: s, id: id, mt: Response, body: res.Body, } return nil } type bodyLogger struct { index uint32 // atomic s *Stream id string mt MessageType body io.ReadCloser } // Read implements the standard Reader interface. Read reads the bytes of the body // and returns the number of bytes read and an error. func (bl *bodyLogger) Read(b []byte) (int, error) { var terminal bool n, err := bl.body.Read(b) if err == io.EOF { terminal = true } bl.s.sendData(bl.id, bl.mt, atomic.AddUint32(&bl.index, 1)-1, terminal, b, n) return n, err } // Close closes the bodyLogger. func (bl *bodyLogger) Close() error { return bl.body.Close() } martian-3.3.2/marbl/marbl_test.go000066400000000000000000000261641421371434000167420ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package marbl import ( "bytes" "io" "net/http" "strconv" "strings" "testing" "time" "github.com/google/martian/v3" "github.com/google/martian/v3/proxyutil" ) func TestMarkAPIRequestsWithHeader(t *testing.T) { areq, err := http.NewRequest("POST", "http://localhost:8080/configure", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } ctx, remove, err := martian.TestContext(areq, nil, nil) if err != nil { t.Fatalf("TestContext(): got %v, want no error", err) } defer remove() ctx.APIRequest() req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } _, removereq, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("TestContext(): got %v, want no error", err) } defer removereq() var b bytes.Buffer s := NewStream(&b) s.LogRequest("00000000", areq) s.LogRequest("00000001", req) s.Close() headers := make(map[string]string) reader := NewReader(&b) for { frame, err := reader.ReadFrame() if frame == nil { break } if err != nil && err != io.EOF { t.Fatalf("reader.ReadFrame(): got %v, want no error or io.EOF", err) } headerFrame, ok := frame.(Header) if !ok { t.Fatalf("frame.(Header): couldn't convert frame '%v' to a headerFrame", frame) } headers[headerFrame.ID+headerFrame.Name] = headerFrame.Value } apih, ok := headers["00000000:api"] if !ok { t.Errorf("headers[00000000:api]: got no such header, want :api (headers were: %v)", headers) } if got, want := apih, "true"; got != want { t.Errorf("headers[%q]: got %v, want %q", "00000000:api", got, want) } _, ok = headers["00000001:api"] if got, want := ok, false; got != want { t.Error("headers[00000001:api]: got :api header, want no header for non-api requests") } } func TestSendTimestampWithLogRequest(t *testing.T) { req, err := http.NewRequest("POST", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } _, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("TestContext(): got %v, want no error", err) } defer remove() var b bytes.Buffer s := NewStream(&b) before := time.Now().UnixNano() / 1000 / 1000 s.LogRequest("Fake_Id0", req) s.Close() after := time.Now().UnixNano() / 1000 / 1000 headers := make(map[string]string) reader := NewReader(&b) for { frame, err := reader.ReadFrame() if frame == nil { break } if err != nil && err != io.EOF { t.Fatalf("reader.ReadFrame(): got %v, want no error or io.EOF", err) } headerFrame, ok := frame.(Header) if !ok { t.Fatalf("frame.(Header): couldn't convert frame '%v' to a headerFrame", frame) } headers[headerFrame.Name] = headerFrame.Value } timestr, ok := headers[":timestamp"] if !ok { t.Fatalf("headers[:timestamp]: got no such header, want :timestamp (headers were: %v)", headers) } ts, err := strconv.ParseInt(timestr, 10, 64) if err != nil { t.Fatalf("strconv.ParseInt: got %s, want no error. Invalidly formatted timestamp ('%s')", err, timestr) } if ts < before || ts > after { t.Fatalf("headers[:timestamp]: got %d, want timestamp between %d and %d", ts, before, after) } } func TestSendTimestampWithLogResponse(t *testing.T) { req, err := http.NewRequest("POST", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } _, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("TestContext(): got %v, want no error", err) } defer remove() res := proxyutil.NewResponse(200, nil, req) var b bytes.Buffer s := NewStream(&b) before := time.Now().UnixNano() / 1000 / 1000 s.LogResponse("Fake_Id1", res) s.Close() after := time.Now().UnixNano() / 1000 / 1000 headers := make(map[string]string) reader := NewReader(&b) for { frame, err := reader.ReadFrame() if frame == nil { break } if err != nil && err != io.EOF { t.Fatalf("reader.ReadFrame(): got %v, want no error or io.EOF", err) } headerFrame, ok := frame.(Header) if !ok { t.Fatalf("frame.(Header): couldn't convert frame '%v' to a headerFrame", frame) } headers[headerFrame.Name] = headerFrame.Value } timestr, ok := headers[":timestamp"] if !ok { t.Fatalf("headers[:timestamp]: got no such header, want :timestamp (headers were: %v)", headers) } ts, err := strconv.ParseInt(timestr, 10, 64) if err != nil { t.Fatalf("strconv.ParseInt: got %s, want no error. Invalidly formatted timestamp ('%s')", err, timestr) } if ts < before || ts > after { t.Fatalf("headers[:timestamp]: got %d, want timestamp between %d and %d (headers were: %v)", ts, before, after, headers) } } func TestBodyLoggingWithOneRead(t *testing.T) { // Test scenario: // 1. Prepare HTTP request with body containing a string. // 2. Initialize marbl logging on this request. // 3. Read body of the request in single Read() and verity that it matches // original string. // 4. Parse marbl data, extract DataFrames and verify that they match // . original string. body := "hello, world" req, err := http.NewRequest("POST", "http://example.com", strings.NewReader(body)) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } _, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("TestContext(): got %v, want no error", err) } defer remove() var b bytes.Buffer s := NewStream(&b) s.LogRequest("Fake_Id0", req) // Read request body into big slice. bodybytes := make([]byte, 100) // First read. Due to implementation details of strings.Read // it reads all bytes but doesn't return EOF. n, err := req.Body.Read(bodybytes) if n != len(body) { t.Fatalf("req.Body.Read(): expected to read %v bytes but read %v", len(body), n) } if body != string(bodybytes[:n]) { t.Fatalf("req.Body.Read(): expected to read %v but read %v", body, string(bodybytes[:n])) } if err != nil { t.Fatalf("req.Body.Read(): first read expected to be successful but got error %v", err) } // second read. We already consumed the whole string on the first read // so now it should be 0 bytes and EOF. n, err = req.Body.Read(bodybytes) if n != 0 { t.Fatalf("req.Body.Read(): expected to read 0 bytes but read %v", n) } if err != io.EOF { t.Fatalf("req.Body.Read(): expected EOF but got %v", err) } s.Close() reader := NewReader(&b) bodybytes = readAllDataFrames(reader, "Fake_Id0", t) if len(bodybytes) != len(body) { t.Fatalf("readAllDataFrames(): expected .marbl data to have %v bytes, but got %v", len(body), len(bodybytes)) } if body != string(bodybytes) { t.Fatalf("readAllDataFrames(): expected .marbl data to have string %v but got %v", body, string(bodybytes)) } } func TestBodyLogging_ManyReads(t *testing.T) { // Test scenario: // 1. Prepare HTTP request with body containing a string. // 2. Initialize marbl logging on this request. // 3. Read body of the request in many reads, 1 byte per read and // . verify that it matches original string. // 4. Parse marbl data, extract DataFrames and verify that they match // . original string. body := "hello, world" req, err := http.NewRequest("POST", "http://example.com", strings.NewReader(body)) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } _, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("TestContext(): got %v, want no error", err) } defer remove() var b bytes.Buffer s := NewStream(&b) s.LogRequest("Fake_Id0", req) // Read request body into single byte slice. bodybytes := make([]byte, 1) for i := 0; i < len(body); i++ { // first read n, err := req.Body.Read(bodybytes) if n != 1 { t.Fatalf("req.Body.Read(): expected to read 1 byte but read %v", n) } if body[i] != bodybytes[0] { t.Fatalf("req.Body.Read(): expected to read %v but read %v", body[i], bodybytes[0]) } if err != nil { t.Fatalf("req.Body.Read(): read expected to be successfully but got error %v", err) } } // last read. We already consumed the whole string on the previous reads // so now it should be 0 bytes and EOF. n, err := req.Body.Read(bodybytes) if n != 0 { t.Fatalf("req.Body.Read(): expected to read 0 bytes but read %v", n) } if err != io.EOF { t.Fatalf("req.Body.Read(): expected EOF but got %v", err) } s.Close() reader := NewReader(&b) bodybytes = readAllDataFrames(reader, "Fake_Id0", t) if len(bodybytes) != len(body) { t.Fatalf("readAllDataFrames(): expected .marbl data to have %v bytes, but got %v", len(body), len(bodybytes)) } if body != string(bodybytes) { t.Fatalf("readAllDataFrames(): expected .marbl data to have string %v but got %v", body, string(bodybytes)) } } func TestReturnOriginalRequestPathAndQuery(t *testing.T) { req, err := http.NewRequest("GET", "http://example.com/foo%20bar?baz%20qux", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } _, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("TestContext(): got %v, want no error", err) } defer remove() var b bytes.Buffer s := NewStream(&b) s.LogRequest("Fake_Id0", req) s.Close() headers := make(map[string]string) reader := NewReader(&b) for { frame, err := reader.ReadFrame() if frame == nil { break } if err != nil && err != io.EOF { t.Fatalf("reader.ReadFrame(): got %v, want no error or io.EOF", err) } headerFrame, ok := frame.(Header) if !ok { t.Fatalf("frame.(Header): couldn't convert frame '%v' to a headerFrame", frame) } headers[headerFrame.Name] = headerFrame.Value } path := headers[":path"] if path != "/foo%20bar" { t.Fatalf("headers[:path]: expected /foo%%20bar but got %s", path) } query := headers[":query"] if query != "baz%20qux" { t.Fatalf("headers[:query]: expected baz%%20qux but got %s", query) } } // readAllDataFrames reads all DataFrames with reader, filters the one that match provided // id and assembles data from all frames into single slice. It expects that // there is only one slice of DataFrames with provided id. func readAllDataFrames(reader *Reader, id string, t *testing.T) []byte { res := make([]byte, 0) term := false var i uint32 for { frame, _ := reader.ReadFrame() if frame == nil { break } if frame.FrameType() == DataFrame { df := frame.(Data) if df.ID != id { continue } if term { t.Fatal("DataFrame after terminal frame are not allowed.") } if df.Index != i { t.Fatalf("expected DataFrame index %v but got %v", i, df.Index) } term = df.Terminal res = append(res, df.Data...) i++ } } if !term { t.Fatal("didn't see terminal DataFrame") } return res } martian-3.3.2/marbl/modifier.go000066400000000000000000000026051421371434000163760ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package marbl import ( "io" "net/http" "github.com/google/martian/v3" ) // Modifier implements the Martian modifier interface so that marbl logs // can be captured at any point in a Martian modifier tree. type Modifier struct { s *Stream } // NewModifier returns a marbl.Modifier initialized with a marbl.Stream. func NewModifier(w io.Writer) *Modifier { return &Modifier{ s: NewStream(w), } } // ModifyRequest writes an HTTP request to the log stream. func (m *Modifier) ModifyRequest(req *http.Request) error { ctx := martian.NewContext(req) return m.s.LogRequest(ctx.ID(), req) } // ModifyResponse writes an HTTP response to the log stream. func (m *Modifier) ModifyResponse(res *http.Response) error { ctx := martian.NewContext(res.Request) return m.s.LogResponse(ctx.ID(), res) } martian-3.3.2/marbl/reader.go000066400000000000000000000071171421371434000160450ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package marbl import ( "bufio" "encoding/binary" "fmt" "io" ) // Header is either an HTTP header or meta-data pertaining to the request or response. type Header struct { ID string MessageType MessageType Name string Value string } // String returns the contents of a Header frame in a format appropriate for debugging and runtime logging. func (hf Header) String() string { return fmt.Sprintf("ID=%s; Type=%d; Name=%s; Value=%s", hf.ID, hf.MessageType, hf.Name, hf.Value) } // FrameType returns HeaderFrame func (hf Header) FrameType() FrameType { return HeaderFrame } // Data is the payload (body) of the request or response. type Data struct { ID string MessageType MessageType Index uint32 Terminal bool Data []byte } // String returns the contents of a Data frame in a format appropriate for debugging and runtime logging. The // contents of the data content slice (df.Data) is not printed, instead the length of Data is printed. func (df Data) String() string { return fmt.Sprintf("ID=%s; Type=%d; Index=%d; Terminal=%t; Data length=%d", df.ID, df.MessageType, df.Index, df.Terminal, len(df.Data)) } // FrameType returns DataFrame func (df Data) FrameType() FrameType { return DataFrame } // Frame describes the interface for a frame (either Data or Header). type Frame interface { String() string FrameType() FrameType } // Reader wraps a buffered Reader that reads from the io.Reader and emits Frames. type Reader struct { r io.Reader } // NewReader returns a Reader initialized with a buffered reader. func NewReader(r io.Reader) *Reader { return &Reader{ r: bufio.NewReader(r), } } // ReadFrame reads from r, determines the FrameType, and returns either a Header or Data and an error. func (r *Reader) ReadFrame() (Frame, error) { fh := make([]byte, 10) if _, err := io.ReadFull(r.r, fh); err != nil { return nil, err } switch FrameType(fh[0]) { case HeaderFrame: hf := Header{ ID: string(fh[2:]), MessageType: MessageType(fh[1]), } lens := make([]byte, 8) if _, err := io.ReadFull(r.r, lens); err != nil { return nil, err } nl := binary.BigEndian.Uint32(lens[:4]) vl := binary.BigEndian.Uint32(lens[4:]) nv := make([]byte, int(nl+vl)) if _, err := io.ReadFull(r.r, nv); err != nil { return nil, err } hf.Name = string(nv[:nl]) hf.Value = string(nv[nl:]) return hf, nil case DataFrame: df := Data{ ID: string(fh[2:]), MessageType: MessageType(fh[1]), } // Reading 9 bytes: // 4 bytes index // 1 byte terminal // 4 bytes data length desc := make([]byte, 9) if _, err := io.ReadFull(r.r, desc); err != nil { return nil, err } df.Index = binary.BigEndian.Uint32(desc[:4]) df.Terminal = desc[4] == 1 dl := binary.BigEndian.Uint32(desc[5:]) data := make([]byte, int(dl)) if _, err := io.ReadFull(r.r, data); err != nil { return nil, err } df.Data = data return df, nil default: return nil, fmt.Errorf("marbl: unknown type of frame") } } martian-3.3.2/martian.go000066400000000000000000000037351421371434000151430ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package martian provides an HTTP/1.1 proxy with an API for configurable // request and response modifiers. package martian import "net/http" // RequestModifier is an interface that defines a request modifier that can be // used by a proxy. type RequestModifier interface { // ModifyRequest modifies the request. ModifyRequest(req *http.Request) error } // ResponseModifier is an interface that defines a response modifier that can // be used by a proxy. type ResponseModifier interface { // ModifyResponse modifies the response. ModifyResponse(res *http.Response) error } // RequestResponseModifier is an interface that is both a ResponseModifier and // a RequestModifier. type RequestResponseModifier interface { RequestModifier ResponseModifier } // RequestModifierFunc is an adapter for using a function with the given // signature as a RequestModifier. type RequestModifierFunc func(req *http.Request) error // ResponseModifierFunc is an adapter for using a function with the given // signature as a ResponseModifier. type ResponseModifierFunc func(res *http.Response) error // ModifyRequest modifies the request using the given function. func (f RequestModifierFunc) ModifyRequest(req *http.Request) error { return f(req) } // ModifyResponse modifies the response using the given function. func (f ResponseModifierFunc) ModifyResponse(res *http.Response) error { return f(res) } martian-3.3.2/martian_test.go000066400000000000000000000033441421371434000161760ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package martian import ( "net/http" "testing" "github.com/google/martian/v3/proxyutil" ) func TestModifierFuncs(t *testing.T) { reqmod := RequestModifierFunc( func(req *http.Request) error { req.Header.Set("Request-Modified", "true") return nil }) resmod := ResponseModifierFunc( func(res *http.Response) error { res.Header.Set("Response-Modified", "true") return nil }) req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := reqmod.ModifyRequest(req); err != nil { t.Fatalf("reqmod.ModifyRequest(): got %v, want no error", err) } if got, want := req.Header.Get("Request-Modified"), "true"; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", "Request-Modified", got, want) } res := proxyutil.NewResponse(200, nil, req) if err := resmod.ModifyResponse(res); err != nil { t.Fatalf("resmod.ModifyResponse(): got %v, want no error", err) } if got, want := res.Header.Get("Response-Modified"), "true"; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Response-Modified", got, want) } } martian-3.3.2/martianhttp/000077500000000000000000000000001421371434000155045ustar00rootroot00000000000000martian-3.3.2/martianhttp/authority_handler.go000066400000000000000000000023671421371434000215700ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package martianhttp import ( "crypto/x509" "encoding/pem" "net/http" ) type authorityHandler struct { cert []byte } // NewAuthorityHandler returns an http.Handler that will present the client // with the CA certificate to use in browser. func NewAuthorityHandler(ca *x509.Certificate) http.Handler { return &authorityHandler{ cert: pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: ca.Raw, }), } } // ServeHTTP writes the CA certificate in PEM format to the client. func (h *authorityHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { rw.Header().Set("Content-Type", "application/x-x509-ca-cert") rw.Write(h.cert) } martian-3.3.2/martianhttp/authority_handler_test.go000066400000000000000000000035771421371434000226330ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package martianhttp import ( "crypto/x509" "encoding/pem" "net/http" "net/http/httptest" "testing" "time" "github.com/google/martian/v3/mitm" ) func TestAuthorityHandler(t *testing.T) { ca, _, err := mitm.NewAuthority("martian.proxy", "Martian Authority", time.Hour) if err != nil { t.Fatalf("mitm.NewAuthority(): got %v, want no error", err) } rw := httptest.NewRecorder() req, err := http.NewRequest("GET", "/martian/authority.cer", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } h := NewAuthorityHandler(ca) h.ServeHTTP(rw, req) if got, want := rw.Code, 200; got != want { t.Errorf("rw.Code: got %d, want %d", got, want) } if got, want := rw.Header().Get("Content-Type"), "application/x-x509-ca-cert"; got != want { t.Errorf("rw.Header().Get(%q): got %q, want %q", "Content-Type", got, want) } blk, _ := pem.Decode(rw.Body.Bytes()) if got, want := blk.Type, "CERTIFICATE"; got != want { t.Errorf("rw.Body: got PEM type %q, want %q", got, want) } cert, err := x509.ParseCertificate(blk.Bytes) if err != nil { t.Fatalf("x509.ParseCertificate(res.Body): got %v, want no error", err) } if got, want := cert.Subject.CommonName, "martian.proxy"; got != want { t.Errorf("cert.Subject.CommonName: got %q, want %q", got, want) } } martian-3.3.2/martianhttp/martianhttp.go000066400000000000000000000115141421371434000203700ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package martianhttp provides HTTP handlers for managing the state of a martian.Proxy. package martianhttp import ( "bytes" "encoding/json" "io/ioutil" "net/http" "sync" "github.com/google/martian/v3" "github.com/google/martian/v3/log" "github.com/google/martian/v3/parse" "github.com/google/martian/v3/verify" ) var noop = martian.Noop("martianhttp.Modifier") // Modifier is a locking modifier that is configured via http.Handler. type Modifier struct { mu sync.RWMutex config []byte reqmod martian.RequestModifier resmod martian.ResponseModifier } // NewModifier returns a new martianhttp.Modifier. func NewModifier() *Modifier { return &Modifier{ reqmod: noop, resmod: noop, } } // SetRequestModifier sets the request modifier. func (m *Modifier) SetRequestModifier(reqmod martian.RequestModifier) { m.mu.Lock() defer m.mu.Unlock() m.setRequestModifier(reqmod) } func (m *Modifier) setRequestModifier(reqmod martian.RequestModifier) { if reqmod == nil { reqmod = noop } m.reqmod = reqmod } // SetResponseModifier sets the response modifier. func (m *Modifier) SetResponseModifier(resmod martian.ResponseModifier) { m.mu.Lock() defer m.mu.Unlock() m.setResponseModifier(resmod) } func (m *Modifier) setResponseModifier(resmod martian.ResponseModifier) { if resmod == nil { resmod = noop } m.resmod = resmod } // ModifyRequest runs reqmod. func (m *Modifier) ModifyRequest(req *http.Request) error { m.mu.RLock() defer m.mu.RUnlock() return m.reqmod.ModifyRequest(req) } // ModifyResponse runs resmod. func (m *Modifier) ModifyResponse(res *http.Response) error { m.mu.RLock() defer m.mu.RUnlock() return m.resmod.ModifyResponse(res) } // VerifyRequests verifies reqmod, iff reqmod is a RequestVerifier. func (m *Modifier) VerifyRequests() error { m.mu.RLock() defer m.mu.RUnlock() if reqv, ok := m.reqmod.(verify.RequestVerifier); ok { return reqv.VerifyRequests() } return nil } // VerifyResponses verifies resmod, iff resmod is a ResponseVerifier. func (m *Modifier) VerifyResponses() error { m.mu.RLock() defer m.mu.RUnlock() if resv, ok := m.resmod.(verify.ResponseVerifier); ok { return resv.VerifyResponses() } return nil } // ResetRequestVerifications resets verifications on reqmod, iff reqmod is a // RequestVerifier. func (m *Modifier) ResetRequestVerifications() { m.mu.Lock() defer m.mu.Unlock() if reqv, ok := m.reqmod.(verify.RequestVerifier); ok { reqv.ResetRequestVerifications() } } // ResetResponseVerifications resets verifications on resmod, iff resmod is a // ResponseVerifier. func (m *Modifier) ResetResponseVerifications() { m.mu.Lock() defer m.mu.Unlock() if resv, ok := m.resmod.(verify.ResponseVerifier); ok { resv.ResetResponseVerifications() } } // ServeHTTP sets or retrieves the JSON-encoded modifier configuration // depending on request method. POST requests are expected to provide a JSON // modifier message in the body which will be used to update the contained // request and response modifiers. GET requests will return the JSON // (pretty-printed) for the most recent configuration. func (m *Modifier) ServeHTTP(rw http.ResponseWriter, req *http.Request) { switch req.Method { case "POST": m.servePOST(rw, req) return case "GET": m.serveGET(rw, req) return default: rw.Header().Set("Allow", "GET, POST") rw.WriteHeader(405) } } func (m *Modifier) servePOST(rw http.ResponseWriter, req *http.Request) { body, err := ioutil.ReadAll(req.Body) if err != nil { http.Error(rw, err.Error(), 500) log.Errorf("martianhttp: error reading request body: %v", err) return } req.Body.Close() r, err := parse.FromJSON(body) if err != nil { http.Error(rw, err.Error(), 400) log.Errorf("martianhttp: error parsing JSON: %v", err) return } buf := new(bytes.Buffer) if err := json.Indent(buf, body, "", " "); err != nil { http.Error(rw, err.Error(), 400) log.Errorf("martianhttp: error formatting JSON: %v", err) return } m.mu.Lock() defer m.mu.Unlock() m.config = buf.Bytes() m.setRequestModifier(r.RequestModifier()) m.setResponseModifier(r.ResponseModifier()) } func (m *Modifier) serveGET(rw http.ResponseWriter, req *http.Request) { m.mu.RLock() defer m.mu.RUnlock() rw.Header().Set("Content-Type", "application/json") rw.Write(m.config) } martian-3.3.2/martianhttp/martianhttp_integration_test.go000066400000000000000000000044541421371434000240370ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package martianhttp import ( "net" "net/http" "net/http/httptest" "net/url" "strings" "testing" "github.com/google/martian/v3" "github.com/google/martian/v3/martiantest" _ "github.com/google/martian/v3/header" ) func TestIntegration(t *testing.T) { ptr := martiantest.NewTransport() proxy := martian.NewProxy() defer proxy.Close() proxy.SetRoundTripper(ptr) l, err := net.Listen("tcp", "[::]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } go proxy.Serve(l) m := NewModifier() proxy.SetRequestModifier(m) proxy.SetResponseModifier(m) mux := http.NewServeMux() mux.Handle("/", m) s := httptest.NewServer(mux) defer s.Close() body := strings.NewReader(`{ "header.Modifier": { "scope": ["request", "response"], "name": "Martian-Test", "value": "true" } }`) res, err := http.Post(s.URL, "application/json", body) if err != nil { t.Fatalf("http.Post(%s): got %v, want no error", s.URL, err) } res.Body.Close() if got, want := res.StatusCode, 200; got != want { t.Fatalf("res.StatusCode: got %d, want %d", got, want) } tr := &http.Transport{ Proxy: http.ProxyURL(&url.URL{ Scheme: "http", Host: l.Addr().String(), }), } defer tr.CloseIdleConnections() req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Header.Set("Connection", "close") res, err = tr.RoundTrip(req) if err != nil { t.Fatalf("transport.RoundTrip(%q): got %v, want no error", req.URL, err) } res.Body.Close() if got, want := res.Header.Get("Martian-Test"), "true"; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Martian-Test", got, want) } } martian-3.3.2/martianhttp/martianhttp_test.go000066400000000000000000000145631421371434000214360ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package martianhttp import ( "bytes" "encoding/json" "fmt" "net/http" "net/http/httptest" "testing" "github.com/google/martian/v3/martiantest" "github.com/google/martian/v3/proxyutil" "github.com/google/martian/v3/verify" _ "github.com/google/martian/v3/header" ) func TestNoModifiers(t *testing.T) { m := NewModifier() m.SetRequestModifier(nil) m.SetResponseModifier(nil) req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := m.ModifyRequest(req); err != nil { t.Errorf("ModifyRequest(): got %v, want no error", err) } res := proxyutil.NewResponse(200, nil, req) if err := m.ModifyResponse(res); err != nil { t.Errorf("ModifyResponse(): got %v, want no error", err) } } func TestModifyRequest(t *testing.T) { m := NewModifier() tm := martiantest.NewModifier() m.SetRequestModifier(tm) req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := m.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if !tm.RequestModified() { t.Error("tm.RequestModified(): got false, want true") } m.SetRequestModifier(nil) if err := m.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } } func TestModifyResponse(t *testing.T) { m := NewModifier() tm := martiantest.NewModifier() m.SetResponseModifier(tm) res := proxyutil.NewResponse(200, nil, nil) if err := m.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if !tm.ResponseModified() { t.Error("tm.ResponseModified(): got false, want true") } m.SetResponseModifier(nil) if err := m.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } } func TestVerifyRequests(t *testing.T) { m := NewModifier() if err := m.VerifyRequests(); err != nil { t.Errorf("VerifyRequests(): got %v, want no error", err) } verr := fmt.Errorf("request verification failure") m.SetRequestModifier(&verify.TestVerifier{ RequestError: verr, }) if err := m.VerifyRequests(); err != verr { t.Errorf("VerifyRequests(): got %v, want %v", err, verr) } m.ResetRequestVerifications() if err := m.VerifyRequests(); err != nil { t.Errorf("m.VerifyRequests(): got %v, want no error", err) } } func TestVerifyResponses(t *testing.T) { m := NewModifier() if err := m.VerifyResponses(); err != nil { t.Errorf("VerifyResponses(): got %v, want no error", err) } verr := fmt.Errorf("response verification failure") m.SetResponseModifier(&verify.TestVerifier{ ResponseError: verr, }) if err := m.VerifyResponses(); err != verr { t.Errorf("VerifyResponses(): got %v, want %v", err, verr) } m.ResetResponseVerifications() if err := m.VerifyResponses(); err != nil { t.Errorf("m.VerifyResponses(): got %v, want no error", err) } } func TestServeHTTPInvalidMethod(t *testing.T) { m := NewModifier() req, err := http.NewRequest("PATCH", "/configure", nil) if err != nil { t.Fatalf("http.NewRequest(%q, ...): got %v, want no error", "GET", err) } rw := httptest.NewRecorder() m.ServeHTTP(rw, req) if got, want := rw.Code, 405; got != want { t.Errorf("rw.Code: got %d, want %d", got, want) } if got, want := rw.Header().Get("Allow"), "GET, POST"; got != want { t.Errorf("rw.Header().Get(%q): got %q, want %q", "Allow", got, want) } } func TestServeHTTPInvalidJSON(t *testing.T) { m := NewModifier() req, err := http.NewRequest("POST", "/configure", bytes.NewReader([]byte("not-json"))) if err != nil { t.Fatalf("http.NewRequest(%q, %q, ...): got %v, want no error", "POST", "/configure", err) } rw := httptest.NewRecorder() m.ServeHTTP(rw, req) if got, want := rw.Code, 400; got != want { t.Errorf("rw.Code: got %d, want %d", got, want) } } func TestServeHTTP(t *testing.T) { m := NewModifier() body := []byte(`{ "header.Modifier": { "scope": ["request", "response"], "name": "Martian-Test", "value": "true" } }`) req, err := http.NewRequest("POST", "/configure", bytes.NewReader(body)) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Header.Set("Content-Type", "application/json") rw := httptest.NewRecorder() m.ServeHTTP(rw, req) if got, want := rw.Code, 200; got != want { t.Errorf("rw.Code: got %d, want %d", got, want) } req, err = http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := m.ModifyRequest(req); err != nil { t.Fatalf("m.ModifyRequest(): got %v, want no error", err) } if got, want := req.Header.Get("Martian-Test"), "true"; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", "Martian-Test", got, want) } res := proxyutil.NewResponse(200, nil, req) if err := m.ModifyResponse(res); err != nil { t.Fatalf("m.ModifyResponse(): got %v, want no error", err) } if got, want := res.Header.Get("Martian-Test"), "true"; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Martian-Test", got, want) } req, err = http.NewRequest("GET", "/configure", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } rw = httptest.NewRecorder() m.ServeHTTP(rw, req) if got, want := rw.Code, 200; got != want { t.Errorf("rw.Code: got %d, want %d", got, want) } got := new(bytes.Buffer) want := new(bytes.Buffer) if err := json.Compact(got, body); err != nil { t.Fatalf("json.Compact(body): got %v, want no error", err) } if err := json.Compact(want, rw.Body.Bytes()); err != nil { t.Fatalf("json.Compact(rw.Body): got %v, want no error", err) } if !bytes.Equal(got.Bytes(), want.Bytes()) { t.Errorf("rw.Body: got %q, want %q", got.Bytes(), want.Bytes()) } } martian-3.3.2/martianlog/000077500000000000000000000000001421371434000153065ustar00rootroot00000000000000martian-3.3.2/martianlog/logger.go000066400000000000000000000114611421371434000171170ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package martianlog provides a Martian modifier that logs the request and response. package martianlog import ( "bytes" "encoding/json" "fmt" "io" "net/http" "strings" "github.com/google/martian/v3" "github.com/google/martian/v3/log" "github.com/google/martian/v3/messageview" "github.com/google/martian/v3/parse" ) // Logger is a modifier that logs requests and responses. type Logger struct { log func(line string) headersOnly bool decode bool } type loggerJSON struct { Scope []parse.ModifierType `json:"scope"` HeadersOnly bool `json:"headersOnly"` Decode bool `json:"decode"` } func init() { parse.Register("log.Logger", loggerFromJSON) } // NewLogger returns a logger that logs requests and responses, optionally // logging the body. Log function defaults to martian.Infof. func NewLogger() *Logger { return &Logger{ log: func(line string) { log.Infof(line) }, } } // SetHeadersOnly sets whether to log the request/response body in the log. func (l *Logger) SetHeadersOnly(headersOnly bool) { l.headersOnly = headersOnly } // SetDecode sets whether to decode the request/response body in the log. func (l *Logger) SetDecode(decode bool) { l.decode = decode } // SetLogFunc sets the logging function for the logger. func (l *Logger) SetLogFunc(logFunc func(line string)) { l.log = logFunc } // ModifyRequest logs the request, optionally including the body. // // The format logged is: // -------------------------------------------------------------------------------- // Request to http://www.google.com/path?querystring // -------------------------------------------------------------------------------- // GET /path?querystring HTTP/1.1 // Host: www.google.com // Connection: close // Other-Header: values // // request content // -------------------------------------------------------------------------------- func (l *Logger) ModifyRequest(req *http.Request) error { ctx := martian.NewContext(req) if ctx.SkippingLogging() { return nil } b := &bytes.Buffer{} fmt.Fprintln(b, "") fmt.Fprintln(b, strings.Repeat("-", 80)) fmt.Fprintf(b, "Request to %s\n", req.URL) fmt.Fprintln(b, strings.Repeat("-", 80)) mv := messageview.New() mv.SkipBody(l.headersOnly) if err := mv.SnapshotRequest(req); err != nil { return err } var opts []messageview.Option if l.decode { opts = append(opts, messageview.Decode()) } r, err := mv.Reader(opts...) if err != nil { return err } io.Copy(b, r) fmt.Fprintln(b, "") fmt.Fprintln(b, strings.Repeat("-", 80)) l.log(b.String()) return nil } // ModifyResponse logs the response, optionally including the body. // // The format logged is: // -------------------------------------------------------------------------------- // Response from http://www.google.com/path?querystring // -------------------------------------------------------------------------------- // HTTP/1.1 200 OK // Date: Tue, 15 Nov 1994 08:12:31 GMT // Other-Header: values // // response content // -------------------------------------------------------------------------------- func (l *Logger) ModifyResponse(res *http.Response) error { ctx := martian.NewContext(res.Request) if ctx.SkippingLogging() { return nil } b := &bytes.Buffer{} fmt.Fprintln(b, "") fmt.Fprintln(b, strings.Repeat("-", 80)) fmt.Fprintf(b, "Response from %s\n", res.Request.URL) fmt.Fprintln(b, strings.Repeat("-", 80)) mv := messageview.New() mv.SkipBody(l.headersOnly) if err := mv.SnapshotResponse(res); err != nil { return err } var opts []messageview.Option if l.decode { opts = append(opts, messageview.Decode()) } r, err := mv.Reader(opts...) if err != nil { return err } io.Copy(b, r) fmt.Fprintln(b, "") fmt.Fprintln(b, strings.Repeat("-", 80)) l.log(b.String()) return nil } // loggerFromJSON builds a logger from JSON. // // Example JSON: // { // "log.Logger": { // "scope": ["request", "response"], // "headersOnly": true, // "decode": true // } // } func loggerFromJSON(b []byte) (*parse.Result, error) { msg := &loggerJSON{} if err := json.Unmarshal(b, msg); err != nil { return nil, err } l := NewLogger() l.SetHeadersOnly(msg.HeadersOnly) l.SetDecode(msg.Decode) return parse.NewResult(l, msg.Scope) } martian-3.3.2/martianlog/logger_test.go000066400000000000000000000067601421371434000201640ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package martianlog import ( "bytes" "compress/gzip" "fmt" "net/http" "strings" "testing" "github.com/google/martian/v3" "github.com/google/martian/v3/parse" "github.com/google/martian/v3/proxyutil" ) func ExampleLogger() { l := NewLogger() l.SetLogFunc(func(line string) { // Remove \r to make it easier to test with examples. fmt.Print(strings.Replace(line, "\r", "", -1)) }) l.SetDecode(true) buf := new(bytes.Buffer) gw := gzip.NewWriter(buf) gw.Write([]byte("request content")) gw.Close() req, err := http.NewRequest("GET", "http://example.com/path?querystring", buf) if err != nil { fmt.Println(err) return } req.TransferEncoding = []string{"chunked"} req.Header.Set("Content-Encoding", "gzip") _, remove, err := martian.TestContext(req, nil, nil) if err != nil { fmt.Println(err) return } defer remove() if err := l.ModifyRequest(req); err != nil { fmt.Println(err) return } res := proxyutil.NewResponse(200, strings.NewReader("response content"), req) res.ContentLength = 16 res.Header.Set("Date", "Tue, 15 Nov 1994 08:12:31 GMT") res.Header.Set("Other-Header", "values") if err := l.ModifyResponse(res); err != nil { fmt.Println(err) return } // Output: // -------------------------------------------------------------------------------- // Request to http://example.com/path?querystring // -------------------------------------------------------------------------------- // GET http://example.com/path?querystring HTTP/1.1 // Host: example.com // Transfer-Encoding: chunked // Content-Encoding: gzip // // request content // // -------------------------------------------------------------------------------- // // -------------------------------------------------------------------------------- // Response from http://example.com/path?querystring // -------------------------------------------------------------------------------- // HTTP/1.1 200 OK // Content-Length: 16 // Date: Tue, 15 Nov 1994 08:12:31 GMT // Other-Header: values // // response content // -------------------------------------------------------------------------------- } func TestLoggerFromJSON(t *testing.T) { msg := []byte(`{ "log.Logger": { "scope": ["request", "response"], "headersOnly": true, "decode": true } }`) r, err := parse.FromJSON(msg) if err != nil { t.Fatalf("parse.FromJSON(): got %v, want no error", err) } reqmod := r.RequestModifier() if reqmod == nil { t.Fatal("r.RequestModifier(): got nil, want not nil") } if _, ok := reqmod.(*Logger); !ok { t.Error("reqmod.(*Logger): got !ok, want ok") } resmod := r.ResponseModifier() if resmod == nil { t.Fatal("r.ResponseModifier(): got nil, want not nil") } l, ok := resmod.(*Logger) if !ok { t.Error("resmod.(*Logger); got !ok, want ok") } if !l.headersOnly { t.Error("l.headersOnly: got false, want true") } if !l.decode { t.Error("l.decode: got false, want true") } } martian-3.3.2/martiantest/000077500000000000000000000000001421371434000155045ustar00rootroot00000000000000martian-3.3.2/martiantest/martiantest.go000066400000000000000000000074751421371434000204030ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package martiantest provides helper utilities for testing // modifiers. package martiantest import ( "net/http" "sync/atomic" ) // Modifier keeps track of the number of requests and responses it has modified // and can be configured to return errors or run custom functions. type Modifier struct { reqcount int32 // atomic rescount int32 // atomic reqerr error reserr error reqfunc func(*http.Request) resfunc func(*http.Response) } // NewModifier returns a new test modifier. func NewModifier() *Modifier { return &Modifier{} } // RequestCount returns the number of requests modified. func (m *Modifier) RequestCount() int32 { return atomic.LoadInt32(&m.reqcount) } // ResponseCount returns the number of responses modified. func (m *Modifier) ResponseCount() int32 { return atomic.LoadInt32(&m.rescount) } // RequestModified returns whether a request has been modified. func (m *Modifier) RequestModified() bool { return m.RequestCount() != 0 } // ResponseModified returns whether a response has been modified. func (m *Modifier) ResponseModified() bool { return m.ResponseCount() != 0 } // RequestError overrides the error returned by ModifyRequest. func (m *Modifier) RequestError(err error) { m.reqerr = err } // ResponseError overrides the error returned by ModifyResponse. func (m *Modifier) ResponseError(err error) { m.reserr = err } // RequestFunc is a function to run during ModifyRequest. func (m *Modifier) RequestFunc(reqfunc func(req *http.Request)) { m.reqfunc = reqfunc } // ResponseFunc is a function to run during ModifyResponse. func (m *Modifier) ResponseFunc(resfunc func(res *http.Response)) { m.resfunc = resfunc } // ModifyRequest increases the count of requests seen and runs reqfunc if configured. func (m *Modifier) ModifyRequest(req *http.Request) error { atomic.AddInt32(&m.reqcount, 1) if m.reqfunc != nil { m.reqfunc(req) } return m.reqerr } // ModifyResponse increases the count of responses seen and runs resfunc if configured. func (m *Modifier) ModifyResponse(res *http.Response) error { atomic.AddInt32(&m.rescount, 1) if m.resfunc != nil { m.resfunc(res) } return m.reserr } // Reset resets the request and response counts, the custom // functions, and the modifier errors. func (m *Modifier) Reset() { atomic.StoreInt32(&m.reqcount, 0) atomic.StoreInt32(&m.rescount, 0) m.reqfunc = nil m.resfunc = nil m.reqerr = nil m.reserr = nil } // Matcher is a stubbed matcher used in tests. type Matcher struct { resval bool reqval bool } // NewMatcher returns a pointer to martiantest.Matcher with the return values // for MatchRequest and MatchResponse intiailized to true. func NewMatcher() *Matcher { return &Matcher{resval: true, reqval: true} } // ResponseEvaluatesTo sets the value returned by MatchResponse. func (tm *Matcher) ResponseEvaluatesTo(value bool) { tm.resval = value } // RequestEvaluatesTo sets the value returned by MatchRequest. func (tm *Matcher) RequestEvaluatesTo(value bool) { tm.reqval = value } // MatchRequest returns the stubbed value in tm.reqval. func (tm *Matcher) MatchRequest(*http.Request) bool { return tm.reqval } // MatchResponse returns the stubbed value in tm.resval. func (tm *Matcher) MatchResponse(*http.Response) bool { return tm.resval } martian-3.3.2/martiantest/martiantest_test.go000066400000000000000000000041341421371434000214270ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package martiantest import ( "errors" "net/http" "testing" "github.com/google/martian/v3/proxyutil" ) func TestModifier(t *testing.T) { var reqrun bool var resrun bool moderr := errors.New("modifier error") tm := NewModifier() tm.RequestError(moderr) tm.RequestFunc(func(*http.Request) { reqrun = true }) tm.ResponseError(moderr) tm.ResponseFunc(func(*http.Response) { resrun = true }) req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := tm.ModifyRequest(req); err != moderr { t.Fatalf("tm.ModifyRequest(): got %v, want %v", err, moderr) } if !tm.RequestModified() { t.Errorf("tm.RequestModified(): got false, want true") } if tm.RequestCount() != 1 { t.Errorf("tm.RequestCount(): got %d, want %d", tm.RequestCount(), 1) } if !reqrun { t.Error("reqrun: got false, want true") } res := proxyutil.NewResponse(200, nil, req) if err := tm.ModifyResponse(res); err != moderr { t.Fatalf("tm.ModifyResponse(): got %v, want %v", err, moderr) } if !tm.ResponseModified() { t.Errorf("tm.ResponseModified(): got false, want true") } if tm.ResponseCount() != 1 { t.Errorf("tm.ResponseCount(): got %d, want %d", tm.ResponseCount(), 1) } if !resrun { t.Error("resrun: got false, want true") } tm.Reset() if tm.RequestModified() { t.Error("tm.RequestModified(): got true, want false") } if tm.ResponseModified() { t.Error("tm.ResponseModified(): got true, want false") } } martian-3.3.2/martiantest/transport.go000066400000000000000000000032371421371434000200740ustar00rootroot00000000000000package martiantest import ( "net/http" "github.com/google/martian/v3/proxyutil" ) // Transport is an http.RoundTripper for testing. type Transport struct { rtfunc func(*http.Request) (*http.Response, error) } // NewTransport builds a new transport that will respond with a 200 OK // response. func NewTransport() *Transport { tr := &Transport{} tr.Respond(200) return tr } // Respond sets the transport to respond with response with statusCode. func (tr *Transport) Respond(statusCode int) { tr.rtfunc = func(req *http.Request) (*http.Response, error) { // Force CONNECT requests to 200 to test CONNECT with downstream proxy. if req.Method == "CONNECT" { statusCode = 200 } res := proxyutil.NewResponse(statusCode, nil, req) return res, nil } } // RespondError sets the transport to respond with an error on round trip. func (tr *Transport) RespondError(err error) { tr.rtfunc = func(*http.Request) (*http.Response, error) { return nil, err } } // CopyHeaders sets the transport to respond with a 200 OK response with // headers copied from the request to the response verbatim. func (tr *Transport) CopyHeaders(names ...string) { tr.rtfunc = func(req *http.Request) (*http.Response, error) { res := proxyutil.NewResponse(200, nil, req) for _, n := range names { res.Header.Set(n, req.Header.Get(n)) } return res, nil } } // Func sets the transport to use the rtfunc. func (tr *Transport) Func(rtfunc func(*http.Request) (*http.Response, error)) { tr.rtfunc = rtfunc } // RoundTrip runs the stored round trip func and returns the response. func (tr *Transport) RoundTrip(req *http.Request) (*http.Response, error) { return tr.rtfunc(req) } martian-3.3.2/martiantest/transport_test.go000066400000000000000000000044731421371434000211360ustar00rootroot00000000000000package martiantest import ( "errors" "net/http" "testing" "github.com/google/martian/v3/proxyutil" ) func TestTransport(t *testing.T) { req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } tr := NewTransport() res, err := tr.RoundTrip(req) if err != nil { t.Fatalf("tr.Roundtrip(): got %v, want no error", err) } res.Body.Close() if got, want := res.StatusCode, 200; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } // Respond with 301 response. tr.Respond(301) res, err = tr.RoundTrip(req) if err != nil { t.Fatalf("tr.Roundtrip(): got %v, want no error", err) } res.Body.Close() if got, want := res.StatusCode, 301; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } // Respond with error. trerr := errors.New("transport error") tr.RespondError(trerr) if _, err := tr.RoundTrip(req); err != trerr { t.Fatalf("tr.Roundtrip(): got %v, want %v", err, trerr) } // Copy headers from request to response. req.Header.Set("First-Header", "first") req.Header.Set("Second-Header", "second") tr.CopyHeaders("First-Header", "Second-Header") res, err = tr.RoundTrip(req) if err != nil { t.Fatalf("tr.Roundtrip(): got %v, want no error", err) } res.Body.Close() if got, want := res.StatusCode, 200; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } if got, want := res.Header.Get("First-Header"), "first"; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "First-Header", got, want) } if got, want := res.Header.Get("Second-Header"), "second"; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Second-Header", got, want) } // Custom round trip function. tr.Func(func(req *http.Request) (*http.Response, error) { res := proxyutil.NewResponse(200, nil, req) res.Header.Set("Request-Method", req.Method) return res, nil }) res, err = tr.RoundTrip(req) if err != nil { t.Fatalf("tr.Roundtrip(): got %v, want no error", err) } res.Body.Close() if got, want := res.StatusCode, 200; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } if got, want := res.Header.Get("Request-Method"), "GET"; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Request-Method", got, want) } } martian-3.3.2/martianurl/000077500000000000000000000000001421371434000153275ustar00rootroot00000000000000martian-3.3.2/martianurl/host.go000066400000000000000000000032031421371434000166310ustar00rootroot00000000000000// Copyright 2016 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package martianurl // MatchHost matches two URL hosts with support for wildcards. func MatchHost(host, match string) bool { // Short circuit if host is empty. if host == "" { return false } // Exact match, no need to loop. if host == match { return true } // Walk backward over the host. hi := len(host) - 1 for mi := len(match) - 1; mi >= 0; mi-- { // Found wildcard, skip to next period. if match[mi] == '*' { for hi > 0 && host[hi] != '.' { hi-- } // Wildcard was the leftmost part and we have walked the entire host, // success. if mi == 0 && hi == 0 { return true } continue } if host[hi] != match[mi] { return false } // We have walked the entire host, if we have not walked the entire matcher // (mi != 0) that means the matcher has remaining characters to match and // thus the host cannot match. if hi == 0 { return mi == 0 } hi-- } // We have walked the entire length of the matcher, but haven't finished // walking the host thus they cannot match. return false } martian-3.3.2/martianurl/host_test.go000066400000000000000000000024471421371434000177010ustar00rootroot00000000000000// copyright 2016 google inc. all rights reserved. // // licensed under the apache license, version 2.0 (the "license"); // you may not use this file except in compliance with the license. // you may obtain a copy of the license at // // http://www.apache.org/licenses/license-2.0 // // unless required by applicable law or agreed to in writing, software // distributed under the license is distributed on an "as is" basis, // without warranties or conditions of any kind, either express or implied. // see the license for the specific language governing permissions and // limitations under the license. package martianurl import "testing" func TestMatchHost(t *testing.T) { tt := []struct { host, match string want bool }{ {"example.com", "example.com", true}, {"example.com", "example.org", false}, {"ample.com", "example.com", false}, {"example.com", "ample.com", false}, {"example.com", "example.*", true}, {"www.example.com", "*.example.com", true}, {"one.two.example.com", "*.example.com", false}, {"one.two.example.com", "*.*.example.com", true}, {"", "", false}, {"", "foo", false}, } for i, tc := range tt { if got := MatchHost(tc.host, tc.match); got != tc.want { t.Errorf("%d. MatchHost(%s, %s): got %t, want %t", i, tc.host, tc.match, got, tc.want) } } } martian-3.3.2/martianurl/url_filter.go000066400000000000000000000054751421371434000200400ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package martianurl import ( "encoding/json" "net/url" "github.com/google/martian/v3" "github.com/google/martian/v3/filter" "github.com/google/martian/v3/log" "github.com/google/martian/v3/parse" ) var noop = martian.Noop("url.Filter") func init() { parse.Register("url.Filter", filterFromJSON) } // Filter runs modifiers iff the request URL matches all of the segments in url. type Filter struct { *filter.Filter } type filterJSON struct { Scheme string `json:"scheme"` Host string `json:"host"` Path string `json:"path"` Query string `json:"query"` Modifier json.RawMessage `json:"modifier"` ElseModifier json.RawMessage `json:"else"` Scope []parse.ModifierType `json:"scope"` } // NewFilter constructs a filter that applies the modifer when the // request URL matches all of the provided URL segments. func NewFilter(u *url.URL) *Filter { log.Debugf("martianurl.NewFilter: %s", u) m := NewMatcher(u) f := filter.New() f.SetRequestCondition(m) f.SetResponseCondition(m) return &Filter{f} } // filterFromJSON takes a JSON message as a byte slice and returns a // parse.Result that contains a URLFilter and a bitmask that represents the // type of modifier. // // Example JSON configuration message: // { // "scheme": "https", // "host": "example.com", // "path": "/foo/bar", // "query": "q=value", // "scope": ["request", "response"], // "modifier": { ... } // "else": { ... } // } func filterFromJSON(b []byte) (*parse.Result, error) { msg := &filterJSON{} if err := json.Unmarshal(b, msg); err != nil { return nil, err } filter := NewFilter(&url.URL{ Scheme: msg.Scheme, Host: msg.Host, Path: msg.Path, RawQuery: msg.Query, }) m, err := parse.FromJSON(msg.Modifier) if err != nil { return nil, err } filter.RequestWhenTrue(m.RequestModifier()) filter.ResponseWhenTrue(m.ResponseModifier()) if len(msg.ElseModifier) > 0 { em, err := parse.FromJSON(msg.ElseModifier) if err != nil { return nil, err } if em != nil { filter.RequestWhenFalse(em.RequestModifier()) filter.ResponseWhenFalse(em.ResponseModifier()) } } return parse.NewResult(filter, msg.Scope) } martian-3.3.2/martianurl/url_filter_test.go000066400000000000000000000210441421371434000210650ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package martianurl import ( "errors" "net/http" "net/url" "testing" "github.com/google/martian/v3/martiantest" "github.com/google/martian/v3/parse" "github.com/google/martian/v3/proxyutil" "github.com/google/martian/v3/verify" _ "github.com/google/martian/v3/header" ) func TestFilterModifyRequest(t *testing.T) { tt := []struct { want bool match string url *url.URL }{ { match: "https://www.example.com", url: &url.URL{Scheme: "https"}, want: true, }, { match: "http://www.martian.local", url: &url.URL{Host: "*.martian.local"}, want: true, }, { match: "http://www.example.com/test", url: &url.URL{Path: "/test"}, want: true, }, { match: "http://www.example.com?test=true", url: &url.URL{RawQuery: "test=true"}, want: true, }, { match: "http://www.example.com#test", url: &url.URL{Fragment: "test"}, want: true, }, { match: "https://martian.local/test?test=true#test", url: &url.URL{ Scheme: "https", Host: "martian.local", Path: "/test", RawQuery: "test=true", Fragment: "test", }, want: true, }, { match: "https://www.example.com", url: &url.URL{Scheme: "http"}, want: false, }, { match: "http://www.martian.external", url: &url.URL{Host: "www.martian.local"}, want: false, }, { match: "http://www.example.com/testing", url: &url.URL{Path: "/test"}, want: false, }, { match: "http://www.example.com?test=false", url: &url.URL{RawQuery: "test=true"}, want: false, }, { match: "http://www.example.com#test", url: &url.URL{Fragment: "testing"}, want: false, }, } for i, tc := range tt { req, err := http.NewRequest("GET", tc.match, nil) if err != nil { t.Fatalf("%d. NewRequest(): got %v, want no error", i, err) } mod := NewFilter(tc.url) tm := martiantest.NewModifier() mod.SetRequestModifier(tm) if err := mod.ModifyRequest(req); err != nil { t.Fatalf("%d. ModifyRequest(): got %q, want no error", i, err) } if tm.RequestModified() != tc.want { t.Errorf("%d. tm.RequestModified(): got %t, want %t", i, tm.RequestModified(), tc.want) } } } func TestFilterModifyResponse(t *testing.T) { tt := []struct { want bool match string url *url.URL }{ { match: "https://www.example.com", url: &url.URL{Scheme: "https"}, want: true, }, { match: "http://www.martian.local", url: &url.URL{Host: "www.martian.local"}, want: true, }, { match: "http://www.example.com/test", url: &url.URL{Path: "/test"}, want: true, }, { match: "http://www.example.com?test=true", url: &url.URL{RawQuery: "test=true"}, want: true, }, { match: "http://www.example.com#test", url: &url.URL{Fragment: "test"}, want: true, }, { match: "https://martian.local/test?test=true#test", url: &url.URL{ Scheme: "https", Host: "martian.local", Path: "/test", RawQuery: "test=true", Fragment: "test", }, want: true, }, { match: "https://www.example.com", url: &url.URL{Scheme: "http"}, want: false, }, { match: "http://www.martian.external", url: &url.URL{Host: "www.martian.local"}, want: false, }, { match: "http://www.example.com/testing", url: &url.URL{Path: "/test"}, want: false, }, { match: "http://www.example.com?test=false", url: &url.URL{RawQuery: "test=true"}, want: false, }, { match: "http://www.example.com#test", url: &url.URL{Fragment: "testing"}, want: false, }, } for i, tc := range tt { req, err := http.NewRequest("GET", tc.match, nil) if err != nil { t.Fatalf("%d. NewRequest(): got %v, want no error", i, err) } res := proxyutil.NewResponse(200, nil, req) mod := NewFilter(tc.url) tm := martiantest.NewModifier() mod.SetResponseModifier(tm) if err := mod.ModifyResponse(res); err != nil { t.Fatalf("%d. ModifyResponse(): got %q, want no error", i, err) } if tm.ResponseModified() != tc.want { t.Errorf("tm.ResponseModified(): got %t, want %t", tm.ResponseModified(), tc.want) } } } func TestFilterFromJSON(t *testing.T) { msg := []byte(`{ "url.Filter": { "scope": ["request", "response"], "scheme": "https", "modifier": { "header.Modifier": { "scope": ["request", "response"], "name": "Mod-Run", "value": "true" } }, "else": { "header.Modifier": { "scope": ["request", "response"], "name": "Else-Run", "value": "true" } } } }`) r, err := parse.FromJSON(msg) if err != nil { t.Fatalf("FilterFromJSON(): got %v, want no error", err) } reqmod := r.RequestModifier() if reqmod == nil { t.Fatal("reqmod: got nil, want not nil") } req, err := http.NewRequest("GET", "https://martian.test", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := reqmod.ModifyRequest(req); err != nil { t.Fatalf("reqmod.ModifyRequest(): got %v, want no error", err) } if got, want := req.Header.Get("Mod-Run"), "true"; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", "Mod-Run", got, want) } resmod := r.ResponseModifier() if resmod == nil { t.Fatalf("resmod: got nil, want not nil") } res := proxyutil.NewResponse(200, nil, req) if err := resmod.ModifyResponse(res); err != nil { t.Fatalf("resmod.ModifyResponse(): got %v, want no error", err) } if got, want := res.Header.Get("Mod-Run"), "true"; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Mod-Run", got, want) } // test else conditional modifier with scheme of http req, err = http.NewRequest("GET", "http://martian.test", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := reqmod.ModifyRequest(req); err != nil { t.Fatalf("reqmod.ModifyRequest(): got %v, want no error", err) } if got, want := req.Header.Get("Mod-Run"), ""; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", "Mod-Run", got, want) } if got, want := req.Header.Get("Else-Run"), "true"; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", "Mod-Run", got, want) } } func TestPassThroughVerifyRequests(t *testing.T) { u := &url.URL{Host: "www.martian.local"} f := NewFilter(u) if err := f.VerifyRequests(); err != nil { t.Fatalf("VerifyRequest(): got %v, want no error", err) } tv := &verify.TestVerifier{ RequestError: errors.New("verify request failure"), } f.SetRequestModifier(tv) if got, want := f.VerifyRequests().Error(), "verify request failure"; got != want { t.Fatalf("VerifyRequests(): got %s, want %s", got, want) } } func TestPassThroughVerifyResponses(t *testing.T) { u := &url.URL{Host: "www.martian.local"} f := NewFilter(u) if err := f.VerifyResponses(); err != nil { t.Fatalf("VerifyResponses(): got %v, want no error", err) } tv := &verify.TestVerifier{ ResponseError: errors.New("verify response failure"), } f.SetResponseModifier(tv) if got, want := f.VerifyResponses().Error(), "verify response failure"; got != want { t.Fatalf("VerifyResponses(): got %s, want %s", got, want) } } func TestResets(t *testing.T) { u := &url.URL{Host: "www.martian.local"} f := NewFilter(u) tv := &verify.TestVerifier{ ResponseError: errors.New("verify response failure"), } f.SetResponseModifier(tv) tv = &verify.TestVerifier{ RequestError: errors.New("verify request failure"), } f.SetRequestModifier(tv) if err := f.VerifyRequests(); err == nil { t.Fatal("VerifyRequests(): got nil, want error") } if err := f.VerifyResponses(); err == nil { t.Fatal("VerifyResponses(): got nil, want error") } f.ResetRequestVerifications() f.ResetResponseVerifications() if err := f.VerifyRequests(); err != nil { t.Errorf("VerifyRequests(): got %v, want no error", err) } if err := f.VerifyResponses(); err != nil { t.Errorf("VerifyResponses(): got %v, want no error", err) } } martian-3.3.2/martianurl/url_matcher.go000066400000000000000000000040131421371434000201610ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package martianurl import ( "net/http" "net/url" "github.com/google/martian/v3/log" ) // Matcher is a conditional evaluator of request urls to be used in // filters that take conditionals. type Matcher struct { url *url.URL } // NewMatcher builds a new url matcher. func NewMatcher(url *url.URL) *Matcher { return &Matcher{ url: url, } } // MatchRequest retuns true if all non-empty URL segments in m.url match the // request URL. func (m *Matcher) MatchRequest(req *http.Request) bool { matched := m.matches(req.URL) if matched { log.Debugf("martianurl.Matcher.MatchRequest: matched: %s", req.URL) } return matched } // MatchResponse retuns true if all non-empty URL segments in m.url match the // request URL. func (m *Matcher) MatchResponse(res *http.Response) bool { matched := m.matches(res.Request.URL) if matched { log.Debugf("martianurl.Matcher.MatchResponse: matched: %s", res.Request.URL) } return matched } // matches forces all non-empty URL segments to match or it returns false. func (m *Matcher) matches(u *url.URL) bool { switch { case m.url.Scheme != "" && m.url.Scheme != u.Scheme: return false case m.url.Host != "" && !MatchHost(u.Host, m.url.Host): return false case m.url.Path != "" && m.url.Path != u.Path: return false case m.url.RawQuery != "" && m.url.RawQuery != u.RawQuery: return false case m.url.Fragment != "" && m.url.Fragment != u.Fragment: return false } return true } martian-3.3.2/martianurl/url_modifier.go000066400000000000000000000050251421371434000203400ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package martianurl provides utilities for modifying, filtering, // and verifying URLs in martian.Proxy. package martianurl import ( "encoding/json" "net/http" "net/url" "github.com/google/martian/v3" "github.com/google/martian/v3/parse" ) // Modifier alters the request URL fields to match the fields of // url and adds a X-Forwarded-Url header that contains the original // value of the request URL. type Modifier struct { url *url.URL } type modifierJSON struct { Scheme string `json:"scheme"` Host string `json:"host"` Path string `json:"path"` Query string `json:"query"` Scope []parse.ModifierType `json:"scope"` } func init() { parse.Register("url.Modifier", modifierFromJSON) } // ModifyRequest sets the fields of req.URL to m.Url if they are not the zero value. func (m *Modifier) ModifyRequest(req *http.Request) error { if m.url.Scheme != "" { req.URL.Scheme = m.url.Scheme } if m.url.Host != "" { req.URL.Host = m.url.Host } if m.url.Path != "" { req.URL.Path = m.url.Path } if m.url.RawQuery != "" { req.URL.RawQuery = m.url.RawQuery } if m.url.Fragment != "" { req.URL.Fragment = m.url.Fragment } return nil } // NewModifier overrides the url of the request. func NewModifier(url *url.URL) martian.RequestModifier { return &Modifier{ url: url, } } // modifierFromJSON builds a martianurl.Modifier from JSON. // // Example modifier JSON: // { // "martianurl.Modifier": { // "scope": ["request"], // "scheme": "https", // "host": "www.google.com", // "path": "/proxy", // "query": "testing=true" // } // } func modifierFromJSON(b []byte) (*parse.Result, error) { msg := &modifierJSON{} if err := json.Unmarshal(b, msg); err != nil { return nil, err } mod := NewModifier(&url.URL{ Scheme: msg.Scheme, Host: msg.Host, Path: msg.Path, RawQuery: msg.Query, }) return parse.NewResult(mod, msg.Scope) } martian-3.3.2/martianurl/url_modifier_test.go000066400000000000000000000077511421371434000214070ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package martianurl import ( "net/http" "net/http/httptest" "net/url" "testing" "github.com/google/martian/v3/parse" ) func TestNewModifier(t *testing.T) { tt := []struct { want string url *url.URL }{ { want: "https://www.example.com", url: &url.URL{Scheme: "https"}, }, { want: "http://www.martian.local", url: &url.URL{Host: "www.martian.local"}, }, { want: "http://www.example.com/test", url: &url.URL{Path: "/test"}, }, { want: "http://www.example.com?test=true", url: &url.URL{RawQuery: "test=true"}, }, { want: "http://www.example.com#test", url: &url.URL{Fragment: "test"}, }, { want: "https://martian.local/test?test=true#test", url: &url.URL{ Scheme: "https", Host: "martian.local", Path: "/test", RawQuery: "test=true", Fragment: "test", }, }, } for i, tc := range tt { req, err := http.NewRequest("GET", "http://www.example.com", nil) if err != nil { t.Fatalf("%d. NewRequest(): got %v, want no error", i, err) } mod := NewModifier(tc.url) if err := mod.ModifyRequest(req); err != nil { t.Fatalf("%d. ModifyRequest(): got %q, want no error", i, err) } if got := req.URL.String(); got != tc.want { t.Errorf("%d. req.URL: got %q, want %q", i, got, tc.want) } } } func TestIntegration(t *testing.T) { server := httptest.NewServer(http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { r.URL.Scheme = "http" r.URL.Host = r.Host w.Header().Set("Martian-URL", r.URL.String()) })) defer server.Close() u := &url.URL{ Scheme: "http", Host: server.Listener.Addr().String(), } m := NewModifier(u) req, err := http.NewRequest("GET", "https://example.com/test", nil) if err != nil { t.Fatalf("http.NewRequest(%q, %q, nil): got %v, want no error", "GET", "http://example.com/test", err) } if err := m.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } res, err := http.DefaultClient.Do(req) if err != nil { t.Fatalf("http.DefaultClient.Do(): got %v, want no error", err) } want := "http://example.com/test" if got := res.Header.Get("Martian-URL"); got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Martian-URL", got, want) } } func TestModifierFromJSON(t *testing.T) { msg := []byte(`{ "url.Modifier": { "scope": ["request"], "scheme": "https", "host": "www.martian.proxy", "path": "/testing", "query": "test=true" } }`) r, err := parse.FromJSON(msg) if err != nil { t.Fatalf("parse.FromJSON(): got %v, want no error", err) } reqmod := r.RequestModifier() if reqmod == nil { t.Fatal("reqmod: got nil, want not nil") } req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := reqmod.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := req.URL.Scheme, "https"; got != want { t.Errorf("req.URL.Scheme: got %q, want %q", got, want) } if got, want := req.URL.Host, "www.martian.proxy"; got != want { t.Errorf("req.URL.Host: got %q, want %q", got, want) } if got, want := req.URL.Path, "/testing"; got != want { t.Errorf("req.URL.Path: got %q, want %q", got, want) } if got, want := req.URL.RawQuery, "test=true"; got != want { t.Errorf("req.URL.RawQuery: got %q, want %q", got, want) } } martian-3.3.2/martianurl/url_regex_filter.go000066400000000000000000000051441421371434000212230ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package martianurl import ( "encoding/json" "regexp" "github.com/google/martian/v3/filter" "github.com/google/martian/v3/parse" ) func init() { parse.Register("url.RegexFilter", regexFilterFromJSON) } // URLRegexFilter runs Modifier if the request URL matches the regex, and runs ElseModifier if not. // This is not to be confused with url.Filter that does string matching on URL segments. type URLRegexFilter struct { *filter.Filter } type regexFilterJSON struct { Regex string `json:"regex"` Modifier json.RawMessage `json:"modifier"` ElseModifier json.RawMessage `json:"else"` Scope []parse.ModifierType `json:"scope"` } // NewRegexFilter constructs a filter that matches on regular expressions. func NewRegexFilter(r *regexp.Regexp) *URLRegexFilter { filter := filter.New() matcher := NewRegexMatcher(r) filter.SetRequestCondition(matcher) filter.SetResponseCondition(matcher) return &URLRegexFilter{filter} } // regexFilterFromJSON takes a JSON message as a byte slice and returns a // parse.Result that contains a URLRegexFilter and a scope. The regex syntax is RE2 // as described at https://golang.org/s/re2syntax. // // Example JSON configuration message: // { // "scope": ["request", "response"], // "regex": ".*www.example.com.*" // } func regexFilterFromJSON(b []byte) (*parse.Result, error) { msg := ®exFilterJSON{} if err := json.Unmarshal(b, msg); err != nil { return nil, err } matcher, err := regexp.Compile(msg.Regex) if err != nil { return nil, err } filter := NewRegexFilter(matcher) m, err := parse.FromJSON(msg.Modifier) if err != nil { return nil, err } filter.RequestWhenTrue(m.RequestModifier()) filter.ResponseWhenTrue(m.ResponseModifier()) if len(msg.ElseModifier) > 0 { em, err := parse.FromJSON(msg.ElseModifier) if err != nil { return nil, err } if em != nil { filter.RequestWhenFalse(em.RequestModifier()) filter.ResponseWhenFalse(em.ResponseModifier()) } } return parse.NewResult(filter, msg.Scope) } martian-3.3.2/martianurl/url_regex_filter_test.go000066400000000000000000000133071421371434000222620ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package martianurl import ( "net/http" "regexp" "testing" "github.com/google/martian/v3" _ "github.com/google/martian/v3/header" "github.com/google/martian/v3/parse" "github.com/google/martian/v3/proxyutil" ) func TestRegexFilterModifyRequest(t *testing.T) { tt := []struct { want bool match string regstr string }{ { match: "https://www.example.com", regstr: "https://.*", want: true, }, { match: "http://www.example.com/subpath", regstr: ".*www.example.com.*", want: true, }, { match: "https://www.example.com/subpath", regstr: ".*www.example.com.*", want: true, }, { match: "http://www.example.com/test", regstr: ".*/test", want: true, }, { match: "http://www.example.com?test=true", regstr: ".*test=true.*", want: true, }, { match: "http://www.example.com#test", regstr: ".*test.*", want: true, }, { match: "https://martian.local/test?test=true#test", regstr: "https://martian.local/test\\?test=true#test", want: true, }, { match: "http://www.youtube.com/get_tags?tagone=yes", regstr: ".*www.youtube.com/get_tags\\?.*", want: true, }, { match: "https://www.example.com", regstr: "http://.*", want: false, }, { match: "http://www.martian.external", regstr: ".*www.martian.local.*", want: false, }, { match: "http://www.example.com/testing", regstr: ".*/test$", want: false, }, { match: "http://www.example.com?test=false", regstr: ".*test=true.*", want: false, }, { match: "http://www.example.com#test", regstr: ".*#testing.*", want: false, }, { match: "https://martian.local/test?test=true#test", // "\\\\" was the old way of adding a backslash in SAVR regstr: "https://martian.local/test\\\\?test=true#test", want: false, }, { match: "http://www.youtube.com/get_tags/nope", regstr: ".*www.youtube.com/get_ad_tags\\?.*", want: false, }, } for i, tc := range tt { req, err := http.NewRequest("GET", tc.match, nil) if err != nil { t.Fatalf("%d. NewRequest(): got %v, want no error", i, err) } regex, err := regexp.Compile(tc.regstr) if err != nil { t.Fatalf("%d. regexp.Compile(): got %v, want no error", i, err) } var modRun bool mod := NewRegexFilter(regex) mod.SetRequestModifier(martian.RequestModifierFunc( func(*http.Request) error { modRun = true return nil })) if err := mod.ModifyRequest(req); err != nil { t.Fatalf("%d. ModifyRequest(): got %q, want no error", i, err) } if modRun != tc.want { t.Errorf("%d. modRun: got %t, want %t", i, modRun, tc.want) } } } // The matching functionality is already tested above, so this just tests response setting. func TestRegexFilterModifyResponse(t *testing.T) { tt := []struct { want bool match string regstr string }{ { match: "https://www.example.com", regstr: ".*www.example.com.*", want: true, }, { match: "http://www.martian.external", regstr: ".*www.martian.local.*", want: false, }, } for i, tc := range tt { req, err := http.NewRequest("GET", tc.match, nil) if err != nil { t.Fatalf("%d. NewRequest(): got %v, want no error", i, err) } res := proxyutil.NewResponse(200, nil, req) regex, err := regexp.Compile(tc.regstr) if err != nil { t.Fatalf("%d. regexp.Compile(): got %v, want no error", i, err) } var modRun bool mod := NewRegexFilter(regex) mod.SetResponseModifier(martian.ResponseModifierFunc( func(*http.Response) error { modRun = true return nil })) if err := mod.ModifyResponse(res); err != nil { t.Fatalf("%d. ModifyResponse(): got %q, want no error", i, err) } if modRun != tc.want { t.Errorf("%d. modRun: got %t, want %t", i, modRun, tc.want) } } } func TestRegexFilterFromJSON(t *testing.T) { rawMsg := ` { "url.RegexFilter": { "scope": ["request", "response"], "regex": ".*martian.test.*", "modifier": { "header.Modifier": { "name": "Martian-Test", "value": "true", "scope": ["request", "response"] } } } }` r, err := parse.FromJSON([]byte(rawMsg)) if err != nil { t.Fatalf("parse.FromJSON(): got %v, want no error", err) } reqmod := r.RequestModifier() if reqmod == nil { t.Errorf("reqmod: got nil, want not nil") } req, err := http.NewRequest("GET", "https://martian.test", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := reqmod.ModifyRequest(req); err != nil { t.Fatalf("reqmod.ModifyRequest(): got %v, want no error", err) } if got, want := req.Header.Get("Martian-Test"), "true"; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", "Martian-Test", got, want) } resmod := r.ResponseModifier() if resmod == nil { t.Fatalf("resmod: got nil, want not nil") } res := proxyutil.NewResponse(200, nil, req) if err := resmod.ModifyResponse(res); err != nil { t.Fatalf("resmod.ModifyResponse(): got %v, want no error", err) } if got, want := res.Header.Get("Martian-Test"), "true"; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Martian-Test", got, want) } }martian-3.3.2/martianurl/url_regex_matcher.go000066400000000000000000000026251421371434000213620ustar00rootroot00000000000000// Copyright 2017 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package martianurl import ( "net/http" "net/url" "regexp" ) // RegexMatcher is a conditional evaluator of request urls to be used in // filters that take conditionals. type RegexMatcher struct { r *regexp.Regexp } // NewRegexMatcher builds a new url matcher from a compiled Regexp. func NewRegexMatcher(r *regexp.Regexp) *RegexMatcher { return &RegexMatcher{ r: r, } } // MatchRequest retuns true if the request URL matches r. func (m *RegexMatcher) MatchRequest(req *http.Request) bool { return m.matches(req.URL) } // MatchResponse retuns true if the response URL matches r. func (m *RegexMatcher) MatchResponse(res *http.Response) bool { return m.matches(res.Request.URL) } // matches checks if a url matches r. func (m *RegexMatcher) matches(u *url.URL) bool { return m.r.MatchString(u.String()) } martian-3.3.2/martianurl/url_verifier.go000066400000000000000000000072471421371434000203650ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package martianurl import ( "encoding/json" "fmt" "net/http" "net/url" "strings" "github.com/google/martian/v3" "github.com/google/martian/v3/parse" "github.com/google/martian/v3/verify" ) const ( errFormat = "request(%s) url verify failure:\n%s" errPartFormat = "\t%s: got %q, want %q" ) func init() { parse.Register("url.Verifier", verifierFromJSON) } // Verifier verifies the structure of URLs. type Verifier struct { url *url.URL err *martian.MultiError } type verifierJSON struct { Scheme string `json:"scheme"` Host string `json:"host"` Path string `json:"path"` Query string `json:"query"` Scope []parse.ModifierType `json:"scope"` } // NewVerifier returns a new URL verifier. func NewVerifier(url *url.URL) verify.RequestVerifier { return &Verifier{ url: url, err: martian.NewMultiError(), } } // ModifyRequest verifies that the request URL matches all parts of url. If the // value in url is non-empty it must be an exact match. func (v *Verifier) ModifyRequest(req *http.Request) error { // skip requests to API ctx := martian.NewContext(req) if ctx.IsAPIRequest() { return nil } var failures []string u := req.URL if v.url.Scheme != "" && v.url.Scheme != u.Scheme { f := fmt.Sprintf(errPartFormat, "Scheme", u.Scheme, v.url.Scheme) failures = append(failures, f) } if v.url.Host != "" && !MatchHost(u.Host, v.url.Host) { f := fmt.Sprintf(errPartFormat, "Host", u.Host, v.url.Host) failures = append(failures, f) } if v.url.Path != "" && v.url.Path != u.Path { f := fmt.Sprintf(errPartFormat, "Path", u.Path, v.url.Path) failures = append(failures, f) } if v.url.RawQuery != "" && v.url.RawQuery != u.RawQuery { f := fmt.Sprintf(errPartFormat, "Query", u.RawQuery, v.url.RawQuery) failures = append(failures, f) } if v.url.Fragment != "" && v.url.Fragment != u.Fragment { f := fmt.Sprintf(errPartFormat, "Fragment", u.Fragment, v.url.Fragment) failures = append(failures, f) } if len(failures) > 0 { err := fmt.Errorf(errFormat, u, strings.Join(failures, "\n")) v.err.Add(err) } return nil } // VerifyRequests returns an error if verification for any request failed. // If an error is returned it will be of type *martian.MultiError. func (v *Verifier) VerifyRequests() error { if v.err.Empty() { return nil } return v.err } // ResetRequestVerifications clears all failed request verifications. func (v *Verifier) ResetRequestVerifications() { v.err = martian.NewMultiError() } // verifierFromJSON builds a martianurl.Verifier from JSON. // // Example modifier JSON: // { // "martianurl.Verifier": { // "scope": ["request"], // "scheme": "https", // "host": "www.google.com", // "path": "/proxy", // "query": "testing=true" // } // } func verifierFromJSON(b []byte) (*parse.Result, error) { msg := &verifierJSON{} if err := json.Unmarshal(b, msg); err != nil { return nil, err } v := NewVerifier(&url.URL{ Scheme: msg.Scheme, Host: msg.Host, Path: msg.Path, RawQuery: msg.Query, }) return parse.NewResult(v, msg.Scope) } martian-3.3.2/martianurl/url_verifier_test.go000066400000000000000000000115031421371434000214120ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package martianurl import ( "net/http" "net/url" "testing" "github.com/google/martian/v3" "github.com/google/martian/v3/parse" "github.com/google/martian/v3/verify" ) func TestVerifyRequests(t *testing.T) { u := &url.URL{ Scheme: "https", Host: "*.example.com", Path: "/test", RawQuery: "testing=true", Fragment: "test", } v := NewVerifier(u) tt := []struct { got, want string }{ { got: "http://www.example.com/test?testing=true#test", want: `request(http://www.example.com/test?testing=true#test) url verify failure: Scheme: got "http", want "https"`, }, { got: "http://www.martian.test/test?testing=true#test", want: `request(http://www.martian.test/test?testing=true#test) url verify failure: Scheme: got "http", want "https" Host: got "www.martian.test", want "*.example.com"`, }, { got: "http://www.martian.test/prod?testing=true#test", want: `request(http://www.martian.test/prod?testing=true#test) url verify failure: Scheme: got "http", want "https" Host: got "www.martian.test", want "*.example.com" Path: got "/prod", want "/test"`, }, { got: "http://www.martian.test/prod#test", want: `request(http://www.martian.test/prod#test) url verify failure: Scheme: got "http", want "https" Host: got "www.martian.test", want "*.example.com" Path: got "/prod", want "/test" Query: got "", want "testing=true"`, }, { got: "http://www.martian.test/prod#fake", want: `request(http://www.martian.test/prod#fake) url verify failure: Scheme: got "http", want "https" Host: got "www.martian.test", want "*.example.com" Path: got "/prod", want "/test" Query: got "", want "testing=true" Fragment: got "fake", want "test"`, }, } for i, tc := range tt { req, err := http.NewRequest("GET", tc.got, nil) if err != nil { t.Fatalf("%d. http.NewRequest(..., %s, ...): got %v, want no error", i, tc.got, err) } _, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("TestContext(): got %v, want no error", err) } defer remove() if err := v.ModifyRequest(req); err != nil { t.Fatalf("%d. ModifyRequest(): got %v, want no error", i, err) } } merr, ok := v.VerifyRequests().(*martian.MultiError) if !ok { t.Fatal("VerifyRequests(): got nil, want *verify.MultiError") } errs := merr.Errors() if got, want := len(errs), len(tt); got != want { t.Fatalf("len(merr.Errors()): got %d, want %d", got, want) } for i, tc := range tt { if got, want := errs[i].Error(), tc.want; got != want { t.Errorf("%d. err.Error(): mismatched error output\ngot: %s\nwant: %s", i, got, want) } } v.ResetRequestVerifications() // A valid request. req, err := http.NewRequest("GET", "https://www.example.com/test?testing=true#test", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } _, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("TestContext(): got %v, want no error", err) } defer remove() if err := v.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if err := v.VerifyRequests(); err != nil { t.Errorf("VerifyRequests(): got %v, want no error", err) } } func TestVerifierFromJSON(t *testing.T) { msg := []byte(`{ "url.Verifier": { "scope": ["request"], "scheme": "https", "host": "www.martian.proxy", "path": "/testing", "query": "test=true" } }`) r, err := parse.FromJSON(msg) if err != nil { t.Fatalf("parse.FromJSON(): got %v, want no error", err) } reqmod := r.RequestModifier() if reqmod == nil { t.Fatal("reqmod: got nil, want not nil") } reqv, ok := reqmod.(verify.RequestVerifier) if !ok { t.Fatal("reqmod.(verify.RequestVerifier): got !ok, want ok") } req, err := http.NewRequest("GET", "https://www.martian.proxy/testing?test=false", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } _, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("TestContext(): got %v, want no error", err) } defer remove() if err := reqv.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if err := reqv.VerifyRequests(); err == nil { t.Error("VerifyRequests(): got nil, want not nil") } } martian-3.3.2/messageview/000077500000000000000000000000001421371434000154705ustar00rootroot00000000000000martian-3.3.2/messageview/messageview.go000066400000000000000000000160761421371434000203500ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package messageview provides no-op snapshots for HTTP requests and // responses. package messageview import ( "bytes" "compress/flate" "compress/gzip" "fmt" "io" "io/ioutil" "net/http" "net/http/httputil" "strings" ) // MessageView is a static view of an HTTP request or response. type MessageView struct { message []byte cts []string chunked bool skipBody bool compress string bodyoffset int64 traileroffset int64 } type config struct { decode bool } // Option is a configuration option for a MessageView. type Option func(*config) // Decode sets an option to decode the message body for logging purposes. func Decode() Option { return func(c *config) { c.decode = true } } // New returns a new MessageView. func New() *MessageView { return &MessageView{} } // SkipBody will skip reading the body when the view is loaded with a request // or response. func (mv *MessageView) SkipBody(skipBody bool) { mv.skipBody = skipBody } // SkipBodyUnlessContentType will skip reading the body unless the // Content-Type matches one in cts. func (mv *MessageView) SkipBodyUnlessContentType(cts ...string) { mv.skipBody = true mv.cts = cts } // SnapshotRequest reads the request into the MessageView. If mv.skipBody is false // it will also read the body into memory and replace the existing body with // the in-memory copy. This method is semantically a no-op. func (mv *MessageView) SnapshotRequest(req *http.Request) error { buf := new(bytes.Buffer) fmt.Fprintf(buf, "%s %s HTTP/%d.%d\r\n", req.Method, req.URL, req.ProtoMajor, req.ProtoMinor) if req.Host != "" { fmt.Fprintf(buf, "Host: %s\r\n", req.Host) } if tec := len(req.TransferEncoding); tec > 0 { mv.chunked = req.TransferEncoding[tec-1] == "chunked" fmt.Fprintf(buf, "Transfer-Encoding: %s\r\n", strings.Join(req.TransferEncoding, ", ")) } if !mv.chunked && req.ContentLength >= 0 { fmt.Fprintf(buf, "Content-Length: %d\r\n", req.ContentLength) } mv.compress = req.Header.Get("Content-Encoding") req.Header.WriteSubset(buf, map[string]bool{ "Host": true, "Content-Length": true, "Transfer-Encoding": true, }) fmt.Fprint(buf, "\r\n") mv.bodyoffset = int64(buf.Len()) mv.traileroffset = int64(buf.Len()) ct := req.Header.Get("Content-Type") if mv.skipBody && !mv.matchContentType(ct) || req.Body == nil { mv.message = buf.Bytes() return nil } data, err := ioutil.ReadAll(req.Body) if err != nil { return err } req.Body.Close() if mv.chunked { cw := httputil.NewChunkedWriter(buf) cw.Write(data) cw.Close() } else { buf.Write(data) } mv.traileroffset = int64(buf.Len()) req.Body = ioutil.NopCloser(bytes.NewReader(data)) if req.Trailer != nil { req.Trailer.Write(buf) } else if mv.chunked { fmt.Fprint(buf, "\r\n") } mv.message = buf.Bytes() return nil } // SnapshotResponse reads the response into the MessageView. If mv.headersOnly // is false it will also read the body into memory and replace the existing // body with the in-memory copy. This method is semantically a no-op. func (mv *MessageView) SnapshotResponse(res *http.Response) error { buf := new(bytes.Buffer) fmt.Fprintf(buf, "HTTP/%d.%d %s\r\n", res.ProtoMajor, res.ProtoMinor, res.Status) if tec := len(res.TransferEncoding); tec > 0 { mv.chunked = res.TransferEncoding[tec-1] == "chunked" fmt.Fprintf(buf, "Transfer-Encoding: %s\r\n", strings.Join(res.TransferEncoding, ", ")) } if !mv.chunked && res.ContentLength >= 0 { fmt.Fprintf(buf, "Content-Length: %d\r\n", res.ContentLength) } mv.compress = res.Header.Get("Content-Encoding") // Do not uncompress if we have don't have the full contents. if res.StatusCode == http.StatusNoContent || res.StatusCode == http.StatusPartialContent { mv.compress = "" } res.Header.WriteSubset(buf, map[string]bool{ "Content-Length": true, "Transfer-Encoding": true, }) fmt.Fprint(buf, "\r\n") mv.bodyoffset = int64(buf.Len()) mv.traileroffset = int64(buf.Len()) ct := res.Header.Get("Content-Type") if mv.skipBody && !mv.matchContentType(ct) || res.Body == nil { mv.message = buf.Bytes() return nil } data, err := ioutil.ReadAll(res.Body) if err != nil { return err } res.Body.Close() if mv.chunked { cw := httputil.NewChunkedWriter(buf) cw.Write(data) cw.Close() } else { buf.Write(data) } mv.traileroffset = int64(buf.Len()) res.Body = ioutil.NopCloser(bytes.NewReader(data)) if res.Trailer != nil { res.Trailer.Write(buf) } else if mv.chunked { fmt.Fprint(buf, "\r\n") } mv.message = buf.Bytes() return nil } // Reader returns the an io.ReadCloser that reads the full HTTP message. func (mv *MessageView) Reader(opts ...Option) (io.ReadCloser, error) { hr := mv.HeaderReader() br, err := mv.BodyReader(opts...) if err != nil { return nil, err } tr := mv.TrailerReader() return struct { io.Reader io.Closer }{ Reader: io.MultiReader(hr, br, tr), Closer: br, }, nil } // HeaderReader returns an io.Reader that reads the HTTP Status-Line or // HTTP Request-Line and headers. func (mv *MessageView) HeaderReader() io.Reader { r := bytes.NewReader(mv.message) return io.NewSectionReader(r, 0, mv.bodyoffset) } // BodyReader returns an io.ReadCloser that reads the HTTP request or response // body. If mv.skipBody was set the reader will immediately return io.EOF. // // If the Decode option is passed the body will be unchunked if // Transfer-Encoding is set to "chunked", and will decode the following // Content-Encodings: gzip, deflate. func (mv *MessageView) BodyReader(opts ...Option) (io.ReadCloser, error) { var r io.Reader conf := &config{} for _, o := range opts { o(conf) } br := bytes.NewReader(mv.message) r = io.NewSectionReader(br, mv.bodyoffset, mv.traileroffset-mv.bodyoffset) if !conf.decode { return ioutil.NopCloser(r), nil } if mv.chunked { r = httputil.NewChunkedReader(r) } switch mv.compress { case "gzip": gr, err := gzip.NewReader(r) if err != nil { return nil, err } return gr, nil case "deflate": return flate.NewReader(r), nil default: return ioutil.NopCloser(r), nil } } // TrailerReader returns an io.Reader that reads the HTTP request or response // trailers, if present. func (mv *MessageView) TrailerReader() io.Reader { r := bytes.NewReader(mv.message) end := int64(len(mv.message)) - mv.traileroffset return io.NewSectionReader(r, mv.traileroffset, end) } func (mv *MessageView) matchContentType(mct string) bool { for _, ct := range mv.cts { if strings.HasPrefix(mct, ct) { return true } } return false } martian-3.3.2/messageview/messageview_test.go000066400000000000000000000507551421371434000214110ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package messageview import ( "bufio" "bytes" "compress/flate" "compress/gzip" "io" "io/ioutil" "net/http" "strings" "testing" "github.com/google/martian/v3/proxyutil" ) func TestRequestViewHeadersOnly(t *testing.T) { body := strings.NewReader("body content") req, err := http.NewRequest("GET", "http://example.com/path?k=v", body) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.ContentLength = int64(body.Len()) req.Header.Set("Request-Header", "true") mv := New() mv.SkipBody(true) if err := mv.SnapshotRequest(req); err != nil { t.Fatalf("SnapshotRequest(): got %v, want no error", err) } got, err := ioutil.ReadAll(mv.HeaderReader()) if err != nil { t.Fatalf("ioutil.ReadAll(mv.HeaderReader()): got %v, want no error", err) } hdrwant := "GET http://example.com/path?k=v HTTP/1.1\r\n" + "Host: example.com\r\n" + "Content-Length: 12\r\n" + "Request-Header: true\r\n\r\n" if !bytes.Equal(got, []byte(hdrwant)) { t.Fatalf("mv.HeaderReader(): got %q, want %q", got, hdrwant) } br, err := mv.BodyReader() if err != nil { t.Fatalf("mv.BodyReader(): got %v, want no error", err) } if _, err := br.Read(nil); err != io.EOF { t.Fatalf("ioutil.ReadAll(mv.BodyReader()): got %v, want io.EOF", err) } r, err := mv.Reader() if err != nil { t.Fatalf("mv.Reader(): got %v, want no error", err) } got, err = ioutil.ReadAll(r) if err != nil { t.Fatalf("ioutil.ReadAll(mv.Reader()): got %v, want no error", err) } if want := []byte(hdrwant); !bytes.Equal(got, want) { t.Fatalf("mv.Read(): got %q, want %q", got, want) } } func TestRequestView(t *testing.T) { body := strings.NewReader("body content") req, err := http.NewRequest("GET", "http://example.com/path?k=v", body) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Header.Set("Request-Header", "true") // Force Content Length to be unset to simulate lack of Content-Length and // Transfer-Encoding which is valid. req.ContentLength = -1 mv := New() if err := mv.SnapshotRequest(req); err != nil { t.Fatalf("SnapshotRequest(): got %v, want no error", err) } got, err := ioutil.ReadAll(mv.HeaderReader()) if err != nil { t.Fatalf("ioutil.ReadAll(mv.HeaderReader()): got %v, want no error", err) } hdrwant := "GET http://example.com/path?k=v HTTP/1.1\r\n" + "Host: example.com\r\n" + "Request-Header: true\r\n\r\n" if !bytes.Equal(got, []byte(hdrwant)) { t.Fatalf("mv.HeaderReader(): got %q, want %q", got, hdrwant) } br, err := mv.BodyReader() if err != nil { t.Fatalf("mv.BodyReader(): got %v, want no error", err) } got, err = ioutil.ReadAll(br) if err != nil { t.Fatalf("ioutil.ReadAll(mv.BodyReader()): got %v, want no error", err) } bodywant := "body content" if !bytes.Equal(got, []byte(bodywant)) { t.Fatalf("mv.BodyReader(): got %q, want %q", got, bodywant) } r, err := mv.Reader() if err != nil { t.Fatalf("mv.Reader(): got %v, want no error", err) } got, err = ioutil.ReadAll(r) if err != nil { t.Fatalf("ioutil.ReadAll(mv.Reader()): got %v, want no error", err) } if want := []byte(hdrwant + bodywant); !bytes.Equal(got, want) { t.Fatalf("mv.Read(): got %q, want %q", got, want) } // Sanity check to ensure it still parses. if _, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(got))); err != nil { t.Fatalf("http.ReadRequest(): got %v, want no error", err) } } func TestRequestViewSkipBodyUnlessContentType(t *testing.T) { req, err := http.NewRequest("GET", "http://example.com", strings.NewReader("body content")) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.ContentLength = 12 req.Header.Set("Content-Type", "text/plain; charset=utf-8") mv := New() mv.SkipBodyUnlessContentType("text/plain") if err := mv.SnapshotRequest(req); err != nil { t.Fatalf("SnapshotRequest(): got %v, want no error", err) } br, err := mv.BodyReader() if err != nil { t.Fatalf("mv.BodyReader(): got %v, want no error", err) } got, err := ioutil.ReadAll(br) if err != nil { t.Fatalf("ioutil.ReadAll(mv.BodyReader()): got %v, want no error", err) } bodywant := "body content" if !bytes.Equal(got, []byte(bodywant)) { t.Fatalf("mv.BodyReader(): got %q, want %q", got, bodywant) } req.Header.Set("Content-Type", "image/png") mv = New() mv.SkipBodyUnlessContentType("text/plain") if err := mv.SnapshotRequest(req); err != nil { t.Fatalf("SnapshotRequest(): got %v, want no error", err) } br, err = mv.BodyReader() if err != nil { t.Fatalf("mv.BodyReader(): got %v, want no error", err) } if _, err := br.Read(nil); err != io.EOF { t.Fatalf("br.Read(): got %v, want io.EOF", err) } } func TestRequestViewChunkedTransferEncoding(t *testing.T) { req, err := http.NewRequest("GET", "http://example.com/path?k=v", strings.NewReader("body content")) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.TransferEncoding = []string{"chunked"} req.Header.Set("Trailer", "Trailer-Header") req.Trailer = http.Header{ "Trailer-Header": []string{"true"}, } mv := New() if err := mv.SnapshotRequest(req); err != nil { t.Fatalf("SnapshotRequest(): got %v, want no error", err) } got, err := ioutil.ReadAll(mv.HeaderReader()) if err != nil { t.Fatalf("ioutil.ReadAll(mv.HeaderReader()): got %v, want no error", err) } hdrwant := "GET http://example.com/path?k=v HTTP/1.1\r\n" + "Host: example.com\r\n" + "Transfer-Encoding: chunked\r\n" + "Trailer: Trailer-Header\r\n\r\n" if !bytes.Equal(got, []byte(hdrwant)) { t.Fatalf("mv.HeaderReader(): got %q, want %q", got, hdrwant) } br, err := mv.BodyReader() if err != nil { t.Fatalf("mv.BodyReader(): got %v, want no error", err) } got, err = ioutil.ReadAll(br) if err != nil { t.Fatalf("ioutil.ReadAll(mv.BodyReader()): got %v, want no error", err) } bodywant := "c\r\nbody content\r\n0\r\n" if !bytes.Equal(got, []byte(bodywant)) { t.Fatalf("mv.BodyReader(): got %q, want %q", got, bodywant) } got, err = ioutil.ReadAll(mv.TrailerReader()) if err != nil { t.Fatalf("ioutil.ReadAll(mv.TrailerReader()): got %v, want no error", err) } trailerwant := "Trailer-Header: true\r\n" if !bytes.Equal(got, []byte(trailerwant)) { t.Fatalf("mv.TrailerReader(): got %q, want %q", got, trailerwant) } r, err := mv.Reader() if err != nil { t.Fatalf("mv.Reader(): got %v, want no error", err) } got, err = ioutil.ReadAll(r) if err != nil { t.Fatalf("ioutil.ReadAll(mv.Reader()): got %v, want no error", err) } if want := []byte(hdrwant + bodywant + trailerwant); !bytes.Equal(got, want) { t.Fatalf("mv.Read(): got %q, want %q", got, want) } // Sanity check to ensure it still parses. if _, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(got))); err != nil { t.Fatalf("http.ReadRequest(): got %v, want no error", err) } } func TestRequestViewDecodeGzipContentEncoding(t *testing.T) { body := new(bytes.Buffer) gw := gzip.NewWriter(body) gw.Write([]byte("body content")) gw.Flush() gw.Close() req, err := http.NewRequest("GET", "http://example.com/path?k=v", body) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.TransferEncoding = []string{"chunked"} req.Header.Set("Content-Encoding", "gzip") mv := New() if err := mv.SnapshotRequest(req); err != nil { t.Fatalf("SnapshotRequest(): got %v, want no error", err) } got, err := ioutil.ReadAll(mv.HeaderReader()) if err != nil { t.Fatalf("ioutil.ReadAll(mv.HeaderReader()): got %v, want no error", err) } hdrwant := "GET http://example.com/path?k=v HTTP/1.1\r\n" + "Host: example.com\r\n" + "Transfer-Encoding: chunked\r\n" + "Content-Encoding: gzip\r\n\r\n" if !bytes.Equal(got, []byte(hdrwant)) { t.Fatalf("mv.HeaderReader(): got %q, want %q", got, hdrwant) } br, err := mv.BodyReader(Decode()) if err != nil { t.Fatalf("mv.BodyReader(): got %v, want no error", err) } got, err = ioutil.ReadAll(br) if err != nil { t.Fatalf("ioutil.ReadAll(mv.BodyReader()): got %v, wt o error", err) } bodywant := "body content" if !bytes.Equal(got, []byte(bodywant)) { t.Fatalf("mv.BodyReader(): got %q, want %q", got, bodywant) } r, err := mv.Reader(Decode()) if err != nil { t.Fatalf("mv.Reader(): got %v, want no error", err) } got, err = ioutil.ReadAll(r) if err != nil { t.Fatalf("ioutil.ReadAll(mv.Reader()): got %v, want no error", err) } if want := []byte(hdrwant + bodywant + "\r\n"); !bytes.Equal(got, want) { t.Fatalf("mv.Read(): got %q, want %q", got, want) } } func TestRequestViewDecodeDeflateContentEncoding(t *testing.T) { body := new(bytes.Buffer) dw, err := flate.NewWriter(body, -1) if err != nil { t.Fatalf("flate.NewWriter(): got %v, want no error", err) } dw.Write([]byte("body content")) dw.Flush() dw.Close() req, err := http.NewRequest("GET", "http://example.com/path?k=v", body) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.TransferEncoding = []string{"chunked"} req.Header.Set("Content-Encoding", "deflate") mv := New() if err := mv.SnapshotRequest(req); err != nil { t.Fatalf("SnapshotRequest(): got %v, want no error", err) } got, err := ioutil.ReadAll(mv.HeaderReader()) if err != nil { t.Fatalf("ioutil.ReadAll(mv.HeaderReader()): got %v, want no error", err) } hdrwant := "GET http://example.com/path?k=v HTTP/1.1\r\n" + "Host: example.com\r\n" + "Transfer-Encoding: chunked\r\n" + "Content-Encoding: deflate\r\n\r\n" if !bytes.Equal(got, []byte(hdrwant)) { t.Fatalf("mv.HeaderReader(): got %q, want %q", got, hdrwant) } br, err := mv.BodyReader(Decode()) if err != nil { t.Fatalf("mv.BodyReader(): got %v, want no error", err) } got, err = ioutil.ReadAll(br) if err != nil { t.Fatalf("ioutil.ReadAll(mv.BodyReader()): got %v, want no error", err) } bodywant := "body content" if !bytes.Equal(got, []byte(bodywant)) { t.Fatalf("mv.BodyReader(): got %q, want %q", got, bodywant) } r, err := mv.Reader(Decode()) if err != nil { t.Fatalf("mv.Reader(): got %v, want no error", err) } got, err = ioutil.ReadAll(r) if err != nil { t.Fatalf("ioutil.ReadAll(mv.Reader()): got %v, want no error", err) } if want := []byte(hdrwant + bodywant + "\r\n"); !bytes.Equal(got, want) { t.Fatalf("mv.Read(): got %q, want %q", got, want) } } func TestResponseViewHeadersOnly(t *testing.T) { body := strings.NewReader("body content") res := proxyutil.NewResponse(200, body, nil) res.ContentLength = 12 res.Header.Set("Response-Header", "true") mv := New() mv.SkipBody(true) if err := mv.SnapshotResponse(res); err != nil { t.Fatalf("SnapshotResponse(): got %v, want no error", err) } got, err := ioutil.ReadAll(mv.HeaderReader()) if err != nil { t.Fatalf("ioutil.ReadAll(mv.HeaderReader()): got %v, want no error", err) } hdrwant := "HTTP/1.1 200 OK\r\n" + "Content-Length: 12\r\n" + "Response-Header: true\r\n\r\n" if !bytes.Equal(got, []byte(hdrwant)) { t.Fatalf("mv.HeaderReader(): got %q, want %q", got, hdrwant) } br, err := mv.BodyReader() if err != nil { t.Fatalf("mv.BodyReader(): got %v, want no error", err) } if _, err := br.Read(nil); err != io.EOF { t.Fatalf("ioutil.ReadAll(mv.BodyReader()): got %v, want io.EOF", err) } r, err := mv.Reader() if err != nil { t.Fatalf("mv.Reader(): got %v, want no error", err) } got, err = ioutil.ReadAll(r) if err != nil { t.Fatalf("ioutil.ReadAll(mv.Reader()): got %v, want no error", err) } if want := []byte(hdrwant); !bytes.Equal(got, want) { t.Fatalf("mv.Read(): got %q, want %q", got, want) } } func TestResponseView(t *testing.T) { body := strings.NewReader("body content") res := proxyutil.NewResponse(200, body, nil) res.ContentLength = 12 res.Header.Set("Response-Header", "true") mv := New() if err := mv.SnapshotResponse(res); err != nil { t.Fatalf("SnapshotResponse(): got %v, want no error", err) } got, err := ioutil.ReadAll(mv.HeaderReader()) if err != nil { t.Fatalf("ioutil.ReadAll(mv.HeaderReader()): got %v, want no error", err) } hdrwant := "HTTP/1.1 200 OK\r\n" + "Content-Length: 12\r\n" + "Response-Header: true\r\n\r\n" if !bytes.Equal(got, []byte(hdrwant)) { t.Fatalf("mv.HeaderReader(): got %q, want %q", got, hdrwant) } br, err := mv.BodyReader() if err != nil { t.Fatalf("mv.BodyReader(): got %v, want no error", err) } got, err = ioutil.ReadAll(br) if err != nil { t.Fatalf("ioutil.ReadAll(mv.BodyReader()): got %v, want no error", err) } bodywant := "body content" if !bytes.Equal(got, []byte(bodywant)) { t.Fatalf("mv.BodyReader(): got %q, want %q", got, bodywant) } r, err := mv.Reader() if err != nil { t.Fatalf("mv.Reader(): got %v, want no error", err) } got, err = ioutil.ReadAll(r) if err != nil { t.Fatalf("ioutil.ReadAll(mv.Reader()): got %v, want no error", err) } if want := []byte(hdrwant + bodywant); !bytes.Equal(got, want) { t.Fatalf("mv.Read(): got %q, want %q", got, want) } // Sanity check to ensure it still parses. if _, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(got)), nil); err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } } func TestResponseViewSkipBodyUnlessContentType(t *testing.T) { res := proxyutil.NewResponse(200, strings.NewReader("body content"), nil) res.ContentLength = 12 res.Header.Set("Content-Type", "text/plain; charset=utf-8") mv := New() mv.SkipBodyUnlessContentType("text/plain") if err := mv.SnapshotResponse(res); err != nil { t.Fatalf("SnapshotResponse(): got %v, want no error", err) } br, err := mv.BodyReader() if err != nil { t.Fatalf("mv.BodyReader(): got %v, want no error", err) } got, err := ioutil.ReadAll(br) if err != nil { t.Fatalf("ioutil.ReadAll(mv.BodyReader()): got %v, want no error", err) } bodywant := "body content" if !bytes.Equal(got, []byte(bodywant)) { t.Fatalf("mv.BodyReader(): got %q, want %q", got, bodywant) } res.Header.Set("Content-Type", "image/png") mv = New() mv.SkipBodyUnlessContentType("text/plain") if err := mv.SnapshotResponse(res); err != nil { t.Fatalf("SnapshotResponse(): got %v, want no error", err) } br, err = mv.BodyReader() if err != nil { t.Fatalf("mv.BodyReader(): got %v, want no error", err) } if _, err := br.Read(nil); err != io.EOF { t.Fatalf("br.Read(): got %v, want io.EOF", err) } } func TestResponseViewChunkedTransferEncoding(t *testing.T) { body := strings.NewReader("body content") res := proxyutil.NewResponse(200, body, nil) res.TransferEncoding = []string{"chunked"} res.Header.Set("Trailer", "Trailer-Header") res.Trailer = http.Header{ "Trailer-Header": []string{"true"}, } mv := New() if err := mv.SnapshotResponse(res); err != nil { t.Fatalf("SnapshotResponse(): got %v, want no error", err) } got, err := ioutil.ReadAll(mv.HeaderReader()) if err != nil { t.Fatalf("ioutil.ReadAll(mv.HeaderReader()): got %v, want no error", err) } hdrwant := "HTTP/1.1 200 OK\r\n" + "Transfer-Encoding: chunked\r\n" + "Trailer: Trailer-Header\r\n\r\n" if !bytes.Equal(got, []byte(hdrwant)) { t.Fatalf("mv.HeaderReader(): got %q, want %q", got, hdrwant) } br, err := mv.BodyReader() if err != nil { t.Fatalf("mv.BodyReader(): got %v, want no error", err) } got, err = ioutil.ReadAll(br) if err != nil { t.Fatalf("ioutil.ReadAll(mv.BodyReader()): got %v, want no error", err) } bodywant := "c\r\nbody content\r\n0\r\n" if !bytes.Equal(got, []byte(bodywant)) { t.Fatalf("mv.BodyReader(): got %q, want %q", got, bodywant) } got, err = ioutil.ReadAll(mv.TrailerReader()) if err != nil { t.Fatalf("ioutil.ReadAll(mv.TrailerReader()): got %v, want no error", err) } trailerwant := "Trailer-Header: true\r\n" if !bytes.Equal(got, []byte(trailerwant)) { t.Fatalf("mv.TrailerReader(): got %q, want %q", got, trailerwant) } r, err := mv.Reader() if err != nil { t.Fatalf("mv.Reader(): got %v, want no error", err) } got, err = ioutil.ReadAll(r) if err != nil { t.Fatalf("ioutil.ReadAll(mv.Reader()): got %v, want no error", err) } if want := []byte(hdrwant + bodywant + trailerwant); !bytes.Equal(got, want) { t.Fatalf("mv.Read(): got %q, want %q", got, want) } // Sanity check to ensure it still parses. if _, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(got)), nil); err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } } func TestResponseViewDecodeGzipContentEncoding(t *testing.T) { body := new(bytes.Buffer) gw := gzip.NewWriter(body) gw.Write([]byte("body content")) gw.Flush() gw.Close() res := proxyutil.NewResponse(200, body, nil) res.TransferEncoding = []string{"chunked"} res.Header.Set("Content-Encoding", "gzip") mv := New() if err := mv.SnapshotResponse(res); err != nil { t.Fatalf("SnapshotResponse(): got %v, want no error", err) } got, err := ioutil.ReadAll(mv.HeaderReader()) if err != nil { t.Fatalf("ioutil.ReadAll(mv.HeaderReader()): got %v, want no error", err) } hdrwant := "HTTP/1.1 200 OK\r\n" + "Transfer-Encoding: chunked\r\n" + "Content-Encoding: gzip\r\n\r\n" if !bytes.Equal(got, []byte(hdrwant)) { t.Fatalf("mv.HeaderReader(): got %q, want %q", got, hdrwant) } br, err := mv.BodyReader(Decode()) if err != nil { t.Fatalf("mv.BodyReader(): got %v, want no error", err) } got, err = ioutil.ReadAll(br) if err != nil { t.Fatalf("ioutil.ReadAll(mv.BodyReader()): got %v, wt o error", err) } bodywant := "body content" if !bytes.Equal(got, []byte(bodywant)) { t.Fatalf("mv.BodyReader(): got %q, want %q", got, bodywant) } r, err := mv.Reader(Decode()) if err != nil { t.Fatalf("mv.Reader(): got %v, want no error", err) } got, err = ioutil.ReadAll(r) if err != nil { t.Fatalf("ioutil.ReadAll(mv.Reader()): got %v, want no error", err) } if want := []byte(hdrwant + bodywant + "\r\n"); !bytes.Equal(got, want) { t.Fatalf("mv.Read(): got %q, want %q", got, want) } } func TestResponseViewDecodeGzipContentEncodingPartial(t *testing.T) { bodywant := "partial content" res := proxyutil.NewResponse(206, strings.NewReader(bodywant), nil) res.TransferEncoding = []string{"chunked"} res.Header.Set("Content-Encoding", "gzip") mv := New() if err := mv.SnapshotResponse(res); err != nil { t.Fatalf("SnapshotResponse(): got %v, want no error", err) } br, err := mv.BodyReader(Decode()) if err != nil { t.Fatalf("mv.BodyReader(): got %v, want no error", err) } got, err := ioutil.ReadAll(br) if err != nil { t.Fatalf("ioutil.ReadAll(mv.BodyReader()): got %v, wt o error", err) } if !bytes.Equal(got, []byte(bodywant)) { t.Fatalf("mv.BodyReader(): got %q, want %q", got, bodywant) } } func TestResponseViewDecodeDeflateContentEncoding(t *testing.T) { body := new(bytes.Buffer) dw, err := flate.NewWriter(body, -1) if err != nil { t.Fatalf("flate.NewWriter(): got %v, want no error", err) } dw.Write([]byte("body content")) dw.Flush() dw.Close() res := proxyutil.NewResponse(200, body, nil) res.TransferEncoding = []string{"chunked"} res.Header.Set("Content-Encoding", "deflate") mv := New() if err := mv.SnapshotResponse(res); err != nil { t.Fatalf("SnapshotResponse(): got %v, want no error", err) } got, err := ioutil.ReadAll(mv.HeaderReader()) if err != nil { t.Fatalf("ioutil.ReadAll(mv.HeaderReader()): got %v, want no error", err) } hdrwant := "HTTP/1.1 200 OK\r\n" + "Transfer-Encoding: chunked\r\n" + "Content-Encoding: deflate\r\n\r\n" if !bytes.Equal(got, []byte(hdrwant)) { t.Fatalf("mv.HeaderReader(): got %q, want %q", got, hdrwant) } br, err := mv.BodyReader(Decode()) if err != nil { t.Fatalf("mv.BodyReader(): got %v, want no error", err) } got, err = ioutil.ReadAll(br) if err != nil { t.Fatalf("ioutil.ReadAll(mv.BodyReader()): got %v, wt o error", err) } bodywant := "body content" if !bytes.Equal(got, []byte(bodywant)) { t.Fatalf("mv.BodyReader(): got %q, want %q", got, bodywant) } r, err := mv.Reader(Decode()) if err != nil { t.Fatalf("mv.Reader(): got %v, want no error", err) } got, err = ioutil.ReadAll(r) if err != nil { t.Fatalf("ioutil.ReadAll(mv.Reader()): got %v, want no error", err) } if want := []byte(hdrwant + bodywant + "\r\n"); !bytes.Equal(got, want) { t.Fatalf("mv.Read(): got %q, want %q", got, want) } } martian-3.3.2/method/000077500000000000000000000000001421371434000144315ustar00rootroot00000000000000martian-3.3.2/method/method_filter.go000066400000000000000000000061371421371434000176140ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package method import ( "encoding/json" "net/http" "strings" "github.com/google/martian/v3" "github.com/google/martian/v3/filter" "github.com/google/martian/v3/log" "github.com/google/martian/v3/parse" ) var noop = martian.Noop("method.Filter") func init() { parse.Register("method.Filter", filterFromJSON) } // Filter runs modifier iff the request method matches the specified method. type Filter struct { *filter.Filter } type filterJSON struct { Method string `json:"method"` Modifier json.RawMessage `json:"modifier"` ElseModifier json.RawMessage `json:"else"` Scope []parse.ModifierType `json:"scope"` } func filterFromJSON(b []byte) (*parse.Result, error) { msg := &filterJSON{} if err := json.Unmarshal(b, msg); err != nil { return nil, err } filter := NewFilter(msg.Method) m, err := parse.FromJSON(msg.Modifier) if err != nil { return nil, err } filter.RequestWhenTrue(m.RequestModifier()) filter.ResponseWhenTrue(m.ResponseModifier()) if len(msg.ElseModifier) > 0 { em, err := parse.FromJSON(msg.ElseModifier) if err != nil { return nil, err } if em != nil { filter.RequestWhenFalse(em.RequestModifier()) filter.ResponseWhenFalse(em.ResponseModifier()) } } return parse.NewResult(filter, msg.Scope) } // NewFilter constructs a filter that applies the modifer when the // request method matches meth. func NewFilter(meth string) *Filter { log.Debugf("method.NewFilter(%q)", meth) m := NewMatcher(meth) f := filter.New() f.SetRequestCondition(m) f.SetResponseCondition(m) return &Filter{f} } // Matcher is a conditional evaluator of request methods to be used in // filters that take conditionals. type Matcher struct { method string } // NewMatcher builds a new method matcher. func NewMatcher(method string) *Matcher { return &Matcher{ method: method, } } // MatchRequest retuns true if m.method matches the request method. func (m *Matcher) MatchRequest(req *http.Request) bool { matched := m.matches(req.Method) if matched { log.Debugf("method.MatchRequest: matched %s request: %s", req.Method, req.URL) } return matched } // MatchResponse retuns true if m.method matches res.Request.Method. func (m *Matcher) MatchResponse(res *http.Response) bool { matched := m.matches(res.Request.Method) if matched { log.Debugf("method.MatchResponse: matched %s request: %s", res.Request.Method, res.Request.URL) } return matched } func (m *Matcher) matches(method string) bool { return strings.EqualFold(method, m.method) } martian-3.3.2/method/method_filter_test.go000066400000000000000000000117371421371434000206550ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package method import ( "net/http" "testing" _ "github.com/google/martian/v3/header" "github.com/google/martian/v3/martiantest" "github.com/google/martian/v3/parse" "github.com/google/martian/v3/proxyutil" ) func TestFilterModifyRequest(t *testing.T) { tt := []struct { method string want bool }{ { method: "GET", want: true, }, { method: "get", want: true, }, { method: "POST", want: false, }, { method: "DELETE", want: false, }, { method: "CONNECT", want: false, }, { method: "connect", want: false, }, } for i, tc := range tt { req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("%d. NewRequest(): got %v, want no error", i, err) } mod := NewFilter(tc.method) tm := martiantest.NewModifier() mod.SetRequestModifier(tm) if err := mod.ModifyRequest(req); err != nil { t.Fatalf("%d. ModifyRequest(): got %q, want no error", i, err) } if tm.RequestModified() != tc.want { t.Errorf("%d. tm.RequestModified(): got %t, want %t", i, tm.RequestModified(), tc.want) } } } func TestFilterModifyResponse(t *testing.T) { tt := []struct { method string want bool }{ { method: "GET", want: true, }, { method: "get", want: true, }, { method: "POST", want: false, }, { method: "DELETE", want: false, }, { method: "CONNECT", want: false, }, { method: "connect", want: false, }, } for i, tc := range tt { req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("%d. NewRequest(): got %v, want no error", i, err) } res := proxyutil.NewResponse(200, nil, req) mod := NewFilter(tc.method) tm := martiantest.NewModifier() mod.SetResponseModifier(tm) if err := mod.ModifyResponse(res); err != nil { t.Fatalf("%d. ModifyResponse(): got %q, want no error", i, err) } if tm.ResponseModified() != tc.want { t.Errorf("%d. tm.ResponseModified(): got %t, want %t", i, tm.ResponseModified(), tc.want) } } } func TestFilterFromJSON(t *testing.T) { j := `{ "method.Filter": { "scope": ["request", "response"], "method": "GET", "modifier": { "header.Modifier": { "scope": ["request", "response"], "name": "Mod-Run", "value": "true" } }, "else": { "header.Modifier": { "scope": ["request", "response"], "name": "Else-Run", "value": "true" } } } }` msg := []byte(j) r, err := parse.FromJSON(msg) if err != nil { t.Fatalf("FilterFromJSON(): got %v, want no error", err) } reqmod := r.RequestModifier() if reqmod == nil { t.Fatal("reqmod: got nil, want not nil") } req, err := http.NewRequest("GET", "https://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := reqmod.ModifyRequest(req); err != nil { t.Fatalf("reqmod.ModifyRequest(): got %v, want no error", err) } if got, want := req.Header.Get("Mod-Run"), "true"; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", "Mod-Run", got, want) } if got, want := req.Header.Get("Else-Run"), ""; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", "Else-Run", got, want) } resmod := r.ResponseModifier() if resmod == nil { t.Fatalf("resmod: got nil, want not nil") } res := proxyutil.NewResponse(200, nil, req) if err := resmod.ModifyResponse(res); err != nil { t.Fatalf("resmod.ModifyResponse(): got %v, want no error", err) } if got, want := res.Header.Get("Mod-Run"), "true"; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Mod-Run", got, want) } // test else conditional modifier with POST req, err = http.NewRequest("POST", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := reqmod.ModifyRequest(req); err != nil { t.Fatalf("reqmod.ModifyRequest(): got %v, want no error", err) } if got, want := req.Header.Get("Mod-Run"), ""; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", "Mod-Run", got, want) } if got, want := req.Header.Get("Else-Run"), "true"; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", "Else-Run", got, want) } } martian-3.3.2/method/method_verifier.go000066400000000000000000000051561421371434000201420ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package method provides utilities for working with request methods. package method import ( "encoding/json" "fmt" "net/http" "github.com/google/martian/v3" "github.com/google/martian/v3/parse" "github.com/google/martian/v3/verify" ) type verifier struct { method string err *martian.MultiError } type verifierJSON struct { Method string `json:"method"` Scope []parse.ModifierType `json:"scope"` } func init() { parse.Register("method.Verifier", verifierFromJSON) } // NewVerifier returns a new method verifier. func NewVerifier(method string) (verify.RequestVerifier, error) { if method == "" { return nil, fmt.Errorf("%s is not a valid HTTP method", method) } return &verifier{ method: method, err: martian.NewMultiError(), }, nil } // ModifyRequest verifies that the request's method matches the given method // in all modified requests. An error will be added to the contained *MultiError // if a method is unmatched. func (v *verifier) ModifyRequest(req *http.Request) error { m := req.Method if v.method != "" && v.method != m { err := fmt.Errorf("request(%v) method verification error: got %v, want %v", req.URL, v.method, m) v.err.Add(err) } return nil } // VerifyRequests returns an error if verification for any request failed. // If an error is returned it will be of type *martian.MultiError. func (v *verifier) VerifyRequests() error { if v.err.Empty() { return nil } return v.err } // ResetRequestVerifications clears all failed request verifications. func (v *verifier) ResetRequestVerifications() { v.err = martian.NewMultiError() } // verifierFromJSON builds a method.Verifier from JSON. // // Example JSON: // { // "method.Verifier": { // "scope": ["request"], // "method": "POST" // } // } func verifierFromJSON(b []byte) (*parse.Result, error) { msg := &verifierJSON{} if err := json.Unmarshal(b, msg); err != nil { return nil, err } v, err := NewVerifier(msg.Method) if err != nil { return nil, err } return parse.NewResult(v, msg.Scope) } martian-3.3.2/method/method_verifier_test.go000066400000000000000000000071231421371434000211750ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package method import ( "net/http" "testing" "github.com/google/martian/v3" "github.com/google/martian/v3/parse" "github.com/google/martian/v3/verify" ) func TestVerifierFromJSON(t *testing.T) { msg := []byte(`{ "method.Verifier": { "scope": ["request"], "method": "POST" } }`) r, err := parse.FromJSON(msg) if err != nil { t.Fatalf("parse.FromJSON(): got %v, want no error", err) } reqmod := r.RequestModifier() if reqmod == nil { t.Fatal("reqmod: got nil, want not nil") } reqv, ok := reqmod.(verify.RequestVerifier) if !ok { t.Fatal("reqmod.(verify.RequestVerifier): got !ok, want ok") } req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := reqv.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if err := reqv.VerifyRequests(); err == nil { t.Error("VerifyRequests(): got nil, want not nil") } } func TestVerifyRequestPasses(t *testing.T) { for _, m := range []string{ "GET", "HEAD", "PUT", "POST", "DELETE", "TRACE", "OPTIONS", "CONNECT", "PATCH", } { v, err := NewVerifier(m) if err != nil { t.Fatalf("NewVerifier(%q): got %v, want no error", m, err) } req, err := http.NewRequest(m, "www.google.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := v.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if err := v.VerifyRequests(); err != nil { t.Fatalf("VerifyRequests(): got %v, want no error", err) } v.ResetRequestVerifications() if err := v.VerifyRequests(); err != nil { t.Errorf("v.VerifyRequests(): got %v, want no error", err) } } } func TestVerifyPostRequestFailsWithMultiFail(t *testing.T) { v, _ := NewVerifier("POST") req, err := http.NewRequest("GET", "http://www.google.com", nil) if err != nil { t.Fatalf("http.NewRequest got %v, want no error", err) } if err := v.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if err := v.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } merr, ok := v.VerifyRequests().(*martian.MultiError) if !ok { t.Fatalf("VerifyRequests(): got nil, want *verify.MultiError") } errs := merr.Errors() if len(errs) != 2 { t.Fatalf("len(merr.Errors()): got %d, want 2", len(errs)) } expectErr := "request(http://www.google.com) method verification error: got POST, want GET" for i := range errs { if got, want := errs[i].Error(), expectErr; got != want { t.Errorf("%d. err.Error(): mismatched error output\ngot: %s\nwant: %s", i, got, want) } } v.ResetRequestVerifications() if err := v.VerifyRequests(); err != nil { t.Errorf("v.VerifyRequests(): got %v, want no error", err) } } func TestBadInputToConstructor(t *testing.T) { if _, err := NewVerifier(""); err == nil { t.Fatalf("NewVerifier(): no error returned for empty") } } martian-3.3.2/mitm/000077500000000000000000000000001421371434000141175ustar00rootroot00000000000000martian-3.3.2/mitm/mitm.go000066400000000000000000000211651421371434000154210ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package mitm provides tooling for MITMing TLS connections. It provides // tooling to create CA certs and generate TLS configs that can be used to MITM // a TLS connection with a provided CA certificate. package mitm import ( "bytes" "crypto/rand" "crypto/rsa" "crypto/sha1" "crypto/tls" "crypto/x509" "crypto/x509/pkix" "errors" "math/big" "net" "net/http" "sync" "time" "github.com/google/martian/v3/h2" "github.com/google/martian/v3/log" ) // MaxSerialNumber is the upper boundary that is used to create unique serial // numbers for the certificate. This can be any unsigned integer up to 20 // bytes (2^(8*20)-1). var MaxSerialNumber = big.NewInt(0).SetBytes(bytes.Repeat([]byte{255}, 20)) // Config is a set of configuration values that are used to build TLS configs // capable of MITM. type Config struct { ca *x509.Certificate capriv interface{} priv *rsa.PrivateKey keyID []byte validity time.Duration org string h2Config *h2.Config getCertificate func(*tls.ClientHelloInfo) (*tls.Certificate, error) roots *x509.CertPool skipVerify bool handshakeErrorCallback func(*http.Request, error) certmu sync.RWMutex certs map[string]*tls.Certificate } // NewAuthority creates a new CA certificate and associated // private key. func NewAuthority(name, organization string, validity time.Duration) (*x509.Certificate, *rsa.PrivateKey, error) { priv, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { return nil, nil, err } pub := priv.Public() // Subject Key Identifier support for end entity certificate. // https://www.ietf.org/rfc/rfc3280.txt (section 4.2.1.2) pkixpub, err := x509.MarshalPKIXPublicKey(pub) if err != nil { return nil, nil, err } h := sha1.New() h.Write(pkixpub) keyID := h.Sum(nil) // TODO: keep a map of used serial numbers to avoid potentially reusing a // serial multiple times. serial, err := rand.Int(rand.Reader, MaxSerialNumber) if err != nil { return nil, nil, err } tmpl := &x509.Certificate{ SerialNumber: serial, Subject: pkix.Name{ CommonName: name, Organization: []string{organization}, }, SubjectKeyId: keyID, KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, BasicConstraintsValid: true, NotBefore: time.Now().Add(-validity), NotAfter: time.Now().Add(validity), DNSNames: []string{name}, IsCA: true, } raw, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, pub, priv) if err != nil { return nil, nil, err } // Parse certificate bytes so that we have a leaf certificate. x509c, err := x509.ParseCertificate(raw) if err != nil { return nil, nil, err } return x509c, priv, nil } // NewConfig creates a MITM config using the CA certificate and // private key to generate on-the-fly certificates. func NewConfig(ca *x509.Certificate, privateKey interface{}) (*Config, error) { roots := x509.NewCertPool() roots.AddCert(ca) priv, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { return nil, err } pub := priv.Public() // Subject Key Identifier support for end entity certificate. // https://www.ietf.org/rfc/rfc3280.txt (section 4.2.1.2) pkixpub, err := x509.MarshalPKIXPublicKey(pub) if err != nil { return nil, err } h := sha1.New() h.Write(pkixpub) keyID := h.Sum(nil) return &Config{ ca: ca, capriv: privateKey, priv: priv, keyID: keyID, validity: time.Hour, org: "Martian Proxy", certs: make(map[string]*tls.Certificate), roots: roots, }, nil } // SetValidity sets the validity window around the current time that the // certificate is valid for. func (c *Config) SetValidity(validity time.Duration) { c.validity = validity } // SkipTLSVerify skips the TLS certification verification check. func (c *Config) SkipTLSVerify(skip bool) { c.skipVerify = skip } // SetOrganization sets the organization of the certificate. func (c *Config) SetOrganization(org string) { c.org = org } // SetH2Config configures processing of HTTP/2 streams. func (c *Config) SetH2Config(h2Config *h2.Config) { c.h2Config = h2Config } // H2Config returns the current HTTP/2 configuration. func (c *Config) H2Config() *h2.Config { return c.h2Config } // SetHandshakeErrorCallback sets the handshakeErrorCallback function. func (c *Config) SetHandshakeErrorCallback(cb func(*http.Request, error)) { c.handshakeErrorCallback = cb } // HandshakeErrorCallback calls the handshakeErrorCallback function in this // Config, if it is non-nil. Request is the connect request that this handshake // is being executed through. func (c *Config) HandshakeErrorCallback(r *http.Request, err error) { if c.handshakeErrorCallback != nil { c.handshakeErrorCallback(r, err) } } // TLS returns a *tls.Config that will generate certificates on-the-fly using // the SNI extension in the TLS ClientHello. func (c *Config) TLS() *tls.Config { return &tls.Config{ InsecureSkipVerify: c.skipVerify, GetCertificate: func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { if clientHello.ServerName == "" { return nil, errors.New("mitm: SNI not provided, failed to build certificate") } return c.cert(clientHello.ServerName) }, NextProtos: []string{"http/1.1"}, } } // TLSForHost returns a *tls.Config that will generate certificates on-the-fly // using SNI from the connection, or fall back to the provided hostname. func (c *Config) TLSForHost(hostname string) *tls.Config { nextProtos := []string{"http/1.1"} if c.h2AllowedHost(hostname) { nextProtos = []string{"h2", "http/1.1"} } return &tls.Config{ InsecureSkipVerify: c.skipVerify, GetCertificate: func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { host := clientHello.ServerName if host == "" { host = hostname } return c.cert(host) }, NextProtos: nextProtos, } } func (c *Config) h2AllowedHost(host string) bool { return c.h2Config != nil && c.h2Config.AllowedHostsFilter != nil && c.h2Config.AllowedHostsFilter(host) } func (c *Config) cert(hostname string) (*tls.Certificate, error) { // Remove the port if it exists. host, _, err := net.SplitHostPort(hostname) if err == nil { hostname = host } c.certmu.RLock() tlsc, ok := c.certs[hostname] c.certmu.RUnlock() if ok { log.Debugf("mitm: cache hit for %s", hostname) // Check validity of the certificate for hostname match, expiry, etc. In // particular, if the cached certificate has expired, create a new one. if _, err := tlsc.Leaf.Verify(x509.VerifyOptions{ DNSName: hostname, Roots: c.roots, }); err == nil { return tlsc, nil } log.Debugf("mitm: invalid certificate in cache for %s", hostname) } log.Debugf("mitm: cache miss for %s", hostname) serial, err := rand.Int(rand.Reader, MaxSerialNumber) if err != nil { return nil, err } tmpl := &x509.Certificate{ SerialNumber: serial, Subject: pkix.Name{ CommonName: hostname, Organization: []string{c.org}, }, SubjectKeyId: c.keyID, KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, BasicConstraintsValid: true, NotBefore: time.Now().Add(-c.validity), NotAfter: time.Now().Add(c.validity), } if ip := net.ParseIP(hostname); ip != nil { tmpl.IPAddresses = []net.IP{ip} } else { tmpl.DNSNames = []string{hostname} } raw, err := x509.CreateCertificate(rand.Reader, tmpl, c.ca, c.priv.Public(), c.capriv) if err != nil { return nil, err } // Parse certificate bytes so that we have a leaf certificate. x509c, err := x509.ParseCertificate(raw) if err != nil { return nil, err } tlsc = &tls.Certificate{ Certificate: [][]byte{raw, c.ca.Raw}, PrivateKey: c.priv, Leaf: x509c, } c.certmu.Lock() c.certs[hostname] = tlsc c.certmu.Unlock() return tlsc, nil } martian-3.3.2/mitm/mitm_test.go000066400000000000000000000140141421371434000164530ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package mitm import ( "crypto/tls" "crypto/x509" "net" "reflect" "testing" "time" ) func TestMITM(t *testing.T) { ca, priv, err := NewAuthority("martian.proxy", "Martian Authority", 24*time.Hour) if err != nil { t.Fatalf("NewAuthority(): got %v, want no error", err) } c, err := NewConfig(ca, priv) if err != nil { t.Fatalf("NewConfig(): got %v, want no error", err) } c.SetValidity(20 * time.Hour) c.SetOrganization("Test Organization") protos := []string{"http/1.1"} conf := c.TLS() if got := conf.NextProtos; !reflect.DeepEqual(got, protos) { t.Errorf("conf.NextProtos: got %v, want %v", got, protos) } if conf.InsecureSkipVerify { t.Error("conf.InsecureSkipVerify: got true, want false") } // Simulate a TLS connection without SNI. clientHello := &tls.ClientHelloInfo{ ServerName: "", } if _, err := conf.GetCertificate(clientHello); err == nil { t.Fatal("conf.GetCertificate(): got nil, want error") } // Simulate a TLS connection with SNI. clientHello.ServerName = "example.com" tlsc, err := conf.GetCertificate(clientHello) if err != nil { t.Fatalf("conf.GetCertificate(): got %v, want no error", err) } x509c := tlsc.Leaf if got, want := x509c.Subject.CommonName, "example.com"; got != want { t.Errorf("x509c.Subject.CommonName: got %q, want %q", got, want) } c.SkipTLSVerify(true) conf = c.TLSForHost("example.com") if got := conf.NextProtos; !reflect.DeepEqual(got, protos) { t.Errorf("conf.NextProtos: got %v, want %v", got, protos) } if !conf.InsecureSkipVerify { t.Error("conf.InsecureSkipVerify: got false, want true") } // Set SNI, takes precedence over host. clientHello.ServerName = "google.com" tlsc, err = conf.GetCertificate(clientHello) if err != nil { t.Fatalf("conf.GetCertificate(): got %v, want no error", err) } x509c = tlsc.Leaf if got, want := x509c.Subject.CommonName, "google.com"; got != want { t.Errorf("x509c.Subject.CommonName: got %q, want %q", got, want) } // Reset SNI to fallback to hostname. clientHello.ServerName = "" tlsc, err = conf.GetCertificate(clientHello) if err != nil { t.Fatalf("conf.GetCertificate(): got %v, want no error", err) } x509c = tlsc.Leaf if got, want := x509c.Subject.CommonName, "example.com"; got != want { t.Errorf("x509c.Subject.CommonName: got %q, want %q", got, want) } } func TestCert(t *testing.T) { ca, priv, err := NewAuthority("martian.proxy", "Martian Authority", 24*time.Hour) if err != nil { t.Fatalf("NewAuthority(): got %v, want no error", err) } c, err := NewConfig(ca, priv) if err != nil { t.Fatalf("NewConfig(): got %v, want no error", err) } tlsc, err := c.cert("example.com") if err != nil { t.Fatalf("c.cert(%q): got %v, want no error", "example.com:8080", err) } if tlsc.Certificate == nil { t.Error("tlsc.Certificate: got nil, want certificate bytes") } if tlsc.PrivateKey == nil { t.Error("tlsc.PrivateKey: got nil, want private key") } x509c := tlsc.Leaf if x509c == nil { t.Fatal("x509c: got nil, want *x509.Certificate") } if got := x509c.SerialNumber; got.Cmp(MaxSerialNumber) >= 0 { t.Errorf("x509c.SerialNumber: got %v, want <= MaxSerialNumber", got) } if got, want := x509c.Subject.CommonName, "example.com"; got != want { t.Errorf("X509c.Subject.CommonName: got %q, want %q", got, want) } if err := x509c.VerifyHostname("example.com"); err != nil { t.Errorf("x509c.VerifyHostname(%q): got %v, want no error", "example.com", err) } if got, want := x509c.Subject.Organization, []string{"Martian Proxy"}; !reflect.DeepEqual(got, want) { t.Errorf("x509c.Subject.Organization: got %v, want %v", got, want) } if got := x509c.SubjectKeyId; got == nil { t.Error("x509c.SubjectKeyId: got nothing, want key ID") } if !x509c.BasicConstraintsValid { t.Error("x509c.BasicConstraintsValid: got false, want true") } if got, want := x509c.KeyUsage, x509.KeyUsageKeyEncipherment; got&want == 0 { t.Error("x509c.KeyUsage: got nothing, want to include x509.KeyUsageKeyEncipherment") } if got, want := x509c.KeyUsage, x509.KeyUsageDigitalSignature; got&want == 0 { t.Error("x509c.KeyUsage: got nothing, want to include x509.KeyUsageDigitalSignature") } want := []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth} if got := x509c.ExtKeyUsage; !reflect.DeepEqual(got, want) { t.Errorf("x509c.ExtKeyUsage: got %v, want %v", got, want) } if got, want := x509c.DNSNames, []string{"example.com"}; !reflect.DeepEqual(got, want) { t.Errorf("x509c.DNSNames: got %v, want %v", got, want) } before := time.Now().Add(-2 * time.Hour) if got := x509c.NotBefore; before.After(got) { t.Errorf("x509c.NotBefore: got %v, want after %v", got, before) } after := time.Now().Add(2 * time.Hour) if got := x509c.NotAfter; !after.After(got) { t.Errorf("x509c.NotAfter: got %v, want before %v", got, want) } // Retrieve cached certificate. tlsc2, err := c.cert("example.com") if err != nil { t.Fatalf("c.cert(%q): got %v, want no error", "example.com", err) } if tlsc != tlsc2 { t.Error("tlsc2: got new certificate, want cached certificate") } // TLS certificate for IP. tlsc, err = c.cert("10.0.0.1:8227") if err != nil { t.Fatalf("c.cert(%q): got %v, want no error", "10.0.0.1:8227", err) } x509c = tlsc.Leaf if got, want := len(x509c.IPAddresses), 1; got != want { t.Fatalf("len(x509c.IPAddresses): got %d, want %d", got, want) } if got, want := x509c.IPAddresses[0], net.ParseIP("10.0.0.1"); !got.Equal(want) { t.Fatalf("x509c.IPAddresses: got %v, want %v", got, want) } } martian-3.3.2/mobile/000077500000000000000000000000001421371434000144205ustar00rootroot00000000000000martian-3.3.2/mobile/init.go000066400000000000000000000016671421371434000157240ustar00rootroot00000000000000// Copyright 2017 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package mobile configures and instantiates a Martian Proxy. // This package is a reference implementation of Martian Proxy intended to // be cross compiled with gomobile for use on Android and iOS. package mobile // Init runs common initialization code for a martian proxy. func init() { // Add custom code for your environment here. } martian-3.3.2/mobile/proxy.go000066400000000000000000000200141421371434000161250ustar00rootroot00000000000000// Copyright 2017 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package mobile import ( "crypto/tls" "crypto/x509" "fmt" "io/ioutil" "log" "net" "net/http" "os" "path" "time" "github.com/google/martian/v3" "github.com/google/martian/v3/api" "github.com/google/martian/v3/cors" "github.com/google/martian/v3/cybervillains" "github.com/google/martian/v3/fifo" "github.com/google/martian/v3/har" "github.com/google/martian/v3/httpspec" mlog "github.com/google/martian/v3/log" "github.com/google/martian/v3/marbl" "github.com/google/martian/v3/martianhttp" "github.com/google/martian/v3/mitm" "github.com/google/martian/v3/servemux" "github.com/google/martian/v3/trafficshape" "github.com/google/martian/v3/verify" // side-effect importing to register with JSON API _ "github.com/google/martian/v3/body" _ "github.com/google/martian/v3/cookie" _ "github.com/google/martian/v3/failure" _ "github.com/google/martian/v3/header" _ "github.com/google/martian/v3/martianurl" _ "github.com/google/martian/v3/method" _ "github.com/google/martian/v3/pingback" _ "github.com/google/martian/v3/port" _ "github.com/google/martian/v3/priority" _ "github.com/google/martian/v3/querystring" _ "github.com/google/martian/v3/skip" _ "github.com/google/martian/v3/stash" _ "github.com/google/martian/v3/static" _ "github.com/google/martian/v3/status" ) // Martian is a wrapper for the initialized Martian proxy type Martian struct { proxy *martian.Proxy listener net.Listener apiListener net.Listener mux *http.ServeMux started bool HARLogging bool TrafficPort int TrafficShaping bool APIPort int APIOverTLS bool BindLocalhost bool Cert string Key string AllowCORS bool RoundTripper *http.Transport } // EnableCybervillains configures Martian to use the Cybervillians certificate. func (m *Martian) EnableCybervillains() { m.Cert = cybervillains.Cert m.Key = cybervillains.Key } // NewProxy creates a new Martian struct for configuring and starting a martian. func NewProxy() *Martian { return &Martian{} } // Start starts the proxy given the configured values of the Martian struct. func (m *Martian) Start() { var err error m.listener, err = net.Listen("tcp", m.bindAddress(m.TrafficPort)) if err != nil { log.Fatal(err) } mlog.Debugf("mobile: started listener on: %v", m.listener.Addr()) m.proxy = martian.NewProxy() m.mux = http.NewServeMux() if m.Cert != "" && m.Key != "" { tlsc, err := tls.X509KeyPair([]byte(m.Cert), []byte(m.Key)) if err != nil { log.Fatal(err) } mlog.Debugf("mobile: loaded cert and key") x509c, err := x509.ParseCertificate(tlsc.Certificate[0]) if err != nil { log.Fatal(err) } mlog.Debugf("mobile: parsed cert") mc, err := mitm.NewConfig(x509c, tlsc.PrivateKey) if err != nil { log.Fatal(err) } mc.SetValidity(12 * time.Hour) mc.SetOrganization("Martian Proxy") m.proxy.SetMITM(mc) if m.RoundTripper != nil { m.proxy.SetRoundTripper(m.RoundTripper) } m.handle("/authority.cer", martianhttp.NewAuthorityHandler(x509c)) } // Enable Traffic shaping if requested if m.TrafficShaping { tsl := trafficshape.NewListener(m.listener) tsh := trafficshape.NewHandler(tsl) m.handle("/shape-traffic", tsh) m.listener = tsl } // Forward traffic that pattern matches in m.mux before applying // httpspec modifiers (via modifier, specifically) topg := fifo.NewGroup() apif := servemux.NewFilter(m.mux) apif.SetRequestModifier(api.NewForwarder("", m.APIPort)) topg.AddRequestModifier(apif) stack, fg := httpspec.NewStack("martian.mobile") topg.AddRequestModifier(stack) topg.AddResponseModifier(stack) m.proxy.SetRequestModifier(topg) m.proxy.SetResponseModifier(topg) if m.HARLogging { // add HAR logger for unmodified logs. uhl := har.NewLogger() uhmuxf := servemux.NewFilter(m.mux) uhmuxf.RequestWhenFalse(uhl) uhmuxf.ResponseWhenFalse(uhl) fg.AddRequestModifier(uhmuxf) fg.AddResponseModifier(uhmuxf) // add HAR logger hl := har.NewLogger() hmuxf := servemux.NewFilter(m.mux) hmuxf.RequestWhenFalse(hl) hmuxf.ResponseWhenFalse(hl) stack.AddRequestModifier(hmuxf) stack.AddResponseModifier(hmuxf) // Retrieve Unmodified HAR logs m.handle("/logs/original", har.NewExportHandler(uhl)) m.handle("/logs/original/reset", har.NewResetHandler(uhl)) // Retrieve HAR logs m.handle("/logs", har.NewExportHandler(hl)) m.handle("/logs/reset", har.NewResetHandler(hl)) } lsh := marbl.NewHandler() // retrieve binary marbl logs m.handle("/binlogs", lsh) lsm := marbl.NewModifier(lsh) muxf := servemux.NewFilter(m.mux) muxf.RequestWhenFalse(lsm) muxf.ResponseWhenFalse(lsm) stack.AddRequestModifier(muxf) stack.AddResponseModifier(muxf) mod := martianhttp.NewModifier() fg.AddRequestModifier(mod) fg.AddResponseModifier(mod) // Proxy specific handlers. // These handlers take precendence over proxy traffic and will not be intercepted. // Update modifiers. m.handle("/configure", mod) // Verify assertions. vh := verify.NewHandler() vh.SetRequestVerifier(mod) vh.SetResponseVerifier(mod) m.handle("/verify", vh) // Reset verifications. rh := verify.NewResetHandler() rh.SetRequestVerifier(mod) rh.SetResponseVerifier(mod) m.handle("/verify/reset", rh) mlog.Infof("mobile: starting Martian proxy on listener") go m.proxy.Serve(m.listener) // start the API server apiAddr := m.bindAddress(m.APIPort) m.apiListener, err = net.Listen("tcp", apiAddr) if err != nil { log.Fatal(err) } if m.APIOverTLS { if m.Cert == "" || m.Key == "" { log.Fatal("mobile: APIOverTLS cannot be true without valid cert and key") } cerfile, err := ioutil.TempFile("", "martian-api.cert") if err != nil { log.Fatal(err) } keyfile, err := ioutil.TempFile("", "martian-api.key") if err != nil { log.Fatal(err) } if _, err := cerfile.Write([]byte(m.Cert)); err != nil { log.Fatal(err) } if _, err := keyfile.Write([]byte(m.Key)); err != nil { log.Fatal(err) } go func() { http.ServeTLS(m.apiListener, m.mux, cerfile.Name(), keyfile.Name()) defer os.Remove(cerfile.Name()) defer os.Remove(keyfile.Name()) }() mlog.Infof("mobile: proxy API started on %s over TLS", apiAddr) } else { go http.Serve(m.apiListener, m.mux) mlog.Infof("mobile: proxy API started on %s", apiAddr) } m.started = true } // IsStarted returns true if the proxy has finished starting. func (m *Martian) IsStarted() bool { return m.started } // Shutdown tells the Proxy to close. This function returns immediately, though // there may still be connection threads hanging around until they time out // depending on how the OS manages them. func (m *Martian) Shutdown() { mlog.Infof("mobile: shutting down proxy") m.listener.Close() m.apiListener.Close() m.proxy.Close() m.started = false mlog.Infof("mobile: proxy shut down") } // SetLogLevel sets the Martian log level (Silent = 0, Error, Info, Debug), controlling which Martian // log calls are displayed in the console func SetLogLevel(l int) { mlog.SetLevel(l) } func (m *Martian) handle(pattern string, handler http.Handler) { if m.AllowCORS { handler = cors.NewHandler(handler) } m.mux.Handle(pattern, handler) mlog.Infof("mobile: handler registered for %s", pattern) lhp := path.Join(fmt.Sprintf("localhost:%d", m.APIPort), pattern) m.mux.Handle(lhp, handler) mlog.Infof("mobile: handler registered for %s", lhp) } func (m *Martian) bindAddress(port int) string { if m.BindLocalhost { return fmt.Sprintf("[::1]:%d", port) } return fmt.Sprintf(":%d", port) } martian-3.3.2/multierror.go000066400000000000000000000034221421371434000157050ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package martian import ( "strings" "sync" ) // MultiError is a collection of errors that implements the error interface. type MultiError struct { mu sync.RWMutex errs []error } // NewMultiError returns a new MultiError. func NewMultiError() *MultiError { return &MultiError{} } // Error returns the list of errors separated by newlines. func (merr *MultiError) Error() string { merr.mu.RLock() defer merr.mu.RUnlock() var errs []string for _, err := range merr.errs { errs = append(errs, err.Error()) } return strings.Join(errs, "\n") } // Errors returns the error slice containing the error collection. func (merr *MultiError) Errors() []error { merr.mu.RLock() defer merr.mu.RUnlock() return merr.errs } // Add appends an error to the error collection. func (merr *MultiError) Add(err error) { merr.mu.Lock() defer merr.mu.Unlock() // Unwrap *MultiError to ensure that depth never exceeds 1. if merr2, ok := err.(*MultiError); ok { merr.errs = append(merr.errs, merr2.Errors()...) return } merr.errs = append(merr.errs, err) } // Empty returns whether the *MultiError contains any errors. func (merr *MultiError) Empty() bool { return len(merr.errs) == 0 } martian-3.3.2/multierror_test.go000066400000000000000000000023401421371434000167420ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package martian import ( "fmt" "reflect" "testing" ) func TestMultiError(t *testing.T) { merr := NewMultiError() if !merr.Empty() { t.Fatal("Empty(): got false, want true") } var errs []error for i := 0; i < 3; i++ { err := fmt.Errorf("%d. error", i) errs = append(errs, err) merr.Add(err) } if merr.Empty() { t.Fatal("Empty(): got true, want false") } if got, want := merr.Errors(), errs; !reflect.DeepEqual(got, want) { t.Errorf("Errors(): got %v, want %v", got, want) } want := "0. error\n1. error\n2. error" if got := merr.Error(); got != want { t.Errorf("Error(): got %q, want %q", got, want) } } martian-3.3.2/noop.go000066400000000000000000000023211421371434000144510ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package martian import ( "net/http" "github.com/google/martian/v3/log" ) type noopModifier struct { id string } // Noop returns a modifier that does not change the request or the response. func Noop(id string) RequestResponseModifier { return &noopModifier{ id: id, } } // ModifyRequest logs a debug line. func (nm *noopModifier) ModifyRequest(*http.Request) error { log.Debugf("%s: no request modifier configured", nm.id) return nil } // ModifyResponse logs a debug line. func (nm *noopModifier) ModifyResponse(*http.Response) error { log.Debugf("%s: no response modifier configured", nm.id) return nil } martian-3.3.2/noop/000077500000000000000000000000001421371434000141245ustar00rootroot00000000000000martian-3.3.2/noop/noop.go000066400000000000000000000040531421371434000154300ustar00rootroot00000000000000// Copyright 2021 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package noop provides a martian.RequestResponseModifier that does not // modify the request or response. package noop import ( "encoding/json" "net/http" "github.com/google/martian/v3" "github.com/google/martian/v3/log" "github.com/google/martian/v3/parse" ) func init() { parse.Register("noop.Modifier", modifierFromJSON) } type noopModifier struct { id string } // Noop returns a modifier that does not change the request or the response. func Noop(id string) martian.RequestResponseModifier { return &noopModifier{ id: id, } } // ModifyRequest logs a debug line. func (nm *noopModifier) ModifyRequest(*http.Request) error { log.Debugf("noopModifier: %s: no request modification applied", nm.id) return nil } // ModifyResponse logs a debug line. func (nm *noopModifier) ModifyResponse(*http.Response) error { log.Debugf("noopModifier: %s: no response modification applied", nm.id) return nil } type modifierJSON struct { Name string `json:"name"` Scope []parse.ModifierType `json:"scope"` } // modifierFromJSON takes a JSON message as a byte slice and returns // a headerModifier and an error. // // Example JSON configuration message: // { // "scope": ["request", "result"], // "name": "noop-name", // } func modifierFromJSON(b []byte) (*parse.Result, error) { msg := &modifierJSON{} if err := json.Unmarshal(b, msg); err != nil { return nil, err } modifier := Noop(msg.Name) return parse.NewResult(modifier, msg.Scope) } martian-3.3.2/nosigpipe/000077500000000000000000000000001421371434000151465ustar00rootroot00000000000000martian-3.3.2/nosigpipe/nosigpipe.go000066400000000000000000000004021421371434000174660ustar00rootroot00000000000000//+build !darwin !go1.9 package nosigpipe import "net" // IgnoreSIGPIPE prevents SIGPIPE from being raised on TCP sockets when remote hangs up // See: https://github.com/golang/go/issues/17393. Do nothing for non Darwin func IgnoreSIGPIPE(c net.Conn) { } martian-3.3.2/nosigpipe/nosigpipe_darwin.go000066400000000000000000000013371421371434000210420ustar00rootroot00000000000000//+build darwin,go1.9 package nosigpipe import ( "net" "syscall" "github.com/google/martian/v3/log" ) // IgnoreSIGPIPE prevents SIGPIPE from being raised on TCP sockets when remote hangs up // See: https://github.com/golang/go/issues/17393 func IgnoreSIGPIPE(c net.Conn) { if c == nil { return } s, ok := c.(syscall.Conn) if !ok { return } r, e := s.SyscallConn() if e != nil { log.Errorf("Failed to get SyscallConn: %s", e) return } e = r.Control(func(fd uintptr) { intfd := int(fd) if e := syscall.SetsockoptInt(intfd, syscall.SOL_SOCKET, syscall.SO_NOSIGPIPE, 1); e != nil { log.Errorf("Failed to set SO_NOSIGPIPE: %s", e) } }) if e != nil { log.Errorf("Failed to set SO_NOSIGPIPE: %s", e) } } martian-3.3.2/parse/000077500000000000000000000000001421371434000142635ustar00rootroot00000000000000martian-3.3.2/parse/parse.go000066400000000000000000000100151421371434000157210ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package parse constructs martian modifiers from JSON messages. package parse import ( "encoding/json" "fmt" "sync" "github.com/google/martian/v3" ) // ModifierType is the HTTP message type. type ModifierType string const ( // Request modifies an HTTP request. Request ModifierType = "request" // Response modifies an HTTP response. Response ModifierType = "response" ) // Result holds the parsed modifier and its type. type Result struct { reqmod martian.RequestModifier resmod martian.ResponseModifier } // NewResult returns a new parse.Result for a given interface{} that implements a modifier // and a slice of scopes to generate the result for. // // Returns nil, error if a given modifier does not support a given scope func NewResult(mod interface{}, scope []ModifierType) (*Result, error) { reqmod, reqOk := mod.(martian.RequestModifier) resmod, resOk := mod.(martian.ResponseModifier) result := &Result{} if scope == nil { result.reqmod = reqmod result.resmod = resmod return result, nil } for _, s := range scope { switch s { case Request: if !reqOk { return nil, fmt.Errorf("parse: invalid scope %q for modifier", "request") } result.reqmod = reqmod case Response: if !resOk { return nil, fmt.Errorf("parse: invalid scope %q for modifier", "response") } result.resmod = resmod default: return nil, fmt.Errorf("parse: invalid scope: %s not in [%q, %q]", s, "request", "response") } } return result, nil } // RequestModifier returns the parsed RequestModifier. // // Returns nil if the message has no request modifier. func (r *Result) RequestModifier() martian.RequestModifier { return r.reqmod } // ResponseModifier returns the parsed ResponseModifier. // // Returns nil if the message has no response modifier. func (r *Result) ResponseModifier() martian.ResponseModifier { return r.resmod } var ( parseMu sync.RWMutex parseFuncs = make(map[string]func(b []byte) (*Result, error)) ) // ErrUnknownModifier is the error returned when the message does not // contain a field representing a known modifier type. type ErrUnknownModifier struct { name string } // Error returns a formatted error message for an ErrUnknownModifier. func (e ErrUnknownModifier) Error() string { return fmt.Sprintf("parse: unknown modifier: %s", e.name) } // Register registers a parsing function for name that will be used to unmarshal // a JSON message into the appropriate modifier. func Register(name string, parseFunc func(b []byte) (*Result, error)) { parseMu.Lock() defer parseMu.Unlock() parseFuncs[name] = parseFunc } // FromJSON parses a Modifier JSON message by looking up the named modifier in parseFuncs // and passing its modifier to the registered parseFunc. Returns a parse.Result containing // the top-level parsed modifier. If no parser has been registered with the given name // it returns an error of type ErrUnknownModifier. func FromJSON(b []byte) (*Result, error) { msg := make(map[string]json.RawMessage) if err := json.Unmarshal(b, &msg); err != nil { return nil, err } if len(msg) != 1 { ks := "" for k := range msg { ks += ", " + k } return nil, fmt.Errorf("parse: expected one modifier, received %d: %s", len(msg), ks) } parseMu.RLock() defer parseMu.RUnlock() for k, m := range msg { parseFunc, ok := parseFuncs[k] if !ok { return nil, ErrUnknownModifier{name: k} } return parseFunc(m) } return nil, fmt.Errorf("parse: no modifiers found: %v", msg) } martian-3.3.2/parse/parse_test.go000066400000000000000000000103131421371434000167610ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package parse import ( "encoding/json" "net/http" "testing" "github.com/google/martian/v3" "github.com/google/martian/v3/martiantest" ) func TestFromJSON(t *testing.T) { msg := []byte(`{ "first.Modifier": { }, "second.Modifier": { } }`) if _, err := FromJSON(msg); err == nil { t.Error("FromJson(): got nil, want more than one key error") } Register("martiantest.Modifier", func(b []byte) (*Result, error) { type testJSON struct { Scope []ModifierType `json:"scope"` } msg := &testJSON{} if err := json.Unmarshal(b, msg); err != nil { return nil, err } tm := martiantest.NewModifier() return NewResult(tm, msg.Scope) }) msg = []byte(`{ "martiantest.Modifier": { } }`) r, err := FromJSON(msg) if err != nil { t.Fatalf("FromJSON(): got %v, want no error", err) } if _, ok := r.RequestModifier().(*martiantest.Modifier); !ok { t.Fatal("r.RequestModifier().(*martiantest.Modifier): got !ok, want ok") } if _, ok := r.ResponseModifier().(*martiantest.Modifier); !ok { t.Fatal("r.ResponseModifier().(*martiantest.Modifier): got !ok, want ok") } msg = []byte(`{ "martiantest.Modifier": { "scope": ["request"] } }`) r, err = FromJSON(msg) if err != nil { t.Fatalf("FromJSON(): got %v, want no error", err) } if _, ok := r.RequestModifier().(*martiantest.Modifier); !ok { t.Fatal("r.RequestModifier().(*martiantest.Modifier): got !ok, want ok") } resmod := r.ResponseModifier() if resmod != nil { t.Error("r.ResponseModifier(): got not nil, want nil") } msg = []byte(`{ "martiantest.Modifier": { "scope": ["response"] } }`) r, err = FromJSON(msg) if err != nil { t.Fatalf("FromJSON(): got %v, want no error", err) } if _, ok := r.ResponseModifier().(*martiantest.Modifier); !ok { t.Fatal("r.ResponseModifier().(*martiantest.Modifier): got !ok, want ok") } reqmod := r.RequestModifier() if reqmod != nil { t.Error("r.RequestModifier(): got not nil, want nil") } } func TestNewResultMismatchedScopes(t *testing.T) { reqmod := martian.RequestModifierFunc( func(*http.Request) error { return nil }) resmod := martian.ResponseModifierFunc( func(*http.Response) error { return nil }) if _, err := NewResult(reqmod, []ModifierType{Response}); err == nil { t.Error("NewResult(reqmod, RESPONSE): got nil, want error") } if _, err := NewResult(resmod, []ModifierType{Request}); err == nil { t.Error("NewResult(resmod, REQUEST): got nil, want error") } if _, err := NewResult(reqmod, []ModifierType{ModifierType("unknown")}); err == nil { t.Error("NewResult(resmod, REQUEST): got nil, want error") } } func TestResultModifierAccessors(t *testing.T) { tm := martiantest.NewModifier() r := &Result{ reqmod: tm, resmod: nil, } if reqmod := r.RequestModifier(); reqmod == nil { t.Error("r.RequestModifier: got nil, want reqmod") } if resmod := r.ResponseModifier(); resmod != nil { t.Error("r.ResponseModifier: got resmod, want nil") } r = &Result{ reqmod: nil, resmod: tm, } if reqmod := r.RequestModifier(); reqmod != nil { t.Errorf("r.RequestModifier: got reqmod, want nil") } if resmod := r.ResponseModifier(); resmod == nil { t.Error("r.ResponseModifier: got nil, want resmod") } } func TestParseUnknownModifierReturnsError(t *testing.T) { msg := []byte(`{ "unknown.Key": { "scope": ["request", "response"] } }`) _, err := FromJSON(msg) umerr, ok := err.(ErrUnknownModifier) if !ok { t.Fatalf("FromJSON(): got %v, want ErrUnknownModifier", err) } if got, want := umerr.Error(), "parse: unknown modifier: unknown.Key"; got != want { t.Errorf("Error(): got %q, want %q", got, want) } } martian-3.3.2/pingback/000077500000000000000000000000001421371434000147275ustar00rootroot00000000000000martian-3.3.2/pingback/pingback_verifier.go000066400000000000000000000062221421371434000207310ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package pingback provides verification that specific URLs have been seen by // the proxy. package pingback import ( "encoding/json" "fmt" "net/http" "net/url" "github.com/google/martian/v3" "github.com/google/martian/v3/parse" "github.com/google/martian/v3/verify" ) const ( errFormat = "request(%s): pingback never occurred" ) func init() { parse.Register("pingback.Verifier", verifierFromJSON) } // Verifier verifies that the specific URL has been seen. type Verifier struct { url *url.URL err error } type verifierJSON struct { Scheme string `json:"scheme"` Host string `json:"host"` Path string `json:"path"` Query string `json:"query"` Scope []parse.ModifierType `json:"scope"` } // NewVerifier returns a new pingback verifier. func NewVerifier(url *url.URL) verify.RequestVerifier { return &Verifier{ url: url, err: fmt.Errorf(errFormat, url.String()), } } // ModifyRequest verifies that the request URL matches all parts of url. // // If the value in url is non-empty, it must be an exact match. If the URL // matches the pingback, it is recorded by setting the error to nil. The error // will continue to be nil until the verifier has been reset, regardless of // subsequent requests matching. func (v *Verifier) ModifyRequest(req *http.Request) error { // skip requests to API ctx := martian.NewContext(req) if ctx.IsAPIRequest() { return nil } u := req.URL switch { case v.url.Scheme != "" && v.url.Scheme != u.Scheme: case v.url.Host != "" && v.url.Host != u.Host: case v.url.Path != "" && v.url.Path != u.Path: case v.url.RawQuery != "" && v.url.RawQuery != u.RawQuery: default: v.err = nil } return nil } // VerifyRequests returns an error if pingback never occurred. func (v *Verifier) VerifyRequests() error { return v.err } // ResetRequestVerifications clears the failed request verification. func (v *Verifier) ResetRequestVerifications() { v.err = fmt.Errorf(errFormat, v.url.String()) } // verifierFromJSON builds a pingback.Verifier from JSON. // // Example JSON: // { // "pingback.Verifier": { // "scope": ["request"], // "scheme": "https", // "host": "www.google.com", // "path": "/proxy", // "query": "testing=true" // } // } func verifierFromJSON(b []byte) (*parse.Result, error) { msg := &verifierJSON{} if err := json.Unmarshal(b, msg); err != nil { return nil, err } v := NewVerifier(&url.URL{ Scheme: msg.Scheme, Host: msg.Host, Path: msg.Path, RawQuery: msg.Query, }) return parse.NewResult(v, msg.Scope) } martian-3.3.2/pingback/pingback_verifier_test.go000066400000000000000000000103141421371434000217650ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package pingback import ( "net/http" "net/url" "testing" "github.com/google/martian/v3" "github.com/google/martian/v3/parse" "github.com/google/martian/v3/verify" ) func TestVerifyRequests(t *testing.T) { v := NewVerifier(&url.URL{ Scheme: "https", Host: "example.com", Path: "/test", RawQuery: "testing=true", }) // Initial error state is failure. No pingback has been seen. err := v.VerifyRequests() if err == nil { t.Fatal("v.VerifyRequests(): got nil, want error") } want := "request(https://example.com/test?testing=true): pingback never occurred" if got := err.Error(); got != want { t.Errorf("err.Error(): got %q, want %q", got, want) } // Send non-matching request, error persists. req, err := http.NewRequest("GET", "http://www.google.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } _, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("TestContext(): got %v, want no error", err) } defer remove() if err := v.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if err := v.VerifyRequests(); err == nil { t.Fatal("v.VerifyRequests(): got nil, want error") } // Send matching requests, clear error. req, err = http.NewRequest("GET", "https://example.com/test?testing=true", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } _, rmv, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("TestContext(): got %v, want no error", err) } defer rmv() if err := v.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if err := v.VerifyRequests(); err != nil { t.Fatalf("VerifyRequests(): got %v, want no error", err) } // Send non-matching request again, error is still nil after // pingback. req, err = http.NewRequest("GET", "http://www.google.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } _, rm, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("TestContext(): got %v, want no error", err) } defer rm() if err := v.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if err := v.VerifyRequests(); err != nil { t.Fatalf("VerifyRequests(): got %v, want no error", err) } v.ResetRequestVerifications() if err := v.VerifyRequests(); err == nil { t.Error("VerifyRequests(): got nil, want error") } } func TestVerifierFromJSON(t *testing.T) { msg := []byte(`{ "pingback.Verifier": { "scope": ["request"], "scheme": "https", "host": "example.com", "path": "/testing", "query": "test=true" } }`) r, err := parse.FromJSON(msg) if err != nil { t.Fatalf("parse.FromJSON(): got %v, want no error", err) } reqmod := r.RequestModifier() if reqmod == nil { t.Fatal("reqmod: got nil, want not nil") } reqv, ok := reqmod.(verify.RequestVerifier) if !ok { t.Fatal("reqmod.(verify.RequestVerifier): got !ok, want ok") } if err := reqv.VerifyRequests(); err == nil { t.Fatal("VerifyRequests(): got nil, want error") } req, err := http.NewRequest("GET", "https://example.com/testing?test=true", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } _, rm, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("TestContext(): got %v, want no error", err) } defer rm() if err := reqv.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if err := reqv.VerifyRequests(); err != nil { t.Errorf("VerifyRequests(): got %v, want no error", err) } } martian-3.3.2/port/000077500000000000000000000000001421371434000141355ustar00rootroot00000000000000martian-3.3.2/port/port_filter.go000066400000000000000000000071741421371434000170260ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package port provides utilities for modifying and filtering // based on the port of request URLs. package port import ( "encoding/json" "net" "net/http" "strconv" "strings" "github.com/google/martian/v3" "github.com/google/martian/v3/parse" ) var noop = martian.Noop("port.Filter") func init() { parse.Register("port.Filter", filterFromJSON) } // Filter runs modifiers iff the port in the request URL matches port. type Filter struct { reqmod martian.RequestModifier resmod martian.ResponseModifier port int } type filterJSON struct { Port int `json:"port"` Modifier json.RawMessage `json:"modifier"` Scope []parse.ModifierType `json:"scope"` } // NewFilter returns a filter that executes modifiers if the port of // request matches port. func NewFilter(port int) *Filter { return &Filter{ port: port, reqmod: noop, resmod: noop, } } // SetRequestModifier sets the request modifier. func (f *Filter) SetRequestModifier(reqmod martian.RequestModifier) { if reqmod == nil { reqmod = noop } f.reqmod = reqmod } // SetResponseModifier sets the response modifier. func (f *Filter) SetResponseModifier(resmod martian.ResponseModifier) { if resmod == nil { resmod = noop } f.resmod = resmod } // ModifyRequest runs the modifier if the port matches the provided port. func (f *Filter) ModifyRequest(req *http.Request) error { var defaultPort int if req.URL.Scheme == "http" { defaultPort = 80 } if req.URL.Scheme == "https" { defaultPort = 443 } hasPort := strings.Contains(req.URL.Host, ":") if hasPort { _, p, err := net.SplitHostPort(req.URL.Host) if err != nil { return err } pt, err := strconv.Atoi(p) if err != nil { return err } if pt == f.port { return f.reqmod.ModifyRequest(req) } return nil } // no port explictly declared - default port if f.port == defaultPort { return f.reqmod.ModifyRequest(req) } return nil } // ModifyResponse runs the modifier if the request URL matches urlMatcher. func (f *Filter) ModifyResponse(res *http.Response) error { var defaultPort int if res.Request.URL.Scheme == "http" { defaultPort = 80 } if res.Request.URL.Scheme == "https" { defaultPort = 443 } if !strings.Contains(res.Request.URL.Host, ":") && (f.port == defaultPort) { return f.resmod.ModifyResponse(res) } _, p, err := net.SplitHostPort(res.Request.URL.Host) if err != nil { return err } pt, err := strconv.Atoi(p) if err != nil { return err } if pt == f.port { return f.resmod.ModifyResponse(res) } return nil } func filterFromJSON(b []byte) (*parse.Result, error) { msg := &filterJSON{} if err := json.Unmarshal(b, msg); err != nil { return nil, err } filter := NewFilter(msg.Port) r, err := parse.FromJSON(msg.Modifier) if err != nil { return nil, err } reqmod := r.RequestModifier() if err != nil { return nil, err } if reqmod != nil { filter.SetRequestModifier(reqmod) } resmod := r.ResponseModifier() if resmod != nil { filter.SetResponseModifier(resmod) } return parse.NewResult(filter, msg.Scope) } martian-3.3.2/port/port_filter_test.go000066400000000000000000000070471421371434000200640ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package port import ( "net/http" "net/url" "testing" _ "github.com/google/martian/v3/header" "github.com/google/martian/v3/martiantest" "github.com/google/martian/v3/parse" "github.com/google/martian/v3/proxyutil" ) func TestFilterModifyRequest(t *testing.T) { tt := []struct { want bool url *url.URL port int }{ { url: &url.URL{Scheme: "http", Host: "example.com"}, port: 80, want: true, }, { url: &url.URL{Scheme: "http", Host: "example.com:80"}, port: 80, want: true, }, { url: &url.URL{Scheme: "http", Host: "example.com"}, port: 123, want: false, }, { url: &url.URL{Scheme: "http", Host: "example.com:8080"}, port: 123, want: false, }, { url: &url.URL{Scheme: "https", Host: "example.com"}, port: 443, want: true, }, { url: &url.URL{Scheme: "https", Host: "example.com:443"}, port: 443, want: true, }, { url: &url.URL{Scheme: "https", Host: "example.com"}, port: 123, want: false, }, { url: &url.URL{Scheme: "https", Host: "example.com:8080"}, port: 123, want: false, }, } for i, tc := range tt { req, err := http.NewRequest("GET", tc.url.String(), nil) if err != nil { t.Fatalf("%d. NewRequest(): got %v, want no error", i, err) } mod := NewFilter(tc.port) tm := martiantest.NewModifier() mod.SetRequestModifier(tm) if err := mod.ModifyRequest(req); err != nil { t.Fatalf("%d. ModifyRequest(): got %q, want no error", i, err) } if tm.RequestModified() != tc.want { t.Errorf("%d. tm.RequestModified(): got %t, want %t", i, tm.RequestModified(), tc.want) } } } func TestFilterFromJSON(t *testing.T) { msg := []byte(`{ "port.Filter": { "scope": ["request", "response"], "port": 8080, "modifier": { "header.Modifier": { "scope": ["request", "response"], "name": "Mod-Run", "value": "true" } } } }`) r, err := parse.FromJSON(msg) if err != nil { t.Fatalf("FilterFromJSON(): got %v, want no error", err) } reqmod := r.RequestModifier() if reqmod == nil { t.Fatal("reqmod: got nil, want not nil") } req, err := http.NewRequest("GET", "https://example.com:8080", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := reqmod.ModifyRequest(req); err != nil { t.Fatalf("reqmod.ModifyRequest(): got %v, want no error", err) } if got, want := req.Header.Get("Mod-Run"), "true"; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", "Mod-Run", got, want) } resmod := r.ResponseModifier() if resmod == nil { t.Fatalf("resmod: got nil, want not nil") } res := proxyutil.NewResponse(200, nil, req) if err := resmod.ModifyResponse(res); err != nil { t.Fatalf("resmod.ModifyResponse(): got %v, want no error", err) } if got, want := res.Header.Get("Mod-Run"), "true"; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Mod-Run", got, want) } } martian-3.3.2/port/port_modifier.go000066400000000000000000000111551421371434000173310ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package port import ( "encoding/json" "fmt" "net" "net/http" "strconv" "strings" "github.com/google/martian/v3/parse" ) func init() { parse.Register("port.Modifier", modifierFromJSON) } // Modifier alters the request URL and Host header to use the provided port. // Only one of port, defaultForScheme, or remove may be specified. Whichever is set last is the one that will take effect. // If remove is true, remove the port from the host string ('example.com'). // If defaultForScheme is true, explicitly specify 80 for HTTP or 443 for HTTPS ('http://example.com:80'). Do nothing for a scheme that is not 'http' or 'https'. // If port is specified, explicitly add it to the host string ('example.com:1234'). // If port is zero and the other fields are false, the request will not be modified. type Modifier struct { port int defaultForScheme bool remove bool } type modifierJSON struct { Port int `json:"port"` DefaultForScheme bool `json:"defaultForScheme"` Remove bool `json:"remove"` Scope []parse.ModifierType `json:"scope"` } // NewModifier returns a RequestModifier that can be configured to alter the request URL and Host header's port. // One of DefaultPortForScheme, UsePort, or RemovePort should be called to configure this modifier. func NewModifier() *Modifier { return &Modifier{} } // DefaultPortForScheme configures the modifier to explicitly specify 80 for HTTP or 443 for HTTPS ('http://example.com:80'). // The modifier will not modify requests with a scheme that is not 'http' or 'https'. // This overrides any previous configuration for this modifier. func (m *Modifier) DefaultPortForScheme() { m.defaultForScheme = true m.remove = false } // UsePort configures the modifier to add the specified port to the host string ('example.com:1234'). // This overrides any previous configuration for this modifier. func (m *Modifier) UsePort(port int) { m.port = port m.remove = false m.defaultForScheme = false } // RemovePort configures the modifier to remove the port from the host string ('example.com'). // This overrides any previous configuration for this modifier. func (m *Modifier) RemovePort() { m.remove = true m.defaultForScheme = false } // ModifyRequest alters the request URL and Host header to modify the port as specified. // See docs for Modifier for details. func (m *Modifier) ModifyRequest(req *http.Request) error { if m.port == 0 && !m.defaultForScheme && !m.remove { return nil } host := req.URL.Host if strings.Contains(host, ":") { h, _, err := net.SplitHostPort(host) if err != nil { return err } host = h } if m.remove { req.URL.Host = host req.Header.Set("Host", host) return nil } if m.defaultForScheme { switch req.URL.Scheme { case "http": hp := net.JoinHostPort(host, "80") req.URL.Host = hp req.Header.Set("Host", hp) return nil case "https": hp := net.JoinHostPort(host, "443") req.URL.Host = hp req.Header.Set("Host", hp) return nil default: // Unknown scheme, do nothing. return nil } } // Not removing or using default for the scheme, so use the provided port number. hp := net.JoinHostPort(host, strconv.Itoa(m.port)) req.URL.Host = hp req.Header.Set("Host", hp) return nil } func modifierFromJSON(b []byte) (*parse.Result, error) { msg := &modifierJSON{} if err := json.Unmarshal(b, msg); err != nil { return nil, err } errMsg := fmt.Errorf("Must specify only one of port, defaultForScheme or remove") mod := NewModifier() // Check that exactly one field of port, defaultForScheme, and remove is set. switch { case msg.Port != 0: if msg.DefaultForScheme || msg.Remove { return nil, errMsg } mod.UsePort(msg.Port) case msg.DefaultForScheme: if msg.Port != 0 || msg.Remove { return nil, errMsg } mod.DefaultPortForScheme() case msg.Remove: if msg.Port != 0 || msg.DefaultForScheme { return nil, errMsg } mod.RemovePort() default: return nil, errMsg } return parse.NewResult(mod, msg.Scope) } martian-3.3.2/port/port_modifier_test.go000066400000000000000000000123761421371434000203760ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package port import ( "net" "net/http" "testing" "github.com/google/martian/v3/parse" ) func TestPortModifierOnPort(t *testing.T) { mod := NewModifier() mod.UsePort(8080) req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } if err := mod.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } _, port, err := net.SplitHostPort(req.URL.Host) if err != nil { t.Fatalf("net.SplitHostPort(%q): got %v, want no error", req.URL.Host, err) } if got, want := port, "8080"; got != want { t.Errorf("port: got %v, want %v", got, want) } } func TestPortModifierWithNoConfiguration(t *testing.T) { mod := NewModifier() req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } if err := mod.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := req.URL.Host, "example.com"; got != want { t.Errorf("req.URL.Host: got %v, want %v", got, want) } req, err = http.NewRequest("GET", "http://example.com:80", nil) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } if err := mod.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := req.URL.Host, "example.com:80"; got != want { t.Errorf("req.URL.Host: got %v, want %v", got, want) } } func TestPortModifierDefaultForScheme(t *testing.T) { mod := NewModifier() mod.DefaultPortForScheme() req, err := http.NewRequest("GET", "HtTp://example.com", nil) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } if err := mod.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := req.URL.Host, "example.com:80"; got != want { t.Errorf("req.URL.Host: got %v, want %v", got, want) } } func TestPortModifierRemove(t *testing.T) { mod := NewModifier() mod.RemovePort() req, err := http.NewRequest("GET", "http://example.com:8080", nil) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } if err := mod.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := req.URL.Host, "example.com"; got != want { t.Errorf("req.URL.Host: got %v, want %v", got, want) } } func TestPortModifierAllFields(t *testing.T) { mod := NewModifier() mod.UsePort(8081) mod.DefaultPortForScheme() mod.RemovePort() req, err := http.NewRequest("GET", "http://example.com:8080", nil) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } if err := mod.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } // Last configuration was to remove. if got, want := req.URL.Host, "example.com"; got != want { t.Errorf("req.URL.Host: got %v, want %v", got, want) } } func TestModiferFromJSON(t *testing.T) { msg := []byte(`{ "port.Modifier": { "scope": ["request"], "port": 8080 } }`) r, err := parse.FromJSON(msg) if err != nil { t.Fatalf("parse.FromJSON(): got %v, want no error", err) } reqmod := r.RequestModifier() if reqmod == nil { t.Fatal("reqmod: got nil, want not nil") } req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } if err := reqmod.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } _, port, err := net.SplitHostPort(req.URL.Host) if err != nil { t.Fatalf("net.SplitHostPort(%q): got %v, want no error", req.URL.Host, err) } if got, want := port, "8080"; got != want { t.Errorf("port: got %v, want %v", got, want) } } func TestModiferFromJSONInvalidConfigurations(t *testing.T) { for _, msg := range [][]byte{ []byte(`{ "port.Modifier": { "scope": ["request"], "port": 8080, "defaultForScheme": true, "remove": true } }`), []byte(`{ "port.Modifier": { "scope": ["request"], "port": 8080 "remove": true } }`), []byte(`{ "port.Modifier": { "scope": ["request"], "port": 8080 "defaultForScheme": true, } }`), []byte(`{ "port.Modifier": { "scope": ["request"], "defaultForScheme": true, "remove": true } }`), []byte(`{ "port.Modifier": { "scope": ["request"], } }`), []byte(`{ "port.Modifier": { "scope": ["response"], "remove": true } }`), } { _, err := parse.FromJSON(msg) if err == nil { t.Fatalf("parseFromJSON(msg): Got no error, but should have gotten one.") } } } martian-3.3.2/priority/000077500000000000000000000000001421371434000150325ustar00rootroot00000000000000martian-3.3.2/priority/priority_group.go000066400000000000000000000144121421371434000204600ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package priority allows grouping modifiers and applying them in priority order. package priority import ( "encoding/json" "errors" "net/http" "sync" "github.com/google/martian/v3" "github.com/google/martian/v3/parse" ) var ( // ErrModifierNotFound is the error returned when attempting to remove a // modifier when the modifier does not exist in the group. ErrModifierNotFound = errors.New("modifier not found in group") ) // priorityRequestModifier is a request modifier with a priority. type priorityRequestModifier struct { reqmod martian.RequestModifier priority int64 } // priorityResponseModifier is a response modifier with a priority. type priorityResponseModifier struct { resmod martian.ResponseModifier priority int64 } // Group is a group of request and response modifiers ordered by their priority. type Group struct { reqmu sync.RWMutex reqmods []*priorityRequestModifier resmu sync.RWMutex resmods []*priorityResponseModifier } type groupJSON struct { Modifiers []modifierJSON `json:"modifiers"` Scope []parse.ModifierType `json:"scope"` } type modifierJSON struct { Priority int64 `json:"priority"` Modifier json.RawMessage `json:"modifier"` } func init() { parse.Register("priority.Group", groupFromJSON) } // NewGroup returns a priority group. func NewGroup() *Group { return &Group{} } // AddRequestModifier adds a RequestModifier with the given priority. // // If a modifier is added with a priority that is equal to an existing priority // the newer modifier will be added before the existing modifier in the chain. func (pg *Group) AddRequestModifier(reqmod martian.RequestModifier, priority int64) { pg.reqmu.Lock() defer pg.reqmu.Unlock() preqmod := &priorityRequestModifier{ reqmod: reqmod, priority: priority, } for i, m := range pg.reqmods { if preqmod.priority >= m.priority { pg.reqmods = append(pg.reqmods, nil) copy(pg.reqmods[i+1:], pg.reqmods[i:]) pg.reqmods[i] = preqmod return } } // Either this is the first modifier in the list, or the priority is less // than all existing modifiers. pg.reqmods = append(pg.reqmods, preqmod) } // RemoveRequestModifier removes the the highest priority given RequestModifier. // Returns ErrModifierNotFound if the given modifier does not exist in the group. func (pg *Group) RemoveRequestModifier(reqmod martian.RequestModifier) error { pg.reqmu.Lock() defer pg.reqmu.Unlock() for i, m := range pg.reqmods { if m.reqmod == reqmod { copy(pg.reqmods[i:], pg.reqmods[i+1:]) pg.reqmods[len(pg.reqmods)-1] = nil pg.reqmods = pg.reqmods[:len(pg.reqmods)-1] return nil } } return ErrModifierNotFound } // AddResponseModifier adds a ResponseModifier with the given priority. // // If a modifier is added with a priority that is equal to an existing priority // the newer modifier will be added before the existing modifier in the chain. func (pg *Group) AddResponseModifier(resmod martian.ResponseModifier, priority int64) { pg.resmu.Lock() defer pg.resmu.Unlock() presmod := &priorityResponseModifier{ resmod: resmod, priority: priority, } for i, m := range pg.resmods { if presmod.priority >= m.priority { pg.resmods = append(pg.resmods, nil) copy(pg.resmods[i+1:], pg.resmods[i:]) pg.resmods[i] = presmod return } } // Either this is the first modifier in the list, or the priority is less // than all existing modifiers. pg.resmods = append(pg.resmods, presmod) } // RemoveResponseModifier removes the the highest priority given ResponseModifier. // Returns ErrModifierNotFound if the given modifier does not exist in the group. func (pg *Group) RemoveResponseModifier(resmod martian.ResponseModifier) error { pg.resmu.Lock() defer pg.resmu.Unlock() for i, m := range pg.resmods { if m.resmod == resmod { copy(pg.resmods[i:], pg.resmods[i+1:]) pg.resmods[len(pg.resmods)-1] = nil pg.resmods = pg.resmods[:len(pg.resmods)-1] return nil } } return ErrModifierNotFound } // ModifyRequest modifies the request. Modifiers are run in descending order of // their priority. If an error is returned by a RequestModifier the error is // returned and no further modifiers are run. func (pg *Group) ModifyRequest(req *http.Request) error { pg.reqmu.RLock() defer pg.reqmu.RUnlock() for _, m := range pg.reqmods { if err := m.reqmod.ModifyRequest(req); err != nil { return err } } return nil } // ModifyResponse modifies the response. Modifiers are run in descending order // of their priority. If an error is returned by a ResponseModifier the error // is returned and no further modifiers are run. func (pg *Group) ModifyResponse(res *http.Response) error { pg.resmu.RLock() defer pg.resmu.RUnlock() for _, m := range pg.resmods { if err := m.resmod.ModifyResponse(res); err != nil { return err } } return nil } // groupFromJSON builds a priority.Group from JSON. // // Example JSON: // { // "priority.Group": { // "scope": ["request", "response"], // "modifiers": [ // { // "priority": 100, // Will run first. // "modifier": { ... }, // }, // { // "priority": 0, // Will run last. // "modifier": { ... }, // } // ] // } // } func groupFromJSON(b []byte) (*parse.Result, error) { msg := &groupJSON{} if err := json.Unmarshal(b, msg); err != nil { return nil, err } pg := NewGroup() for _, m := range msg.Modifiers { r, err := parse.FromJSON(m.Modifier) if err != nil { return nil, err } reqmod := r.RequestModifier() if reqmod != nil { pg.AddRequestModifier(reqmod, m.Priority) } resmod := r.ResponseModifier() if resmod != nil { pg.AddResponseModifier(resmod, m.Priority) } } return parse.NewResult(pg, msg.Scope) } martian-3.3.2/priority/priority_group_test.go000066400000000000000000000150461421371434000215230ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package priority import ( "errors" "net/http" "reflect" "testing" "github.com/google/martian/v3/martiantest" "github.com/google/martian/v3/parse" "github.com/google/martian/v3/proxyutil" // Import to register header.Modifier with JSON parser. _ "github.com/google/martian/v3/header" ) func TestPriorityGroupModifyRequest(t *testing.T) { var order []string pg := NewGroup() tm50 := martiantest.NewModifier() tm50.RequestFunc(func(*http.Request) { order = append(order, "tm50") }) pg.AddRequestModifier(tm50, 50) tm100a := martiantest.NewModifier() tm100a.RequestFunc(func(*http.Request) { order = append(order, "tm100a") }) pg.AddRequestModifier(tm100a, 100) tm100b := martiantest.NewModifier() tm100b.RequestFunc(func(*http.Request) { order = append(order, "tm100b") }) pg.AddRequestModifier(tm100b, 100) tm75 := martiantest.NewModifier() tm75.RequestFunc(func(*http.Request) { order = append(order, "tm75") }) if err := pg.RemoveRequestModifier(tm75); err != ErrModifierNotFound { t.Fatalf("RemoveRequestModifier(): got %v, want ErrModifierNotFound", err) } pg.AddRequestModifier(tm75, 100) if err := pg.RemoveRequestModifier(tm75); err != nil { t.Fatalf("RemoveRequestModifier(): got %v, want no error", err) } req, err := http.NewRequest("GET", "http://example.com/", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := pg.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := order, []string{"tm100b", "tm100a", "tm50"}; !reflect.DeepEqual(got, want) { t.Fatalf("reflect.DeepEqual(%v, %v): got false, want true", got, want) } } func TestPriorityGroupModifyRequestHaltsOnError(t *testing.T) { pg := NewGroup() reqerr := errors.New("request error") tm := martiantest.NewModifier() tm.RequestError(reqerr) pg.AddRequestModifier(tm, 100) tm2 := martiantest.NewModifier() pg.AddRequestModifier(tm2, 75) req, err := http.NewRequest("GET", "http://example.com/", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := pg.ModifyRequest(req); err != reqerr { t.Fatalf("ModifyRequest(): got %v, want %v", err, reqerr) } if tm2.RequestModified() { t.Error("tm2.RequestModified(): got true, want false") } } func TestPriorityGroupModifyResponse(t *testing.T) { var order []string pg := NewGroup() tm50 := martiantest.NewModifier() tm50.ResponseFunc(func(*http.Response) { order = append(order, "tm50") }) pg.AddResponseModifier(tm50, 50) tm100a := martiantest.NewModifier() tm100a.ResponseFunc(func(*http.Response) { order = append(order, "tm100a") }) pg.AddResponseModifier(tm100a, 100) tm100b := martiantest.NewModifier() tm100b.ResponseFunc(func(*http.Response) { order = append(order, "tm100b") }) pg.AddResponseModifier(tm100b, 100) tm75 := martiantest.NewModifier() tm75.ResponseFunc(func(*http.Response) { order = append(order, "tm75") }) if err := pg.RemoveResponseModifier(tm75); err != ErrModifierNotFound { t.Fatalf("RemoveResponseModifier(): got %v, want ErrModifierNotFound", err) } pg.AddResponseModifier(tm75, 100) if err := pg.RemoveResponseModifier(tm75); err != nil { t.Fatalf("RemoveResponseModifier(): got %v, want no error", err) } res := proxyutil.NewResponse(200, nil, nil) if err := pg.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := order, []string{"tm100b", "tm100a", "tm50"}; !reflect.DeepEqual(got, want) { t.Fatalf("reflect.DeepEqual(%v, %v): got false, want true", got, want) } } func TestPriorityGroupModifyResponseHaltsOnError(t *testing.T) { pg := NewGroup() reserr := errors.New("response error") tm := martiantest.NewModifier() tm.ResponseError(reserr) pg.AddResponseModifier(tm, 100) tm2 := martiantest.NewModifier() pg.AddResponseModifier(tm2, 75) res := proxyutil.NewResponse(200, nil, nil) if err := pg.ModifyResponse(res); err != reserr { t.Fatalf("ModifyRequest(): got %v, want %v", err, reserr) } if tm2.ResponseModified() { t.Error("tm2.ResponseModified(): got true, want false") } } func TestGroupFromJSON(t *testing.T) { msg := []byte(`{ "priority.Group": { "scope": ["request", "response"], "modifiers": [ { "priority": 100, "modifier": { "header.Modifier": { "scope": ["request", "response"], "name": "X-Testing", "value": "true" } } }, { "priority": 0, "modifier": { "header.Modifier": { "scope": ["request", "response"], "name": "Y-Testing", "value": "true" } } } ] } }`) r, err := parse.FromJSON(msg) if err != nil { t.Fatalf("parse.FromJSON(): got %v, want no error", err) } reqmod := r.RequestModifier() if reqmod == nil { t.Fatal("reqmod: got nil, want not nil") } req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := reqmod.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := req.Header.Get("X-Testing"), "true"; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", "X-Testing", got, want) } if got, want := req.Header.Get("Y-Testing"), "true"; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", "Y-Testing", got, want) } resmod := r.ResponseModifier() if resmod == nil { t.Fatal("resmod: got nil, want not nil") } res := proxyutil.NewResponse(200, nil, req) if err := resmod.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := res.Header.Get("X-Testing"), "true"; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "X-Testing", got, want) } if got, want := res.Header.Get("Y-Testing"), "true"; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Y-Testing", got, want) } } martian-3.3.2/proxy.go000066400000000000000000000416271421371434000146730ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package martian import ( "bufio" "bytes" "crypto/tls" "errors" "io" "net" "net/http" "net/http/httputil" "net/url" "regexp" "sync" "time" "github.com/google/martian/v3/log" "github.com/google/martian/v3/mitm" "github.com/google/martian/v3/nosigpipe" "github.com/google/martian/v3/proxyutil" "github.com/google/martian/v3/trafficshape" ) var errClose = errors.New("closing connection") var noop = Noop("martian") func isCloseable(err error) bool { if neterr, ok := err.(net.Error); ok && neterr.Timeout() { return true } switch err { case io.EOF, io.ErrClosedPipe, errClose: return true } return false } // Proxy is an HTTP proxy with support for TLS MITM and customizable behavior. type Proxy struct { roundTripper http.RoundTripper dial func(string, string) (net.Conn, error) timeout time.Duration mitm *mitm.Config proxyURL *url.URL conns sync.WaitGroup connsMu sync.Mutex // protects conns.Add/Wait from concurrent access closing chan bool reqmod RequestModifier resmod ResponseModifier } // NewProxy returns a new HTTP proxy. func NewProxy() *Proxy { proxy := &Proxy{ roundTripper: &http.Transport{ // TODO(adamtanner): This forces the http.Transport to not upgrade requests // to HTTP/2 in Go 1.6+. Remove this once Martian can support HTTP/2. TLSNextProto: make(map[string]func(string, *tls.Conn) http.RoundTripper), Proxy: http.ProxyFromEnvironment, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: time.Second, }, timeout: 5 * time.Minute, closing: make(chan bool), reqmod: noop, resmod: noop, } proxy.SetDial((&net.Dialer{ Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, }).Dial) return proxy } // GetRoundTripper gets the http.RoundTripper of the proxy. func (p *Proxy) GetRoundTripper() http.RoundTripper { return p.roundTripper } // SetRoundTripper sets the http.RoundTripper of the proxy. func (p *Proxy) SetRoundTripper(rt http.RoundTripper) { p.roundTripper = rt if tr, ok := p.roundTripper.(*http.Transport); ok { tr.TLSNextProto = make(map[string]func(string, *tls.Conn) http.RoundTripper) tr.Proxy = http.ProxyURL(p.proxyURL) tr.Dial = p.dial } } // SetDownstreamProxy sets the proxy that receives requests from the upstream // proxy. func (p *Proxy) SetDownstreamProxy(proxyURL *url.URL) { p.proxyURL = proxyURL if tr, ok := p.roundTripper.(*http.Transport); ok { tr.Proxy = http.ProxyURL(p.proxyURL) } } // SetTimeout sets the request timeout of the proxy. func (p *Proxy) SetTimeout(timeout time.Duration) { p.timeout = timeout } // SetMITM sets the config to use for MITMing of CONNECT requests. func (p *Proxy) SetMITM(config *mitm.Config) { p.mitm = config } // SetDial sets the dial func used to establish a connection. func (p *Proxy) SetDial(dial func(string, string) (net.Conn, error)) { p.dial = func(a, b string) (net.Conn, error) { c, e := dial(a, b) nosigpipe.IgnoreSIGPIPE(c) return c, e } if tr, ok := p.roundTripper.(*http.Transport); ok { tr.Dial = p.dial } } // Close sets the proxy to the closing state so it stops receiving new connections, // finishes processing any inflight requests, and closes existing connections without // reading anymore requests from them. func (p *Proxy) Close() { log.Infof("martian: closing down proxy") close(p.closing) log.Infof("martian: waiting for connections to close") p.connsMu.Lock() p.conns.Wait() p.connsMu.Unlock() log.Infof("martian: all connections closed") } // Closing returns whether the proxy is in the closing state. func (p *Proxy) Closing() bool { select { case <-p.closing: return true default: return false } } // SetRequestModifier sets the request modifier. func (p *Proxy) SetRequestModifier(reqmod RequestModifier) { if reqmod == nil { reqmod = noop } p.reqmod = reqmod } // SetResponseModifier sets the response modifier. func (p *Proxy) SetResponseModifier(resmod ResponseModifier) { if resmod == nil { resmod = noop } p.resmod = resmod } // Serve accepts connections from the listener and handles the requests. func (p *Proxy) Serve(l net.Listener) error { defer l.Close() var delay time.Duration for { if p.Closing() { return nil } conn, err := l.Accept() nosigpipe.IgnoreSIGPIPE(conn) if err != nil { if nerr, ok := err.(net.Error); ok && nerr.Temporary() { if delay == 0 { delay = 5 * time.Millisecond } else { delay *= 2 } if max := time.Second; delay > max { delay = max } log.Debugf("martian: temporary error on accept: %v", err) time.Sleep(delay) continue } if errors.Is(err, net.ErrClosed) { log.Debugf("martian: listener closed, returning") return err } log.Errorf("martian: failed to accept: %v", err) return err } delay = 0 log.Debugf("martian: accepted connection from %s", conn.RemoteAddr()) if tconn, ok := conn.(*net.TCPConn); ok { tconn.SetKeepAlive(true) tconn.SetKeepAlivePeriod(3 * time.Minute) } go p.handleLoop(conn) } } func (p *Proxy) handleLoop(conn net.Conn) { p.connsMu.Lock() p.conns.Add(1) p.connsMu.Unlock() defer p.conns.Done() defer conn.Close() if p.Closing() { return } brw := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn)) s, err := newSession(conn, brw) if err != nil { log.Errorf("martian: failed to create session: %v", err) return } ctx, err := withSession(s) if err != nil { log.Errorf("martian: failed to create context: %v", err) return } for { deadline := time.Now().Add(p.timeout) conn.SetDeadline(deadline) if err := p.handle(ctx, conn, brw); isCloseable(err) { log.Debugf("martian: closing connection: %v", conn.RemoteAddr()) return } } } func (p *Proxy) readRequest(ctx *Context, conn net.Conn, brw *bufio.ReadWriter) (*http.Request, error) { var req *http.Request reqc := make(chan *http.Request, 1) errc := make(chan error, 1) go func() { r, err := http.ReadRequest(brw.Reader) if err != nil { errc <- err return } reqc <- r }() select { case err := <-errc: if isCloseable(err) { log.Debugf("martian: connection closed prematurely: %v", err) } else { log.Errorf("martian: failed to read request: %v", err) } // TODO: TCPConn.WriteClose() to avoid sending an RST to the client. return nil, errClose case req = <-reqc: case <-p.closing: return nil, errClose } return req, nil } func (p *Proxy) handleConnectRequest(ctx *Context, req *http.Request, session *Session, brw *bufio.ReadWriter, conn net.Conn) error { if err := p.reqmod.ModifyRequest(req); err != nil { log.Errorf("martian: error modifying CONNECT request: %v", err) proxyutil.Warning(req.Header, err) } if session.Hijacked() { log.Debugf("martian: connection hijacked by request modifier") return nil } if p.mitm != nil { log.Debugf("martian: attempting MITM for connection: %s / %s", req.Host, req.URL.String()) res := proxyutil.NewResponse(200, nil, req) if err := p.resmod.ModifyResponse(res); err != nil { log.Errorf("martian: error modifying CONNECT response: %v", err) proxyutil.Warning(res.Header, err) } if session.Hijacked() { log.Infof("martian: connection hijacked by response modifier") return nil } if err := res.Write(brw); err != nil { log.Errorf("martian: got error while writing response back to client: %v", err) } if err := brw.Flush(); err != nil { log.Errorf("martian: got error while flushing response back to client: %v", err) } log.Debugf("martian: completed MITM for connection: %s", req.Host) b := make([]byte, 1) if _, err := brw.Read(b); err != nil { log.Errorf("martian: error peeking message through CONNECT tunnel to determine type: %v", err) } // Drain all of the rest of the buffered data. buf := make([]byte, brw.Reader.Buffered()) brw.Read(buf) // 22 is the TLS handshake. // https://tools.ietf.org/html/rfc5246#section-6.2.1 if b[0] == 22 { // Prepend the previously read data to be read again by // http.ReadRequest. tlsconn := tls.Server(&peekedConn{conn, io.MultiReader(bytes.NewReader(b), bytes.NewReader(buf), conn)}, p.mitm.TLSForHost(req.Host)) if err := tlsconn.Handshake(); err != nil { p.mitm.HandshakeErrorCallback(req, err) return err } if tlsconn.ConnectionState().NegotiatedProtocol == "h2" { return p.mitm.H2Config().Proxy(p.closing, tlsconn, req.URL) } var nconn net.Conn nconn = tlsconn // If the original connection is a traffic shaped connection, wrap the tls // connection inside a traffic shaped connection too. if ptsconn, ok := conn.(*trafficshape.Conn); ok { nconn = ptsconn.Listener.GetTrafficShapedConn(tlsconn) } brw.Writer.Reset(nconn) brw.Reader.Reset(nconn) return p.handle(ctx, nconn, brw) } // Prepend the previously read data to be read again by http.ReadRequest. brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), bytes.NewReader(buf), conn)) return p.handle(ctx, conn, brw) } log.Debugf("martian: attempting to establish CONNECT tunnel: %s", req.URL.Host) res, cconn, cerr := p.connect(req) if cerr != nil { log.Errorf("martian: failed to CONNECT: %v", cerr) res = proxyutil.NewResponse(502, nil, req) proxyutil.Warning(res.Header, cerr) if err := p.resmod.ModifyResponse(res); err != nil { log.Errorf("martian: error modifying CONNECT response: %v", err) proxyutil.Warning(res.Header, err) } if session.Hijacked() { log.Infof("martian: connection hijacked by response modifier") return nil } if err := res.Write(brw); err != nil { log.Errorf("martian: got error while writing response back to client: %v", err) } err := brw.Flush() if err != nil { log.Errorf("martian: got error while flushing response back to client: %v", err) } return err } defer res.Body.Close() defer cconn.Close() if err := p.resmod.ModifyResponse(res); err != nil { log.Errorf("martian: error modifying CONNECT response: %v", err) proxyutil.Warning(res.Header, err) } if session.Hijacked() { log.Infof("martian: connection hijacked by response modifier") return nil } res.ContentLength = -1 if err := res.Write(brw); err != nil { log.Errorf("martian: got error while writing response back to client: %v", err) } if err := brw.Flush(); err != nil { log.Errorf("martian: got error while flushing response back to client: %v", err) } cbw := bufio.NewWriter(cconn) cbr := bufio.NewReader(cconn) defer cbw.Flush() copySync := func(w io.Writer, r io.Reader, donec chan<- bool) { if _, err := io.Copy(w, r); err != nil && err != io.EOF { log.Errorf("martian: failed to copy CONNECT tunnel: %v", err) } log.Debugf("martian: CONNECT tunnel finished copying") donec <- true } donec := make(chan bool, 2) go copySync(cbw, brw, donec) go copySync(brw, cbr, donec) log.Debugf("martian: established CONNECT tunnel, proxying traffic") <-donec <-donec log.Debugf("martian: closed CONNECT tunnel") return errClose } func (p *Proxy) handle(ctx *Context, conn net.Conn, brw *bufio.ReadWriter) error { log.Debugf("martian: waiting for request: %v", conn.RemoteAddr()) req, err := p.readRequest(ctx, conn, brw) if err != nil { return err } defer req.Body.Close() session := ctx.Session() ctx, err = withSession(session) if err != nil { log.Errorf("martian: failed to build new context: %v", err) return err } link(req, ctx) defer unlink(req) if tsconn, ok := conn.(*trafficshape.Conn); ok { wrconn := tsconn.GetWrappedConn() if sconn, ok := wrconn.(*tls.Conn); ok { session.MarkSecure() cs := sconn.ConnectionState() req.TLS = &cs } } if tconn, ok := conn.(*tls.Conn); ok { session.MarkSecure() cs := tconn.ConnectionState() req.TLS = &cs } req.URL.Scheme = "http" if session.IsSecure() { log.Infof("martian: forcing HTTPS inside secure session") req.URL.Scheme = "https" } req.RemoteAddr = conn.RemoteAddr().String() if req.URL.Host == "" { req.URL.Host = req.Host } if req.Method == "CONNECT" { return p.handleConnectRequest(ctx, req, session, brw, conn) } // Not a CONNECT request if err := p.reqmod.ModifyRequest(req); err != nil { log.Errorf("martian: error modifying request: %v", err) proxyutil.Warning(req.Header, err) } if session.Hijacked() { return nil } // perform the HTTP roundtrip res, err := p.roundTrip(ctx, req) if err != nil { log.Errorf("martian: failed to round trip: %v", err) res = proxyutil.NewResponse(502, nil, req) proxyutil.Warning(res.Header, err) } defer res.Body.Close() // set request to original request manually, res.Request may be changed in transport. // see https://github.com/google/martian/issues/298 res.Request = req if err := p.resmod.ModifyResponse(res); err != nil { log.Errorf("martian: error modifying response: %v", err) proxyutil.Warning(res.Header, err) } if session.Hijacked() { log.Infof("martian: connection hijacked by response modifier") return nil } var closing error if req.Close || res.Close || p.Closing() { log.Debugf("martian: received close request: %v", req.RemoteAddr) res.Close = true closing = errClose } // check if conn is a traffic shaped connection. if ptsconn, ok := conn.(*trafficshape.Conn); ok { ptsconn.Context = &trafficshape.Context{} // Check if the request URL matches any URLRegex in Shapes. If so, set the connections's Context // with the required information, so that the Write() method of the Conn has access to it. for urlregex, buckets := range ptsconn.LocalBuckets { if match, _ := regexp.MatchString(urlregex, req.URL.String()); match { if rangeStart := proxyutil.GetRangeStart(res); rangeStart > -1 { dump, err := httputil.DumpResponse(res, false) if err != nil { return err } ptsconn.Context = &trafficshape.Context{ Shaping: true, Buckets: buckets, GlobalBucket: ptsconn.GlobalBuckets[urlregex], URLRegex: urlregex, RangeStart: rangeStart, ByteOffset: rangeStart, HeaderLen: int64(len(dump)), HeaderBytesWritten: 0, } // Get the next action to perform, if there. ptsconn.Context.NextActionInfo = ptsconn.GetNextActionFromByte(rangeStart) // Check if response lies in a throttled byte range. ptsconn.Context.ThrottleContext = ptsconn.GetCurrentThrottle(rangeStart) if ptsconn.Context.ThrottleContext.ThrottleNow { ptsconn.Context.Buckets.WriteBucket.SetCapacity( ptsconn.Context.ThrottleContext.Bandwidth) } log.Infof( "trafficshape: Request %s with Range Start: %d matches a Shaping request %s. Enforcing Traffic shaping.", req.URL, rangeStart, urlregex) } break } } } err = res.Write(brw) if err != nil { log.Errorf("martian: got error while writing response back to client: %v", err) if _, ok := err.(*trafficshape.ErrForceClose); ok { closing = errClose } } err = brw.Flush() if err != nil { log.Errorf("martian: got error while flushing response back to client: %v", err) if _, ok := err.(*trafficshape.ErrForceClose); ok { closing = errClose } } return closing } // A peekedConn subverts the net.Conn.Read implementation, primarily so that // sniffed bytes can be transparently prepended. type peekedConn struct { net.Conn r io.Reader } // Read allows control over the embedded net.Conn's read data. By using an // io.MultiReader one can read from a conn, and then replace what they read, to // be read again. func (c *peekedConn) Read(buf []byte) (int, error) { return c.r.Read(buf) } func (p *Proxy) roundTrip(ctx *Context, req *http.Request) (*http.Response, error) { if ctx.SkippingRoundTrip() { log.Debugf("martian: skipping round trip") return proxyutil.NewResponse(200, nil, req), nil } return p.roundTripper.RoundTrip(req) } func (p *Proxy) connect(req *http.Request) (*http.Response, net.Conn, error) { if p.proxyURL != nil { log.Debugf("martian: CONNECT with downstream proxy: %s", p.proxyURL.Host) conn, err := p.dial("tcp", p.proxyURL.Host) if err != nil { return nil, nil, err } pbw := bufio.NewWriter(conn) pbr := bufio.NewReader(conn) req.Write(pbw) pbw.Flush() res, err := http.ReadResponse(pbr, req) if err != nil { return nil, nil, err } return res, conn, nil } log.Debugf("martian: CONNECT to host directly: %s", req.URL.Host) conn, err := p.dial("tcp", req.URL.Host) if err != nil { return nil, nil, err } return proxyutil.NewResponse(200, nil, req), conn, nil } martian-3.3.2/proxy_test.go000066400000000000000000001047621421371434000157320ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package martian import ( "bufio" "bytes" "crypto/tls" "crypto/x509" "errors" "fmt" "io" "io/ioutil" "net" "net/http" "net/url" "os" "strings" "testing" "time" "github.com/google/martian/v3/log" "github.com/google/martian/v3/martiantest" "github.com/google/martian/v3/mitm" "github.com/google/martian/v3/proxyutil" ) type tempError struct{} func (e *tempError) Error() string { return "temporary" } func (e *tempError) Timeout() bool { return true } func (e *tempError) Temporary() bool { return true } type timeoutListener struct { net.Listener errCount int err error } func newTimeoutListener(l net.Listener, errCount int) net.Listener { return &timeoutListener{ Listener: l, errCount: errCount, err: &tempError{}, } } func (l *timeoutListener) Accept() (net.Conn, error) { if l.errCount > 0 { l.errCount-- return nil, l.err } return l.Listener.Accept() } func TestIntegrationTemporaryTimeout(t *testing.T) { t.Parallel() l, err := net.Listen("tcp", "[::]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } p := NewProxy() defer p.Close() tr := martiantest.NewTransport() p.SetRoundTripper(tr) p.SetTimeout(200 * time.Millisecond) // Start the proxy with a listener that will return a temporary error on // Accept() three times. go p.Serve(newTimeoutListener(l, 3)) conn, err := net.Dial("tcp", l.Addr().String()) if err != nil { t.Fatalf("net.Dial(): got %v, want no error", err) } defer conn.Close() req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Header.Set("Connection", "close") // GET http://example.com/ HTTP/1.1 // Host: example.com if err := req.WriteProxy(conn); err != nil { t.Fatalf("req.WriteProxy(): got %v, want no error", err) } res, err := http.ReadResponse(bufio.NewReader(conn), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } defer res.Body.Close() if got, want := res.StatusCode, 200; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } } func TestIntegrationHTTP(t *testing.T) { t.Parallel() l, err := net.Listen("tcp", "[::]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } p := NewProxy() defer p.Close() p.SetRequestModifier(nil) p.SetResponseModifier(nil) tr := martiantest.NewTransport() p.SetRoundTripper(tr) p.SetTimeout(200 * time.Millisecond) tm := martiantest.NewModifier() tm.RequestFunc(func(req *http.Request) { ctx := NewContext(req) ctx.Set("martian.test", "true") }) tm.ResponseFunc(func(res *http.Response) { ctx := NewContext(res.Request) v, _ := ctx.Get("martian.test") res.Header.Set("Martian-Test", v.(string)) }) p.SetRequestModifier(tm) p.SetResponseModifier(tm) go p.Serve(l) conn, err := net.Dial("tcp", l.Addr().String()) if err != nil { t.Fatalf("net.Dial(): got %v, want no error", err) } defer conn.Close() req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } // GET http://example.com/ HTTP/1.1 // Host: example.com if err := req.WriteProxy(conn); err != nil { t.Fatalf("req.WriteProxy(): got %v, want no error", err) } res, err := http.ReadResponse(bufio.NewReader(conn), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } if got, want := res.StatusCode, 200; got != want { t.Fatalf("res.StatusCode: got %d, want %d", got, want) } if got, want := res.Header.Get("Martian-Test"), "true"; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Martian-Test", got, want) } } func TestIntegrationHTTP100Continue(t *testing.T) { t.Parallel() l, err := net.Listen("tcp", "[::]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } p := NewProxy() defer p.Close() p.SetTimeout(2 * time.Second) sl, err := net.Listen("tcp", "[::]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } go func() { conn, err := sl.Accept() if err != nil { log.Errorf("proxy_test: failed to accept connection: %v", err) return } defer conn.Close() log.Infof("proxy_test: accepted connection: %s", conn.RemoteAddr()) req, err := http.ReadRequest(bufio.NewReader(conn)) if err != nil { log.Errorf("proxy_test: failed to read request: %v", err) return } if req.Header.Get("Expect") == "100-continue" { log.Infof("proxy_test: received 100-continue request") conn.Write([]byte("HTTP/1.1 100 Continue\r\n\r\n")) log.Infof("proxy_test: sent 100-continue response") } else { log.Infof("proxy_test: received non 100-continue request") res := proxyutil.NewResponse(417, nil, req) res.Header.Set("Connection", "close") res.Write(conn) return } res := proxyutil.NewResponse(200, req.Body, req) res.Header.Set("Connection", "close") res.Write(conn) log.Infof("proxy_test: sent 200 response") }() tm := martiantest.NewModifier() p.SetRequestModifier(tm) p.SetResponseModifier(tm) go p.Serve(l) conn, err := net.Dial("tcp", l.Addr().String()) if err != nil { t.Fatalf("net.Dial(): got %v, want no error", err) } defer conn.Close() host := sl.Addr().String() raw := fmt.Sprintf("POST http://%s/ HTTP/1.1\r\n"+ "Host: %s\r\n"+ "Content-Length: 12\r\n"+ "Expect: 100-continue\r\n\r\n", host, host) if _, err := conn.Write([]byte(raw)); err != nil { t.Fatalf("conn.Write(headers): got %v, want no error", err) } go func() { select { case <-time.After(time.Second): conn.Write([]byte("body content")) } }() res, err := http.ReadResponse(bufio.NewReader(conn), nil) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } defer res.Body.Close() if got, want := res.StatusCode, 200; got != want { t.Fatalf("res.StatusCode: got %d, want %d", got, want) } got, err := ioutil.ReadAll(res.Body) if err != nil { t.Fatalf("ioutil.ReadAll(): got %v, want no error", err) } if want := []byte("body content"); !bytes.Equal(got, want) { t.Errorf("res.Body: got %q, want %q", got, want) } if !tm.RequestModified() { t.Error("tm.RequestModified(): got false, want true") } if !tm.ResponseModified() { t.Error("tm.ResponseModified(): got false, want true") } } func TestIntegrationHTTPDownstreamProxy(t *testing.T) { t.Parallel() // Start first proxy to use as downstream. dl, err := net.Listen("tcp", "[::]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } downstream := NewProxy() defer downstream.Close() dtr := martiantest.NewTransport() dtr.Respond(299) downstream.SetRoundTripper(dtr) downstream.SetTimeout(600 * time.Millisecond) go downstream.Serve(dl) // Start second proxy as upstream proxy, will write to downstream proxy. ul, err := net.Listen("tcp", "[::]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } upstream := NewProxy() defer upstream.Close() // Set upstream proxy's downstream proxy to the host:port of the first proxy. upstream.SetDownstreamProxy(&url.URL{ Host: dl.Addr().String(), }) upstream.SetTimeout(600 * time.Millisecond) go upstream.Serve(ul) // Open connection to upstream proxy. conn, err := net.Dial("tcp", ul.Addr().String()) if err != nil { t.Fatalf("net.Dial(): got %v, want no error", err) } defer conn.Close() req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } // GET http://example.com/ HTTP/1.1 // Host: example.com if err := req.WriteProxy(conn); err != nil { t.Fatalf("req.WriteProxy(): got %v, want no error", err) } // Response from downstream proxy. res, err := http.ReadResponse(bufio.NewReader(conn), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } if got, want := res.StatusCode, 299; got != want { t.Fatalf("res.StatusCode: got %d, want %d", got, want) } } func TestIntegrationHTTPDownstreamProxyError(t *testing.T) { t.Parallel() l, err := net.Listen("tcp", "[::]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } p := NewProxy() defer p.Close() // Set proxy's downstream proxy to invalid host:port to force failure. p.SetDownstreamProxy(&url.URL{ Host: "[::]:0", }) p.SetTimeout(600 * time.Millisecond) tm := martiantest.NewModifier() reserr := errors.New("response error") tm.ResponseError(reserr) p.SetResponseModifier(tm) go p.Serve(l) // Open connection to upstream proxy. conn, err := net.Dial("tcp", l.Addr().String()) if err != nil { t.Fatalf("net.Dial(): got %v, want no error", err) } defer conn.Close() req, err := http.NewRequest("CONNECT", "//example.com:443", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } // CONNECT example.com:443 HTTP/1.1 // Host: example.com if err := req.Write(conn); err != nil { t.Fatalf("req.Write(): got %v, want no error", err) } // Response from upstream proxy, assuming downstream proxy failed to CONNECT. res, err := http.ReadResponse(bufio.NewReader(conn), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } if got, want := res.StatusCode, 502; got != want { t.Fatalf("res.StatusCode: got %d, want %d", got, want) } if got, want := res.Header["Warning"][1], reserr.Error(); !strings.Contains(got, want) { t.Errorf("res.Header.get(%q): got %q, want to contain %q", "Warning", got, want) } } func TestIntegrationTLSHandshakeErrorCallback(t *testing.T) { t.Parallel() l, err := net.Listen("tcp", "[::]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } p := NewProxy() defer p.Close() // Test TLS server. ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", time.Hour) if err != nil { t.Fatalf("mitm.NewAuthority(): got %v, want no error", err) } mc, err := mitm.NewConfig(ca, priv) if err != nil { t.Fatalf("mitm.NewConfig(): got %v, want no error", err) } var herr error mc.SetHandshakeErrorCallback(func(_ *http.Request, err error) { herr = fmt.Errorf("handshake error") }) p.SetMITM(mc) tl, err := net.Listen("tcp", "[::]:0") if err != nil { t.Fatalf("tls.Listen(): got %v, want no error", err) } tl = tls.NewListener(tl, mc.TLS()) go http.Serve(tl, http.HandlerFunc( func(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(200) })) tm := martiantest.NewModifier() // Force the CONNECT request to dial the local TLS server. tm.RequestFunc(func(req *http.Request) { req.URL.Host = tl.Addr().String() }) go p.Serve(l) conn, err := net.Dial("tcp", l.Addr().String()) if err != nil { t.Fatalf("net.Dial(): got %v, want no error", err) } defer conn.Close() req, err := http.NewRequest("CONNECT", "//example.com:443", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } // CONNECT example.com:443 HTTP/1.1 // Host: example.com // // Rewritten to CONNECT to host:port in CONNECT request modifier. if err := req.Write(conn); err != nil { t.Fatalf("req.Write(): got %v, want no error", err) } // CONNECT response after establishing tunnel. if _, err := http.ReadResponse(bufio.NewReader(conn), req); err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } tlsconn := tls.Client(conn, &tls.Config{ ServerName: "example.com", // Client has no cert so it will get "x509: certificate signed by unknown authority" from the // handshake and send "remote error: bad certificate" to the server. RootCAs: x509.NewCertPool(), }) defer tlsconn.Close() req, err = http.NewRequest("GET", "https://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Header.Set("Connection", "close") if got, want := req.Write(tlsconn), "x509: certificate signed by unknown authority"; !strings.Contains(got.Error(), want) { t.Fatalf("Got incorrect error from Client Handshake(), got: %v, want: %v", got, want) } // TODO: herr is not being asserted against. It should be pushed on to a channel // of err, and the assertion should pull off of it and assert. That design resulted in the test // hanging for unknown reasons. t.Skip("skipping assertion of handshake error callback error due to mysterious deadlock") if got, want := herr, "remote error: bad certificate"; !strings.Contains(got.Error(), want) { t.Fatalf("Got incorrect error from Server Handshake(), got: %v, want: %v", got, want) } } func TestIntegrationConnect(t *testing.T) { t.Parallel() l, err := net.Listen("tcp", "[::]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } p := NewProxy() defer p.Close() // Test TLS server. ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", time.Hour) if err != nil { t.Fatalf("mitm.NewAuthority(): got %v, want no error", err) } mc, err := mitm.NewConfig(ca, priv) if err != nil { t.Fatalf("mitm.NewConfig(): got %v, want no error", err) } tl, err := net.Listen("tcp", "[::]:0") if err != nil { t.Fatalf("tls.Listen(): got %v, want no error", err) } tl = tls.NewListener(tl, mc.TLS()) go http.Serve(tl, http.HandlerFunc( func(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(299) })) tm := martiantest.NewModifier() reqerr := errors.New("request error") reserr := errors.New("response error") // Force the CONNECT request to dial the local TLS server. tm.RequestFunc(func(req *http.Request) { req.URL.Host = tl.Addr().String() }) tm.RequestError(reqerr) tm.ResponseError(reserr) p.SetRequestModifier(tm) p.SetResponseModifier(tm) go p.Serve(l) conn, err := net.Dial("tcp", l.Addr().String()) if err != nil { t.Fatalf("net.Dial(): got %v, want no error", err) } defer conn.Close() req, err := http.NewRequest("CONNECT", "//example.com:443", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } // CONNECT example.com:443 HTTP/1.1 // Host: example.com // // Rewritten to CONNECT to host:port in CONNECT request modifier. if err := req.Write(conn); err != nil { t.Fatalf("req.Write(): got %v, want no error", err) } // CONNECT response after establishing tunnel. res, err := http.ReadResponse(bufio.NewReader(conn), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } if got, want := res.StatusCode, 200; got != want { t.Fatalf("res.StatusCode: got %d, want %d", got, want) } if !tm.RequestModified() { t.Error("tm.RequestModified(): got false, want true") } if !tm.ResponseModified() { t.Error("tm.ResponseModified(): got false, want true") } if got, want := res.Header.Get("Warning"), reserr.Error(); !strings.Contains(got, want) { t.Errorf("res.Header.Get(%q): got %q, want to contain %q", "Warning", got, want) } roots := x509.NewCertPool() roots.AddCert(ca) tlsconn := tls.Client(conn, &tls.Config{ ServerName: "example.com", RootCAs: roots, }) defer tlsconn.Close() req, err = http.NewRequest("GET", "https://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Header.Set("Connection", "close") // GET / HTTP/1.1 // Host: example.com // Connection: close if err := req.Write(tlsconn); err != nil { t.Fatalf("req.Write(): got %v, want no error", err) } res, err = http.ReadResponse(bufio.NewReader(tlsconn), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } defer res.Body.Close() if got, want := res.StatusCode, 299; got != want { t.Fatalf("res.StatusCode: got %d, want %d", got, want) } if got, want := res.Header.Get("Warning"), reserr.Error(); strings.Contains(got, want) { t.Errorf("res.Header.Get(%q): got %s, want to not contain %s", "Warning", got, want) } } func TestIntegrationConnectDownstreamProxy(t *testing.T) { t.Parallel() // Start first proxy to use as downstream. dl, err := net.Listen("tcp", "[::]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } downstream := NewProxy() defer downstream.Close() dtr := martiantest.NewTransport() dtr.Respond(299) downstream.SetRoundTripper(dtr) ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", 2*time.Hour) if err != nil { t.Fatalf("mitm.NewAuthority(): got %v, want no error", err) } mc, err := mitm.NewConfig(ca, priv) if err != nil { t.Fatalf("mitm.NewConfig(): got %v, want no error", err) } downstream.SetMITM(mc) go downstream.Serve(dl) // Start second proxy as upstream proxy, will CONNECT to downstream proxy. ul, err := net.Listen("tcp", "[::]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } upstream := NewProxy() defer upstream.Close() // Set upstream proxy's downstream proxy to the host:port of the first proxy. upstream.SetDownstreamProxy(&url.URL{ Host: dl.Addr().String(), }) go upstream.Serve(ul) // Open connection to upstream proxy. conn, err := net.Dial("tcp", ul.Addr().String()) if err != nil { t.Fatalf("net.Dial(): got %v, want no error", err) } defer conn.Close() req, err := http.NewRequest("CONNECT", "//example.com:443", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } // CONNECT example.com:443 HTTP/1.1 // Host: example.com if err := req.Write(conn); err != nil { t.Fatalf("req.Write(): got %v, want no error", err) } // Response from downstream proxy starting MITM. res, err := http.ReadResponse(bufio.NewReader(conn), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } if got, want := res.StatusCode, 200; got != want { t.Fatalf("res.StatusCode: got %d, want %d", got, want) } roots := x509.NewCertPool() roots.AddCert(ca) tlsconn := tls.Client(conn, &tls.Config{ // Validate the hostname. ServerName: "example.com", // The certificate will have been MITM'd, verify using the MITM CA // certificate. RootCAs: roots, }) defer tlsconn.Close() req, err = http.NewRequest("GET", "https://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } // GET / HTTP/1.1 // Host: example.com if err := req.Write(tlsconn); err != nil { t.Fatalf("req.Write(): got %v, want no error", err) } // Response from MITM in downstream proxy. res, err = http.ReadResponse(bufio.NewReader(tlsconn), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } defer res.Body.Close() if got, want := res.StatusCode, 299; got != want { t.Fatalf("res.StatusCode: got %d, want %d", got, want) } } func TestIntegrationMITM(t *testing.T) { t.Parallel() l, err := net.Listen("tcp", "[::]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } p := NewProxy() defer p.Close() tr := martiantest.NewTransport() tr.Func(func(req *http.Request) (*http.Response, error) { res := proxyutil.NewResponse(200, nil, req) res.Header.Set("Request-Scheme", req.URL.Scheme) return res, nil }) p.SetRoundTripper(tr) p.SetTimeout(600 * time.Millisecond) ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", 2*time.Hour) if err != nil { t.Fatalf("mitm.NewAuthority(): got %v, want no error", err) } mc, err := mitm.NewConfig(ca, priv) if err != nil { t.Fatalf("mitm.NewConfig(): got %v, want no error", err) } p.SetMITM(mc) tm := martiantest.NewModifier() reqerr := errors.New("request error") reserr := errors.New("response error") tm.RequestError(reqerr) tm.ResponseError(reserr) p.SetRequestModifier(tm) p.SetResponseModifier(tm) go p.Serve(l) conn, err := net.Dial("tcp", l.Addr().String()) if err != nil { t.Fatalf("net.Dial(): got %v, want no error", err) } defer conn.Close() req, err := http.NewRequest("CONNECT", "//example.com:443", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } // CONNECT example.com:443 HTTP/1.1 // Host: example.com if err := req.Write(conn); err != nil { t.Fatalf("req.Write(): got %v, want no error", err) } // Response MITM'd from proxy. res, err := http.ReadResponse(bufio.NewReader(conn), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } if got, want := res.StatusCode, 200; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } if got, want := res.Header.Get("Warning"), reserr.Error(); !strings.Contains(got, want) { t.Errorf("res.Header.Get(%q): got %q, want to contain %q", "Warning", got, want) } roots := x509.NewCertPool() roots.AddCert(ca) tlsconn := tls.Client(conn, &tls.Config{ ServerName: "example.com", RootCAs: roots, }) defer tlsconn.Close() req, err = http.NewRequest("GET", "https://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } // GET / HTTP/1.1 // Host: example.com if err := req.Write(tlsconn); err != nil { t.Fatalf("req.Write(): got %v, want no error", err) } // Response from MITM proxy. res, err = http.ReadResponse(bufio.NewReader(tlsconn), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } defer res.Body.Close() if got, want := res.StatusCode, 200; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } if got, want := res.Header.Get("Request-Scheme"), "https"; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Request-Scheme", got, want) } if got, want := res.Header.Get("Warning"), reserr.Error(); !strings.Contains(got, want) { t.Errorf("res.Header.Get(%q): got %q, want to contain %q", "Warning", got, want) } } func TestIntegrationTransparentHTTP(t *testing.T) { t.Parallel() l, err := net.Listen("tcp", "[::]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } p := NewProxy() defer p.Close() tr := martiantest.NewTransport() p.SetRoundTripper(tr) if got, want := p.GetRoundTripper(), tr; got != want { t.Errorf("proxy.GetRoundTripper: got %v, want %v", got, want) } p.SetTimeout(200 * time.Millisecond) tm := martiantest.NewModifier() p.SetRequestModifier(tm) p.SetResponseModifier(tm) go p.Serve(l) conn, err := net.Dial("tcp", l.Addr().String()) if err != nil { t.Fatalf("net.Dial(): got %v, want no error", err) } defer conn.Close() req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } // GET / HTTP/1.1 // Host: www.example.com if err := req.Write(conn); err != nil { t.Fatalf("req.Write(): got %v, want no error", err) } res, err := http.ReadResponse(bufio.NewReader(conn), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } if got, want := res.StatusCode, 200; got != want { t.Fatalf("res.StatusCode: got %d, want %d", got, want) } if !tm.RequestModified() { t.Error("tm.RequestModified(): got false, want true") } if !tm.ResponseModified() { t.Error("tm.ResponseModified(): got false, want true") } } func TestIntegrationTransparentMITM(t *testing.T) { t.Parallel() ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", 2*time.Hour) if err != nil { t.Fatalf("mitm.NewAuthority(): got %v, want no error", err) } mc, err := mitm.NewConfig(ca, priv) if err != nil { t.Fatalf("mitm.NewConfig(): got %v, want no error", err) } // Start TLS listener with config that will generate certificates based on // SNI from connection. // // BUG: tls.Listen will not accept a tls.Config where Certificates is empty, // even though it is supported by tls.Server when GetCertificate is not nil. l, err := net.Listen("tcp", "[::]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } l = tls.NewListener(l, mc.TLS()) p := NewProxy() defer p.Close() tr := martiantest.NewTransport() tr.Func(func(req *http.Request) (*http.Response, error) { res := proxyutil.NewResponse(200, nil, req) res.Header.Set("Request-Scheme", req.URL.Scheme) return res, nil }) p.SetRoundTripper(tr) tm := martiantest.NewModifier() p.SetRequestModifier(tm) p.SetResponseModifier(tm) go p.Serve(l) roots := x509.NewCertPool() roots.AddCert(ca) tlsconn, err := tls.Dial("tcp", l.Addr().String(), &tls.Config{ // Verify the hostname is example.com. ServerName: "example.com", // The certificate will have been generated during MITM, so we need to // verify it with the generated CA certificate. RootCAs: roots, }) if err != nil { t.Fatalf("tls.Dial(): got %v, want no error", err) } defer tlsconn.Close() req, err := http.NewRequest("GET", "https://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } // Write Encrypted request directly, no CONNECT. // GET / HTTP/1.1 // Host: example.com if err := req.Write(tlsconn); err != nil { t.Fatalf("req.Write(): got %v, want no error", err) } res, err := http.ReadResponse(bufio.NewReader(tlsconn), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } defer res.Body.Close() if got, want := res.StatusCode, 200; got != want { t.Fatalf("res.StatusCode: got %d, want %d", got, want) } if got, want := res.Header.Get("Request-Scheme"), "https"; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Request-Scheme", got, want) } if !tm.RequestModified() { t.Errorf("tm.RequestModified(): got false, want true") } if !tm.ResponseModified() { t.Errorf("tm.ResponseModified(): got false, want true") } } func TestIntegrationFailedRoundTrip(t *testing.T) { t.Parallel() l, err := net.Listen("tcp", "[::]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } p := NewProxy() defer p.Close() tr := martiantest.NewTransport() trerr := errors.New("round trip error") tr.RespondError(trerr) p.SetRoundTripper(tr) p.SetTimeout(200 * time.Millisecond) go p.Serve(l) conn, err := net.Dial("tcp", l.Addr().String()) if err != nil { t.Fatalf("net.Dial(): got %v, want no error", err) } defer conn.Close() req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } // GET http://example.com/ HTTP/1.1 // Host: example.com if err := req.WriteProxy(conn); err != nil { t.Fatalf("req.WriteProxy(): got %v, want no error", err) } // Response from failed round trip. res, err := http.ReadResponse(bufio.NewReader(conn), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } defer res.Body.Close() if got, want := res.StatusCode, 502; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } if got, want := res.Header.Get("Warning"), trerr.Error(); !strings.Contains(got, want) { t.Errorf("res.Header.Get(%q): got %q, want to contain %q", "Warning", got, want) } } func TestIntegrationSkipRoundTrip(t *testing.T) { t.Parallel() l, err := net.Listen("tcp", "[::]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } p := NewProxy() defer p.Close() // Transport will be skipped, no 500. tr := martiantest.NewTransport() tr.Respond(500) p.SetRoundTripper(tr) p.SetTimeout(200 * time.Millisecond) tm := martiantest.NewModifier() tm.RequestFunc(func(req *http.Request) { ctx := NewContext(req) ctx.SkipRoundTrip() }) p.SetRequestModifier(tm) go p.Serve(l) conn, err := net.Dial("tcp", l.Addr().String()) if err != nil { t.Fatalf("net.Dial(): got %v, want no error", err) } defer conn.Close() req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } // GET http://example.com/ HTTP/1.1 // Host: example.com if err := req.WriteProxy(conn); err != nil { t.Fatalf("req.WriteProxy(): got %v, want no error", err) } // Response from skipped round trip. res, err := http.ReadResponse(bufio.NewReader(conn), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } defer res.Body.Close() if got, want := res.StatusCode, 200; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } } func TestHTTPThroughConnectWithMITM(t *testing.T) { t.Parallel() l, err := net.Listen("tcp", "[::]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } p := NewProxy() defer p.Close() tm := martiantest.NewModifier() tm.RequestFunc(func(req *http.Request) { ctx := NewContext(req) ctx.SkipRoundTrip() if req.Method != "GET" && req.Method != "CONNECT" { t.Errorf("unexpected method on request handler: %v", req.Method) } }) p.SetRequestModifier(tm) ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", 2*time.Hour) if err != nil { t.Fatalf("mitm.NewAuthority(): got %v, want no error", err) } mc, err := mitm.NewConfig(ca, priv) if err != nil { t.Fatalf("mitm.NewConfig(): got %v, want no error", err) } p.SetMITM(mc) go p.Serve(l) conn, err := net.Dial("tcp", l.Addr().String()) if err != nil { t.Fatalf("net.Dial(): got %v, want no error", err) } defer conn.Close() req, err := http.NewRequest("CONNECT", "//example.com:80", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } // CONNECT example.com:80 HTTP/1.1 // Host: example.com if err := req.Write(conn); err != nil { t.Fatalf("req.Write(): got %v, want no error", err) } // Response skipped round trip. res, err := http.ReadResponse(bufio.NewReader(conn), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } res.Body.Close() if got, want := res.StatusCode, 200; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } req, err = http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } // GET http://example.com/ HTTP/1.1 // Host: example.com if err := req.WriteProxy(conn); err != nil { t.Fatalf("req.WriteProxy(): got %v, want no error", err) } // Response from skipped round trip. res, err = http.ReadResponse(bufio.NewReader(conn), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } res.Body.Close() if got, want := res.StatusCode, 200; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } req, err = http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } // GET http://example.com/ HTTP/1.1 // Host: example.com if err := req.WriteProxy(conn); err != nil { t.Fatalf("req.WriteProxy(): got %v, want no error", err) } // Response from skipped round trip. res, err = http.ReadResponse(bufio.NewReader(conn), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } res.Body.Close() if got, want := res.StatusCode, 200; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } } func TestServerClosesConnection(t *testing.T) { t.Parallel() dstl, err := net.Listen("tcp", "[::]:0") if err != nil { t.Fatalf("Failed to create http listener: %v", err) } defer dstl.Close() go func() { t.Logf("Waiting for server side connection") conn, err := dstl.Accept() if err != nil { t.Fatalf("Got error while accepting connection on destination listener: %v", err) } t.Logf("Accepted server side connection") buf := make([]byte, 16384) if _, err := conn.Read(buf); err != nil { t.Fatalf("Error reading: %v", err) } _, err = conn.Write([]byte("HTTP/1.1 301 MOVED PERMANENTLY\r\n" + "Server: \r\n" + "Date: \r\n" + "Referer: \r\n" + "Location: http://www.foo.com/\r\n" + "Content-type: text/html\r\n" + "Connection: close\r\n\r\n")) if err != nil { t.Fatalf("Got error while writting to connection on destination listener: %v", err) } conn.Close() }() l, err := net.Listen("tcp", "[::]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } ca, priv, err := mitm.NewAuthority("martian.proxy", "Martian Authority", 2*time.Hour) if err != nil { t.Fatalf("mitm.NewAuthority(): got %v, want no error", err) } mc, err := mitm.NewConfig(ca, priv) if err != nil { t.Fatalf("mitm.NewConfig(): got %v, want no error", err) } p := NewProxy() p.SetMITM(mc) defer p.Close() // Start the proxy with a listener that will return a temporary error on // Accept() three times. go p.Serve(newTimeoutListener(l, 3)) conn, err := net.Dial("tcp", l.Addr().String()) if err != nil { t.Fatalf("net.Dial(): got %v, want no error", err) } defer conn.Close() req, err := http.NewRequest("CONNECT", fmt.Sprintf("//%s", dstl.Addr().String()), nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } // CONNECT example.com:443 HTTP/1.1 // Host: example.com if err := req.Write(conn); err != nil { t.Fatalf("req.Write(): got %v, want no error", err) } res, err := http.ReadResponse(bufio.NewReader(conn), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } res.Body.Close() _, err = conn.Write([]byte("GET / HTTP/1.1\r\n" + "User-Agent: curl/7.35.0\r\n" + fmt.Sprintf("Host: %s\r\n", dstl.Addr()) + "Accept: */*\r\n\r\n")) if err != nil { t.Fatalf("Error while writing GET request: %v", err) } res, err = http.ReadResponse(bufio.NewReader(io.TeeReader(conn, os.Stderr)), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } _, err = ioutil.ReadAll(res.Body) if err != nil { t.Fatalf("error while ReadAll: %v", err) } defer res.Body.Close() } // TestRacyClose checks that creating a proxy, serving from it, and closing // it in rapid succession doesn't result in race warnings. // See https://github.com/google/martian/issues/286. func TestRacyClose(t *testing.T) { t.Parallel() log.SetLevel(log.Silent) // avoid "failed to accept" messages because we close l openAndConnect := func() { l, err := net.Listen("tcp", "[::]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } defer l.Close() // to make p.Serve exit p := NewProxy() go p.Serve(l) defer p.Close() conn, err := net.Dial("tcp", l.Addr().String()) if err != nil { t.Fatalf("net.Dial(): got %v, want no error", err) } defer conn.Close() } // Repeat a bunch of times to make failures more repeatable. for i := 0; i < 100; i++ { openAndConnect() } } martian-3.3.2/proxy_trafficshaping_test.go000066400000000000000000000361311421371434000207740ustar00rootroot00000000000000package martian import ( "bufio" "bytes" "io/ioutil" "net" "net/http" "net/http/httptest" "strings" "testing" "time" "github.com/google/martian/v3/log" "github.com/google/martian/v3/martiantest" "github.com/google/martian/v3/trafficshape" ) // Tests that sending data of length 600 bytes with max bandwidth of 100 bytes/s takes // atleast 4.9s. Uses the Close Connection action to immediately close the connection // upon the proxy writing 600 bytes. (4.9s ~ 5s = 600b /100b/s - 1s) func TestConstantThrottleAndClose(t *testing.T) { log.SetLevel(log.Info) l, err := net.Listen("tcp", "[::]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } tsl := trafficshape.NewListener(l) tsh := trafficshape.NewHandler(tsl) // This is the data to be sent. testString := strings.Repeat("0", 600) // Traffic shaping config request. jsonString := `{ "trafficshape": { "shapes": [ { "url_regex": "http://example/example", "throttles": [ { "bytes": "0-", "bandwidth": 100 } ], "close_connections": [ { "byte": 600, "count": 1 } ] } ] } }` tsReq, err := http.NewRequest("POST", "test", bytes.NewBufferString(jsonString)) rw := httptest.NewRecorder() tsh.ServeHTTP(rw, tsReq) res := rw.Result() if got, want := res.StatusCode, 200; got != want { t.Fatalf("res.StatusCode: got %d, want %d", got, want) } p := NewProxy() defer p.Close() p.SetRequestModifier(nil) p.SetResponseModifier(nil) tr := martiantest.NewTransport() p.SetRoundTripper(tr) p.SetTimeout(15 * time.Second) tm := martiantest.NewModifier() tm.RequestFunc(func(req *http.Request) { ctx := NewContext(req) ctx.SkipRoundTrip() }) tm.ResponseFunc(func(res *http.Response) { res.StatusCode = http.StatusOK res.Body = ioutil.NopCloser(bytes.NewBufferString(testString)) }) p.SetRequestModifier(tm) p.SetResponseModifier(tm) go p.Serve(tsl) c1 := make(chan string) conn, err := net.Dial("tcp", l.Addr().String()) defer conn.Close() if err != nil { t.Fatalf("net.Dial(): got %v, want no error", err) } go func() { req, err := http.NewRequest("GET", "http://example/example", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := req.WriteProxy(conn); err != nil { t.Fatalf("req.WriteProxy(): got %v, want no error", err) } res, err := http.ReadResponse(bufio.NewReader(conn), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } body, _ := ioutil.ReadAll(res.Body) bodystr := string(body) c1 <- bodystr }() var bodystr string select { case bodystringc := <-c1: t.Errorf("took < 4.9s, should take at least 4.9s") bodystr = bodystringc case <-time.After(4900 * time.Millisecond): bodystringc := <-c1 bodystr = bodystringc } if bodystr != testString { t.Errorf("res.Body: got %s, want %s", bodystr, testString) } } // Tests that sleeping for 5s and then closing the connection // upon reading 200 bytes, with a bandwidth of 5000 bytes/s // takes at least 4.9s, and results in a correctly trimmed // response body. (200 0s instead of 500 0s) func TestSleepAndClose(t *testing.T) { log.SetLevel(log.Info) l, err := net.Listen("tcp", "[::]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } tsl := trafficshape.NewListener(l) tsh := trafficshape.NewHandler(tsl) // This is the data to be sent. testString := strings.Repeat("0", 500) // Traffic shaping config request. jsonString := `{ "trafficshape": { "shapes": [ { "url_regex": "http://example/example", "throttles": [ { "bytes": "0-", "bandwidth": 5000 } ], "halts": [ { "byte": 100, "duration": 5000, "count": 1 } ], "close_connections": [ { "byte": 200, "count": 1 } ] } ] } }` tsReq, err := http.NewRequest("POST", "test", bytes.NewBufferString(jsonString)) rw := httptest.NewRecorder() tsh.ServeHTTP(rw, tsReq) res := rw.Result() if got, want := res.StatusCode, 200; got != want { t.Fatalf("res.StatusCode: got %d, want %d", got, want) } p := NewProxy() defer p.Close() p.SetRequestModifier(nil) p.SetResponseModifier(nil) tr := martiantest.NewTransport() p.SetRoundTripper(tr) p.SetTimeout(15 * time.Second) tm := martiantest.NewModifier() tm.RequestFunc(func(req *http.Request) { ctx := NewContext(req) ctx.SkipRoundTrip() }) tm.ResponseFunc(func(res *http.Response) { res.StatusCode = http.StatusOK res.Body = ioutil.NopCloser(bytes.NewBufferString(testString)) }) p.SetRequestModifier(tm) p.SetResponseModifier(tm) go p.Serve(tsl) c1 := make(chan string) conn, err := net.Dial("tcp", l.Addr().String()) defer conn.Close() if err != nil { t.Fatalf("net.Dial(): got %v, want no error", err) } go func() { req, err := http.NewRequest("GET", "http://example/example", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := req.WriteProxy(conn); err != nil { t.Fatalf("req.WriteProxy(): got %v, want no error", err) } res, err := http.ReadResponse(bufio.NewReader(conn), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } body, _ := ioutil.ReadAll(res.Body) bodystr := string(body) c1 <- bodystr }() var bodystr string select { case bodystringc := <-c1: t.Errorf("took < 4.9s, should take at least 4.9s") bodystr = bodystringc case <-time.After(4900 * time.Millisecond): bodystringc := <-c1 bodystr = bodystringc } if want := strings.Repeat("0", 200); bodystr != want { t.Errorf("res.Body: got %s, want %s", bodystr, want) } } // Similar to TestConstantThrottleAndClose, except that it applies // the throttle only in a specific byte range, and modifies the // the response to lie in the byte range. func TestConstantThrottleAndCloseByteRange(t *testing.T) { log.SetLevel(log.Info) l, err := net.Listen("tcp", "[::]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } tsl := trafficshape.NewListener(l) tsh := trafficshape.NewHandler(tsl) // This is the data to be sent. testString := strings.Repeat("0", 600) // Traffic shaping config request. jsonString := `{ "trafficshape": { "shapes": [ { "url_regex": "http://example/example", "throttles": [ { "bytes": "500-", "bandwidth": 100 } ], "close_connections": [ { "byte": 1100, "count": 1 } ] } ] } }` tsReq, err := http.NewRequest("POST", "test", bytes.NewBufferString(jsonString)) rw := httptest.NewRecorder() tsh.ServeHTTP(rw, tsReq) res := rw.Result() if got, want := res.StatusCode, 200; got != want { t.Fatalf("res.StatusCode: got %d, want %d", got, want) } p := NewProxy() defer p.Close() p.SetRequestModifier(nil) p.SetResponseModifier(nil) tr := martiantest.NewTransport() p.SetRoundTripper(tr) p.SetTimeout(15 * time.Second) tm := martiantest.NewModifier() tm.RequestFunc(func(req *http.Request) { ctx := NewContext(req) ctx.SkipRoundTrip() }) tm.ResponseFunc(func(res *http.Response) { res.StatusCode = http.StatusPartialContent res.Body = ioutil.NopCloser(bytes.NewBufferString(testString)) res.Header.Set("Content-Range", "bytes 500-1100/1100") }) p.SetRequestModifier(tm) p.SetResponseModifier(tm) go p.Serve(tsl) c1 := make(chan string) conn, err := net.Dial("tcp", l.Addr().String()) defer conn.Close() if err != nil { t.Fatalf("net.Dial(): got %v, want no error", err) } go func() { req, err := http.NewRequest("GET", "http://example/example", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := req.WriteProxy(conn); err != nil { t.Fatalf("req.WriteProxy(): got %v, want no error", err) } res, err := http.ReadResponse(bufio.NewReader(conn), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } body, _ := ioutil.ReadAll(res.Body) bodystr := string(body) c1 <- bodystr }() var bodystr string select { case bodystringc := <-c1: t.Errorf("took < 4.9s, should take at least 4.9s") bodystr = bodystringc case <-time.After(4900 * time.Millisecond): bodystringc := <-c1 bodystr = bodystringc } if bodystr != testString { t.Errorf("res.Body: got %s, want %s", bodystr, testString) } } // Opens up 5 concurrent connections, and sets the // max global bandwidth for the url regex to be 250b/s. // Every connection tries to read 500b of data, but since // the global bandwidth for the particular regex is 250, // it should take at least 5 * 500b / 250b/s -1s = 9s to read // everything. func TestMaxBandwidth(t *testing.T) { log.SetLevel(log.Info) l, err := net.Listen("tcp", "[::]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } tsl := trafficshape.NewListener(l) tsh := trafficshape.NewHandler(tsl) // This is the data to be sent. testString := strings.Repeat("0", 500) // Traffic shaping config request. jsonString := `{ "trafficshape": { "shapes": [ { "url_regex": "http://example/example", "max_global_bandwidth": 250, "close_connections": [ { "byte": 500, "count": 5 } ] } ] } }` tsReq, err := http.NewRequest("POST", "test", bytes.NewBufferString(jsonString)) rw := httptest.NewRecorder() tsh.ServeHTTP(rw, tsReq) res := rw.Result() if got, want := res.StatusCode, 200; got != want { t.Fatalf("res.StatusCode: got %d, want %d", got, want) } p := NewProxy() defer p.Close() p.SetRequestModifier(nil) p.SetResponseModifier(nil) tr := martiantest.NewTransport() p.SetRoundTripper(tr) p.SetTimeout(20 * time.Second) tm := martiantest.NewModifier() tm.RequestFunc(func(req *http.Request) { ctx := NewContext(req) ctx.SkipRoundTrip() }) tm.ResponseFunc(func(res *http.Response) { res.StatusCode = http.StatusOK res.Body = ioutil.NopCloser(bytes.NewBufferString(testString)) }) p.SetRequestModifier(tm) p.SetResponseModifier(tm) go p.Serve(tsl) numChannels := 5 channels := make([]chan string, numChannels) for i := 0; i < numChannels; i++ { channels[i] = make(chan string) } for i := 0; i < numChannels; i++ { go func(i int) { conn, err := net.Dial("tcp", l.Addr().String()) defer conn.Close() if err != nil { t.Fatalf("net.Dial(): got %v, want no error", err) } req, err := http.NewRequest("GET", "http://example/example", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := req.WriteProxy(conn); err != nil { t.Fatalf("req.WriteProxy(): got %v, want no error", err) } res, err := http.ReadResponse(bufio.NewReader(conn), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } body, _ := ioutil.ReadAll(res.Body) bodystr := string(body) if i != 0 { <-channels[i-1] } channels[i] <- bodystr }(i) } var bodystr string select { case bodystringc := <-channels[numChannels-1]: t.Errorf("took < 8.9s, should take at least 8.9s") bodystr = bodystringc case <-time.After(8900 * time.Millisecond): bodystringc := <-channels[numChannels-1] bodystr = bodystringc } if bodystr != testString { t.Errorf("res.Body: got %s, want %s", bodystr, testString) } } // Makes 2 requests, with the first one having a byte range starting // at byte 250, and adds a close connection action at byte 450. // The first request should hit the action sooner, // and delete it. The second request should read the whole // data (500b) func TestConcurrentResponseActions(t *testing.T) { log.SetLevel(log.Info) l, err := net.Listen("tcp", "[::]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } tsl := trafficshape.NewListener(l) tsh := trafficshape.NewHandler(tsl) // This is the data to be sent. testString := strings.Repeat("0", 500) // Traffic shaping config request. jsonString := `{ "trafficshape": { "shapes": [ { "url_regex": "http://example/example", "throttles": [ { "bytes": "-", "bandwidth": 250 } ], "close_connections": [ { "byte": 450, "count": 1 }, { "byte": 500, "count": 1 } ] } ] } }` tsReq, err := http.NewRequest("POST", "test", bytes.NewBufferString(jsonString)) rw := httptest.NewRecorder() tsh.ServeHTTP(rw, tsReq) res := rw.Result() if got, want := res.StatusCode, 200; got != want { t.Fatalf("res.StatusCode: got %d, want %d", got, want) } p := NewProxy() defer p.Close() p.SetRequestModifier(nil) p.SetResponseModifier(nil) tr := martiantest.NewTransport() p.SetRoundTripper(tr) p.SetTimeout(20 * time.Second) tm := martiantest.NewModifier() tm.RequestFunc(func(req *http.Request) { ctx := NewContext(req) ctx.SkipRoundTrip() }) tm.ResponseFunc(func(res *http.Response) { cr := res.Request.Header.Get("ContentRange") res.StatusCode = http.StatusOK res.Body = ioutil.NopCloser(bytes.NewBufferString(testString)) if cr != "" { res.StatusCode = http.StatusPartialContent res.Header.Set("Content-Range", cr) } }) p.SetRequestModifier(tm) p.SetResponseModifier(tm) go p.Serve(tsl) c1 := make(chan string) c2 := make(chan string) go func() { conn, err := net.Dial("tcp", l.Addr().String()) defer conn.Close() if err != nil { t.Fatalf("net.Dial(): got %v, want no error", err) } req, err := http.NewRequest("GET", "http://example/example", nil) req.Header.Set("ContentRange", "bytes 250-1000/1000") if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := req.WriteProxy(conn); err != nil { t.Fatalf("req.WriteProxy(): got %v, want no error", err) } res, err := http.ReadResponse(bufio.NewReader(conn), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } body, _ := ioutil.ReadAll(res.Body) bodystr := string(body) c1 <- bodystr }() go func() { conn, err := net.Dial("tcp", l.Addr().String()) defer conn.Close() if err != nil { t.Fatalf("net.Dial(): got %v, want no error", err) } req, err := http.NewRequest("GET", "http://example/example", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := req.WriteProxy(conn); err != nil { t.Fatalf("req.WriteProxy(): got %v, want no error", err) } res, err := http.ReadResponse(bufio.NewReader(conn), req) if err != nil { t.Fatalf("http.ReadResponse(): got %v, want no error", err) } body, _ := ioutil.ReadAll(res.Body) bodystr := string(body) c2 <- bodystr }() bodystr1 := <-c1 bodystr2 := <-c2 if want1 := strings.Repeat("0", 200); bodystr1 != want1 { t.Errorf("res.Body: got %s, want %s", bodystr1, want1) } if want2 := strings.Repeat("0", 500); bodystr2 != want2 { t.Errorf("res.Body: got %s, want %s", bodystr2, want2) } } martian-3.3.2/proxyauth/000077500000000000000000000000001421371434000152145ustar00rootroot00000000000000martian-3.3.2/proxyauth/proxyauth.go000066400000000000000000000055361421371434000176170ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package proxyauth provides authentication support via the // Proxy-Authorization header. package proxyauth import ( "encoding/base64" "net/http" "strings" "github.com/google/martian/v3" "github.com/google/martian/v3/auth" ) var noop = martian.Noop("proxyauth.Modifier") // Modifier is the proxy authentication modifier. type Modifier struct { reqmod martian.RequestModifier resmod martian.ResponseModifier } // NewModifier returns a new proxy authentication modifier. func NewModifier() *Modifier { return &Modifier{ reqmod: noop, resmod: noop, } } // SetRequestModifier sets the request modifier. func (m *Modifier) SetRequestModifier(reqmod martian.RequestModifier) { if reqmod == nil { reqmod = noop } m.reqmod = reqmod } // SetResponseModifier sets the response modifier. func (m *Modifier) SetResponseModifier(resmod martian.ResponseModifier) { if resmod == nil { resmod = noop } m.resmod = resmod } // ModifyRequest sets the auth ID in the context from the request iff it has // not already been set and runs reqmod.ModifyRequest. If the underlying // modifier has indicated via auth error that no valid auth credentials // have been found we set ctx.SkipRoundTrip. func (m *Modifier) ModifyRequest(req *http.Request) error { ctx := martian.NewContext(req) actx := auth.FromContext(ctx) actx.SetID(id(req.Header)) err := m.reqmod.ModifyRequest(req) if actx.Error() != nil { ctx.SkipRoundTrip() } return err } // ModifyResponse runs resmod.ModifyResponse and modifies the response to // include the correct status code and headers if auth error is present. // // If an error is returned from resmod.ModifyResponse it is returned. func (m *Modifier) ModifyResponse(res *http.Response) error { ctx := martian.NewContext(res.Request) actx := auth.FromContext(ctx) err := m.resmod.ModifyResponse(res) if actx.Error() != nil { res.StatusCode = http.StatusProxyAuthRequired res.Header.Set("Proxy-Authenticate", "Basic") } return err } // id returns an ID derived from the Proxy-Authorization header username and password. func id(header http.Header) string { id := strings.TrimPrefix(header.Get("Proxy-Authorization"), "Basic ") data, err := base64.StdEncoding.DecodeString(id) if err != nil { return "" } return string(data) } martian-3.3.2/proxyauth/proxyauth_test.go000066400000000000000000000116731421371434000206550ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package proxyauth import ( "encoding/base64" "errors" "net/http" "testing" "github.com/google/martian/v3" "github.com/google/martian/v3/auth" "github.com/google/martian/v3/martiantest" "github.com/google/martian/v3/proxyutil" ) func encode(v string) string { return base64.StdEncoding.EncodeToString([]byte(v)) } func TestNoModifiers(t *testing.T) { m := NewModifier() m.SetRequestModifier(nil) m.SetResponseModifier(nil) req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } _, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("martian.TestContext(): got %v, want no error", err) } defer remove() if err := m.ModifyRequest(req); err != nil { t.Errorf("ModifyRequest(): got %v, want no error", err) } res := proxyutil.NewResponse(200, nil, req) if err := m.ModifyResponse(res); err != nil { t.Errorf("ModifyResponse(): got %v, want no error", err) } } func TestProxyAuth(t *testing.T) { m := NewModifier() tm := martiantest.NewModifier() m.SetRequestModifier(tm) m.SetResponseModifier(tm) req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Header.Set("Proxy-Authorization", "Basic "+encode("user:pass")) ctx, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("martian.TestContext(): got %v, want no error", err) } defer remove() if err := m.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } actx := auth.FromContext(ctx) if got, want := actx.ID(), "user:pass"; got != want { t.Fatalf("actx.ID(): got %q, want %q", got, want) } if err := m.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if !tm.RequestModified() { t.Error("tm.RequestModified(): got false, want true") } res := proxyutil.NewResponse(200, nil, req) if err := m.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if err := m.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if !tm.ResponseModified() { t.Error("tm.ResponseModified(): got false, want true") } if got, want := res.StatusCode, 200; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } if got, want := res.Header.Get("Proxy-Authenticate"), ""; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Proxy-Authenticate", got, want) } } func TestProxyAuthInvalidCredentials(t *testing.T) { m := NewModifier() authErr := errors.New("auth error") tm := martiantest.NewModifier() tm.RequestFunc(func(req *http.Request) { ctx := martian.NewContext(req) actx := auth.FromContext(ctx) actx.SetError(authErr) }) tm.ResponseFunc(func(res *http.Response) { ctx := martian.NewContext(res.Request) actx := auth.FromContext(ctx) actx.SetError(authErr) }) m.SetRequestModifier(tm) m.SetResponseModifier(tm) req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Header.Set("Proxy-Authorization", "Basic "+encode("user:pass")) ctx, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("martian.TestContext(): got %v, want no error", err) } defer remove() if err := m.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if !tm.RequestModified() { t.Error("tm.RequestModified(): got false, want true") } actx := auth.FromContext(ctx) if actx.Error() != authErr { t.Fatalf("auth.Error(): got %v, want %v", actx.Error(), authErr) } actx.SetError(nil) res := proxyutil.NewResponse(200, nil, req) if err := m.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if !tm.ResponseModified() { t.Error("tm.ResponseModified(): got false, want true") } if actx.Error() != authErr { t.Fatalf("auth.Error(): got %v, want %v", actx.Error(), authErr) } if got, want := res.StatusCode, http.StatusProxyAuthRequired; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } if got, want := res.Header.Get("Proxy-Authenticate"), "Basic"; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Proxy-Authenticate", got, want) } } martian-3.3.2/proxyutil/000077500000000000000000000000001421371434000152305ustar00rootroot00000000000000martian-3.3.2/proxyutil/header.go000066400000000000000000000106421421371434000170120ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package proxyutil import ( "fmt" "net/http" "strconv" ) // Header is a generic representation of a set of HTTP headers for requests and // responses. type Header struct { h http.Header host func() string cl func() int64 te func() []string setHost func(string) setCL func(int64) setTE func([]string) } // RequestHeader returns a new set of headers from a request. func RequestHeader(req *http.Request) *Header { return &Header{ h: req.Header, host: func() string { return req.Host }, cl: func() int64 { return req.ContentLength }, te: func() []string { return req.TransferEncoding }, setHost: func(host string) { req.Host = host }, setCL: func(cl int64) { req.ContentLength = cl }, setTE: func(te []string) { req.TransferEncoding = te }, } } // ResponseHeader returns a new set of headers from a request. func ResponseHeader(res *http.Response) *Header { return &Header{ h: res.Header, host: func() string { return "" }, cl: func() int64 { return res.ContentLength }, te: func() []string { return res.TransferEncoding }, setHost: func(string) {}, setCL: func(cl int64) { res.ContentLength = cl }, setTE: func(te []string) { res.TransferEncoding = te }, } } // Set sets value at header name for the request or response. func (h *Header) Set(name, value string) error { switch http.CanonicalHeaderKey(name) { case "Host": h.setHost(value) case "Content-Length": cl, err := strconv.ParseInt(value, 10, 64) if err != nil { return err } h.setCL(cl) case "Transfer-Encoding": h.setTE([]string{value}) default: h.h.Set(name, value) } return nil } // Add appends the value to the existing header at name for the request or // response. func (h *Header) Add(name, value string) error { switch http.CanonicalHeaderKey(name) { case "Host": if h.host() != "" { return fmt.Errorf("proxyutil: illegal header multiple: %s", "Host") } return h.Set(name, value) case "Content-Length": if h.cl() > 0 { return fmt.Errorf("proxyutil: illegal header multiple: %s", "Content-Length") } return h.Set(name, value) case "Transfer-Encoding": h.setTE(append(h.te(), value)) default: h.h.Add(name, value) } return nil } // Get returns the first value at header name for the request or response. func (h *Header) Get(name string) string { switch http.CanonicalHeaderKey(name) { case "Host": return h.host() case "Content-Length": if h.cl() < 0 { return "" } return strconv.FormatInt(h.cl(), 10) case "Transfer-Encoding": if len(h.te()) < 1 { return "" } return h.te()[0] default: return h.h.Get(name) } } // All returns all the values for header name. If the header does not exist it // returns nil, false. func (h *Header) All(name string) ([]string, bool) { switch http.CanonicalHeaderKey(name) { case "Host": if h.host() == "" { return nil, false } return []string{h.host()}, true case "Content-Length": if h.cl() <= 0 { return nil, false } return []string{strconv.FormatInt(h.cl(), 10)}, true case "Transfer-Encoding": if h.te() == nil { return nil, false } return h.te(), true default: vs, ok := h.h[http.CanonicalHeaderKey(name)] return vs, ok } } // Del deletes the header at name for the request or response. func (h *Header) Del(name string) { switch http.CanonicalHeaderKey(name) { case "Host": h.setHost("") case "Content-Length": h.setCL(-1) case "Transfer-Encoding": h.setTE(nil) default: h.h.Del(name) } } // Map returns an http.Header that includes Host, Content-Length, and // Transfer-Encoding. func (h *Header) Map() http.Header { hm := make(http.Header) for k, vs := range h.h { hm[k] = vs } for _, k := range []string{ "Host", "Content-Length", "Transfer-Encoding", } { vs, ok := h.All(k) if !ok { continue } hm[k] = vs } return hm } martian-3.3.2/proxyutil/header_test.go000066400000000000000000000220051421371434000200450ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package proxyutil import ( "net/http" "reflect" "testing" ) func TestRequestHeader(t *testing.T) { req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } h := RequestHeader(req) tt := []struct { name string value string }{ { name: "Host", value: "example.com", }, { name: "Test-Header", value: "true", }, { name: "Content-Length", value: "100", }, { name: "Transfer-Encoding", value: "chunked", }, } for i, tc := range tt { if err := h.Set(tc.name, tc.value); err != nil { t.Errorf("%d. h.Set(%q, %q): got %v, want no error", i, tc.name, tc.value, err) } } if got, want := req.Host, "example.com"; got != want { t.Errorf("req.Host: got %q, want %q", got, want) } if got, want := req.Header.Get("Test-Header"), "true"; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", "Test-Header", got, want) } if got, want := req.ContentLength, int64(100); got != want { t.Errorf("req.ContentLength: got %d, want %d", got, want) } if got, want := req.TransferEncoding, []string{"chunked"}; !reflect.DeepEqual(got, want) { t.Errorf("req.TransferEncoding: got %v, want %v", got, want) } if got, want := len(h.Map()), 4; got != want { t.Errorf("h.Map(): got %d entries, want %d entries", got, want) } for n, vs := range h.Map() { var want string switch n { case "Host": want = "example.com" case "Content-Length": want = "100" case "Transfer-Encoding": want = "chunked" case "Test-Header": want = "true" default: t.Errorf("h.Map(): got unexpected %s header", n) } if got := vs[0]; got != want { t.Errorf("h.Map(): got %s header with value %s, want value %s", n, got, want) } } for i, tc := range tt { got, ok := h.All(tc.name) if !ok { t.Errorf("%d. h.All(%q): got false, want true", i, tc.name) } if want := []string{tc.value}; !reflect.DeepEqual(got, want) { t.Errorf("%d. h.All(%q): got %v, want %v", i, tc.name, got, want) } if got, want := h.Get(tc.name), tc.value; got != want { t.Errorf("%d. h.Get(%q): got %q, want %q", i, tc.name, got, want) } h.Del(tc.name) } if got, want := req.Host, ""; got != want { t.Errorf("req.Host: got %q, want %q", got, want) } if got, want := req.Header.Get("Test-Header"), ""; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", "Test-Header", got, want) } if got, want := req.ContentLength, int64(-1); got != want { t.Errorf("req.ContentLength: got %d, want %d", got, want) } if got := req.TransferEncoding; got != nil { t.Errorf("req.TransferEncoding: got %v, want nil", got) } for i, tc := range tt { if got, want := h.Get(tc.name), ""; got != want { t.Errorf("%d. h.Get(%q): got %q, want %q", i, tc.name, got, want) } got, ok := h.All(tc.name) if ok { t.Errorf("%d. h.All(%q): got ok, want !ok", i, tc.name) } if got != nil { t.Errorf("%d. h.All(%q): got %v, want nil", i, tc.name, got) } } } func TestRequestHeaderAdd(t *testing.T) { req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Host = "" // Set to empty so add may overwrite. h := RequestHeader(req) tt := []struct { name string values []string errOnSecondValue bool }{ { name: "Host", values: []string{"example.com", "invalid.com"}, errOnSecondValue: true, }, { name: "Test-Header", values: []string{"first", "second"}, }, { name: "Content-Length", values: []string{"100", "101"}, errOnSecondValue: true, }, { name: "Transfer-Encoding", values: []string{"chunked", "gzip"}, }, } for i, tc := range tt { if err := h.Add(tc.name, tc.values[0]); err != nil { t.Errorf("%d. h.Add(%q, %q): got %v, want no error", i, tc.name, tc.values[0], err) } if err := h.Add(tc.name, tc.values[1]); err != nil && !tc.errOnSecondValue { t.Errorf("%d. h.Add(%q, %q): got %v, want no error", i, tc.name, tc.values[1], err) } } if got, want := req.Host, "example.com"; got != want { t.Errorf("req.Host: got %q, want %q", got, want) } if got, want := req.Header["Test-Header"], []string{"first", "second"}; !reflect.DeepEqual(got, want) { t.Errorf("req.Header[%q]: got %v, want %v", "Test-Header", got, want) } if got, want := req.ContentLength, int64(100); got != want { t.Errorf("req.ContentLength: got %d, want %d", got, want) } if got, want := req.TransferEncoding, []string{"chunked", "gzip"}; !reflect.DeepEqual(got, want) { t.Errorf("req.TransferEncoding: got %v, want %v", got, want) } } func TestResponseHeader(t *testing.T) { res := NewResponse(200, nil, nil) h := ResponseHeader(res) tt := []struct { name string value string }{ { name: "Test-Header", value: "true", }, { name: "Content-Length", value: "100", }, { name: "Transfer-Encoding", value: "chunked", }, } for i, tc := range tt { if err := h.Set(tc.name, tc.value); err != nil { t.Errorf("%d. h.Set(%q, %q): got %v, want no error", i, tc.name, tc.value, err) } } if got, want := res.Header.Get("Test-Header"), "true"; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Test-Header", got, want) } if got, want := res.ContentLength, int64(100); got != want { t.Errorf("res.ContentLength: got %d, want %d", got, want) } if got, want := res.TransferEncoding, []string{"chunked"}; !reflect.DeepEqual(got, want) { t.Errorf("res.TransferEncoding: got %v, want %v", got, want) } if got, want := len(h.Map()), 3; got != want { t.Errorf("h.Map(): got %d entries, want %d entries", got, want) } for n, vs := range h.Map() { var want string switch n { case "Content-Length": want = "100" case "Transfer-Encoding": want = "chunked" case "Test-Header": want = "true" default: t.Errorf("h.Map(): got unexpected %s header", n) } if got := vs[0]; got != want { t.Errorf("h.Map(): got %s header with value %s, want value %s", n, got, want) } } for i, tc := range tt { got, ok := h.All(tc.name) if !ok { t.Errorf("%d. h.All(%q): got false, want true", i, tc.name) } if want := []string{tc.value}; !reflect.DeepEqual(got, want) { t.Errorf("%d. h.All(%q): got %v, want %v", i, tc.name, got, want) } if got, want := h.Get(tc.name), tc.value; got != want { t.Errorf("%d. h.Get(%q): got %q, want %q", i, tc.name, got, want) } h.Del(tc.name) } if got, want := res.Header.Get("Test-Header"), ""; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Test-Header", got, want) } if got, want := res.ContentLength, int64(-1); got != want { t.Errorf("res.ContentLength: got %d, want %d", got, want) } if got := res.TransferEncoding; got != nil { t.Errorf("res.TransferEncoding: got %v, want nil", got) } for i, tc := range tt { if got, want := h.Get(tc.name), ""; got != want { t.Errorf("%d. h.Get(%q): got %q, want %q", i, tc.name, got, want) } got, ok := h.All(tc.name) if ok { t.Errorf("%d. h.All(%q): got ok, want !ok", i, tc.name) } if got != nil { t.Errorf("%d. h.All(%q): got %v, want nil", i, tc.name, got) } } } func TestResponseHeaderAdd(t *testing.T) { res := NewResponse(200, nil, nil) h := ResponseHeader(res) tt := []struct { name string values []string errOnSecondValue bool }{ { name: "Test-Header", values: []string{"first", "second"}, }, { name: "Content-Length", values: []string{"100", "101"}, errOnSecondValue: true, }, { name: "Transfer-Encoding", values: []string{"chunked", "gzip"}, }, } for i, tc := range tt { if err := h.Add(tc.name, tc.values[0]); err != nil { t.Errorf("%d. h.Add(%q, %q): got %v, want no error", i, tc.name, tc.values[0], err) } if err := h.Add(tc.name, tc.values[1]); err != nil && !tc.errOnSecondValue { t.Errorf("%d. h.Add(%q, %q): got %v, want no error", i, tc.name, tc.values[1], err) } } if got, want := res.Header["Test-Header"], []string{"first", "second"}; !reflect.DeepEqual(got, want) { t.Errorf("res.Header[%q]: got %v, want %v", "Test-Header", got, want) } if got, want := res.ContentLength, int64(100); got != want { t.Errorf("res.ContentLength: got %d, want %d", got, want) } if got, want := res.TransferEncoding, []string{"chunked", "gzip"}; !reflect.DeepEqual(got, want) { t.Errorf("res.TransferEncoding: got %v, want %v", got, want) } } martian-3.3.2/proxyutil/proxyutil.go000066400000000000000000000050071421371434000176400ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. /* Package proxyutil provides functionality for building proxies. */ package proxyutil import ( "bytes" "fmt" "io" "io/ioutil" "net/http" "regexp" "strconv" "strings" "time" ) // NewResponse builds new HTTP responses. // If body is nil, an empty byte.Buffer will be provided to be consistent with // the guarantees provided by http.Transport and http.Client. func NewResponse(code int, body io.Reader, req *http.Request) *http.Response { if body == nil { body = &bytes.Buffer{} } rc, ok := body.(io.ReadCloser) if !ok { rc = ioutil.NopCloser(body) } res := &http.Response{ StatusCode: code, Status: fmt.Sprintf("%d %s", code, http.StatusText(code)), Proto: "HTTP/1.1", ProtoMajor: 1, ProtoMinor: 1, Header: http.Header{}, Body: rc, Request: req, } if req != nil { res.Close = req.Close res.Proto = req.Proto res.ProtoMajor = req.ProtoMajor res.ProtoMinor = req.ProtoMinor } return res } // Warning adds an error to the Warning header in the format: 199 "martian" // "error message" "date". func Warning(header http.Header, err error) { date := header.Get("Date") if date == "" { date = time.Now().Format(http.TimeFormat) } w := fmt.Sprintf(`199 "martian" %q %q`, err.Error(), date) header.Add("Warning", w) } // GetRangeStart returns the byte index of the start of the range, if it has one. // Returns 0 if the range header is absent, and -1 if the range header is invalid or // has multi-part ranges. func GetRangeStart(res *http.Response) int64 { if res.StatusCode != http.StatusPartialContent { return 0 } if strings.Contains(res.Header.Get("Content-Type"), "multipart/byteranges") { return -1 } re := regexp.MustCompile(`bytes (\d+)-\d+/\d+`) matchSlice := re.FindStringSubmatch(res.Header.Get("Content-Range")) if len(matchSlice) < 2 { return -1 } num, err := strconv.ParseInt(matchSlice[1], 10, 64) if err != nil { return -1 } return num } martian-3.3.2/proxyutil/proxyutil_test.go000066400000000000000000000051111421371434000206730ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package proxyutil import ( "fmt" "io" "net/http" "strings" "testing" ) func TestNewResponse(t *testing.T) { req, err := http.NewRequest("GET", "http://www.example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } req.Close = true res := NewResponse(200, nil, req) if got, want := res.StatusCode, 200; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } if got, want := res.Status, "200 OK"; got != want { t.Errorf("res.Status: got %q, want %q", got, want) } if !res.Close { t.Error("res.Close: got false, want true") } if got, want := res.Proto, "HTTP/1.1"; got != want { t.Errorf("res.Proto: got %q, want %q", got, want) } if got, want := res.ProtoMajor, 1; got != want { t.Errorf("res.ProtoMajor: got %d, want %d", got, want) } if got, want := res.ProtoMinor, 1; got != want { t.Errorf("res.ProtoMinor: got %d, want %d", got, want) } if res.Header == nil { t.Error("res.Header: got nil, want header") } if _, ok := res.Body.(io.ReadCloser); !ok { t.Error("res.Body.(io.ReadCloser): got !ok, want ok") } if got, want := res.Request, req; got != want { t.Errorf("res.Request: got %v, want %v", got, want) } } func TestWarning(t *testing.T) { hdr := http.Header{} err := fmt.Errorf("modifier error") Warning(hdr, err) if got, want := len(hdr["Warning"]), 1; got != want { t.Fatalf("len(hdr[%q]): got %d, want %d", "Warning", got, want) } want := `199 "martian" "modifier error"` if got := hdr["Warning"][0]; !strings.HasPrefix(got, want) { t.Errorf("hdr[%q][0]: got %q, want to have prefix %q", "Warning", got, want) } hdr.Set("Date", "Mon, 02 Jan 2006 15:04:05 GMT") Warning(hdr, err) if got, want := len(hdr["Warning"]), 2; got != want { t.Fatalf("len(hdr[%q]): got %d, want %d", "Warning", got, want) } want = `199 "martian" "modifier error" "Mon, 02 Jan 2006 15:04:05 GMT"` if got := hdr["Warning"][1]; got != want { t.Errorf("hdr[%q][1]: got %q, want %q", "Warning", got, want) } } martian-3.3.2/querystring/000077500000000000000000000000001421371434000155455ustar00rootroot00000000000000martian-3.3.2/querystring/query_string_filter.go000066400000000000000000000045141421371434000222000ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package querystring import ( "encoding/json" "github.com/google/martian/v3" "github.com/google/martian/v3/filter" "github.com/google/martian/v3/parse" ) var noop = martian.Noop("querystring.Filter") func init() { parse.Register("querystring.Filter", filterFromJSON) } // Filter runs modifiers iff the request query parameter for name matches value. type Filter struct { *filter.Filter } type filterJSON struct { Name string `json:"name"` Value string `json:"value"` Modifier json.RawMessage `json:"modifier"` ElseModifier json.RawMessage `json:"else"` Scope []parse.ModifierType `json:"scope"` } // NewFilter builds a querystring.Filter that filters on name and optionally // value. func NewFilter(name, value string) *Filter { m := NewMatcher(name, value) f := filter.New() f.SetRequestCondition(m) f.SetResponseCondition(m) return &Filter{f} } // filterFromJSON takes a JSON message and returns a querystring.Filter. // // Example JSON: // { // "name": "param", // "value": "example", // "scope": ["request", "response"], // "modifier": { ... } // } func filterFromJSON(b []byte) (*parse.Result, error) { msg := &filterJSON{} if err := json.Unmarshal(b, msg); err != nil { return nil, err } f := NewFilter(msg.Name, msg.Value) r, err := parse.FromJSON(msg.Modifier) if err != nil { return nil, err } f.RequestWhenTrue(r.RequestModifier()) f.ResponseWhenTrue(r.ResponseModifier()) if len(msg.ElseModifier) > 0 { em, err := parse.FromJSON(msg.ElseModifier) if err != nil { return nil, err } if em != nil { f.RequestWhenFalse(em.RequestModifier()) f.ResponseWhenFalse(em.ResponseModifier()) } } return parse.NewResult(f, msg.Scope) } martian-3.3.2/querystring/query_string_filter_test.go000066400000000000000000000217161421371434000232420ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package querystring import ( "errors" "net/http" "strings" "testing" "github.com/google/martian/v3" "github.com/google/martian/v3/martiantest" "github.com/google/martian/v3/parse" "github.com/google/martian/v3/proxyutil" "github.com/google/martian/v3/verify" // Import to register header.Modifier with JSON parser. _ "github.com/google/martian/v3/header" ) func TestNoModifiers(t *testing.T) { f := NewFilter("", "") f.SetRequestModifier(nil) f.SetResponseModifier(nil) req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := f.ModifyRequest(req); err != nil { t.Errorf("ModifyRequest(): got %v, want no error", err) } res := proxyutil.NewResponse(200, nil, req) if err := f.ModifyResponse(res); err != nil { t.Errorf("ModifyResponse(): got %v, want no error", err) } } func TestQueryStringFilterWithQuery(t *testing.T) { // Name only, no value. f := NewFilter("match", "") tm := martiantest.NewModifier() f.SetRequestModifier(tm) f.SetResponseModifier(tm) req, err := http.NewRequest("GET", "http://martian.local?match=any", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := f.ModifyRequest(req); err != nil { t.Errorf("ModifyRequest(): got %v, want no error", err) } if !tm.RequestModified() { t.Error("tm.RequestModified(): got false, want true") } res := proxyutil.NewResponse(200, nil, req) if err := f.ModifyResponse(res); err != nil { t.Errorf("ModifyResponse(): got %v, want no error", err) } if !tm.ResponseModified() { t.Error("tm.ResponseModified(): got false, want true") } tm.Reset() req, err = http.NewRequest("GET", "http://martian.local?nomatch", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := f.ModifyRequest(req); err != nil { t.Errorf("ModifyRequest(): got %v, want no error", err) } res = proxyutil.NewResponse(200, nil, req) if err := f.ModifyResponse(res); err != nil { t.Errorf("ModifyResponse(): got %v, want no error", err) } if tm.RequestModified() { t.Error("tm.RequestModified(): got true, want false") } if tm.ResponseModified() { t.Error("tm.ResponseModified(): got true, want false") } tm.Reset() // Name and value. f = NewFilter("match", "value") f.SetRequestModifier(tm) f.SetResponseModifier(tm) req, err = http.NewRequest("GET", "http://martian.local?match=value", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := f.ModifyRequest(req); err != nil { t.Errorf("ModifyRequest(): got %v, want no error", err) } res = proxyutil.NewResponse(200, nil, req) if err := f.ModifyResponse(res); err != nil { t.Errorf("ModifyResponse(): got %v, want no error", err) } if !tm.RequestModified() { t.Error("tm.RequestModified(): got false, want true") } if !tm.ResponseModified() { t.Error("tm.ResponseModified(): got false, want true") } tm.Reset() req, err = http.NewRequest("GET", "http://martian.local?match=notvalue", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := f.ModifyRequest(req); err != nil { t.Errorf("ModifyRequest(): got %v, want no error", err) } res = proxyutil.NewResponse(200, nil, req) if err := f.ModifyResponse(res); err != nil { t.Errorf("ModifyResponse(): got %v, want no error", err) } if tm.RequestModified() { t.Error("tm.RequestModified(): got true, want false") } if tm.ResponseModified() { t.Error("tm.ResponseModified(): got true, want false") } tm.Reset() // Explicitly do not match POST data. req, err = http.NewRequest("GET", "http://martian.local", strings.NewReader("match=value")) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := f.ModifyRequest(req); err != nil { t.Errorf("ModifyRequest(): got %v, want no error", err) } res = proxyutil.NewResponse(200, nil, req) if err := f.ModifyResponse(res); err != nil { t.Errorf("ModifyResponse(): got %v, want no error", err) } if tm.RequestModified() { t.Error("tm.RequestModified(): got true, want false") } if tm.ResponseModified() { t.Error("tm.ResponseModified(): got true, want false") } tm.Reset() } func TestFilterFromJSON(t *testing.T) { msg := []byte(`{ "querystring.Filter": { "scope": ["request", "response"], "name": "param", "value": "true", "modifier": { "header.Modifier": { "scope": ["request", "response"], "name": "Martian-Modified", "value": "true" } } } }`) r, err := parse.FromJSON(msg) if err != nil { t.Fatalf("parse.FromJSON(): got %v, want no error", err) } reqmod := r.RequestModifier() if reqmod == nil { t.Fatal("reqmod: got nil, want not nil") } req, err := http.NewRequest("GET", "https://martian.test?param=true", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := reqmod.ModifyRequest(req); err != nil { t.Fatalf("reqmod.ModifyRequest(): got %v, want no error", err) } if got, want := req.Header.Get("Martian-Modified"), "true"; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", "Martian-Modified", got, want) } resmod := r.ResponseModifier() if resmod == nil { t.Fatalf("resmod: got nil, want not nil") } res := proxyutil.NewResponse(200, nil, req) if err := resmod.ModifyResponse(res); err != nil { t.Fatalf("resmod.ModifyResponse(): got %v, want no error", err) } if got, want := res.Header.Get("Martian-Modified"), "true"; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Martian-Modified", got, want) } } func TestElseCondition(t *testing.T) { msg := []byte(`{ "querystring.Filter": { "scope": ["request", "response"], "name": "param", "value": "true", "modifier": { "header.Modifier": { "scope": ["request", "response"], "name": "Martian-Modified", "value": "true" } }, "else": { "header.Modifier": { "scope": ["request", "response"], "name": "Martian-Modified", "value": "false" } } } }`) r, err := parse.FromJSON(msg) if err != nil { t.Fatalf("parse.FromJSON(): got %v, want no error", err) } reqmod := r.RequestModifier() if reqmod == nil { t.Fatal("reqmod: got nil, want not nil") } req, err := http.NewRequest("GET", "https://martian.test?param=false", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := reqmod.ModifyRequest(req); err != nil { t.Fatalf("reqmod.ModifyRequest(): got %v, want no error", err) } if got, want := req.Header.Get("Martian-Modified"), "false"; got != want { t.Errorf("req.Header.Get(%q): got %q, want %q", "Martian-Modified", got, want) } resmod := r.ResponseModifier() if resmod == nil { t.Fatalf("resmod: got nil, want not nil") } res := proxyutil.NewResponse(200, nil, req) if err := resmod.ModifyResponse(res); err != nil { t.Fatalf("resmod.ModifyResponse(): got %v, want no error", err) } if got, want := res.Header.Get("Martian-Modified"), "false"; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Martian-Modified", got, want) } } func TestVerifyRequests(t *testing.T) { f := NewFilter("", "") if err := f.VerifyRequests(); err != nil { t.Fatalf("VerifyRequest(): got %v, want no error", err) } tv := &verify.TestVerifier{ RequestError: errors.New("verify request failure"), } f.SetRequestModifier(tv) want := martian.NewMultiError() want.Add(tv.RequestError) if got := f.VerifyRequests(); got.Error() != want.Error() { t.Fatalf("VerifyRequests(): got %v, want %v", got, want) } f.ResetRequestVerifications() if err := f.VerifyRequests(); err != nil { t.Fatalf("VerifyRequest(): got %v, want no error", err) } } func TestVerifyResponses(t *testing.T) { f := NewFilter("", "") if err := f.VerifyResponses(); err != nil { t.Fatalf("VerifyResponses(): got %v, want no error", err) } tv := &verify.TestVerifier{ ResponseError: errors.New("verify response failure"), } f.SetResponseModifier(tv) want := martian.NewMultiError() want.Add(tv.ResponseError) if got := f.VerifyResponses(); got.Error() != want.Error() { t.Fatalf("VerifyResponses(): got %v, want %v", got, want) } f.ResetResponseVerifications() if err := f.VerifyResponses(); err != nil { t.Fatalf("VerifyResponses(): got %v, want no error", err) } } martian-3.3.2/querystring/query_string_matcher.go000066400000000000000000000031461421371434000223360ustar00rootroot00000000000000// Copyright 2017 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package querystring import "net/http" // Matcher is a conditonal evalutor of query string parameters // to be used in structs that take conditions. type Matcher struct { name, value string } // NewMatcher builds a new querystring matcher func NewMatcher(name, value string) *Matcher { return &Matcher{name: name, value: value} } // MatchRequest evaluates a request and returns whether or not // the request contains a querystring param that matches the provided name // and value. func (m *Matcher) MatchRequest(req *http.Request) bool { for n, vs := range req.URL.Query() { if m.name == n { if m.value == "" { return true } for _, v := range vs { if m.value == v { return true } } } } return false } // MatchResponse evaluates a response and returns whether or not // the request that resulted in that response contains a querystring param that matches the provided name // and value. func (m *Matcher) MatchResponse(res *http.Response) bool { return m.MatchRequest(res.Request) } martian-3.3.2/querystring/query_string_modifier.go000066400000000000000000000040541421371434000225100ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package querystring contains a modifier to rewrite query strings in a request. package querystring import ( "encoding/json" "net/http" "github.com/google/martian/v3" "github.com/google/martian/v3/parse" ) func init() { parse.Register("querystring.Modifier", modifierFromJSON) } type modifier struct { key, value string } type modifierJSON struct { Name string `json:"name"` Value string `json:"value"` Scope []parse.ModifierType `json:"scope"` } // ModifyRequest modifies the query string of the request with the given key and value. func (m *modifier) ModifyRequest(req *http.Request) error { query := req.URL.Query() query.Set(m.key, m.value) req.URL.RawQuery = query.Encode() return nil } // NewModifier returns a request modifier that will set the query string // at key with the given value. If the query string key already exists all // values will be overwritten. func NewModifier(key, value string) martian.RequestModifier { return &modifier{ key: key, value: value, } } // modifierFromJSON takes a JSON message as a byte slice and returns // a querystring.modifier and an error. // // Example JSON: // { // "name": "param", // "value": "true", // "scope": ["request", "response"] // } func modifierFromJSON(b []byte) (*parse.Result, error) { msg := &modifierJSON{} if err := json.Unmarshal(b, msg); err != nil { return nil, err } return parse.NewResult(NewModifier(msg.Name, msg.Value), msg.Scope) } martian-3.3.2/querystring/query_string_modifier_test.go000066400000000000000000000061021421371434000235430ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package querystring import ( "net/http" "testing" "github.com/google/martian/v3/parse" ) func TestNewQueryStringModifier(t *testing.T) { mod := NewModifier("testing", "true") req, err := http.NewRequest("GET", "/", nil) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } if err := mod.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := req.URL.Query().Get("testing"), "true"; got != want { t.Errorf("req.URL.Query().Get(%q): got %q, want %q", "testing", got, want) } } func TestQueryStringModifierQueryExists(t *testing.T) { mod := NewModifier("testing", "true") req, err := http.NewRequest("GET", "/?testing=false", nil) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } if err := mod.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := req.URL.Query().Get("testing"), "true"; got != want { t.Errorf("req.URL.Query().Get(%q): got %q, want %q", "testing", got, want) } } func TestQueryStringModifierQueryExistsMultipleKeys(t *testing.T) { mod := NewModifier("testing", "true") req, err := http.NewRequest("GET", "/?testing=false&testing=foo&foo=bar", nil) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } if err := mod.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if got, want := req.URL.Query().Get("testing"), "true"; got != want { t.Errorf("req.URL.Query().Get(%q): got %q, want %q", "testing", got, want) } if got, want := req.URL.Query().Get("foo"), "bar"; got != want { t.Errorf("req.URL.Query().Get(%q): got %q, want %q", "testing", got, want) } } func TestModifierFromJSON(t *testing.T) { msg := []byte(` { "querystring.Modifier": { "scope": ["request"], "name": "param", "value": "true" } }`) r, err := parse.FromJSON(msg) if err != nil { t.Fatalf("parse.FromJSON(): got %v, want no error", err) } req, err := http.NewRequest("GET", "http://martian.test", nil) if err != nil { t.Fatalf("http.NewRequest(): got %q, want no error", err) } reqmod := r.RequestModifier() if reqmod == nil { t.Fatalf("reqmod: got nil, want not nil") } if err := reqmod.ModifyRequest(req); err != nil { t.Fatalf("reqmod.ModifyRequest(): got %v, want no error", err) } if got, want := req.URL.Query().Get("param"), "true"; got != want { t.Errorf("req.URL.Query().Get(%q): got %q, want %q", "param", got, want) } } martian-3.3.2/querystring/query_string_verifier.go000066400000000000000000000064141421371434000225270ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package querystring import ( "encoding/json" "fmt" "net/http" "strings" "github.com/google/martian/v3" "github.com/google/martian/v3/parse" "github.com/google/martian/v3/verify" ) func init() { parse.Register("querystring.Verifier", verifierFromJSON) } type verifier struct { key, value string err *martian.MultiError } type verifierJSON struct { Name string `json:"name"` Value string `json:"value"` Scope []parse.ModifierType `json:"scope"` } // NewVerifier returns a new param verifier. func NewVerifier(key, value string) (verify.RequestVerifier, error) { if key == "" { return nil, fmt.Errorf("no key provided to param verifier") } return &verifier{ key: key, value: value, err: martian.NewMultiError(), }, nil } // ModifyRequest verifies that the request's URL params match the given params // in all modified requests. If no value is provided, the verifier will only // check if the given key is present. An error will be added to the contained // *MultiError if the param is unmatched. func (v *verifier) ModifyRequest(req *http.Request) error { // skip requests to API ctx := martian.NewContext(req) if ctx.IsAPIRequest() { return nil } if err := req.ParseForm(); err != nil { err := fmt.Errorf("request(%v) parsing failed; could not parse query parameters", req.URL) v.err.Add(err) return nil } vals, ok := req.Form[v.key] if !ok { err := fmt.Errorf("request(%v) param verification error: key %v not found", req.URL, v.key) v.err.Add(err) return nil } if v.value == "" { return nil } for _, val := range vals { if v.value == val { return nil } } err := fmt.Errorf("request(%v) param verification error: got %v for key %v, want %v", req.URL, strings.Join(vals, ", "), v.key, v.value) v.err.Add(err) return nil } // VerifyRequests returns an error if verification for any request failed. // If an error is returned it will be of type *martian.MultiError. func (v *verifier) VerifyRequests() error { if v.err.Empty() { return nil } return v.err } // ResetRequestVerifications clears all failed request verifications. func (v *verifier) ResetRequestVerifications() { v.err = martian.NewMultiError() } // verifierFromJSON builds a querystring.Verifier from JSON. // // Example JSON: // { // "querystring.Verifier": { // "scope": ["request", "response"], // "name": "Martian-Testing", // "value": "true" // } // } func verifierFromJSON(b []byte) (*parse.Result, error) { msg := &verifierJSON{} if err := json.Unmarshal(b, msg); err != nil { return nil, err } v, err := NewVerifier(msg.Name, msg.Value) if err != nil { return nil, err } return parse.NewResult(v, msg.Scope) } martian-3.3.2/querystring/query_string_verifier_test.go000066400000000000000000000132721421371434000235660ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package querystring import ( "net/http" "testing" "github.com/google/martian/v3" "github.com/google/martian/v3/parse" "github.com/google/martian/v3/verify" ) func TestVerifyRequestPasses(t *testing.T) { v, err := NewVerifier("foo", "bar") if err != nil { t.Fatalf("NewVerifier(%q, %q): got %v, want no error", "foo", "bar", err) } req, err := http.NewRequest("GET", "http://www.google.com?foo=baz&foo=bar", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } _, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("TestContext(): got %v, want no error", err) } defer remove() if err := v.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if err := v.VerifyRequests(); err != nil { t.Fatalf("VerifyRequests(): got %v, want no error", err) } } func TestVerifyEmptyValue(t *testing.T) { v, err := NewVerifier("foo", "") if err != nil { t.Fatalf("NewVerifier(%q, %q): got %v, want no error", "foo", "", err) } req, err := http.NewRequest("GET", "http://www.google.com?foo=bar", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } _, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("TestContext(): got %v, want no error", err) } defer remove() if err := v.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if err := v.VerifyRequests(); err != nil { t.Fatalf("VerifyRequests(): got %v, want no error", err) } } func TestFailureWithMissingKey(t *testing.T) { v, err := NewVerifier("foo", "bar") if err != nil { t.Fatalf("NewVerifier(%q, %q): got %v, want no error", "foo", "bar", err) } req, err := http.NewRequest("GET", "http://www.google.com?fizz=bar", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } _, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("TestContext(): got %v, want no error", err) } defer remove() if err := v.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } merr, ok := v.VerifyRequests().(*martian.MultiError) if !ok { t.Fatal("VerifyRequests(): got nil, want *verify.MultiError") } errs := merr.Errors() if len(errs) != 1 { t.Fatalf("len(merr.Errors()): got %d, want 1", len(errs)) } expectErr := "request(http://www.google.com?fizz=bar) param verification error: key foo not found" for i := range errs { if got, want := errs[i].Error(), expectErr; got != want { t.Errorf("%d. err.Error(): mismatched error output\ngot: %s\nwant: %s", i, got, want) } } } func TestFailureWithMultiFail(t *testing.T) { v, err := NewVerifier("foo", "bar") if err != nil { t.Fatalf("NewVerifier(%q, %q): got %v, want no error", "foo", "bar", err) } req, err := http.NewRequest("GET", "http://www.google.com?foo=baz", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } _, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("TestContext(): got %v, want no error", err) } defer remove() if err := v.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if err := v.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } merr, ok := v.VerifyRequests().(*martian.MultiError) if !ok { t.Fatalf("VerifyRequests(): got nil, want *verify.MultiError") } errs := merr.Errors() if len(errs) != 2 { t.Fatalf("len(merr.Errors()): got %d, want 2", len(errs)) } expectErr := "request(http://www.google.com?foo=baz) param verification error: got baz for key foo, want bar" for i := range errs { if got, want := errs[i].Error(), expectErr; got != want { t.Errorf("%d. err.Error(): mismatched error output\ngot: %s\nwant: %s", i, got, want) } } v.ResetRequestVerifications() if err := v.VerifyRequests(); err != nil { t.Fatalf("VerifyRequests(): got %v, want no error", err) } } func TestBadInputToConstructor(t *testing.T) { if _, err := NewVerifier("", "bar"); err == nil { t.Fatalf("NewVerifier(): no error returned for empty key") } } func TestVerifierFromJSON(t *testing.T) { msg := []byte(`{ "querystring.Verifier": { "scope": ["request"], "name": "param", "value": "true" } }`) r, err := parse.FromJSON(msg) if err != nil { t.Fatalf("parse.FromJSON(): got %v, want no error", err) } reqmod := r.RequestModifier() if reqmod == nil { t.Fatal("reqmod: got nil, want not nil") } reqv, ok := reqmod.(verify.RequestVerifier) if !ok { t.Fatal("reqmod.(verify.RequestVerifier): got !ok, want ok") } req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } _, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("TestContext(): got %v, want no error", err) } defer remove() if err := reqv.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if err := reqv.VerifyRequests(); err == nil { t.Error("VerifyRequests(): got nil, want not nil") } } martian-3.3.2/servemux/000077500000000000000000000000001421371434000150275ustar00rootroot00000000000000martian-3.3.2/servemux/servemux_filter.go000066400000000000000000000025721421371434000206070ustar00rootroot00000000000000// Copyright 2016 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package servemux contains a filter that executes modifiers when there is a // pattern match in a mux. package servemux import ( "net/http" "github.com/google/martian/v3" "github.com/google/martian/v3/filter" ) var noop = martian.Noop("mux.Filter") // Filter is a modifier that executes mod if a pattern is matched in mux. type Filter struct { *filter.Filter } // NewFilter constructs a filter that applies the modifier when the request // url matches a pattern in mux. If no mux is provided, the request is evaluated // against patterns in http.DefaultServeMux. func NewFilter(mux *http.ServeMux) *Filter { if mux == nil { mux = http.DefaultServeMux } m := NewMatcher(mux) f := filter.New() f.SetRequestCondition(m) f.SetResponseCondition(m) return &Filter{Filter: f} } martian-3.3.2/servemux/servemux_filter_test.go000066400000000000000000000071421421371434000216440ustar00rootroot00000000000000// Copyright 2016 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package servemux import ( "net/http" "testing" "github.com/google/martian/v3/martiantest" "github.com/google/martian/v3/proxyutil" ) func TestModifyRequest(t *testing.T) { mux := http.NewServeMux() mux.HandleFunc("example.com/test", func(rw http.ResponseWriter, req *http.Request) { return }) f := NewFilter(mux) tm := martiantest.NewModifier() f.RequestWhenTrue(tm) fm := martiantest.NewModifier() f.RequestWhenFalse(fm) req, err := http.NewRequest("GET", "http://example.com/test", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := f.ModifyRequest(req); err != nil { t.Errorf("ModifyRequest(): got %v, want no error", err) } if got, want := tm.RequestModified(), true; got != want { t.Errorf("tm.RequestModified(): got %v, want %v", got, want) } if got, want := fm.RequestModified(), false; got != want { t.Errorf("fm.RequestModified(): got %v, want %v", got, want) } tm.Reset() fm.Reset() req, err = http.NewRequest("GET", "http://example.com/nomatch", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } if err := f.ModifyRequest(req); err != nil { t.Errorf("ModifyRequest(): got %v, want no error", err) } if got, want := tm.RequestModified(), false; got != want { t.Errorf("tm.RequestModified(): got %v, want %v", got, want) } if got, want := fm.RequestModified(), true; got != want { t.Errorf("fm.RequestModified(): got %v, want %v", got, want) } } func TestModifyResponse(t *testing.T) { mux := http.NewServeMux() mux.HandleFunc("example.com/restest", func(rw http.ResponseWriter, req *http.Request) { return }) f := NewFilter(mux) tm := martiantest.NewModifier() f.ResponseWhenTrue(tm) fm := martiantest.NewModifier() f.ResponseWhenFalse(fm) req, err := http.NewRequest("GET", "http://example.com/restest", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } res := proxyutil.NewResponse(200, nil, req) if err := f.ModifyResponse(res); err != nil { t.Errorf("ModifyResponse(): got %v, want no error", err) } if got, want := tm.ResponseModified(), true; got != want { t.Errorf("tm.ResponseModified(): got %v, want %v", got, want) } if got, want := fm.ResponseModified(), false; got != want { t.Errorf("fm.ResponseModified(): got %v, want %v", got, want) } tm.Reset() fm.Reset() req, err = http.NewRequest("GET", "http://example.com/nomatch", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } res = proxyutil.NewResponse(200, nil, req) if err := f.ModifyResponse(res); err != nil { t.Errorf("ModifyResponse(): got %v, want no error", err) } if tm.ResponseModified() != false { t.Errorf("tm.ResponseModified(): got %t, want %t", tm.ResponseModified(), false) } if got, want := tm.ResponseModified(), false; got != want { t.Errorf("tm.ResponseModified(): got %v, want %v", got, want) } if got, want := fm.ResponseModified(), true; got != want { t.Errorf("fm.ResponseModified(): got %v, want %v", got, want) } } martian-3.3.2/servemux/servemux_matcher.go000066400000000000000000000027211421371434000207410ustar00rootroot00000000000000// Copyright 2016 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package servemux import "net/http" // Matcher is a conditional evaluator of request urls against patterns registered // in mux. type Matcher struct { mux *http.ServeMux } // NewMatcher builds a new servemux.Matcher. func NewMatcher(mux *http.ServeMux) *Matcher { return &Matcher{ mux: mux, } } // MatchRequest returns true if the request URL matches any pattern in mux. If no // pattern is matched, false is returned. func (m *Matcher) MatchRequest(req *http.Request) bool { if _, pattern := m.mux.Handler(req); pattern != "" { return true } return false } // MatchResponse returns true if the request URL associated with the response matches // any pattern in mux. If pattern is matched, false is returned. func (m *Matcher) MatchResponse(res *http.Response) bool { if _, pattern := m.mux.Handler(res.Request); pattern != "" { return true } return false } martian-3.3.2/skip/000077500000000000000000000000001421371434000141175ustar00rootroot00000000000000martian-3.3.2/skip/skip.go000066400000000000000000000031651421371434000154210ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package skip provides a request modifier to skip the HTTP round-trip. package skip import ( "encoding/json" "net/http" "github.com/google/martian/v3" "github.com/google/martian/v3/parse" ) // RoundTrip is a modifier that skips the request round-trip. type RoundTrip struct{} type roundTripJSON struct { Scope []parse.ModifierType `json:"scope"` } func init() { parse.Register("skip.RoundTrip", roundTripFromJSON) } // NewRoundTrip returns a new modifier that skips round-trip. func NewRoundTrip() *RoundTrip { return &RoundTrip{} } // ModifyRequest skips the request round-trip. func (r *RoundTrip) ModifyRequest(req *http.Request) error { ctx := martian.NewContext(req) ctx.SkipRoundTrip() return nil } // roundTripFromJSON builds a skip.RoundTrip from JSON. // Example JSON: // { // "skip.RoundTrip": { } // } func roundTripFromJSON(b []byte) (*parse.Result, error) { msg := &roundTripJSON{} if err := json.Unmarshal(b, msg); err != nil { return nil, err } return parse.NewResult(NewRoundTrip(), msg.Scope) } martian-3.3.2/skip/skip_test.go000066400000000000000000000032501421371434000164530ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package skip import ( "net/http" "testing" "github.com/google/martian/v3" "github.com/google/martian/v3/parse" ) func TestRoundTrip(t *testing.T) { m := NewRoundTrip() req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } ctx, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("martian.TestContext(): got %v, want no error", err) } defer remove() if ctx.SkippingRoundTrip() { t.Fatal("ctx.SkippingRoundTrip(): got true, want false") } if err := m.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if !ctx.SkippingRoundTrip() { t.Fatal("ctx.SkippingRoundTrip(): got false, want true") } } func TestFromJSON(t *testing.T) { msg := []byte(`{ "skip.RoundTrip": {} }`) r, err := parse.FromJSON(msg) if err != nil { t.Fatalf("parse.FromJSON(): got %v, want no error", err) } reqmod := r.RequestModifier() if _, ok := reqmod.(*RoundTrip); !ok { t.Fatal("reqmod.(*RoundTrip): got !ok, want ok") } } martian-3.3.2/stash/000077500000000000000000000000001421371434000142735ustar00rootroot00000000000000martian-3.3.2/stash/stash_modifier.go000066400000000000000000000046121421371434000176250ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package stash provides a modifier that stores the request URL in a // specified header. package stash import ( "encoding/json" "fmt" "net/http" "github.com/google/martian/v3/parse" ) func init() { parse.Register("stash.Modifier", modifierFromJSON) } // Modifier adds a header to the request containing the current state of the URL. // The header will be named with the value stored in headerName. // There will be no validation done on this header name. type Modifier struct { headerName string } type modifierJSON struct { HeaderName string `json:"headerName"` Scope []parse.ModifierType `json:"scope"` } // NewModifier returns a RequestModifier that write the current URL into a header. func NewModifier(headerName string) *Modifier { return &Modifier{headerName: headerName} } // ModifyRequest writes the current URL into a header. func (m *Modifier) ModifyRequest(req *http.Request) error { req.Header.Set(m.headerName, req.URL.String()) return nil } // ModifyResponse writes the same header written in the request into the response. func (m *Modifier) ModifyResponse(res *http.Response) error { res.Header.Set(m.headerName, res.Request.Header.Get(m.headerName)) return nil } func modifierFromJSON(b []byte) (*parse.Result, error) { // If you would like the saved state of the URL to be written in the response you must specify // this modifier's scope as both request and response. msg := &modifierJSON{} if err := json.Unmarshal(b, msg); err != nil { return nil, err } mod := NewModifier(msg.HeaderName) r, err := parse.NewResult(mod, msg.Scope) if err != nil { return nil, err } if r.ResponseModifier() != nil && r.RequestModifier() == nil { return nil, fmt.Errorf("to write header on a response, specify scope as both request and response") } return r, nil } martian-3.3.2/stash/stash_modifier_test.go000066400000000000000000000125371421371434000206710ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package stash import ( "net" "net/http" "testing" "github.com/google/martian/v3/fifo" "github.com/google/martian/v3/parse" "github.com/google/martian/v3/port" "github.com/google/martian/v3/proxyutil" ) func TestStashRequest(t *testing.T) { fg := fifo.NewGroup() fg.AddRequestModifier(NewModifier("stashed-url")) pmod := port.NewModifier() pmod.UsePort(8080) fg.AddRequestModifier(pmod) req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } if err := fg.ModifyRequest(req); err != nil { t.Fatalf("smod.ModifyRequest(): got %v, want no error", err) } _, port, err := net.SplitHostPort(req.URL.Host) if err != nil { t.Fatalf("net.SplitHostPort(%q): got %v, want no error", req.URL.Host, err) } if got, want := port, "8080"; got != want { t.Errorf("port: got %v, want %v", got, want) } if got, want := req.Header.Get("stashed-url"), "http://example.com"; got != want { t.Errorf("stashed-url header: got %v, want %v", got, want) } } func TestStashRequestResponse(t *testing.T) { headerName := "stashed-url" originalURL := "http://example.com" fg := fifo.NewGroup() fg.AddRequestModifier(NewModifier(headerName)) fg.AddResponseModifier(NewModifier(headerName)) pmod := port.NewModifier() pmod.UsePort(8080) fg.AddRequestModifier(pmod) req, err := http.NewRequest("GET", originalURL, nil) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } if err := fg.ModifyRequest(req); err != nil { t.Fatalf("smod.ModifyRequest(): got %v, want no error", err) } _, port, err := net.SplitHostPort(req.URL.Host) if err != nil { t.Fatalf("net.SplitHostPort(%q): got %v, want no error", req.URL.Host, err) } if got, want := port, "8080"; got != want { t.Errorf("port: got %v, want %v", got, want) } if got, want := req.Header.Get(headerName), originalURL; got != want { t.Errorf("res.Header.Get(%q): got %v, want %v", headerName, got, want) } res := proxyutil.NewResponse(200, nil, req) if err := fg.ModifyResponse(res); err != nil { t.Fatalf("resmod.ModifyResponse(): got %v, want no error", err) } if got, want := res.Header.Get(headerName), originalURL; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", headerName, got, want) } } func TestStashInvalidHeaderName(t *testing.T) { mod := NewModifier("invalid-chars-actually-work-;><@") req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } if err := mod.ModifyRequest(req); err != nil { t.Fatalf("smod.ModifyRequest(): got %v, want no error", err) } if got, want := req.Header.Get("invalid-chars-actually-work-;><@"), "http://example.com"; got != want { t.Errorf("stashed-url header: got %v, want %v", got, want) } } func TestModiferFromJSON(t *testing.T) { headerName := "stashed-url" originalURL := "http://example.com" msg := []byte(`{ "fifo.Group": { "scope": ["request", "response"], "modifiers": [ { "stash.Modifier": { "scope": ["request", "response"], "headerName": "stashed-url" } }, { "port.Modifier": { "scope": ["request"], "port": 8080 } } ] } }`) r, err := parse.FromJSON(msg) if err != nil { t.Fatalf("parse.FromJSON(): got %v, want no error", err) } reqmod := r.RequestModifier() if reqmod == nil { t.Fatal("reqmod: got nil, want not nil") } req, err := http.NewRequest("GET", originalURL, nil) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } if err := reqmod.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } _, port, err := net.SplitHostPort(req.URL.Host) if err != nil { t.Fatalf("net.SplitHostPort(%q): got %v, want no error", req.URL.Host, err) } if got, want := port, "8080"; got != want { t.Errorf("port: got %v, want %v", got, want) } if got, want := req.Header.Get(headerName), originalURL; got != want { t.Errorf("req.Header.Get(%q) header: got %v, want %v", headerName, got, want) } resmod := r.ResponseModifier() res := proxyutil.NewResponse(200, nil, req) if err := resmod.ModifyResponse(res); err != nil { t.Fatalf("resmod.ModifyResponse(): got %v, want no error", err) } if got, want := res.Header.Get(headerName), originalURL; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", headerName, got, want) } } func TestModiferFromJSONInvalidConfigurations(t *testing.T) { msg := []byte(`{ "stash.Modifier": { "scope": ["response"], "headerName": "stash-header" } }`) _, err := parse.FromJSON(msg) if err == nil { t.Fatalf("parseFromJSON(msg): Got no error, but should have gotten one.") } } martian-3.3.2/static/000077500000000000000000000000001421371434000144405ustar00rootroot00000000000000martian-3.3.2/static/static_file_modifier.go000066400000000000000000000144651421371434000211450ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package static provides a modifier that allows Martian to return static files // local to Martian. The static modifier does not support setting explicit path // mappings via the JSON API. package static import ( "bytes" "encoding/json" "fmt" "io" "io/ioutil" "mime" "mime/multipart" "net/http" "net/textproto" "os" "path" "path/filepath" "strconv" "strings" "github.com/google/martian/v3" "github.com/google/martian/v3/parse" ) // Modifier is a martian.RequestResponseModifier that routes reqeusts to rootPath // and serves the assets there, while skipping the HTTP roundtrip. type Modifier struct { rootPath string explicitPaths map[string]string } type staticJSON struct { ExplicitPaths map[string]string `json:"explicitPaths"` RootPath string `json:"rootPath"` Scope []parse.ModifierType `json:"scope"` } func init() { parse.Register("static.Modifier", modifierFromJSON) } // NewModifier constructs a static.Modifier that takes a path to serve files from, as well as an optional mapping of request paths to local // file paths (still rooted at rootPath). func NewModifier(rootPath string) *Modifier { return &Modifier{ rootPath: path.Clean(rootPath), explicitPaths: make(map[string]string), } } // ModifyRequest marks the context to skip the roundtrip and downgrades any https requests // to http. func (s *Modifier) ModifyRequest(req *http.Request) error { ctx := martian.NewContext(req) ctx.SkipRoundTrip() return nil } // ModifyResponse reads the file rooted at rootPath joined with the request URL // path. In the case that the the request path is a key in s.explicitPaths, ModifyRequest // will attempt to open the file located at s.rootPath joined by the value in s.explicitPaths // (keyed by res.Request.URL.Path). In the case that the file cannot be found, the response // will be a 404. ModifyResponse will return a 404 for any path that is defined in s.explictPaths // and that does not exist locally, even if that file does exist in s.rootPath. func (s *Modifier) ModifyResponse(res *http.Response) error { reqpth := filepath.Clean(res.Request.URL.Path) fpth := filepath.Join(s.rootPath, reqpth) if _, ok := s.explicitPaths[reqpth]; ok { fpth = filepath.Join(s.rootPath, s.explicitPaths[reqpth]) } f, err := os.Open(fpth) switch { case os.IsNotExist(err): res.StatusCode = http.StatusNotFound return nil case os.IsPermission(err): // This is returning a StatusUnauthorized to reflect that the Martian does // not have the appropriate permissions on the local file system. This is a // deviation from the standard assumption around an HTTP 401 response. res.StatusCode = http.StatusUnauthorized return err case err != nil: res.StatusCode = http.StatusInternalServerError return err } res.Body.Close() info, err := f.Stat() if err != nil { res.StatusCode = http.StatusInternalServerError return err } contentType := mime.TypeByExtension(filepath.Ext(fpth)) res.Header.Set("Content-Type", contentType) // If no range request header is present, return the file as the response body. if res.Request.Header.Get("Range") == "" { res.ContentLength = info.Size() res.Body = f return nil } rh := res.Request.Header.Get("Range") rh = strings.ToLower(rh) sranges := strings.Split(strings.TrimLeft(rh, "bytes="), ",") var ranges [][]int for _, rng := range sranges { if strings.HasSuffix(rng, "-") { rng = fmt.Sprintf("%s%d", rng, info.Size()-1) } rs := strings.Split(rng, "-") if len(rs) != 2 { res.StatusCode = http.StatusRequestedRangeNotSatisfiable return nil } start, err := strconv.Atoi(strings.TrimSpace(rs[0])) if err != nil { return err } end, err := strconv.Atoi(strings.TrimSpace(rs[1])) if err != nil { return err } if start > end { res.StatusCode = http.StatusRequestedRangeNotSatisfiable return nil } ranges = append(ranges, []int{start, end}) } // Range request. res.StatusCode = http.StatusPartialContent // Single range request. if len(ranges) == 1 { start := ranges[0][0] end := ranges[0][1] length := end - start + 1 seg := make([]byte, length) switch n, err := f.ReadAt(seg, int64(start)); err { case nil, io.EOF: res.ContentLength = int64(n) default: return err } res.Body = ioutil.NopCloser(bytes.NewReader(seg)) res.Header.Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, info.Size())) return nil } // Multipart range request. var mpbody bytes.Buffer mpw := multipart.NewWriter(&mpbody) for _, rng := range ranges { start, end := rng[0], rng[1] mimeh := make(textproto.MIMEHeader) mimeh.Set("Content-Type", contentType) mimeh.Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, info.Size())) length := end - start + 1 seg := make([]byte, length) switch n, err := f.ReadAt(seg, int64(start)); err { case nil, io.EOF: res.ContentLength = int64(n) default: return err } pw, err := mpw.CreatePart(mimeh) if err != nil { return err } if _, err := pw.Write(seg); err != nil { return err } } mpw.Close() res.ContentLength = int64(len(mpbody.Bytes())) res.Body = ioutil.NopCloser(bytes.NewReader(mpbody.Bytes())) res.Header.Set("Content-Type", fmt.Sprintf("multipart/byteranges; boundary=%s", mpw.Boundary())) return nil } // SetExplicitPathMappings sets an optional mapping of request paths to local // file paths rooted at s.rootPath. func (s *Modifier) SetExplicitPathMappings(ep map[string]string) { s.explicitPaths = ep } func modifierFromJSON(b []byte) (*parse.Result, error) { msg := &staticJSON{} if err := json.Unmarshal(b, msg); err != nil { return nil, err } mod := NewModifier(msg.RootPath) mod.SetExplicitPathMappings(msg.ExplicitPaths) return parse.NewResult(mod, msg.Scope) } martian-3.3.2/static/static_file_modifier_test.go000066400000000000000000000336601421371434000222020ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package static import ( "bytes" "fmt" "io/ioutil" "net/http" "os" "path" "testing" "github.com/google/martian/v3" "github.com/google/martian/v3/parse" "github.com/google/martian/v3/proxyutil" ) func Test404WhenExplictlyMappedFileDoesNotExist(t *testing.T) { tmpdir, err := ioutil.TempDir("", "test_static_modifier_explicit_path_mapping_") if err != nil { t.Fatalf("ioutil.TempDir(): got %v, want no error", err) } //if err := os.MkdirAll(path.Join(tmpdir, "explicit/path"), 0777); err != nil { // t.Fatalf("os.Mkdir(): got %v, want no error", err) //} //if err := ioutil.WriteFile(path.Join(tmpdir, "explicit/path", "sfmtest.txt"), []byte("test file"), 0777); err != nil { // t.Fatalf("ioutil.WriteFile(): got %v, want no error", err) //} req, err := http.NewRequest("GET", "/sfmtest.txt", nil) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } _, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("TestContext(): got %v, want no error", err) } defer remove() res := proxyutil.NewResponse(http.StatusOK, nil, req) mod := NewModifier(tmpdir) if err := mod.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } mod.SetExplicitPathMappings(map[string]string{"/sfmtest.txt": "/explicit/path/sfmtest.txt"}) if err := mod.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := res.StatusCode, http.StatusNotFound; got != want { t.Errorf("res.StatusCode: got %v, want %v", got, want) } } func TestFileExistsInBothExplictlyMappedPathAndInferredPath(t *testing.T) { tmpdir, err := ioutil.TempDir("", "test_static_modifier_explicit_path_mapping_") if err != nil { t.Fatalf("ioutil.TempDir(): got %v, want no error", err) } if err := os.MkdirAll(path.Join(tmpdir, "explicit/path"), 0777); err != nil { t.Fatalf("os.Mkdir(): got %v, want no error", err) } if err := ioutil.WriteFile(path.Join(tmpdir, "sfmtest.txt"), []byte("dont return"), 0777); err != nil { t.Fatalf("ioutil.WriteFile(): got %v, want no error", err) } if err := ioutil.WriteFile(path.Join(tmpdir, "explicit/path", "sfmtest.txt"), []byte("target"), 0777); err != nil { t.Fatalf("ioutil.WriteFile(): got %v, want no error", err) } req, err := http.NewRequest("GET", "/sfmtest.txt", nil) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } _, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("TestContext(): got %v, want no error", err) } defer remove() res := proxyutil.NewResponse(http.StatusOK, nil, req) mod := NewModifier(tmpdir) if err := mod.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } mod.SetExplicitPathMappings(map[string]string{"/sfmtest.txt": "/explicit/path/sfmtest.txt"}) if err := mod.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := res.Header.Get("Content-Type"), "text/plain; charset=utf-8"; got != want { t.Errorf("res.Header.Get('Content-Type'): got %v, want %v", got, want) } got, err := ioutil.ReadAll(res.Body) if err != nil { t.Fatalf("ioutil.ReadAll(): got %v, want no error", err) } res.Body.Close() if want := []byte("target"); !bytes.Equal(got, want) { t.Errorf("res.Body: got %q, want %q", got, want) } if got, want := res.ContentLength, int64(len("target")); got != want { t.Errorf("res.ContentLength: got %v, want %v", got, want) } } func TestStaticModifierExplicitPathMapping(t *testing.T) { tmpdir, err := ioutil.TempDir("", "test_static_modifier_explicit_path_mapping_") if err != nil { t.Fatalf("ioutil.TempDir(): got %v, want no error", err) } if err := os.MkdirAll(path.Join(tmpdir, "explicit/path"), 0777); err != nil { t.Fatalf("os.Mkdir(): got %v, want no error", err) } if err := ioutil.WriteFile(path.Join(tmpdir, "explicit/path", "sfmtest.txt"), []byte("test file"), 0777); err != nil { t.Fatalf("ioutil.WriteFile(): got %v, want no error", err) } req, err := http.NewRequest("GET", "/sfmtest.txt", nil) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } _, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("TestContext(): got %v, want no error", err) } defer remove() res := proxyutil.NewResponse(http.StatusOK, nil, req) mod := NewModifier(tmpdir) if err := mod.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } mod.SetExplicitPathMappings(map[string]string{"/sfmtest.txt": "/explicit/path/sfmtest.txt"}) if err := mod.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := res.Header.Get("Content-Type"), "text/plain; charset=utf-8"; got != want { t.Errorf("res.Header.Get('Content-Type'): got %v, want %v", got, want) } got, err := ioutil.ReadAll(res.Body) if err != nil { t.Fatalf("ioutil.ReadAll(): got %v, want no error", err) } res.Body.Close() if want := []byte("test file"); !bytes.Equal(got, want) { t.Errorf("res.Body: got %q, want %q", got, want) } if got, want := res.ContentLength, int64(len("test file")); got != want { t.Errorf("res.ContentLength: got %v, want %v", got, want) } } func TestStaticModifierOnRequest(t *testing.T) { tmpdir, err := ioutil.TempDir("", "test_static_modifier_on_request_") if err != nil { t.Fatalf("ioutil.TempDir(): got %v, want no error", err) } if err := ioutil.WriteFile(path.Join(tmpdir, "sfmtest.txt"), []byte("test file"), 0777); err != nil { t.Fatalf("ioutil.WriteFile(): got %v, want no error", err) } req, err := http.NewRequest("GET", "/sfmtest.txt", nil) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } _, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("TestContext(): got %v, want no error", err) } defer remove() res := proxyutil.NewResponse(http.StatusOK, nil, req) mod := NewModifier(tmpdir) if err := mod.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if err := mod.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := res.Header.Get("Content-Type"), "text/plain; charset=utf-8"; got != want { t.Errorf("res.Header.Get('Content-Type'): got %v, want %v", got, want) } got, err := ioutil.ReadAll(res.Body) if err != nil { t.Fatalf("ioutil.ReadAll(): got %v, want no error", err) } res.Body.Close() if want := []byte("test file"); !bytes.Equal(got, want) { t.Errorf("res.Body: got %q, want %q", got, want) } if got, want := res.ContentLength, int64(len("test file")); got != want { t.Errorf("res.ContentLength: got %v, want %v", got, want) } } func TestRequestOverHTTPS(t *testing.T) { tmpdir, err := ioutil.TempDir("", "test_static_modifier_on_request_") if err != nil { t.Fatalf("ioutil.TempDir(): got %v, want no error", err) } if err := ioutil.WriteFile(path.Join(tmpdir, "sfmtest.txt"), []byte("test file"), 0777); err != nil { t.Fatalf("ioutil.WriteFile(): got %v, want no error", err) } req, err := http.NewRequest("GET", "/sfmtest.txt", nil) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } req.URL.Scheme = "https" _, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("TestContext(): got %v, want no error", err) } defer remove() res := proxyutil.NewResponse(http.StatusOK, nil, req) mod := NewModifier(tmpdir) if err := mod.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if err := mod.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := res.Header.Get("Content-Type"), "text/plain; charset=utf-8"; got != want { t.Errorf("res.Header.Get('Content-Type'): got %v, want %v", got, want) } got, err := ioutil.ReadAll(res.Body) if err != nil { t.Fatalf("ioutil.ReadAll(): got %v, want no error", err) } res.Body.Close() if want := []byte("test file"); !bytes.Equal(got, want) { t.Errorf("res.Body: got %q, want %q", got, want) } if got, want := res.ContentLength, int64(len("test file")); got != want { t.Errorf("res.ContentLength: got %v, want %v", got, want) } } func TestModifierFromJSON(t *testing.T) { tmpdir, err := ioutil.TempDir("", "test_static_modifier_on_request_") if err != nil { t.Fatalf("ioutil.TempDir(): got %v, want no error", err) } tmpdir2 := path.Join(tmpdir, "subdir") err = os.Mkdir(tmpdir2, 0777) if err != nil { t.Fatalf("ioutil.TempDir(): got %v, want no error", err) } if err := ioutil.WriteFile(path.Join(tmpdir, "sfmtest.txt"), []byte("test file"), 0777); err != nil { t.Fatalf("ioutil.WriteFile(): got %v, want no error", err) } if err := ioutil.WriteFile(path.Join(tmpdir2, "sfmtest.txt"), []byte("test file2"), 0777); err != nil { t.Fatalf("ioutil.WriteFile(): got %v, want no error", err) } msg := []byte(fmt.Sprintf(`{ "static.Modifier": { "scope": ["request", "response"], "explicitPaths": {"/foo/bar.baz": "/subdir/sfmtest.txt"}, "rootPath": %q } }`, tmpdir)) r, err := parse.FromJSON(msg) if err != nil { t.Fatalf("parse.FromJSON(): got %v, want no error", err) } reqmod := r.RequestModifier() if reqmod == nil { t.Fatal("reqmod: got nil, want not nil") } resmod := r.ResponseModifier() if resmod == nil { t.Fatal("resmod: got nil, want not nil") } req, err := http.NewRequest("GET", "/sfmtest.txt", nil) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } _, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("TestContext(): got %v, want no error", err) } defer remove() res := proxyutil.NewResponse(http.StatusOK, nil, req) if err := reqmod.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if err := resmod.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := res.Header.Get("Content-Type"), "text/plain; charset=utf-8"; got != want { t.Errorf("res.Header.Get('Content-Type'): got %v, want %v", got, want) } got, err := ioutil.ReadAll(res.Body) if err != nil { t.Fatalf("ioutil.ReadAll(): got %v, want no error", err) } res.Body.Close() if want := []byte("test file"); !bytes.Equal(got, want) { t.Errorf("res.Body: got %q, want %q", got, want) } if got, want := res.ContentLength, int64(len("test file")); got != want { t.Errorf("res.ContentLength: got %v, want %v", got, want) } req, err = http.NewRequest("GET", "/foo/bar.baz", nil) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } _, remove, err = martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("TestContext(): got %v, want no error", err) } defer remove() res = proxyutil.NewResponse(http.StatusOK, nil, req) if err := reqmod.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } if err := resmod.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := res.Header.Get("Content-Type"), "text/plain; charset=utf-8"; got != want { t.Errorf("res.Header.Get('Content-Type'): got %v, want %v", got, want) } got, err = ioutil.ReadAll(res.Body) if err != nil { t.Fatalf("ioutil.ReadAll(): got %v, want no error", err) } res.Body.Close() if want := []byte("test file2"); !bytes.Equal(got, want) { t.Errorf("res.Body: got %q, want %q", got, want) } if got, want := res.ContentLength, int64(len("test file2")); got != want { t.Errorf("res.ContentLength: got %v, want %v", got, want) } } func TestStaticModifierSingleRangeRequest(t *testing.T) { tmpdir, err := ioutil.TempDir("", "test_static_modifier_on_request_") if err != nil { t.Fatalf("ioutil.TempDir(): got %v, want no error", err) } mod := NewModifier(tmpdir) if err := ioutil.WriteFile(path.Join(tmpdir, "sfmtest.txt"), []byte("0123456789"), 0777); err != nil { t.Fatalf("ioutil.WriteFile(): got %v, want no error", err) } req, err := http.NewRequest("GET", "/sfmtest.txt", nil) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } req.Header.Set("Range", "bytes=1-4") _, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("TestContext(): got %v, want no error", err) } defer remove() if err := mod.ModifyRequest(req); err != nil { t.Fatalf("ModifyRequest(): got %v, want no error", err) } res := proxyutil.NewResponse(http.StatusOK, nil, req) if err := mod.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := res.StatusCode, http.StatusPartialContent; got != want { t.Errorf("res.Status: got %v, want %v", got, want) } if got, want := res.ContentLength, int64(len([]byte("1234"))); got != want { t.Errorf("res.ContentLength: got %d, want %d", got, want) } if got, want := res.Header.Get("Content-Range"), "bytes 1-4/10"; got != want { t.Errorf("res.Header.Get(%q): got %q, want %q", "Content-Encoding", got, want) } got, err := ioutil.ReadAll(res.Body) if err != nil { t.Fatalf("ioutil.ReadAll(): got %v, want no error", err) } res.Body.Close() if want := []byte("1234"); !bytes.Equal(got, want) { t.Errorf("res.Body: got %q, want %q", got, want) } if got, want := res.Header.Get("Content-Type"), "text/plain; charset=utf-8"; got != want { t.Errorf("res.Header.Get('Content-Type'): got %v, want %v", got, want) } } martian-3.3.2/status/000077500000000000000000000000001421371434000144745ustar00rootroot00000000000000martian-3.3.2/status/status_modifier.go000066400000000000000000000037021421371434000202260ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package status contains a modifier to rewrite the status code on a response. package status import ( "encoding/json" "fmt" "net/http" "github.com/google/martian/v3" "github.com/google/martian/v3/parse" ) type statusModifier struct { statusCode int } type statusJSON struct { StatusCode int `json:"statusCode"` Scope []parse.ModifierType `json:"scope"` } func init() { parse.Register("status.Modifier", modifierFromJSON) } // ModifyResponse overwrites the status text and code on an HTTP response and // returns nil. func (s *statusModifier) ModifyResponse(res *http.Response) error { res.StatusCode = s.statusCode res.Status = fmt.Sprintf("%d %s", s.statusCode, http.StatusText(s.statusCode)) return nil } // NewModifier constructs a statusModifier that overrides response status // codes with the HTTP status code provided. func NewModifier(statusCode int) martian.ResponseModifier { return &statusModifier{ statusCode: statusCode, } } // modifierFromJSON builds a status.Modifier from JSON. // // Example JSON: // { // "status.Modifier": { // "scope": ["response"], // "statusCode": 401 // } // } func modifierFromJSON(b []byte) (*parse.Result, error) { msg := &statusJSON{} if err := json.Unmarshal(b, msg); err != nil { return nil, err } return parse.NewResult(NewModifier(msg.StatusCode), msg.Scope) } martian-3.3.2/status/status_modifier_test.go000066400000000000000000000041611421371434000212650ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package status import ( "fmt" "net/http" "testing" "github.com/google/martian/v3/parse" "github.com/google/martian/v3/proxyutil" ) func TestFromJSON(t *testing.T) { msg := []byte(`{ "status.Modifier": { "scope": ["response"], "statusCode": 400 } }`) r, err := parse.FromJSON(msg) if err != nil { t.Fatalf("parse.FromJSON(): got %v, want no error", err) } resmod := r.ResponseModifier() if resmod == nil { t.Fatal("resmod: got nil, want not nil") } res := proxyutil.NewResponse(200, nil, nil) if err := resmod.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if got, want := res.StatusCode, 400; got != want { t.Errorf("res.StatusCode: got %d, want %d", got, want) } } func TestStatusModifierOnResponse(t *testing.T) { for i, status := range []int{ http.StatusForbidden, http.StatusOK, http.StatusTemporaryRedirect, } { req, err := http.NewRequest("GET", "/", nil) if err != nil { t.Fatalf("NewRequest(): got %v, want no error", err) } res := proxyutil.NewResponse(200, nil, req) mod := NewModifier(status) if err := mod.ModifyResponse(res); err != nil { t.Fatalf("%d. ModifyResponse(): got %v, want no error", i, err) } if got, want := res.StatusCode, status; got != want { t.Errorf("%d. res.StatusCode: got %v, want %v", i, got, want) } if got, want := res.Status, fmt.Sprintf("%d %s", res.StatusCode, http.StatusText(status)); got != want { t.Errorf("%d. res.Status: got %q, want %q", i, got, want) } } } martian-3.3.2/status/status_verifier.go000066400000000000000000000050071421371434000202430ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package status import ( "encoding/json" "fmt" "net/http" "github.com/google/martian/v3" "github.com/google/martian/v3/parse" "github.com/google/martian/v3/verify" ) const errFormat = "response(%s) status code verify failure: got %d, want %d" // Verifier verifies the status codes of all responses. type Verifier struct { statusCode int err *martian.MultiError } type verifierJSON struct { StatusCode int `json:"statusCode"` Scope []parse.ModifierType `json:"scope"` } func init() { parse.Register("status.Verifier", verifierFromJSON) } // NewVerifier returns a new status.Verifier for statusCode. func NewVerifier(statusCode int) verify.ResponseVerifier { return &Verifier{ statusCode: statusCode, err: martian.NewMultiError(), } } // ModifyResponse verifies that the status code for all requests // matches statusCode. func (v *Verifier) ModifyResponse(res *http.Response) error { ctx := martian.NewContext(res.Request) if ctx.IsAPIRequest() { return nil } if res.StatusCode != v.statusCode { v.err.Add(fmt.Errorf(errFormat, res.Request.URL, res.StatusCode, v.statusCode)) } return nil } // VerifyResponses returns an error if verification for any // request failed. // If an error is returned it will be of type *martian.MultiError. func (v *Verifier) VerifyResponses() error { if v.err.Empty() { return nil } return v.err } // ResetResponseVerifications clears all failed response verifications. func (v *Verifier) ResetResponseVerifications() { v.err = martian.NewMultiError() } // verifierFromJSON builds a status.Verifier from JSON. // // Example JSON: // { // "status.Verifier": { // "scope": ["response"], // "statusCode": 401 // } // } func verifierFromJSON(b []byte) (*parse.Result, error) { msg := &verifierJSON{} if err := json.Unmarshal(b, msg); err != nil { return nil, err } return parse.NewResult(NewVerifier(msg.StatusCode), msg.Scope) } martian-3.3.2/status/status_verifier_test.go000066400000000000000000000063731421371434000213110ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package status import ( "net/http" "testing" "github.com/google/martian/v3" "github.com/google/martian/v3/parse" "github.com/google/martian/v3/proxyutil" "github.com/google/martian/v3/verify" ) func TestVerifyResponses(t *testing.T) { v := NewVerifier(301) tt := []struct { got int want string }{ {200, "response(http://www.example.com) status code verify failure: got 200, want 301"}, {302, "response(http://www.example.com) status code verify failure: got 302, want 301"}, {400, "response(http://www.example.com) status code verify failure: got 400, want 301"}, } for i, tc := range tt { req, err := http.NewRequest("GET", "http://www.example.com", nil) if err != nil { t.Fatalf("%d. http.NewRequest(): got %v, want no error", i, err) } _, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("TestContext(): got %v, want no error", err) } defer remove() res := proxyutil.NewResponse(tc.got, nil, req) if err := v.ModifyResponse(res); err != nil { t.Fatalf("%d. ModifyResponse(): got %v, want no error", i, err) } } merr, ok := v.VerifyResponses().(*martian.MultiError) if !ok { t.Fatal("VerifyResponses(): got nil, want *verify.MultiError") } errs := merr.Errors() if got, want := len(errs), len(tt); got != want { t.Fatalf("len(merr.Errors(): got %d, want %d", got, want) } for i, tc := range tt { if got, want := errs[i].Error(), tc.want; got != want { t.Errorf("%d. merr.Errors(): got %q, want %q", i, got, want) } } v.ResetResponseVerifications() if err := v.VerifyResponses(); err != nil { t.Errorf("v.VerifyResponses(): got %v, want no error", err) } } func TestVerifierFromJSON(t *testing.T) { msg := []byte(`{ "status.Verifier": { "scope": ["response"], "statusCode": 400 } }`) r, err := parse.FromJSON(msg) if err != nil { t.Fatalf("parse.FromJSON(): got %v, want no error", err) } resmod := r.ResponseModifier() if resmod == nil { t.Fatal("resmod: got nil, want not nil") } resv, ok := resmod.(verify.ResponseVerifier) if !ok { t.Fatal("reqmod.(verify.RequestVerifier): got !ok, want ok") } req, err := http.NewRequest("GET", "http://www.example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } _, remove, err := martian.TestContext(req, nil, nil) if err != nil { t.Fatalf("TestContext(): got %v, want no error", err) } defer remove() res := proxyutil.NewResponse(200, nil, req) if err := resv.ModifyResponse(res); err != nil { t.Fatalf("ModifyResponse(): got %v, want no error", err) } if err := resv.VerifyResponses(); err == nil { t.Error("VerifyResponses(): got nil, want not nil") } } martian-3.3.2/trafficshape/000077500000000000000000000000001421371434000156105ustar00rootroot00000000000000martian-3.3.2/trafficshape/bucket.go000066400000000000000000000136311421371434000174200ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package trafficshape import ( "errors" "sync" "sync/atomic" "time" "github.com/google/martian/v3/log" ) // Bucket is a generic leaky bucket that drains at a configurable interval and // fills at user defined rate. The bucket may be used concurrently. type Bucket struct { capacity int64 // atomic fill int64 // atomic mu sync.Mutex t *time.Ticker closec chan struct{} } var ( // ErrBucketOverflow is an error that indicates the bucket has been overflown // by the user. This error is only returned iff fill > capacity. ErrBucketOverflow = errors.New("trafficshape: bucket overflow") errFillClosedBucket = errors.New("trafficshape: fill on closed bucket") ) // NewBucket returns a new leaky bucket with capacity that is drained // at interval. func NewBucket(capacity int64, interval time.Duration) *Bucket { b := &Bucket{ capacity: capacity, t: time.NewTicker(interval), closec: make(chan struct{}), } go b.loop() return b } // Capacity returns the capacity of the bucket. func (b *Bucket) Capacity() int64 { return atomic.LoadInt64(&b.capacity) } // SetCapacity sets the capacity for the bucket and resets the fill to zero. func (b *Bucket) SetCapacity(capacity int64) { log.Infof("trafficshape: set capacity: %d", capacity) atomic.StoreInt64(&b.capacity, capacity) atomic.StoreInt64(&b.fill, 0) } // Close stops the drain loop and marks the bucket as closed. func (b *Bucket) Close() error { log.Debugf("trafficshape: closing bucket") // Allow b to be closed multiple times without panicking. if b.closed() { return nil } b.t.Stop() close(b.closec) return nil } // FillThrottle calls fn with the available capacity remaining (capacity-fill) // and fills the bucket with the number of tokens returned by fn. If the // remaining capacity is <= 0, FillThrottle will wait for the next drain before // running fn. // // If fn returns an error, it will be returned by FillThrottle along with the // number of tokens processed by fn. // // fn is provided the remaining capacity as a soft maximum, fn is allowed to // use more than the remaining capacity without incurring spillage. // // If the bucket is closed when FillThrottle is called, or while waiting for // the next drain, fn will not be executed and FillThrottle will return with an // error. func (b *Bucket) FillThrottle(fn func(int64) (int64, error)) (int64, error) { for { if b.closed() { log.Errorf("trafficshape: fill on closed bucket") return 0, errFillClosedBucket } fill := atomic.LoadInt64(&b.fill) capacity := atomic.LoadInt64(&b.capacity) if fill < capacity { log.Debugf("trafficshape: under capacity (%d/%d)", fill, capacity) n, err := fn(capacity - fill) fill = atomic.AddInt64(&b.fill, n) return n, err } log.Debugf("trafficshape: bucket full (%d/%d)", fill, capacity) } } // FillThrottleLocked is like FillThrottle, except that it uses a lock to protect // the critical section between accessing the fill value and updating it. func (b *Bucket) FillThrottleLocked(fn func(int64) (int64, error)) (int64, error) { for { if b.closed() { log.Errorf("trafficshape: fill on closed bucket") return 0, errFillClosedBucket } b.mu.Lock() fill := atomic.LoadInt64(&b.fill) capacity := atomic.LoadInt64(&b.capacity) if fill < capacity { n, err := fn(capacity - fill) fill = atomic.AddInt64(&b.fill, n) b.mu.Unlock() return n, err } b.mu.Unlock() log.Debugf("trafficshape: bucket full (%d/%d)", fill, capacity) } } // Fill calls fn with the available capacity remaining (capacity-fill) and // fills the bucket with the number of tokens returned by fn. If the remaining // capacity is 0, Fill returns 0, nil. If the remaining capacity is < 0, Fill // returns 0, ErrBucketOverflow. // // If fn returns an error, it will be returned by Fill along with the remaining // capacity. // // fn is provided the remaining capacity as a soft maximum, fn is allowed to // use more than the remaining capacity without incurring spillage, though this // will cause subsequent calls to Fill to return ErrBucketOverflow until the // next drain. // // If the bucket is closed when Fill is called, fn will not be executed and // Fill will return with an error. func (b *Bucket) Fill(fn func(int64) (int64, error)) (int64, error) { if b.closed() { log.Errorf("trafficshape: fill on closed bucket") return 0, errFillClosedBucket } fill := atomic.LoadInt64(&b.fill) capacity := atomic.LoadInt64(&b.capacity) switch { case fill < capacity: log.Debugf("trafficshape: under capacity (%d/%d)", fill, capacity) n, err := fn(capacity - fill) fill = atomic.AddInt64(&b.fill, n) return n, err case fill > capacity: log.Debugf("trafficshape: bucket overflow (%d/%d)", fill, capacity) return 0, ErrBucketOverflow } log.Debugf("trafficshape: bucket full (%d/%d)", fill, capacity) return 0, nil } // loop drains the fill at interval and returns when the bucket is closed. func (b *Bucket) loop() { log.Debugf("trafficshape: started drain loop") defer log.Debugf("trafficshape: stopped drain loop") for { select { case t := <-b.t.C: atomic.StoreInt64(&b.fill, 0) log.Debugf("trafficshape: fill reset @ %s", t) case <-b.closec: log.Debugf("trafficshape: bucket closed") return } } } func (b *Bucket) closed() bool { select { case <-b.closec: return true default: return false } } martian-3.3.2/trafficshape/bucket_test.go000066400000000000000000000105021421371434000204510ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package trafficshape import ( "errors" "runtime" "sync/atomic" "testing" "time" ) func TestBucket(t *testing.T) { t.Parallel() b := NewBucket(10, 10*time.Millisecond) defer b.Close() if got, want := b.Capacity(), int64(10); got != want { t.Fatalf("b.Capacity(): got %d, want %d", got, want) } n, err := b.Fill(func(remaining int64) (int64, error) { if want := int64(10); remaining != want { t.Errorf("remaining: got %d, want %d", remaining, want) } return 5, nil }) if err != nil { t.Fatalf("Fill(): got %v, want no error", err) } if got, want := n, int64(5); got != want { t.Fatalf("n: got %d, want %d", got, want) } n, err = b.Fill(func(remaining int64) (int64, error) { if want := int64(5); remaining != want { t.Errorf("remaining: got %d, want %d", remaining, want) } return 5, nil }) if err != nil { t.Fatalf("Fill(): got %v, want no error", err) } if got, want := n, int64(5); got != want { t.Fatalf("n: got %d, want %d", got, want) } n, err = b.Fill(func(remaining int64) (int64, error) { t.Fatal("Fill: executed func when full, want skipped") return 0, nil }) if err != nil { t.Fatalf("Fill(): got %v, want no error", err) } // Wait for the bucket to drain. for { if atomic.LoadInt64(&b.fill) == 0 { break } // Allow for a goroutine switch, required for GOMAXPROCS = 1. runtime.Gosched() } wanterr := errors.New("fill function error") n, err = b.Fill(func(remaining int64) (int64, error) { if want := int64(10); remaining != want { t.Errorf("remaining: got %d, want %d", remaining, want) } return 0, wanterr }) if err != wanterr { t.Fatalf("Fill(): got %v, want %v", err, wanterr) } if got, want := n, int64(0); got != want { t.Fatalf("n: got %d, want %d", got, want) } } func TestBucketClosed(t *testing.T) { t.Parallel() b := NewBucket(0, time.Millisecond) b.Close() if _, err := b.Fill(nil); err != errFillClosedBucket { t.Errorf("Fill(): got %v, want errFillClosedBucket", err) } if _, err := b.FillThrottle(nil); err != errFillClosedBucket { t.Errorf("FillThrottle(): got %v, want errFillClosedBucket", err) } } func TestBucketOverflow(t *testing.T) { t.Parallel() b := NewBucket(10, 10*time.Millisecond) defer b.Close() n, err := b.Fill(func(remaining int64) (int64, error) { return 11, nil }) if err != nil { t.Fatalf("Fill(): got %v, want no error", err) } n, err = b.Fill(func(int64) (int64, error) { t.Fatal("Fill: executed func when full, want skipped") return 0, nil }) if err != ErrBucketOverflow { t.Fatalf("Fill(): got %v, want ErrBucketOverflow", err) } if got, want := n, int64(0); got != want { t.Fatalf("n: got %d, want %d", got, want) } } func TestBucketThrottle(t *testing.T) { t.Parallel() b := NewBucket(50, 50*time.Millisecond) defer b.Close() closec := make(chan struct{}) errc := make(chan error, 1) fill := func() { for { select { case <-closec: return default: if _, err := b.FillThrottle(func(remaining int64) (int64, error) { if remaining < 10 { return remaining, nil } return 10, nil }); err != nil { select { case errc <- err: default: } } } } } for i := 0; i < 5; i++ { go fill() } time.Sleep(time.Second) close(closec) select { case err := <-errc: t.Fatalf("FillThrottle: got %v, want no error", err) default: } } func TestBucketFillThrottleCloseBeforeTick(t *testing.T) { t.Parallel() b := NewBucket(0, time.Minute) time.AfterFunc(time.Second, func() { b.Close() }) if _, err := b.FillThrottle(func(int64) (int64, error) { t.Fatal("FillThrottle(): executed func after close, want skipped") return 0, nil }); err != errFillClosedBucket { t.Errorf("b.FillThrottle(): got nil, want errFillClosedBucket") } } martian-3.3.2/trafficshape/conn.go000066400000000000000000000345601421371434000171040ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package trafficshape import ( "io" "net" "sort" "sync" "time" "github.com/google/martian/v3/log" ) // Conn wraps a net.Conn and simulates connection latency and bandwidth // charateristics. type Conn struct { Context *Context // Shapes represents the traffic shape map inherited from the listener. Shapes *urlShapes GlobalBuckets map[string]*Bucket // LocalBuckets represents a map from the url_regexes to their dedicated buckets. LocalBuckets map[string]*Buckets Established time.Time // Established is the time that the connection is established. DefaultBandwidth Bandwidth Listener *Listener ReadBucket *Bucket // Shared by listener. WriteBucket *Bucket // Shared by listener. conn net.Conn latency time.Duration ronce sync.Once wonce sync.Once } // Read reads bytes from connection into b, optionally simulating connection // latency and throttling read throughput based on desired bandwidth // constraints. func (c *Conn) Read(b []byte) (int, error) { c.ronce.Do(c.sleepLatency) n, err := c.ReadBucket.FillThrottle(func(remaining int64) (int64, error) { max := remaining if l := int64(len(b)); max > l { max = l } n, err := c.conn.Read(b[:max]) return int64(n), err }) if err != nil && err != io.EOF { log.Errorf("trafficshape: error on throttled read: %v", err) } return int(n), err } // ReadFrom reads data from r until EOF or error, optionally simulating // connection latency and throttling read throughput based on desired bandwidth // constraints. func (c *Conn) ReadFrom(r io.Reader) (int64, error) { c.ronce.Do(c.sleepLatency) var total int64 for { n, err := c.ReadBucket.FillThrottle(func(remaining int64) (int64, error) { return io.CopyN(c.conn, r, remaining) }) total += n if err == io.EOF { log.Debugf("trafficshape: exhausted reader successfully") return total, nil } else if err != nil { log.Errorf("trafficshape: failed copying from reader: %v", err) return total, err } } } // Close closes the connection. // Any blocked Read or Write operations will be unblocked and return errors. func (c *Conn) Close() error { return c.conn.Close() } // LocalAddr returns the local network address. func (c *Conn) LocalAddr() net.Addr { return c.conn.LocalAddr() } // RemoteAddr returns the remote network address. func (c *Conn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() } // SetDeadline sets the read and write deadlines associated // with the connection. It is equivalent to calling both // SetReadDeadline and SetWriteDeadline. // // A deadline is an absolute time after which I/O operations // fail with a timeout (see type Error) instead of // blocking. The deadline applies to all future and pending // I/O, not just the immediately following call to Read or // Write. After a deadline has been exceeded, the connection // can be refreshed by setting a deadline in the future. // // An idle timeout can be implemented by repeatedly extending // the deadline after successful Read or Write calls. // // A zero value for t means I/O operations will not time out. // // Note that if a TCP connection has keep-alive turned on, // which is the default unless overridden by Dialer.KeepAlive // or ListenConfig.KeepAlive, then a keep-alive failure may // also return a timeout error. On Unix systems a keep-alive // failure on I/O can be detected using // errors.Is(err, syscall.ETIMEDOUT). func (c *Conn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) } // SetReadDeadline sets the deadline for future Read calls // and any currently-blocked Read call. // A zero value for t means Read will not time out. func (c *Conn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) } // SetWriteDeadline sets the deadline for future Write calls // and any currently-blocked Write call. // Even if write times out, it may return n > 0, indicating that // some of the data was successfully written. // A zero value for t means Write will not time out. func (c *Conn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) } // GetWrappedConn returns the undrelying trafficshaped net.Conn. func (c *Conn) GetWrappedConn() net.Conn { return c.conn } // WriteTo writes data to w from the connection, optionally simulating // connection latency and throttling write throughput based on desired // bandwidth constraints. func (c *Conn) WriteTo(w io.Writer) (int64, error) { c.wonce.Do(c.sleepLatency) var total int64 for { n, err := c.WriteBucket.FillThrottle(func(remaining int64) (int64, error) { return io.CopyN(w, c.conn, remaining) }) total += n if err != nil { if err != io.EOF { log.Errorf("trafficshape: failed copying to writer: %v", err) } return total, err } } } func min(x, y int64) int64 { if x < y { return x } return y } // CheckExistenceAndValidity checks that the current url regex is present in the map, and that // the connection was established before the url shape map was last updated. We do not allow the // updated url shape map to traffic shape older connections. // Important: Assumes you have acquired the required locks and will release them youself. func (c *Conn) CheckExistenceAndValidity(URLRegex string) bool { shapeStillValid := c.Shapes.LastModifiedTime.Before(c.Established) _, p := c.Shapes.M[URLRegex] return p && shapeStillValid } // GetCurrentThrottle uses binary search to determine if the current byte offset ('start') // lies within a throttle interval. If so, also returns the bandwidth specified for that interval. func (c *Conn) GetCurrentThrottle(start int64) *ThrottleContext { c.Shapes.RLock() defer c.Shapes.RUnlock() if !c.CheckExistenceAndValidity(c.Context.URLRegex) { log.Debugf("existence check failed") return &ThrottleContext{ ThrottleNow: false, } } c.Shapes.M[c.Context.URLRegex].RLock() defer c.Shapes.M[c.Context.URLRegex].RUnlock() throttles := c.Shapes.M[c.Context.URLRegex].Shape.Throttles if l := len(throttles); l != 0 { // ind is the first index in throttles with ByteStart > start. // Once we get ind, we can check the previous throttle, if any, // to see if its ByteEnd is after 'start'. ind := sort.Search(len(throttles), func(i int) bool { return throttles[i].ByteStart > start }) // All throttles have Bytestart > start, hence not in throttle. if ind == 0 { return &ThrottleContext{ ThrottleNow: false, } } // No throttle has Bytestart > start, so check the last throttle to // see if it ends after 'start'. Note: the last throttle is special // since it can have -1 (meaning infinity) as the ByteEnd. if ind == l { if throttles[l-1].ByteEnd > start || throttles[l-1].ByteEnd == -1 { return &ThrottleContext{ ThrottleNow: true, Bandwidth: throttles[l-1].Bandwidth, } } return &ThrottleContext{ ThrottleNow: false, } } // Check the previous throttle to see if it ends after 'start'. if throttles[ind-1].ByteEnd > start { return &ThrottleContext{ ThrottleNow: true, Bandwidth: throttles[ind-1].Bandwidth, } } return &ThrottleContext{ ThrottleNow: false, } } return &ThrottleContext{ ThrottleNow: false, } } // GetNextActionFromByte takes in a byte offset and uses binary search to determine the upcoming // action, i.e the first action after the byte that still has a non zero count. func (c *Conn) GetNextActionFromByte(start int64) *NextActionInfo { c.Shapes.RLock() defer c.Shapes.RUnlock() if !c.CheckExistenceAndValidity(c.Context.URLRegex) { log.Debugf("existence check failed") return &NextActionInfo{ ActionNext: false, } } c.Shapes.M[c.Context.URLRegex].RLock() defer c.Shapes.M[c.Context.URLRegex].RUnlock() actions := c.Shapes.M[c.Context.URLRegex].Shape.Actions if l := len(actions); l != 0 { ind := sort.Search(len(actions), func(i int) bool { return actions[i].getByte() >= start }) return c.GetNextActionFromIndex(int64(ind)) } return &NextActionInfo{ ActionNext: false, } } // GetNextActionFromIndex takes in an index and returns the first action after the index that // has a non zero count, if there is one. func (c *Conn) GetNextActionFromIndex(ind int64) *NextActionInfo { c.Shapes.RLock() defer c.Shapes.RUnlock() if !c.CheckExistenceAndValidity(c.Context.URLRegex) { return &NextActionInfo{ ActionNext: false, } } c.Shapes.M[c.Context.URLRegex].RLock() defer c.Shapes.M[c.Context.URLRegex].RUnlock() actions := c.Shapes.M[c.Context.URLRegex].Shape.Actions if l := int64(len(actions)); l != 0 { for ind < l && (actions[ind].getCount() == 0) { ind++ } if ind >= l { return &NextActionInfo{ ActionNext: false, } } return &NextActionInfo{ ActionNext: true, Index: ind, ByteOffset: actions[ind].getByte(), } } return &NextActionInfo{ ActionNext: false, } } // WriteDefaultBuckets writes bytes from b to the connection, optionally simulating // connection latency and throttling write throughput based on desired // bandwidth constraints. It uses the WriteBucket inherited from the listener. func (c *Conn) WriteDefaultBuckets(b []byte) (int, error) { c.wonce.Do(c.sleepLatency) var total int64 for len(b) > 0 { var max int64 n, err := c.WriteBucket.FillThrottle(func(remaining int64) (int64, error) { max = remaining if l := int64(len(b)); remaining >= l { max = l } n, err := c.conn.Write(b[:max]) return int64(n), err }) total += n if err != nil { if err != io.EOF { log.Errorf("trafficshape: failed write: %v", err) } return int(total), err } b = b[max:] } return int(total), nil } // Write writes bytes from b to the connection, while enforcing throttles and performing actions. // It uses and updates the Context in the connection. func (c *Conn) Write(b []byte) (int, error) { if !c.Context.Shaping { return c.WriteDefaultBuckets(b) } c.wonce.Do(c.sleepLatency) var total int64 // Write the header if needed, without enforcing any traffic shaping, and without updating // ByteOffset. if headerToWrite := c.Context.HeaderLen - c.Context.HeaderBytesWritten; headerToWrite > 0 { writeAmount := min(int64(len(b)), headerToWrite) n, err := c.conn.Write(b[:writeAmount]) if err != nil { if err != io.EOF { log.Errorf("trafficshape: failed write: %v", err) } return int(n), err } c.Context.HeaderBytesWritten += writeAmount total += writeAmount b = b[writeAmount:] } var amountToWrite int64 for len(b) > 0 { var max int64 // Determine the amount to be written up till the next action. amountToWrite = int64(len(b)) if c.Context.NextActionInfo.ActionNext { amountTillNextAction := c.Context.NextActionInfo.ByteOffset - c.Context.ByteOffset if amountTillNextAction <= amountToWrite { amountToWrite = amountTillNextAction } } // Write into both the local and global buckets, as well as the underlying connection. n, err := c.Context.Buckets.WriteBucket.FillThrottleLocked(func(remaining int64) (int64, error) { max = min(remaining, amountToWrite) if max == 0 { return 0, nil } return c.Context.GlobalBucket.FillThrottleLocked(func(rem int64) (int64, error) { max = min(rem, max) n, err := c.conn.Write(b[:max]) return int64(n), err }) }) if err != nil { if err != io.EOF { log.Errorf("trafficshape: failed write: %v", err) } return int(total), err } // Update the current byte offset. c.Context.ByteOffset += n total += n b = b[max:] // Check if there was an upcoming action, and that the byte offset matches the action's byte. if c.Context.NextActionInfo.ActionNext && c.Context.ByteOffset >= c.Context.NextActionInfo.ByteOffset { // Note here, we check again that the url shape map is still valid and that the action still has // a non zero count, since that could have been modified since the last time we checked. ind := c.Context.NextActionInfo.Index c.Shapes.RLock() if !c.CheckExistenceAndValidity(c.Context.URLRegex) { c.Shapes.RUnlock() // Write the remaining b using default buckets, and set Shaping as false // so that subsequent calls to Write() also use default buckets // without performing any actions. c.Context.Shaping = false writeTotal, e := c.WriteDefaultBuckets(b) return int(total) + writeTotal, e } c.Shapes.M[c.Context.URLRegex].Lock() actions := c.Shapes.M[c.Context.URLRegex].Shape.Actions if actions[ind].getCount() != 0 { // Update the action count, determine the type of action and perform it. actions[ind].decrementCount() switch action := actions[ind].(type) { case *Halt: d := action.Duration log.Debugf("trafficshape: Sleeping for time %d ms for urlregex %s at byte offset %d", d, c.Context.URLRegex, c.Context.ByteOffset) c.Shapes.M[c.Context.URLRegex].Unlock() c.Shapes.RUnlock() time.Sleep(time.Duration(d) * time.Millisecond) case *CloseConnection: log.Infof("trafficshape: Closing connection for urlregex %s at byte offset %d", c.Context.URLRegex, c.Context.ByteOffset) c.Shapes.M[c.Context.URLRegex].Unlock() c.Shapes.RUnlock() return int(total), &ErrForceClose{message: "Forcing close connection"} case *ChangeBandwidth: bw := action.Bandwidth log.Infof("trafficshape: Changing connection bandwidth to %d for urlregex %s at byte offset %d", bw, c.Context.URLRegex, c.Context.ByteOffset) c.Shapes.M[c.Context.URLRegex].Unlock() c.Shapes.RUnlock() c.Context.Buckets.WriteBucket.SetCapacity(bw) default: c.Shapes.M[c.Context.URLRegex].Unlock() c.Shapes.RUnlock() } } else { c.Shapes.M[c.Context.URLRegex].Unlock() c.Shapes.RUnlock() } // Get the next action to be performed, if any. c.Context.NextActionInfo = c.GetNextActionFromIndex(ind + 1) } } return int(total), nil } func (c *Conn) sleepLatency() { log.Debugf("trafficshape: simulating latency: %s", c.latency) time.Sleep(c.latency) } martian-3.3.2/trafficshape/doc.go000066400000000000000000000013341421371434000167050ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package trafficshape provides tools for simulating latency and bandwidth at // the network layer. package trafficshape martian-3.3.2/trafficshape/handler.go000066400000000000000000000143571421371434000175660ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package trafficshape import ( "bytes" "encoding/json" "io" "io/ioutil" "net/http" "time" "github.com/google/martian/v3/log" ) // Handler configures a trafficshape.Listener. type Handler struct { l *Listener } // Throttle represents a byte interval with a specific bandwidth. type Throttle struct { Bytes string `json:"bytes"` Bandwidth int64 `json:"bandwidth"` ByteStart int64 ByteEnd int64 } // Action represents an arbitrary event that needs to be executed while writing back to the client. type Action interface { // Byte offset to perform Action at. getByte() int64 // Number of times to perform the action. -1 for infinite times. getCount() int64 // Update the count when performing an action. decrementCount() } // Halt is the event that represents a period of time to sleep while writing. // It implements the Action interface. type Halt struct { Byte int64 `json:"byte"` Duration int64 `json:"duration"` Count int64 `json:"count"` } func (h *Halt) getByte() int64 { return h.Byte } func (h *Halt) getCount() int64 { return h.Count } func (h *Halt) decrementCount() { if h.Count > 0 { h.Count-- } } // CloseConnection is an event that represents the closing of a connection with a client. // It implements the Action interface. type CloseConnection struct { Byte int64 `json:"byte"` Count int64 `json:"count"` } func (cc *CloseConnection) getByte() int64 { return cc.Byte } func (cc *CloseConnection) getCount() int64 { return cc.Count } func (cc *CloseConnection) decrementCount() { if cc.Count > 0 { cc.Count-- } } // Shape encloses the traffic shape of a particular url regex. type Shape struct { URLRegex string `json:"url_regex"` MaxBandwidth int64 `json:"max_global_bandwidth"` Throttles []*Throttle `json:"throttles"` Halts []*Halt `json:"halts"` CloseConnections []*CloseConnection `json:"close_connections"` // Actions are populated after processing Throttles, Halts and CloseConnections. // Actions is sorted in the order of byte offset. Actions []Action // WriteBucket is initialized by us using MaxBandwidth. WriteBucket *Bucket } // Bandwidth encloses information about the upstream and downstream bandwidths. type Bandwidth struct { Up int64 `json:"up"` Down int64 `json:"down"` } // Default encloses information about the default traffic shaping parameters: bandwidth and latency. type Default struct { Bandwidth Bandwidth `json:"bandwidth"` Latency int64 `json:"latency"` } // Trafficshape contains global shape of traffic, i.e information about shape of each url specified and // the default traffic shaping parameters. type Trafficshape struct { Defaults *Default `json:"default"` Shapes []*Shape `json:"shapes"` } // ConfigRequest represents a request to configure the global traffic shape. type ConfigRequest struct { Trafficshape *Trafficshape `json:"trafficshape"` } // ChangeBandwidth represents the event of changing the current bandwidth. It is used as an // endpoint of a Throttle. It implements the Action interface. type ChangeBandwidth struct { Byte int64 Bandwidth int64 } func (cb *ChangeBandwidth) getByte() int64 { return cb.Byte } func (cb *ChangeBandwidth) getCount() int64 { return -1 } // No op. This is because Throttles have infinite count. func (cb *ChangeBandwidth) decrementCount() { } // NewHandler returns an http.Handler to configure traffic shaping. func NewHandler(l *Listener) *Handler { return &Handler{ l: l, } } // ServeHTTP configures latency and bandwidth constraints. // // The "latency" query string parameter accepts a duration string in any format // supported by time.ParseDuration. // The "up" and "down" query string parameters accept integers as bits per // second to be used for read and write throughput. func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { log.Infof("trafficshape: configuration request") receivedConfig := &ConfigRequest{} body, err := ioutil.ReadAll(req.Body) if err != nil { http.Error(rw, "Error reading request body", 400) return } bodystr := string(body) req.Body = ioutil.NopCloser(bytes.NewBuffer(body)) if err := json.NewDecoder(req.Body).Decode(&receivedConfig); err != nil { log.Errorf("Error while parsing the received json: %v", err) http.Error(rw, err.Error(), 400) return } if receivedConfig.Trafficshape == nil { http.Error(rw, "Error: trafficshape property not found", 400) return } defaults := receivedConfig.Trafficshape.Defaults if defaults == nil { defaults = &Default{} } if defaults.Bandwidth.Up < 0 || defaults.Bandwidth.Down < 0 || defaults.Latency < 0 { http.Error(rw, "Error: Invalid Defaults", 400) return } if defaults.Bandwidth.Up == 0 { defaults.Bandwidth.Up = DefaultBitrate / 8 } if defaults.Bandwidth.Down == 0 { defaults.Bandwidth.Down = DefaultBitrate / 8 } // Parse and verify the received shapes. if err := parseShapes(receivedConfig.Trafficshape); err != nil { http.Error(rw, err.Error(), 400) return } // Update the Listener with the new traffic shape. h.l.Shapes.Lock() h.l.Shapes.LastModifiedTime = time.Now() h.l.ReadBucket.SetCapacity(defaults.Bandwidth.Down) h.l.WriteBucket.SetCapacity(defaults.Bandwidth.Up) h.l.SetLatency(time.Duration(defaults.Latency) * time.Millisecond) h.l.SetDefaults(defaults) h.l.Shapes.M = make(map[string]*urlShape) for _, shape := range receivedConfig.Trafficshape.Shapes { h.l.Shapes.M[shape.URLRegex] = &urlShape{Shape: shape} } // Update the time that the map was last modified to the current time. h.l.Shapes.LastModifiedTime = time.Now() h.l.Shapes.Unlock() rw.WriteHeader(http.StatusOK) io.WriteString(rw, bodystr) } martian-3.3.2/trafficshape/handler_test.go000066400000000000000000000205551421371434000206220ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package trafficshape import ( "bytes" "fmt" "net" "net/http" "net/http/httptest" "testing" "time" ) func compareActions(testSlice []Action, refSlice []Action) (bool, string) { if len(testSlice) != len(refSlice) { return false, fmt.Sprintf("length: got %d, want %d", len(testSlice), len(refSlice)) } for i, action := range refSlice { failure := false switch refAction := action.(type) { case *Halt: if testAction, ok := testSlice[i].(*Halt); ok { if *testAction != *refAction { failure = true } } else { failure = true } case *CloseConnection: if testAction, ok := testSlice[i].(*CloseConnection); ok { if *testAction != *refAction { failure = true } } else { failure = true } case *ChangeBandwidth: if testAction, ok := testSlice[i].(*ChangeBandwidth); ok { if *testAction != *refAction { failure = true } } else { failure = true } } if failure { return false, fmt.Sprintf("Action %d: got %+v, want %+v", i, testSlice[i], action) } } return true, "" } func TestHandlerIncorrectInputs(t *testing.T) { l, err := net.Listen("tcp", "[::]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } tt := []struct { testcase string body string }{ { testcase: `overlapping throttle`, body: `{"trafficshape":{"shapes":[{"url_regex":"http://example/example","throttles":[{"bytes":"500-1000","bandwidth":100},{"bytes":"700-2000","bandwidth":100}],"close_connections":[{"byte":1078,"count":1}]}]}}`, }, { testcase: `negative bandwidth`, body: `{"trafficshape":{"shapes":[{"url_regex":"http://example/example","throttles":[{"bytes":"500-","bandwidth":abc}],"close_connections":[{"byte":1078,"count":1}]}]}}`, }, { testcase: `negative close byte`, body: `{"trafficshape":{"shapes":[{"url_regex":"http://example/example","throttles":[{"bytes":"500-1000","bandwidth":100}],"close_connections":[{"byte":-1,"count":1}]}]}}`, }, { testcase: `uncompiling regex`, body: `{"trafficshape":{"shapes":[{"url_regex":"http://example/example(","throttles":[{"bytes":"500-1000","bandwidth":100}],"close_connections":[{"byte":100,"count":1}]}]}}`, }, { testcase: `missing count`, body: `{"trafficshape":{"shapes":[{"url_regex":"http://example/example","throttles":[{"bytes":"500-1000","bandwidth":100}],"halts":[{"byte":100}]}]}}`, }, { testcase: `illformed byte range`, body: `{"trafficshape":{"shapes":[{"url_regex":"http://example/example","throttles":[{"bytes":"500--1000","bandwidth":100}],"close_connections":[{"byte":10,"count":1}]}]}}`, }, { testcase: `throttle end < start`, body: `{"trafficshape":{"shapes":[{"url_regex":"http://example/example","throttles":[{"bytes":"500-255","bandwidth":100}],"close_connections":[{"byte":100,"count":1}]}]}}`, }, { testcase: `missing comma`, body: `{"trafficshape":{"shapes":[{"url_regex":"http://example/example","throttles":[{"bytes":"500-1000","bandwidth":100}]"close_connections":[{"byte":100,"count":1}]}]}}`, }, { testcase: `missing regex`, body: `{"trafficshape":{"shapes":[{"throttles":[{"bytes":"500-1000","bandwidth":100}],"close_connections":[{"byte":-1,"count":1}]}]}}`, }, { testcase: `negative default bandwidth`, body: `{"trafficshape":{"default":{"bandwidth":{"up":-100000,"down":100000},"latency":1000},"shapes":[{"url_regex":"http://example/example","throttles":[{"bytes":"500-1000","bandwidth":100}]",close_connections":[{"byte":100,"count":1}]}]}}`, }, { testcase: `negative default latency`, body: `{"trafficshape":{"default":{"bandwidth":{"up":100000,"down":100000},"latency":-1000},"shapes":[{"url_regex":"http://example/example","throttles":[{"bytes":"500-1000","bandwidth":100}]",close_connections":[{"byte":100,"count":1}]}]}}`, }, } for i, tc := range tt { t.Logf("case %d: %s", i+1, tc.testcase) tsl := NewListener(l) defer tsl.Close() h := NewHandler(tsl) req, err := http.NewRequest("POST", "test", bytes.NewBufferString(tc.body)) if err != nil { t.Fatalf("%d. http.NewRequest(): got %v, want no error", i, err) } rw := httptest.NewRecorder() h.ServeHTTP(rw, req) if got := rw.Code; got != 400 { t.Errorf("%d. rw.Code: got %d, want %d", i+1, got, 400) } } } func TestHandlerClear(t *testing.T) { l, err := net.Listen("tcp", "[::]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } tsl := NewListener(l) defer tsl.Close() h := NewHandler(tsl) startTime := time.Now() jsonString := `{"trafficshape":{}}` req, err := http.NewRequest("POST", "test", bytes.NewBufferString(jsonString)) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } rw := httptest.NewRecorder() h.ServeHTTP(rw, req) if got, want := rw.Code, 200; got != want { t.Errorf(" rw.Code: got %d, want %d", got, want) } defaults := tsl.Defaults() if got, want := defaults.Bandwidth.Down, DefaultBitrate/8; got != want { t.Errorf("default downstream bandwidth: got %d, want %d", got, want) } if got, want := defaults.Latency, int64(0); got != want { t.Errorf("default latency: got %d, want %d", got, want) } if got, want := tsl.WriteBucket.Capacity(), DefaultBitrate/8; got != want { t.Errorf("tsl WriteBucket Capacity: got %d, want %d", got, want) } tsl.Shapes.RLock() if got, want := len(tsl.Shapes.M), 0; got != want { t.Errorf("length of shape map: got %d, want %d", got, want) } if modifiedTime := tsl.Shapes.LastModifiedTime; modifiedTime.Before(startTime) { t.Errorf("modified time is before start time; should be after") } tsl.Shapes.RUnlock() } func TestHandlerActions(t *testing.T) { l, err := net.Listen("tcp", "[::]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } tt := []struct { jsonString string actions []Action }{ { jsonString: `{"trafficshape":{"shapes":[{"url_regex":"http://example/example", "max_global_bandwidth":1000, "throttles":[{"bytes":"500-1000","bandwidth":100},{"bytes":"1000-2000","bandwidth":300},{"bytes":"2001-","bandwidth":400}], "halts":[{"byte":530,"duration": 5, "count": 1}],"close_connections":[{"byte":1078,"count":1}]}]}}`, actions: []Action{ &ChangeBandwidth{Byte: 500, Bandwidth: 100}, &Halt{Byte: 530, Duration: 5, Count: 1}, &ChangeBandwidth{Byte: 1000, Bandwidth: 300}, &CloseConnection{Byte: 1078, Count: 1}, &ChangeBandwidth{Byte: 2000, Bandwidth: 1000}, &ChangeBandwidth{Byte: 2001, Bandwidth: 400}, }, }, { jsonString: `{"trafficshape":{"shapes":[{"url_regex":"http://example/example","throttles":[{"bytes":"-","bandwidth":100}], "close_connections":[{"byte":100,"count":1}]}]}}`, actions: []Action{ &ChangeBandwidth{Byte: 0, Bandwidth: 100}, &CloseConnection{Byte: 100, Count: 1}, }, }, } for i, tc := range tt { tsl := NewListener(l) defer tsl.Close() h := NewHandler(tsl) startTime := time.Now() req, err := http.NewRequest("POST", "test", bytes.NewBufferString(tc.jsonString)) if err != nil { t.Fatalf("%d. http.NewRequest(): got %v, want no error", i, err) } rw := httptest.NewRecorder() h.ServeHTTP(rw, req) if got, want := rw.Code, 200; got != want { t.Errorf("%d. rw.Code: got %d, want %d", i+1, got, want) } tsl.Shapes.RLock() defer tsl.Shapes.RUnlock() if got, want := len(tsl.Shapes.M), 1; got != want { t.Errorf("tc.%d length of shape map: got %d, want %d", i+1, got, want) } tsl.Shapes.M["http://example/example"].RLock() defer tsl.Shapes.M["http://example/example"].RUnlock() if same, errStr := compareActions(tsl.Shapes.M["http://example/example"].Shape.Actions, tc.actions); !same { t.Errorf(errStr) } if modifiedTime := tsl.Shapes.LastModifiedTime; modifiedTime.Before(startTime) { t.Errorf("tc.%d modified time is before start time; should be after", i+1) } } } martian-3.3.2/trafficshape/listener.go000066400000000000000000000164751421371434000200010ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package trafficshape import ( "net" "sync" "time" "github.com/google/martian/v3/log" ) // DefaultBitrate represents the bitrate that will be for all url regexs for which a shape // has not been specified. var DefaultBitrate int64 = 500000000000 // 500Gbps (unlimited) // ErrForceClose is an error that communicates the need to close the connection. type ErrForceClose struct { message string } func (efc *ErrForceClose) Error() string { return efc.message } // urlShape contains a rw lock protected shape of a url_regex. type urlShape struct { sync.RWMutex Shape *Shape } // urlShapes contains a rw lock protected map of url regexs to their URLShapes. type urlShapes struct { sync.RWMutex M map[string]*urlShape LastModifiedTime time.Time } // Buckets contains the read and write buckets for a url_regex. type Buckets struct { ReadBucket *Bucket WriteBucket *Bucket } // NewBuckets returns a *Buckets with the specified up and down bandwidths. func NewBuckets(up int64, down int64) *Buckets { return &Buckets{ ReadBucket: NewBucket(up, time.Second), WriteBucket: NewBucket(down, time.Second), } } // ThrottleContext represents whether we are currently in a throttle interval for a particular // url_regex. If ThrottleNow is true, only then will the current throttle 'Bandwidth' be set // correctly. type ThrottleContext struct { ThrottleNow bool Bandwidth int64 } // NextActionInfo represents whether there is an upcoming action. Only if ActionNext is true will the // Index and ByteOffset be set correctly. type NextActionInfo struct { ActionNext bool Index int64 ByteOffset int64 } // Context represents the current information that is needed while writing back to the client. // Only if Shaping is true, that is we are currently writing back a response that matches a certain // url_regex will the other values be set correctly. If so, the Buckets represent the buckets // to be used for the current url_regex. NextActionInfo tells us whether there is an upcoming action // that needs to be performed, and ThrottleContext tells us whether we are currently in a throttle // interval (according to the RangeStart). Note, the ThrottleContext is only used once in the start // to determine the beginning bandwidth. It need not be updated after that. This // is because the subsequent throttles are captured in the upcoming ChangeBandwidth actions. // Byte Offset represents the absolute byte offset of response data that we are currently writing back. // It does not account for the header data. type Context struct { Shaping bool RangeStart int64 URLRegex string Buckets *Buckets GlobalBucket *Bucket ThrottleContext *ThrottleContext NextActionInfo *NextActionInfo ByteOffset int64 HeaderLen int64 HeaderBytesWritten int64 } // Listener wraps a net.Listener and simulates connection latency and bandwidth // constraints. type Listener struct { net.Listener ReadBucket *Bucket WriteBucket *Bucket mu sync.RWMutex latency time.Duration GlobalBuckets map[string]*Bucket Shapes *urlShapes defaults *Default } // NewListener returns a new bandwidth constrained listener. Defaults to // DefaultBitrate (uncapped). func NewListener(l net.Listener) *Listener { return &Listener{ Listener: l, ReadBucket: NewBucket(DefaultBitrate/8, time.Second), WriteBucket: NewBucket(DefaultBitrate/8, time.Second), Shapes: &urlShapes{M: make(map[string]*urlShape)}, GlobalBuckets: make(map[string]*Bucket), defaults: &Default{ Bandwidth: Bandwidth{ Up: DefaultBitrate / 8, Down: DefaultBitrate / 8, }, Latency: 0, }, } } // ReadBitrate returns the bitrate in bits per second for reads. func (l *Listener) ReadBitrate() int64 { return l.ReadBucket.Capacity() * 8 } // SetReadBitrate sets the bitrate in bits per second for reads. func (l *Listener) SetReadBitrate(bitrate int64) { l.ReadBucket.SetCapacity(bitrate / 8) } // WriteBitrate returns the bitrate in bits per second for writes. func (l *Listener) WriteBitrate() int64 { return l.WriteBucket.Capacity() * 8 } // SetWriteBitrate sets the bitrate in bits per second for writes. func (l *Listener) SetWriteBitrate(bitrate int64) { l.WriteBucket.SetCapacity(bitrate / 8) } // SetDefaults sets the default traffic shaping parameters for the listener. func (l *Listener) SetDefaults(defaults *Default) { l.mu.Lock() defer l.mu.Unlock() l.defaults = defaults } // Defaults returns the default traffic shaping parameters for the listener. func (l *Listener) Defaults() *Default { l.mu.RLock() defer l.mu.RUnlock() return l.defaults } // Latency returns the latency for connections. func (l *Listener) Latency() time.Duration { l.mu.Lock() defer l.mu.Unlock() return l.latency } // SetLatency sets the initial latency for connections. func (l *Listener) SetLatency(latency time.Duration) { l.mu.Lock() defer l.mu.Unlock() l.latency = latency } // GetTrafficShapedConn takes in a normal connection and returns a traffic shaped connection. func (l *Listener) GetTrafficShapedConn(oc net.Conn) *Conn { if tsconn, ok := oc.(*Conn); ok { return tsconn } urlbuckets := make(map[string]*Buckets) globalurlbuckets := make(map[string]*Bucket) l.Shapes.RLock() defaults := l.Defaults() latency := l.Latency() defaultBandwidth := defaults.Bandwidth for regex, shape := range l.Shapes.M { // It should be ok to not acquire the read lock on shape, since WriteBucket is never mutated. globalurlbuckets[regex] = shape.Shape.WriteBucket urlbuckets[regex] = NewBuckets(DefaultBitrate/8, shape.Shape.MaxBandwidth) } l.Shapes.RUnlock() curinfo := &Context{} lc := &Conn{ conn: oc, latency: latency, ReadBucket: l.ReadBucket, WriteBucket: l.WriteBucket, Shapes: l.Shapes, GlobalBuckets: globalurlbuckets, LocalBuckets: urlbuckets, Context: curinfo, Established: time.Now(), DefaultBandwidth: defaultBandwidth, Listener: l, } return lc } // Accept waits for and returns the next connection to the listener. func (l *Listener) Accept() (net.Conn, error) { oc, err := l.Listener.Accept() if err != nil { log.Errorf("trafficshape: failed accepting connection: %v", err) return nil, err } if tconn, ok := oc.(*net.TCPConn); ok { log.Debugf("trafficshape: setting keep-alive for TCP connection") tconn.SetKeepAlive(true) tconn.SetKeepAlivePeriod(3 * time.Minute) } return l.GetTrafficShapedConn(oc), nil } // Close closes the read and write buckets along with the underlying listener. func (l *Listener) Close() error { defer log.Debugf("trafficshape: closed read/write buckets and connection") l.ReadBucket.Close() l.WriteBucket.Close() return l.Listener.Close() } martian-3.3.2/trafficshape/listener_test.go000066400000000000000000000413241421371434000210270ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package trafficshape import ( "bytes" "io" "io/ioutil" "net" "net/http" "net/http/httptest" "sync" "testing" "time" ) func TestListenerRead(t *testing.T) { t.Parallel() l, err := net.Listen("tcp", "[::]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } tsl := NewListener(l) defer tsl.Close() if got := tsl.ReadBitrate(); got != DefaultBitrate { t.Errorf("tsl.ReadBitrate(): got %d, want DefaultBitrate", got) } if got := tsl.WriteBitrate(); got != DefaultBitrate { t.Errorf("tsl.WriteBitrate(): got %d, want DefaultBitrate", got) } tsl.SetReadBitrate(40) // 4 bytes per second if got, want := tsl.ReadBitrate(), int64(40); got != want { t.Errorf("tsl.ReadBitrate(): got %d, want %d", got, want) } tsl.SetWriteBitrate(40) // 4 bytes per second if got, want := tsl.WriteBitrate(), int64(40); got != want { t.Errorf("tsl.WriteBitrate(): got %d, want %d", got, want) } tsl.SetLatency(time.Second) if got, want := tsl.Latency(), time.Second; got != want { t.Errorf("tsl.Latency(): got %s, want %s", got, want) } var wg sync.WaitGroup wg.Add(1) want := bytes.Repeat([]byte("*"), 16) go func() { // Dial the local listener. c, err := net.Dial("tcp", tsl.Addr().String()) if err != nil { t.Fatalf("net.Dial(): got %v, want no error", err) } defer c.Close() // Wait for the signal that it's okay to write to the connection; ensure // the test is ready to read it. wg.Wait() c.Write(want) }() tsc, err := tsl.Accept() if err != nil { t.Fatalf("tsl.Accept(): got %v, want no error", err) } defer tsc.Close() // Signal to the write goroutine that it may begin writing. wg.Done() start := time.Now() got, err := ioutil.ReadAll(tsc) end := time.Now() if err != nil { t.Fatalf("tsc.Read(): got %v, want no error", err) } if !bytes.Equal(got, want) { t.Errorf("tsc.Read(): got %q, want %q", got, want) } // Breakdown of ~3s minimum: // 1 second for the initial latency // ~2-3 seconds for throttled read // - 4 bytes per second with 16 bytes total = 3 seconds (first four bytes // are read immediately at the zeroth second; 0:4, 1:8, 2:12, 3:16) // - the drain ticker begins before the initial start time so some of that // tick time is unaccounted for in the difference; potentially up to a // full second (the drain interval). For example, if the ticker is 300ms // into its tick before start is calculated we will believe that the // throttled read will have occurred in 2.7s. Allow for up to drain // interval in skew to account for this and ensure the test does not // flake. // // The test runtime should be negligible compared the latency simulation, so // we assume the ~3s (> 2.95s) is accounted for by throttling and latency in // the worst case (we read and a new tick happens immediately). min := 2*time.Second + 950*time.Millisecond max := 4*time.Second + 50*time.Millisecond if got := end.Sub(start); !between(got, min, max) { t.Errorf("tsc.Read(): took %s, want within [%s, %s]", got, min, max) } } func TestListenerWrite(t *testing.T) { t.Parallel() l, err := net.Listen("tcp", "[::]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } tsl := NewListener(l) defer tsl.Close() tsl.SetReadBitrate(40) // 4 bytes per second tsl.SetWriteBitrate(40) // 4 bytes per second tsl.SetLatency(time.Second) var wg sync.WaitGroup wg.Add(1) want := bytes.Repeat([]byte("*"), 16) var start time.Time var end time.Time go func() { // Dial the local listener. c, err := net.Dial("tcp", tsl.Addr().String()) if err != nil { t.Fatalf("net.Dial(): got %v, want no error", err) } defer c.Close() // Wait for the signal that it's okay to read from the connection; ensure // the test is ready to write to it. wg.Wait() got, err := ioutil.ReadAll(c) if err != nil { t.Fatalf("c.Read(): got %v, want no error", err) } if !bytes.Equal(got, want) { t.Errorf("c.Read(): got %q, want %q", got, want) } }() tsc, err := tsl.Accept() if err != nil { t.Fatalf("tsl.Accept(): got %v, want no error", err) } // Signal to the write goroutine that it may begin writing. wg.Done() start = time.Now() n, err := tsc.Write(want) end = time.Now() tsc.Close() if err != nil { t.Fatalf("tsc.Write(): got %v, want no error", err) } if got, want := n, len(want); got != want { t.Errorf("tsc.Write(): got %d bytes, want %d bytes", got, want) } // Breakdown of ~3s minimum: // 1 second for the initial latency // ~2-3 seconds for throttled write // - 4 bytes per second with 16 bytes total = 3 seconds (first four bytes // are written immediately at the zeroth second; 0:4, 1:8, 2:12, 3:16) // - the drain ticker begins before the initial start time so some of that // tick time is unaccounted for in the difference; potentially up to a // full second (the drain interval). For example, if the ticker is 300ms // into its tick before start is calculated we will believe that the // throttled write will have occurred in 2.7s. Allow for up to drain // interval in skew to account for this and ensure the test does not // flake. // // The test runtime should be negligible compared the latency simulation, so // we assume the ~3s (> 2.95s) is accounted for by throttling and latency in // the worst case (we write and a new tick happens immediately). min := 2*time.Second + 950*time.Millisecond max := 4*time.Second + 50*time.Millisecond if got := end.Sub(start); !between(got, min, max) { t.Errorf("tsc.Write(): took %s, want within [%s, %s]", got, min, max) } } func TestListenerWriteTo(t *testing.T) { t.Parallel() l, err := net.Listen("tcp", "[::]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } tsl := NewListener(l) defer tsl.Close() tsl.SetReadBitrate(40) // 4 bytes per second tsl.SetWriteBitrate(40) // 4 bytes per second tsl.SetLatency(time.Second) var wg sync.WaitGroup wg.Add(1) want := bytes.Repeat([]byte("*"), 16) go func() { // Dial the local listener. c, err := net.Dial("tcp", tsl.Addr().String()) if err != nil { t.Fatalf("net.Dial(): got %v, want no error", err) } defer c.Close() // Wait for the signal that it's okay to write to the connection; ensure // the test is ready to read it. wg.Wait() c.Write(want) }() tsc, err := tsl.Accept() if err != nil { t.Fatalf("tsl.Accept(): got %v, want no error", err) } defer tsc.Close() // Signal to the write goroutine that it may begin writing. wg.Done() got := &bytes.Buffer{} wt, ok := tsc.(io.WriterTo) if !ok { t.Fatal("tsc.(io.WriterTo): got !ok, want ok") } start := time.Now() n, err := wt.WriteTo(got) end := time.Now() if err != io.EOF { t.Fatalf("tsc.WriteTo(): got %v, want io.EOF", err) } if got, want := n, int64(len(want)); got != want { t.Errorf("tsc.WriteTo(): got %d bytes, want %d bytes", got, want) } if !bytes.Equal(got.Bytes(), want) { t.Errorf("tsc.WriteTo(): got %q, want %q", got.Bytes(), want) } // Breakdown of ~3s minimum: // 1 second for the initial latency // ~2-3 seconds for throttled read // - 4 bytes per second with 16 bytes total = 3 seconds (first four bytes // are read immediately at the zeroth second; 0:4, 1:8, 2:12, 3:16) // - the drain ticker begins before the initial start time so some of that // tick time is unaccounted for in the difference; potentially up to a // full second (the drain interval). For example, if the ticker is 300ms // into its tick before start is calculated we will believe that the // throttled read will have occurred in 2.7s. Allow for up to drain // interval in skew to account for this and ensure the test does not // flake. // // The test runtime should be negligible compared the latency simulation, so // we assume the ~3s (> 2.95s) is accounted for by throttling and latency in // the worst case (we read and a new tick happens immediately). min := 2*time.Second + 950*time.Millisecond max := 4*time.Second + 50*time.Millisecond if got := end.Sub(start); !between(got, min, max) { t.Errorf("tsc.WriteTo(): took %s, want within [%s, %s]", got, min, max) } } func TestListenerReadFrom(t *testing.T) { t.Parallel() l, err := net.Listen("tcp", "[::]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } tsl := NewListener(l) defer tsl.Close() tsl.SetReadBitrate(40) // 4 bytes per second tsl.SetWriteBitrate(40) // 4 bytes per second tsl.SetLatency(time.Second) var wg sync.WaitGroup wg.Add(1) want := bytes.Repeat([]byte("*"), 16) var start time.Time var end time.Time go func() { // Dial the local listener. c, err := net.Dial("tcp", tsl.Addr().String()) if err != nil { t.Fatalf("net.Dial(): got %v, want no error", err) } defer c.Close() // Wait for the signal that it's okay to read from the connection; ensure // the test is ready to write it. wg.Wait() got, err := ioutil.ReadAll(c) if err != nil { t.Fatalf("c.Read(): got %v, want no error", err) } if !bytes.Equal(got, want) { t.Errorf("c.Read(): got %q, want %q", got, want) } }() tsc, err := tsl.Accept() if err != nil { t.Fatalf("tsl.Accept(): got %v, want no error", err) } // Signal to the write goroutine that it may begin writing. wg.Done() buf := bytes.NewReader(want) rf, ok := tsc.(io.ReaderFrom) if !ok { t.Fatal("tsc.(io.ReaderFrom): got !ok, want ok") } start = time.Now() n, err := rf.ReadFrom(buf) end = time.Now() tsc.Close() if err != nil { t.Fatalf("tsc.ReadFrom(): got %v, want no error", err) } if got, want := n, int64(len(want)); got != want { t.Errorf("tsc.ReadFrom(): got %d bytes, want %d bytes", got, want) } // Breakdown of ~3s minimum: // 1 second for the initial latency // ~2-3 seconds for throttled writes // - 4 bytes per second with 16 bytes total = 3 seconds (first four bytes // are written immediately at the zeroth second; 0:4, 1:8, 2:12, 3:16) // - the drain ticker begins before the initial start time so some of that // tick time is unaccounted for in the difference; potentially up to a // full second (the drain interval). For example, if the ticker is 300ms // into its tick before start is calculated we will believe that the // throttled write will have occurred in 2.7s. Allow for up to drain // interval in skew to account for this and ensure the test does not // flake. // // The test runtime should be negligible compared the latency simulation, so // we assume the ~3s (> 2.95s) is accounted for by throttling and latency in // the worst case (we write and a new tick happens immediately). min := 2*time.Second + 950*time.Millisecond max := 4*time.Second + 50*time.Millisecond if got := end.Sub(start); !between(got, min, max) { t.Errorf("tsc.ReadFrom(): took %s, want within [%s, %s]", got, min, max) } } func between(d, min, max time.Duration) bool { return d >= min && d <= max } type throttleAssertion struct { Offset int64 ThrottleContext *ThrottleContext } type actionByteAssertion struct { Offset int64 NextActionInfo *NextActionInfo } func TestActionsAndThrottles(t *testing.T) { l, err := net.Listen("tcp", "[::]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } tt := []struct { jsonString string throttleAssertions []throttleAssertion actionByteAssertions []actionByteAssertion }{ { jsonString: `{"trafficshape":{"shapes":[{"url_regex":"http://example/example", "max_global_bandwidth":1000, "throttles":[{"bytes":"500-1000","bandwidth":100},{"bytes":"1000-2000","bandwidth":300},{"bytes":"2001-","bandwidth":400}], "halts":[{"byte":530,"duration": 5, "count": 1}],"close_connections":[{"byte":1078,"count":1}]}]}}`, throttleAssertions: []throttleAssertion{ { Offset: 10, ThrottleContext: &ThrottleContext{ThrottleNow: false}, }, { Offset: 700, ThrottleContext: &ThrottleContext{ThrottleNow: true, Bandwidth: 100}, }, { Offset: 1000, ThrottleContext: &ThrottleContext{ThrottleNow: true, Bandwidth: 300}, }, { Offset: 5000, ThrottleContext: &ThrottleContext{ThrottleNow: true, Bandwidth: 400}, }, }, actionByteAssertions: []actionByteAssertion{ { Offset: 501, NextActionInfo: &NextActionInfo{ActionNext: true, ByteOffset: 530, Index: 1}, }, { Offset: 900, NextActionInfo: &NextActionInfo{ActionNext: true, ByteOffset: 1000, Index: 2}, }, { Offset: 1015, NextActionInfo: &NextActionInfo{ActionNext: true, ByteOffset: 1078, Index: 3}, }, { Offset: 2001, NextActionInfo: &NextActionInfo{ActionNext: true, ByteOffset: 2001, Index: 5}, }, }, }, } for i, tc := range tt { tsl := NewListener(l) defer tsl.Close() h := NewHandler(tsl) req, err := http.NewRequest("POST", "test", bytes.NewBufferString(tc.jsonString)) if err != nil { t.Fatalf("%d. http.NewRequest(): got %v, want no error", i, err) } rw := httptest.NewRecorder() h.ServeHTTP(rw, req) if got, want := rw.Code, 200; got != want { t.Errorf("%d. rw.Code: got %d, want %d", i+1, got, want) } conn, err := net.Dial("tcp", l.Addr().String()) defer conn.Close() if err != nil { t.Fatalf("net.Dial(): got %v, want no error", err) } tsconn := tsl.GetTrafficShapedConn(conn) tsconn.Context = &Context{ Shaping: true, URLRegex: "http://example/example", } for _, ta := range tc.throttleAssertions { if got, want := *tsconn.GetCurrentThrottle(ta.Offset), *ta.ThrottleContext; got != want { t.Errorf("tc.%d CurtThrottleInfo at %d got %+v, want %+v", i+1, ta.Offset, got, want) } } for _, aba := range tc.actionByteAssertions { if got, want := *tsconn.GetNextActionFromByte(aba.Offset), *aba.NextActionInfo; got != want { t.Errorf("tc.%d NextActionInfo at %d got %+v, want %+v", i+1, aba.Offset, got, want) } } } } func TestActionsAfterUpdatingCounts(t *testing.T) { l, err := net.Listen("tcp", "[::]:0") if err != nil { t.Fatalf("net.Listen(): got %v, want no error", err) } jsonString := `{"trafficshape":{"shapes":[{"url_regex":"http://example/example", "max_global_bandwidth":1000, "throttles":[{"bytes":"500-1000","bandwidth":100}], "halts":[{"byte":530,"duration": 5, "count": 1},{"byte":550,"duration": 5, "count": 1}],"close_connections":[{"byte":1078,"count":1}]}]}}` tsl := NewListener(l) defer tsl.Close() h := NewHandler(tsl) req, err := http.NewRequest("POST", "test", bytes.NewBufferString(jsonString)) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } rw := httptest.NewRecorder() h.ServeHTTP(rw, req) if got, want := rw.Code, 200; got != want { t.Errorf("rw.Code: got %d, want %d", got, want) } conn, err := net.Dial("tcp", l.Addr().String()) defer conn.Close() if err != nil { t.Fatalf("net.Dial(): got %v, want no error", err) } tsconn := tsl.GetTrafficShapedConn(conn) tsconn.Context = &Context{ Shaping: true, URLRegex: "http://example/example", } actions := tsconn.Shapes.M[tsconn.Context.URLRegex].Shape.Actions nai := []*NextActionInfo{ &NextActionInfo{ActionNext: true, ByteOffset: 530, Index: 1}, &NextActionInfo{ActionNext: true, ByteOffset: 550, Index: 2}, &NextActionInfo{ActionNext: true, ByteOffset: 1000, Index: 3}, &NextActionInfo{ActionNext: true, ByteOffset: 1078, Index: 4}, &NextActionInfo{ActionNext: false}, } if got, want := *tsconn.GetNextActionFromByte(515), *nai[0]; got != want { t.Errorf("NextActionInfo at %d got %+v, want %+v", 515, got, want) } actions[1].decrementCount() if got, want := *tsconn.GetNextActionFromByte(515), *nai[1]; got != want { t.Errorf("NextActionInfo at %d got %+v, want %+v", 515, got, want) } actions[2].decrementCount() if got, want := *tsconn.GetNextActionFromByte(515), *nai[2]; got != want { t.Errorf("NextActionInfo at %d got %+v, want %+v", 515, got, want) } if got, want := *tsconn.GetNextActionFromByte(1015), *nai[3]; got != want { t.Errorf("NextActionInfo at %d got %+v, want %+v", 1015, got, want) } actions[4].decrementCount() if got, want := *tsconn.GetNextActionFromByte(1015), *nai[4]; got != want { t.Errorf("NextActionInfo at %d got %+v, want %+v", 1015, got, want) } } martian-3.3.2/trafficshape/utils.go000066400000000000000000000143411421371434000173020ustar00rootroot00000000000000package trafficshape import ( "errors" "fmt" "regexp" "sort" "strconv" "strings" "time" ) // Converts a sorted slice of Throttles to their ChangeBandwidth actions. In adddition, checks for // overlapping throttle ranges. Returns a slice of actions and an error specifying if the throttles // passed the non-overlapping verification. // // Idea: For every throttle, add two ChangeBandwidth actions (one for start and one for end), unless // the ending byte of one throttle is the same as the starting byte of the next throttle, in which // case we do not add the end ChangeBandwidth for the first throttle, or if the end of a throttle // is -1 (representing EOF), in which case we do not add the end ChangeBandwidth action for the throttle. // // Note, we only allow the last throttle in the sorted list to have an end of -1 in order to avoid an overlap. func getActionsFromThrottles(throttles []*Throttle, defaultBandwidth int64) ([]Action, error) { lenThr := len(throttles) var actions []Action for index, throttle := range throttles { start := throttle.ByteStart end := throttle.ByteEnd if index == lenThr-1 { if end == -1 { actions = append(actions, Action(&ChangeBandwidth{ Byte: start, Bandwidth: throttle.Bandwidth, })) } else { actions = append(actions, Action(&ChangeBandwidth{ Byte: start, Bandwidth: throttle.Bandwidth, }), Action(&ChangeBandwidth{ Byte: end, Bandwidth: defaultBandwidth, })) } break } if end > throttles[index+1].ByteStart || end == -1 { return actions, errors.New("overlapping throttle intervals found") } if end == throttles[index+1].ByteStart { actions = append(actions, Action(&ChangeBandwidth{ Byte: start, Bandwidth: throttle.Bandwidth, })) } else { actions = append(actions, Action(&ChangeBandwidth{ Byte: start, Bandwidth: throttle.Bandwidth, }), Action(&ChangeBandwidth{ Byte: end, Bandwidth: defaultBandwidth, })) } } return actions, nil } // Parses a Trafficshape and updates Traffficshape.Shapes while performing verifications. // // Returns an error in case a verification check fails. func parseShapes(ts *Trafficshape) error { var err error for shapeIndex, shape := range ts.Shapes { if shape == nil { return fmt.Errorf("nil shape at index: %d", shapeIndex) } if shape.URLRegex == "" { return fmt.Errorf("no url_regex for shape at index: %d", shapeIndex) } if _, err = regexp.Compile(shape.URLRegex); err != nil { return fmt.Errorf("url_regex for shape at index doesn't compile: %d", shapeIndex) } if shape.MaxBandwidth < 0 { return fmt.Errorf("max_bandwidth cannot be negative for shape at index: %d", shapeIndex) } if shape.MaxBandwidth == 0 { shape.MaxBandwidth = DefaultBitrate / 8 } shape.WriteBucket = NewBucket(shape.MaxBandwidth, time.Second) // Verify and process the throttles, filling in their ByteStart and ByteEnd. for throttleIndex, throttle := range shape.Throttles { if throttle == nil { return fmt.Errorf("nil throttle at index %d in shape index %d", throttleIndex, shapeIndex) } if throttle.Bandwidth <= 0 { return fmt.Errorf("invalid bandwidth: %d at throttle index %d in shape index %d", throttle.Bandwidth, throttleIndex, shapeIndex) } sl := strings.Split(throttle.Bytes, "-") if len(sl) != 2 { return fmt.Errorf("invalid bytes: %s at throttle index %d in shape index %d", throttle.Bytes, throttleIndex, shapeIndex) } start := sl[0] end := sl[1] if start == "" { throttle.ByteStart = 0 } else { throttle.ByteStart, err = strconv.ParseInt(start, 10, 64) if err != nil { return fmt.Errorf("invalid bytes: %s at throttle index %d in shape index %d", throttle.Bytes, throttleIndex, shapeIndex) } } if end == "" { throttle.ByteEnd = -1 } else { throttle.ByteEnd, err = strconv.ParseInt(end, 10, 64) if err != nil { return fmt.Errorf("invalid bytes: %s at throttle index %d in shape index %d", throttle.Bytes, throttleIndex, shapeIndex) } if throttle.ByteEnd < throttle.ByteStart { return fmt.Errorf("invalid bytes: %s at throttle index %d in shape index %d", throttle.Bytes, throttleIndex, shapeIndex) } } if throttle.ByteStart == throttle.ByteEnd { return fmt.Errorf("invalid bytes: %s at throttle index %d in shape index %d", throttle.Bytes, throttleIndex, shapeIndex) } } // Fill in the actions, while performing verification. shape.Actions = make([]Action, len(shape.Halts)+len(shape.CloseConnections)) for index, value := range shape.Halts { if value == nil { return fmt.Errorf("nil halt at index %d in shape index %d", index, shapeIndex) } if value.Duration < 0 || value.Byte < 0 { return fmt.Errorf("invalid halt at index %d in shape index %d", index, shapeIndex) } if value.Count == 0 { return fmt.Errorf(" 0 count for halt at index %d in shape index %d", index, shapeIndex) } shape.Actions[index] = Action(value) } offset := len(shape.Halts) for index, value := range shape.CloseConnections { if value == nil { return fmt.Errorf("nil close_connection at index %d in shape index %d", index, shapeIndex) } if value.Byte < 0 { return fmt.Errorf("invalid close_connection at index %d in shape index %d", index, shapeIndex) } if value.Count == 0 { return fmt.Errorf("0 count for close_connection at index %d in shape index %d", index, shapeIndex) } shape.Actions[offset+index] = Action(value) } sort.SliceStable(shape.Throttles, func(i, j int) bool { return shape.Throttles[i].ByteStart < shape.Throttles[j].ByteStart }) defaultBandwidth := DefaultBitrate / 8 if shape.MaxBandwidth > 0 { defaultBandwidth = shape.MaxBandwidth } throttleActions, err := getActionsFromThrottles(shape.Throttles, defaultBandwidth) if err != nil { return fmt.Errorf("err: %s in shape index %d", err.Error(), shapeIndex) } shape.Actions = append(shape.Actions, throttleActions...) // Sort the actions according to their byte offset. sort.SliceStable(shape.Actions, func(i, j int) bool { return shape.Actions[i].getByte() < shape.Actions[j].getByte() }) } return nil } martian-3.3.2/verify/000077500000000000000000000000001421371434000144555ustar00rootroot00000000000000martian-3.3.2/verify/verify.go000066400000000000000000000046541421371434000163210ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Package verify provides support for using martian modifiers for request and // response verifications. package verify import ( "net/http" "github.com/google/martian/v3" ) // RequestVerifier is a RequestModifier that maintains a verification state. // RequestVerifiers should only return an error from ModifyRequest for errors // unrelated to the expectation. type RequestVerifier interface { martian.RequestModifier VerifyRequests() error ResetRequestVerifications() } // ResponseVerifier is a ResponseModifier that maintains a verification state. // ResponseVerifiers should only return an error from ModifyResponse for errors // unrelated to the expectation. type ResponseVerifier interface { martian.ResponseModifier VerifyResponses() error ResetResponseVerifications() } // RequestResponseVerifier is a RequestVerifier and a ResponseVerifier. type RequestResponseVerifier interface { RequestVerifier ResponseVerifier } // TestVerifier is a request and response verifier with overridable errors for // verification. type TestVerifier struct { RequestError error ResponseError error } // ModifyRequest is a no-op. func (tv *TestVerifier) ModifyRequest(*http.Request) error { return nil } // ModifyResponse is a no-op. func (tv *TestVerifier) ModifyResponse(*http.Response) error { return nil } // VerifyRequests returns the set request error. func (tv *TestVerifier) VerifyRequests() error { return tv.RequestError } // VerifyResponses returns the set response error. func (tv *TestVerifier) VerifyResponses() error { return tv.ResponseError } // ResetRequestVerifications clears out the set request error. func (tv *TestVerifier) ResetRequestVerifications() { tv.RequestError = nil } // ResetResponseVerifications clears out the set response error. func (tv *TestVerifier) ResetResponseVerifications() { tv.ResponseError = nil } martian-3.3.2/verify/verify_handlers.go000066400000000000000000000070161421371434000201740ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package verify import ( "encoding/json" "net/http" "github.com/google/martian/v3" "github.com/google/martian/v3/log" ) // Handler is an http.Handler that returns the request and response // verifications of reqv and resv as JSON. type Handler struct { reqv RequestVerifier resv ResponseVerifier } // ResetHandler is an http.Handler that resets the request and response // verifications of reqv and resv. type ResetHandler struct { reqv RequestVerifier resv ResponseVerifier } type verifyResponse struct { Errors []verifyError `json:"errors"` } type verifyError struct { Message string `json:"message"` } // NewHandler returns an http.Handler for requesting the verification // error status. func NewHandler() *Handler { return &Handler{} } // NewResetHandler returns an http.Handler for reseting the verification error // status. func NewResetHandler() *ResetHandler { return &ResetHandler{} } // SetRequestVerifier sets the RequestVerifier to verify. func (h *Handler) SetRequestVerifier(reqv RequestVerifier) { h.reqv = reqv } // SetResponseVerifier sets the ResponseVerifier to verify. func (h *Handler) SetResponseVerifier(resv ResponseVerifier) { h.resv = resv } // ServeHTTP writes out a JSON response containing a list of verification // errors that occurred during the requests and responses sent to the proxy. func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { rw.Header().Set("Content-Type", "application/json") if req.Method != "GET" { rw.Header().Set("Allow", "GET") rw.WriteHeader(405) log.Errorf("verify: invalid request method: %s", req.Method) return } vres := &verifyResponse{ Errors: make([]verifyError, 0), } if h.reqv != nil { if err := h.reqv.VerifyRequests(); err != nil { appendError(vres, err) } } if h.resv != nil { if err := h.resv.VerifyResponses(); err != nil { appendError(vres, err) } } json.NewEncoder(rw).Encode(vres) } func appendError(vres *verifyResponse, err error) { merr, ok := err.(*martian.MultiError) if !ok { vres.Errors = append(vres.Errors, verifyError{Message: err.Error()}) return } for _, err := range merr.Errors() { vres.Errors = append(vres.Errors, verifyError{Message: err.Error()}) } } // SetRequestVerifier sets the RequestVerifier to reset. func (h *ResetHandler) SetRequestVerifier(reqv RequestVerifier) { h.reqv = reqv } // SetResponseVerifier sets the ResponseVerifier to reset. func (h *ResetHandler) SetResponseVerifier(resv ResponseVerifier) { h.resv = resv } // ServeHTTP resets the verifier for the given ID so that it may // be run again. func (h *ResetHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { if req.Method != "POST" { rw.Header().Set("Allow", "POST") rw.WriteHeader(405) log.Errorf("verify: invalid request method: %s", req.Method) return } if h.reqv != nil { h.reqv.ResetRequestVerifications() } if h.resv != nil { h.resv.ResetResponseVerifications() } rw.WriteHeader(204) } martian-3.3.2/verify/verify_handlers_test.go000066400000000000000000000112641421371434000212330ustar00rootroot00000000000000// Copyright 2015 Google Inc. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package verify import ( "bytes" "encoding/json" "fmt" "net/http" "net/http/httptest" "testing" "github.com/google/martian/v3" ) func TestHandlerServeHTTPUnsupportedMethod(t *testing.T) { h := NewHandler() for i, m := range []string{"POST", "PUT", "DELETE"} { req, err := http.NewRequest(m, "http://example.com", nil) if err != nil { t.Fatalf("%d. http.NewRequest(): got %v, want no error", i, err) } rw := httptest.NewRecorder() h.ServeHTTP(rw, req) if got, want := rw.Code, 405; got != want { t.Errorf("%d. rw.Code: got %d, want %d", i, got, want) } if got, want := rw.Header().Get("Allow"), "GET"; got != want { t.Errorf("%d. rw.Header().Get(%q): got %q, want %q", i, "Allow", got, want) } } } func TestHandlerServeHTTPNoVerifiers(t *testing.T) { h := NewHandler() req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } rw := httptest.NewRecorder() h.ServeHTTP(rw, req) if got, want := rw.Code, 200; got != want { t.Errorf("rw.Code: got %d, want %d", got, want) } } func TestHandlerServeHTTP(t *testing.T) { merr := martian.NewMultiError() merr.Add(fmt.Errorf("first response verification failure")) merr.Add(fmt.Errorf("second response verification failure")) v := &TestVerifier{ RequestError: fmt.Errorf("request verification failure"), ResponseError: merr, } h := NewHandler() h.SetRequestVerifier(v) h.SetResponseVerifier(v) req, err := http.NewRequest("GET", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } rw := httptest.NewRecorder() h.ServeHTTP(rw, req) if got, want := rw.Code, 200; got != want { t.Errorf("rw.Code: got %d, want %d", got, want) } buf := new(bytes.Buffer) if err := json.Compact(buf, []byte(`{ "errors": [ { "message": "request verification failure" }, { "message": "first response verification failure" }, { "message": "second response verification failure" } ] }`)); err != nil { t.Fatalf("json.Compact(): got %v, want no error", err) } // json.(*Encoder).Encode writes a trailing newline, so we will too. // see: https://golang.org/src/encoding/json/stream.go buf.WriteByte('\n') if got, want := rw.Body.Bytes(), buf.Bytes(); !bytes.Equal(got, want) { t.Errorf("rw.Body: got %q, want %q", got, want) } } func TestResetHandlerServeHTTPUnsupportedMethod(t *testing.T) { h := NewResetHandler() for i, m := range []string{"GET", "PUT", "DELETE"} { req, err := http.NewRequest(m, "http://example.com", nil) if err != nil { t.Fatalf("%d. http.NewRequest(): got %v, want no error", i, err) } rw := httptest.NewRecorder() h.ServeHTTP(rw, req) if got, want := rw.Code, 405; got != want { t.Errorf("%d. rw.Code: got %d, want %d", i, got, want) } if got, want := rw.Header().Get("Allow"), "POST"; got != want { t.Errorf("%d. rw.Header().Get(%q): got %q, want %q", i, "Allow", got, want) } } } func TestResetHandlerServeHTTPNoVerifiers(t *testing.T) { h := NewResetHandler() req, err := http.NewRequest("POST", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } rw := httptest.NewRecorder() h.ServeHTTP(rw, req) if got, want := rw.Code, 204; got != want { t.Errorf("rw.Code: got %d, want %d", got, want) } } func TestResetHandlerServeHTTP(t *testing.T) { v := &TestVerifier{ RequestError: fmt.Errorf("request verification failure"), ResponseError: fmt.Errorf("response verification failure"), } h := NewResetHandler() h.SetRequestVerifier(v) h.SetResponseVerifier(v) req, err := http.NewRequest("POST", "http://example.com", nil) if err != nil { t.Fatalf("http.NewRequest(): got %v, want no error", err) } rw := httptest.NewRecorder() h.ServeHTTP(rw, req) if got, want := rw.Code, 204; got != want { t.Errorf("rw.Code: got %d, want %d", got, want) } if err := v.VerifyRequests(); err != nil { t.Errorf("v.VerifyRequests(): got %v, want no error", err) } if err := v.VerifyResponses(); err != nil { t.Errorf("v.VerifyResponses(): got %v, want no error", err) } }