pax_global_header00006660000000000000000000000064140404025750014513gustar00rootroot0000000000000052 comment=df0f91b29bbbdfc3a686a7a8edbe2b9de2072fdd go-grpc-middleware-1.3.0/000077500000000000000000000000001404040257500151655ustar00rootroot00000000000000go-grpc-middleware-1.3.0/.github/000077500000000000000000000000001404040257500165255ustar00rootroot00000000000000go-grpc-middleware-1.3.0/.github/stale.yml000066400000000000000000000017461404040257500203700ustar00rootroot00000000000000# Configuration for probot-stale - https://github.com/probot/stale # Number of days of inactivity before an Issue becomes stale daysUntilStale: 60 # Number of days of inactivity before a stale Issue is closed. # Set to false to disable. If disabled, issues still need to be closed manually, but will remain marked as stale. daysUntilClose: 7 # Issues with these labels will never be considered stale. Set to `[]` to disable exemptLabels: - bug - docs improvement - help wanted # Label to use when marking as stale staleLabel: stale issues: # Comment to post when marking Issues as stale. markComment: > This issue has been automatically marked as stale because it has not had any recent activity. It will be closed in 7 days if no further activity occurs. # Comment to post when closing a stale Issue. closeComment: > This issue has been automatically closed because it has not had any recent activity. If you have a question or comment, please open a new issue. go-grpc-middleware-1.3.0/.gitignore000066400000000000000000000062711404040257500171630ustar00rootroot00000000000000# Created by .ignore support plugin (hsz.mobi) ### Go template # Compiled Object files, Static and Dynamic libs (Shared Objects) *.o *.a *.so # Folders _obj _test # Architecture specific extensions/prefixes *.[568vq] [568vq].out *.cgo1.go *.cgo2.c _cgo_defun.c _cgo_gotypes.go _cgo_export.* _testmain.go *.exe *.test *.prof ### Windows template # Windows image file caches Thumbs.db ehthumbs.db # Folder config file Desktop.ini # Recycle Bin used on file shares $RECYCLE.BIN/ # Windows Installer files *.cab *.msi *.msm *.msp # Windows shortcuts *.lnk ### Kate template # Swap Files # .*.kate-swp .swp.* ### SublimeText template # cache files for sublime text *.tmlanguage.cache *.tmPreferences.cache *.stTheme.cache # workspace files are user-specific *.sublime-workspace # project files should be checked into the repository, unless a significant # proportion of contributors will probably not be using SublimeText # *.sublime-project # sftp configuration file sftp-config.json ### Linux template *~ # temporary files which can be created if a process still has a handle open of a deleted file .fuse_hidden* # KDE directory preferences .directory # Linux trash folder which might appear on any partition or disk .Trash-* ### JetBrains template # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 # User-specific stuff: .idea .idea/tasks.xml .idea/dictionaries .idea/vcs.xml .idea/jsLibraryMappings.xml # Sensitive or high-churn files: .idea/dataSources.ids .idea/dataSources.xml .idea/dataSources.local.xml .idea/sqlDataSources.xml .idea/dynamic.xml .idea/uiDesigner.xml # Gradle: .idea/gradle.xml .idea/libraries # Mongo Explorer plugin: .idea/mongoSettings.xml ## File-based project format: *.iws ## Plugin-specific files: # IntelliJ /out/ # mpeltonen/sbt-idea plugin .idea_modules/ # JIRA plugin atlassian-ide-plugin.xml # Crashlytics plugin (for Android Studio and IntelliJ) com_crashlytics_export_strings.xml crashlytics.properties crashlytics-build.properties fabric.properties ### Xcode template # Xcode # # gitignore contributors: remember to update Global/Xcode.gitignore, Objective-C.gitignore & Swift.gitignore ## Build generated build/ DerivedData/ ## Various settings *.pbxuser !default.pbxuser *.mode1v3 !default.mode1v3 *.mode2v3 !default.mode2v3 *.perspectivev3 !default.perspectivev3 xcuserdata/ ## Other *.moved-aside *.xccheckout *.xcscmblueprint ### Eclipse template .metadata bin/ tmp/ *.tmp *.bak *.swp *~.nib local.properties .settings/ .loadpath .recommenders # Eclipse Core .project # External tool builders .externalToolBuilders/ # Locally stored "Eclipse launch configurations" *.launch # PyDev specific (Python IDE for Eclipse) *.pydevproject # CDT-specific (C/C++ Development Tooling) .cproject # JDT-specific (Eclipse Java Development Tools) .classpath # Java annotation processor (APT) .factorypath # PDT-specific (PHP Development Tools) .buildpath # sbteclipse plugin .target # Tern plugin .tern-project # TeXlipse plugin .texlipse # STS (Spring Tool Suite) .springBeans # Code Recommenders .recommenders/ coverage.txt #vendor vendor/ .envrcgo-grpc-middleware-1.3.0/.travis.yml000066400000000000000000000002651404040257500173010ustar00rootroot00000000000000sudo: false language: go go: - 1.13.x - 1.14.x - 1.15.x env: global: - GO111MODULE=on script: - make test after_success: - bash <(curl -s https://codecov.io/bash) go-grpc-middleware-1.3.0/CHANGELOG.md000066400000000000000000000046551404040257500170100ustar00rootroot00000000000000# Changelog All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html). Types of changes: - `Added` for new features. - `Changed` for changes in existing functionality. - `Deprecated` for soon-to-be removed features. - `Removed` for now removed features. - `Fixed` for any bug fixes. - `Security` in case of vulnerabilities. ## [Unreleased] ### Added - [#223](https://github.com/grpc-ecosystem/go-grpc-middleware/pull/223) Add go-kit logging middleware - [adrien-f](https://github.com/adrien-f) ## [v1.1.0] - 2019-09-12 ### Added - [#226](https://github.com/grpc-ecosystem/go-grpc-middleware/pull/226) Support for go modules. - [#221](https://github.com/grpc-ecosystem/go-grpc-middleware/pull/221) logging/zap add support for gRPC LoggerV2 - [kush-patel-hs](https://github.com/kush-patel-hs) - [#181](https://github.com/grpc-ecosystem/go-grpc-middleware/pull/181) Rate Limit support - [ceshihao](https://github.com/ceshihao) - [#161](https://github.com/grpc-ecosystem/go-grpc-middleware/pull/161) Retry on server stream call - [lonnblad](https://github.com/lonnblad) - [#152](https://github.com/grpc-ecosystem/go-grpc-middleware/pull/152) Exponential backoff functions - [polyfloyd](https://github.com/polyfloyd) - [#147](https://github.com/grpc-ecosystem/go-grpc-middleware/pull/147) Jaeger support for ctxtags extraction - [vporoshok](https://github.com/vporoshok) - [#184](https://github.com/grpc-ecosystem/go-grpc-middleware/pull/184) ctxTags identifies if the call was sampled ### Deprecated - [#201](https://github.com/grpc-ecosystem/go-grpc-middleware/pull/201) `golang.org/x/net/context` - [houz42](https://github.com/houz42) - [#183](https://github.com/grpc-ecosystem/go-grpc-middleware/pull/183) Documentation Generation in favour of . ### Fixed - [172](https://github.com/grpc-ecosystem/go-grpc-middleware/pull/172) Passing ctx into retry and recover - [johanbrandhorst](https://github.com/johanbrandhorst) - Numerious documentation fixes. ## v1.0.0 - 2018-05-08 ### Added - grpc_auth - grpc_ctxtags - grpc_zap - grpc_logrus - grpc_opentracing - grpc_retry - grpc_validator - grpc_recovery [Unreleased]: https://github.com/grpc-ecosystem/go-grpc-middleware/compare/v1.1.0...HEAD [v1.1.0]: https://github.com/grpc-ecosystem/go-grpc-middleware/compare/v1.0.0...v1.1.0 go-grpc-middleware-1.3.0/CONTRIBUTING.md000066400000000000000000000006461404040257500174240ustar00rootroot00000000000000# Contributing We would love to have people submit pull requests and help make `grpc-ecosystem/go-grpc-middleware` even better 👍. Fork, then clone the repo: ```bash git clone git@github.com:your-username/go-grpc-middleware.git ``` Before checking in please run the following: ```bash make all ``` This will `vet`, `fmt`, regenerate documentation and run all tests. Push to your fork and open a pull request.go-grpc-middleware-1.3.0/LICENSE000066400000000000000000000261141404040257500161760ustar00rootroot00000000000000 Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.go-grpc-middleware-1.3.0/README.md000066400000000000000000000117561404040257500164560ustar00rootroot00000000000000# Go gRPC Middleware [![Travis Build](https://travis-ci.org/grpc-ecosystem/go-grpc-middleware.svg?branch=master)](https://travis-ci.org/grpc-ecosystem/go-grpc-middleware) [![Go Report Card](https://goreportcard.com/badge/github.com/grpc-ecosystem/go-grpc-middleware)](https://goreportcard.com/report/github.com/grpc-ecosystem/go-grpc-middleware) [![GoDoc](http://img.shields.io/badge/GoDoc-Reference-blue.svg)](https://godoc.org/github.com/grpc-ecosystem/go-grpc-middleware) [![SourceGraph](https://sourcegraph.com/github.com/grpc-ecosystem/go-grpc-middleware/-/badge.svg)](https://sourcegraph.com/github.com/grpc-ecosystem/go-grpc-middleware/?badge) [![codecov](https://codecov.io/gh/grpc-ecosystem/go-grpc-middleware/branch/master/graph/badge.svg)](https://codecov.io/gh/grpc-ecosystem/go-grpc-middleware) [![Apache 2.0 License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](LICENSE) [![quality: production](https://img.shields.io/badge/quality-production-orange.svg)](#status) [![Slack](https://img.shields.io/badge/slack-%23grpc--middleware-brightgreen)](https://slack.com/share/IRUQCFC23/9Tm7hxRFVKKNoajQfMOcUiIk/enQtODc4ODI4NTIyMDcxLWM5NDA0ZTE4Njg5YjRjYWZkMTI5MzQwNDY3YzBjMzE1YzdjOGM5ZjI1NDNiM2JmNzI2YjM5ODE5OTRiNTEyOWE) [gRPC Go](https://github.com/grpc/grpc-go) Middleware: interceptors, helpers, utilities. ## Middleware [gRPC Go](https://github.com/grpc/grpc-go) recently acquired support for Interceptors, i.e. [middleware](https://medium.com/@matryer/writing-middleware-in-golang-and-how-go-makes-it-so-much-fun-4375c1246e81#.gv7tdlghs) that is executed either on the gRPC Server before the request is passed onto the user's application logic, or on the gRPC client around the user call. It is a perfect way to implement common patterns: auth, logging, message, validation, retries or monitoring. These are generic building blocks that make it easy to build multiple microservices easily. The purpose of this repository is to act as a go-to point for such reusable functionality. It contains some of them itself, but also will link to useful external repos. `grpc_middleware` itself provides support for chaining interceptors, here's an example: ```go import "github.com/grpc-ecosystem/go-grpc-middleware" myServer := grpc.NewServer( grpc.StreamInterceptor(grpc_middleware.ChainStreamServer( grpc_recovery.StreamServerInterceptor(), grpc_ctxtags.StreamServerInterceptor(), grpc_opentracing.StreamServerInterceptor(), grpc_prometheus.StreamServerInterceptor, grpc_zap.StreamServerInterceptor(zapLogger), grpc_auth.StreamServerInterceptor(myAuthFunction), )), grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer( grpc_recovery.UnaryServerInterceptor(), grpc_ctxtags.UnaryServerInterceptor(), grpc_opentracing.UnaryServerInterceptor(), grpc_prometheus.UnaryServerInterceptor, grpc_zap.UnaryServerInterceptor(zapLogger), grpc_auth.UnaryServerInterceptor(myAuthFunction), )), ) ``` ## Interceptors *Please send a PR to add new interceptors or middleware to this list* #### Auth * [`grpc_auth`](auth) - a customizable (via `AuthFunc`) piece of auth middleware #### Logging * [`grpc_ctxtags`](tags/) - a library that adds a `Tag` map to context, with data populated from request body * [`grpc_zap`](logging/zap/) - integration of [zap](https://github.com/uber-go/zap) logging library into gRPC handlers. * [`grpc_logrus`](logging/logrus/) - integration of [logrus](https://github.com/sirupsen/logrus) logging library into gRPC handlers. * [`grpc_kit`](logging/kit/) - integration of [go-kit](https://github.com/go-kit/kit/tree/master/log) logging library into gRPC handlers. * [`grpc_grpc_logsettable`](logging/settable/) - a wrapper around `grpclog.LoggerV2` that allows to replace loggers in runtime (thread-safe). #### Monitoring * [`grpc_prometheus`⚡](https://github.com/grpc-ecosystem/go-grpc-prometheus) - Prometheus client-side and server-side monitoring middleware * [`otgrpc`⚡](https://github.com/grpc-ecosystem/grpc-opentracing/tree/master/go/otgrpc) - [OpenTracing](http://opentracing.io/) client-side and server-side interceptors * [`grpc_opentracing`](tracing/opentracing) - [OpenTracing](http://opentracing.io/) client-side and server-side interceptors with support for streaming and handler-returned tags #### Client * [`grpc_retry`](retry/) - a generic gRPC response code retry mechanism, client-side middleware #### Server * [`grpc_validator`](validator/) - codegen inbound message validation from `.proto` options * [`grpc_recovery`](recovery/) - turn panics into gRPC errors * [`ratelimit`](ratelimit/) - grpc rate limiting by your own limiter ## Status This code has been running in *production* since May 2016 as the basis of the gRPC micro services stack at [Improbable](https://improbable.io). Additional tooling will be added, and contributions are welcome. ## License `go-grpc-middleware` is released under the Apache 2.0 license. See the [LICENSE](LICENSE) file for details. go-grpc-middleware-1.3.0/auth/000077500000000000000000000000001404040257500161265ustar00rootroot00000000000000go-grpc-middleware-1.3.0/auth/auth.go000066400000000000000000000052751404040257500174270ustar00rootroot00000000000000// Copyright 2016 Michal Witkowski. All Rights Reserved. // See LICENSE for licensing terms. package grpc_auth import ( "context" "github.com/grpc-ecosystem/go-grpc-middleware" "google.golang.org/grpc" ) // AuthFunc is the pluggable function that performs authentication. // // The passed in `Context` will contain the gRPC metadata.MD object (for header-based authentication) and // the peer.Peer information that can contain transport-based credentials (e.g. `credentials.AuthInfo`). // // The returned context will be propagated to handlers, allowing user changes to `Context`. However, // please make sure that the `Context` returned is a child `Context` of the one passed in. // // If error is returned, its `grpc.Code()` will be returned to the user as well as the verbatim message. // Please make sure you use `codes.Unauthenticated` (lacking auth) and `codes.PermissionDenied` // (authed, but lacking perms) appropriately. type AuthFunc func(ctx context.Context) (context.Context, error) // ServiceAuthFuncOverride allows a given gRPC service implementation to override the global `AuthFunc`. // // If a service implements the AuthFuncOverride method, it takes precedence over the `AuthFunc` method, // and will be called instead of AuthFunc for all method invocations within that service. type ServiceAuthFuncOverride interface { AuthFuncOverride(ctx context.Context, fullMethodName string) (context.Context, error) } // UnaryServerInterceptor returns a new unary server interceptors that performs per-request auth. func UnaryServerInterceptor(authFunc AuthFunc) grpc.UnaryServerInterceptor { return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { var newCtx context.Context var err error if overrideSrv, ok := info.Server.(ServiceAuthFuncOverride); ok { newCtx, err = overrideSrv.AuthFuncOverride(ctx, info.FullMethod) } else { newCtx, err = authFunc(ctx) } if err != nil { return nil, err } return handler(newCtx, req) } } // StreamServerInterceptor returns a new unary server interceptors that performs per-request auth. func StreamServerInterceptor(authFunc AuthFunc) grpc.StreamServerInterceptor { return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { var newCtx context.Context var err error if overrideSrv, ok := srv.(ServiceAuthFuncOverride); ok { newCtx, err = overrideSrv.AuthFuncOverride(stream.Context(), info.FullMethod) } else { newCtx, err = authFunc(stream.Context()) } if err != nil { return err } wrapped := grpc_middleware.WrapServerStream(stream) wrapped.WrappedContext = newCtx return handler(srv, wrapped) } } go-grpc-middleware-1.3.0/auth/auth_test.go000066400000000000000000000166021404040257500204620ustar00rootroot00000000000000// Copyright 2016 Michal Witkowski. All Rights Reserved. // See LICENSE for licensing terms. package grpc_auth_test import ( "context" "fmt" "testing" "time" "github.com/grpc-ecosystem/go-grpc-middleware/auth" "github.com/grpc-ecosystem/go-grpc-middleware/testing" pb_testproto "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto" "github.com/grpc-ecosystem/go-grpc-middleware/util/metautils" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "golang.org/x/oauth2" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/oauth" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" ) var ( commonAuthToken = "some_good_token" overrideAuthToken = "override_token" authedMarker = "some_context_marker" goodPing = &pb_testproto.PingRequest{Value: "something", SleepTimeMs: 9999} ) // TODO(mwitkow): Add auth from metadata client dialer, which requires TLS. func buildDummyAuthFunction(expectedScheme string, expectedToken string) func(ctx context.Context) (context.Context, error) { return func(ctx context.Context) (context.Context, error) { token, err := grpc_auth.AuthFromMD(ctx, expectedScheme) if err != nil { return nil, err } if token != expectedToken { return nil, status.Errorf(codes.PermissionDenied, "buildDummyAuthFunction bad token") } return context.WithValue(ctx, authedMarker, "marker_exists"), nil } } func assertAuthMarkerExists(t *testing.T, ctx context.Context) { assert.Equal(t, "marker_exists", ctx.Value(authedMarker).(string), "auth marker from buildDummyAuthFunction must be passed around") } type assertingPingService struct { pb_testproto.TestServiceServer T *testing.T } func (s *assertingPingService) PingError(ctx context.Context, ping *pb_testproto.PingRequest) (*pb_testproto.Empty, error) { assertAuthMarkerExists(s.T, ctx) return s.TestServiceServer.PingError(ctx, ping) } func (s *assertingPingService) PingList(ping *pb_testproto.PingRequest, stream pb_testproto.TestService_PingListServer) error { assertAuthMarkerExists(s.T, stream.Context()) return s.TestServiceServer.PingList(ping, stream) } func ctxWithToken(ctx context.Context, scheme string, token string) context.Context { md := metadata.Pairs("authorization", fmt.Sprintf("%s %v", scheme, token)) nCtx := metautils.NiceMD(md).ToOutgoing(ctx) return nCtx } func TestAuthTestSuite(t *testing.T) { authFunc := buildDummyAuthFunction("bearer", commonAuthToken) s := &AuthTestSuite{ InterceptorTestSuite: &grpc_testing.InterceptorTestSuite{ TestService: &assertingPingService{&grpc_testing.TestPingService{T: t}, t}, ServerOpts: []grpc.ServerOption{ grpc.StreamInterceptor(grpc_auth.StreamServerInterceptor(authFunc)), grpc.UnaryInterceptor(grpc_auth.UnaryServerInterceptor(authFunc)), }, }, } suite.Run(t, s) } type AuthTestSuite struct { *grpc_testing.InterceptorTestSuite } func (s *AuthTestSuite) TestUnary_NoAuth() { _, err := s.Client.Ping(s.SimpleCtx(), goodPing) assert.Error(s.T(), err, "there must be an error") assert.Equal(s.T(), codes.Unauthenticated, status.Code(err), "must error with unauthenticated") } func (s *AuthTestSuite) TestUnary_BadAuth() { _, err := s.Client.Ping(ctxWithToken(s.SimpleCtx(), "bearer", "bad_token"), goodPing) assert.Error(s.T(), err, "there must be an error") assert.Equal(s.T(), codes.PermissionDenied, status.Code(err), "must error with permission denied") } func (s *AuthTestSuite) TestUnary_PassesAuth() { _, err := s.Client.Ping(ctxWithToken(s.SimpleCtx(), "bearer", commonAuthToken), goodPing) require.NoError(s.T(), err, "no error must occur") } func (s *AuthTestSuite) TestUnary_PassesWithPerRpcCredentials() { grpcCreds := oauth.TokenSource{TokenSource: &fakeOAuth2TokenSource{accessToken: commonAuthToken}} client := s.NewClient(grpc.WithPerRPCCredentials(grpcCreds)) _, err := client.Ping(s.SimpleCtx(), goodPing) require.NoError(s.T(), err, "no error must occur") } func (s *AuthTestSuite) TestStream_NoAuth() { stream, err := s.Client.PingList(s.SimpleCtx(), goodPing) require.NoError(s.T(), err, "should not fail on establishing the stream") _, err = stream.Recv() assert.Error(s.T(), err, "there must be an error") assert.Equal(s.T(), codes.Unauthenticated, status.Code(err), "must error with unauthenticated") } func (s *AuthTestSuite) TestStream_BadAuth() { stream, err := s.Client.PingList(ctxWithToken(s.SimpleCtx(), "bearer", "bad_token"), goodPing) require.NoError(s.T(), err, "should not fail on establishing the stream") _, err = stream.Recv() assert.Error(s.T(), err, "there must be an error") assert.Equal(s.T(), codes.PermissionDenied, status.Code(err), "must error with permission denied") } func (s *AuthTestSuite) TestStream_PassesAuth() { stream, err := s.Client.PingList(ctxWithToken(s.SimpleCtx(), "Bearer", commonAuthToken), goodPing) require.NoError(s.T(), err, "should not fail on establishing the stream") pong, err := stream.Recv() require.NoError(s.T(), err, "no error must occur") require.NotNil(s.T(), pong, "pong must not be nil") } func (s *AuthTestSuite) TestStream_PassesWithPerRpcCredentials() { grpcCreds := oauth.TokenSource{TokenSource: &fakeOAuth2TokenSource{accessToken: commonAuthToken}} client := s.NewClient(grpc.WithPerRPCCredentials(grpcCreds)) stream, err := client.PingList(s.SimpleCtx(), goodPing) require.NoError(s.T(), err, "should not fail on establishing the stream") pong, err := stream.Recv() require.NoError(s.T(), err, "no error must occur") require.NotNil(s.T(), pong, "pong must not be nil") } type authOverrideTestService struct { pb_testproto.TestServiceServer T *testing.T } func (s *authOverrideTestService) AuthFuncOverride(ctx context.Context, fullMethodName string) (context.Context, error) { assert.NotEmpty(s.T, fullMethodName, "method name of caller is passed around") return buildDummyAuthFunction("bearer", overrideAuthToken)(ctx) } func TestAuthOverrideTestSuite(t *testing.T) { authFunc := buildDummyAuthFunction("bearer", commonAuthToken) s := &AuthOverrideTestSuite{ InterceptorTestSuite: &grpc_testing.InterceptorTestSuite{ TestService: &authOverrideTestService{&assertingPingService{&grpc_testing.TestPingService{T: t}, t}, t}, ServerOpts: []grpc.ServerOption{ grpc.StreamInterceptor(grpc_auth.StreamServerInterceptor(authFunc)), grpc.UnaryInterceptor(grpc_auth.UnaryServerInterceptor(authFunc)), }, }, } suite.Run(t, s) } type AuthOverrideTestSuite struct { *grpc_testing.InterceptorTestSuite } func (s *AuthOverrideTestSuite) TestUnary_PassesAuth() { _, err := s.Client.Ping(ctxWithToken(s.SimpleCtx(), "bearer", overrideAuthToken), goodPing) require.NoError(s.T(), err, "no error must occur") } func (s *AuthOverrideTestSuite) TestStream_PassesAuth() { stream, err := s.Client.PingList(ctxWithToken(s.SimpleCtx(), "Bearer", overrideAuthToken), goodPing) require.NoError(s.T(), err, "should not fail on establishing the stream") pong, err := stream.Recv() require.NoError(s.T(), err, "no error must occur") require.NotNil(s.T(), pong, "pong must not be nil") } // fakeOAuth2TokenSource implements a fake oauth2.TokenSource for the purpose of credentials test. type fakeOAuth2TokenSource struct { accessToken string } func (ts *fakeOAuth2TokenSource) Token() (*oauth2.Token, error) { t := &oauth2.Token{ AccessToken: ts.accessToken, Expiry: time.Now().Add(1 * time.Minute), TokenType: "bearer", } return t, nil } go-grpc-middleware-1.3.0/auth/doc.go000066400000000000000000000013131404040257500172200ustar00rootroot00000000000000// Copyright 2016 Michal Witkowski. All Rights Reserved. // See LICENSE for licensing terms. /* `grpc_auth` a generic server-side auth middleware for gRPC. Server Side Auth Middleware It allows for easy assertion of `:authorization` headers in gRPC calls, be it HTTP Basic auth, or OAuth2 Bearer tokens. The middleware takes a user-customizable `AuthFunc`, which can be customized to verify and extract auth information from the request. The extracted information can be put in the `context.Context` of handlers downstream for retrieval. It also allows for per-service implementation overrides of `AuthFunc`. See `ServiceAuthFuncOverride`. Please see examples for simple examples of use. */ package grpc_auth go-grpc-middleware-1.3.0/auth/examples_test.go000066400000000000000000000045461404040257500213430ustar00rootroot00000000000000package grpc_auth_test import ( "context" "log" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth" grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags" pb "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto" ) func parseToken(token string) (struct{}, error) { return struct{}{}, nil } func userClaimFromToken(struct{}) string { return "foobar" } // exampleAuthFunc is used by a middleware to authenticate requests func exampleAuthFunc(ctx context.Context) (context.Context, error) { token, err := grpc_auth.AuthFromMD(ctx, "bearer") if err != nil { return nil, err } tokenInfo, err := parseToken(token) if err != nil { return nil, status.Errorf(codes.Unauthenticated, "invalid auth token: %v", err) } grpc_ctxtags.Extract(ctx).Set("auth.sub", userClaimFromToken(tokenInfo)) // WARNING: in production define your own type to avoid context collisions newCtx := context.WithValue(ctx, "tokenInfo", tokenInfo) return newCtx, nil } // Simple example of server initialization code func Example_serverConfig() { _ = grpc.NewServer( grpc.StreamInterceptor(grpc_auth.StreamServerInterceptor(exampleAuthFunc)), grpc.UnaryInterceptor(grpc_auth.UnaryServerInterceptor(exampleAuthFunc)), ) } type server struct { pb.UnimplementedTestServiceServer message string } // SayHello only can be called by client when authenticated by exampleAuthFunc func (g *server) Ping(ctx context.Context, request *pb.PingRequest) (*pb.PingResponse, error) { return &pb.PingResponse{Value: g.message}, nil } // AuthFuncOverride is called instead of exampleAuthFunc func (g *server) AuthFuncOverride(ctx context.Context, fullMethodName string) (context.Context, error) { log.Println("client is calling method:", fullMethodName) return ctx, nil } // Simple example of server initialization code with AuthFuncOverride method. func Example_serverConfigWithAuthOverride() { svr := grpc.NewServer( grpc.StreamInterceptor(grpc_auth.StreamServerInterceptor(exampleAuthFunc)), grpc.UnaryInterceptor(grpc_auth.UnaryServerInterceptor(exampleAuthFunc)), ) overrideActive := true if overrideActive { pb.RegisterTestServiceServer(svr, &server{message: "pong unauthenticated"}) } else { pb.RegisterTestServiceServer(svr, &server{message: "pong authenticated"}) } } go-grpc-middleware-1.3.0/auth/metadata.go000066400000000000000000000024131404040257500202350ustar00rootroot00000000000000// Copyright 2016 Michal Witkowski. All Rights Reserved. // See LICENSE for licensing terms. package grpc_auth import ( "context" "strings" "github.com/grpc-ecosystem/go-grpc-middleware/util/metautils" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) var ( headerAuthorize = "authorization" ) // AuthFromMD is a helper function for extracting the :authorization header from the gRPC metadata of the request. // // It expects the `:authorization` header to be of a certain scheme (e.g. `basic`, `bearer`), in a // case-insensitive format (see rfc2617, sec 1.2). If no such authorization is found, or the token // is of wrong scheme, an error with gRPC status `Unauthenticated` is returned. func AuthFromMD(ctx context.Context, expectedScheme string) (string, error) { val := metautils.ExtractIncoming(ctx).Get(headerAuthorize) if val == "" { return "", status.Errorf(codes.Unauthenticated, "Request unauthenticated with "+expectedScheme) } splits := strings.SplitN(val, " ", 2) if len(splits) < 2 { return "", status.Errorf(codes.Unauthenticated, "Bad authorization string") } if !strings.EqualFold(splits[0], expectedScheme) { return "", status.Errorf(codes.Unauthenticated, "Request unauthenticated with "+expectedScheme) } return splits[1], nil } go-grpc-middleware-1.3.0/auth/metadata_test.go000066400000000000000000000037401404040257500213000ustar00rootroot00000000000000// Copyright 2016 Michal Witkowski. All Rights Reserved. // See LICENSE for licensing terms. package grpc_auth import ( "context" "testing" "github.com/grpc-ecosystem/go-grpc-middleware/util/metautils" "github.com/stretchr/testify/assert" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" ) func TestAuthFromMD(t *testing.T) { for _, run := range []struct { md metadata.MD value string errCode codes.Code msg string }{ { md: metadata.Pairs("authorization", "bearer some_token"), value: "some_token", msg: "must extract simple bearer tokens without case checking", }, { md: metadata.Pairs("authorization", "Bearer some_token"), value: "some_token", msg: "must extract simple bearer tokens with case checking", }, { md: metadata.Pairs("authorization", "Bearer some multi string bearer"), value: "some multi string bearer", msg: "must handle string based bearers", }, { md: metadata.Pairs("authorization", "Basic login:passwd"), value: "", errCode: codes.Unauthenticated, msg: "must check authentication type", }, { md: metadata.Pairs("authorization", "Basic login:passwd", "authorization", "bearer some_token"), value: "", errCode: codes.Unauthenticated, msg: "must not allow multiple authentication methods", }, { md: metadata.Pairs("authorization", ""), value: "", errCode: codes.Unauthenticated, msg: "authorization string must not be empty", }, { md: metadata.Pairs("authorization", "Bearer"), value: "", errCode: codes.Unauthenticated, msg: "bearer token must not be empty", }, } { ctx := metautils.NiceMD(run.md).ToIncoming(context.TODO()) out, err := AuthFromMD(ctx, "bearer") if run.errCode != codes.OK { assert.Equal(t, run.errCode, status.Code(err), run.msg) } else { assert.NoError(t, err, run.msg) } assert.Equal(t, run.value, out, run.msg) } } go-grpc-middleware-1.3.0/chain.go000066400000000000000000000121231404040257500165750ustar00rootroot00000000000000// Copyright 2016 Michal Witkowski. All Rights Reserved. // See LICENSE for licensing terms. // gRPC Server Interceptor chaining middleware. package grpc_middleware import ( "context" "google.golang.org/grpc" ) // ChainUnaryServer creates a single interceptor out of a chain of many interceptors. // // Execution is done in left-to-right order, including passing of context. // For example ChainUnaryServer(one, two, three) will execute one before two before three, and three // will see context changes of one and two. func ChainUnaryServer(interceptors ...grpc.UnaryServerInterceptor) grpc.UnaryServerInterceptor { n := len(interceptors) return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { chainer := func(currentInter grpc.UnaryServerInterceptor, currentHandler grpc.UnaryHandler) grpc.UnaryHandler { return func(currentCtx context.Context, currentReq interface{}) (interface{}, error) { return currentInter(currentCtx, currentReq, info, currentHandler) } } chainedHandler := handler for i := n - 1; i >= 0; i-- { chainedHandler = chainer(interceptors[i], chainedHandler) } return chainedHandler(ctx, req) } } // ChainStreamServer creates a single interceptor out of a chain of many interceptors. // // Execution is done in left-to-right order, including passing of context. // For example ChainUnaryServer(one, two, three) will execute one before two before three. // If you want to pass context between interceptors, use WrapServerStream. func ChainStreamServer(interceptors ...grpc.StreamServerInterceptor) grpc.StreamServerInterceptor { n := len(interceptors) return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { chainer := func(currentInter grpc.StreamServerInterceptor, currentHandler grpc.StreamHandler) grpc.StreamHandler { return func(currentSrv interface{}, currentStream grpc.ServerStream) error { return currentInter(currentSrv, currentStream, info, currentHandler) } } chainedHandler := handler for i := n - 1; i >= 0; i-- { chainedHandler = chainer(interceptors[i], chainedHandler) } return chainedHandler(srv, ss) } } // ChainUnaryClient creates a single interceptor out of a chain of many interceptors. // // Execution is done in left-to-right order, including passing of context. // For example ChainUnaryClient(one, two, three) will execute one before two before three. func ChainUnaryClient(interceptors ...grpc.UnaryClientInterceptor) grpc.UnaryClientInterceptor { n := len(interceptors) return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { chainer := func(currentInter grpc.UnaryClientInterceptor, currentInvoker grpc.UnaryInvoker) grpc.UnaryInvoker { return func(currentCtx context.Context, currentMethod string, currentReq, currentRepl interface{}, currentConn *grpc.ClientConn, currentOpts ...grpc.CallOption) error { return currentInter(currentCtx, currentMethod, currentReq, currentRepl, currentConn, currentInvoker, currentOpts...) } } chainedInvoker := invoker for i := n - 1; i >= 0; i-- { chainedInvoker = chainer(interceptors[i], chainedInvoker) } return chainedInvoker(ctx, method, req, reply, cc, opts...) } } // ChainStreamClient creates a single interceptor out of a chain of many interceptors. // // Execution is done in left-to-right order, including passing of context. // For example ChainStreamClient(one, two, three) will execute one before two before three. func ChainStreamClient(interceptors ...grpc.StreamClientInterceptor) grpc.StreamClientInterceptor { n := len(interceptors) return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { chainer := func(currentInter grpc.StreamClientInterceptor, currentStreamer grpc.Streamer) grpc.Streamer { return func(currentCtx context.Context, currentDesc *grpc.StreamDesc, currentConn *grpc.ClientConn, currentMethod string, currentOpts ...grpc.CallOption) (grpc.ClientStream, error) { return currentInter(currentCtx, currentDesc, currentConn, currentMethod, currentStreamer, currentOpts...) } } chainedStreamer := streamer for i := n - 1; i >= 0; i-- { chainedStreamer = chainer(interceptors[i], chainedStreamer) } return chainedStreamer(ctx, desc, cc, method, opts...) } } // Chain creates a single interceptor out of a chain of many interceptors. // // WithUnaryServerChain is a grpc.Server config option that accepts multiple unary interceptors. // Basically syntactic sugar. func WithUnaryServerChain(interceptors ...grpc.UnaryServerInterceptor) grpc.ServerOption { return grpc.UnaryInterceptor(ChainUnaryServer(interceptors...)) } // WithStreamServerChain is a grpc.Server config option that accepts multiple stream interceptors. // Basically syntactic sugar. func WithStreamServerChain(interceptors ...grpc.StreamServerInterceptor) grpc.ServerOption { return grpc.StreamInterceptor(ChainStreamServer(interceptors...)) } go-grpc-middleware-1.3.0/chain_test.go000066400000000000000000000222471404040257500176440ustar00rootroot00000000000000// Copyright 2016 Michal Witkowski. All Rights Reserved. // See LICENSE for licensing terms. package grpc_middleware import ( "context" "fmt" "testing" "github.com/stretchr/testify/require" "google.golang.org/grpc" "google.golang.org/grpc/metadata" ) var ( someServiceName = "SomeService.StreamMethod" parentUnaryInfo = &grpc.UnaryServerInfo{FullMethod: someServiceName} parentStreamInfo = &grpc.StreamServerInfo{ FullMethod: someServiceName, IsServerStream: true, } someValue = 1 parentContext = context.WithValue(context.TODO(), "parent", someValue) ) func TestChainUnaryServer(t *testing.T) { input := "input" output := "output" first := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { requireContextValue(t, ctx, "parent", "first interceptor must know the parent context value") require.Equal(t, parentUnaryInfo, info, "first interceptor must know the someUnaryServerInfo") ctx = context.WithValue(ctx, "first", 1) return handler(ctx, req) } second := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { requireContextValue(t, ctx, "parent", "second interceptor must know the parent context value") requireContextValue(t, ctx, "first", "second interceptor must know the first context value") require.Equal(t, parentUnaryInfo, info, "second interceptor must know the someUnaryServerInfo") ctx = context.WithValue(ctx, "second", 1) return handler(ctx, req) } handler := func(ctx context.Context, req interface{}) (interface{}, error) { require.EqualValues(t, input, req, "handler must get the input") requireContextValue(t, ctx, "parent", "handler must know the parent context value") requireContextValue(t, ctx, "first", "handler must know the first context value") requireContextValue(t, ctx, "second", "handler must know the second context value") return output, nil } chain := ChainUnaryServer(first, second) out, _ := chain(parentContext, input, parentUnaryInfo, handler) require.EqualValues(t, output, out, "chain must return handler's output") } func TestChainStreamServer(t *testing.T) { someService := &struct{}{} recvMessage := "received" sentMessage := "sent" outputError := fmt.Errorf("some error") first := func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { requireContextValue(t, stream.Context(), "parent", "first interceptor must know the parent context value") require.Equal(t, parentStreamInfo, info, "first interceptor must know the parentStreamInfo") require.Equal(t, someService, srv, "first interceptor must know someService") wrapped := WrapServerStream(stream) wrapped.WrappedContext = context.WithValue(stream.Context(), "first", 1) return handler(srv, wrapped) } second := func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { requireContextValue(t, stream.Context(), "parent", "second interceptor must know the parent context value") requireContextValue(t, stream.Context(), "first", "second interceptor must know the first context value") require.Equal(t, parentStreamInfo, info, "second interceptor must know the parentStreamInfo") require.Equal(t, someService, srv, "second interceptor must know someService") wrapped := WrapServerStream(stream) wrapped.WrappedContext = context.WithValue(stream.Context(), "second", 1) return handler(srv, wrapped) } handler := func(srv interface{}, stream grpc.ServerStream) error { require.Equal(t, someService, srv, "handler must know someService") requireContextValue(t, stream.Context(), "parent", "handler must know the parent context value") requireContextValue(t, stream.Context(), "first", "handler must know the first context value") requireContextValue(t, stream.Context(), "second", "handler must know the second context value") require.NoError(t, stream.RecvMsg(recvMessage), "handler must have access to stream messages") require.NoError(t, stream.SendMsg(sentMessage), "handler must be able to send stream messages") return outputError } fakeStream := &fakeServerStream{ctx: parentContext, recvMessage: recvMessage} chain := ChainStreamServer(first, second) err := chain(someService, fakeStream, parentStreamInfo, handler) require.Equal(t, outputError, err, "chain must return handler's error") require.Equal(t, sentMessage, fakeStream.sentMessage, "handler's sent message must propagate to stream") } func TestChainUnaryClient(t *testing.T) { ignoredMd := metadata.Pairs("foo", "bar") parentOpts := []grpc.CallOption{grpc.Header(&ignoredMd)} reqMessage := "request" replyMessage := "reply" outputError := fmt.Errorf("some error") first := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { requireContextValue(t, ctx, "parent", "first must know the parent context value") require.Equal(t, someServiceName, method, "first must know someService") require.Len(t, opts, 1, "first should see parent CallOptions") wrappedCtx := context.WithValue(ctx, "first", 1) return invoker(wrappedCtx, method, req, reply, cc, opts...) } second := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { requireContextValue(t, ctx, "parent", "second must know the parent context value") requireContextValue(t, ctx, "first", "second must know the first context value") require.Equal(t, someServiceName, method, "second must know someService") require.Len(t, opts, 1, "second should see parent CallOptions") wrappedOpts := append(opts, grpc.WaitForReady(false)) wrappedCtx := context.WithValue(ctx, "second", 1) return invoker(wrappedCtx, method, req, reply, cc, wrappedOpts...) } invoker := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error { require.Equal(t, someServiceName, method, "invoker must know someService") requireContextValue(t, ctx, "parent", "invoker must know the parent context value") requireContextValue(t, ctx, "first", "invoker must know the first context value") requireContextValue(t, ctx, "second", "invoker must know the second context value") require.Len(t, opts, 2, "invoker should see both CallOpts from second and parent") return outputError } chain := ChainUnaryClient(first, second) err := chain(parentContext, someServiceName, reqMessage, replyMessage, nil, invoker, parentOpts...) require.Equal(t, outputError, err, "chain must return invokers's error") } func TestChainStreamClient(t *testing.T) { ignoredMd := metadata.Pairs("foo", "bar") parentOpts := []grpc.CallOption{grpc.Header(&ignoredMd)} clientStream := &fakeClientStream{} fakeStreamDesc := &grpc.StreamDesc{ClientStreams: true, ServerStreams: true, StreamName: someServiceName} first := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { requireContextValue(t, ctx, "parent", "first must know the parent context value") require.Equal(t, someServiceName, method, "first must know someService") require.Len(t, opts, 1, "first should see parent CallOptions") wrappedCtx := context.WithValue(ctx, "first", 1) return streamer(wrappedCtx, desc, cc, method, opts...) } second := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { requireContextValue(t, ctx, "parent", "second must know the parent context value") requireContextValue(t, ctx, "first", "second must know the first context value") require.Equal(t, someServiceName, method, "second must know someService") require.Len(t, opts, 1, "second should see parent CallOptions") wrappedOpts := append(opts, grpc.WaitForReady(false)) wrappedCtx := context.WithValue(ctx, "second", 1) return streamer(wrappedCtx, desc, cc, method, wrappedOpts...) } streamer := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { require.Equal(t, someServiceName, method, "streamer must know someService") require.Equal(t, fakeStreamDesc, desc, "streamer must see the right StreamDesc") requireContextValue(t, ctx, "parent", "streamer must know the parent context value") requireContextValue(t, ctx, "first", "streamer must know the first context value") requireContextValue(t, ctx, "second", "streamer must know the second context value") require.Len(t, opts, 2, "streamer should see both CallOpts from second and parent") return clientStream, nil } chain := ChainStreamClient(first, second) someStream, err := chain(parentContext, fakeStreamDesc, nil, someServiceName, streamer, parentOpts...) require.NoError(t, err, "chain must not return an error") require.Equal(t, clientStream, someStream, "chain must return invokers's clientstream") } func requireContextValue(t *testing.T, ctx context.Context, key string, msg ...interface{}) { val := ctx.Value(key) require.NotNil(t, val, msg...) require.Equal(t, someValue, val, msg...) } go-grpc-middleware-1.3.0/doc.go000066400000000000000000000060501404040257500162620ustar00rootroot00000000000000// Copyright 2016 Michal Witkowski. All Rights Reserved. // See LICENSE for licensing terms. /* `grpc_middleware` is a collection of gRPC middleware packages: interceptors, helpers and tools. Middleware gRPC is a fantastic RPC middleware, which sees a lot of adoption in the Golang world. However, the upstream gRPC codebase is relatively bare bones. This package, and most of its child packages provides commonly needed middleware for gRPC: client-side interceptors for retires, server-side interceptors for input validation and auth, functions for chaining said interceptors, metadata convenience methods and more. Chaining By default, gRPC doesn't allow one to have more than one interceptor either on the client nor on the server side. `grpc_middleware` provides convenient chaining methods Simple way of turning a multiple interceptors into a single interceptor. Here's an example for server chaining: myServer := grpc.NewServer( grpc.StreamInterceptor(grpc_middleware.ChainStreamServer(loggingStream, monitoringStream, authStream)), grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer(loggingUnary, monitoringUnary, authUnary)), ) These interceptors will be executed from left to right: logging, monitoring and auth. Here's an example for client side chaining: clientConn, err = grpc.Dial( address, grpc.WithUnaryInterceptor(grpc_middleware.ChainUnaryClient(monitoringClientUnary, retryUnary)), grpc.WithStreamInterceptor(grpc_middleware.ChainStreamClient(monitoringClientStream, retryStream)), ) client = pb_testproto.NewTestServiceClient(clientConn) resp, err := client.PingEmpty(s.ctx, &myservice.Request{Msg: "hello"}) These interceptors will be executed from left to right: monitoring and then retry logic. The retry interceptor will call every interceptor that follows it whenever when a retry happens. Writing Your Own Implementing your own interceptor is pretty trivial: there are interfaces for that. But the interesting bit exposing common data to handlers (and other middleware), similarly to HTTP Middleware design. For example, you may want to pass the identity of the caller from the auth interceptor all the way to the handling function. For example, a client side interceptor example for auth looks like: func FakeAuthUnaryInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { newCtx := context.WithValue(ctx, "user_id", "john@example.com") return handler(newCtx, req) } Unfortunately, it's not as easy for streaming RPCs. These have the `context.Context` embedded within the `grpc.ServerStream` object. To pass values through context, a wrapper (`WrappedServerStream`) is needed. For example: func FakeAuthStreamingInterceptor(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { newStream := grpc_middleware.WrapServerStream(stream) newStream.WrappedContext = context.WithValue(ctx, "user_id", "john@example.com") return handler(srv, newStream) } */ package grpc_middleware go-grpc-middleware-1.3.0/go.mod000066400000000000000000000013521404040257500162740ustar00rootroot00000000000000module github.com/grpc-ecosystem/go-grpc-middleware require ( github.com/go-kit/kit v0.9.0 github.com/go-logfmt/logfmt v0.4.0 // indirect github.com/go-stack/stack v1.8.0 // indirect github.com/gogo/protobuf v1.3.2 github.com/golang/protobuf v1.3.3 github.com/opentracing/opentracing-go v1.1.0 github.com/pkg/errors v0.8.1 // indirect github.com/sirupsen/logrus v1.4.2 github.com/stretchr/testify v1.4.0 go.uber.org/atomic v1.4.0 // indirect go.uber.org/multierr v1.1.0 // indirect go.uber.org/zap v1.10.0 golang.org/x/net v0.0.0-20201021035429-f5854403a974 golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be google.golang.org/genproto v0.0.0-20200423170343-7949de9c1215 // indirect google.golang.org/grpc v1.29.1 ) go 1.14 go-grpc-middleware-1.3.0/go.sum000066400000000000000000000270111404040257500163210ustar00rootroot00000000000000cloud.google.com/go v0.26.0 h1:e0WKqKTd5BnrG8aKH3J3h+QvEIQtSUcf2n5UZ5ZgLtQ= cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/go-kit/kit v0.9.0 h1:wDJmvq38kDhkVxi50ni9ykkdUr1PKgqKOoi01fa0Mdk= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-logfmt/logfmt v0.4.0 h1:MP4Eh7ZCb31lleYCFuwm0oe4/YGak+5l1vA2NOE80nA= github.com/go-logfmt/logfmt v0.4.0/go.mod h1:3RMwSq7FuexP4Kalkev3ejPJsZTpXXBr9+V4qmtdjCk= github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.3 h1:gyjaxf+svBWX08ZjK86iN9geUJF0H6gp2IRKX6Nf6/I= github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= github.com/google/go-cmp v0.2.0 h1:+dTQ8DZQJz0Mb/HjFlkptS1FeQ4cWSnN941F8aEG4SQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/konsorten/go-windows-terminal-sequences v1.0.1 h1:mweAR1A6xJ3oS2pRaGiHgQ4OO8tzTaLawm8vnODuwDk= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515 h1:T+h1c/A9Gawja4Y9mFVWj2vyii2bbUNDw3kt9VxK2EY= github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc= github.com/opentracing/opentracing-go v1.1.0 h1:pWlfV3Bxv7k65HYwkikxat0+s3pV4bsqf19k25Ur8rU= github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= go.uber.org/atomic v1.4.0 h1:cxzIVoETapQEqDhQu3QfnvXAV4AlzcvUCxkVUFw3+EU= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/multierr v1.1.0 h1:HoEmRHQPVSqub6w2z2d2EOVs2fjyFRGyofhKuyDq0QI= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= go.uber.org/zap v1.10.0 h1:ORx85nbTijNz8ljznvCMR1ZBIPKFn3jQrag10X2AsuM= go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201021035429-f5854403a974 h1:IX6qOQeG5uLjB/hjjwjedwfjND0hgjPMMyO1RoIXQNI= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be h1:vEDujvNQGv4jgYKudGeI/+DAX4Jffq6hpD55MmoEvKs= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f h1:+Nyd8tzPX9R7BWHguqsrbFdRx3WQ/1ib8I44HXV5yTA= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3 h1:cokOdA+Jmi5PJGXLlLllQSgYigAEfHXJAERHVMaCc2k= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0 h1:/wp5JvzpHIxhs/dumFmF7BXTf3Z+dd4uXta4kVyO508= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= google.golang.org/genproto v0.0.0-20200423170343-7949de9c1215 h1:0Uz5jLJQioKgVozXa1gzGbzYxbb/rhQEVvSWxzw5oUs= google.golang.org/genproto v0.0.0-20200423170343-7949de9c1215/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= google.golang.org/grpc v1.29.1 h1:EC2SB8S04d2r73uptxphDSUG+kTKVgjRPF+N3xpxRB4= google.golang.org/grpc v1.29.1/go.mod h1:itym6AZVZYACWQqET3MqgPpjcuV5QH3BxFS3IjizoKk= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= go-grpc-middleware-1.3.0/logging/000077500000000000000000000000001404040257500166135ustar00rootroot00000000000000go-grpc-middleware-1.3.0/logging/common.go000066400000000000000000000027121404040257500204340ustar00rootroot00000000000000// Copyright 2017 Michal Witkowski. All Rights Reserved. // See LICENSE for licensing terms. package grpc_logging import ( "context" "io" "github.com/golang/protobuf/proto" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) // ErrorToCode function determines the error code of an error // This makes using custom errors with grpc middleware easier type ErrorToCode func(err error) codes.Code func DefaultErrorToCode(err error) codes.Code { return status.Code(err) } // Decider function defines rules for suppressing any interceptor logs type Decider func(fullMethodName string, err error) bool // DefaultDeciderMethod is the default implementation of decider to see if you should log the call // by default this if always true so all calls are logged func DefaultDeciderMethod(fullMethodName string, err error) bool { return true } // ServerPayloadLoggingDecider is a user-provided function for deciding whether to log the server-side // request/response payloads type ServerPayloadLoggingDecider func(ctx context.Context, fullMethodName string, servingObject interface{}) bool // ClientPayloadLoggingDecider is a user-provided function for deciding whether to log the client-side // request/response payloads type ClientPayloadLoggingDecider func(ctx context.Context, fullMethodName string) bool // JsonPbMarshaller is a marshaller that serializes protobuf messages. type JsonPbMarshaler interface { Marshal(out io.Writer, pb proto.Message) error } go-grpc-middleware-1.3.0/logging/doc.go000066400000000000000000000023241404040257500177100ustar00rootroot00000000000000// /* grpc_logging is a "parent" package for gRPC logging middlewares. General functionality of all middleware The gRPC logging middleware populates request-scoped data to `grpc_ctxtags.Tags` that relate to the current gRPC call (e.g. service and method names). Once the gRPC logging middleware has added the gRPC specific Tags to the ctx they will then be written with the logs that are made using the `ctx_logrus` or `ctx_zap` loggers. All logging middleware will emit a final log statement. It is based on the error returned by the handler function, the gRPC status code, an error (if any) and it will emit at a level controlled via `WithLevels`. This parent package This particular package is intended for use by other middleware, logging or otherwise. It contains interfaces that other logging middlewares *could* share . This allows code to be shared between different implementations. Field names All field names of loggers follow the OpenTracing semantics definitions, with `grpc.` prefix if needed: https://github.com/opentracing/specification/blob/master/semantic_conventions.md Implementations There are three implementations at the moment: logrus, zap and kit See relevant packages below. */ package grpc_logging go-grpc-middleware-1.3.0/logging/kit/000077500000000000000000000000001404040257500174025ustar00rootroot00000000000000go-grpc-middleware-1.3.0/logging/kit/client_interceptors.go000066400000000000000000000040341404040257500240110ustar00rootroot00000000000000package kit import ( "path" "time" "context" "github.com/go-kit/kit/log" "google.golang.org/grpc" ) // UnaryClientInterceptor returns a new unary client interceptor that optionally logs the execution of external gRPC calls. func UnaryClientInterceptor(logger log.Logger, opts ...Option) grpc.UnaryClientInterceptor { o := evaluateClientOpt(opts) return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { fields := newClientLoggerFields(ctx, method) startTime := time.Now() err := invoker(ctx, method, req, reply, cc, opts...) logFinalClientLine(o, log.With(logger, fields...), startTime, err, "finished client unary call") return err } } // StreamClientInterceptor returns a new streaming client interceptor that optionally logs the execution of external gRPC calls. func StreamClientInterceptor(logger log.Logger, opts ...Option) grpc.StreamClientInterceptor { o := evaluateClientOpt(opts) return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { fields := newClientLoggerFields(ctx, method) startTime := time.Now() clientStream, err := streamer(ctx, desc, cc, method, opts...) logFinalClientLine(o, log.With(logger, fields...), startTime, err, "finished client streaming call") return clientStream, err } } func logFinalClientLine(o *options, logger log.Logger, startTime time.Time, err error, msg string) { code := o.codeFunc(err) logger = o.levelFunc(code, logger) args := []interface{}{"msg", msg, "error", err, "grpc.code", code.String()} args = append(args, o.durationFunc(time.Since(startTime))...) logger.Log(args...) } func newClientLoggerFields(ctx context.Context, fullMethodString string) []interface{} { service := path.Dir(fullMethodString)[1:] method := path.Base(fullMethodString) return []interface{}{ "system", "grpc", "span.kind", "client", "grpc.service", service, "grpc.method", method, } } go-grpc-middleware-1.3.0/logging/kit/client_interceptors_test.go000066400000000000000000000166041404040257500250560ustar00rootroot00000000000000package kit_test import ( "io" "runtime" "strings" "testing" "github.com/go-kit/kit/log" "github.com/go-kit/kit/log/level" grpc_kit "github.com/grpc-ecosystem/go-grpc-middleware/logging/kit" pb_testproto "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "google.golang.org/grpc" "google.golang.org/grpc/codes" ) func customClientCodeToLevel(c codes.Code, logger log.Logger) log.Logger { if c == codes.Unauthenticated { // Make this a special case for tests, and an error. return level.Error(logger) } return grpc_kit.DefaultClientCodeToLevel(c, logger) } func TestKitClientSuite(t *testing.T) { opts := []grpc_kit.Option{ grpc_kit.WithLevels(customClientCodeToLevel), } b := newKitBaseSuite(t) b.logger = level.NewFilter(b.logger, level.AllowDebug()) // a lot of our stuff is on debug level by default b.InterceptorTestSuite.ClientOpts = []grpc.DialOption{ grpc.WithUnaryInterceptor(grpc_kit.UnaryClientInterceptor(b.logger, opts...)), grpc.WithStreamInterceptor(grpc_kit.StreamClientInterceptor(b.logger, opts...)), } suite.Run(t, &kitClientSuite{b}) } type kitClientSuite struct { *kitBaseSuite } func (s *kitClientSuite) TestPing() { _, err := s.Client.Ping(s.SimpleCtx(), goodPing) assert.NoError(s.T(), err, "there must be not be an on a successful call") msgs := s.getOutputJSONs() require.Len(s.T(), msgs, 1, "one log statement should be logged") assert.Equal(s.T(), msgs[0]["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain the correct service name") assert.Equal(s.T(), msgs[0]["grpc.method"], "Ping", "all lines must contain the correct method name") assert.Equal(s.T(), msgs[0]["msg"], "finished client unary call", "handler's message must contain the correct message") assert.Equal(s.T(), msgs[0]["span.kind"], "client", "all lines must contain the kind of call (client)") assert.Equal(s.T(), msgs[0]["level"], "debug", "OK codes must be logged on debug level.") assert.Contains(s.T(), msgs[0], "grpc.time_ms", "interceptor log statement should contain execution time (duration in ms)") } func (s *kitClientSuite) TestPingList() { stream, err := s.Client.PingList(s.SimpleCtx(), goodPing) require.NoError(s.T(), err, "should not fail on establishing the stream") for { _, err := stream.Recv() if err == io.EOF { break } require.NoError(s.T(), err, "reading stream should not fail") } msgs := s.getOutputJSONs() require.Len(s.T(), msgs, 1, "one log statement should be logged") assert.Equal(s.T(), msgs[0]["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain the correct service name") assert.Equal(s.T(), msgs[0]["grpc.method"], "PingList", "all lines must contain the correct method name") assert.Equal(s.T(), msgs[0]["msg"], "finished client streaming call", "handler's message must contain the correct message") assert.Equal(s.T(), msgs[0]["span.kind"], "client", "all lines must contain the kind of call (client)") assert.Equal(s.T(), msgs[0]["level"], "debug", "OK codes must be logged on debug level.") assert.Contains(s.T(), msgs[0], "grpc.time_ms", "interceptor log statement should contain execution time (duration in ms)") } func (s *kitClientSuite) TestPingError_WithCustomLevels() { for _, tcase := range []struct { code codes.Code level level.Value msg string }{ { code: codes.Internal, level: level.WarnValue(), msg: "Internal must remap to WarnLevel in DefaultClientCodeToLevel", }, { code: codes.NotFound, level: level.DebugValue(), msg: "NotFound must remap to DebugLevel in DefaultClientCodeToLevel", }, { code: codes.FailedPrecondition, level: level.DebugValue(), msg: "FailedPrecondition must remap to DebugLevel in DefaultClientCodeToLevel", }, { code: codes.Unauthenticated, level: level.ErrorValue(), msg: "Unauthenticated is overwritten to ErrorLevel with customClientCodeToLevel override, which probably didn't work", }, } { s.SetupTest() _, err := s.Client.PingError( s.SimpleCtx(), &pb_testproto.PingRequest{Value: "something", ErrorCodeReturned: uint32(tcase.code)}) assert.Error(s.T(), err, "each call here must return an error") msgs := s.getOutputJSONs() require.Len(s.T(), msgs, 1, "only a single log message is printed") assert.Equal(s.T(), msgs[0]["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain the correct service name") assert.Equal(s.T(), msgs[0]["grpc.method"], "PingError", "all lines must contain the correct method name") assert.Equal(s.T(), msgs[0]["grpc.code"], tcase.code.String(), "all lines must contain a grpc code") assert.Equal(s.T(), msgs[0]["level"], tcase.level.String(), tcase.msg) } } func TestKitClientOverrideSuite(t *testing.T) { if strings.HasPrefix(runtime.Version(), "go1.7") { t.Skip("Skipping due to json.RawMessage incompatibility with go1.7") return } opts := []grpc_kit.Option{ grpc_kit.WithDurationField(grpc_kit.DurationToDurationField), } b := newKitBaseSuite(t) b.logger = level.NewFilter(b.logger, level.AllowDebug()) // a lot of our stuff is on debug level by default b.InterceptorTestSuite.ClientOpts = []grpc.DialOption{ grpc.WithUnaryInterceptor(grpc_kit.UnaryClientInterceptor(b.logger, opts...)), grpc.WithStreamInterceptor(grpc_kit.StreamClientInterceptor(b.logger, opts...)), } suite.Run(t, &kitClientOverrideSuite{b}) } type kitClientOverrideSuite struct { *kitBaseSuite } func (s *kitClientOverrideSuite) TestPing_HasOverrides() { _, err := s.Client.Ping(s.SimpleCtx(), goodPing) assert.NoError(s.T(), err, "there must be not be an on a successful call") msgs := s.getOutputJSONs() require.Len(s.T(), msgs, 1, "one log statement should be logged") assert.Equal(s.T(), msgs[0]["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain the correct service name") assert.Equal(s.T(), msgs[0]["grpc.method"], "Ping", "all lines must contain the correct method name") assert.Equal(s.T(), msgs[0]["msg"], "finished client unary call", "handler's message must contain the correct message") assert.NotContains(s.T(), msgs[0], "grpc.time_ms", "message must not contain default duration") assert.Contains(s.T(), msgs[0], "grpc.duration", "message must contain overridden duration") } func (s *kitClientOverrideSuite) TestPingList_HasOverrides() { stream, err := s.Client.PingList(s.SimpleCtx(), goodPing) require.NoError(s.T(), err, "should not fail on establishing the stream") for { _, err := stream.Recv() if err == io.EOF { break } require.NoError(s.T(), err, "reading stream should not fail") } msgs := s.getOutputJSONs() require.Len(s.T(), msgs, 1, "one log statement should be logged") assert.Equal(s.T(), msgs[0]["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain the correct service name") assert.Equal(s.T(), msgs[0]["grpc.method"], "PingList", "all lines must contain the correct method name") assert.Equal(s.T(), msgs[0]["msg"], "finished client streaming call", "log message must be correct") assert.Equal(s.T(), msgs[0]["span.kind"], "client", "all lines must contain the kind of call (client)") assert.Equal(s.T(), msgs[0]["level"], "debug", "OK codes must be logged on debug level.") assert.NotContains(s.T(), msgs[0], "grpc.time_ms", "message must not contain default duration") assert.Contains(s.T(), msgs[0], "grpc.duration", "message must contain overridden duration") } go-grpc-middleware-1.3.0/logging/kit/ctxkit/000077500000000000000000000000001404040257500207105ustar00rootroot00000000000000go-grpc-middleware-1.3.0/logging/kit/ctxkit/context.go000066400000000000000000000027061404040257500227300ustar00rootroot00000000000000package ctxkit import ( "context" "github.com/go-kit/kit/log" grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags" ) type ctxMarker struct{} type ctxLogger struct { logger log.Logger fields []interface{} } var ( ctxMarkerKey = &ctxMarker{} ) // AddFields adds fields to the logger. func AddFields(ctx context.Context, fields ...interface{}) { l, ok := ctx.Value(ctxMarkerKey).(*ctxLogger) if !ok || l == nil { return } l.fields = append(l.fields, fields...) } // Extract takes the call-scoped Logger from grpc_kit middleware. // // It always returns a Logger that has all the grpc_ctxtags updated. func Extract(ctx context.Context) log.Logger { l, ok := ctx.Value(ctxMarkerKey).(*ctxLogger) if !ok || l == nil { return log.NewNopLogger() } // Add grpc_ctxtags tags metadata until now. fields := TagsToFields(ctx) return log.With(l.logger, append(fields, l.fields...)...) } // TagsToFields transforms the Tags on the supplied context into kit fields. func TagsToFields(ctx context.Context) []interface{} { var fields []interface{} tags := grpc_ctxtags.Extract(ctx) for k, v := range tags.Values() { fields = append(fields, k, v) } return fields } // ToContext adds the kit.Logger to the context for extraction later. // Returning the new context that has been created. func ToContext(ctx context.Context, logger log.Logger) context.Context { l := &ctxLogger{ logger: logger, } return context.WithValue(ctx, ctxMarkerKey, l) } go-grpc-middleware-1.3.0/logging/kit/ctxkit/doc.go000066400000000000000000000011751404040257500220100ustar00rootroot00000000000000/* `ctxkit` is a ctxlogger that is backed by go-kit It accepts a user-configured `log.Logger` that will be used for logging. The same `log.Logger` will be populated into the `context.Context` passed into gRPC handler code. You can use `ctxkit.Extract` to log into a request-scoped `log.Logger` instance in your handler code. As `ctxkit.Extract` will iterate all tags on from `grpc_ctxtags` it is therefore expensive so it is advised that you extract once at the start of the function from the context and reuse it for the remainder of the function (see examples). Please see examples and tests for examples of use. */ package ctxkit go-grpc-middleware-1.3.0/logging/kit/ctxkit/examples_test.go000066400000000000000000000016151404040257500241170ustar00rootroot00000000000000package ctxkit_test import ( "context" "github.com/grpc-ecosystem/go-grpc-middleware/logging/kit/ctxkit" grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags" pb_testproto "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto" ) // Simple unary handler that adds custom fields to the requests's context. These will be used for all log statements. func ExampleExtract_unary() { _ = func(ctx context.Context, ping *pb_testproto.PingRequest) (*pb_testproto.PingResponse, error) { // Add fields the ctxtags of the request which will be added to all extracted loggers. grpc_ctxtags.Extract(ctx).Set("custom_tags.string", "something").Set("custom_tags.int", 1337) // Extract a single request-scoped log.Logger and log messages. l := ctxkit.Extract(ctx) l.Log("msg", "some ping") l.Log("msg", "another ping") return &pb_testproto.PingResponse{Value: ping.Value}, nil } } go-grpc-middleware-1.3.0/logging/kit/doc.go000066400000000000000000000064131404040257500205020ustar00rootroot00000000000000/* `grpc_kit` is a gRPC logging middleware backed by go-kit loggers It accepts a user-configured `log.Logger` that will be used for logging completed gRPC calls, and be populated into the `context.Context` passed into gRPC handler code. On calling `StreamServerInterceptor` or `UnaryServerInterceptor` this logging middleware will add gRPC call information to the ctx so that it will be present on subsequent use of the `ctxkit` logger. If a deadline is present on the gRPC request the grpc.request.deadline tag is populated when the request begins. grpc.request.deadline is a string representing the time (RFC3339) when the current call will expire. This package also implements request and response *payload* logging, both for server-side and client-side. These will be logged as structured `jsonpb` fields for every message received/sent (both unary and streaming). For that please use `Payload*Interceptor` functions for that. Please note that the user-provided function that determines whetether to log the full request/response payload needs to be written with care, this can significantly slow down gRPC. *Server Interceptor* Below is a JSON formatted example of a log that would be logged by the server interceptor: { "level": "info", // string log level "msg": "finished unary call", // string log message "grpc.code": "OK", // string grpc status code "grpc.method": "Ping", // string method name "grpc.service": "mwitkow.testproto.TestService", // string full name of the called service "grpc.start_time": "2006-01-02T15:04:05Z07:00", // string RFC3339 representation of the start time "grpc.request.deadline": "2006-01-02T15:04:05Z07:00", // string RFC3339 deadline of the current request if supplied "grpc.request.value": "something", // string value on the request "grpc.time_ms": 1.345, // float32 run time of the call in ms "peer.address": { "IP": "127.0.0.1", // string IP address of calling party "Port": 60216, // int port call is coming in on "Zone": "" // string peer zone for caller }, "span.kind": "server", // string client | server "system": "grpc" // string "custom_field": "custom_value", // string user defined field "custom_tags.int": 1337, // int user defined tag on the ctx "custom_tags.string": "something", // string user defined tag on the ctx } *Payload Interceptor* Below is a JSON formatted example of a log that would be logged by the payload interceptor: { "level": "info", // string kit log levels "msg": "client request payload logged as grpc.request.content", // string log message "grpc.request.content": { // object content of RPC request "msg" : { // object kit specific inner object "value": "something", // string defined by caller "sleepTimeMs": 9999 // int defined by caller } }, "grpc.method": "Ping", // string method being called "grpc.service": "mwitkow.testproto.TestService", // string service being called "span.kind": "client", // string client | server "system": "grpc" // string } Please see examples and tests for examples of use. */ package kit go-grpc-middleware-1.3.0/logging/kit/examples_test.go000066400000000000000000000050101404040257500226020ustar00rootroot00000000000000package kit_test import ( "time" "github.com/go-kit/kit/log" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" "github.com/grpc-ecosystem/go-grpc-middleware/logging/kit" grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags" "google.golang.org/grpc" ) var ( customFunc kit.CodeToLevel ) // Initialization shows a relatively complex initialization sequence. func Example_initialization() { // Logger is used, allowing pre-definition of certain fields by the user. logger := log.NewNopLogger() // Shared options for the logger, with a custom gRPC code to log level function. opts := []kit.Option{ kit.WithLevels(customFunc), } // Create a server, make sure we put the grpc_ctxtags context before everything else. _ = grpc.NewServer( grpc_middleware.WithUnaryServerChain( grpc_ctxtags.UnaryServerInterceptor(grpc_ctxtags.WithFieldExtractor(grpc_ctxtags.CodeGenRequestFieldExtractor)), kit.UnaryServerInterceptor(logger, opts...), ), grpc_middleware.WithStreamServerChain( grpc_ctxtags.StreamServerInterceptor(grpc_ctxtags.WithFieldExtractor(grpc_ctxtags.CodeGenRequestFieldExtractor)), kit.StreamServerInterceptor(logger, opts...), ), ) } func Example_initializationWithDurationFieldOverride() { // Logger is used, allowing pre-definition of certain fields by the user. logger := log.NewNopLogger() // Shared options for the logger, with a custom duration to log field function. opts := []kit.Option{ kit.WithDurationField(func(duration time.Duration) []interface{} { return kit.DurationToTimeMillisField(duration) }), } _ = grpc.NewServer( grpc_middleware.WithUnaryServerChain( grpc_ctxtags.UnaryServerInterceptor(), kit.UnaryServerInterceptor(logger, opts...), ), grpc_middleware.WithStreamServerChain( grpc_ctxtags.StreamServerInterceptor(), kit.StreamServerInterceptor(logger, opts...), ), ) } func ExampleWithDecider() { opts := []kit.Option{ kit.WithDecider(func(methodFullName string, err error) bool { // will not log gRPC calls if it was a call to healthcheck and no error was raised if err == nil && methodFullName == "blah.foo.healthcheck" { return false } // by default you will log all calls return true }), } _ = []grpc.ServerOption{ grpc_middleware.WithStreamServerChain( grpc_ctxtags.StreamServerInterceptor(), kit.StreamServerInterceptor(log.NewNopLogger(), opts...)), grpc_middleware.WithUnaryServerChain( grpc_ctxtags.UnaryServerInterceptor(), kit.UnaryServerInterceptor(log.NewNopLogger(), opts...)), } } go-grpc-middleware-1.3.0/logging/kit/options.go000066400000000000000000000105571404040257500214340ustar00rootroot00000000000000package kit import ( "time" "github.com/go-kit/kit/log" "github.com/go-kit/kit/log/level" grpc_logging "github.com/grpc-ecosystem/go-grpc-middleware/logging" "google.golang.org/grpc/codes" ) var ( defaultOptions = &options{ shouldLog: grpc_logging.DefaultDeciderMethod, codeFunc: grpc_logging.DefaultErrorToCode, durationFunc: DefaultDurationToField, timestampFormat: time.RFC3339, } ) type options struct { levelFunc CodeToLevel shouldLog grpc_logging.Decider codeFunc grpc_logging.ErrorToCode durationFunc DurationToField timestampFormat string } type Option func(*options) // CodeToLevel function defines the mapping between gRPC return codes and interceptor log level. type CodeToLevel func(code codes.Code, logger log.Logger) log.Logger // DurationToField function defines how to produce duration fields for logging type DurationToField func(duration time.Duration) []interface{} func evaluateServerOpt(opts []Option) *options { optCopy := &options{} *optCopy = *defaultOptions optCopy.levelFunc = DefaultCodeToLevel for _, o := range opts { o(optCopy) } return optCopy } func evaluateClientOpt(opts []Option) *options { optCopy := &options{} *optCopy = *defaultOptions optCopy.levelFunc = DefaultClientCodeToLevel for _, o := range opts { o(optCopy) } return optCopy } // WithDecider customizes the function for deciding if the gRPC interceptor logs should log. func WithDecider(f grpc_logging.Decider) Option { return func(o *options) { o.shouldLog = f } } // WithLevels customizes the function for mapping gRPC return codes and interceptor log level statements. func WithLevels(f CodeToLevel) Option { return func(o *options) { o.levelFunc = f } } // WithCodes customizes the function for mapping errors to error codes. func WithCodes(f grpc_logging.ErrorToCode) Option { return func(o *options) { o.codeFunc = f } } // WithDurationField customizes the function for mapping request durations to log fields. func WithDurationField(f DurationToField) Option { return func(o *options) { o.durationFunc = f } } // WithTimestampFormat customizes the timestamps emitted in the log fields. func WithTimestampFormat(format string) Option { return func(o *options) { o.timestampFormat = format } } // DefaultCodeToLevel is the default implementation of gRPC return codes and interceptor log level for server side. func DefaultCodeToLevel(code codes.Code, logger log.Logger) log.Logger { switch code { case codes.OK, codes.Canceled, codes.InvalidArgument, codes.NotFound, codes.AlreadyExists, codes.Unauthenticated: return level.Info(logger) case codes.DeadlineExceeded, codes.PermissionDenied, codes.ResourceExhausted, codes.FailedPrecondition, codes.Aborted, codes.OutOfRange, codes.Unavailable: return level.Warn(logger) case codes.Unknown, codes.Unimplemented, codes.Internal, codes.DataLoss: return level.Error(logger) default: return level.Error(logger) } } // DefaultClientCodeToLevel is the default implementation of gRPC return codes to log levels for client side. func DefaultClientCodeToLevel(code codes.Code, logger log.Logger) log.Logger { switch code { case codes.OK, codes.Canceled, codes.InvalidArgument, codes.NotFound, codes.AlreadyExists, codes.ResourceExhausted, codes.FailedPrecondition, codes.Aborted, codes.OutOfRange: return level.Debug(logger) case codes.Unknown, codes.DeadlineExceeded, codes.PermissionDenied, codes.Unauthenticated: return level.Info(logger) case codes.Unimplemented, codes.Internal, codes.Unavailable, codes.DataLoss: return level.Warn(logger) default: return level.Info(logger) } } // DefaultDurationToField is the default implementation of converting request duration to a kit field. var DefaultDurationToField = DurationToTimeMillisField // DurationToTimeMillisField converts the duration to milliseconds and uses the key `grpc.time_ms`. func DurationToTimeMillisField(duration time.Duration) []interface{} { return []interface{}{"grpc.time_ms", durationToMilliseconds(duration)} } // DurationToDurationField uses a Duration field to log the request duration // and leaves it up to Log's encoder settings to determine how that is output. func DurationToDurationField(duration time.Duration) []interface{} { return []interface{}{"grpc.duration", duration} } func durationToMilliseconds(duration time.Duration) float32 { return float32(duration.Nanoseconds()/1000) / 1000 } go-grpc-middleware-1.3.0/logging/kit/payload_interceptors.go000066400000000000000000000136161404040257500241720ustar00rootroot00000000000000package kit import ( "bytes" "fmt" "context" "github.com/go-kit/kit/log" "github.com/go-kit/kit/log/level" "github.com/golang/protobuf/jsonpb" "github.com/golang/protobuf/proto" grpc_logging "github.com/grpc-ecosystem/go-grpc-middleware/logging" "github.com/grpc-ecosystem/go-grpc-middleware/logging/kit/ctxkit" "google.golang.org/grpc" ) var ( // JsonPbMarshaller is the marshaller used for serializing protobuf messages. // If needed, this variable can be reassigned with a different marshaller with the same Marshal() signature. JsonPbMarshaller grpc_logging.JsonPbMarshaler = &jsonpb.Marshaler{} ) // PayloadUnaryServerInterceptor returns a new unary server interceptors that logs the payloads of requests. // // This *only* works when placed *after* the `kit.UnaryServerInterceptor`. However, the logging can be done to a // separate instance of the logger. func PayloadUnaryServerInterceptor(logger log.Logger, decider grpc_logging.ServerPayloadLoggingDecider) grpc.UnaryServerInterceptor { return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { if !decider(ctx, info.FullMethod, info.Server) { return handler(ctx, req) } // Use the provided log.Logger for logging but use the fields from context. logger = log.With(logger, append(serverCallFields(info.FullMethod), ctxkit.TagsToFields(ctx)...)...) logProtoMessageAsJson(logger, req, "grpc.request.content", "server request payload logged as grpc.request.content field") resp, err := handler(ctx, req) if err == nil { logProtoMessageAsJson(logger, resp, "grpc.response.content", "server response payload logged as grpc.request.content field") } return resp, err } } // PayloadStreamServerInterceptor returns a new server server interceptors that logs the payloads of requests. // // This *only* works when placed *after* the `kit.StreamServerInterceptor`. However, the logging can be done to a // separate instance of the logger. func PayloadStreamServerInterceptor(logger log.Logger, decider grpc_logging.ServerPayloadLoggingDecider) grpc.StreamServerInterceptor { return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { if !decider(stream.Context(), info.FullMethod, srv) { return handler(srv, stream) } logEntry := log.With(logger, append(serverCallFields(info.FullMethod), ctxkit.TagsToFields(stream.Context())...)...) newStream := &loggingServerStream{ServerStream: stream, logger: logEntry} return handler(srv, newStream) } } // PayloadUnaryClientInterceptor returns a new unary client interceptor that logs the paylods of requests and responses. func PayloadUnaryClientInterceptor(logger log.Logger, decider grpc_logging.ClientPayloadLoggingDecider) grpc.UnaryClientInterceptor { return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { if !decider(ctx, method) { return invoker(ctx, method, req, reply, cc, opts...) } logEntry := log.With(logger, newClientLoggerFields(ctx, method)...) logProtoMessageAsJson(logEntry, req, "grpc.request.content", "client request payload logged as grpc.request.content") err := invoker(ctx, method, req, reply, cc, opts...) if err == nil { logProtoMessageAsJson(logEntry, reply, "grpc.response.content", "client response payload logged as grpc.response.content") } return err } } // PayloadStreamClientInterceptor returns a new streaming client interceptor that logs the paylods of requests and responses. func PayloadStreamClientInterceptor(logger log.Logger, decider grpc_logging.ClientPayloadLoggingDecider) grpc.StreamClientInterceptor { return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { if !decider(ctx, method) { return streamer(ctx, desc, cc, method, opts...) } logEntry := log.With(logger, newClientLoggerFields(ctx, method)...) clientStream, err := streamer(ctx, desc, cc, method, opts...) newStream := &loggingClientStream{ClientStream: clientStream, logger: logEntry} return newStream, err } } type loggingClientStream struct { grpc.ClientStream logger log.Logger } func (l *loggingClientStream) SendMsg(m interface{}) error { err := l.ClientStream.SendMsg(m) if err == nil { logProtoMessageAsJson(l.logger, m, "grpc.request.content", "server request payload logged as grpc.request.content field") } return err } func (l *loggingClientStream) RecvMsg(m interface{}) error { err := l.ClientStream.RecvMsg(m) if err == nil { logProtoMessageAsJson(l.logger, m, "grpc.response.content", "server response payload logged as grpc.response.content field") } return err } type loggingServerStream struct { grpc.ServerStream logger log.Logger } func (l *loggingServerStream) SendMsg(m interface{}) error { err := l.ServerStream.SendMsg(m) if err == nil { logProtoMessageAsJson(l.logger, m, "grpc.response.content", "server response payload logged as grpc.response.content field") } return err } func (l *loggingServerStream) RecvMsg(m interface{}) error { err := l.ServerStream.RecvMsg(m) if err == nil { logProtoMessageAsJson(l.logger, m, "grpc.request.content", "server request payload logged as grpc.request.content field") } return err } func logProtoMessageAsJson(logger log.Logger, pbMsg interface{}, key string, msg string) { if p, ok := pbMsg.(proto.Message); ok { payload, err := (&jsonpbObjectMarshaler{pb: p}).marshalJSON() if err != nil { level.Info(logger).Log(key, err) } level.Info(logger).Log(key, string(payload)) } } type jsonpbObjectMarshaler struct { pb proto.Message } func (j *jsonpbObjectMarshaler) marshalJSON() ([]byte, error) { b := &bytes.Buffer{} if err := JsonPbMarshaller.Marshal(b, j.pb); err != nil { return nil, fmt.Errorf("jsonpb serializer failed: %v", err) } return b.Bytes(), nil } go-grpc-middleware-1.3.0/logging/kit/payload_interceptors_test.go000066400000000000000000000133661404040257500252330ustar00rootroot00000000000000package kit_test import ( "io" "runtime" "strings" "testing" "context" "github.com/go-kit/kit/log" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" grpc_kit "github.com/grpc-ecosystem/go-grpc-middleware/logging/kit" grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags" pb_testproto "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "google.golang.org/grpc" ) func TestKitPayloadSuite(t *testing.T) { if strings.HasPrefix(runtime.Version(), "go1.7") { t.Skipf("Skipping due to json.RawMessage incompatibility with go1.7") return } alwaysLoggingDeciderServer := func(ctx context.Context, fullMethodName string, servingObject interface{}) bool { return true } alwaysLoggingDeciderClient := func(ctx context.Context, fullMethodName string) bool { return true } b := newKitBaseSuite(t) b.InterceptorTestSuite.ClientOpts = []grpc.DialOption{ grpc.WithUnaryInterceptor(grpc_kit.PayloadUnaryClientInterceptor(b.logger, alwaysLoggingDeciderClient)), grpc.WithStreamInterceptor(grpc_kit.PayloadStreamClientInterceptor(b.logger, alwaysLoggingDeciderClient)), } noOpLogger := log.NewNopLogger() b.InterceptorTestSuite.ServerOpts = []grpc.ServerOption{ grpc_middleware.WithStreamServerChain( grpc_ctxtags.StreamServerInterceptor(grpc_ctxtags.WithFieldExtractor(grpc_ctxtags.CodeGenRequestFieldExtractor)), grpc_kit.StreamServerInterceptor(noOpLogger), grpc_kit.PayloadStreamServerInterceptor(b.logger, alwaysLoggingDeciderServer)), grpc_middleware.WithUnaryServerChain( grpc_ctxtags.UnaryServerInterceptor(grpc_ctxtags.WithFieldExtractor(grpc_ctxtags.CodeGenRequestFieldExtractor)), grpc_kit.UnaryServerInterceptor(noOpLogger), grpc_kit.PayloadUnaryServerInterceptor(b.logger, alwaysLoggingDeciderServer)), } suite.Run(t, &kitPayloadSuite{b}) } type kitPayloadSuite struct { *kitBaseSuite } func (s *kitPayloadSuite) getServerAndClientMessages(expectedServer int, expectedClient int) (serverMsgs []map[string]interface{}, clientMsgs []map[string]interface{}) { msgs := s.getOutputJSONs() for _, m := range msgs { if m["span.kind"] == "server" { serverMsgs = append(serverMsgs, m) } else if m["span.kind"] == "client" { clientMsgs = append(clientMsgs, m) } } require.Len(s.T(), serverMsgs, expectedServer, "must match expected number of server log messages") require.Len(s.T(), clientMsgs, expectedClient, "must match expected number of client log messages") return serverMsgs, clientMsgs } func (s *kitPayloadSuite) TestPing_LogsBothRequestAndResponse() { _, err := s.Client.Ping(s.SimpleCtx(), goodPing) require.NoError(s.T(), err, "there must be not be an error on a successful call") serverMsgs, clientMsgs := s.getServerAndClientMessages(2, 2) for _, m := range append(serverMsgs, clientMsgs...) { assert.Equal(s.T(), m["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain service name") assert.Equal(s.T(), m["grpc.method"], "Ping", "all lines must contain method name") assert.Equal(s.T(), m["level"], "info", "all payloads must be logged on info level") } serverReq, serverResp := serverMsgs[0], serverMsgs[1] clientReq, clientResp := clientMsgs[0], clientMsgs[1] s.T().Log(clientReq) assert.Contains(s.T(), clientReq, "grpc.request.content", "request payload must be logged in a structured way") assert.Contains(s.T(), serverReq, "grpc.request.content", "request payload must be logged in a structured way") assert.Contains(s.T(), clientResp, "grpc.response.content", "response payload must be logged in a structured way") assert.Contains(s.T(), serverResp, "grpc.response.content", "response payload must be logged in a structured way") } func (s *kitPayloadSuite) TestPingError_LogsOnlyRequestsOnError() { _, err := s.Client.PingError(s.SimpleCtx(), &pb_testproto.PingRequest{Value: "something", ErrorCodeReturned: uint32(4)}) require.Error(s.T(), err, "there must be an error on an unsuccessful call") serverMsgs, clientMsgs := s.getServerAndClientMessages(1, 1) for _, m := range append(serverMsgs, clientMsgs...) { assert.Equal(s.T(), m["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain service name") assert.Equal(s.T(), m["grpc.method"], "PingError", "all lines must contain method name") assert.Equal(s.T(), m["level"], "info", "must be logged at the info level") } assert.Contains(s.T(), clientMsgs[0], "grpc.request.content", "request payload must be logged in a structured way") assert.Contains(s.T(), serverMsgs[0], "grpc.request.content", "request payload must be logged in a structured way") } func (s *kitPayloadSuite) TestPingStream_LogsAllRequestsAndResponses() { messagesExpected := 20 stream, err := s.Client.PingStream(s.SimpleCtx()) require.NoError(s.T(), err, "no error on stream creation") for i := 0; i < messagesExpected; i++ { require.NoError(s.T(), stream.Send(goodPing), "sending must succeed") } require.NoError(s.T(), stream.CloseSend(), "no error on send stream") for { pong := &pb_testproto.PingResponse{} err := stream.RecvMsg(pong) if err == io.EOF { break } require.NoError(s.T(), err, "no error on receive") } serverMsgs, clientMsgs := s.getServerAndClientMessages(2*messagesExpected, 2*messagesExpected) for _, m := range append(serverMsgs, clientMsgs...) { assert.Equal(s.T(), m["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain service name") assert.Equal(s.T(), m["grpc.method"], "PingStream", "all lines must contain method name") assert.Equal(s.T(), m["level"], "info", "all lines must logged at info level") content := m["grpc.request.content"] != nil || m["grpc.response.content"] != nil assert.True(s.T(), content, "all messages must contain payloads") } } go-grpc-middleware-1.3.0/logging/kit/server_interceptors.go000066400000000000000000000062401404040257500240420ustar00rootroot00000000000000package kit import ( "path" "time" "context" "github.com/go-kit/kit/log" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" "github.com/grpc-ecosystem/go-grpc-middleware/logging/kit/ctxkit" "google.golang.org/grpc" "google.golang.org/grpc/codes" ) var ( // SystemField is used in every log statement made through grpc_zap. Can be overwritten before any initialization code. SystemField = "grpc" // ServerField is used in every server-side log statement made through grpc_zap.Can be overwritten before initialization. ServerField = "server" ) // UnaryServerInterceptor returns a new unary server interceptors that adds kit.Logger to the context. func UnaryServerInterceptor(logger log.Logger, opts ...Option) grpc.UnaryServerInterceptor { o := evaluateServerOpt(opts) return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { startTime := time.Now() newCtx := injectLogger(ctx, logger, info.FullMethod, startTime, o.timestampFormat) resp, err := handler(newCtx, req) if !o.shouldLog(info.FullMethod, err) { return resp, err } code := o.codeFunc(err) logCall(newCtx, o, "finished unary call with code "+code.String(), code, startTime, err) return resp, err } } // StreamServerInterceptor returns a new stream server interceptors that adds kit.Logger to the context. func StreamServerInterceptor(logger log.Logger, opts ...Option) grpc.StreamServerInterceptor { o := evaluateServerOpt(opts) return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { startTime := time.Now() newCtx := injectLogger(stream.Context(), logger, info.FullMethod, startTime, o.timestampFormat) wrapped := grpc_middleware.WrapServerStream(stream) wrapped.WrappedContext = newCtx err := handler(srv, wrapped) if !o.shouldLog(info.FullMethod, err) { return err } code := o.codeFunc(err) logCall(newCtx, o, "finished streaming call with code "+code.String(), code, startTime, err) return err } } func injectLogger(ctx context.Context, logger log.Logger, fullMethodString string, start time.Time, timestampFormat string) context.Context { f := ctxkit.TagsToFields(ctx) f = append(f, "grpc.start_time", start.Format(timestampFormat)) if d, ok := ctx.Deadline(); ok { f = append(f, "grpc.request.deadline", d.Format(timestampFormat)) } f = append(f, serverCallFields(fullMethodString)...) callLog := log.With(logger, f...) return ctxkit.ToContext(ctx, callLog) } func serverCallFields(fullMethodString string) []interface{} { service := path.Dir(fullMethodString)[1:] method := path.Base(fullMethodString) return []interface{}{ "system", SystemField, "span.kind", ServerField, "grpc.service", service, "grpc.method", method, } } func logCall(ctx context.Context, options *options, msg string, code codes.Code, startTime time.Time, err error) { extractedLogger := ctxkit.Extract(ctx) extractedLogger = options.levelFunc(code, extractedLogger) args := []interface{}{"msg", msg, "error", err, "grpc.code", code.String()} args = append(args, options.durationFunc(time.Since(startTime))...) _ = extractedLogger.Log(args...) } go-grpc-middleware-1.3.0/logging/kit/server_interceptors_test.go000066400000000000000000000342331404040257500251040ustar00rootroot00000000000000package kit_test import ( "io" "runtime" "strings" "testing" "time" "github.com/go-kit/kit/log" "github.com/go-kit/kit/log/level" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" grpc_kit "github.com/grpc-ecosystem/go-grpc-middleware/logging/kit" grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags" pb_testproto "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "google.golang.org/grpc" "google.golang.org/grpc/codes" ) func customCodeToLevel(c codes.Code, logger log.Logger) log.Logger { if c == codes.Unauthenticated { // Make this a special case for tests, and an error. return level.Error(logger) } return grpc_kit.DefaultCodeToLevel(c, logger) } func TestKitLoggingSuite(t *testing.T) { if strings.HasPrefix(runtime.Version(), "go1.7") { t.Skipf("Skipping due to json.RawMessage incompatibility with go1.7") return } for _, tcase := range []struct { timestampFormat string }{ { timestampFormat: time.RFC3339, }, { timestampFormat: "2006-01-02", }, } { opts := []grpc_kit.Option{ grpc_kit.WithLevels(customCodeToLevel), grpc_kit.WithTimestampFormat(tcase.timestampFormat), } b := newKitBaseSuite(t) b.timestampFormat = tcase.timestampFormat b.InterceptorTestSuite.ServerOpts = []grpc.ServerOption{ grpc_middleware.WithStreamServerChain( grpc_ctxtags.StreamServerInterceptor(grpc_ctxtags.WithFieldExtractor(grpc_ctxtags.CodeGenRequestFieldExtractor)), grpc_kit.StreamServerInterceptor(b.logger, opts...)), grpc_middleware.WithUnaryServerChain( grpc_ctxtags.UnaryServerInterceptor(grpc_ctxtags.WithFieldExtractor(grpc_ctxtags.CodeGenRequestFieldExtractor)), grpc_kit.UnaryServerInterceptor(b.logger, opts...)), } suite.Run(t, &kitServerSuite{b}) } } type kitServerSuite struct { *kitBaseSuite } func (s *kitServerSuite) TestPing_WithCustomTags() { deadline := time.Now().Add(3 * time.Second) _, err := s.Client.Ping(s.DeadlineCtx(deadline), goodPing) require.NoError(s.T(), err, "there must be not be an error on a successful call") msgs := s.getOutputJSONs() require.Len(s.T(), msgs, 2, "two log statements should be logged") for _, m := range msgs { assert.Equal(s.T(), m["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain service name") assert.Equal(s.T(), m["grpc.method"], "Ping", "all lines must contain method name") assert.Equal(s.T(), m["span.kind"], "server", "all lines must contain the kind of call (server)") assert.Equal(s.T(), m["custom_tags.string"], "something", "all lines must contain `custom_tags.string`") assert.Equal(s.T(), m["grpc.request.value"], "something", "all lines must contain fields extracted") assert.Equal(s.T(), m["custom_field"], "custom_value", "all lines must contain `custom_field`") assert.Contains(s.T(), m, "custom_tags.int", "all lines must contain `custom_tags.int`") require.Contains(s.T(), m, "grpc.start_time", "all lines must contain the start time") _, err := time.Parse(s.timestampFormat, m["grpc.start_time"].(string)) assert.NoError(s.T(), err, "should be able to parse start time") require.Contains(s.T(), m, "grpc.request.deadline", "all lines must contain the deadline of the call") _, err = time.Parse(s.timestampFormat, m["grpc.request.deadline"].(string)) require.NoError(s.T(), err, "should be able to parse deadline") assert.Equal(s.T(), m["grpc.request.deadline"], deadline.Format(s.timestampFormat), "should have the same deadline that was set by the caller") } assert.Equal(s.T(), msgs[0]["msg"], "some ping", "handler's message must contain user message") assert.Equal(s.T(), msgs[1]["msg"], "finished unary call with code OK", "handler's message must contain user message") assert.Equal(s.T(), msgs[1]["level"], "info", "must be logged at info level") assert.Contains(s.T(), msgs[1], "grpc.time_ms", "interceptor log statement should contain execution time") } func (s *kitServerSuite) TestPingError_WithCustomLevels() { for _, tcase := range []struct { code codes.Code level level.Value msg string }{ { code: codes.Internal, level: level.ErrorValue(), msg: "Internal must remap to ErrorLevel in DefaultCodeToLevel", }, { code: codes.NotFound, level: level.InfoValue(), msg: "NotFound must remap to InfoLevel in DefaultCodeToLevel", }, { code: codes.FailedPrecondition, level: level.WarnValue(), msg: "FailedPrecondition must remap to WarnLevel in DefaultCodeToLevel", }, { code: codes.Unauthenticated, level: level.ErrorValue(), msg: "Unauthenticated is overwritten to DPanicLevel with customCodeToLevel override, which probably didn't work", }, } { s.buffer.Reset() _, err := s.Client.PingError( s.SimpleCtx(), &pb_testproto.PingRequest{Value: "something", ErrorCodeReturned: uint32(tcase.code)}) require.Error(s.T(), err, "each call here must return an error") msgs := s.getOutputJSONs() require.Len(s.T(), msgs, 1, "only the interceptor log message is printed in PingErr") m := msgs[0] assert.Equal(s.T(), m["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain service name") assert.Equal(s.T(), m["grpc.method"], "PingError", "all lines must contain method name") assert.Equal(s.T(), m["grpc.code"], tcase.code.String(), "all lines have the correct gRPC code") assert.Equal(s.T(), m["level"], tcase.level.String(), tcase.msg) assert.Equal(s.T(), m["msg"], "finished unary call with code "+tcase.code.String(), "needs the correct end message") require.Contains(s.T(), m, "grpc.start_time", "all lines must contain the start time") _, err = time.Parse(s.timestampFormat, m["grpc.start_time"].(string)) assert.NoError(s.T(), err, "should be able to parse start time") } } func (s *kitServerSuite) TestPingList_WithCustomTags() { stream, err := s.Client.PingList(s.SimpleCtx(), goodPing) require.NoError(s.T(), err, "should not fail on establishing the stream") for { _, err := stream.Recv() if err == io.EOF { break } require.NoError(s.T(), err, "reading stream should not fail") } msgs := s.getOutputJSONs() require.Len(s.T(), msgs, 2, "two log statements should be logged") for _, m := range msgs { assert.Equal(s.T(), m["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain service name") assert.Equal(s.T(), m["grpc.method"], "PingList", "all lines must contain method name") assert.Equal(s.T(), m["span.kind"], "server", "all lines must contain the kind of call (server)") assert.Equal(s.T(), m["custom_tags.string"], "something", "all lines must contain `custom_tags.string` set by AddFields") assert.Equal(s.T(), m["grpc.request.value"], "something", "all lines must contain fields extracted from goodPing because of test.manual_extractfields.pb") assert.Contains(s.T(), m, "custom_tags.int", "all lines must contain `custom_tags.int` set by AddFields") require.Contains(s.T(), m, "grpc.start_time", "all lines must contain the start time") _, err := time.Parse(s.timestampFormat, m["grpc.start_time"].(string)) assert.NoError(s.T(), err, "should be able to parse start time") } assert.Equal(s.T(), msgs[0]["msg"], "some pinglist", "handler's message must contain user message") assert.Equal(s.T(), msgs[1]["msg"], "finished streaming call with code OK", "handler's message must contain user message") assert.Equal(s.T(), msgs[1]["level"], "info", "OK codes must be logged on info level.") assert.Contains(s.T(), msgs[1], "grpc.time_ms", "interceptor log statement should contain execution time") } func TestKitLoggingOverrideSuite(t *testing.T) { if strings.HasPrefix(runtime.Version(), "go1.7") { t.Skip("Skipping due to json.RawMessage incompatibility with go1.7") return } opts := []grpc_kit.Option{ grpc_kit.WithDurationField(grpc_kit.DurationToDurationField), } b := newKitBaseSuite(t) b.InterceptorTestSuite.ServerOpts = []grpc.ServerOption{ grpc_middleware.WithStreamServerChain( grpc_ctxtags.StreamServerInterceptor(), grpc_kit.StreamServerInterceptor(b.logger, opts...)), grpc_middleware.WithUnaryServerChain( grpc_ctxtags.UnaryServerInterceptor(), grpc_kit.UnaryServerInterceptor(b.logger, opts...)), } suite.Run(t, &kitServerOverrideSuite{b}) } type kitServerOverrideSuite struct { *kitBaseSuite } func (s *kitServerOverrideSuite) TestPing_HasOverriddenDuration() { _, err := s.Client.Ping(s.SimpleCtx(), goodPing) require.NoError(s.T(), err, "there must be not be an error on a successful call") msgs := s.getOutputJSONs() require.Len(s.T(), msgs, 2, "two log statements should be logged") for _, m := range msgs { assert.Equal(s.T(), m["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain service name") assert.Equal(s.T(), m["grpc.method"], "Ping", "all lines must contain method name") } assert.Equal(s.T(), msgs[0]["msg"], "some ping", "handler's message must contain user message") assert.NotContains(s.T(), msgs[0], "grpc.time_ms", "handler's message must not contain default duration") assert.NotContains(s.T(), msgs[0], "grpc.duration", "handler's message must not contain overridden duration") assert.Equal(s.T(), msgs[1]["msg"], "finished unary call with code OK", "handler's message must contain user message") assert.Equal(s.T(), msgs[1]["level"], "info", "OK error codes must be logged on info level.") assert.NotContains(s.T(), msgs[1], "grpc.time_ms", "handler's message must not contain default duration") assert.Contains(s.T(), msgs[1], "grpc.duration", "handler's message must contain overridden duration") } func (s *kitServerOverrideSuite) TestPingList_HasOverriddenDuration() { stream, err := s.Client.PingList(s.SimpleCtx(), goodPing) require.NoError(s.T(), err, "should not fail on establishing the stream") for { _, err := stream.Recv() if err == io.EOF { break } require.NoError(s.T(), err, "reading stream should not fail") } msgs := s.getOutputJSONs() require.Len(s.T(), msgs, 2, "two log statements should be logged") for _, m := range msgs { s.T() assert.Equal(s.T(), m["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain service name") assert.Equal(s.T(), m["grpc.method"], "PingList", "all lines must contain method name") } assert.Equal(s.T(), msgs[0]["msg"], "some pinglist", "handler's message must contain user message") assert.NotContains(s.T(), msgs[0], "grpc.time_ms", "handler's message must not contain default duration") assert.NotContains(s.T(), msgs[0], "grpc.duration", "handler's message must not contain overridden duration") assert.Equal(s.T(), msgs[1]["msg"], "finished streaming call with code OK", "handler's message must contain user message") assert.Equal(s.T(), msgs[1]["level"], "info", "OK error codes must be logged on info level.") assert.NotContains(s.T(), msgs[1], "grpc.time_ms", "handler's message must not contain default duration") assert.Contains(s.T(), msgs[1], "grpc.duration", "handler's message must contain overridden duration") } func TestKitServerOverrideSuppressedSuite(t *testing.T) { if strings.HasPrefix(runtime.Version(), "go1.7") { t.Skip("Skipping due to json.RawMessage incompatibility with go1.7") return } opts := []grpc_kit.Option{ grpc_kit.WithDecider(func(method string, err error) bool { if err != nil && method == "/mwitkow.testproto.TestService/PingError" { return true } return false }), } b := newKitBaseSuite(t) b.InterceptorTestSuite.ServerOpts = []grpc.ServerOption{ grpc_middleware.WithStreamServerChain( grpc_ctxtags.StreamServerInterceptor(), grpc_kit.StreamServerInterceptor(b.logger, opts...)), grpc_middleware.WithUnaryServerChain( grpc_ctxtags.UnaryServerInterceptor(), grpc_kit.UnaryServerInterceptor(b.logger, opts...)), } suite.Run(t, &kitServerOverridenDeciderSuite{b}) } type kitServerOverridenDeciderSuite struct { *kitBaseSuite } func (s *kitServerOverridenDeciderSuite) TestPing_HasOverriddenDecider() { _, err := s.Client.Ping(s.SimpleCtx(), goodPing) require.NoError(s.T(), err, "there must be not be an error on a successful call") msgs := s.getOutputJSONs() require.Len(s.T(), msgs, 1, "single log statements should be logged") assert.Equal(s.T(), msgs[0]["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain service name") assert.Equal(s.T(), msgs[0]["grpc.method"], "Ping", "all lines must contain method name") assert.Equal(s.T(), msgs[0]["msg"], "some ping", "handler's message must contain user message") } func (s *kitServerOverridenDeciderSuite) TestPingError_HasOverriddenDecider() { code := codes.NotFound msg := "NotFound must remap to InfoLevel in DefaultCodeToLevel" s.buffer.Reset() _, err := s.Client.PingError( s.SimpleCtx(), &pb_testproto.PingRequest{Value: "something", ErrorCodeReturned: uint32(code)}) require.Error(s.T(), err, "each call here must return an error") msgs := s.getOutputJSONs() require.Len(s.T(), msgs, 1, "only the interceptor log message is printed in PingErr") m := msgs[0] assert.Equal(s.T(), m["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain service name") assert.Equal(s.T(), m["grpc.method"], "PingError", "all lines must contain method name") assert.Equal(s.T(), m["grpc.code"], code.String(), "all lines must contain the correct gRPC code") assert.Equal(s.T(), m["level"], "info", msg) } func (s *kitServerOverridenDeciderSuite) TestPingList_HasOverriddenDecider() { stream, err := s.Client.PingList(s.SimpleCtx(), goodPing) require.NoError(s.T(), err, "should not fail on establishing the stream") for { _, err := stream.Recv() if err == io.EOF { break } require.NoError(s.T(), err, "reading stream should not fail") } msgs := s.getOutputJSONs() require.Len(s.T(), msgs, 1, "single log statements should be logged") assert.Equal(s.T(), msgs[0]["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain service name") assert.Equal(s.T(), msgs[0]["grpc.method"], "PingList", "all lines must contain method name") assert.Equal(s.T(), msgs[0]["msg"], "some pinglist", "handler's message must contain user message") assert.NotContains(s.T(), msgs[0], "grpc.time_ms", "handler's message must not contain default duration") assert.NotContains(s.T(), msgs[0], "grpc.duration", "handler's message must not contain overridden duration") } go-grpc-middleware-1.3.0/logging/kit/shared_test.go000066400000000000000000000054611404040257500222440ustar00rootroot00000000000000package kit_test import ( "bytes" "encoding/json" "io" "testing" "github.com/grpc-ecosystem/go-grpc-middleware/logging/kit/ctxkit" "context" "github.com/go-kit/kit/log" "github.com/go-kit/kit/log/level" grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags" grpc_testing "github.com/grpc-ecosystem/go-grpc-middleware/testing" pb_testproto "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto" ) var ( goodPing = &pb_testproto.PingRequest{Value: "something", SleepTimeMs: 9999} ) type loggingPingService struct { pb_testproto.TestServiceServer } func (s *loggingPingService) Ping(ctx context.Context, ping *pb_testproto.PingRequest) (*pb_testproto.PingResponse, error) { grpc_ctxtags.Extract(ctx).Set("custom_tags.string", "something").Set("custom_tags.int", 1337) ctxkit.AddFields(ctx, []interface{}{"custom_field", "custom_value"}...) level.Info(ctxkit.Extract(ctx)).Log("msg", "some ping") return s.TestServiceServer.Ping(ctx, ping) } func (s *loggingPingService) PingError(ctx context.Context, ping *pb_testproto.PingRequest) (*pb_testproto.Empty, error) { return s.TestServiceServer.PingError(ctx, ping) } func (s *loggingPingService) PingList(ping *pb_testproto.PingRequest, stream pb_testproto.TestService_PingListServer) error { grpc_ctxtags.Extract(stream.Context()).Set("custom_tags.string", "something").Set("custom_tags.int", 1337) ctxkit.AddFields(stream.Context(), []interface{}{"custom_field", "custom_value"}...) level.Info(ctxkit.Extract(stream.Context())).Log("msg", "some pinglist") return s.TestServiceServer.PingList(ping, stream) } func (s *loggingPingService) PingEmpty(ctx context.Context, empty *pb_testproto.Empty) (*pb_testproto.PingResponse, error) { return s.TestServiceServer.PingEmpty(ctx, empty) } type kitBaseSuite struct { *grpc_testing.InterceptorTestSuite mutexBuffer *grpc_testing.MutexReadWriter buffer *bytes.Buffer logger log.Logger timestampFormat string } func newKitBaseSuite(t *testing.T) *kitBaseSuite { b := &bytes.Buffer{} muB := grpc_testing.NewMutexReadWriter(b) logger := log.NewJSONLogger(log.NewSyncWriter(muB)) return &kitBaseSuite{ logger: logger, buffer: b, mutexBuffer: muB, InterceptorTestSuite: &grpc_testing.InterceptorTestSuite{ TestService: &loggingPingService{&grpc_testing.TestPingService{T: t}}, }, } } func (s *kitBaseSuite) SetupTest() { s.mutexBuffer.Lock() s.buffer.Reset() s.mutexBuffer.Unlock() } func (s *kitBaseSuite) getOutputJSONs() []map[string]interface{} { ret := make([]map[string]interface{}, 0) dec := json.NewDecoder(s.mutexBuffer) for { var val map[string]interface{} err := dec.Decode(&val) if err == io.EOF { break } if err != nil { s.T().Fatalf("failed decoding output from go-kit JSON: %v", err) } ret = append(ret, val) } return ret } go-grpc-middleware-1.3.0/logging/logrus/000077500000000000000000000000001404040257500201265ustar00rootroot00000000000000go-grpc-middleware-1.3.0/logging/logrus/client_interceptors.go000066400000000000000000000045121404040257500245360ustar00rootroot00000000000000// Copyright 2017 Michal Witkowski. All Rights Reserved. // See LICENSE for licensing terms. package grpc_logrus import ( "context" "path" "time" "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" "github.com/sirupsen/logrus" "google.golang.org/grpc" ) // UnaryClientInterceptor returns a new unary client interceptor that optionally logs the execution of external gRPC calls. func UnaryClientInterceptor(entry *logrus.Entry, opts ...Option) grpc.UnaryClientInterceptor { o := evaluateClientOpt(opts) return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { fields := newClientLoggerFields(ctx, method) startTime := time.Now() err := invoker(ctx, method, req, reply, cc, opts...) newCtx := ctxlogrus.ToContext(ctx, entry.WithFields(fields)) logFinalClientLine(newCtx, o, startTime, err, "finished client unary call") return err } } // StreamClientInterceptor returns a new streaming client interceptor that optionally logs the execution of external gRPC calls. func StreamClientInterceptor(entry *logrus.Entry, opts ...Option) grpc.StreamClientInterceptor { o := evaluateClientOpt(opts) return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { fields := newClientLoggerFields(ctx, method) startTime := time.Now() clientStream, err := streamer(ctx, desc, cc, method, opts...) newCtx := ctxlogrus.ToContext(ctx, entry.WithFields(fields)) logFinalClientLine(newCtx, o, startTime, err, "finished client streaming call") return clientStream, err } } func logFinalClientLine(ctx context.Context, o *options, startTime time.Time, err error, msg string) { code := o.codeFunc(err) level := o.levelFunc(code) durField, durVal := o.durationFunc(time.Now().Sub(startTime)) fields := logrus.Fields{ "grpc.code": code.String(), durField: durVal, } o.messageFunc(ctx, msg, level, code, err, fields) } func newClientLoggerFields(ctx context.Context, fullMethodString string) logrus.Fields { service := path.Dir(fullMethodString)[1:] method := path.Base(fullMethodString) return logrus.Fields{ SystemField: "grpc", KindField: "client", "grpc.service": service, "grpc.method": method, } } go-grpc-middleware-1.3.0/logging/logrus/client_interceptors_test.go000066400000000000000000000216041404040257500255760ustar00rootroot00000000000000package grpc_logrus_test import ( "io" "testing" grpc_logrus "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus" "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "google.golang.org/grpc" "google.golang.org/grpc/codes" pb_testproto "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto" ) func customClientCodeToLevel(c codes.Code) logrus.Level { if c == codes.Unauthenticated { // Make this a special case for tests, and an error. return logrus.ErrorLevel } level := grpc_logrus.DefaultClientCodeToLevel(c) return level } func TestLogrusClientSuite(t *testing.T) { opts := []grpc_logrus.Option{ grpc_logrus.WithLevels(customClientCodeToLevel), } b := newLogrusBaseSuite(t) b.logger.Level = logrus.DebugLevel // a lot of our stuff is on debug level by default b.InterceptorTestSuite.ClientOpts = []grpc.DialOption{ grpc.WithUnaryInterceptor(grpc_logrus.UnaryClientInterceptor(logrus.NewEntry(b.logger), opts...)), grpc.WithStreamInterceptor(grpc_logrus.StreamClientInterceptor(logrus.NewEntry(b.logger), opts...)), } suite.Run(t, &logrusClientSuite{b}) } type logrusClientSuite struct { *logrusBaseSuite } func (s *logrusClientSuite) TestPing() { _, err := s.Client.Ping(s.SimpleCtx(), goodPing) assert.NoError(s.T(), err, "there must be not be an on a successful call") msgs := s.getOutputJSONs() require.Len(s.T(), msgs, 1, "one log statement should be logged") assert.Equal(s.T(), msgs[0]["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain the correct service name") assert.Equal(s.T(), msgs[0]["grpc.method"], "Ping", "all lines must contain the correct method name") assert.Equal(s.T(), msgs[0]["msg"], "finished client unary call", "handler's message must contain the correct message") assert.Equal(s.T(), msgs[0]["span.kind"], "client", "all lines must contain the kind of call (client)") assert.Equal(s.T(), msgs[0]["level"], "debug", "OK codes must be logged on debug level.") assert.Contains(s.T(), msgs[0], "grpc.time_ms", "interceptor log statement should contain execution time (duration in ms)") } func (s *logrusClientSuite) TestPingList() { stream, err := s.Client.PingList(s.SimpleCtx(), goodPing) require.NoError(s.T(), err, "should not fail on establishing the stream") for { _, err := stream.Recv() if err == io.EOF { break } require.NoError(s.T(), err, "reading stream should not fail") } msgs := s.getOutputJSONs() require.Len(s.T(), msgs, 1, "one log statement should be logged") assert.Equal(s.T(), msgs[0]["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain the correct service name") assert.Equal(s.T(), msgs[0]["grpc.method"], "PingList", "all lines must contain the correct method name") assert.Equal(s.T(), msgs[0]["msg"], "finished client streaming call", "handler's message must contain the correct message") assert.Equal(s.T(), msgs[0]["span.kind"], "client", "all lines must contain the kind of call (client)") assert.Equal(s.T(), msgs[0]["level"], "debug", "OK codes must be logged on debug level.") assert.Contains(s.T(), msgs[0], "grpc.time_ms", "interceptor log statement should contain execution time (duration in ms)") } func (s *logrusClientSuite) TestPingError_WithCustomLevels() { for _, tcase := range []struct { code codes.Code level logrus.Level msg string }{ { code: codes.Internal, level: logrus.WarnLevel, msg: "Internal must remap to ErrorLevel in DefaultClientCodeToLevel", }, { code: codes.NotFound, level: logrus.DebugLevel, msg: "NotFound must remap to InfoLevel in DefaultClientCodeToLevel", }, { code: codes.FailedPrecondition, level: logrus.DebugLevel, msg: "FailedPrecondition must remap to WarnLevel in DefaultClientCodeToLevel", }, { code: codes.Unauthenticated, level: logrus.ErrorLevel, msg: "Unauthenticated is overwritten to ErrorLevel with customClientCodeToLevel override, which probably didn't work", }, } { s.SetupTest() _, err := s.Client.PingError( s.SimpleCtx(), &pb_testproto.PingRequest{Value: "something", ErrorCodeReturned: uint32(tcase.code)}) assert.Error(s.T(), err, "each call here must return an error") msgs := s.getOutputJSONs() require.Len(s.T(), msgs, 1, "only a single log message is printed") assert.Equal(s.T(), msgs[0]["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain the correct service name") assert.Equal(s.T(), msgs[0]["grpc.method"], "PingError", "all lines must contain the correct method name") assert.Equal(s.T(), msgs[0]["grpc.code"], tcase.code.String(), "all lines must contain a grpc code") assert.Equal(s.T(), msgs[0]["level"], tcase.level.String(), tcase.msg) } } func TestLogrusClientOverrideSuite(t *testing.T) { opts := []grpc_logrus.Option{ grpc_logrus.WithDurationField(grpc_logrus.DurationToDurationField), } b := newLogrusBaseSuite(t) b.logger.Level = logrus.DebugLevel // a lot of our stuff is on debug level by default b.InterceptorTestSuite.ClientOpts = []grpc.DialOption{ grpc.WithUnaryInterceptor(grpc_logrus.UnaryClientInterceptor(logrus.NewEntry(b.logger), opts...)), grpc.WithStreamInterceptor(grpc_logrus.StreamClientInterceptor(logrus.NewEntry(b.logger), opts...)), } suite.Run(t, &logrusClientOverrideSuite{b}) } type logrusClientOverrideSuite struct { *logrusBaseSuite } func (s *logrusClientOverrideSuite) TestPing_HasOverrides() { _, err := s.Client.Ping(s.SimpleCtx(), goodPing) assert.NoError(s.T(), err, "there must be not be an on a successful call") msgs := s.getOutputJSONs() require.Len(s.T(), msgs, 1, "one log statement should be logged") assert.Equal(s.T(), msgs[0]["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain the correct service name") assert.Equal(s.T(), msgs[0]["grpc.method"], "Ping", "all lines must contain the correct method name") assert.Equal(s.T(), msgs[0]["msg"], "finished client unary call", "handler's message must contain the correct message") assert.NotContains(s.T(), msgs[0], "grpc.time_ms", "message must not contain default duration") assert.Contains(s.T(), msgs[0], "grpc.duration", "message must contain overridden duration") } func (s *logrusClientOverrideSuite) TestPingList_HasOverrides() { stream, err := s.Client.PingList(s.SimpleCtx(), goodPing) require.NoError(s.T(), err, "should not fail on establishing the stream") for { _, err := stream.Recv() if err == io.EOF { break } require.NoError(s.T(), err, "reading stream should not fail") } msgs := s.getOutputJSONs() require.Len(s.T(), msgs, 1, "one log statement should be logged") assert.Equal(s.T(), msgs[0]["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain the correct service name") assert.Equal(s.T(), msgs[0]["grpc.method"], "PingList", "all lines must contain the correct method name") assert.Equal(s.T(), msgs[0]["msg"], "finished client streaming call", "log message must be correct") assert.Equal(s.T(), msgs[0]["span.kind"], "client", "all lines must contain the kind of call (client)") assert.Equal(s.T(), msgs[0]["level"], "debug", "OK codes must be logged on debug level.") assert.NotContains(s.T(), msgs[0], "grpc.time_ms", "message must not contain default duration") assert.Contains(s.T(), msgs[0], "grpc.duration", "message must contain overridden duration") } func TestZapLoggingClientMessageProducerSuite(t *testing.T) { opts := []grpc_logrus.Option{ grpc_logrus.WithMessageProducer(StubMessageProducer), } b := newLogrusBaseSuite(t) b.logger.Level = logrus.DebugLevel // a lot of our stuff is on debug level by default b.InterceptorTestSuite.ClientOpts = []grpc.DialOption{ grpc.WithUnaryInterceptor(grpc_logrus.UnaryClientInterceptor(logrus.NewEntry(b.logger), opts...)), grpc.WithStreamInterceptor(grpc_logrus.StreamClientInterceptor(logrus.NewEntry(b.logger), opts...)), } suite.Run(t, &logrusClientMessageProducerSuite{b}) } type logrusClientMessageProducerSuite struct { *logrusBaseSuite } func (s *logrusClientMessageProducerSuite) TestPing_HasOverriddenMessageProducer() { _, err := s.Client.Ping(s.SimpleCtx(), goodPing) assert.NoError(s.T(), err, "there must be not be an on a successful call") msgs := s.getOutputJSONs() require.Len(s.T(), msgs, 1, "one log statement should be logged") assert.Equal(s.T(), msgs[0]["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain the correct service name") assert.Equal(s.T(), msgs[0]["grpc.method"], "Ping", "all lines must contain the correct method name") assert.Equal(s.T(), msgs[0]["msg"], "custom message", "handler's message must contain the correct message") assert.Equal(s.T(), msgs[0]["span.kind"], "client", "all lines must contain the kind of call (client)") assert.Equal(s.T(), msgs[0]["level"], "debug", "OK codes must be logged on debug level.") assert.Contains(s.T(), msgs[0], "grpc.time_ms", "interceptor log statement should contain execution time (duration in ms)") } go-grpc-middleware-1.3.0/logging/logrus/context.go000066400000000000000000000010621404040257500221400ustar00rootroot00000000000000package grpc_logrus import ( "context" "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" "github.com/sirupsen/logrus" ) // AddFields adds logrus fields to the logger. // Deprecated: should use the ctxlogrus.Extract instead func AddFields(ctx context.Context, fields logrus.Fields) { ctxlogrus.AddFields(ctx, fields) } // Extract takes the call-scoped logrus.Entry from grpc_logrus middleware. // Deprecated: should use the ctxlogrus.Extract instead func Extract(ctx context.Context) *logrus.Entry { return ctxlogrus.Extract(ctx) } go-grpc-middleware-1.3.0/logging/logrus/ctxlogrus/000077500000000000000000000000001404040257500221605ustar00rootroot00000000000000go-grpc-middleware-1.3.0/logging/logrus/ctxlogrus/context.go000066400000000000000000000027051404040257500241770ustar00rootroot00000000000000package ctxlogrus import ( "context" "github.com/grpc-ecosystem/go-grpc-middleware/tags" "github.com/sirupsen/logrus" ) type ctxLoggerMarker struct{} type ctxLogger struct { logger *logrus.Entry fields logrus.Fields } var ( ctxLoggerKey = &ctxLoggerMarker{} ) // AddFields adds logrus fields to the logger. func AddFields(ctx context.Context, fields logrus.Fields) { l, ok := ctx.Value(ctxLoggerKey).(*ctxLogger) if !ok || l == nil { return } for k, v := range fields { l.fields[k] = v } } // Extract takes the call-scoped logrus.Entry from ctx_logrus middleware. // // If the ctx_logrus middleware wasn't used, a no-op `logrus.Entry` is returned. This makes it safe to // use regardless. func Extract(ctx context.Context) *logrus.Entry { l, ok := ctx.Value(ctxLoggerKey).(*ctxLogger) if !ok || l == nil { return logrus.NewEntry(nullLogger) } fields := logrus.Fields{} // Add grpc_ctxtags tags metadata until now. tags := grpc_ctxtags.Extract(ctx) for k, v := range tags.Values() { fields[k] = v } // Add logrus fields added until now. for k, v := range l.fields { fields[k] = v } return l.logger.WithFields(fields) } // ToContext adds the logrus.Entry to the context for extraction later. // Returning the new context that has been created. func ToContext(ctx context.Context, entry *logrus.Entry) context.Context { l := &ctxLogger{ logger: entry, fields: logrus.Fields{}, } return context.WithValue(ctx, ctxLoggerKey, l) } go-grpc-middleware-1.3.0/logging/logrus/ctxlogrus/doc.go000066400000000000000000000012241404040257500232530ustar00rootroot00000000000000/* `ctxlogrus` is a ctxlogger that is backed by logrus It accepts a user-configured `logrus.Logger` that will be used for logging. The same `logrus.Logger` will be populated into the `context.Context` passed into gRPC handler code. You can use `ctx_logrus.Extract` to log into a request-scoped `logrus.Logger` instance in your handler code. As `ctx_logrus.Extract` will iterate all tags on from `grpc_ctxtags` it is therefore expensive so it is advised that you extract once at the start of the function from the context and reuse it for the remainder of the function (see examples). Please see examples and tests for examples of use. */ package ctxlogrus go-grpc-middleware-1.3.0/logging/logrus/ctxlogrus/examples_test.go000066400000000000000000000012111404040257500253570ustar00rootroot00000000000000package ctxlogrus_test import ( "context" "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" "github.com/grpc-ecosystem/go-grpc-middleware/tags" ) // Simple unary handler that adds custom fields to the requests's context. These will be used for all log statements. func ExampleExtract_unary() { ctx := context.Background() // setting tags will be added to the logger as log fields grpc_ctxtags.Extract(ctx).Set("custom_tags.string", "something").Set("custom_tags.int", 1337) // Extract a single request-scoped logrus.Logger and log messages. l := ctxlogrus.Extract(ctx) l.Info("some ping") l.Info("another ping") } go-grpc-middleware-1.3.0/logging/logrus/ctxlogrus/noop.go000066400000000000000000000003771404040257500234710ustar00rootroot00000000000000package ctxlogrus import ( "io/ioutil" "github.com/sirupsen/logrus" ) var ( nullLogger = &logrus.Logger{ Out: ioutil.Discard, Formatter: new(logrus.TextFormatter), Hooks: make(logrus.LevelHooks), Level: logrus.PanicLevel, } ) go-grpc-middleware-1.3.0/logging/logrus/doc.go000066400000000000000000000066731404040257500212360ustar00rootroot00000000000000/* `grpc_logrus` is a gRPC logging middleware backed by Logrus loggers It accepts a user-configured `logrus.Entry` that will be used for logging completed gRPC calls. The same `logrus.Entry` will be used for logging completed gRPC calls, and be populated into the `context.Context` passed into gRPC handler code. On calling `StreamServerInterceptor` or `UnaryServerInterceptor` this logging middleware will add gRPC call information to the ctx so that it will be present on subsequent use of the `ctxlogrus` logger. This package also implements request and response *payload* logging, both for server-side and client-side. These will be logged as structured `jsonpb` fields for every message received/sent (both unary and streaming). For that please use `Payload*Interceptor` functions for that. Please note that the user-provided function that determines whether to log the full request/response payload needs to be written with care, this can significantly slow down gRPC. If a deadline is present on the gRPC request the grpc.request.deadline tag is populated when the request begins. grpc.request.deadline is a string representing the time (RFC3339) when the current call will expire. Logrus can also be made as a backend for gRPC library internals. For that use `ReplaceGrpcLogger`. *Server Interceptor* Below is a JSON formatted example of a log that would be logged by the server interceptor: { "level": "info", // string logrus log levels "msg": "finished unary call", // string log message "grpc.code": "OK", // string grpc status code "grpc.method": "Ping", // string method name "grpc.service": "mwitkow.testproto.TestService", // string full name of the called service "grpc.start_time": "2006-01-02T15:04:05Z07:00", // string RFC3339 representation of the start time "grpc.request.deadline": "2006-01-02T15:04:05Z07:00", // string RFC3339 deadline of the current request if supplied "grpc.request.value": "something", // string value on the request "grpc.time_ms": 1.234, // float32 run time of the call in ms "peer.address": { "IP": "127.0.0.1", // string IP address of calling party "Port": 60216, // int port call is coming in on "Zone": "" // string peer zone for caller }, "span.kind": "server", // string client | server "system": "grpc" // string "custom_field": "custom_value", // string user defined field "custom_tags.int": 1337, // int user defined tag on the ctx "custom_tags.string": "something", // string user defined tag on the ctx } *Payload Interceptor* Below is a JSON formatted example of a log that would be logged by the payload interceptor: { "level": "info", // string logrus log levels "msg": "client request payload logged as grpc.request.content", // string log message "grpc.request.content": { // object content of RPC request "value": "something", // string defined by caller "sleepTimeMs": 9999 // int defined by caller }, "grpc.method": "Ping", // string method being called "grpc.service": "mwitkow.testproto.TestService", // string service being called "span.kind": "client", // string client | server "system": "grpc" // string } Note - due to implementation ZAP differs from Logrus in the "grpc.request.content" object by having an inner "msg" object. Please see examples and tests for examples of use. */ package grpc_logrus go-grpc-middleware-1.3.0/logging/logrus/examples_test.go000066400000000000000000000072401404040257500233350ustar00rootroot00000000000000package grpc_logrus_test import ( "context" "time" "github.com/grpc-ecosystem/go-grpc-middleware" "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus" "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" "github.com/grpc-ecosystem/go-grpc-middleware/tags" pb_testproto "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto" "github.com/sirupsen/logrus" "google.golang.org/grpc" ) var ( logrusLogger *logrus.Logger customFunc grpc_logrus.CodeToLevel ) // Initialization shows a relatively complex initialization sequence. func Example_initialization() { // Logrus entry is used, allowing pre-definition of certain fields by the user. logrusEntry := logrus.NewEntry(logrusLogger) // Shared options for the logger, with a custom gRPC code to log level function. opts := []grpc_logrus.Option{ grpc_logrus.WithLevels(customFunc), } // Make sure that log statements internal to gRPC library are logged using the logrus Logger as well. grpc_logrus.ReplaceGrpcLogger(logrusEntry) // Create a server, make sure we put the grpc_ctxtags context before everything else. _ = grpc.NewServer( grpc_middleware.WithUnaryServerChain( grpc_ctxtags.UnaryServerInterceptor(grpc_ctxtags.WithFieldExtractor(grpc_ctxtags.CodeGenRequestFieldExtractor)), grpc_logrus.UnaryServerInterceptor(logrusEntry, opts...), ), grpc_middleware.WithStreamServerChain( grpc_ctxtags.StreamServerInterceptor(grpc_ctxtags.WithFieldExtractor(grpc_ctxtags.CodeGenRequestFieldExtractor)), grpc_logrus.StreamServerInterceptor(logrusEntry, opts...), ), ) } func Example_initializationWithDurationFieldOverride() { // Logrus entry is used, allowing pre-definition of certain fields by the user. logrusEntry := logrus.NewEntry(logrusLogger) // Shared options for the logger, with a custom duration to log field function. opts := []grpc_logrus.Option{ grpc_logrus.WithDurationField(func(duration time.Duration) (key string, value interface{}) { return "grpc.time_ns", duration.Nanoseconds() }), } _ = grpc.NewServer( grpc_middleware.WithUnaryServerChain( grpc_ctxtags.UnaryServerInterceptor(), grpc_logrus.UnaryServerInterceptor(logrusEntry, opts...), ), grpc_middleware.WithStreamServerChain( grpc_ctxtags.StreamServerInterceptor(), grpc_logrus.StreamServerInterceptor(logrusEntry, opts...), ), ) } // Simple unary handler that adds custom fields to the requests's context. These will be used for all log statements. func ExampleExtract_unary() { _ = func(ctx context.Context, ping *pb_testproto.PingRequest) (*pb_testproto.PingResponse, error) { // Add fields the ctxtags of the request which will be added to all extracted loggers. grpc_ctxtags.Extract(ctx).Set("custom_tags.string", "something").Set("custom_tags.int", 1337) // Extract a single request-scoped logrus.Logger and log messages. l := ctxlogrus.Extract(ctx) l.Info("some ping") l.Info("another ping") return &pb_testproto.PingResponse{Value: ping.Value}, nil } } func ExampleWithDecider() { opts := []grpc_logrus.Option{ grpc_logrus.WithDecider(func(methodFullName string, err error) bool { // will not log gRPC calls if it was a call to healthcheck and no error was raised if err == nil && methodFullName == "blah.foo.healthcheck" { return false } // by default you will log all calls return true }), } _ = []grpc.ServerOption{ grpc_middleware.WithStreamServerChain( grpc_ctxtags.StreamServerInterceptor(), grpc_logrus.StreamServerInterceptor(logrus.NewEntry(logrus.New()), opts...)), grpc_middleware.WithUnaryServerChain( grpc_ctxtags.UnaryServerInterceptor(), grpc_logrus.UnaryServerInterceptor(logrus.NewEntry(logrus.New()), opts...)), } } go-grpc-middleware-1.3.0/logging/logrus/grpclogger.go000066400000000000000000000012151404040257500226070ustar00rootroot00000000000000// Copyright 2017 Michal Witkowski. All Rights Reserved. // See LICENSE for licensing terms. package grpc_logrus import ( "github.com/sirupsen/logrus" "google.golang.org/grpc/grpclog" ) // ReplaceGrpcLogger sets the given logrus.Logger as a gRPC-level logger. // This should be called *before* any other initialization, preferably from init() functions. func ReplaceGrpcLogger(logger *logrus.Entry) { grpclog.SetLoggerV2(&logrusGrpcLoggerV2{ logger.WithField("system", SystemField), }) } type logrusGrpcLoggerV2 struct { *logrus.Entry } func (l *logrusGrpcLoggerV2) V(level int) bool { return l.Logger.IsLevelEnabled(logrus.Level(level)) } go-grpc-middleware-1.3.0/logging/logrus/grpclogger_test.go000066400000000000000000000033561404040257500236560ustar00rootroot00000000000000package grpc_logrus import ( "testing" "github.com/sirupsen/logrus" ) func Test_logrusGrpcLoggerV2_V(t *testing.T) { tests := []struct { name string setupLevel logrus.Level inLevel logrus.Level want bool }{ { name: "WarnLevel setup when we have WarnLevel msg should return TRUE", setupLevel: logrus.WarnLevel, inLevel: logrus.WarnLevel, want: true, }, { name: "WarnLevel setup when we have ErrorLevel msg should return TRUE", setupLevel: logrus.WarnLevel, inLevel: logrus.ErrorLevel, want: true, }, { name: "WarnLevel setup when we have InfoLevel msg should return FALSE", setupLevel: logrus.WarnLevel, inLevel: logrus.InfoLevel, want: false, }, { name: "WarnLevel setup when we have DebugLevel msg should return FALSE", setupLevel: logrus.WarnLevel, inLevel: logrus.DebugLevel, want: false, }, { name: "WarnLevel setup when we have TraceLevel msg should return FALSE", setupLevel: logrus.WarnLevel, inLevel: logrus.TraceLevel, want: false, }, { name: "TraceLevel setup when we have WarnLevel msg should return TRUE", setupLevel: logrus.TraceLevel, inLevel: logrus.WarnLevel, want: true, }, { name: "TraceLevel setup when we have TraceLevel msg should return TRUE", setupLevel: logrus.TraceLevel, inLevel: logrus.TraceLevel, want: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { lr := logrus.New() lr.SetLevel(tt.setupLevel) l := &logrusGrpcLoggerV2{logrus.NewEntry(lr)} if got := l.V(int(tt.inLevel)); got != tt.want { t.Errorf("V() = %v, want %v", got, tt.want) } }) } } go-grpc-middleware-1.3.0/logging/logrus/options.go000066400000000000000000000146501404040257500221560ustar00rootroot00000000000000// Copyright 2017 Michal Witkowski. All Rights Reserved. // See LICENSE for licensing terms. package grpc_logrus import ( "context" "time" grpc_logging "github.com/grpc-ecosystem/go-grpc-middleware/logging" "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" "github.com/sirupsen/logrus" "google.golang.org/grpc/codes" ) var ( defaultOptions = &options{ levelFunc: nil, shouldLog: grpc_logging.DefaultDeciderMethod, codeFunc: grpc_logging.DefaultErrorToCode, durationFunc: DefaultDurationToField, messageFunc: DefaultMessageProducer, timestampFormat: time.RFC3339, } ) type options struct { levelFunc CodeToLevel shouldLog grpc_logging.Decider codeFunc grpc_logging.ErrorToCode durationFunc DurationToField messageFunc MessageProducer timestampFormat string } func evaluateServerOpt(opts []Option) *options { optCopy := &options{} *optCopy = *defaultOptions optCopy.levelFunc = DefaultCodeToLevel for _, o := range opts { o(optCopy) } return optCopy } func evaluateClientOpt(opts []Option) *options { optCopy := &options{} *optCopy = *defaultOptions optCopy.levelFunc = DefaultClientCodeToLevel for _, o := range opts { o(optCopy) } return optCopy } type Option func(*options) // CodeToLevel function defines the mapping between gRPC return codes and interceptor log level. type CodeToLevel func(code codes.Code) logrus.Level // DurationToField function defines how to produce duration fields for logging type DurationToField func(duration time.Duration) (key string, value interface{}) // WithDecider customizes the function for deciding if the gRPC interceptor logs should log. func WithDecider(f grpc_logging.Decider) Option { return func(o *options) { o.shouldLog = f } } // WithLevels customizes the function for mapping gRPC return codes and interceptor log level statements. func WithLevels(f CodeToLevel) Option { return func(o *options) { o.levelFunc = f } } // WithCodes customizes the function for mapping errors to error codes. func WithCodes(f grpc_logging.ErrorToCode) Option { return func(o *options) { o.codeFunc = f } } // WithDurationField customizes the function for mapping request durations to log fields. func WithDurationField(f DurationToField) Option { return func(o *options) { o.durationFunc = f } } // WithMessageProducer customizes the function for message formation. func WithMessageProducer(f MessageProducer) Option { return func(o *options) { o.messageFunc = f } } // WithTimestampFormat customizes the timestamps emitted in the log fields. func WithTimestampFormat(format string) Option { return func(o *options) { o.timestampFormat = format } } // DefaultCodeToLevel is the default implementation of gRPC return codes to log levels for server side. func DefaultCodeToLevel(code codes.Code) logrus.Level { switch code { case codes.OK: return logrus.InfoLevel case codes.Canceled: return logrus.InfoLevel case codes.Unknown: return logrus.ErrorLevel case codes.InvalidArgument: return logrus.InfoLevel case codes.DeadlineExceeded: return logrus.WarnLevel case codes.NotFound: return logrus.InfoLevel case codes.AlreadyExists: return logrus.InfoLevel case codes.PermissionDenied: return logrus.WarnLevel case codes.Unauthenticated: return logrus.InfoLevel // unauthenticated requests can happen case codes.ResourceExhausted: return logrus.WarnLevel case codes.FailedPrecondition: return logrus.WarnLevel case codes.Aborted: return logrus.WarnLevel case codes.OutOfRange: return logrus.WarnLevel case codes.Unimplemented: return logrus.ErrorLevel case codes.Internal: return logrus.ErrorLevel case codes.Unavailable: return logrus.WarnLevel case codes.DataLoss: return logrus.ErrorLevel default: return logrus.ErrorLevel } } // DefaultClientCodeToLevel is the default implementation of gRPC return codes to log levels for client side. func DefaultClientCodeToLevel(code codes.Code) logrus.Level { switch code { case codes.OK: return logrus.DebugLevel case codes.Canceled: return logrus.DebugLevel case codes.Unknown: return logrus.InfoLevel case codes.InvalidArgument: return logrus.DebugLevel case codes.DeadlineExceeded: return logrus.InfoLevel case codes.NotFound: return logrus.DebugLevel case codes.AlreadyExists: return logrus.DebugLevel case codes.PermissionDenied: return logrus.InfoLevel case codes.Unauthenticated: return logrus.InfoLevel // unauthenticated requests can happen case codes.ResourceExhausted: return logrus.DebugLevel case codes.FailedPrecondition: return logrus.DebugLevel case codes.Aborted: return logrus.DebugLevel case codes.OutOfRange: return logrus.DebugLevel case codes.Unimplemented: return logrus.WarnLevel case codes.Internal: return logrus.WarnLevel case codes.Unavailable: return logrus.WarnLevel case codes.DataLoss: return logrus.WarnLevel default: return logrus.InfoLevel } } // DefaultDurationToField is the default implementation of converting request duration to a log field (key and value). var DefaultDurationToField = DurationToTimeMillisField // DurationToTimeMillisField converts the duration to milliseconds and uses the key `grpc.time_ms`. func DurationToTimeMillisField(duration time.Duration) (key string, value interface{}) { return "grpc.time_ms", durationToMilliseconds(duration) } // DurationToDurationField uses the duration value to log the request duration. func DurationToDurationField(duration time.Duration) (key string, value interface{}) { return "grpc.duration", duration } func durationToMilliseconds(duration time.Duration) float32 { return float32(duration.Nanoseconds()/1000) / 1000 } // MessageProducer produces a user defined log message type MessageProducer func(ctx context.Context, format string, level logrus.Level, code codes.Code, err error, fields logrus.Fields) // DefaultMessageProducer writes the default message func DefaultMessageProducer(ctx context.Context, format string, level logrus.Level, code codes.Code, err error, fields logrus.Fields) { if err != nil { fields[logrus.ErrorKey] = err } entry := ctxlogrus.Extract(ctx).WithContext(ctx).WithFields(fields) switch level { case logrus.DebugLevel: entry.Debugf(format) case logrus.InfoLevel: entry.Infof(format) case logrus.WarnLevel: entry.Warningf(format) case logrus.ErrorLevel: entry.Errorf(format) case logrus.FatalLevel: entry.Fatalf(format) case logrus.PanicLevel: entry.Panicf(format) } } go-grpc-middleware-1.3.0/logging/logrus/options_test.go000066400000000000000000000004471404040257500232140ustar00rootroot00000000000000package grpc_logrus import ( "testing" "time" "github.com/stretchr/testify/assert" ) func TestDurationToTimeMillisField(t *testing.T) { _, val := DurationToTimeMillisField(time.Microsecond * 100) assert.Equal(t, val.(float32), float32(0.1), "sub millisecond values should be correct") } go-grpc-middleware-1.3.0/logging/logrus/payload_interceptors.go000066400000000000000000000134111404040257500247070ustar00rootroot00000000000000package grpc_logrus import ( "bytes" "context" "fmt" "github.com/golang/protobuf/jsonpb" "github.com/golang/protobuf/proto" "github.com/grpc-ecosystem/go-grpc-middleware/logging" "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" "github.com/sirupsen/logrus" "google.golang.org/grpc" ) var ( // JsonPbMarshaller is the marshaller used for serializing protobuf messages. // If needed, this variable can be reassigned with a different marshaller with the same Marshal() signature. JsonPbMarshaller grpc_logging.JsonPbMarshaler = &jsonpb.Marshaler{} ) // PayloadUnaryServerInterceptor returns a new unary server interceptors that logs the payloads of requests. // // This *only* works when placed *after* the `grpc_logrus.UnaryServerInterceptor`. However, the logging can be done to a // separate instance of the logger. func PayloadUnaryServerInterceptor(entry *logrus.Entry, decider grpc_logging.ServerPayloadLoggingDecider) grpc.UnaryServerInterceptor { return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { if !decider(ctx, info.FullMethod, info.Server) { return handler(ctx, req) } // Use the provided logrus.Entry for logging but use the fields from context. logEntry := entry.WithFields(ctxlogrus.Extract(ctx).Data) logProtoMessageAsJson(logEntry, req, "grpc.request.content", "server request payload logged as grpc.request.content field") resp, err := handler(ctx, req) if err == nil { logProtoMessageAsJson(logEntry, resp, "grpc.response.content", "server response payload logged as grpc.request.content field") } return resp, err } } // PayloadStreamServerInterceptor returns a new server server interceptors that logs the payloads of requests. // // This *only* works when placed *after* the `grpc_logrus.StreamServerInterceptor`. However, the logging can be done to a // separate instance of the logger. func PayloadStreamServerInterceptor(entry *logrus.Entry, decider grpc_logging.ServerPayloadLoggingDecider) grpc.StreamServerInterceptor { return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { if !decider(stream.Context(), info.FullMethod, srv) { return handler(srv, stream) } // Use the provided logrus.Entry for logging but use the fields from context. logEntry := entry.WithFields(ctxlogrus.Extract(stream.Context()).Data) newStream := &loggingServerStream{ServerStream: stream, entry: logEntry} return handler(srv, newStream) } } // PayloadUnaryClientInterceptor returns a new unary client interceptor that logs the payloads of requests and responses. func PayloadUnaryClientInterceptor(entry *logrus.Entry, decider grpc_logging.ClientPayloadLoggingDecider) grpc.UnaryClientInterceptor { return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { if !decider(ctx, method) { return invoker(ctx, method, req, reply, cc, opts...) } logEntry := entry.WithFields(newClientLoggerFields(ctx, method)) logProtoMessageAsJson(logEntry, req, "grpc.request.content", "client request payload logged as grpc.request.content") err := invoker(ctx, method, req, reply, cc, opts...) if err == nil { logProtoMessageAsJson(logEntry, reply, "grpc.response.content", "client response payload logged as grpc.response.content") } return err } } // PayloadStreamClientInterceptor returns a new streaming client interceptor that logs the payloads of requests and responses. func PayloadStreamClientInterceptor(entry *logrus.Entry, decider grpc_logging.ClientPayloadLoggingDecider) grpc.StreamClientInterceptor { return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { if !decider(ctx, method) { return streamer(ctx, desc, cc, method, opts...) } logEntry := entry.WithFields(newClientLoggerFields(ctx, method)) clientStream, err := streamer(ctx, desc, cc, method, opts...) newStream := &loggingClientStream{ClientStream: clientStream, entry: logEntry} return newStream, err } } type loggingClientStream struct { grpc.ClientStream entry *logrus.Entry } func (l *loggingClientStream) SendMsg(m interface{}) error { err := l.ClientStream.SendMsg(m) if err == nil { logProtoMessageAsJson(l.entry, m, "grpc.request.content", "server request payload logged as grpc.request.content field") } return err } func (l *loggingClientStream) RecvMsg(m interface{}) error { err := l.ClientStream.RecvMsg(m) if err == nil { logProtoMessageAsJson(l.entry, m, "grpc.response.content", "server response payload logged as grpc.response.content field") } return err } type loggingServerStream struct { grpc.ServerStream entry *logrus.Entry } func (l *loggingServerStream) SendMsg(m interface{}) error { err := l.ServerStream.SendMsg(m) if err == nil { logProtoMessageAsJson(l.entry, m, "grpc.response.content", "server response payload logged as grpc.response.content field") } return err } func (l *loggingServerStream) RecvMsg(m interface{}) error { err := l.ServerStream.RecvMsg(m) if err == nil { logProtoMessageAsJson(l.entry, m, "grpc.request.content", "server request payload logged as grpc.request.content field") } return err } func logProtoMessageAsJson(entry *logrus.Entry, pbMsg interface{}, key string, msg string) { if p, ok := pbMsg.(proto.Message); ok { entry.WithField(key, &jsonpbMarshalleble{p}).Info(msg) } } type jsonpbMarshalleble struct { proto.Message } func (j *jsonpbMarshalleble) MarshalJSON() ([]byte, error) { b := &bytes.Buffer{} if err := JsonPbMarshaller.Marshal(b, j.Message); err != nil { return nil, fmt.Errorf("jsonpb serializer failed: %v", err) } return b.Bytes(), nil } go-grpc-middleware-1.3.0/logging/logrus/payload_interceptors_test.go000066400000000000000000000140631404040257500257520ustar00rootroot00000000000000package grpc_logrus_test import ( "context" "io" "io/ioutil" "runtime" "strings" "testing" "github.com/grpc-ecosystem/go-grpc-middleware" "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus" "github.com/grpc-ecosystem/go-grpc-middleware/tags" pb_testproto "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto" "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "google.golang.org/grpc" ) var ( nullLogger = &logrus.Logger{ Out: ioutil.Discard, Formatter: new(logrus.TextFormatter), Hooks: make(logrus.LevelHooks), Level: logrus.PanicLevel, } ) func TestLogrusPayloadSuite(t *testing.T) { if strings.HasPrefix(runtime.Version(), "go1.7") { t.Skipf("Skipping due to json.RawMessage incompatibility with go1.7") return } alwaysLoggingDeciderServer := func(ctx context.Context, fullMethodName string, servingObject interface{}) bool { return true } alwaysLoggingDeciderClient := func(ctx context.Context, fullMethodName string) bool { return true } b := newLogrusBaseSuite(t) b.InterceptorTestSuite.ClientOpts = []grpc.DialOption{ grpc.WithUnaryInterceptor(grpc_logrus.PayloadUnaryClientInterceptor(logrus.NewEntry(b.logger), alwaysLoggingDeciderClient)), grpc.WithStreamInterceptor(grpc_logrus.PayloadStreamClientInterceptor(logrus.NewEntry(b.logger), alwaysLoggingDeciderClient)), } b.InterceptorTestSuite.ServerOpts = []grpc.ServerOption{ grpc_middleware.WithStreamServerChain( grpc_ctxtags.StreamServerInterceptor(grpc_ctxtags.WithFieldExtractor(grpc_ctxtags.CodeGenRequestFieldExtractor)), grpc_logrus.StreamServerInterceptor(logrus.NewEntry(nullLogger)), grpc_logrus.PayloadStreamServerInterceptor(logrus.NewEntry(b.logger), alwaysLoggingDeciderServer)), grpc_middleware.WithUnaryServerChain( grpc_ctxtags.UnaryServerInterceptor(grpc_ctxtags.WithFieldExtractor(grpc_ctxtags.CodeGenRequestFieldExtractor)), grpc_logrus.UnaryServerInterceptor(logrus.NewEntry(nullLogger)), grpc_logrus.PayloadUnaryServerInterceptor(logrus.NewEntry(b.logger), alwaysLoggingDeciderServer)), } suite.Run(t, &logrusPayloadSuite{b}) } type logrusPayloadSuite struct { *logrusBaseSuite } func (s *logrusPayloadSuite) getServerAndClientMessages(expectedServer int, expectedClient int) (serverMsgs []map[string]interface{}, clientMsgs []map[string]interface{}) { msgs := s.getOutputJSONs() for _, m := range msgs { if m["span.kind"] == "server" { serverMsgs = append(serverMsgs, m) } else if m["span.kind"] == "client" { clientMsgs = append(clientMsgs, m) } } require.Len(s.T(), serverMsgs, expectedServer, "must match expected number of server log messages") require.Len(s.T(), clientMsgs, expectedClient, "must match expected number of client log messages") return serverMsgs, clientMsgs } func (s *logrusPayloadSuite) TestPing_LogsBothRequestAndResponse() { _, err := s.Client.Ping(s.SimpleCtx(), goodPing) require.NoError(s.T(), err, "there must be not be an on a successful call") serverMsgs, clientMsgs := s.getServerAndClientMessages(2, 2) for _, m := range append(serverMsgs, clientMsgs...) { assert.Equal(s.T(), m["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain the correct service name") assert.Equal(s.T(), m["grpc.method"], "Ping", "all lines must contain the correct method name") assert.Equal(s.T(), m["level"], "info", "all lines must contain method name") } serverReq, serverResp := serverMsgs[0], serverMsgs[1] clientReq, clientResp := clientMsgs[0], clientMsgs[1] assert.Contains(s.T(), clientReq, "grpc.request.content", "request payload must be logged in a structured way") assert.Contains(s.T(), serverReq, "grpc.request.content", "request payload must be logged in a structured way") assert.Contains(s.T(), clientResp, "grpc.response.content", "response payload must be logged in a structured way") assert.Contains(s.T(), serverResp, "grpc.response.content", "response payload must be logged in a structured way") } func (s *logrusPayloadSuite) TestPingError_LogsOnlyRequestsOnError() { _, err := s.Client.PingError(s.SimpleCtx(), &pb_testproto.PingRequest{Value: "something", ErrorCodeReturned: uint32(4)}) require.Error(s.T(), err, "there must be not be an error on a successful call") serverMsgs, clientMsgs := s.getServerAndClientMessages(1, 1) for _, m := range append(serverMsgs, clientMsgs...) { assert.Equal(s.T(), m["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain the correct service name") assert.Equal(s.T(), m["grpc.method"], "PingError", "all lines must contain the correct method name") assert.Equal(s.T(), m["level"], "info", "all lines must be logged at info level") } assert.Contains(s.T(), clientMsgs[0], "grpc.request.content", "request payload must be logged by the client") assert.Contains(s.T(), serverMsgs[0], "grpc.request.content", "request payload must be logged by the server") } func (s *logrusPayloadSuite) TestPingStream_LogsAllRequestsAndResponses() { messagesExpected := 20 stream, err := s.Client.PingStream(s.SimpleCtx()) require.NoError(s.T(), err, "no error on stream creation") for i := 0; i < messagesExpected; i++ { require.NoError(s.T(), stream.Send(goodPing), "sending must succeed") } require.NoError(s.T(), stream.CloseSend(), "no error on close of stream") for { pong := &pb_testproto.PingResponse{} err := stream.RecvMsg(pong) if err == io.EOF { break } require.NoError(s.T(), err, "no error on receive") } serverMsgs, clientMsgs := s.getServerAndClientMessages(2*messagesExpected, 2*messagesExpected) for _, m := range append(serverMsgs, clientMsgs...) { assert.Equal(s.T(), m["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain the correct service name") assert.Equal(s.T(), m["grpc.method"], "PingStream", "all lines must contain the correct method name") assert.Equal(s.T(), m["level"], "info", "all lines must be at info log level") content := m["grpc.request.content"] != nil || m["grpc.response.content"] != nil assert.True(s.T(), content, "all messages must contain a payload") } } go-grpc-middleware-1.3.0/logging/logrus/server_interceptors.go000066400000000000000000000062761404040257500245770ustar00rootroot00000000000000// Copyright (c) Improbable Worlds Ltd, All Rights Reserved package grpc_logrus import ( "context" "path" "time" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" "github.com/sirupsen/logrus" "google.golang.org/grpc" ) var ( // SystemField is used in every log statement made through grpc_logrus. Can be overwritten before any initialization code. SystemField = "system" // KindField describes the log field used to indicate whether this is a server or a client log statement. KindField = "span.kind" ) // UnaryServerInterceptor returns a new unary server interceptors that adds logrus.Entry to the context. func UnaryServerInterceptor(entry *logrus.Entry, opts ...Option) grpc.UnaryServerInterceptor { o := evaluateServerOpt(opts) return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { startTime := time.Now() newCtx := newLoggerForCall(ctx, entry, info.FullMethod, startTime, o.timestampFormat) resp, err := handler(newCtx, req) if !o.shouldLog(info.FullMethod, err) { return resp, err } code := o.codeFunc(err) level := o.levelFunc(code) durField, durVal := o.durationFunc(time.Since(startTime)) fields := logrus.Fields{ "grpc.code": code.String(), durField: durVal, } if err != nil { fields[logrus.ErrorKey] = err } o.messageFunc(newCtx, "finished unary call with code "+code.String(), level, code, err, fields) return resp, err } } // StreamServerInterceptor returns a new streaming server interceptor that adds logrus.Entry to the context. func StreamServerInterceptor(entry *logrus.Entry, opts ...Option) grpc.StreamServerInterceptor { o := evaluateServerOpt(opts) return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { startTime := time.Now() newCtx := newLoggerForCall(stream.Context(), entry, info.FullMethod, startTime, o.timestampFormat) wrapped := grpc_middleware.WrapServerStream(stream) wrapped.WrappedContext = newCtx err := handler(srv, wrapped) if !o.shouldLog(info.FullMethod, err) { return err } code := o.codeFunc(err) level := o.levelFunc(code) durField, durVal := o.durationFunc(time.Since(startTime)) fields := logrus.Fields{ "grpc.code": code.String(), durField: durVal, } o.messageFunc(newCtx, "finished streaming call with code "+code.String(), level, code, err, fields) return err } } func newLoggerForCall(ctx context.Context, entry *logrus.Entry, fullMethodString string, start time.Time, timestampFormat string) context.Context { service := path.Dir(fullMethodString)[1:] method := path.Base(fullMethodString) callLog := entry.WithFields( logrus.Fields{ SystemField: "grpc", KindField: "server", "grpc.service": service, "grpc.method": method, "grpc.start_time": start.Format(timestampFormat), }) if d, ok := ctx.Deadline(); ok { callLog = callLog.WithFields( logrus.Fields{ "grpc.request.deadline": d.Format(timestampFormat), }) } callLog = callLog.WithFields(ctxlogrus.Extract(ctx).Data) return ctxlogrus.ToContext(ctx, callLog) } go-grpc-middleware-1.3.0/logging/logrus/server_interceptors_test.go000066400000000000000000000371311404040257500256300ustar00rootroot00000000000000package grpc_logrus_test import ( "io" "testing" "time" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" grpc_logrus "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus" grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags" pb_testproto "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto" "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "google.golang.org/grpc" "google.golang.org/grpc/codes" ) func TestLogrusServerSuite(t *testing.T) { for _, tcase := range []struct { timestampFormat string }{ { timestampFormat: time.RFC3339, }, { timestampFormat: "2006-01-02", }, } { opts := []grpc_logrus.Option{ grpc_logrus.WithLevels(customCodeToLevel), grpc_logrus.WithTimestampFormat(tcase.timestampFormat), } b := newLogrusBaseSuite(t) b.timestampFormat = tcase.timestampFormat b.InterceptorTestSuite.ServerOpts = []grpc.ServerOption{ grpc_middleware.WithStreamServerChain( grpc_ctxtags.StreamServerInterceptor(grpc_ctxtags.WithFieldExtractor(grpc_ctxtags.CodeGenRequestFieldExtractor)), grpc_logrus.StreamServerInterceptor(logrus.NewEntry(b.logger), opts...)), grpc_middleware.WithUnaryServerChain( grpc_ctxtags.UnaryServerInterceptor(grpc_ctxtags.WithFieldExtractor(grpc_ctxtags.CodeGenRequestFieldExtractor)), grpc_logrus.UnaryServerInterceptor(logrus.NewEntry(b.logger), opts...)), } suite.Run(t, &logrusServerSuite{b}) } } type logrusServerSuite struct { *logrusBaseSuite } func (s *logrusServerSuite) TestPing_WithCustomTags() { deadline := time.Now().Add(3 * time.Second) _, err := s.Client.Ping(s.DeadlineCtx(deadline), goodPing) require.NoError(s.T(), err, "there must be not be an error on a successful call") msgs := s.getOutputJSONs() require.Len(s.T(), msgs, 2, "two log statements should be logged") for _, m := range msgs { assert.Equal(s.T(), m["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain the correct service name") assert.Equal(s.T(), m["grpc.method"], "Ping", "all lines must contain the correct method name") assert.Equal(s.T(), m["span.kind"], "server", "all lines must contain the kind of call (server)") assert.Equal(s.T(), m["custom_tags.string"], "something", "all lines must contain `custom_tags.string` with expected value") assert.Equal(s.T(), m["grpc.request.value"], "something", "all lines must contain the correct request value") assert.Equal(s.T(), m["custom_field"], "custom_value", "all lines must contain `custom_field` with the correct value") assert.Contains(s.T(), m, "custom_tags.int", "all lines must contain `custom_tags.int`") require.Contains(s.T(), m, "grpc.start_time", "all lines must contain the start time of the call") _, err := time.Parse(s.timestampFormat, m["grpc.start_time"].(string)) assert.NoError(s.T(), err, "should be able to parse start time") require.Contains(s.T(), m, "grpc.request.deadline", "all lines must contain the deadline of the call") _, err = time.Parse(s.timestampFormat, m["grpc.request.deadline"].(string)) require.NoError(s.T(), err, "should be able to parse deadline") assert.Equal(s.T(), m["grpc.request.deadline"], deadline.Format(s.timestampFormat), "should have the same deadline that was set by the caller") } assert.Equal(s.T(), msgs[0]["msg"], "some ping", "first message must contain the correct user message") assert.Equal(s.T(), msgs[1]["msg"], "finished unary call with code OK", "second message must contain the correct user message") assert.Equal(s.T(), msgs[1]["level"], "info", "OK codes must be logged on info level.") assert.Contains(s.T(), msgs[1], "grpc.time_ms", "interceptor log statement should contain execution time") } func (s *logrusServerSuite) TestPingError_WithCustomLevels() { for _, tcase := range []struct { code codes.Code level logrus.Level msg string }{ { code: codes.Internal, level: logrus.ErrorLevel, msg: "Internal must remap to ErrorLevel in DefaultCodeToLevel", }, { code: codes.NotFound, level: logrus.InfoLevel, msg: "NotFound must remap to InfoLevel in DefaultCodeToLevel", }, { code: codes.FailedPrecondition, level: logrus.WarnLevel, msg: "FailedPrecondition must remap to WarnLevel in DefaultCodeToLevel", }, { code: codes.Unauthenticated, level: logrus.ErrorLevel, msg: "Unauthenticated is overwritten to ErrorLevel with customCodeToLevel override, which probably didn't work", }, } { s.buffer.Reset() _, err := s.Client.PingError( s.SimpleCtx(), &pb_testproto.PingRequest{Value: "something", ErrorCodeReturned: uint32(tcase.code)}) require.Error(s.T(), err, "each call here must return an error") msgs := s.getOutputJSONs() require.Len(s.T(), msgs, 1, "only the interceptor log message is printed in PingErr") m := msgs[0] assert.Equal(s.T(), m["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain the correct service name") assert.Equal(s.T(), m["grpc.method"], "PingError", "all lines must contain the correct method name") assert.Equal(s.T(), m["grpc.code"], tcase.code.String(), "a gRPC code must be present") assert.Equal(s.T(), m["level"], tcase.level.String(), tcase.msg) assert.Equal(s.T(), m["msg"], "finished unary call with code "+tcase.code.String(), "must have the correct finish message") require.Contains(s.T(), m, "grpc.start_time", "all lines must contain a start time for the call") _, err = time.Parse(s.timestampFormat, m["grpc.start_time"].(string)) assert.NoError(s.T(), err, "should be able to parse the start time") require.Contains(s.T(), m, "grpc.request.deadline", "all lines must contain the deadline of the call") _, err = time.Parse(s.timestampFormat, m["grpc.request.deadline"].(string)) require.NoError(s.T(), err, "should be able to parse deadline") } } func (s *logrusServerSuite) TestPingList_WithCustomTags() { stream, err := s.Client.PingList(s.SimpleCtx(), goodPing) require.NoError(s.T(), err, "should not fail on establishing the stream") for { _, err := stream.Recv() if err == io.EOF { break } require.NoError(s.T(), err, "reading stream should not fail") } msgs := s.getOutputJSONs() require.Len(s.T(), msgs, 2, "two log statements should be logged") for _, m := range msgs { assert.Equal(s.T(), m["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain the correct service name") assert.Equal(s.T(), m["grpc.method"], "PingList", "all lines must contain the correct method name") assert.Equal(s.T(), m["span.kind"], "server", "all lines must contain the kind of call (server)") assert.Equal(s.T(), m["custom_tags.string"], "something", "all lines must contain the correct `custom_tags.string`") assert.Equal(s.T(), m["grpc.request.value"], "something", "all lines must contain the correct request value") assert.Contains(s.T(), m, "custom_tags.int", "all lines must contain `custom_tags.int`") require.Contains(s.T(), m, "grpc.start_time", "all lines must contain the start time for the call") _, err := time.Parse(s.timestampFormat, m["grpc.start_time"].(string)) assert.NoError(s.T(), err, "should be able to parse start time as RFC3339") require.Contains(s.T(), m, "grpc.request.deadline", "all lines must contain the deadline of the call") _, err = time.Parse(s.timestampFormat, m["grpc.request.deadline"].(string)) require.NoError(s.T(), err, "should be able to parse deadline") } assert.Equal(s.T(), msgs[0]["msg"], "some pinglist", "msg must be the correct message") assert.Equal(s.T(), msgs[1]["msg"], "finished streaming call with code OK", "msg must be the correct message") assert.Equal(s.T(), msgs[1]["level"], "info", "OK codes must be logged on info level.") assert.Contains(s.T(), msgs[1], "grpc.time_ms", "interceptor log statement should contain execution time") } func TestLogrusServerOverrideSuite(t *testing.T) { opts := []grpc_logrus.Option{ grpc_logrus.WithDurationField(grpc_logrus.DurationToDurationField), } b := newLogrusBaseSuite(t) b.InterceptorTestSuite.ServerOpts = []grpc.ServerOption{ grpc_middleware.WithStreamServerChain( grpc_ctxtags.StreamServerInterceptor(), grpc_logrus.StreamServerInterceptor(logrus.NewEntry(b.logger), opts...)), grpc_middleware.WithUnaryServerChain( grpc_ctxtags.UnaryServerInterceptor(), grpc_logrus.UnaryServerInterceptor(logrus.NewEntry(b.logger), opts...)), } suite.Run(t, &logrusServerOverrideSuite{b}) } type logrusServerOverrideSuite struct { *logrusBaseSuite } func (s *logrusServerOverrideSuite) TestPing_HasOverriddenDuration() { _, err := s.Client.Ping(s.SimpleCtx(), goodPing) require.NoError(s.T(), err, "there must be not be an error on a successful call") msgs := s.getOutputJSONs() require.Len(s.T(), msgs, 2, "two log statements should be logged") for _, m := range msgs { assert.Equal(s.T(), m["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain service name") assert.Equal(s.T(), m["grpc.method"], "Ping", "all lines must contain method name") } assert.Equal(s.T(), msgs[0]["msg"], "some ping", "first message must be correct") assert.NotContains(s.T(), msgs[0], "grpc.time_ms", "first message must not contain default duration") assert.NotContains(s.T(), msgs[0], "grpc.duration", "first message must not contain overridden duration") assert.Equal(s.T(), msgs[1]["msg"], "finished unary call with code OK", "second message must be correct") assert.Equal(s.T(), msgs[1]["level"], "info", "second must be logged on info level.") assert.NotContains(s.T(), msgs[1], "grpc.time_ms", "second message must not contain default duration") assert.Contains(s.T(), msgs[1], "grpc.duration", "second message must contain overridden duration") } func (s *logrusServerOverrideSuite) TestPingList_HasOverriddenDuration() { stream, err := s.Client.PingList(s.SimpleCtx(), goodPing) require.NoError(s.T(), err, "should not fail on establishing the stream") for { _, err := stream.Recv() if err == io.EOF { break } require.NoError(s.T(), err, "reading stream should not fail") } msgs := s.getOutputJSONs() require.Len(s.T(), msgs, 2, "two log statements should be logged") for _, m := range msgs { assert.Equal(s.T(), m["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain service name") assert.Equal(s.T(), m["grpc.method"], "PingList", "all lines must contain method name") } assert.Equal(s.T(), msgs[0]["msg"], "some pinglist", "first message must contain user message") assert.NotContains(s.T(), msgs[0], "grpc.time_ms", "first message must not contain default duration") assert.NotContains(s.T(), msgs[0], "grpc.duration", "first message must not contain overridden duration") assert.Equal(s.T(), msgs[1]["msg"], "finished streaming call with code OK", "second message must contain correct message") assert.Equal(s.T(), msgs[1]["level"], "info", "second message must be logged on info level.") assert.NotContains(s.T(), msgs[1], "grpc.time_ms", "second message must not contain default duration") assert.Contains(s.T(), msgs[1], "grpc.duration", "second message must contain overridden duration") } func TestLogrusServerOverrideDeciderSuite(t *testing.T) { opts := []grpc_logrus.Option{ grpc_logrus.WithDecider(func(method string, err error) bool { if err != nil && method == "/mwitkow.testproto.TestService/PingError" { return true } return false }), } b := newLogrusBaseSuite(t) b.InterceptorTestSuite.ServerOpts = []grpc.ServerOption{ grpc_middleware.WithStreamServerChain( grpc_ctxtags.StreamServerInterceptor(), grpc_logrus.StreamServerInterceptor(logrus.NewEntry(b.logger), opts...)), grpc_middleware.WithUnaryServerChain( grpc_ctxtags.UnaryServerInterceptor(), grpc_logrus.UnaryServerInterceptor(logrus.NewEntry(b.logger), opts...)), } suite.Run(t, &logrusServerOverrideDeciderSuite{b}) } type logrusServerOverrideDeciderSuite struct { *logrusBaseSuite } func (s *logrusServerOverrideDeciderSuite) TestPing_HasOverriddenDecider() { _, err := s.Client.Ping(s.SimpleCtx(), goodPing) require.NoError(s.T(), err, "there must be not be an error on a successful call") msgs := s.getOutputJSONs() require.Len(s.T(), msgs, 1, "single log statements should be logged") assert.Equal(s.T(), msgs[0]["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain service name") assert.Equal(s.T(), msgs[0]["grpc.method"], "Ping", "all lines must contain method name") assert.Equal(s.T(), msgs[0]["msg"], "some ping", "handler's message must contain user message") } func (s *logrusServerOverrideDeciderSuite) TestPingError_HasOverriddenDecider() { code := codes.NotFound level := logrus.InfoLevel msg := "NotFound must remap to InfoLevel in DefaultCodeToLevel" s.buffer.Reset() _, err := s.Client.PingError( s.SimpleCtx(), &pb_testproto.PingRequest{Value: "something", ErrorCodeReturned: uint32(code)}) require.Error(s.T(), err, "each call here must return an error") msgs := s.getOutputJSONs() require.Len(s.T(), msgs, 1, "only the interceptor log message is printed in PingErr") m := msgs[0] assert.Equal(s.T(), m["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain service name") assert.Equal(s.T(), m["grpc.method"], "PingError", "all lines must contain method name") assert.Equal(s.T(), m["grpc.code"], code.String(), "all lines must correct gRPC code") assert.Equal(s.T(), m["level"], level.String(), msg) } func (s *logrusServerOverrideDeciderSuite) TestPingList_HasOverriddenDecider() { stream, err := s.Client.PingList(s.SimpleCtx(), goodPing) require.NoError(s.T(), err, "should not fail on establishing the stream") for { _, err := stream.Recv() if err == io.EOF { break } require.NoError(s.T(), err, "reading stream should not fail") } msgs := s.getOutputJSONs() require.Len(s.T(), msgs, 1, "single log statements should be logged") assert.Equal(s.T(), msgs[0]["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain service name") assert.Equal(s.T(), msgs[0]["grpc.method"], "PingList", "all lines must contain method name") assert.Equal(s.T(), msgs[0]["msg"], "some pinglist", "handler's message must contain user message") assert.NotContains(s.T(), msgs[0], "grpc.time_ms", "handler's message must not contain default duration") assert.NotContains(s.T(), msgs[0], "grpc.duration", "handler's message must not contain overridden duration") } func TestLogrusServerMessageProducerSuite(t *testing.T) { opts := []grpc_logrus.Option{ grpc_logrus.WithMessageProducer(StubMessageProducer), } b := newLogrusBaseSuite(t) b.InterceptorTestSuite.ServerOpts = []grpc.ServerOption{ grpc_middleware.WithStreamServerChain( grpc_ctxtags.StreamServerInterceptor(), grpc_logrus.StreamServerInterceptor(logrus.NewEntry(b.logger), opts...)), grpc_middleware.WithUnaryServerChain( grpc_ctxtags.UnaryServerInterceptor(), grpc_logrus.UnaryServerInterceptor(logrus.NewEntry(b.logger), opts...)), } suite.Run(t, &logrusServerMessageProducerSuite{b}) } type logrusServerMessageProducerSuite struct { *logrusBaseSuite } func (s *logrusServerMessageProducerSuite) TestPing_HasMessageProducer() { _, err := s.Client.Ping(s.SimpleCtx(), goodPing) require.NoError(s.T(), err, "there must be not be an error on a successful call") msgs := s.getOutputJSONs() require.Len(s.T(), msgs, 2, "single log statements should be logged") assert.Equal(s.T(), msgs[0]["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain service name") assert.Equal(s.T(), msgs[0]["grpc.method"], "Ping", "all lines must contain method name") assert.Equal(s.T(), msgs[1]["msg"], "custom message", "user defined message producer must be used") assert.Equal(s.T(), msgs[0]["msg"], "some ping", "handler's message must contain user message") } go-grpc-middleware-1.3.0/logging/logrus/shared_test.go000066400000000000000000000073521404040257500227710ustar00rootroot00000000000000package grpc_logrus_test import ( "bytes" "context" "encoding/json" "io" "testing" grpc_logrus "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus" "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags" grpc_testing "github.com/grpc-ecosystem/go-grpc-middleware/testing" pb_testproto "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto" "github.com/sirupsen/logrus" "google.golang.org/grpc/codes" ) var ( goodPing = &pb_testproto.PingRequest{Value: "something", SleepTimeMs: 9999} ) type loggingPingService struct { pb_testproto.TestServiceServer } func customCodeToLevel(c codes.Code) logrus.Level { if c == codes.Unauthenticated { // Make this a special case for tests, and an error. return logrus.ErrorLevel } level := grpc_logrus.DefaultCodeToLevel(c) return level } func (s *loggingPingService) Ping(ctx context.Context, ping *pb_testproto.PingRequest) (*pb_testproto.PingResponse, error) { grpc_ctxtags.Extract(ctx).Set("custom_tags.string", "something").Set("custom_tags.int", 1337) ctxlogrus.AddFields(ctx, logrus.Fields{"custom_field": "custom_value"}) ctxlogrus.Extract(ctx).Info("some ping") return s.TestServiceServer.Ping(ctx, ping) } func (s *loggingPingService) PingError(ctx context.Context, ping *pb_testproto.PingRequest) (*pb_testproto.Empty, error) { return s.TestServiceServer.PingError(ctx, ping) } func (s *loggingPingService) PingList(ping *pb_testproto.PingRequest, stream pb_testproto.TestService_PingListServer) error { grpc_ctxtags.Extract(stream.Context()).Set("custom_tags.string", "something").Set("custom_tags.int", 1337) ctxlogrus.AddFields(stream.Context(), logrus.Fields{"custom_field": "custom_value"}) ctxlogrus.Extract(stream.Context()).Info("some pinglist") return s.TestServiceServer.PingList(ping, stream) } func (s *loggingPingService) PingEmpty(ctx context.Context, empty *pb_testproto.Empty) (*pb_testproto.PingResponse, error) { return s.TestServiceServer.PingEmpty(ctx, empty) } type logrusBaseSuite struct { *grpc_testing.InterceptorTestSuite mutexBuffer *grpc_testing.MutexReadWriter buffer *bytes.Buffer logger *logrus.Logger timestampFormat string } func newLogrusBaseSuite(t *testing.T) *logrusBaseSuite { b := &bytes.Buffer{} muB := grpc_testing.NewMutexReadWriter(b) logger := logrus.New() logger.Out = muB logger.Formatter = &logrus.JSONFormatter{DisableTimestamp: true} return &logrusBaseSuite{ logger: logger, buffer: b, mutexBuffer: muB, InterceptorTestSuite: &grpc_testing.InterceptorTestSuite{ TestService: &loggingPingService{&grpc_testing.TestPingService{T: t}}, }, } } func (s *logrusBaseSuite) SetupTest() { s.mutexBuffer.Lock() s.buffer.Reset() s.mutexBuffer.Unlock() } func (s *logrusBaseSuite) getOutputJSONs() []map[string]interface{} { ret := make([]map[string]interface{}, 0) dec := json.NewDecoder(s.mutexBuffer) for { var val map[string]interface{} err := dec.Decode(&val) if err == io.EOF { break } if err != nil { s.T().Fatalf("failed decoding output from Logrus JSON: %v", err) } ret = append(ret, val) } return ret } func StubMessageProducer(ctx context.Context, format string, level logrus.Level, code codes.Code, err error, fields logrus.Fields) { if err != nil { fields[logrus.ErrorKey] = err } format = "custom message" entry := ctxlogrus.Extract(ctx).WithContext(ctx).WithFields(fields) switch level { case logrus.DebugLevel: entry.Debugf(format) case logrus.InfoLevel: entry.Infof(format) case logrus.WarnLevel: entry.Warningf(format) case logrus.ErrorLevel: entry.Errorf(format) case logrus.FatalLevel: entry.Fatalf(format) case logrus.PanicLevel: entry.Panicf(format) } } go-grpc-middleware-1.3.0/logging/settable/000077500000000000000000000000001404040257500204165ustar00rootroot00000000000000go-grpc-middleware-1.3.0/logging/settable/doc.go000066400000000000000000000010031404040257500215040ustar00rootroot00000000000000// /* grpc_logsettable contains a thread-safe wrapper around grpc-logging infrastructure. The go-grpc assumes that logger can be only configured once as the `SetLoggerV2` method is: ```Not mutex-protected, should be called before any gRPC functions.``` This package allows to supply parent logger once ("before any grpc"), but later change underlying implementation in thread-safe way when needed. It's in particular useful for testing, where each testcase might need its own logger. */ package grpc_logsettable go-grpc-middleware-1.3.0/logging/settable/logsettable.go000066400000000000000000000041651404040257500232600ustar00rootroot00000000000000package grpc_logsettable import ( "io/ioutil" "sync" "google.golang.org/grpc/grpclog" ) // SettableLoggerV2 is thread-safe. type SettableLoggerV2 interface { grpclog.LoggerV2 // Sets given logger as the underlying implementation. Set(loggerv2 grpclog.LoggerV2) // Sets `discard` logger as the underlying implementation. Reset() } // ReplaceGrpcLoggerV2 creates and configures SettableLoggerV2 as grpc logger. func ReplaceGrpcLoggerV2() SettableLoggerV2 { settable := &settableLoggerV2{} settable.Reset() grpclog.SetLoggerV2(settable) return settable } // SettableLoggerV2 implements SettableLoggerV2 type settableLoggerV2 struct { log grpclog.LoggerV2 mu sync.RWMutex } func (s *settableLoggerV2) Set(log grpclog.LoggerV2) { s.mu.Lock() defer s.mu.Unlock() s.log = log } func (s *settableLoggerV2) Reset() { s.Set(grpclog.NewLoggerV2(ioutil.Discard, ioutil.Discard, ioutil.Discard)) } func (s *settableLoggerV2) get() grpclog.LoggerV2 { s.mu.RLock() defer s.mu.RUnlock() return s.log } func (s *settableLoggerV2) Info(args ...interface{}) { s.get().Info(args) } func (s *settableLoggerV2) Infoln(args ...interface{}) { s.get().Infoln(args) } func (s *settableLoggerV2) Infof(format string, args ...interface{}) { s.get().Infof(format, args) } func (s *settableLoggerV2) Warning(args ...interface{}) { s.get().Warning(args) } func (s *settableLoggerV2) Warningln(args ...interface{}) { s.get().Warningln(args) } func (s *settableLoggerV2) Warningf(format string, args ...interface{}) { s.get().Warningf(format, args) } func (s *settableLoggerV2) Error(args ...interface{}) { s.get().Error(args) } func (s *settableLoggerV2) Errorln(args ...interface{}) { s.get().Errorln(args) } func (s *settableLoggerV2) Errorf(format string, args ...interface{}) { s.get().Errorf(format, args) } func (s *settableLoggerV2) Fatal(args ...interface{}) { s.get().Fatal(args) } func (s *settableLoggerV2) Fatalln(args ...interface{}) { s.get().Fatalln(args) } func (s *settableLoggerV2) Fatalf(format string, args ...interface{}) { s.get().Fatalf(format, args) } func (s *settableLoggerV2) V(l int) bool { return s.get().V(l) } go-grpc-middleware-1.3.0/logging/settable/logsettable_test.go000066400000000000000000000024201404040257500243070ustar00rootroot00000000000000package grpc_logsettable_test import ( "bytes" "io/ioutil" "os" "testing" grpc_logsettable "github.com/grpc-ecosystem/go-grpc-middleware/logging/settable" "github.com/stretchr/testify/assert" "google.golang.org/grpc/grpclog" ) func ExampleSettableLoggerV2_init() { l1 := grpclog.NewLoggerV2(ioutil.Discard, ioutil.Discard, ioutil.Discard) l2 := grpclog.NewLoggerV2(os.Stdout, os.Stdout, os.Stdout) settableLogger := grpc_logsettable.ReplaceGrpcLoggerV2() grpclog.Info("Discarded by default") settableLogger.Set(l1) grpclog.Info("Discarded log by l1") settableLogger.Set(l2) grpclog.Info("Emitted log by l2") // Expected output: INFO: 2021/03/15 12:59:54 [Emitted log by l2] } func TestSettableLoggerV2_init(t *testing.T) { l1buffer := &bytes.Buffer{} l1 := grpclog.NewLoggerV2(l1buffer, l1buffer, l1buffer) l2buffer := &bytes.Buffer{} l2 := grpclog.NewLoggerV2(l2buffer, l2buffer, l2buffer) settableLogger := grpc_logsettable.ReplaceGrpcLoggerV2() grpclog.Info("Discarded by default") settableLogger.Set(l1) grpclog.SetLoggerV2(settableLogger) grpclog.Info("Emitted log by l1") settableLogger.Set(l2) grpclog.Info("Emitted log by l2") assert.Contains(t, l1buffer.String(), "Emitted log by l1") assert.Contains(t, l2buffer.String(), "Emitted log by l2") } go-grpc-middleware-1.3.0/logging/zap/000077500000000000000000000000001404040257500174055ustar00rootroot00000000000000go-grpc-middleware-1.3.0/logging/zap/client_interceptors.go000066400000000000000000000046461404040257500240250ustar00rootroot00000000000000// Copyright 2017 Michal Witkowski. All Rights Reserved. // See LICENSE for licensing terms. package grpc_zap import ( "context" "path" "time" "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap" "go.uber.org/zap" "go.uber.org/zap/zapcore" "google.golang.org/grpc" ) var ( // ClientField is used in every client-side log statement made through grpc_zap. Can be overwritten before initialization. ClientField = zap.String("span.kind", "client") ) // UnaryClientInterceptor returns a new unary client interceptor that optionally logs the execution of external gRPC calls. func UnaryClientInterceptor(logger *zap.Logger, opts ...Option) grpc.UnaryClientInterceptor { o := evaluateClientOpt(opts) return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { fields := newClientLoggerFields(ctx, method) startTime := time.Now() err := invoker(ctx, method, req, reply, cc, opts...) newCtx := ctxzap.ToContext(ctx, logger.With(fields...)) logFinalClientLine(newCtx, o, startTime, err, "finished client unary call") return err } } // StreamClientInterceptor returns a new streaming client interceptor that optionally logs the execution of external gRPC calls. func StreamClientInterceptor(logger *zap.Logger, opts ...Option) grpc.StreamClientInterceptor { o := evaluateClientOpt(opts) return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { fields := newClientLoggerFields(ctx, method) startTime := time.Now() clientStream, err := streamer(ctx, desc, cc, method, opts...) newCtx := ctxzap.ToContext(ctx, logger.With(fields...)) logFinalClientLine(newCtx, o, startTime, err, "finished client streaming call") return clientStream, err } } func logFinalClientLine(ctx context.Context, o *options, startTime time.Time, err error, msg string) { code := o.codeFunc(err) level := o.levelFunc(code) duration := o.durationFunc(time.Now().Sub(startTime)) o.messageFunc(ctx, msg, level, code, err, duration) } func newClientLoggerFields(ctx context.Context, fullMethodString string) []zapcore.Field { service := path.Dir(fullMethodString)[1:] method := path.Base(fullMethodString) return []zapcore.Field{ SystemField, ClientField, zap.String("grpc.service", service), zap.String("grpc.method", method), } } go-grpc-middleware-1.3.0/logging/zap/client_interceptors_test.go000066400000000000000000000177041404040257500250630ustar00rootroot00000000000000package grpc_zap_test import ( "io" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "google.golang.org/grpc" "google.golang.org/grpc/codes" grpc_zap "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap" pb_testproto "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto" "go.uber.org/zap/zapcore" ) func customClientCodeToLevel(c codes.Code) zapcore.Level { if c == codes.Unauthenticated { // Make this a special case for tests, and an error. return zapcore.ErrorLevel } level := grpc_zap.DefaultClientCodeToLevel(c) return level } func TestZapClientSuite(t *testing.T) { opts := []grpc_zap.Option{ grpc_zap.WithLevels(customClientCodeToLevel), } b := newBaseZapSuite(t) b.InterceptorTestSuite.ClientOpts = []grpc.DialOption{ grpc.WithUnaryInterceptor(grpc_zap.UnaryClientInterceptor(b.log, opts...)), grpc.WithStreamInterceptor(grpc_zap.StreamClientInterceptor(b.log, opts...)), } suite.Run(t, &zapClientSuite{b}) } type zapClientSuite struct { *zapBaseSuite } func (s *zapClientSuite) TestPing() { _, err := s.Client.Ping(s.SimpleCtx(), goodPing) require.NoError(s.T(), err, "there must be not be an error on a successful call") msgs := s.getOutputJSONs() require.Len(s.T(), msgs, 1, "one log statement should be logged") assert.Equal(s.T(), msgs[0]["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain service name") assert.Equal(s.T(), msgs[0]["grpc.method"], "Ping", "all lines must contain method name") assert.Equal(s.T(), msgs[0]["msg"], "finished client unary call", "must contain correct message") assert.Equal(s.T(), msgs[0]["span.kind"], "client", "all lines must contain the kind of call (client)") assert.Equal(s.T(), msgs[0]["level"], "debug", "must be logged on debug level.") assert.Contains(s.T(), msgs[0], "grpc.time_ms", "interceptor log statement should contain execution time") } func (s *zapClientSuite) TestPingList() { stream, err := s.Client.PingList(s.SimpleCtx(), goodPing) require.NoError(s.T(), err, "should not fail on establishing the stream") for { _, err := stream.Recv() if err == io.EOF { break } require.NoError(s.T(), err, "reading stream should not fail") } msgs := s.getOutputJSONs() require.Len(s.T(), msgs, 1, "one log statement should be logged") assert.Equal(s.T(), msgs[0]["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain service name") assert.Equal(s.T(), msgs[0]["grpc.method"], "PingList", "all lines must contain method name") assert.Equal(s.T(), msgs[0]["msg"], "finished client streaming call", "handler's message must contain user message") assert.Equal(s.T(), msgs[0]["span.kind"], "client", "all lines must contain the kind of call (client)") assert.Equal(s.T(), msgs[0]["level"], "debug", "OK codes must be logged on debug level.") assert.Contains(s.T(), msgs[0], "grpc.time_ms", "handler's message must contain time in ms") } func (s *zapClientSuite) TestPingError_WithCustomLevels() { for _, tcase := range []struct { code codes.Code level zapcore.Level msg string }{ { code: codes.Internal, level: zapcore.WarnLevel, msg: "Internal must remap to ErrorLevel in DefaultClientCodeToLevel", }, { code: codes.NotFound, level: zapcore.DebugLevel, msg: "NotFound must remap to InfoLevel in DefaultClientCodeToLevel", }, { code: codes.FailedPrecondition, level: zapcore.DebugLevel, msg: "FailedPrecondition must remap to WarnLevel in DefaultClientCodeToLevel", }, { code: codes.Unauthenticated, level: zapcore.ErrorLevel, msg: "Unauthenticated is overwritten to ErrorLevel with customClientCodeToLevel override, which probably didn't work", }, } { s.SetupTest() _, err := s.Client.PingError( s.SimpleCtx(), &pb_testproto.PingRequest{Value: "something", ErrorCodeReturned: uint32(tcase.code)}) require.Error(s.T(), err, "each call here must return an error") msgs := s.getOutputJSONs() require.Len(s.T(), msgs, 1, "only the interceptor log message is printed in PingErr") assert.Equal(s.T(), msgs[0]["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain service name") assert.Equal(s.T(), msgs[0]["grpc.method"], "PingError", "all lines must contain method name") assert.Equal(s.T(), msgs[0]["grpc.code"], tcase.code.String(), "all lines must contain the correct gRPC code") assert.Equal(s.T(), msgs[0]["level"], tcase.level.String(), tcase.msg) } } func TestZapClientOverrideSuite(t *testing.T) { opts := []grpc_zap.Option{ grpc_zap.WithDurationField(grpc_zap.DurationToDurationField), } b := newBaseZapSuite(t) b.InterceptorTestSuite.ClientOpts = []grpc.DialOption{ grpc.WithUnaryInterceptor(grpc_zap.UnaryClientInterceptor(b.log, opts...)), grpc.WithStreamInterceptor(grpc_zap.StreamClientInterceptor(b.log, opts...)), } suite.Run(t, &zapClientOverrideSuite{b}) } type zapClientOverrideSuite struct { *zapBaseSuite } func (s *zapClientOverrideSuite) TestPing_HasOverrides() { _, err := s.Client.Ping(s.SimpleCtx(), goodPing) require.NoError(s.T(), err, "there must be not be an error on a successful call") msgs := s.getOutputJSONs() require.Len(s.T(), msgs, 1, "one log statement should be logged") assert.Equal(s.T(), msgs[0]["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain service name") assert.Equal(s.T(), msgs[0]["grpc.method"], "Ping", "all lines must contain method name") assert.Equal(s.T(), msgs[0]["msg"], "finished client unary call", "handler's message must contain user message") assert.NotContains(s.T(), msgs[0], "grpc.time_ms", "handler's message must not contain default duration") assert.Contains(s.T(), msgs[0], "grpc.duration", "handler's message must contain overridden duration") } func (s *zapClientOverrideSuite) TestPingList_HasOverrides() { stream, err := s.Client.PingList(s.SimpleCtx(), goodPing) require.NoError(s.T(), err, "should not fail on establishing the stream") for { _, err := stream.Recv() if err == io.EOF { break } require.NoError(s.T(), err, "reading stream should not fail") } msgs := s.getOutputJSONs() require.Len(s.T(), msgs, 1, "one log statement should be logged") assert.Equal(s.T(), msgs[0]["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain service name") assert.Equal(s.T(), msgs[0]["grpc.method"], "PingList", "all lines must contain method name") assert.Equal(s.T(), msgs[0]["msg"], "finished client streaming call", "handler's message must contain user message") assert.Equal(s.T(), msgs[0]["span.kind"], "client", "all lines must contain the kind of call (client)") assert.Equal(s.T(), msgs[0]["level"], "debug", "must be logged on debug level.") assert.NotContains(s.T(), msgs[0], "grpc.time_ms", "handler's message must not contain default duration") assert.Contains(s.T(), msgs[0], "grpc.duration", "handler's message must contain overridden duration") } func TestZapLoggingClientMessageProducerSuite(t *testing.T) { opts := []grpc_zap.Option{ grpc_zap.WithMessageProducer(StubMessageProducer), } b := newBaseZapSuite(t) b.InterceptorTestSuite.ClientOpts = []grpc.DialOption{ grpc.WithUnaryInterceptor(grpc_zap.UnaryClientInterceptor(b.log, opts...)), grpc.WithStreamInterceptor(grpc_zap.StreamClientInterceptor(b.log, opts...)), } suite.Run(t, &zapClientMessageProducerSuite{b}) } type zapClientMessageProducerSuite struct { *zapBaseSuite } func (s *zapClientMessageProducerSuite) TestPing_HasOverriddenMessageProducer() { _, err := s.Client.Ping(s.SimpleCtx(), goodPing) require.NoError(s.T(), err, "there must be not be an error on a successful call") msgs := s.getOutputJSONs() require.Len(s.T(), msgs, 1, "one log statement should be logged") assert.Equal(s.T(), msgs[0]["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain service name") assert.Equal(s.T(), msgs[0]["grpc.method"], "Ping", "all lines must contain method name") assert.Equal(s.T(), msgs[0]["msg"], "custom message", "handler's message must contain user message") } go-grpc-middleware-1.3.0/logging/zap/context.go000066400000000000000000000010471404040257500214220ustar00rootroot00000000000000package grpc_zap import ( "context" "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap" "go.uber.org/zap" "go.uber.org/zap/zapcore" ) // AddFields adds zap fields to the logger. // Deprecated: should use the ctxzap.AddFields instead func AddFields(ctx context.Context, fields ...zapcore.Field) { ctxzap.AddFields(ctx, fields...) } // Extract takes the call-scoped Logger from grpc_zap middleware. // Deprecated: should use the ctxzap.Extract instead func Extract(ctx context.Context) *zap.Logger { return ctxzap.Extract(ctx) } go-grpc-middleware-1.3.0/logging/zap/ctxzap/000077500000000000000000000000001404040257500207165ustar00rootroot00000000000000go-grpc-middleware-1.3.0/logging/zap/ctxzap/context.go000066400000000000000000000047541404040257500227430ustar00rootroot00000000000000package ctxzap import ( "context" "github.com/grpc-ecosystem/go-grpc-middleware/tags" "go.uber.org/zap" "go.uber.org/zap/zapcore" ) type ctxMarker struct{} type ctxLogger struct { logger *zap.Logger fields []zapcore.Field } var ( ctxMarkerKey = &ctxMarker{} nullLogger = zap.NewNop() ) // AddFields adds zap fields to the logger. func AddFields(ctx context.Context, fields ...zapcore.Field) { l, ok := ctx.Value(ctxMarkerKey).(*ctxLogger) if !ok || l == nil { return } l.fields = append(l.fields, fields...) } // Extract takes the call-scoped Logger from grpc_zap middleware. // // It always returns a Logger that has all the grpc_ctxtags updated. func Extract(ctx context.Context) *zap.Logger { l, ok := ctx.Value(ctxMarkerKey).(*ctxLogger) if !ok || l == nil { return nullLogger } // Add grpc_ctxtags tags metadata until now. fields := TagsToFields(ctx) // Add zap fields added until now. fields = append(fields, l.fields...) return l.logger.With(fields...) } // TagsToFields transforms the Tags on the supplied context into zap fields. func TagsToFields(ctx context.Context) []zapcore.Field { fields := []zapcore.Field{} tags := grpc_ctxtags.Extract(ctx) for k, v := range tags.Values() { fields = append(fields, zap.Any(k, v)) } return fields } // ToContext adds the zap.Logger to the context for extraction later. // Returning the new context that has been created. func ToContext(ctx context.Context, logger *zap.Logger) context.Context { l := &ctxLogger{ logger: logger, } return context.WithValue(ctx, ctxMarkerKey, l) } // Debug is equivalent to calling Debug on the zap.Logger in the context. // It is a no-op if the context does not contain a zap.Logger. func Debug(ctx context.Context, msg string, fields ...zap.Field) { Extract(ctx).Debug(msg, fields...) } // Info is equivalent to calling Info on the zap.Logger in the context. // It is a no-op if the context does not contain a zap.Logger. func Info(ctx context.Context, msg string, fields ...zap.Field) { Extract(ctx).Info(msg, fields...) } // Warn is equivalent to calling Warn on the zap.Logger in the context. // It is a no-op if the context does not contain a zap.Logger. func Warn(ctx context.Context, msg string, fields ...zap.Field) { Extract(ctx).Warn(msg, fields...) } // Error is equivalent to calling Error on the zap.Logger in the context. // It is a no-op if the context does not contain a zap.Logger. func Error(ctx context.Context, msg string, fields ...zap.Field) { Extract(ctx).Error(msg, fields...) } go-grpc-middleware-1.3.0/logging/zap/ctxzap/context_test.go000066400000000000000000000023071404040257500237720ustar00rootroot00000000000000package ctxzap import ( "context" "testing" "go.uber.org/zap" "go.uber.org/zap/zapcore" "go.uber.org/zap/zaptest" ) func TestShorthands(t *testing.T) { cases := []struct { fn func(ctx context.Context, msg string, fields ...zapcore.Field) level zapcore.Level }{ {Debug, zap.DebugLevel}, {Info, zap.InfoLevel}, {Warn, zap.WarnLevel}, {Error, zap.ErrorLevel}, } const message = "omg!" for _, c := range cases { t.Run(c.level.String(), func(t *testing.T) { called := false logger := zaptest.NewLogger(t, zaptest.WrapOptions(zap.Hooks(func(e zapcore.Entry) error { called = true if e.Level != c.level { t.Fatalf("Expected %v, got %v", c.level, e.Level) } if e.Message != message { t.Fatalf("message: expected %v, got %v", message, e.Message) } return nil }))) ctx := ToContext(context.Background(), logger) c.fn(ctx, message) if !called { t.Fatal("hook not called") } }) } } func TestShorthandsNoop(t *testing.T) { // Just check we don't panic if there is no logger in the context. Debug(context.Background(), "no-op") Info(context.Background(), "no-op") Warn(context.Background(), "no-op") Error(context.Background(), "no-op") } go-grpc-middleware-1.3.0/logging/zap/ctxzap/doc.go000066400000000000000000000011721404040257500220130ustar00rootroot00000000000000/* `ctxzap` is a ctxlogger that is backed by Zap It accepts a user-configured `zap.Logger` that will be used for logging. The same `zap.Logger` will be populated into the `context.Context` passed into gRPC handler code. You can use `ctxzap.Extract` to log into a request-scoped `zap.Logger` instance in your handler code. As `ctxzap.Extract` will iterate all tags on from `grpc_ctxtags` it is therefore expensive so it is advised that you extract once at the start of the function from the context and reuse it for the remainder of the function (see examples). Please see examples and tests for examples of use. */ package ctxzap go-grpc-middleware-1.3.0/logging/zap/ctxzap/examples_test.go000066400000000000000000000016421404040257500241250ustar00rootroot00000000000000package ctxzap_test import ( "context" "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap" "github.com/grpc-ecosystem/go-grpc-middleware/tags" pb_testproto "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto" "go.uber.org/zap" ) var zapLogger *zap.Logger // Simple unary handler that adds custom fields to the requests's context. These will be used for all log statements. func ExampleExtract_unary() { _ = func(ctx context.Context, ping *pb_testproto.PingRequest) (*pb_testproto.PingResponse, error) { // Add fields the ctxtags of the request which will be added to all extracted loggers. grpc_ctxtags.Extract(ctx).Set("custom_tags.string", "something").Set("custom_tags.int", 1337) // Extract a single request-scoped zap.Logger and log messages. l := ctxzap.Extract(ctx) l.Info("some ping") l.Info("another ping") return &pb_testproto.PingResponse{Value: ping.Value}, nil } } go-grpc-middleware-1.3.0/logging/zap/doc.go000066400000000000000000000072431404040257500205070ustar00rootroot00000000000000/* `grpc_zap` is a gRPC logging middleware backed by ZAP loggers It accepts a user-configured `zap.Logger` that will be used for logging completed gRPC calls. The same `zap.Logger` will be used for logging completed gRPC calls, and be populated into the `context.Context` passed into gRPC handler code. On calling `StreamServerInterceptor` or `UnaryServerInterceptor` this logging middleware will add gRPC call information to the ctx so that it will be present on subsequent use of the `ctx_zap` logger. If a deadline is present on the gRPC request the grpc.request.deadline tag is populated when the request begins. grpc.request.deadline is a string representing the time (RFC3339) when the current call will expire. This package also implements request and response *payload* logging, both for server-side and client-side. These will be logged as structured `jsonpb` fields for every message received/sent (both unary and streaming). For that please use `Payload*Interceptor` functions for that. Please note that the user-provided function that determines whether to log the full request/response payload needs to be written with care, this can significantly slow down gRPC. ZAP can also be made as a backend for gRPC library internals. For that use `ReplaceGrpcLoggerV2`. *Server Interceptor* Below is a JSON formatted example of a log that would be logged by the server interceptor: { "level": "info", // string zap log levels "msg": "finished unary call", // string log message "grpc.code": "OK", // string grpc status code "grpc.method": "Ping", // string method name "grpc.service": "mwitkow.testproto.TestService", // string full name of the called service "grpc.start_time": "2006-01-02T15:04:05Z07:00", // string RFC3339 representation of the start time "grpc.request.deadline": "2006-01-02T15:04:05Z07:00", // string RFC3339 deadline of the current request if supplied "grpc.request.value": "something", // string value on the request "grpc.time_ms": 1.345, // float32 run time of the call in ms "peer.address": { "IP": "127.0.0.1", // string IP address of calling party "Port": 60216, // int port call is coming in on "Zone": "" // string peer zone for caller }, "span.kind": "server", // string client | server "system": "grpc" // string "custom_field": "custom_value", // string user defined field "custom_tags.int": 1337, // int user defined tag on the ctx "custom_tags.string": "something", // string user defined tag on the ctx } *Payload Interceptor* Below is a JSON formatted example of a log that would be logged by the payload interceptor: { "level": "info", // string zap log levels "msg": "client request payload logged as grpc.request.content", // string log message "grpc.request.content": { // object content of RPC request "msg" : { // object ZAP specific inner object "value": "something", // string defined by caller "sleepTimeMs": 9999 // int defined by caller } }, "grpc.method": "Ping", // string method being called "grpc.service": "mwitkow.testproto.TestService", // string service being called "span.kind": "client", // string client | server "system": "grpc" // string } Note - due to implementation ZAP differs from Logrus in the "grpc.request.content" object by having an inner "msg" object. Please see examples and tests for examples of use. Please see settable_test.go for canonical integration through "zaptest" with golang testing infrastructure. */ package grpc_zap go-grpc-middleware-1.3.0/logging/zap/examples_test.go000066400000000000000000000066431404040257500226220ustar00rootroot00000000000000package grpc_zap_test import ( "context" "time" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" grpc_zap "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap" "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap" grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags" pb_testproto "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto" "go.uber.org/zap" "go.uber.org/zap/zapcore" "google.golang.org/grpc" ) var ( zapLogger *zap.Logger customFunc grpc_zap.CodeToLevel ) // Initialization shows a relatively complex initialization sequence. func Example_initialization() { // Shared options for the logger, with a custom gRPC code to log level function. opts := []grpc_zap.Option{ grpc_zap.WithLevels(customFunc), } // Make sure that log statements internal to gRPC library are logged using the zapLogger as well. grpc_zap.ReplaceGrpcLoggerV2(zapLogger) // Create a server, make sure we put the grpc_ctxtags context before everything else. _ = grpc.NewServer( grpc_middleware.WithUnaryServerChain( grpc_ctxtags.UnaryServerInterceptor(grpc_ctxtags.WithFieldExtractor(grpc_ctxtags.CodeGenRequestFieldExtractor)), grpc_zap.UnaryServerInterceptor(zapLogger, opts...), ), grpc_middleware.WithStreamServerChain( grpc_ctxtags.StreamServerInterceptor(grpc_ctxtags.WithFieldExtractor(grpc_ctxtags.CodeGenRequestFieldExtractor)), grpc_zap.StreamServerInterceptor(zapLogger, opts...), ), ) } // Initialization shows an initialization sequence with the duration field generation overridden. func Example_initializationWithDurationFieldOverride() { opts := []grpc_zap.Option{ grpc_zap.WithDurationField(func(duration time.Duration) zapcore.Field { return zap.Int64("grpc.time_ns", duration.Nanoseconds()) }), } _ = grpc.NewServer( grpc_middleware.WithUnaryServerChain( grpc_ctxtags.UnaryServerInterceptor(), grpc_zap.UnaryServerInterceptor(zapLogger, opts...), ), grpc_middleware.WithStreamServerChain( grpc_ctxtags.StreamServerInterceptor(), grpc_zap.StreamServerInterceptor(zapLogger, opts...), ), ) } // Simple unary handler that adds custom fields to the requests's context. These will be used for all log statements. func ExampleExtract_unary() { _ = func(ctx context.Context, ping *pb_testproto.PingRequest) (*pb_testproto.PingResponse, error) { // Add fields the ctxtags of the request which will be added to all extracted loggers. grpc_ctxtags.Extract(ctx).Set("custom_tags.string", "something").Set("custom_tags.int", 1337) // Extract a single request-scoped zap.Logger and log messages. (containing the grpc.xxx tags) l := ctxzap.Extract(ctx) l.Info("some ping") l.Info("another ping") return &pb_testproto.PingResponse{Value: ping.Value}, nil } } func Example_initializationWithDecider() { opts := []grpc_zap.Option{ grpc_zap.WithDecider(func(fullMethodName string, err error) bool { // will not log gRPC calls if it was a call to healthcheck and no error was raised if err == nil && fullMethodName == "foo.bar.healthcheck" { return false } // by default everything will be logged return true }), } _ = []grpc.ServerOption{ grpc_middleware.WithStreamServerChain( grpc_ctxtags.StreamServerInterceptor(), grpc_zap.StreamServerInterceptor(zap.NewNop(), opts...)), grpc_middleware.WithUnaryServerChain( grpc_ctxtags.UnaryServerInterceptor(), grpc_zap.UnaryServerInterceptor(zap.NewNop(), opts...)), } } go-grpc-middleware-1.3.0/logging/zap/grpclogger.go000066400000000000000000000076151404040257500221000ustar00rootroot00000000000000// Copyright 2017 Michal Witkowski. All Rights Reserved. // See LICENSE for licensing terms. package grpc_zap import ( "fmt" grpc_logsettable "github.com/grpc-ecosystem/go-grpc-middleware/logging/settable" "go.uber.org/zap" "google.golang.org/grpc/grpclog" ) // ReplaceGrpcLogger sets the given zap.Logger as a gRPC-level logger. // This should be called *before* any other initialization, preferably from init() functions. // Deprecated: use ReplaceGrpcLoggerV2. func ReplaceGrpcLogger(logger *zap.Logger) { zgl := &zapGrpcLogger{logger.With(SystemField, zap.Bool("grpc_log", true))} grpclog.SetLogger(zgl) } type zapGrpcLogger struct { logger *zap.Logger } func (l *zapGrpcLogger) Fatal(args ...interface{}) { l.logger.Fatal(fmt.Sprint(args...)) } func (l *zapGrpcLogger) Fatalf(format string, args ...interface{}) { l.logger.Fatal(fmt.Sprintf(format, args...)) } func (l *zapGrpcLogger) Fatalln(args ...interface{}) { l.logger.Fatal(fmt.Sprint(args...)) } func (l *zapGrpcLogger) Print(args ...interface{}) { l.logger.Info(fmt.Sprint(args...)) } func (l *zapGrpcLogger) Printf(format string, args ...interface{}) { l.logger.Info(fmt.Sprintf(format, args...)) } func (l *zapGrpcLogger) Println(args ...interface{}) { l.logger.Info(fmt.Sprint(args...)) } // ReplaceGrpcLoggerV2 replaces the grpc_log.LoggerV2 with the provided logger. // It should be called before any gRPC functions. func ReplaceGrpcLoggerV2(logger *zap.Logger) { ReplaceGrpcLoggerV2WithVerbosity(logger, 0) } // ReplaceGrpcLoggerV2WithVerbosity replaces the grpc_.LoggerV2 with the provided logger and verbosity. // It should be called before any gRPC functions. func ReplaceGrpcLoggerV2WithVerbosity(logger *zap.Logger, verbosity int) { zgl := &zapGrpcLoggerV2{ logger: logger.With(SystemField, zap.Bool("grpc_log", true)), verbosity: verbosity, } grpclog.SetLoggerV2(zgl) } // SetGrpcLoggerV2 replaces the grpc_log.LoggerV2 with the provided logger. // It can be used even when grpc infrastructure was initialized. func SetGrpcLoggerV2(settable grpc_logsettable.SettableLoggerV2, logger *zap.Logger) { SetGrpcLoggerV2WithVerbosity(settable, logger, 0) } // SetGrpcLoggerV2WithVerbosity replaces the grpc_.LoggerV2 with the provided logger and verbosity. // It can be used even when grpc infrastructure was initialized. func SetGrpcLoggerV2WithVerbosity(settable grpc_logsettable.SettableLoggerV2, logger *zap.Logger, verbosity int) { zgl := &zapGrpcLoggerV2{ logger: logger.With(SystemField, zap.Bool("grpc_log", true)), verbosity: verbosity, } settable.Set(zgl) } type zapGrpcLoggerV2 struct { logger *zap.Logger verbosity int } func (l *zapGrpcLoggerV2) Info(args ...interface{}) { l.logger.Info(fmt.Sprint(args...)) } func (l *zapGrpcLoggerV2) Infoln(args ...interface{}) { l.logger.Info(fmt.Sprint(args...)) } func (l *zapGrpcLoggerV2) Infof(format string, args ...interface{}) { l.logger.Info(fmt.Sprintf(format, args...)) } func (l *zapGrpcLoggerV2) Warning(args ...interface{}) { l.logger.Warn(fmt.Sprint(args...)) } func (l *zapGrpcLoggerV2) Warningln(args ...interface{}) { l.logger.Warn(fmt.Sprint(args...)) } func (l *zapGrpcLoggerV2) Warningf(format string, args ...interface{}) { l.logger.Warn(fmt.Sprintf(format, args...)) } func (l *zapGrpcLoggerV2) Error(args ...interface{}) { l.logger.Error(fmt.Sprint(args...)) } func (l *zapGrpcLoggerV2) Errorln(args ...interface{}) { l.logger.Error(fmt.Sprint(args...)) } func (l *zapGrpcLoggerV2) Errorf(format string, args ...interface{}) { l.logger.Error(fmt.Sprintf(format, args...)) } func (l *zapGrpcLoggerV2) Fatal(args ...interface{}) { l.logger.Fatal(fmt.Sprint(args...)) } func (l *zapGrpcLoggerV2) Fatalln(args ...interface{}) { l.logger.Fatal(fmt.Sprint(args...)) } func (l *zapGrpcLoggerV2) Fatalf(format string, args ...interface{}) { l.logger.Fatal(fmt.Sprintf(format, args...)) } func (l *zapGrpcLoggerV2) V(level int) bool { return l.verbosity <= level } go-grpc-middleware-1.3.0/logging/zap/grpclogger_test.go000066400000000000000000000013761404040257500231350ustar00rootroot00000000000000package grpc_zap import ( "testing" "github.com/stretchr/testify/assert" "go.uber.org/zap" "go.uber.org/zap/zapcore" "go.uber.org/zap/zaptest/observer" "google.golang.org/grpc/grpclog" ) func Test_zapGrpcLogger_V(t *testing.T) { // copied from gRPC const ( // infoLog indicates Info severity. infoLog int = iota // warningLog indicates Warning severity. warningLog // errorLog indicates Error severity. errorLog // fatalLog indicates Fatal severity. fatalLog ) core, _ := observer.New(zapcore.DebugLevel) logger := zap.New(core) ReplaceGrpcLoggerV2WithVerbosity(logger, warningLog) assert.False(t, grpclog.V(infoLog)) assert.True(t, grpclog.V(warningLog)) assert.True(t, grpclog.V(errorLog)) assert.True(t, grpclog.V(fatalLog)) } go-grpc-middleware-1.3.0/logging/zap/options.go000066400000000000000000000141151404040257500214310ustar00rootroot00000000000000package grpc_zap import ( "context" "time" grpc_logging "github.com/grpc-ecosystem/go-grpc-middleware/logging" "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap" "go.uber.org/zap" "go.uber.org/zap/zapcore" "google.golang.org/grpc/codes" ) var ( defaultOptions = &options{ levelFunc: DefaultCodeToLevel, shouldLog: grpc_logging.DefaultDeciderMethod, codeFunc: grpc_logging.DefaultErrorToCode, durationFunc: DefaultDurationToField, messageFunc: DefaultMessageProducer, timestampFormat: time.RFC3339, } ) type options struct { levelFunc CodeToLevel shouldLog grpc_logging.Decider codeFunc grpc_logging.ErrorToCode durationFunc DurationToField messageFunc MessageProducer timestampFormat string } func evaluateServerOpt(opts []Option) *options { optCopy := &options{} *optCopy = *defaultOptions optCopy.levelFunc = DefaultCodeToLevel for _, o := range opts { o(optCopy) } return optCopy } func evaluateClientOpt(opts []Option) *options { optCopy := &options{} *optCopy = *defaultOptions optCopy.levelFunc = DefaultClientCodeToLevel for _, o := range opts { o(optCopy) } return optCopy } type Option func(*options) // CodeToLevel function defines the mapping between gRPC return codes and interceptor log level. type CodeToLevel func(code codes.Code) zapcore.Level // DurationToField function defines how to produce duration fields for logging type DurationToField func(duration time.Duration) zapcore.Field // WithDecider customizes the function for deciding if the gRPC interceptor logs should log. func WithDecider(f grpc_logging.Decider) Option { return func(o *options) { o.shouldLog = f } } // WithLevels customizes the function for mapping gRPC return codes and interceptor log level statements. func WithLevels(f CodeToLevel) Option { return func(o *options) { o.levelFunc = f } } // WithCodes customizes the function for mapping errors to error codes. func WithCodes(f grpc_logging.ErrorToCode) Option { return func(o *options) { o.codeFunc = f } } // WithDurationField customizes the function for mapping request durations to Zap fields. func WithDurationField(f DurationToField) Option { return func(o *options) { o.durationFunc = f } } // WithMessageProducer customizes the function for message formation. func WithMessageProducer(f MessageProducer) Option { return func(o *options) { o.messageFunc = f } } // WithTimestampFormat customizes the timestamps emitted in the log fields. func WithTimestampFormat(format string) Option { return func(o *options) { o.timestampFormat = format } } // DefaultCodeToLevel is the default implementation of gRPC return codes and interceptor log level for server side. func DefaultCodeToLevel(code codes.Code) zapcore.Level { switch code { case codes.OK: return zap.InfoLevel case codes.Canceled: return zap.InfoLevel case codes.Unknown: return zap.ErrorLevel case codes.InvalidArgument: return zap.InfoLevel case codes.DeadlineExceeded: return zap.WarnLevel case codes.NotFound: return zap.InfoLevel case codes.AlreadyExists: return zap.InfoLevel case codes.PermissionDenied: return zap.WarnLevel case codes.Unauthenticated: return zap.InfoLevel // unauthenticated requests can happen case codes.ResourceExhausted: return zap.WarnLevel case codes.FailedPrecondition: return zap.WarnLevel case codes.Aborted: return zap.WarnLevel case codes.OutOfRange: return zap.WarnLevel case codes.Unimplemented: return zap.ErrorLevel case codes.Internal: return zap.ErrorLevel case codes.Unavailable: return zap.WarnLevel case codes.DataLoss: return zap.ErrorLevel default: return zap.ErrorLevel } } // DefaultClientCodeToLevel is the default implementation of gRPC return codes to log levels for client side. func DefaultClientCodeToLevel(code codes.Code) zapcore.Level { switch code { case codes.OK: return zap.DebugLevel case codes.Canceled: return zap.DebugLevel case codes.Unknown: return zap.InfoLevel case codes.InvalidArgument: return zap.DebugLevel case codes.DeadlineExceeded: return zap.InfoLevel case codes.NotFound: return zap.DebugLevel case codes.AlreadyExists: return zap.DebugLevel case codes.PermissionDenied: return zap.InfoLevel case codes.Unauthenticated: return zap.InfoLevel // unauthenticated requests can happen case codes.ResourceExhausted: return zap.DebugLevel case codes.FailedPrecondition: return zap.DebugLevel case codes.Aborted: return zap.DebugLevel case codes.OutOfRange: return zap.DebugLevel case codes.Unimplemented: return zap.WarnLevel case codes.Internal: return zap.WarnLevel case codes.Unavailable: return zap.WarnLevel case codes.DataLoss: return zap.WarnLevel default: return zap.InfoLevel } } // DefaultDurationToField is the default implementation of converting request duration to a Zap field. var DefaultDurationToField = DurationToTimeMillisField // DurationToTimeMillisField converts the duration to milliseconds and uses the key `grpc.time_ms`. func DurationToTimeMillisField(duration time.Duration) zapcore.Field { return zap.Float32("grpc.time_ms", durationToMilliseconds(duration)) } // DurationToDurationField uses a Duration field to log the request duration // and leaves it up to Zap's encoder settings to determine how that is output. func DurationToDurationField(duration time.Duration) zapcore.Field { return zap.Duration("grpc.duration", duration) } func durationToMilliseconds(duration time.Duration) float32 { return float32(duration.Nanoseconds()/1000) / 1000 } // MessageProducer produces a user defined log message type MessageProducer func(ctx context.Context, msg string, level zapcore.Level, code codes.Code, err error, duration zapcore.Field) // DefaultMessageProducer writes the default message func DefaultMessageProducer(ctx context.Context, msg string, level zapcore.Level, code codes.Code, err error, duration zapcore.Field) { // re-extract logger from newCtx, as it may have extra fields that changed in the holder. ctxzap.Extract(ctx).Check(level, msg).Write( zap.Error(err), zap.String("grpc.code", code.String()), duration, ) } go-grpc-middleware-1.3.0/logging/zap/options_test.go000066400000000000000000000006521404040257500224710ustar00rootroot00000000000000package grpc_zap import ( "math" "testing" "time" "github.com/stretchr/testify/assert" "go.uber.org/zap/zapcore" ) func TestDurationToTimeMillisField(t *testing.T) { val := DurationToTimeMillisField(time.Microsecond * 100) assert.Equal(t, val.Type, zapcore.Float32Type, "should be a float type") assert.Equal(t, math.Float32frombits(uint32(val.Integer)), float32(0.1), "sub millisecond values should be correct") } go-grpc-middleware-1.3.0/logging/zap/payload_interceptors.go000066400000000000000000000140261404040257500241710ustar00rootroot00000000000000package grpc_zap import ( "bytes" "context" "fmt" "github.com/golang/protobuf/jsonpb" "github.com/golang/protobuf/proto" "github.com/grpc-ecosystem/go-grpc-middleware/logging" "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap" "go.uber.org/zap" "go.uber.org/zap/zapcore" "google.golang.org/grpc" ) var ( // JsonPbMarshaller is the marshaller used for serializing protobuf messages. // If needed, this variable can be reassigned with a different marshaller with the same Marshal() signature. JsonPbMarshaller grpc_logging.JsonPbMarshaler = &jsonpb.Marshaler{} ) // PayloadUnaryServerInterceptor returns a new unary server interceptors that logs the payloads of requests. // // This *only* works when placed *after* the `grpc_zap.UnaryServerInterceptor`. However, the logging can be done to a // separate instance of the logger. func PayloadUnaryServerInterceptor(logger *zap.Logger, decider grpc_logging.ServerPayloadLoggingDecider) grpc.UnaryServerInterceptor { return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { if !decider(ctx, info.FullMethod, info.Server) { return handler(ctx, req) } // Use the provided zap.Logger for logging but use the fields from context. logEntry := logger.With(append(serverCallFields(info.FullMethod), ctxzap.TagsToFields(ctx)...)...) logProtoMessageAsJson(logEntry, req, "grpc.request.content", "server request payload logged as grpc.request.content field") resp, err := handler(ctx, req) if err == nil { logProtoMessageAsJson(logEntry, resp, "grpc.response.content", "server response payload logged as grpc.response.content field") } return resp, err } } // PayloadStreamServerInterceptor returns a new server server interceptors that logs the payloads of requests. // // This *only* works when placed *after* the `grpc_zap.StreamServerInterceptor`. However, the logging can be done to a // separate instance of the logger. func PayloadStreamServerInterceptor(logger *zap.Logger, decider grpc_logging.ServerPayloadLoggingDecider) grpc.StreamServerInterceptor { return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { if !decider(stream.Context(), info.FullMethod, srv) { return handler(srv, stream) } logEntry := logger.With(append(serverCallFields(info.FullMethod), ctxzap.TagsToFields(stream.Context())...)...) newStream := &loggingServerStream{ServerStream: stream, logger: logEntry} return handler(srv, newStream) } } // PayloadUnaryClientInterceptor returns a new unary client interceptor that logs the payloads of requests and responses. func PayloadUnaryClientInterceptor(logger *zap.Logger, decider grpc_logging.ClientPayloadLoggingDecider) grpc.UnaryClientInterceptor { return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { if !decider(ctx, method) { return invoker(ctx, method, req, reply, cc, opts...) } logEntry := logger.With(newClientLoggerFields(ctx, method)...) logProtoMessageAsJson(logEntry, req, "grpc.request.content", "client request payload logged as grpc.request.content") err := invoker(ctx, method, req, reply, cc, opts...) if err == nil { logProtoMessageAsJson(logEntry, reply, "grpc.response.content", "client response payload logged as grpc.response.content") } return err } } // PayloadStreamClientInterceptor returns a new streaming client interceptor that logs the payloads of requests and responses. func PayloadStreamClientInterceptor(logger *zap.Logger, decider grpc_logging.ClientPayloadLoggingDecider) grpc.StreamClientInterceptor { return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { if !decider(ctx, method) { return streamer(ctx, desc, cc, method, opts...) } logEntry := logger.With(newClientLoggerFields(ctx, method)...) clientStream, err := streamer(ctx, desc, cc, method, opts...) newStream := &loggingClientStream{ClientStream: clientStream, logger: logEntry} return newStream, err } } type loggingClientStream struct { grpc.ClientStream logger *zap.Logger } func (l *loggingClientStream) SendMsg(m interface{}) error { err := l.ClientStream.SendMsg(m) if err == nil { logProtoMessageAsJson(l.logger, m, "grpc.request.content", "server request payload logged as grpc.request.content field") } return err } func (l *loggingClientStream) RecvMsg(m interface{}) error { err := l.ClientStream.RecvMsg(m) if err == nil { logProtoMessageAsJson(l.logger, m, "grpc.response.content", "server response payload logged as grpc.response.content field") } return err } type loggingServerStream struct { grpc.ServerStream logger *zap.Logger } func (l *loggingServerStream) SendMsg(m interface{}) error { err := l.ServerStream.SendMsg(m) if err == nil { logProtoMessageAsJson(l.logger, m, "grpc.response.content", "server response payload logged as grpc.response.content field") } return err } func (l *loggingServerStream) RecvMsg(m interface{}) error { err := l.ServerStream.RecvMsg(m) if err == nil { logProtoMessageAsJson(l.logger, m, "grpc.request.content", "server request payload logged as grpc.request.content field") } return err } func logProtoMessageAsJson(logger *zap.Logger, pbMsg interface{}, key string, msg string) { if p, ok := pbMsg.(proto.Message); ok { logger.Check(zapcore.InfoLevel, msg).Write(zap.Object(key, &jsonpbObjectMarshaler{pb: p})) } } type jsonpbObjectMarshaler struct { pb proto.Message } func (j *jsonpbObjectMarshaler) MarshalLogObject(e zapcore.ObjectEncoder) error { // ZAP jsonEncoder deals with AddReflect by using json.MarshalObject. The same thing applies for consoleEncoder. return e.AddReflected("msg", j) } func (j *jsonpbObjectMarshaler) MarshalJSON() ([]byte, error) { b := &bytes.Buffer{} if err := JsonPbMarshaller.Marshal(b, j.pb); err != nil { return nil, fmt.Errorf("jsonpb serializer failed: %v", err) } return b.Bytes(), nil } go-grpc-middleware-1.3.0/logging/zap/payload_interceptors_test.go000066400000000000000000000133351404040257500252320ustar00rootroot00000000000000package grpc_zap_test import ( "context" "io" "runtime" "strings" "testing" "github.com/stretchr/testify/suite" "google.golang.org/grpc" "github.com/grpc-ecosystem/go-grpc-middleware" "github.com/grpc-ecosystem/go-grpc-middleware/tags" pb_testproto "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap" "go.uber.org/zap" "go.uber.org/zap/zapcore" ) func TestZapPayloadSuite(t *testing.T) { if strings.HasPrefix(runtime.Version(), "go1.7") { t.Skipf("Skipping due to json.RawMessage incompatibility with go1.7") return } alwaysLoggingDeciderServer := func(ctx context.Context, fullMethodName string, servingObject interface{}) bool { return true } alwaysLoggingDeciderClient := func(ctx context.Context, fullMethodName string) bool { return true } b := newBaseZapSuite(t) b.InterceptorTestSuite.ClientOpts = []grpc.DialOption{ grpc.WithUnaryInterceptor(grpc_zap.PayloadUnaryClientInterceptor(b.log, alwaysLoggingDeciderClient)), grpc.WithStreamInterceptor(grpc_zap.PayloadStreamClientInterceptor(b.log, alwaysLoggingDeciderClient)), } noOpZap := zap.New(zapcore.NewNopCore()) b.InterceptorTestSuite.ServerOpts = []grpc.ServerOption{ grpc_middleware.WithStreamServerChain( grpc_ctxtags.StreamServerInterceptor(grpc_ctxtags.WithFieldExtractor(grpc_ctxtags.CodeGenRequestFieldExtractor)), grpc_zap.StreamServerInterceptor(noOpZap), grpc_zap.PayloadStreamServerInterceptor(b.log, alwaysLoggingDeciderServer)), grpc_middleware.WithUnaryServerChain( grpc_ctxtags.UnaryServerInterceptor(grpc_ctxtags.WithFieldExtractor(grpc_ctxtags.CodeGenRequestFieldExtractor)), grpc_zap.UnaryServerInterceptor(noOpZap), grpc_zap.PayloadUnaryServerInterceptor(b.log, alwaysLoggingDeciderServer)), } suite.Run(t, &zapPayloadSuite{b}) } type zapPayloadSuite struct { *zapBaseSuite } func (s *zapPayloadSuite) getServerAndClientMessages(expectedServer int, expectedClient int) (serverMsgs []map[string]interface{}, clientMsgs []map[string]interface{}) { msgs := s.getOutputJSONs() for _, m := range msgs { if m["span.kind"] == "server" { serverMsgs = append(serverMsgs, m) } else if m["span.kind"] == "client" { clientMsgs = append(clientMsgs, m) } } require.Len(s.T(), serverMsgs, expectedServer, "must match expected number of server log messages") require.Len(s.T(), clientMsgs, expectedClient, "must match expected number of client log messages") return serverMsgs, clientMsgs } func (s *zapPayloadSuite) TestPing_LogsBothRequestAndResponse() { _, err := s.Client.Ping(s.SimpleCtx(), goodPing) require.NoError(s.T(), err, "there must be not be an error on a successful call") serverMsgs, clientMsgs := s.getServerAndClientMessages(2, 2) for _, m := range append(serverMsgs, clientMsgs...) { assert.Equal(s.T(), m["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain service name") assert.Equal(s.T(), m["grpc.method"], "Ping", "all lines must contain method name") assert.Equal(s.T(), m["level"], "info", "all payloads must be logged on info level") } serverReq, serverResp := serverMsgs[0], serverMsgs[1] clientReq, clientResp := clientMsgs[0], clientMsgs[1] s.T().Log(clientReq) assert.Contains(s.T(), clientReq, "grpc.request.content", "request payload must be logged in a structured way") assert.Contains(s.T(), serverReq, "grpc.request.content", "request payload must be logged in a structured way") assert.Contains(s.T(), clientResp, "grpc.response.content", "response payload must be logged in a structured way") assert.Contains(s.T(), serverResp, "grpc.response.content", "response payload must be logged in a structured way") } func (s *zapPayloadSuite) TestPingError_LogsOnlyRequestsOnError() { _, err := s.Client.PingError(s.SimpleCtx(), &pb_testproto.PingRequest{Value: "something", ErrorCodeReturned: uint32(4)}) require.Error(s.T(), err, "there must be an error on an unsuccessful call") serverMsgs, clientMsgs := s.getServerAndClientMessages(1, 1) for _, m := range append(serverMsgs, clientMsgs...) { assert.Equal(s.T(), m["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain service name") assert.Equal(s.T(), m["grpc.method"], "PingError", "all lines must contain method name") assert.Equal(s.T(), m["level"], "info", "must be logged at the info level") } assert.Contains(s.T(), clientMsgs[0], "grpc.request.content", "request payload must be logged in a structured way") assert.Contains(s.T(), serverMsgs[0], "grpc.request.content", "request payload must be logged in a structured way") } func (s *zapPayloadSuite) TestPingStream_LogsAllRequestsAndResponses() { messagesExpected := 20 stream, err := s.Client.PingStream(s.SimpleCtx()) require.NoError(s.T(), err, "no error on stream creation") for i := 0; i < messagesExpected; i++ { require.NoError(s.T(), stream.Send(goodPing), "sending must succeed") } require.NoError(s.T(), stream.CloseSend(), "no error on send stream") for { pong := &pb_testproto.PingResponse{} err := stream.RecvMsg(pong) if err == io.EOF { break } require.NoError(s.T(), err, "no error on receive") } serverMsgs, clientMsgs := s.getServerAndClientMessages(2*messagesExpected, 2*messagesExpected) for _, m := range append(serverMsgs, clientMsgs...) { assert.Equal(s.T(), m["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain service name") assert.Equal(s.T(), m["grpc.method"], "PingStream", "all lines must contain method name") assert.Equal(s.T(), m["level"], "info", "all lines must logged at info level") content := m["grpc.request.content"] != nil || m["grpc.response.content"] != nil assert.True(s.T(), content, "all messages must contain payloads") } } go-grpc-middleware-1.3.0/logging/zap/server_interceptors.go000066400000000000000000000060031404040257500240420ustar00rootroot00000000000000package grpc_zap import ( "context" "path" "time" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap" "go.uber.org/zap" "go.uber.org/zap/zapcore" "google.golang.org/grpc" ) var ( // SystemField is used in every log statement made through grpc_zap. Can be overwritten before any initialization code. SystemField = zap.String("system", "grpc") // ServerField is used in every server-side log statement made through grpc_zap.Can be overwritten before initialization. ServerField = zap.String("span.kind", "server") ) // UnaryServerInterceptor returns a new unary server interceptors that adds zap.Logger to the context. func UnaryServerInterceptor(logger *zap.Logger, opts ...Option) grpc.UnaryServerInterceptor { o := evaluateServerOpt(opts) return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { startTime := time.Now() newCtx := newLoggerForCall(ctx, logger, info.FullMethod, startTime, o.timestampFormat) resp, err := handler(newCtx, req) if !o.shouldLog(info.FullMethod, err) { return resp, err } code := o.codeFunc(err) level := o.levelFunc(code) duration := o.durationFunc(time.Since(startTime)) o.messageFunc(newCtx, "finished unary call with code "+code.String(), level, code, err, duration) return resp, err } } // StreamServerInterceptor returns a new streaming server interceptor that adds zap.Logger to the context. func StreamServerInterceptor(logger *zap.Logger, opts ...Option) grpc.StreamServerInterceptor { o := evaluateServerOpt(opts) return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { startTime := time.Now() newCtx := newLoggerForCall(stream.Context(), logger, info.FullMethod, startTime, o.timestampFormat) wrapped := grpc_middleware.WrapServerStream(stream) wrapped.WrappedContext = newCtx err := handler(srv, wrapped) if !o.shouldLog(info.FullMethod, err) { return err } code := o.codeFunc(err) level := o.levelFunc(code) duration := o.durationFunc(time.Since(startTime)) o.messageFunc(newCtx, "finished streaming call with code "+code.String(), level, code, err, duration) return err } } func serverCallFields(fullMethodString string) []zapcore.Field { service := path.Dir(fullMethodString)[1:] method := path.Base(fullMethodString) return []zapcore.Field{ SystemField, ServerField, zap.String("grpc.service", service), zap.String("grpc.method", method), } } func newLoggerForCall(ctx context.Context, logger *zap.Logger, fullMethodString string, start time.Time, timestampFormat string) context.Context { var f []zapcore.Field f = append(f, zap.String("grpc.start_time", start.Format(timestampFormat))) if d, ok := ctx.Deadline(); ok { f = append(f, zap.String("grpc.request.deadline", d.Format(timestampFormat))) } callLog := logger.With(append(f, serverCallFields(fullMethodString)...)...) return ctxzap.ToContext(ctx, callLog) } go-grpc-middleware-1.3.0/logging/zap/server_interceptors_test.go000066400000000000000000000373031404040257500251100ustar00rootroot00000000000000package grpc_zap_test import ( "io" "runtime" "strings" "testing" "time" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" grpc_zap "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap" grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags" pb_testproto "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "go.uber.org/zap" "go.uber.org/zap/zapcore" "google.golang.org/grpc" "google.golang.org/grpc/codes" ) func customCodeToLevel(c codes.Code) zapcore.Level { if c == codes.Unauthenticated { // Make this a special case for tests, and an error. return zap.DPanicLevel } level := grpc_zap.DefaultCodeToLevel(c) return level } func TestZapLoggingSuite(t *testing.T) { if strings.HasPrefix(runtime.Version(), "go1.7") { t.Skipf("Skipping due to json.RawMessage incompatibility with go1.7") return } for _, tcase := range []struct { timestampFormat string }{ { timestampFormat: time.RFC3339, }, { timestampFormat: "2006-01-02", }, } { opts := []grpc_zap.Option{ grpc_zap.WithLevels(customCodeToLevel), grpc_zap.WithTimestampFormat(tcase.timestampFormat), } b := newBaseZapSuite(t) b.timestampFormat = tcase.timestampFormat b.InterceptorTestSuite.ServerOpts = []grpc.ServerOption{ grpc_middleware.WithStreamServerChain( grpc_ctxtags.StreamServerInterceptor(grpc_ctxtags.WithFieldExtractor(grpc_ctxtags.CodeGenRequestFieldExtractor)), grpc_zap.StreamServerInterceptor(b.log, opts...)), grpc_middleware.WithUnaryServerChain( grpc_ctxtags.UnaryServerInterceptor(grpc_ctxtags.WithFieldExtractor(grpc_ctxtags.CodeGenRequestFieldExtractor)), grpc_zap.UnaryServerInterceptor(b.log, opts...)), } suite.Run(t, &zapServerSuite{b}) } } type zapServerSuite struct { *zapBaseSuite } func (s *zapServerSuite) TestPing_WithCustomTags() { deadline := time.Now().Add(3 * time.Second) _, err := s.Client.Ping(s.DeadlineCtx(deadline), goodPing) require.NoError(s.T(), err, "there must be not be an error on a successful call") msgs := s.getOutputJSONs() require.Len(s.T(), msgs, 2, "two log statements should be logged") for _, m := range msgs { assert.Equal(s.T(), m["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain service name") assert.Equal(s.T(), m["grpc.method"], "Ping", "all lines must contain method name") assert.Equal(s.T(), m["span.kind"], "server", "all lines must contain the kind of call (server)") assert.Equal(s.T(), m["custom_tags.string"], "something", "all lines must contain `custom_tags.string`") assert.Equal(s.T(), m["grpc.request.value"], "something", "all lines must contain fields extracted") assert.Equal(s.T(), m["custom_field"], "custom_value", "all lines must contain `custom_field`") assert.Contains(s.T(), m, "custom_tags.int", "all lines must contain `custom_tags.int`") require.Contains(s.T(), m, "grpc.start_time", "all lines must contain the start time") _, err := time.Parse(s.timestampFormat, m["grpc.start_time"].(string)) assert.NoError(s.T(), err, "should be able to parse start time") require.Contains(s.T(), m, "grpc.request.deadline", "all lines must contain the deadline of the call") _, err = time.Parse(s.timestampFormat, m["grpc.request.deadline"].(string)) require.NoError(s.T(), err, "should be able to parse deadline") assert.Equal(s.T(), m["grpc.request.deadline"], deadline.Format(s.timestampFormat), "should have the same deadline that was set by the caller") } assert.Equal(s.T(), msgs[0]["msg"], "some ping", "handler's message must contain user message") assert.Equal(s.T(), msgs[1]["msg"], "finished unary call with code OK", "handler's message must contain user message") assert.Equal(s.T(), msgs[1]["level"], "info", "must be logged at info level") assert.Contains(s.T(), msgs[1], "grpc.time_ms", "interceptor log statement should contain execution time") } func (s *zapServerSuite) TestPingError_WithCustomLevels() { for _, tcase := range []struct { code codes.Code level zapcore.Level msg string }{ { code: codes.Internal, level: zap.ErrorLevel, msg: "Internal must remap to ErrorLevel in DefaultCodeToLevel", }, { code: codes.NotFound, level: zap.InfoLevel, msg: "NotFound must remap to InfoLevel in DefaultCodeToLevel", }, { code: codes.FailedPrecondition, level: zap.WarnLevel, msg: "FailedPrecondition must remap to WarnLevel in DefaultCodeToLevel", }, { code: codes.Unauthenticated, level: zap.DPanicLevel, msg: "Unauthenticated is overwritten to DPanicLevel with customCodeToLevel override, which probably didn't work", }, } { s.buffer.Reset() _, err := s.Client.PingError( s.SimpleCtx(), &pb_testproto.PingRequest{Value: "something", ErrorCodeReturned: uint32(tcase.code)}) require.Error(s.T(), err, "each call here must return an error") msgs := s.getOutputJSONs() require.Len(s.T(), msgs, 1, "only the interceptor log message is printed in PingErr") m := msgs[0] assert.Equal(s.T(), m["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain service name") assert.Equal(s.T(), m["grpc.method"], "PingError", "all lines must contain method name") assert.Equal(s.T(), m["grpc.code"], tcase.code.String(), "all lines have the correct gRPC code") assert.Equal(s.T(), m["level"], tcase.level.String(), tcase.msg) assert.Equal(s.T(), m["msg"], "finished unary call with code "+tcase.code.String(), "needs the correct end message") require.Contains(s.T(), m, "grpc.start_time", "all lines must contain the start time") _, err = time.Parse(s.timestampFormat, m["grpc.start_time"].(string)) assert.NoError(s.T(), err, "should be able to parse start time") } } func (s *zapServerSuite) TestPingList_WithCustomTags() { stream, err := s.Client.PingList(s.SimpleCtx(), goodPing) require.NoError(s.T(), err, "should not fail on establishing the stream") for { _, err := stream.Recv() if err == io.EOF { break } require.NoError(s.T(), err, "reading stream should not fail") } msgs := s.getOutputJSONs() require.Len(s.T(), msgs, 2, "two log statements should be logged") for _, m := range msgs { assert.Equal(s.T(), m["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain service name") assert.Equal(s.T(), m["grpc.method"], "PingList", "all lines must contain method name") assert.Equal(s.T(), m["span.kind"], "server", "all lines must contain the kind of call (server)") assert.Equal(s.T(), m["custom_tags.string"], "something", "all lines must contain `custom_tags.string` set by AddFields") assert.Equal(s.T(), m["grpc.request.value"], "something", "all lines must contain fields extracted from goodPing because of test.manual_extractfields.pb") assert.Contains(s.T(), m, "custom_tags.int", "all lines must contain `custom_tags.int` set by AddFields") require.Contains(s.T(), m, "grpc.start_time", "all lines must contain the start time") _, err := time.Parse(s.timestampFormat, m["grpc.start_time"].(string)) assert.NoError(s.T(), err, "should be able to parse start time") } assert.Equal(s.T(), msgs[0]["msg"], "some pinglist", "handler's message must contain user message") assert.Equal(s.T(), msgs[1]["msg"], "finished streaming call with code OK", "handler's message must contain user message") assert.Equal(s.T(), msgs[1]["level"], "info", "OK codes must be logged on info level.") assert.Contains(s.T(), msgs[1], "grpc.time_ms", "interceptor log statement should contain execution time") } func TestZapLoggingOverrideSuite(t *testing.T) { if strings.HasPrefix(runtime.Version(), "go1.7") { t.Skip("Skipping due to json.RawMessage incompatibility with go1.7") return } opts := []grpc_zap.Option{ grpc_zap.WithDurationField(grpc_zap.DurationToDurationField), } b := newBaseZapSuite(t) b.InterceptorTestSuite.ServerOpts = []grpc.ServerOption{ grpc_middleware.WithStreamServerChain( grpc_ctxtags.StreamServerInterceptor(), grpc_zap.StreamServerInterceptor(b.log, opts...)), grpc_middleware.WithUnaryServerChain( grpc_ctxtags.UnaryServerInterceptor(), grpc_zap.UnaryServerInterceptor(b.log, opts...)), } suite.Run(t, &zapServerOverrideSuite{b}) } type zapServerOverrideSuite struct { *zapBaseSuite } func (s *zapServerOverrideSuite) TestPing_HasOverriddenDuration() { _, err := s.Client.Ping(s.SimpleCtx(), goodPing) require.NoError(s.T(), err, "there must be not be an error on a successful call") msgs := s.getOutputJSONs() require.Len(s.T(), msgs, 2, "two log statements should be logged") for _, m := range msgs { assert.Equal(s.T(), m["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain service name") assert.Equal(s.T(), m["grpc.method"], "Ping", "all lines must contain method name") } assert.Equal(s.T(), msgs[0]["msg"], "some ping", "handler's message must contain user message") assert.NotContains(s.T(), msgs[0], "grpc.time_ms", "handler's message must not contain default duration") assert.NotContains(s.T(), msgs[0], "grpc.duration", "handler's message must not contain overridden duration") assert.Equal(s.T(), msgs[1]["msg"], "finished unary call with code OK", "handler's message must contain user message") assert.Equal(s.T(), msgs[1]["level"], "info", "OK error codes must be logged on info level.") assert.NotContains(s.T(), msgs[1], "grpc.time_ms", "handler's message must not contain default duration") assert.Contains(s.T(), msgs[1], "grpc.duration", "handler's message must contain overridden duration") } func (s *zapServerOverrideSuite) TestPingList_HasOverriddenDuration() { stream, err := s.Client.PingList(s.SimpleCtx(), goodPing) require.NoError(s.T(), err, "should not fail on establishing the stream") for { _, err := stream.Recv() if err == io.EOF { break } require.NoError(s.T(), err, "reading stream should not fail") } msgs := s.getOutputJSONs() require.Len(s.T(), msgs, 2, "two log statements should be logged") for _, m := range msgs { s.T() assert.Equal(s.T(), m["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain service name") assert.Equal(s.T(), m["grpc.method"], "PingList", "all lines must contain method name") } assert.Equal(s.T(), msgs[0]["msg"], "some pinglist", "handler's message must contain user message") assert.NotContains(s.T(), msgs[0], "grpc.time_ms", "handler's message must not contain default duration") assert.NotContains(s.T(), msgs[0], "grpc.duration", "handler's message must not contain overridden duration") assert.Equal(s.T(), msgs[1]["msg"], "finished streaming call with code OK", "handler's message must contain user message") assert.Equal(s.T(), msgs[1]["level"], "info", "OK error codes must be logged on info level.") assert.NotContains(s.T(), msgs[1], "grpc.time_ms", "handler's message must not contain default duration") assert.Contains(s.T(), msgs[1], "grpc.duration", "handler's message must contain overridden duration") } func TestZapServerOverrideSuppressedSuite(t *testing.T) { if strings.HasPrefix(runtime.Version(), "go1.7") { t.Skip("Skipping due to json.RawMessage incompatibility with go1.7") return } opts := []grpc_zap.Option{ grpc_zap.WithDecider(func(method string, err error) bool { if err != nil && method == "/mwitkow.testproto.TestService/PingError" { return true } return false }), } b := newBaseZapSuite(t) b.InterceptorTestSuite.ServerOpts = []grpc.ServerOption{ grpc_middleware.WithStreamServerChain( grpc_ctxtags.StreamServerInterceptor(), grpc_zap.StreamServerInterceptor(b.log, opts...)), grpc_middleware.WithUnaryServerChain( grpc_ctxtags.UnaryServerInterceptor(), grpc_zap.UnaryServerInterceptor(b.log, opts...)), } suite.Run(t, &zapServerOverriddenDeciderSuite{b}) } type zapServerOverriddenDeciderSuite struct { *zapBaseSuite } func (s *zapServerOverriddenDeciderSuite) TestPing_HasOverriddenDecider() { _, err := s.Client.Ping(s.SimpleCtx(), goodPing) require.NoError(s.T(), err, "there must be not be an error on a successful call") msgs := s.getOutputJSONs() require.Len(s.T(), msgs, 1, "single log statements should be logged") assert.Equal(s.T(), msgs[0]["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain service name") assert.Equal(s.T(), msgs[0]["grpc.method"], "Ping", "all lines must contain method name") assert.Equal(s.T(), msgs[0]["msg"], "some ping", "handler's message must contain user message") } func (s *zapServerOverriddenDeciderSuite) TestPingError_HasOverriddenDecider() { code := codes.NotFound level := zapcore.InfoLevel msg := "NotFound must remap to InfoLevel in DefaultCodeToLevel" s.buffer.Reset() _, err := s.Client.PingError( s.SimpleCtx(), &pb_testproto.PingRequest{Value: "something", ErrorCodeReturned: uint32(code)}) require.Error(s.T(), err, "each call here must return an error") msgs := s.getOutputJSONs() require.Len(s.T(), msgs, 1, "only the interceptor log message is printed in PingErr") m := msgs[0] assert.Equal(s.T(), m["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain service name") assert.Equal(s.T(), m["grpc.method"], "PingError", "all lines must contain method name") assert.Equal(s.T(), m["grpc.code"], code.String(), "all lines must contain the correct gRPC code") assert.Equal(s.T(), m["level"], level.String(), msg) } func (s *zapServerOverriddenDeciderSuite) TestPingList_HasOverriddenDecider() { stream, err := s.Client.PingList(s.SimpleCtx(), goodPing) require.NoError(s.T(), err, "should not fail on establishing the stream") for { _, err := stream.Recv() if err == io.EOF { break } require.NoError(s.T(), err, "reading stream should not fail") } msgs := s.getOutputJSONs() require.Len(s.T(), msgs, 1, "single log statements should be logged") assert.Equal(s.T(), msgs[0]["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain service name") assert.Equal(s.T(), msgs[0]["grpc.method"], "PingList", "all lines must contain method name") assert.Equal(s.T(), msgs[0]["msg"], "some pinglist", "handler's message must contain user message") assert.NotContains(s.T(), msgs[0], "grpc.time_ms", "handler's message must not contain default duration") assert.NotContains(s.T(), msgs[0], "grpc.duration", "handler's message must not contain overridden duration") } func TestZapLoggingServerMessageProducerSuite(t *testing.T) { if strings.HasPrefix(runtime.Version(), "go1.7") { t.Skip("Skipping due to json.RawMessage incompatibility with go1.7") return } opts := []grpc_zap.Option{ grpc_zap.WithMessageProducer(StubMessageProducer), } b := newBaseZapSuite(t) b.InterceptorTestSuite.ServerOpts = []grpc.ServerOption{ grpc_middleware.WithStreamServerChain( grpc_ctxtags.StreamServerInterceptor(), grpc_zap.StreamServerInterceptor(b.log, opts...)), grpc_middleware.WithUnaryServerChain( grpc_ctxtags.UnaryServerInterceptor(), grpc_zap.UnaryServerInterceptor(b.log, opts...)), } suite.Run(t, &zapServerMessageProducerSuite{b}) } type zapServerMessageProducerSuite struct { *zapBaseSuite } func (s *zapServerMessageProducerSuite) TestPing_HasOverriddenMessageProducer() { _, err := s.Client.Ping(s.SimpleCtx(), goodPing) require.NoError(s.T(), err, "there must be not be an error on a successful call") msgs := s.getOutputJSONs() require.Len(s.T(), msgs, 2, "two log statements should be logged") for _, m := range msgs { assert.Equal(s.T(), m["grpc.service"], "mwitkow.testproto.TestService", "all lines must contain service name") assert.Equal(s.T(), m["grpc.method"], "Ping", "all lines must contain method name") } assert.Equal(s.T(), msgs[0]["msg"], "some ping", "handler's message must contain user message") assert.Equal(s.T(), msgs[1]["msg"], "custom message", "handler's message must contain user message") assert.Equal(s.T(), msgs[1]["level"], "info", "OK error codes must be logged on info level.") } go-grpc-middleware-1.3.0/logging/zap/settable_test.go000066400000000000000000000015661404040257500226060ustar00rootroot00000000000000package grpc_zap_test import ( "testing" grpc_logsettable "github.com/grpc-ecosystem/go-grpc-middleware/logging/settable" grpc_zap "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap" "go.uber.org/zap/zaptest" ) var grpc_logger grpc_logsettable.SettableLoggerV2 func init() { grpc_logger = grpc_logsettable.ReplaceGrpcLoggerV2() } func beforeTest(t testing.TB) { grpc_zap.SetGrpcLoggerV2(grpc_logger, zaptest.NewLogger(t)) // Starting from go-1.15+ automated 'reset' can also be set: // t.Cleanup(func() { // grpc_logger.Reset() // }) } // This test illustrates setting up a testing harness that attributes // all grpc logs emitted during the test to the test-specific log. // // In case of test failure, only logs emitted by this testcase will be printed. func TestSpecificLogging(t *testing.T) { beforeTest(t) grpc_logger.Info("Test specific log-line") } go-grpc-middleware-1.3.0/logging/zap/shared_test.go000066400000000000000000000070371404040257500222500ustar00rootroot00000000000000package grpc_zap_test import ( "bytes" "context" "encoding/json" "io" "testing" "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap" grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags" grpc_testing "github.com/grpc-ecosystem/go-grpc-middleware/testing" pb_testproto "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto" "go.uber.org/zap" "go.uber.org/zap/zapcore" "google.golang.org/grpc/codes" ) var ( goodPing = &pb_testproto.PingRequest{Value: "something", SleepTimeMs: 9999} ) type loggingPingService struct { pb_testproto.TestServiceServer } func (s *loggingPingService) Ping(ctx context.Context, ping *pb_testproto.PingRequest) (*pb_testproto.PingResponse, error) { grpc_ctxtags.Extract(ctx).Set("custom_tags.string", "something").Set("custom_tags.int", 1337) ctxzap.AddFields(ctx, zap.String("custom_field", "custom_value")) ctxzap.Extract(ctx).Info("some ping") return s.TestServiceServer.Ping(ctx, ping) } func (s *loggingPingService) PingError(ctx context.Context, ping *pb_testproto.PingRequest) (*pb_testproto.Empty, error) { return s.TestServiceServer.PingError(ctx, ping) } func (s *loggingPingService) PingList(ping *pb_testproto.PingRequest, stream pb_testproto.TestService_PingListServer) error { grpc_ctxtags.Extract(stream.Context()).Set("custom_tags.string", "something").Set("custom_tags.int", 1337) ctxzap.Extract(stream.Context()).Info("some pinglist") return s.TestServiceServer.PingList(ping, stream) } func (s *loggingPingService) PingEmpty(ctx context.Context, empty *pb_testproto.Empty) (*pb_testproto.PingResponse, error) { return s.TestServiceServer.PingEmpty(ctx, empty) } func newBaseZapSuite(t *testing.T) *zapBaseSuite { b := &bytes.Buffer{} muB := grpc_testing.NewMutexReadWriter(b) zap.NewDevelopmentConfig() jsonEncoder := zapcore.NewJSONEncoder(zapcore.EncoderConfig{ TimeKey: "ts", LevelKey: "level", NameKey: "logger", CallerKey: "caller", MessageKey: "msg", StacktraceKey: "stacktrace", EncodeLevel: zapcore.LowercaseLevelEncoder, EncodeTime: zapcore.EpochTimeEncoder, EncodeDuration: zapcore.SecondsDurationEncoder, }) core := zapcore.NewCore(jsonEncoder, zapcore.AddSync(muB), zap.LevelEnablerFunc(func(zapcore.Level) bool { return true })) log := zap.New(core) s := &zapBaseSuite{ log: log, buffer: b, mutexBuffer: muB, InterceptorTestSuite: &grpc_testing.InterceptorTestSuite{ TestService: &loggingPingService{&grpc_testing.TestPingService{T: t}}, }, } return s } type zapBaseSuite struct { *grpc_testing.InterceptorTestSuite mutexBuffer *grpc_testing.MutexReadWriter buffer *bytes.Buffer log *zap.Logger timestampFormat string } func (s *zapBaseSuite) SetupTest() { s.mutexBuffer.Lock() s.buffer.Reset() s.mutexBuffer.Unlock() } func (s *zapBaseSuite) getOutputJSONs() []map[string]interface{} { ret := make([]map[string]interface{}, 0) dec := json.NewDecoder(s.mutexBuffer) for { var val map[string]interface{} err := dec.Decode(&val) if err == io.EOF { break } if err != nil { s.T().Fatalf("failed decoding output from Logrus JSON: %v", err) } ret = append(ret, val) } return ret } func StubMessageProducer(ctx context.Context, msg string, level zapcore.Level, code codes.Code, err error, duration zapcore.Field) { // re-extract logger from newCtx, as it may have extra fields that changed in the holder. ctxzap.Extract(ctx).Check(level, "custom message").Write( zap.Error(err), zap.String("grpc.code", code.String()), duration, ) } go-grpc-middleware-1.3.0/makefile000066400000000000000000000004351404040257500166670ustar00rootroot00000000000000SHELL=/bin/bash GOFILES_NOVENDOR = $(shell go list ./... | grep -v /vendor/) all: vet fmt test fmt: go fmt $(GOFILES_NOVENDOR) vet: # do not check lostcancel, they are intentional. go vet -lostcancel=false $(GOFILES_NOVENDOR) test: vet ./scripts/test_all.sh .PHONY: all test go-grpc-middleware-1.3.0/ratelimit/000077500000000000000000000000001404040257500171575ustar00rootroot00000000000000go-grpc-middleware-1.3.0/ratelimit/doc.go000066400000000000000000000004161404040257500202540ustar00rootroot00000000000000/* `ratelimit` a generic server-side ratelimit middleware for gRPC. Server Side Ratelimit Middleware It allows to do grpc rate limit by your own rate limiter (e.g. token bucket, leaky bucket, etc.) Please see examples for simple examples of use. */ package ratelimit go-grpc-middleware-1.3.0/ratelimit/examples_test.go000066400000000000000000000015251404040257500223660ustar00rootroot00000000000000package ratelimit_test import ( "github.com/grpc-ecosystem/go-grpc-middleware" "github.com/grpc-ecosystem/go-grpc-middleware/ratelimit" "google.golang.org/grpc" ) // alwaysPassLimiter is an example limiter which implements Limiter interface. // It does not limit any request because Limit function always returns false. type alwaysPassLimiter struct{} func (*alwaysPassLimiter) Limit() bool { return false } // Simple example of server initialization code. func Example() { // Create unary/stream rateLimiters, based on token bucket here. // You can implement your own ratelimiter for the interface. limiter := &alwaysPassLimiter{} _ = grpc.NewServer( grpc_middleware.WithUnaryServerChain( ratelimit.UnaryServerInterceptor(limiter), ), grpc_middleware.WithStreamServerChain( ratelimit.StreamServerInterceptor(limiter), ), ) } go-grpc-middleware-1.3.0/ratelimit/ratelimit.go000066400000000000000000000025131404040257500215010ustar00rootroot00000000000000package ratelimit import ( "context" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) // Limiter defines the interface to perform request rate limiting. // If Limit function return true, the request will be rejected. // Otherwise, the request will pass. type Limiter interface { Limit() bool } // UnaryServerInterceptor returns a new unary server interceptors that performs request rate limiting. func UnaryServerInterceptor(limiter Limiter) grpc.UnaryServerInterceptor { return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { if limiter.Limit() { return nil, status.Errorf(codes.ResourceExhausted, "%s is rejected by grpc_ratelimit middleware, please retry later.", info.FullMethod) } return handler(ctx, req) } } // StreamServerInterceptor returns a new stream server interceptor that performs rate limiting on the request. func StreamServerInterceptor(limiter Limiter) grpc.StreamServerInterceptor { return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { if limiter.Limit() { return status.Errorf(codes.ResourceExhausted, "%s is rejected by grpc_ratelimit middleware, please retry later.", info.FullMethod) } return handler(srv, stream) } } go-grpc-middleware-1.3.0/ratelimit/ratelimit_test.go000066400000000000000000000041261404040257500225420ustar00rootroot00000000000000package ratelimit import ( "context" "errors" "testing" "google.golang.org/grpc" "github.com/stretchr/testify/assert" ) const errMsgFake = "fake error" type mockPassLimiter struct{} func (*mockPassLimiter) Limit() bool { return false } func TestUnaryServerInterceptor_RateLimitPass(t *testing.T) { interceptor := UnaryServerInterceptor(&mockPassLimiter{}) handler := func(ctx context.Context, req interface{}) (interface{}, error) { return nil, errors.New(errMsgFake) } info := &grpc.UnaryServerInfo{ FullMethod: "FakeMethod", } resp, err := interceptor(nil, nil, info, handler) assert.Nil(t, resp) assert.EqualError(t, err, errMsgFake) } type mockFailLimiter struct{} func (*mockFailLimiter) Limit() bool { return true } func TestUnaryServerInterceptor_RateLimitFail(t *testing.T) { interceptor := UnaryServerInterceptor(&mockFailLimiter{}) handler := func(ctx context.Context, req interface{}) (interface{}, error) { return nil, errors.New(errMsgFake) } info := &grpc.UnaryServerInfo{ FullMethod: "FakeMethod", } resp, err := interceptor(nil, nil, info, handler) assert.Nil(t, resp) assert.EqualError(t, err, "rpc error: code = ResourceExhausted desc = FakeMethod is rejected by grpc_ratelimit middleware, please retry later.") } func TestStreamServerInterceptor_RateLimitPass(t *testing.T) { interceptor := StreamServerInterceptor(&mockPassLimiter{}) handler := func(srv interface{}, stream grpc.ServerStream) error { return errors.New(errMsgFake) } info := &grpc.StreamServerInfo{ FullMethod: "FakeMethod", } err := interceptor(nil, nil, info, handler) assert.EqualError(t, err, errMsgFake) } func TestStreamServerInterceptor_RateLimitFail(t *testing.T) { interceptor := StreamServerInterceptor(&mockFailLimiter{}) handler := func(srv interface{}, stream grpc.ServerStream) error { return errors.New(errMsgFake) } info := &grpc.StreamServerInfo{ FullMethod: "FakeMethod", } err := interceptor(nil, nil, info, handler) assert.EqualError(t, err, "rpc error: code = ResourceExhausted desc = FakeMethod is rejected by grpc_ratelimit middleware, please retry later.") } go-grpc-middleware-1.3.0/recovery/000077500000000000000000000000001404040257500170235ustar00rootroot00000000000000go-grpc-middleware-1.3.0/recovery/doc.go000066400000000000000000000006501404040257500201200ustar00rootroot00000000000000// Copyright 2017 David Ackroyd. All Rights Reserved. // See LICENSE for licensing terms. /* `grpc_recovery` are interceptors that recover from gRPC handler panics. Server Side Recovery Middleware By default a panic will be converted into a gRPC error with `code.Internal`. Handling can be customised by providing an alternate recovery function. Please see examples for simple examples of use. */ package grpc_recovery go-grpc-middleware-1.3.0/recovery/examples_test.go000066400000000000000000000023341404040257500222310ustar00rootroot00000000000000// Copyright 2017 David Ackroyd. All Rights Reserved. // See LICENSE for licensing terms. package grpc_recovery_test import ( "github.com/grpc-ecosystem/go-grpc-middleware" "github.com/grpc-ecosystem/go-grpc-middleware/recovery" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) var ( customFunc grpc_recovery.RecoveryHandlerFunc ) // Initialization shows an initialization sequence with a custom recovery handler func. func Example_initialization() { // Define customfunc to handle panic customFunc = func(p interface{}) (err error) { return status.Errorf(codes.Unknown, "panic triggered: %v", p) } // Shared options for the logger, with a custom gRPC code to log level function. opts := []grpc_recovery.Option{ grpc_recovery.WithRecoveryHandler(customFunc), } // Create a server. Recovery handlers should typically be last in the chain so that other middleware // (e.g. logging) can operate on the recovered state instead of being directly affected by any panic _ = grpc.NewServer( grpc_middleware.WithUnaryServerChain( grpc_recovery.UnaryServerInterceptor(opts...), ), grpc_middleware.WithStreamServerChain( grpc_recovery.StreamServerInterceptor(opts...), ), ) } go-grpc-middleware-1.3.0/recovery/interceptors.go000066400000000000000000000036031404040257500220750ustar00rootroot00000000000000// Copyright 2017 David Ackroyd. All Rights Reserved. // See LICENSE for licensing terms. package grpc_recovery import ( "context" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) // RecoveryHandlerFunc is a function that recovers from the panic `p` by returning an `error`. type RecoveryHandlerFunc func(p interface{}) (err error) // RecoveryHandlerFuncContext is a function that recovers from the panic `p` by returning an `error`. // The context can be used to extract request scoped metadata and context values. type RecoveryHandlerFuncContext func(ctx context.Context, p interface{}) (err error) // UnaryServerInterceptor returns a new unary server interceptor for panic recovery. func UnaryServerInterceptor(opts ...Option) grpc.UnaryServerInterceptor { o := evaluateOptions(opts) return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (_ interface{}, err error) { panicked := true defer func() { if r := recover(); r != nil || panicked { err = recoverFrom(ctx, r, o.recoveryHandlerFunc) } }() resp, err := handler(ctx, req) panicked = false return resp, err } } // StreamServerInterceptor returns a new streaming server interceptor for panic recovery. func StreamServerInterceptor(opts ...Option) grpc.StreamServerInterceptor { o := evaluateOptions(opts) return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) (err error) { panicked := true defer func() { if r := recover(); r != nil || panicked { err = recoverFrom(stream.Context(), r, o.recoveryHandlerFunc) } }() err = handler(srv, stream) panicked = false return err } } func recoverFrom(ctx context.Context, p interface{}, r RecoveryHandlerFuncContext) error { if r == nil { return status.Errorf(codes.Internal, "%v", p) } return r(ctx, p) } go-grpc-middleware-1.3.0/recovery/interceptors_test.go000066400000000000000000000141711404040257500231360ustar00rootroot00000000000000// Copyright 2017 David Ackroyd. All Rights Reserved. // See LICENSE for licensing terms. package grpc_recovery_test import ( "context" "testing" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" grpc_recovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery" grpc_testing "github.com/grpc-ecosystem/go-grpc-middleware/testing" pb_testproto "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) var ( goodPing = &pb_testproto.PingRequest{Value: "something", SleepTimeMs: 9999} panicPing = &pb_testproto.PingRequest{Value: "panic", SleepTimeMs: 9999} nilPanicPing = &pb_testproto.PingRequest{Value: "nilpanic", SleepTimeMs: 9999} ) type recoveryAssertService struct { pb_testproto.TestServiceServer } func (s *recoveryAssertService) Ping(ctx context.Context, ping *pb_testproto.PingRequest) (*pb_testproto.PingResponse, error) { if ping.Value == "panic" { panic("very bad thing happened") } if ping.Value == "nilpanic" { panic(nil) } return s.TestServiceServer.Ping(ctx, ping) } func (s *recoveryAssertService) PingList(ping *pb_testproto.PingRequest, stream pb_testproto.TestService_PingListServer) error { if ping.Value == "panic" { panic("very bad thing happened") } if ping.Value == "nilpanic" { panic(nil) } return s.TestServiceServer.PingList(ping, stream) } func TestRecoverySuite(t *testing.T) { s := &RecoverySuite{ InterceptorTestSuite: &grpc_testing.InterceptorTestSuite{ TestService: &recoveryAssertService{TestServiceServer: &grpc_testing.TestPingService{T: t}}, ServerOpts: []grpc.ServerOption{ grpc_middleware.WithStreamServerChain( grpc_recovery.StreamServerInterceptor()), grpc_middleware.WithUnaryServerChain( grpc_recovery.UnaryServerInterceptor()), }, }, } suite.Run(t, s) } type RecoverySuite struct { *grpc_testing.InterceptorTestSuite } func (s *RecoverySuite) TestUnary_SuccessfulRequest() { _, err := s.Client.Ping(s.SimpleCtx(), goodPing) require.NoError(s.T(), err, "no error must occur") } func (s *RecoverySuite) TestUnary_PanickingRequest() { _, err := s.Client.Ping(s.SimpleCtx(), panicPing) require.Error(s.T(), err, "there must be an error") assert.Equal(s.T(), codes.Internal, status.Code(err), "must error with internal") assert.Equal(s.T(), "very bad thing happened", status.Convert(err).Message(), "must error with message") } func (s *RecoverySuite) TestUnary_NilPanickingRequest() { _, err := s.Client.Ping(s.SimpleCtx(), nilPanicPing) require.Error(s.T(), err, "there must be an error") assert.Equal(s.T(), codes.Internal, status.Code(err), "must error with internal") assert.Equal(s.T(), "", status.Convert(err).Message(), "must error with ") } func (s *RecoverySuite) TestStream_SuccessfulReceive() { stream, err := s.Client.PingList(s.SimpleCtx(), goodPing) require.NoError(s.T(), err, "should not fail on establishing the stream") pong, err := stream.Recv() require.NoError(s.T(), err, "no error must occur") require.NotNil(s.T(), pong, "pong must not be nil") } func (s *RecoverySuite) TestStream_PanickingReceive() { stream, err := s.Client.PingList(s.SimpleCtx(), panicPing) require.NoError(s.T(), err, "should not fail on establishing the stream") _, err = stream.Recv() require.Error(s.T(), err, "there must be an error") assert.Equal(s.T(), codes.Internal, status.Code(err), "must error with internal") assert.Equal(s.T(), "very bad thing happened", status.Convert(err).Message(), "must error with message") } func (s *RecoverySuite) TestStream_NilPanickingReceive() { stream, err := s.Client.PingList(s.SimpleCtx(), nilPanicPing) require.NoError(s.T(), err, "should not fail on establishing the stream") _, err = stream.Recv() require.Error(s.T(), err, "there must be an error") assert.Equal(s.T(), codes.Internal, status.Code(err), "must error with internal") assert.Equal(s.T(), "", status.Convert(err).Message(), "must error with ") } func TestRecoveryOverrideSuite(t *testing.T) { opts := []grpc_recovery.Option{ grpc_recovery.WithRecoveryHandler(func(p interface{}) (err error) { return status.Errorf(codes.Unknown, "panic triggered: %v", p) }), } s := &RecoveryOverrideSuite{ InterceptorTestSuite: &grpc_testing.InterceptorTestSuite{ TestService: &recoveryAssertService{TestServiceServer: &grpc_testing.TestPingService{T: t}}, ServerOpts: []grpc.ServerOption{ grpc_middleware.WithStreamServerChain( grpc_recovery.StreamServerInterceptor(opts...)), grpc_middleware.WithUnaryServerChain( grpc_recovery.UnaryServerInterceptor(opts...)), }, }, } suite.Run(t, s) } type RecoveryOverrideSuite struct { *grpc_testing.InterceptorTestSuite } func (s *RecoveryOverrideSuite) TestUnary_SuccessfulRequest() { _, err := s.Client.Ping(s.SimpleCtx(), goodPing) require.NoError(s.T(), err, "no error must occur") } func (s *RecoveryOverrideSuite) TestUnary_PanickingRequest() { _, err := s.Client.Ping(s.SimpleCtx(), panicPing) require.Error(s.T(), err, "there must be an error") assert.Equal(s.T(), codes.Unknown, status.Code(err), "must error with unknown") assert.Equal(s.T(), "panic triggered: very bad thing happened", status.Convert(err).Message(), "must error with message") } func (s *RecoveryOverrideSuite) TestStream_SuccessfulReceive() { stream, err := s.Client.PingList(s.SimpleCtx(), goodPing) require.NoError(s.T(), err, "should not fail on establishing the stream") pong, err := stream.Recv() require.NoError(s.T(), err, "no error must occur") require.NotNil(s.T(), pong, "pong must not be nil") } func (s *RecoveryOverrideSuite) TestStream_PanickingReceive() { stream, err := s.Client.PingList(s.SimpleCtx(), panicPing) require.NoError(s.T(), err, "should not fail on establishing the stream") _, err = stream.Recv() require.Error(s.T(), err, "there must be an error") assert.Equal(s.T(), codes.Unknown, status.Code(err), "must error with unknown") assert.Equal(s.T(), "panic triggered: very bad thing happened", status.Convert(err).Message(), "must error with message") } go-grpc-middleware-1.3.0/recovery/options.go000066400000000000000000000017021404040257500210450ustar00rootroot00000000000000// Copyright 2017 David Ackroyd. All Rights Reserved. // See LICENSE for licensing terms. package grpc_recovery import "context" var ( defaultOptions = &options{ recoveryHandlerFunc: nil, } ) type options struct { recoveryHandlerFunc RecoveryHandlerFuncContext } func evaluateOptions(opts []Option) *options { optCopy := &options{} *optCopy = *defaultOptions for _, o := range opts { o(optCopy) } return optCopy } type Option func(*options) // WithRecoveryHandler customizes the function for recovering from a panic. func WithRecoveryHandler(f RecoveryHandlerFunc) Option { return func(o *options) { o.recoveryHandlerFunc = RecoveryHandlerFuncContext(func(ctx context.Context, p interface{}) error { return f(p) }) } } // WithRecoveryHandlerContext customizes the function for recovering from a panic. func WithRecoveryHandlerContext(f RecoveryHandlerFuncContext) Option { return func(o *options) { o.recoveryHandlerFunc = f } } go-grpc-middleware-1.3.0/retry/000077500000000000000000000000001404040257500163325ustar00rootroot00000000000000go-grpc-middleware-1.3.0/retry/backoff.go000066400000000000000000000030571404040257500202610ustar00rootroot00000000000000// Copyright 2016 Michal Witkowski. All Rights Reserved. // See LICENSE for licensing terms. package grpc_retry import ( "time" "github.com/grpc-ecosystem/go-grpc-middleware/util/backoffutils" ) // BackoffLinear is very simple: it waits for a fixed period of time between calls. func BackoffLinear(waitBetween time.Duration) BackoffFunc { return func(attempt uint) time.Duration { return waitBetween } } // BackoffLinearWithJitter waits a set period of time, allowing for jitter (fractional adjustment). // // For example waitBetween=1s and jitter=0.10 can generate waits between 900ms and 1100ms. func BackoffLinearWithJitter(waitBetween time.Duration, jitterFraction float64) BackoffFunc { return func(attempt uint) time.Duration { return backoffutils.JitterUp(waitBetween, jitterFraction) } } // BackoffExponential produces increasing intervals for each attempt. // // The scalar is multiplied times 2 raised to the current attempt. So the first // retry with a scalar of 100ms is 100ms, while the 5th attempt would be 1.6s. func BackoffExponential(scalar time.Duration) BackoffFunc { return func(attempt uint) time.Duration { return scalar * time.Duration(backoffutils.ExponentBase2(attempt)) } } // BackoffExponentialWithJitter creates an exponential backoff like // BackoffExponential does, but adds jitter. func BackoffExponentialWithJitter(scalar time.Duration, jitterFraction float64) BackoffFunc { return func(attempt uint) time.Duration { return backoffutils.JitterUp(scalar*time.Duration(backoffutils.ExponentBase2(attempt)), jitterFraction) } } go-grpc-middleware-1.3.0/retry/doc.go000066400000000000000000000016541404040257500174340ustar00rootroot00000000000000// Copyright 2016 Michal Witkowski. All Rights Reserved. // See LICENSE for licensing terms. /* `grpc_retry` provides client-side request retry logic for gRPC. Client-Side Request Retry Interceptor It allows for automatic retry, inside the generated gRPC code of requests based on the gRPC status of the reply. It supports unary (1:1), and server stream (1:n) requests. By default the interceptors *are disabled*, preventing accidental use of retries. You can easily override the number of retries (setting them to more than 0) with a `grpc.ClientOption`, e.g.: myclient.Ping(ctx, goodPing, grpc_retry.WithMax(5)) Other default options are: retry on `ResourceExhausted` and `Unavailable` gRPC codes, use a 50ms linear backoff with 10% jitter. For chained interceptors, the retry interceptor will call every interceptor that follows it whenever when a retry happens. Please see examples for more advanced use. */ package grpc_retry go-grpc-middleware-1.3.0/retry/examples_test.go000066400000000000000000000057721404040257500215510ustar00rootroot00000000000000// Copyright 2016 Michal Witkowski. All Rights Reserved. // See LICENSE for licensing terms. package grpc_retry_test import ( "context" "fmt" "io" "time" "github.com/grpc-ecosystem/go-grpc-middleware/retry" pb_testproto "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto" "google.golang.org/grpc" "google.golang.org/grpc/codes" ) var cc *grpc.ClientConn func newCtx(timeout time.Duration) context.Context { ctx, _ := context.WithTimeout(context.TODO(), timeout) return ctx } // Simple example of using the default interceptor configuration. func Example_initialization() { grpc.Dial("myservice.example.com", grpc.WithStreamInterceptor(grpc_retry.StreamClientInterceptor()), grpc.WithUnaryInterceptor(grpc_retry.UnaryClientInterceptor()), ) } // Complex example with a 100ms linear backoff interval, and retry only on NotFound and Unavailable. func Example_initializationWithOptions() { opts := []grpc_retry.CallOption{ grpc_retry.WithBackoff(grpc_retry.BackoffLinear(100 * time.Millisecond)), grpc_retry.WithCodes(codes.NotFound, codes.Aborted), } grpc.Dial("myservice.example.com", grpc.WithStreamInterceptor(grpc_retry.StreamClientInterceptor(opts...)), grpc.WithUnaryInterceptor(grpc_retry.UnaryClientInterceptor(opts...)), ) } // Example with an exponential backoff starting with 100ms. // // Each next interval is the previous interval multiplied by 2. func Example_initializationWithExponentialBackoff() { opts := []grpc_retry.CallOption{ grpc_retry.WithBackoff(grpc_retry.BackoffExponential(100 * time.Millisecond)), } grpc.Dial("myservice.example.com", grpc.WithStreamInterceptor(grpc_retry.StreamClientInterceptor(opts...)), grpc.WithUnaryInterceptor(grpc_retry.UnaryClientInterceptor(opts...)), ) } // Simple example of an idempotent `ServerStream` call, that will be retried automatically 3 times. func Example_simpleCall() { client := pb_testproto.NewTestServiceClient(cc) stream, _ := client.PingList(newCtx(1*time.Second), &pb_testproto.PingRequest{}, grpc_retry.WithMax(3)) for { pong, err := stream.Recv() // retries happen here if err == io.EOF { break } else if err != nil { return } fmt.Printf("got pong: %v", pong) } } // This is an example of an `Unary` call that will also retry on deadlines. // // Because the passed in context has a `5s` timeout, the whole `Ping` invocation should finish // within that time. However, by default all retried calls will use the parent context for their // deadlines. This means, that unless you shorten the deadline of each call of the retry, you won't // be able to retry the first call at all. // // `WithPerRetryTimeout` allows you to shorten the deadline of each retry call, allowing you to fit // multiple retries in the single parent deadline. func ExampleWithPerRetryTimeout() { client := pb_testproto.NewTestServiceClient(cc) pong, _ := client.Ping( newCtx(5*time.Second), &pb_testproto.PingRequest{}, grpc_retry.WithMax(3), grpc_retry.WithPerRetryTimeout(1*time.Second)) fmt.Printf("got pong: %v", pong) } go-grpc-middleware-1.3.0/retry/options.go000066400000000000000000000117611404040257500203620ustar00rootroot00000000000000// Copyright 2016 Michal Witkowski. All Rights Reserved. // See LICENSE for licensing terms. package grpc_retry import ( "context" "time" "google.golang.org/grpc" "google.golang.org/grpc/codes" ) var ( // DefaultRetriableCodes is a set of well known types gRPC codes that should be retri-able. // // `ResourceExhausted` means that the user quota, e.g. per-RPC limits, have been reached. // `Unavailable` means that system is currently unavailable and the client should retry again. DefaultRetriableCodes = []codes.Code{codes.ResourceExhausted, codes.Unavailable} defaultOptions = &options{ max: 0, // disabled perCallTimeout: 0, // disabled includeHeader: true, codes: DefaultRetriableCodes, backoffFunc: BackoffFuncContext(func(ctx context.Context, attempt uint) time.Duration { return BackoffLinearWithJitter(50*time.Millisecond /*jitter*/, 0.10)(attempt) }), } ) // BackoffFunc denotes a family of functions that control the backoff duration between call retries. // // They are called with an identifier of the attempt, and should return a time the system client should // hold off for. If the time returned is longer than the `context.Context.Deadline` of the request // the deadline of the request takes precedence and the wait will be interrupted before proceeding // with the next iteration. type BackoffFunc func(attempt uint) time.Duration // BackoffFuncContext denotes a family of functions that control the backoff duration between call retries. // // They are called with an identifier of the attempt, and should return a time the system client should // hold off for. If the time returned is longer than the `context.Context.Deadline` of the request // the deadline of the request takes precedence and the wait will be interrupted before proceeding // with the next iteration. The context can be used to extract request scoped metadata and context values. type BackoffFuncContext func(ctx context.Context, attempt uint) time.Duration // Disable disables the retry behaviour on this call, or this interceptor. // // Its semantically the same to `WithMax` func Disable() CallOption { return WithMax(0) } // WithMax sets the maximum number of retries on this call, or this interceptor. func WithMax(maxRetries uint) CallOption { return CallOption{applyFunc: func(o *options) { o.max = maxRetries }} } // WithBackoff sets the `BackoffFunc` used to control time between retries. func WithBackoff(bf BackoffFunc) CallOption { return CallOption{applyFunc: func(o *options) { o.backoffFunc = BackoffFuncContext(func(ctx context.Context, attempt uint) time.Duration { return bf(attempt) }) }} } // WithBackoffContext sets the `BackoffFuncContext` used to control time between retries. func WithBackoffContext(bf BackoffFuncContext) CallOption { return CallOption{applyFunc: func(o *options) { o.backoffFunc = bf }} } // WithCodes sets which codes should be retried. // // Please *use with care*, as you may be retrying non-idempotent calls. // // You cannot automatically retry on Cancelled and Deadline, please use `WithPerRetryTimeout` for these. func WithCodes(retryCodes ...codes.Code) CallOption { return CallOption{applyFunc: func(o *options) { o.codes = retryCodes }} } // WithPerRetryTimeout sets the RPC timeout per call (including initial call) on this call, or this interceptor. // // The context.Deadline of the call takes precedence and sets the maximum time the whole invocation // will take, but WithPerRetryTimeout can be used to limit the RPC time per each call. // // For example, with context.Deadline = now + 10s, and WithPerRetryTimeout(3 * time.Seconds), each // of the retry calls (including the initial one) will have a deadline of now + 3s. // // A value of 0 disables the timeout overrides completely and returns to each retry call using the // parent `context.Deadline`. // // Note that when this is enabled, any DeadlineExceeded errors that are propagated up will be retried. func WithPerRetryTimeout(timeout time.Duration) CallOption { return CallOption{applyFunc: func(o *options) { o.perCallTimeout = timeout }} } type options struct { max uint perCallTimeout time.Duration includeHeader bool codes []codes.Code backoffFunc BackoffFuncContext } // CallOption is a grpc.CallOption that is local to grpc_retry. type CallOption struct { grpc.EmptyCallOption // make sure we implement private after() and before() fields so we don't panic. applyFunc func(opt *options) } func reuseOrNewWithCallOptions(opt *options, callOptions []CallOption) *options { if len(callOptions) == 0 { return opt } optCopy := &options{} *optCopy = *opt for _, f := range callOptions { f.applyFunc(optCopy) } return optCopy } func filterCallOptions(callOptions []grpc.CallOption) (grpcOptions []grpc.CallOption, retryOptions []CallOption) { for _, opt := range callOptions { if co, ok := opt.(CallOption); ok { retryOptions = append(retryOptions, co) } else { grpcOptions = append(grpcOptions, opt) } } return grpcOptions, retryOptions } go-grpc-middleware-1.3.0/retry/retry.go000066400000000000000000000255051404040257500200350ustar00rootroot00000000000000// Copyright 2016 Michal Witkowski. All Rights Reserved. // See LICENSE for licensing terms. package grpc_retry import ( "context" "fmt" "io" "sync" "time" "github.com/grpc-ecosystem/go-grpc-middleware/util/metautils" "golang.org/x/net/trace" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" ) const ( AttemptMetadataKey = "x-retry-attempty" ) // UnaryClientInterceptor returns a new retrying unary client interceptor. // // The default configuration of the interceptor is to not retry *at all*. This behaviour can be // changed through options (e.g. WithMax) on creation of the interceptor or on call (through grpc.CallOptions). func UnaryClientInterceptor(optFuncs ...CallOption) grpc.UnaryClientInterceptor { intOpts := reuseOrNewWithCallOptions(defaultOptions, optFuncs) return func(parentCtx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { grpcOpts, retryOpts := filterCallOptions(opts) callOpts := reuseOrNewWithCallOptions(intOpts, retryOpts) // short circuit for simplicity, and avoiding allocations. if callOpts.max == 0 { return invoker(parentCtx, method, req, reply, cc, grpcOpts...) } var lastErr error for attempt := uint(0); attempt < callOpts.max; attempt++ { if err := waitRetryBackoff(attempt, parentCtx, callOpts); err != nil { return err } callCtx := perCallContext(parentCtx, callOpts, attempt) lastErr = invoker(callCtx, method, req, reply, cc, grpcOpts...) // TODO(mwitkow): Maybe dial and transport errors should be retriable? if lastErr == nil { return nil } logTrace(parentCtx, "grpc_retry attempt: %d, got err: %v", attempt, lastErr) if isContextError(lastErr) { if parentCtx.Err() != nil { logTrace(parentCtx, "grpc_retry attempt: %d, parent context error: %v", attempt, parentCtx.Err()) // its the parent context deadline or cancellation. return lastErr } else if callOpts.perCallTimeout != 0 { // We have set a perCallTimeout in the retry middleware, which would result in a context error if // the deadline was exceeded, in which case try again. logTrace(parentCtx, "grpc_retry attempt: %d, context error from retry call", attempt) continue } } if !isRetriable(lastErr, callOpts) { return lastErr } } return lastErr } } // StreamClientInterceptor returns a new retrying stream client interceptor for server side streaming calls. // // The default configuration of the interceptor is to not retry *at all*. This behaviour can be // changed through options (e.g. WithMax) on creation of the interceptor or on call (through grpc.CallOptions). // // Retry logic is available *only for ServerStreams*, i.e. 1:n streams, as the internal logic needs // to buffer the messages sent by the client. If retry is enabled on any other streams (ClientStreams, // BidiStreams), the retry interceptor will fail the call. func StreamClientInterceptor(optFuncs ...CallOption) grpc.StreamClientInterceptor { intOpts := reuseOrNewWithCallOptions(defaultOptions, optFuncs) return func(parentCtx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { grpcOpts, retryOpts := filterCallOptions(opts) callOpts := reuseOrNewWithCallOptions(intOpts, retryOpts) // short circuit for simplicity, and avoiding allocations. if callOpts.max == 0 { return streamer(parentCtx, desc, cc, method, grpcOpts...) } if desc.ClientStreams { return nil, status.Errorf(codes.Unimplemented, "grpc_retry: cannot retry on ClientStreams, set grpc_retry.Disable()") } var lastErr error for attempt := uint(0); attempt < callOpts.max; attempt++ { if err := waitRetryBackoff(attempt, parentCtx, callOpts); err != nil { return nil, err } callCtx := perCallContext(parentCtx, callOpts, 0) var newStreamer grpc.ClientStream newStreamer, lastErr = streamer(callCtx, desc, cc, method, grpcOpts...) if lastErr == nil { retryingStreamer := &serverStreamingRetryingStream{ ClientStream: newStreamer, callOpts: callOpts, parentCtx: parentCtx, streamerCall: func(ctx context.Context) (grpc.ClientStream, error) { return streamer(ctx, desc, cc, method, grpcOpts...) }, } return retryingStreamer, nil } logTrace(parentCtx, "grpc_retry attempt: %d, got err: %v", attempt, lastErr) if isContextError(lastErr) { if parentCtx.Err() != nil { logTrace(parentCtx, "grpc_retry attempt: %d, parent context error: %v", attempt, parentCtx.Err()) // its the parent context deadline or cancellation. return nil, lastErr } else if callOpts.perCallTimeout != 0 { // We have set a perCallTimeout in the retry middleware, which would result in a context error if // the deadline was exceeded, in which case try again. logTrace(parentCtx, "grpc_retry attempt: %d, context error from retry call", attempt) continue } } if !isRetriable(lastErr, callOpts) { return nil, lastErr } } return nil, lastErr } } // type serverStreamingRetryingStream is the implementation of grpc.ClientStream that acts as a // proxy to the underlying call. If any of the RecvMsg() calls fail, it will try to reestablish // a new ClientStream according to the retry policy. type serverStreamingRetryingStream struct { grpc.ClientStream bufferedSends []interface{} // single message that the client can sen receivedGood bool // indicates whether any prior receives were successful wasClosedSend bool // indicates that CloseSend was closed parentCtx context.Context callOpts *options streamerCall func(ctx context.Context) (grpc.ClientStream, error) mu sync.RWMutex } func (s *serverStreamingRetryingStream) setStream(clientStream grpc.ClientStream) { s.mu.Lock() s.ClientStream = clientStream s.mu.Unlock() } func (s *serverStreamingRetryingStream) getStream() grpc.ClientStream { s.mu.RLock() defer s.mu.RUnlock() return s.ClientStream } func (s *serverStreamingRetryingStream) SendMsg(m interface{}) error { s.mu.Lock() s.bufferedSends = append(s.bufferedSends, m) s.mu.Unlock() return s.getStream().SendMsg(m) } func (s *serverStreamingRetryingStream) CloseSend() error { s.mu.Lock() s.wasClosedSend = true s.mu.Unlock() return s.getStream().CloseSend() } func (s *serverStreamingRetryingStream) Header() (metadata.MD, error) { return s.getStream().Header() } func (s *serverStreamingRetryingStream) Trailer() metadata.MD { return s.getStream().Trailer() } func (s *serverStreamingRetryingStream) RecvMsg(m interface{}) error { attemptRetry, lastErr := s.receiveMsgAndIndicateRetry(m) if !attemptRetry { return lastErr // success or hard failure } // We start off from attempt 1, because zeroth was already made on normal SendMsg(). for attempt := uint(1); attempt < s.callOpts.max; attempt++ { if err := waitRetryBackoff(attempt, s.parentCtx, s.callOpts); err != nil { return err } callCtx := perCallContext(s.parentCtx, s.callOpts, attempt) newStream, err := s.reestablishStreamAndResendBuffer(callCtx) if err != nil { // Retry dial and transport errors of establishing stream as grpc doesn't retry. if isRetriable(err, s.callOpts) { continue } return err } s.setStream(newStream) attemptRetry, lastErr = s.receiveMsgAndIndicateRetry(m) //fmt.Printf("Received message and indicate: %v %v\n", attemptRetry, lastErr) if !attemptRetry { return lastErr } } return lastErr } func (s *serverStreamingRetryingStream) receiveMsgAndIndicateRetry(m interface{}) (bool, error) { s.mu.RLock() wasGood := s.receivedGood s.mu.RUnlock() err := s.getStream().RecvMsg(m) if err == nil || err == io.EOF { s.mu.Lock() s.receivedGood = true s.mu.Unlock() return false, err } else if wasGood { // previous RecvMsg in the stream succeeded, no retry logic should interfere return false, err } if isContextError(err) { if s.parentCtx.Err() != nil { logTrace(s.parentCtx, "grpc_retry parent context error: %v", s.parentCtx.Err()) return false, err } else if s.callOpts.perCallTimeout != 0 { // We have set a perCallTimeout in the retry middleware, which would result in a context error if // the deadline was exceeded, in which case try again. logTrace(s.parentCtx, "grpc_retry context error from retry call") return true, err } } return isRetriable(err, s.callOpts), err } func (s *serverStreamingRetryingStream) reestablishStreamAndResendBuffer( callCtx context.Context, ) (grpc.ClientStream, error) { s.mu.RLock() bufferedSends := s.bufferedSends s.mu.RUnlock() newStream, err := s.streamerCall(callCtx) if err != nil { logTrace(callCtx, "grpc_retry failed redialing new stream: %v", err) return nil, err } for _, msg := range bufferedSends { if err := newStream.SendMsg(msg); err != nil { logTrace(callCtx, "grpc_retry failed resending message: %v", err) return nil, err } } if err := newStream.CloseSend(); err != nil { logTrace(callCtx, "grpc_retry failed CloseSend on new stream %v", err) return nil, err } return newStream, nil } func waitRetryBackoff(attempt uint, parentCtx context.Context, callOpts *options) error { var waitTime time.Duration = 0 if attempt > 0 { waitTime = callOpts.backoffFunc(parentCtx, attempt) } if waitTime > 0 { logTrace(parentCtx, "grpc_retry attempt: %d, backoff for %v", attempt, waitTime) timer := time.NewTimer(waitTime) select { case <-parentCtx.Done(): timer.Stop() return contextErrToGrpcErr(parentCtx.Err()) case <-timer.C: } } return nil } func isRetriable(err error, callOpts *options) bool { errCode := status.Code(err) if isContextError(err) { // context errors are not retriable based on user settings. return false } for _, code := range callOpts.codes { if code == errCode { return true } } return false } func isContextError(err error) bool { code := status.Code(err) return code == codes.DeadlineExceeded || code == codes.Canceled } func perCallContext(parentCtx context.Context, callOpts *options, attempt uint) context.Context { ctx := parentCtx if callOpts.perCallTimeout != 0 { ctx, _ = context.WithTimeout(ctx, callOpts.perCallTimeout) } if attempt > 0 && callOpts.includeHeader { mdClone := metautils.ExtractOutgoing(ctx).Clone().Set(AttemptMetadataKey, fmt.Sprintf("%d", attempt)) ctx = mdClone.ToOutgoing(ctx) } return ctx } func contextErrToGrpcErr(err error) error { switch err { case context.DeadlineExceeded: return status.Error(codes.DeadlineExceeded, err.Error()) case context.Canceled: return status.Error(codes.Canceled, err.Error()) default: return status.Error(codes.Unknown, err.Error()) } } func logTrace(ctx context.Context, format string, a ...interface{}) { tr, ok := trace.FromContext(ctx) if !ok { return } tr.LazyPrintf(format, a...) } go-grpc-middleware-1.3.0/retry/retry_test.go000066400000000000000000000403151404040257500210700ustar00rootroot00000000000000// Copyright 2016 Michal Witkowski. All Rights Reserved. // See LICENSE for licensing terms. package grpc_retry_test import ( "context" "io" "sync" "testing" "time" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry" "github.com/grpc-ecosystem/go-grpc-middleware/testing" pb_testproto "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) var ( retriableErrors = []codes.Code{codes.Unavailable, codes.DataLoss} goodPing = &pb_testproto.PingRequest{Value: "something"} noSleep = 0 * time.Second retryTimeout = 50 * time.Millisecond ) type failingService struct { pb_testproto.TestServiceServer mu sync.Mutex reqCounter uint reqModulo uint reqSleep time.Duration reqError codes.Code } func (s *failingService) resetFailingConfiguration(modulo uint, errorCode codes.Code, sleepTime time.Duration) { s.mu.Lock() defer s.mu.Unlock() s.reqCounter = 0 s.reqModulo = modulo s.reqError = errorCode s.reqSleep = sleepTime } func (s *failingService) requestCount() uint { s.mu.Lock() defer s.mu.Unlock() return s.reqCounter } func (s *failingService) maybeFailRequest() error { s.mu.Lock() s.reqCounter += 1 reqModulo := s.reqModulo reqCounter := s.reqCounter reqSleep := s.reqSleep reqError := s.reqError s.mu.Unlock() if (reqModulo > 0) && (reqCounter%reqModulo == 0) { return nil } time.Sleep(reqSleep) return status.Errorf(reqError, "maybeFailRequest: failing it") } func (s *failingService) Ping(ctx context.Context, ping *pb_testproto.PingRequest) (*pb_testproto.PingResponse, error) { if err := s.maybeFailRequest(); err != nil { return nil, err } return s.TestServiceServer.Ping(ctx, ping) } func (s *failingService) PingList(ping *pb_testproto.PingRequest, stream pb_testproto.TestService_PingListServer) error { if err := s.maybeFailRequest(); err != nil { return err } return s.TestServiceServer.PingList(ping, stream) } func (s *failingService) PingStream(stream pb_testproto.TestService_PingStreamServer) error { if err := s.maybeFailRequest(); err != nil { return err } return s.TestServiceServer.PingStream(stream) } func TestRetrySuite(t *testing.T) { service := &failingService{ TestServiceServer: &grpc_testing.TestPingService{T: t}, } unaryInterceptor := grpc_retry.UnaryClientInterceptor( grpc_retry.WithCodes(retriableErrors...), grpc_retry.WithMax(3), grpc_retry.WithBackoff(grpc_retry.BackoffLinear(retryTimeout)), ) streamInterceptor := grpc_retry.StreamClientInterceptor( grpc_retry.WithCodes(retriableErrors...), grpc_retry.WithMax(3), grpc_retry.WithBackoff(grpc_retry.BackoffLinear(retryTimeout)), ) s := &RetrySuite{ srv: service, InterceptorTestSuite: &grpc_testing.InterceptorTestSuite{ TestService: service, ClientOpts: []grpc.DialOption{ grpc.WithStreamInterceptor(streamInterceptor), grpc.WithUnaryInterceptor(unaryInterceptor), }, }, } suite.Run(t, s) } type RetrySuite struct { *grpc_testing.InterceptorTestSuite srv *failingService } func (s *RetrySuite) SetupTest() { s.srv.resetFailingConfiguration( /* don't fail */ 0, codes.OK, noSleep) } func (s *RetrySuite) TestUnary_FailsOnNonRetriableError() { s.srv.resetFailingConfiguration(5, codes.Internal, noSleep) _, err := s.Client.Ping(s.SimpleCtx(), goodPing) require.Error(s.T(), err, "error must occur from the failing service") require.Equal(s.T(), codes.Internal, status.Code(err), "failure code must come from retrier") require.EqualValues(s.T(), 1, s.srv.requestCount(), "one request should have been made") } func (s *RetrySuite) TestUnary_FailsOnNonRetriableContextError() { s.srv.resetFailingConfiguration(5, codes.Canceled, noSleep) _, err := s.Client.Ping(s.SimpleCtx(), goodPing) require.Error(s.T(), err, "error must occur from the failing service") require.Equal(s.T(), codes.Canceled, status.Code(err), "failure code must come from retrier") require.EqualValues(s.T(), 1, s.srv.requestCount(), "one request should have been made") } func (s *RetrySuite) TestCallOptionsDontPanicWithoutInterceptor() { // Fix for https://github.com/grpc-ecosystem/go-grpc-middleware/issues/37 // If this code doesn't panic, that's good. s.srv.resetFailingConfiguration(100, codes.DataLoss, noSleep) // doesn't matter all requests should fail nonMiddlewareClient := s.NewClient() _, err := nonMiddlewareClient.Ping(s.SimpleCtx(), goodPing, grpc_retry.WithMax(5), grpc_retry.WithBackoff(grpc_retry.BackoffLinear(1*time.Millisecond)), grpc_retry.WithCodes(codes.DataLoss), grpc_retry.WithPerRetryTimeout(1*time.Millisecond), ) require.Error(s.T(), err) } func (s *RetrySuite) TestServerStream_FailsOnNonRetriableError() { s.srv.resetFailingConfiguration(5, codes.Internal, noSleep) stream, err := s.Client.PingList(s.SimpleCtx(), goodPing) require.NoError(s.T(), err, "should not fail on establishing the stream") _, err = stream.Recv() require.Error(s.T(), err, "error must occur from the failing service") require.Equal(s.T(), codes.Internal, status.Code(err), "failure code must come from retrier") } func (s *RetrySuite) TestUnary_SucceedsOnRetriableError() { s.srv.resetFailingConfiguration(3, codes.DataLoss, noSleep) // see retriable_errors out, err := s.Client.Ping(s.SimpleCtx(), goodPing) require.NoError(s.T(), err, "the third invocation should succeed") require.NotNil(s.T(), out, "Pong must be not nil") require.EqualValues(s.T(), 3, s.srv.requestCount(), "three requests should have been made") } func (s *RetrySuite) TestUnary_OverrideFromDialOpts() { s.srv.resetFailingConfiguration(5, codes.ResourceExhausted, noSleep) // default is 3 and retriable_errors out, err := s.Client.Ping(s.SimpleCtx(), goodPing, grpc_retry.WithCodes(codes.ResourceExhausted), grpc_retry.WithMax(5)) require.NoError(s.T(), err, "the fifth invocation should succeed") require.NotNil(s.T(), out, "Pong must be not nil") require.EqualValues(s.T(), 5, s.srv.requestCount(), "five requests should have been made") } func (s *RetrySuite) TestUnary_PerCallDeadline_Succeeds() { // This tests 5 requests, with first 4 sleeping for 10 millisecond, and the retry logic firing // a retry call with a 5 millisecond deadline. The 5th one doesn't sleep and succeeds. deadlinePerCall := 5 * time.Millisecond s.srv.resetFailingConfiguration(5, codes.NotFound, 2*deadlinePerCall) out, err := s.Client.Ping(s.SimpleCtx(), goodPing, grpc_retry.WithPerRetryTimeout(deadlinePerCall), grpc_retry.WithMax(5)) require.NoError(s.T(), err, "the fifth invocation should succeed") require.NotNil(s.T(), out, "Pong must be not nil") require.EqualValues(s.T(), 5, s.srv.requestCount(), "five requests should have been made") } func (s *RetrySuite) TestUnary_PerCallDeadline_FailsOnParent() { // This tests that the parent context (passed to the invocation) takes precedence over retries. // The parent context has 150 milliseconds of deadline. // Each failed call sleeps for 100milliseconds, and there is 5 milliseconds between each one. // This means that unlike in TestUnary_PerCallDeadline_Succeeds, the fifth successful call won't // be made. parentDeadline := 150 * time.Millisecond deadlinePerCall := 50 * time.Millisecond // All 0-4 requests should have 10 millisecond sleeps and deadline, while the last one works. s.srv.resetFailingConfiguration(5, codes.NotFound, 2*deadlinePerCall) ctx, _ := context.WithTimeout(context.TODO(), parentDeadline) _, err := s.Client.Ping(ctx, goodPing, grpc_retry.WithPerRetryTimeout(deadlinePerCall), grpc_retry.WithMax(5)) require.Error(s.T(), err, "the retries must fail due to context deadline exceeded") require.Equal(s.T(), codes.DeadlineExceeded, status.Code(err), "failre code must be a gRPC error of Deadline class") } func (s *RetrySuite) TestServerStream_SucceedsOnRetriableError() { s.srv.resetFailingConfiguration(3, codes.DataLoss, noSleep) // see retriable_errors stream, err := s.Client.PingList(s.SimpleCtx(), goodPing) require.NoError(s.T(), err, "establishing the connection must always succeed") s.assertPingListWasCorrect(stream) require.EqualValues(s.T(), 3, s.srv.requestCount(), "three requests should have been made") } func (s *RetrySuite) TestServerStream_OverrideFromContext() { s.srv.resetFailingConfiguration(5, codes.ResourceExhausted, noSleep) // default is 3 and retriable_errors stream, err := s.Client.PingList(s.SimpleCtx(), goodPing, grpc_retry.WithCodes(codes.ResourceExhausted), grpc_retry.WithMax(5)) require.NoError(s.T(), err, "establishing the connection must always succeed") s.assertPingListWasCorrect(stream) require.EqualValues(s.T(), 5, s.srv.requestCount(), "three requests should have been made") } func (s *RetrySuite) TestServerStream_PerCallDeadline_Succeeds() { // This tests 5 requests, with first 4 sleeping for 100 millisecond, and the retry logic firing // a retry call with a 50 millisecond deadline. The 5th one doesn't sleep and succeeds. deadlinePerCall := 50 * time.Millisecond s.srv.resetFailingConfiguration(5, codes.NotFound, 2*deadlinePerCall) stream, err := s.Client.PingList(s.SimpleCtx(), goodPing, grpc_retry.WithPerRetryTimeout(deadlinePerCall), grpc_retry.WithMax(5)) require.NoError(s.T(), err, "establishing the connection must always succeed") s.assertPingListWasCorrect(stream) require.EqualValues(s.T(), 5, s.srv.requestCount(), "three requests should have been made") } func (s *RetrySuite) TestServerStream_PerCallDeadline_FailsOnParent() { // This tests that the parent context (passed to the invocation) takes precedence over retries. // The parent context has 150 milliseconds of deadline. // Each failed call sleeps for 50milliseconds, and there is 25 milliseconds between each one. // This means that unlike in TestServerStream_PerCallDeadline_Succeeds, the fifth successful call won't // be made. parentDeadline := 150 * time.Millisecond deadlinePerCall := 50 * time.Millisecond // All 0-4 requests should have 10 millisecond sleeps and deadline, while the last one works. s.srv.resetFailingConfiguration(5, codes.NotFound, 2*deadlinePerCall) parentCtx, _ := context.WithTimeout(context.TODO(), parentDeadline) stream, err := s.Client.PingList(parentCtx, goodPing, grpc_retry.WithPerRetryTimeout(deadlinePerCall), grpc_retry.WithMax(5)) require.NoError(s.T(), err, "establishing the connection must always succeed") _, err = stream.Recv() require.Equal(s.T(), codes.DeadlineExceeded, status.Code(err), "failre code must be a gRPC error of Deadline class") } func (s *RetrySuite) TestServerStream_CallFailsOnOutOfRetries() { restarted := s.RestartServer(3 * retryTimeout) _, err := s.Client.PingList(s.SimpleCtx(), goodPing) require.Error(s.T(), err, "establishing the connection should not succeed") assert.Equal(s.T(), codes.Unavailable, status.Code(err)) <-restarted } func (s *RetrySuite) TestServerStream_CallFailsOnDeadlineExceeded() { restarted := s.RestartServer(3 * retryTimeout) ctx, _ := context.WithTimeout(context.TODO(), retryTimeout) _, err := s.Client.PingList(ctx, goodPing) require.Error(s.T(), err, "establishing the connection should not succeed") assert.Equal(s.T(), codes.DeadlineExceeded, status.Code(err)) <-restarted } func (s *RetrySuite) TestServerStream_CallRetrySucceeds() { restarted := s.RestartServer(retryTimeout) _, err := s.Client.PingList(s.SimpleCtx(), goodPing, grpc_retry.WithMax(40), ) assert.NoError(s.T(), err, "establishing the connection should succeed") <-restarted } func (s *RetrySuite) assertPingListWasCorrect(stream pb_testproto.TestService_PingListClient) { count := 0 for { pong, err := stream.Recv() if err == io.EOF { break } require.NotNil(s.T(), pong, "received values must not be nil") require.NoError(s.T(), err, "no errors during receive on client side") require.Equal(s.T(), goodPing.Value, pong.Value, "the returned pong contained the outgoing ping") count += 1 } require.EqualValues(s.T(), grpc_testing.ListResponseCount, count, "should have received all ping items") } type trackedInterceptor struct { called int } func (ti *trackedInterceptor) UnaryClientInterceptor(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { ti.called++ return invoker(ctx, method, req, reply, cc, opts...) } func (ti *trackedInterceptor) StreamClientInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { ti.called++ return streamer(ctx, desc, cc, method, opts...) } func TestChainedRetrySuite(t *testing.T) { service := &failingService{ TestServiceServer: &grpc_testing.TestPingService{T: t}, } preRetryInterceptor := &trackedInterceptor{} postRetryInterceptor := &trackedInterceptor{} s := &ChainedRetrySuite{ srv: service, preRetryInterceptor: preRetryInterceptor, postRetryInterceptor: postRetryInterceptor, InterceptorTestSuite: &grpc_testing.InterceptorTestSuite{ TestService: service, ClientOpts: []grpc.DialOption{ grpc.WithUnaryInterceptor(grpc_middleware.ChainUnaryClient(preRetryInterceptor.UnaryClientInterceptor, grpc_retry.UnaryClientInterceptor(), postRetryInterceptor.UnaryClientInterceptor)), grpc.WithStreamInterceptor(grpc_middleware.ChainStreamClient(preRetryInterceptor.StreamClientInterceptor, grpc_retry.StreamClientInterceptor(), postRetryInterceptor.StreamClientInterceptor)), }, }, } suite.Run(t, s) } type ChainedRetrySuite struct { *grpc_testing.InterceptorTestSuite srv *failingService preRetryInterceptor *trackedInterceptor postRetryInterceptor *trackedInterceptor } func (s *ChainedRetrySuite) SetupTest() { s.srv.resetFailingConfiguration( /* don't fail */ 0, codes.OK, noSleep) s.preRetryInterceptor.called = 0 s.postRetryInterceptor.called = 0 } func (s *ChainedRetrySuite) TestUnaryWithChainedInterceptors_NoFailure() { _, err := s.Client.Ping(s.SimpleCtx(), goodPing, grpc_retry.WithMax(2)) require.NoError(s.T(), err, "the invocation should succeed") require.EqualValues(s.T(), 1, s.srv.requestCount(), "one request should have been made") require.EqualValues(s.T(), 1, s.preRetryInterceptor.called, "pre-retry interceptor should be called once") require.EqualValues(s.T(), 1, s.postRetryInterceptor.called, "post-retry interceptor should be called once") } func (s *ChainedRetrySuite) TestUnaryWithChainedInterceptors_WithRetry() { s.srv.resetFailingConfiguration(2, codes.Unavailable, noSleep) _, err := s.Client.Ping(s.SimpleCtx(), goodPing, grpc_retry.WithMax(2)) require.NoError(s.T(), err, "the second invocation should succeed") require.EqualValues(s.T(), 2, s.srv.requestCount(), "two requests should have been made") require.EqualValues(s.T(), 1, s.preRetryInterceptor.called, "pre-retry interceptor should be called once") require.EqualValues(s.T(), 2, s.postRetryInterceptor.called, "post-retry interceptor should be called twice") } func (s *ChainedRetrySuite) TestStreamWithChainedInterceptors_NoFailure() { stream, err := s.Client.PingList(s.SimpleCtx(), goodPing, grpc_retry.WithMax(2)) require.NoError(s.T(), err, "the invocation should succeed") _, err = stream.Recv() require.NoError(s.T(), err, "the Recv should succeed") require.EqualValues(s.T(), 1, s.srv.requestCount(), "one request should have been made") require.EqualValues(s.T(), 1, s.preRetryInterceptor.called, "pre-retry interceptor should be called once") require.EqualValues(s.T(), 1, s.postRetryInterceptor.called, "post-retry interceptor should be called once") } func (s *ChainedRetrySuite) TestStreamWithChainedInterceptors_WithRetry() { s.srv.resetFailingConfiguration(2, codes.Unavailable, noSleep) stream, err := s.Client.PingList(s.SimpleCtx(), goodPing, grpc_retry.WithMax(2)) require.NoError(s.T(), err, "the second invocation should succeed") _, err = stream.Recv() require.NoError(s.T(), err, "the Recv should succeed") require.EqualValues(s.T(), 2, s.srv.requestCount(), "two requests should have been made") require.EqualValues(s.T(), 1, s.preRetryInterceptor.called, "pre-retry interceptor should be called once") require.EqualValues(s.T(), 2, s.postRetryInterceptor.called, "post-retry interceptor should be called twice") } go-grpc-middleware-1.3.0/scripts/000077500000000000000000000000001404040257500166545ustar00rootroot00000000000000go-grpc-middleware-1.3.0/scripts/test_all.sh000077500000000000000000000005651404040257500210300ustar00rootroot00000000000000#!/usr/bin/env bash set -e echo "" > coverage.txt for d in $(go list ./... | grep -v vendor); do echo -e "TESTS FOR: for \033[0;35m${d}\033[0m" go test -race -v -coverprofile=profile.coverage.out -covermode=atomic $d if [ -f profile.coverage.out ]; then cat profile.coverage.out >> coverage.txt rm profile.coverage.out fi echo "" done go-grpc-middleware-1.3.0/slack.png000066400000000000000000000117401404040257500167730ustar00rootroot00000000000000PNG  IHDRIӪgAMA a cHRMz&u0`:pQ<bKGDC pHYsZZp#}tIME1 TRIDATx\yTWUwu 6N@WA"DQ!qp\&$$Ƹdp D\0GDPQpAnhlGy}}ibl~66N.8]ؾ"$"..^:ZXaC*\46Rzl99Dt!Nl&lJM@jm_BPEQVXaOl^lY#&)&VVXOgXʢ,R/X +'<;;Vb~ +JOVXaEwpbx^evvvvv6@$Iϯ#G9r˔޶G7g{V@&݇N?kb'BBxAmmuukcC//C . H0(hjm/?z4* 鼼t:pOr̕l2D@[i|> xxxxxxT*J%cǎ;vrݙ{իW ؽ{ݻcǎ;llllllv8p۷o?re˖-[۷o߾}d2L= hjru}X_B"O}π7ZZ.\`ۛ68/4@5;:kהʼ< >;- '//V.(H*>HL><=p.hk3T* +޽끆`9{6,c '.+ nT]RRl߽H| psc*U^bqHt766-K w߿G[mm}?lNNskg|iUڵ@hY@`D2m+eKe@}}W{ccOy'?@7u:`Tl !wt:ܹor=Hƍsp4 HJ>|^0)dfVLg`78`p8૯Fv 8}X42hJKUmooΞml<|X~/<=|}$hljkki>n4Oryv(L$&&&&&,uHܹsΝ;JR
( '''''mwڵk׮222222Xp8puuuuu8>|yyyyy9pСCD"nw<ܼgpD\\q1siyyPI?lH| ߌ ,YrHDPtv>~ Oޖ ;;֍ҥ99gڒ%|%ZVٷ]է;jk++L9@>,+Ν:v  t }7g'O C\.2jDP@\P( `'€p_h 3s֬Ç/^LJo~`BUO?w/5g0D;9W@bѣ3f Mm-k>-E"U||oV:~|@h طo3gKUu@t)˗1+7?.,L}&L N.-+BܰPZ[J6SNJ YX=9ήݶxrvtt]٬?NL,q@QQQQQ`0 Cc3f̘1  pLy~ȑ#Gz@&d2VjYBey͛7@^^^^^ bq_? ..3|/P_VS k<0v O?=y H$x==60(+qj==ƆN~Q %e͛-۷JKť6֭ۛoG(L>qRy `o?jI/(hiH0 'O;wnF8Ψ9]-F (-e,LL'OBc<,-t:^JJ:;/]bjKJr##&1c^F1 > Xr1nP H$$l?0JVt#I΃# aa_gfK.o$Nm-RAI@xtfPHz[< .\{q2`miiiiiɓ'O< bbbbbbDFFFFF))))))Fh4@hhhhh(` rz^׳Dx<^O2_~=7'ظqܸ={Jll,n@&\\β2/Ʈ. -V7:;i9:puS/89qrF#Es渻  I8zπ6&&N;  ttPTW@hMGOJ$E;qA/9,ݾll{K9``3!lt:I3B 0hR<,Gd D[IEQ@MJU] \QQ=7zcy<(,| 8X,`^r 蕋3>=rH4a0NJ@bbN̙Fө@XY\ӞY/| d|@qqqqq1k׮]vP( =R ݓJR9s̙3{1Ku______Ϗg2N;;;;;޿ ɜ}eCZ=|@FFne@fk׮=xp`x8uٳlB**&Lpu OYcbz\*,Ǝuv,4{[l,pNS͛@GѨǎ|-_=q"[CYّ#קCzF?澻@0bmĉ@EEKKa!pJu:0rLOMh4u \Xܷ}O#f+M\ߺT޸%rl}H`zMFxc @,7$r7 I ի^^WeeIPZqAx@qqsի  E[[c#@Çb?Is,XxzV[8uwwwww, ǏgXIIIII tߴiӦM"\.{{{{{{VLMMMMMe 9ٻw޽{J̙4}~ݺ'>={VΎg f/*+7l4gkz7%>}ܓ'*UA8:XfnVX2*F" K t?[[OO._>R +^* p3sr KX'VX'999\+>W|\D\!WD"ÇC"P+@ j-Šuo"A n'y^k dž%tEXtdate:create2018-04-27T04:49:09-04:00CA%tEXtdate:modify2018-04-27T04:49:09-04:002d0tEXtsvg:base-urifile:///tmp/magick-5188Ce7IL-5Wy5e6IENDB`go-grpc-middleware-1.3.0/tags/000077500000000000000000000000001404040257500161235ustar00rootroot00000000000000go-grpc-middleware-1.3.0/tags/context.go000066400000000000000000000035641404040257500201460ustar00rootroot00000000000000package grpc_ctxtags import ( "context" ) type ctxMarker struct{} var ( // ctxMarkerKey is the Context value marker used by *all* logging middleware. // The logging middleware object must interf ctxMarkerKey = &ctxMarker{} // NoopTags is a trivial, minimum overhead implementation of Tags for which all operations are no-ops. NoopTags = &noopTags{} ) // Tags is the interface used for storing request tags between Context calls. // The default implementation is *not* thread safe, and should be handled only in the context of the request. type Tags interface { // Set sets the given key in the metadata tags. Set(key string, value interface{}) Tags // Has checks if the given key exists. Has(key string) bool // Values returns a map of key to values. // Do not modify the underlying map, please use Set instead. Values() map[string]interface{} } type mapTags struct { values map[string]interface{} } func (t *mapTags) Set(key string, value interface{}) Tags { t.values[key] = value return t } func (t *mapTags) Has(key string) bool { _, ok := t.values[key] return ok } func (t *mapTags) Values() map[string]interface{} { return t.values } type noopTags struct{} func (t *noopTags) Set(key string, value interface{}) Tags { return t } func (t *noopTags) Has(key string) bool { return false } func (t *noopTags) Values() map[string]interface{} { return nil } // Extracts returns a pre-existing Tags object in the Context. // If the context wasn't set in a tag interceptor, a no-op Tag storage is returned that will *not* be propagated in context. func Extract(ctx context.Context) Tags { t, ok := ctx.Value(ctxMarkerKey).(Tags) if !ok { return NoopTags } return t } func SetInContext(ctx context.Context, tags Tags) context.Context { return context.WithValue(ctx, ctxMarkerKey, tags) } func NewTags() Tags { return &mapTags{values: make(map[string]interface{})} } go-grpc-middleware-1.3.0/tags/doc.go000066400000000000000000000024221404040257500172170ustar00rootroot00000000000000/* `grpc_ctxtags` adds a Tag object to the context that can be used by other middleware to add context about a request. Request Context Tags Tags describe information about the request, and can be set and used by other middleware, or handlers. Tags are used for logging and tracing of requests. Tags are populated both upwards, *and* downwards in the interceptor-handler stack. You can automatically extract tags (in `grpc.request.`) from request payloads. For unary and server-streaming methods, pass in the `WithFieldExtractor` option. For client-streams and bidirectional-streams, you can use `WithFieldExtractorForInitialReq` which will extract the tags from the first message passed from client to server. Note the tags will not be modified for subsequent requests, so this option only makes sense when the initial message establishes the meta-data for the stream. If a user doesn't use the interceptors that initialize the `Tags` object, all operations following from an `Extract(ctx)` will be no-ops. This is to ensure that code doesn't panic if the interceptors weren't used. Tags fields are typed, and shallow and should follow the OpenTracing semantics convention: https://github.com/opentracing/specification/blob/master/semantic_conventions.md */ package grpc_ctxtags go-grpc-middleware-1.3.0/tags/examples_test.go000066400000000000000000000017061404040257500213330ustar00rootroot00000000000000package grpc_ctxtags_test import ( "github.com/grpc-ecosystem/go-grpc-middleware/tags" "google.golang.org/grpc" ) // Simple example of server initialization code, with data automatically populated from `log_fields` Golang tags. func Example_initialization() { opts := []grpc_ctxtags.Option{ grpc_ctxtags.WithFieldExtractor(grpc_ctxtags.TagBasedRequestFieldExtractor("log_fields")), } _ = grpc.NewServer( grpc.StreamInterceptor(grpc_ctxtags.StreamServerInterceptor(opts...)), grpc.UnaryInterceptor(grpc_ctxtags.UnaryServerInterceptor(opts...)), ) } // Example using WithFieldExtractorForInitialReq func Example_initialisationWithOptions() { opts := []grpc_ctxtags.Option{ grpc_ctxtags.WithFieldExtractorForInitialReq(grpc_ctxtags.TagBasedRequestFieldExtractor("log_fields")), } _ = grpc.NewServer( grpc.StreamInterceptor(grpc_ctxtags.StreamServerInterceptor(opts...)), grpc.UnaryInterceptor(grpc_ctxtags.UnaryServerInterceptor(opts...)), ) } go-grpc-middleware-1.3.0/tags/fieldextractor.go000066400000000000000000000061401404040257500214720ustar00rootroot00000000000000// Copyright 2017 Michal Witkowski. All Rights Reserved. // See LICENSE for licensing terms. package grpc_ctxtags import ( "reflect" ) // RequestFieldExtractorFunc is a user-provided function that extracts field information from a gRPC request. // It is called from tags middleware on arrival of unary request or a server-stream request. // Keys and values will be added to the context tags of the request. If there are no fields, you should return a nil. type RequestFieldExtractorFunc func(fullMethod string, req interface{}) map[string]interface{} type requestFieldsExtractor interface { // ExtractRequestFields is a method declared on a Protobuf message that extracts fields from the interface. // The values from the extracted fields should be set in the appendToMap, in order to avoid allocations. ExtractRequestFields(appendToMap map[string]interface{}) } // CodeGenRequestFieldExtractor is a function that relies on code-generated functions that export log fields from requests. // These are usually coming from a protoc-plugin that generates additional information based on custom field options. func CodeGenRequestFieldExtractor(fullMethod string, req interface{}) map[string]interface{} { if ext, ok := req.(requestFieldsExtractor); ok { retMap := make(map[string]interface{}) ext.ExtractRequestFields(retMap) if len(retMap) == 0 { return nil } return retMap } return nil } // TagBasedRequestFieldExtractor is a function that relies on Go struct tags to export log fields from requests. // These are usually coming from a protoc-plugin, such as Gogo protobuf. // // message Metadata { // repeated string tags = 1 [ (gogoproto.moretags) = "log_field:\"meta_tags\"" ]; // } // // The tagName is configurable using the tagName variable. Here it would be "log_field". func TagBasedRequestFieldExtractor(tagName string) RequestFieldExtractorFunc { return func(fullMethod string, req interface{}) map[string]interface{} { retMap := make(map[string]interface{}) reflectMessageTags(req, retMap, tagName) if len(retMap) == 0 { return nil } return retMap } } func reflectMessageTags(msg interface{}, existingMap map[string]interface{}, tagName string) { v := reflect.ValueOf(msg) // Only deal with pointers to structs. if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct { return } // Deref the pointer get to the struct. v = v.Elem() t := v.Type() for i := 0; i < v.NumField(); i++ { field := v.Field(i) kind := field.Kind() // Only recurse down direct pointers, which should only be to nested structs. if (kind == reflect.Ptr || kind == reflect.Interface) && field.CanInterface() { reflectMessageTags(field.Interface(), existingMap, tagName) } // In case of arrays/slices (repeated fields) go down to the concrete type. if kind == reflect.Array || kind == reflect.Slice { if field.Len() == 0 { continue } kind = field.Index(0).Kind() } // Only be interested in if (kind >= reflect.Bool && kind <= reflect.Float64) || kind == reflect.String { if tag := t.Field(i).Tag.Get(tagName); tag != "" { existingMap[tag] = field.Interface() } } } return } go-grpc-middleware-1.3.0/tags/fieldextractor_test.go000066400000000000000000000052361404040257500225360ustar00rootroot00000000000000// Copyright 2017 Michal Witkowski. All Rights Reserved. // See LICENSE for licensing terms. package grpc_ctxtags_test import ( "testing" "time" "github.com/grpc-ecosystem/go-grpc-middleware/tags" pb_gogotestproto "github.com/grpc-ecosystem/go-grpc-middleware/testing/gogotestproto" pb_testproto "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestCodeGenRequestLogFieldExtractor_ManualIsDeclared(t *testing.T) { req := &pb_testproto.PingRequest{Value: "my_value"} valMap := grpc_ctxtags.CodeGenRequestFieldExtractor("", req) require.Len(t, valMap, 1, "PingRequest should have a ExtractLogFields method declared in test.manual_extractfields.pb") require.EqualValues(t, valMap, map[string]interface{}{"value": "my_value"}) } func TestTaggedRequestFiledExtractor_PingRequest(t *testing.T) { req := &pb_gogotestproto.PingRequest{ Ping: &pb_gogotestproto.Ping{ Id: &pb_gogotestproto.PingId{ Id: 1337, // logfield is ping_id }, Value: "something", }, Meta: &pb_gogotestproto.Metadata{ Tags: []string{"tagone", "tagtwo"}, // logfield is meta_tags }, } valMap := grpc_ctxtags.TagBasedRequestFieldExtractor("log_field")("", req) assert.EqualValues(t, 1337, valMap["ping_id"]) assert.EqualValues(t, []string{"tagone", "tagtwo"}, valMap["meta_tags"]) } func TestTaggedRequestFiledExtractor_PongRequest(t *testing.T) { req := &pb_gogotestproto.PongRequest{ Pong: &pb_gogotestproto.Pong{ Id: "some_id", }, Meta: &pb_gogotestproto.Metadata{ Tags: []string{"tagone", "tagtwo"}, // logfield is meta_tags }, } valMap := grpc_ctxtags.TagBasedRequestFieldExtractor("log_field")("", req) assert.EqualValues(t, "some_id", valMap["pong_id"]) assert.EqualValues(t, []string{"tagone", "tagtwo"}, valMap["meta_tags"]) } func TestTaggedRequestFiledExtractor_OneOfLogField(t *testing.T) { req := &pb_gogotestproto.OneOfLogField{ Identifier: &pb_gogotestproto.OneOfLogField_BarId{ BarId: "bar-log-field", }, } valMap := grpc_ctxtags.TagBasedRequestFieldExtractor("log_field")("", req) assert.EqualValues(t, "bar-log-field", valMap["bar_id"]) } // Test to ensure TagBasedRequestFieldExtractor does not panic when encountering private struct members such as // when using gogoproto.stdtime which results in a time.Time that has private struct members func TestTaggedRequestFiledExtractor_GogoTime(t *testing.T) { ts := time.Date(2010, 01, 01, 0, 0, 0, 0, time.UTC) req := &pb_gogotestproto.GoGoProtoStdTime{ Timestamp: &ts, } assert.NotPanics(t, func() { valMap := grpc_ctxtags.TagBasedRequestFieldExtractor("log_field")("", req) assert.Empty(t, valMap) }) } go-grpc-middleware-1.3.0/tags/interceptors.go000066400000000000000000000054651404040257500212050ustar00rootroot00000000000000// Copyright 2017 Michal Witkowski. All Rights Reserved. // See LICENSE for licensing terms. package grpc_ctxtags import ( "context" "google.golang.org/grpc" "google.golang.org/grpc/peer" "github.com/grpc-ecosystem/go-grpc-middleware" ) // UnaryServerInterceptor returns a new unary server interceptors that sets the values for request tags. func UnaryServerInterceptor(opts ...Option) grpc.UnaryServerInterceptor { o := evaluateOptions(opts) return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { newCtx := newTagsForCtx(ctx) if o.requestFieldsFunc != nil { setRequestFieldTags(newCtx, o.requestFieldsFunc, info.FullMethod, req) } return handler(newCtx, req) } } // StreamServerInterceptor returns a new streaming server interceptor that sets the values for request tags. func StreamServerInterceptor(opts ...Option) grpc.StreamServerInterceptor { o := evaluateOptions(opts) return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { newCtx := newTagsForCtx(stream.Context()) if o.requestFieldsFunc == nil { // Short-circuit, don't do the expensive bit of allocating a wrappedStream. wrappedStream := grpc_middleware.WrapServerStream(stream) wrappedStream.WrappedContext = newCtx return handler(srv, wrappedStream) } wrapped := &wrappedStream{stream, info, o, newCtx, true} err := handler(srv, wrapped) return err } } // wrappedStream is a thin wrapper around grpc.ServerStream that allows modifying context and extracts log fields from the initial message. type wrappedStream struct { grpc.ServerStream info *grpc.StreamServerInfo opts *options // WrappedContext is the wrapper's own Context. You can assign it. WrappedContext context.Context initial bool } // Context returns the wrapper's WrappedContext, overwriting the nested grpc.ServerStream.Context() func (w *wrappedStream) Context() context.Context { return w.WrappedContext } func (w *wrappedStream) RecvMsg(m interface{}) error { err := w.ServerStream.RecvMsg(m) // We only do log fields extraction on the single-request of a server-side stream. if !w.info.IsClientStream || w.opts.requestFieldsFromInitial && w.initial { w.initial = false setRequestFieldTags(w.Context(), w.opts.requestFieldsFunc, w.info.FullMethod, m) } return err } func newTagsForCtx(ctx context.Context) context.Context { t := NewTags() if peer, ok := peer.FromContext(ctx); ok { t.Set("peer.address", peer.Addr.String()) } return SetInContext(ctx, t) } func setRequestFieldTags(ctx context.Context, f RequestFieldExtractorFunc, fullMethodName string, req interface{}) { if valMap := f(fullMethodName, req); valMap != nil { t := Extract(ctx) for k, v := range valMap { t.Set("grpc.request."+k, v) } } } go-grpc-middleware-1.3.0/tags/interceptors_test.go000066400000000000000000000151061404040257500222350ustar00rootroot00000000000000package grpc_ctxtags_test import ( "context" "encoding/json" "io" "testing" "time" "github.com/grpc-ecosystem/go-grpc-middleware/tags" "github.com/grpc-ecosystem/go-grpc-middleware/testing" pb_testproto "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "google.golang.org/grpc" ) var ( goodPing = &pb_testproto.PingRequest{Value: "something", SleepTimeMs: 9999} anotherPing = &pb_testproto.PingRequest{Value: "else", SleepTimeMs: 9999} ) func tagsToJson(value map[string]interface{}) string { str, _ := json.Marshal(value) return string(str) } func tagsFromJson(t *testing.T, jstring string) map[string]interface{} { var msgMapTemplate interface{} err := json.Unmarshal([]byte(jstring), &msgMapTemplate) if err != nil { t.Fatalf("failed unmarshaling tags from response %v", err) } return msgMapTemplate.(map[string]interface{}) } type tagPingBack struct { pb_testproto.TestServiceServer } func (s *tagPingBack) Ping(ctx context.Context, ping *pb_testproto.PingRequest) (*pb_testproto.PingResponse, error) { return &pb_testproto.PingResponse{Value: tagsToJson(grpc_ctxtags.Extract(ctx).Values())}, nil } func (s *tagPingBack) PingError(ctx context.Context, ping *pb_testproto.PingRequest) (*pb_testproto.Empty, error) { return s.TestServiceServer.PingError(ctx, ping) } func (s *tagPingBack) PingList(ping *pb_testproto.PingRequest, stream pb_testproto.TestService_PingListServer) error { out := &pb_testproto.PingResponse{Value: tagsToJson(grpc_ctxtags.Extract(stream.Context()).Values())} return stream.Send(out) } func (s *tagPingBack) PingEmpty(ctx context.Context, empty *pb_testproto.Empty) (*pb_testproto.PingResponse, error) { return s.TestServiceServer.PingEmpty(ctx, empty) } func (s *tagPingBack) PingStream(stream pb_testproto.TestService_PingStreamServer) error { for { _, err := stream.Recv() if err == io.EOF { return nil } if err != nil { return err } out := &pb_testproto.PingResponse{Value: tagsToJson(grpc_ctxtags.Extract(stream.Context()).Values())} err = stream.Send(out) if err != nil { return err } } } func TestTaggingSuite(t *testing.T) { opts := []grpc_ctxtags.Option{ grpc_ctxtags.WithFieldExtractor(grpc_ctxtags.CodeGenRequestFieldExtractor), } s := &TaggingSuite{ InterceptorTestSuite: &grpc_testing.InterceptorTestSuite{ TestService: &tagPingBack{&grpc_testing.TestPingService{T: t}}, ServerOpts: []grpc.ServerOption{ grpc.StreamInterceptor(grpc_ctxtags.StreamServerInterceptor(opts...)), grpc.UnaryInterceptor(grpc_ctxtags.UnaryServerInterceptor(opts...)), }, }, } suite.Run(t, s) } type TaggingSuite struct { *grpc_testing.InterceptorTestSuite } func (s *TaggingSuite) SetupTest() { } func (s *TaggingSuite) TestPing_WithCustomTags() { resp, err := s.Client.Ping(s.SimpleCtx(), goodPing) require.NoError(s.T(), err, "must not be an error on a successful call") tags := tagsFromJson(s.T(), resp.Value) require.Len(s.T(), tags, 2, "the tags should contain only two values") assert.Equal(s.T(), tags["grpc.request.value"], "something", "the tags should contain the correct request value") assert.Contains(s.T(), tags, "peer.address", "the tags should contain a peer address") } func (s *TaggingSuite) TestPing_WithDeadline() { ctx, _ := context.WithDeadline(context.TODO(), time.Now().AddDate(0, 0, 5)) resp, err := s.Client.Ping(ctx, goodPing) require.NoError(s.T(), err, "must not be an error on a successful call") tags := tagsFromJson(s.T(), resp.Value) require.Len(s.T(), tags, 2, "the tags should contain only two values") assert.Equal(s.T(), tags["grpc.request.value"], "something", "the tags should contain the correct request value") assert.Contains(s.T(), tags, "peer.address", "the tags should contain a peer address") } func (s *TaggingSuite) TestPing_WithNoDeadline() { ctx := context.TODO() resp, err := s.Client.Ping(ctx, goodPing) require.NoError(s.T(), err, "must not be an error on a successful call") tags := tagsFromJson(s.T(), resp.Value) require.Len(s.T(), tags, 2, "the tags should contain only two values") assert.Equal(s.T(), tags["grpc.request.value"], "something", "the tags should contain the correct request value") assert.Contains(s.T(), tags, "peer.address", "the tags should contain a peer address") } func (s *TaggingSuite) TestPingList_WithCustomTags() { stream, err := s.Client.PingList(s.SimpleCtx(), goodPing) require.NoError(s.T(), err, "should not fail on establishing the stream") for { resp, err := stream.Recv() if err == io.EOF { break } require.NoError(s.T(), err, "reading stream should not fail") tags := tagsFromJson(s.T(), resp.Value) require.Len(s.T(), tags, 2, "the tags should contain only two values") assert.Contains(s.T(), tags, "peer.address", "the tags should contain a peer address") assert.Equal(s.T(), tags["grpc.request.value"], "something", "the tags should contain the correct request value") } } func TestTaggingOnInitialRequestSuite(t *testing.T) { opts := []grpc_ctxtags.Option{ grpc_ctxtags.WithFieldExtractorForInitialReq(grpc_ctxtags.CodeGenRequestFieldExtractor), } // Embeds TaggingSuite as the behaviour should be identical in // the case of unary and server-streamed calls s := &ClientStreamedTaggingSuite{ TaggingSuite: &TaggingSuite{ InterceptorTestSuite: &grpc_testing.InterceptorTestSuite{ TestService: &tagPingBack{&grpc_testing.TestPingService{T: t}}, ServerOpts: []grpc.ServerOption{ grpc.StreamInterceptor(grpc_ctxtags.StreamServerInterceptor(opts...)), grpc.UnaryInterceptor(grpc_ctxtags.UnaryServerInterceptor(opts...)), }, }, }, } suite.Run(t, s) } type ClientStreamedTaggingSuite struct { *TaggingSuite } func (s *ClientStreamedTaggingSuite) TestPingStream_WithCustomTagsFirstRequest() { stream, err := s.Client.PingStream(s.SimpleCtx()) require.NoError(s.T(), err, "should not fail on establishing the stream") count := 0 for { switch { case count == 0: err = stream.Send(goodPing) case count < 3: err = stream.Send(anotherPing) default: err = stream.CloseSend() } resp, err := stream.Recv() if err == io.EOF { break } require.NoError(s.T(), err, "reading stream should not fail") tags := tagsFromJson(s.T(), resp.Value) require.Len(s.T(), tags, 2, "the tags should contain only two values") assert.Equal(s.T(), tags["grpc.request.value"], "something", "the tags should contain the correct request value") assert.Contains(s.T(), tags, "peer.address", "the tags should contain a peer address") count++ } assert.Equal(s.T(), count, 3) } go-grpc-middleware-1.3.0/tags/logrus/000077500000000000000000000000001404040257500174365ustar00rootroot00000000000000go-grpc-middleware-1.3.0/tags/logrus/context.go000066400000000000000000000014451404040257500214550ustar00rootroot00000000000000package ctx_logrus import ( "context" "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus/ctxlogrus" "github.com/sirupsen/logrus" ) // AddFields adds logrus fields to the logger. // Deprecated: should use the ctxlogrus.Extract instead func AddFields(ctx context.Context, fields logrus.Fields) { ctxlogrus.AddFields(ctx, fields) } // Extract takes the call-scoped logrus.Entry from grpc_logrus middleware. // Deprecated: should use the ctxlogrus.Extract instead func Extract(ctx context.Context) *logrus.Entry { return ctxlogrus.Extract(ctx) } // ToContext adds the logrus.Entry to the context for extraction later. // Deprecated: should use ctxlogrus.ToContext instead func ToContext(ctx context.Context, entry *logrus.Entry) context.Context { return ctxlogrus.ToContext(ctx, entry) } go-grpc-middleware-1.3.0/tags/options.go000066400000000000000000000022271404040257500201500ustar00rootroot00000000000000// Copyright 2017 Michal Witkowski. All Rights Reserved. // See LICENSE for licensing terms. package grpc_ctxtags var ( defaultOptions = &options{ requestFieldsFunc: nil, } ) type options struct { requestFieldsFunc RequestFieldExtractorFunc requestFieldsFromInitial bool } func evaluateOptions(opts []Option) *options { optCopy := &options{} *optCopy = *defaultOptions for _, o := range opts { o(optCopy) } return optCopy } type Option func(*options) // WithFieldExtractor customizes the function for extracting log fields from protobuf messages, for // unary and server-streamed methods only. func WithFieldExtractor(f RequestFieldExtractorFunc) Option { return func(o *options) { o.requestFieldsFunc = f } } // WithFieldExtractorForInitialReq customizes the function for extracting log fields from protobuf messages, // for all unary and streaming methods. For client-streams and bidirectional-streams, the tags will be // extracted from the first message from the client. func WithFieldExtractorForInitialReq(f RequestFieldExtractorFunc) Option { return func(o *options) { o.requestFieldsFunc = f o.requestFieldsFromInitial = true } } go-grpc-middleware-1.3.0/tags/zap/000077500000000000000000000000001404040257500167155ustar00rootroot00000000000000go-grpc-middleware-1.3.0/tags/zap/context.go000066400000000000000000000020101404040257500207210ustar00rootroot00000000000000package ctx_zap import ( "context" "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap" "go.uber.org/zap" "go.uber.org/zap/zapcore" ) // AddFields adds zap fields to the logger. // Deprecated: should use the ctxzap.AddFields instead func AddFields(ctx context.Context, fields ...zapcore.Field) { ctxzap.AddFields(ctx, fields...) } // Extract takes the call-scoped Logger from grpc_zap middleware. // Deprecated: should use the ctxzap.Extract instead func Extract(ctx context.Context) *zap.Logger { return ctxzap.Extract(ctx) } // TagsToFields transforms the Tags on the supplied context into zap fields. // Deprecated: use ctxzap.TagsToFields func TagsToFields(ctx context.Context) []zapcore.Field { return ctxzap.TagsToFields(ctx) } // ToContext adds the zap.Logger to the context for extraction later. // Returning the new context that has been created. // Deprecated: use ctxzap.ToContext func ToContext(ctx context.Context, logger *zap.Logger) context.Context { return ctxzap.ToContext(ctx, logger) } go-grpc-middleware-1.3.0/testing/000077500000000000000000000000001404040257500166425ustar00rootroot00000000000000go-grpc-middleware-1.3.0/testing/gogotestproto/000077500000000000000000000000001404040257500215615ustar00rootroot00000000000000go-grpc-middleware-1.3.0/testing/gogotestproto/Makefile000066400000000000000000000002311404040257500232150ustar00rootroot00000000000000all: test_go fields_go: fields.proto PATH="${GOPATH}/bin:${PATH}" protoc \ -I. \ -I${GOPATH}/src \ --gogo_out=plugins=grpc:. \ fields.proto go-grpc-middleware-1.3.0/testing/gogotestproto/fields.pb.go000066400000000000000000000361771404040257500237740ustar00rootroot00000000000000// Code generated by protoc-gen-gogo. DO NOT EDIT. // source: fields.proto // This file is used for testing discovery of log fields from requests using reflection and gogo proto more tags. package mwitkow_testproto import ( fmt "fmt" _ "github.com/gogo/protobuf/gogoproto" proto "github.com/golang/protobuf/proto" _ "github.com/golang/protobuf/ptypes/timestamp" math "math" time "time" ) // Reference imports to suppress errors if they are not otherwise used. var _ = proto.Marshal var _ = fmt.Errorf var _ = math.Inf var _ = time.Kitchen // This is a compile-time assertion to ensure that this generated file // is compatible with the proto package it is being compiled against. // A compilation error at this line likely means your copy of the // proto package needs to be updated. const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package type Metadata struct { Tags []string `protobuf:"bytes,1,rep,name=tags,proto3" json:"tags,omitempty" log_field:"meta_tags"` XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_unrecognized []byte `json:"-"` XXX_sizecache int32 `json:"-"` } func (m *Metadata) Reset() { *m = Metadata{} } func (m *Metadata) String() string { return proto.CompactTextString(m) } func (*Metadata) ProtoMessage() {} func (*Metadata) Descriptor() ([]byte, []int) { return fileDescriptor_d39ad626ec0e575e, []int{0} } func (m *Metadata) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_Metadata.Unmarshal(m, b) } func (m *Metadata) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { return xxx_messageInfo_Metadata.Marshal(b, m, deterministic) } func (m *Metadata) XXX_Merge(src proto.Message) { xxx_messageInfo_Metadata.Merge(m, src) } func (m *Metadata) XXX_Size() int { return xxx_messageInfo_Metadata.Size(m) } func (m *Metadata) XXX_DiscardUnknown() { xxx_messageInfo_Metadata.DiscardUnknown(m) } var xxx_messageInfo_Metadata proto.InternalMessageInfo func (m *Metadata) GetTags() []string { if m != nil { return m.Tags } return nil } type PingId struct { Id int32 `protobuf:"varint,1,opt,name=id,proto3" json:"id,omitempty" log_field:"ping_id"` XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_unrecognized []byte `json:"-"` XXX_sizecache int32 `json:"-"` } func (m *PingId) Reset() { *m = PingId{} } func (m *PingId) String() string { return proto.CompactTextString(m) } func (*PingId) ProtoMessage() {} func (*PingId) Descriptor() ([]byte, []int) { return fileDescriptor_d39ad626ec0e575e, []int{1} } func (m *PingId) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_PingId.Unmarshal(m, b) } func (m *PingId) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { return xxx_messageInfo_PingId.Marshal(b, m, deterministic) } func (m *PingId) XXX_Merge(src proto.Message) { xxx_messageInfo_PingId.Merge(m, src) } func (m *PingId) XXX_Size() int { return xxx_messageInfo_PingId.Size(m) } func (m *PingId) XXX_DiscardUnknown() { xxx_messageInfo_PingId.DiscardUnknown(m) } var xxx_messageInfo_PingId proto.InternalMessageInfo func (m *PingId) GetId() int32 { if m != nil { return m.Id } return 0 } type Ping struct { Id *PingId `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` Value string `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"` XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_unrecognized []byte `json:"-"` XXX_sizecache int32 `json:"-"` } func (m *Ping) Reset() { *m = Ping{} } func (m *Ping) String() string { return proto.CompactTextString(m) } func (*Ping) ProtoMessage() {} func (*Ping) Descriptor() ([]byte, []int) { return fileDescriptor_d39ad626ec0e575e, []int{2} } func (m *Ping) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_Ping.Unmarshal(m, b) } func (m *Ping) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { return xxx_messageInfo_Ping.Marshal(b, m, deterministic) } func (m *Ping) XXX_Merge(src proto.Message) { xxx_messageInfo_Ping.Merge(m, src) } func (m *Ping) XXX_Size() int { return xxx_messageInfo_Ping.Size(m) } func (m *Ping) XXX_DiscardUnknown() { xxx_messageInfo_Ping.DiscardUnknown(m) } var xxx_messageInfo_Ping proto.InternalMessageInfo func (m *Ping) GetId() *PingId { if m != nil { return m.Id } return nil } func (m *Ping) GetValue() string { if m != nil { return m.Value } return "" } type PingRequest struct { Ping *Ping `protobuf:"bytes,1,opt,name=ping,proto3" json:"ping,omitempty"` Meta *Metadata `protobuf:"bytes,2,opt,name=meta,proto3" json:"meta,omitempty"` XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_unrecognized []byte `json:"-"` XXX_sizecache int32 `json:"-"` } func (m *PingRequest) Reset() { *m = PingRequest{} } func (m *PingRequest) String() string { return proto.CompactTextString(m) } func (*PingRequest) ProtoMessage() {} func (*PingRequest) Descriptor() ([]byte, []int) { return fileDescriptor_d39ad626ec0e575e, []int{3} } func (m *PingRequest) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_PingRequest.Unmarshal(m, b) } func (m *PingRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { return xxx_messageInfo_PingRequest.Marshal(b, m, deterministic) } func (m *PingRequest) XXX_Merge(src proto.Message) { xxx_messageInfo_PingRequest.Merge(m, src) } func (m *PingRequest) XXX_Size() int { return xxx_messageInfo_PingRequest.Size(m) } func (m *PingRequest) XXX_DiscardUnknown() { xxx_messageInfo_PingRequest.DiscardUnknown(m) } var xxx_messageInfo_PingRequest proto.InternalMessageInfo func (m *PingRequest) GetPing() *Ping { if m != nil { return m.Ping } return nil } func (m *PingRequest) GetMeta() *Metadata { if m != nil { return m.Meta } return nil } type Pong struct { Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty" log_field:"pong_id"` XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_unrecognized []byte `json:"-"` XXX_sizecache int32 `json:"-"` } func (m *Pong) Reset() { *m = Pong{} } func (m *Pong) String() string { return proto.CompactTextString(m) } func (*Pong) ProtoMessage() {} func (*Pong) Descriptor() ([]byte, []int) { return fileDescriptor_d39ad626ec0e575e, []int{4} } func (m *Pong) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_Pong.Unmarshal(m, b) } func (m *Pong) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { return xxx_messageInfo_Pong.Marshal(b, m, deterministic) } func (m *Pong) XXX_Merge(src proto.Message) { xxx_messageInfo_Pong.Merge(m, src) } func (m *Pong) XXX_Size() int { return xxx_messageInfo_Pong.Size(m) } func (m *Pong) XXX_DiscardUnknown() { xxx_messageInfo_Pong.DiscardUnknown(m) } var xxx_messageInfo_Pong proto.InternalMessageInfo func (m *Pong) GetId() string { if m != nil { return m.Id } return "" } type OneOfLogField struct { // Types that are valid to be assigned to Identifier: // *OneOfLogField_BarId // *OneOfLogField_BazId Identifier isOneOfLogField_Identifier `protobuf_oneof:"identifier"` XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_unrecognized []byte `json:"-"` XXX_sizecache int32 `json:"-"` } func (m *OneOfLogField) Reset() { *m = OneOfLogField{} } func (m *OneOfLogField) String() string { return proto.CompactTextString(m) } func (*OneOfLogField) ProtoMessage() {} func (*OneOfLogField) Descriptor() ([]byte, []int) { return fileDescriptor_d39ad626ec0e575e, []int{5} } func (m *OneOfLogField) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_OneOfLogField.Unmarshal(m, b) } func (m *OneOfLogField) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { return xxx_messageInfo_OneOfLogField.Marshal(b, m, deterministic) } func (m *OneOfLogField) XXX_Merge(src proto.Message) { xxx_messageInfo_OneOfLogField.Merge(m, src) } func (m *OneOfLogField) XXX_Size() int { return xxx_messageInfo_OneOfLogField.Size(m) } func (m *OneOfLogField) XXX_DiscardUnknown() { xxx_messageInfo_OneOfLogField.DiscardUnknown(m) } var xxx_messageInfo_OneOfLogField proto.InternalMessageInfo type isOneOfLogField_Identifier interface { isOneOfLogField_Identifier() } type OneOfLogField_BarId struct { BarId string `protobuf:"bytes,1,opt,name=bar_id,json=barId,proto3,oneof" json:"bar_id,omitempty" log_field:"bar_id"` } type OneOfLogField_BazId struct { BazId string `protobuf:"bytes,2,opt,name=baz_id,json=bazId,proto3,oneof" json:"baz_id,omitempty" log_field:"baz_id"` } func (*OneOfLogField_BarId) isOneOfLogField_Identifier() {} func (*OneOfLogField_BazId) isOneOfLogField_Identifier() {} func (m *OneOfLogField) GetIdentifier() isOneOfLogField_Identifier { if m != nil { return m.Identifier } return nil } func (m *OneOfLogField) GetBarId() string { if x, ok := m.GetIdentifier().(*OneOfLogField_BarId); ok { return x.BarId } return "" } func (m *OneOfLogField) GetBazId() string { if x, ok := m.GetIdentifier().(*OneOfLogField_BazId); ok { return x.BazId } return "" } // XXX_OneofWrappers is for the internal use of the proto package. func (*OneOfLogField) XXX_OneofWrappers() []interface{} { return []interface{}{ (*OneOfLogField_BarId)(nil), (*OneOfLogField_BazId)(nil), } } type PongRequest struct { Pong *Pong `protobuf:"bytes,1,opt,name=pong,proto3" json:"pong,omitempty"` Meta *Metadata `protobuf:"bytes,2,opt,name=meta,proto3" json:"meta,omitempty"` XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_unrecognized []byte `json:"-"` XXX_sizecache int32 `json:"-"` } func (m *PongRequest) Reset() { *m = PongRequest{} } func (m *PongRequest) String() string { return proto.CompactTextString(m) } func (*PongRequest) ProtoMessage() {} func (*PongRequest) Descriptor() ([]byte, []int) { return fileDescriptor_d39ad626ec0e575e, []int{6} } func (m *PongRequest) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_PongRequest.Unmarshal(m, b) } func (m *PongRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { return xxx_messageInfo_PongRequest.Marshal(b, m, deterministic) } func (m *PongRequest) XXX_Merge(src proto.Message) { xxx_messageInfo_PongRequest.Merge(m, src) } func (m *PongRequest) XXX_Size() int { return xxx_messageInfo_PongRequest.Size(m) } func (m *PongRequest) XXX_DiscardUnknown() { xxx_messageInfo_PongRequest.DiscardUnknown(m) } var xxx_messageInfo_PongRequest proto.InternalMessageInfo func (m *PongRequest) GetPong() *Pong { if m != nil { return m.Pong } return nil } func (m *PongRequest) GetMeta() *Metadata { if m != nil { return m.Meta } return nil } type GoGoProtoStdTime struct { Timestamp *time.Time `protobuf:"bytes,1,opt,name=timestamp,proto3,stdtime" json:"timestamp,omitempty"` XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_unrecognized []byte `json:"-"` XXX_sizecache int32 `json:"-"` } func (m *GoGoProtoStdTime) Reset() { *m = GoGoProtoStdTime{} } func (m *GoGoProtoStdTime) String() string { return proto.CompactTextString(m) } func (*GoGoProtoStdTime) ProtoMessage() {} func (*GoGoProtoStdTime) Descriptor() ([]byte, []int) { return fileDescriptor_d39ad626ec0e575e, []int{7} } func (m *GoGoProtoStdTime) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_GoGoProtoStdTime.Unmarshal(m, b) } func (m *GoGoProtoStdTime) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { return xxx_messageInfo_GoGoProtoStdTime.Marshal(b, m, deterministic) } func (m *GoGoProtoStdTime) XXX_Merge(src proto.Message) { xxx_messageInfo_GoGoProtoStdTime.Merge(m, src) } func (m *GoGoProtoStdTime) XXX_Size() int { return xxx_messageInfo_GoGoProtoStdTime.Size(m) } func (m *GoGoProtoStdTime) XXX_DiscardUnknown() { xxx_messageInfo_GoGoProtoStdTime.DiscardUnknown(m) } var xxx_messageInfo_GoGoProtoStdTime proto.InternalMessageInfo func (m *GoGoProtoStdTime) GetTimestamp() *time.Time { if m != nil { return m.Timestamp } return nil } func init() { proto.RegisterType((*Metadata)(nil), "mwitkow.testproto.Metadata") proto.RegisterType((*PingId)(nil), "mwitkow.testproto.PingId") proto.RegisterType((*Ping)(nil), "mwitkow.testproto.Ping") proto.RegisterType((*PingRequest)(nil), "mwitkow.testproto.PingRequest") proto.RegisterType((*Pong)(nil), "mwitkow.testproto.Pong") proto.RegisterType((*OneOfLogField)(nil), "mwitkow.testproto.OneOfLogField") proto.RegisterType((*PongRequest)(nil), "mwitkow.testproto.PongRequest") proto.RegisterType((*GoGoProtoStdTime)(nil), "mwitkow.testproto.GoGoProtoStdTime") } func init() { proto.RegisterFile("fields.proto", fileDescriptor_d39ad626ec0e575e) } var fileDescriptor_d39ad626ec0e575e = []byte{ // 427 bytes of a gzipped FileDescriptorProto 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x9c, 0x92, 0xcd, 0x8e, 0xd3, 0x30, 0x14, 0x85, 0x27, 0x25, 0xad, 0xe8, 0xed, 0x20, 0x81, 0xf9, 0x99, 0x4e, 0x91, 0x48, 0xe5, 0x0d, 0x45, 0x68, 0x1c, 0x31, 0xac, 0x60, 0xc1, 0x22, 0x0b, 0x4a, 0x25, 0xd0, 0x54, 0x66, 0xf6, 0x95, 0x83, 0x5d, 0x63, 0x4d, 0x93, 0x5b, 0x1a, 0x87, 0x91, 0xba, 0xe1, 0x15, 0x58, 0xf2, 0x74, 0x65, 0xc1, 0x1b, 0xf4, 0x09, 0x90, 0x9d, 0xfe, 0x04, 0x0d, 0x65, 0xc1, 0x2e, 0xf6, 0x3d, 0xdf, 0x71, 0xce, 0xd1, 0x85, 0xe3, 0xa9, 0x51, 0x33, 0x59, 0xb0, 0xf9, 0x02, 0x2d, 0x92, 0x7b, 0xd9, 0xb5, 0xb1, 0x57, 0x78, 0xcd, 0xac, 0x2a, 0xac, 0xbf, 0xea, 0x9d, 0x69, 0x63, 0x3f, 0x97, 0x29, 0xfb, 0x84, 0x59, 0xac, 0x51, 0x63, 0xec, 0xaf, 0xd3, 0x72, 0xea, 0x4f, 0xfe, 0xe0, 0xbf, 0x2a, 0x87, 0x5e, 0xa4, 0x11, 0xf5, 0x4c, 0xed, 0x55, 0xd6, 0x64, 0xaa, 0xb0, 0x22, 0x9b, 0x57, 0x02, 0xfa, 0x0a, 0x6e, 0x7f, 0x50, 0x56, 0x48, 0x61, 0x05, 0x39, 0x83, 0xd0, 0x0a, 0x5d, 0x74, 0x83, 0xfe, 0xad, 0x41, 0x3b, 0x39, 0x5d, 0xaf, 0xa2, 0x87, 0x33, 0xd4, 0x13, 0xff, 0x4b, 0xaf, 0x69, 0xa6, 0xac, 0x98, 0xb8, 0x39, 0xe5, 0x5e, 0x46, 0x5f, 0x40, 0x6b, 0x6c, 0x72, 0x3d, 0x92, 0xe4, 0x29, 0x34, 0x8c, 0xec, 0x06, 0xfd, 0x60, 0xd0, 0x4c, 0x4e, 0xd6, 0xab, 0xe8, 0x7e, 0x0d, 0x9b, 0x9b, 0x5c, 0x4f, 0x8c, 0xa4, 0xbc, 0x61, 0x24, 0x1d, 0x42, 0xe8, 0x10, 0xf2, 0x6c, 0x07, 0x74, 0xce, 0x4f, 0xd9, 0x8d, 0x94, 0xac, 0xf2, 0x75, 0x08, 0x79, 0x00, 0xcd, 0xaf, 0x62, 0x56, 0xaa, 0x6e, 0xa3, 0x1f, 0x0c, 0xda, 0xbc, 0x3a, 0xd0, 0x2b, 0xe8, 0x38, 0x0d, 0x57, 0x5f, 0x4a, 0x55, 0x58, 0xf2, 0x1c, 0x42, 0xf7, 0xce, 0xc6, 0xf1, 0xe4, 0x80, 0x23, 0xf7, 0x22, 0x12, 0x43, 0xe8, 0xb2, 0x78, 0xc3, 0xce, 0xf9, 0xe3, 0xbf, 0x88, 0xb7, 0x8d, 0x70, 0x2f, 0xa4, 0x31, 0x84, 0x63, 0xcc, 0x75, 0x2d, 0x66, 0xfb, 0x66, 0x4c, 0xac, 0xc5, 0xfc, 0x06, 0x77, 0x2e, 0x72, 0x75, 0x31, 0x7d, 0x8f, 0xfa, 0xad, 0x1b, 0x93, 0x18, 0x5a, 0xa9, 0x58, 0x4c, 0x76, 0xf4, 0xa3, 0xf5, 0x2a, 0x22, 0x35, 0xba, 0x1a, 0xd2, 0x77, 0x47, 0xbc, 0x99, 0x8a, 0xc5, 0x68, 0x03, 0x2c, 0x1d, 0xd0, 0x38, 0x00, 0x2c, 0xf7, 0xc0, 0x72, 0x24, 0x93, 0x63, 0x00, 0x23, 0x55, 0x6e, 0xcd, 0xd4, 0xa8, 0x85, 0xaf, 0x07, 0xff, 0xac, 0x07, 0xff, 0x5d, 0x0f, 0xfa, 0x7a, 0xf0, 0x7f, 0xea, 0xe1, 0x70, 0x77, 0x88, 0x43, 0x1c, 0xbb, 0xd9, 0x47, 0x2b, 0x2f, 0x4d, 0xa6, 0xc8, 0x1b, 0x68, 0xef, 0x36, 0x6d, 0xf3, 0x6c, 0x8f, 0x55, 0xbb, 0xc8, 0xb6, 0xbb, 0xc8, 0x2e, 0xb7, 0x8a, 0x24, 0xfc, 0xfe, 0x33, 0x0a, 0xf8, 0x1e, 0x49, 0xc2, 0x1f, 0xbf, 0x9e, 0x1c, 0xa5, 0x2d, 0x2f, 0x7d, 0xf9, 0x3b, 0x00, 0x00, 0xff, 0xff, 0xa5, 0x38, 0xdf, 0x28, 0x16, 0x03, 0x00, 0x00, } go-grpc-middleware-1.3.0/testing/gogotestproto/fields.proto000066400000000000000000000020761404040257500241210ustar00rootroot00000000000000syntax = "proto3"; // This file is used for testing discovery of log fields from requests using reflection and gogo proto more tags. package mwitkow.testproto; import "github.com/gogo/protobuf/gogoproto/gogo.proto"; import "google/protobuf/timestamp.proto"; option (gogoproto.gogoproto_import) = false; message Metadata { repeated string tags = 1 [(gogoproto.moretags) = "log_field:\"meta_tags\""]; } message PingId { int32 id = 1 [(gogoproto.moretags) = "log_field:\"ping_id\""]; } message Ping { PingId id = 1; string value = 2; } message PingRequest { Ping ping = 1; Metadata meta = 2; } message Pong { string id = 1 [(gogoproto.moretags) = "log_field:\"pong_id\""]; } message OneOfLogField { oneof identifier { string bar_id = 1 [(gogoproto.moretags) = "log_field:\"bar_id\""]; string baz_id = 2 [(gogoproto.moretags) = "log_field:\"baz_id\""]; } } message PongRequest { Pong pong = 1; Metadata meta = 2; } message GoGoProtoStdTime { google.protobuf.Timestamp timestamp = 1 [(gogoproto.stdtime) = true]; } go-grpc-middleware-1.3.0/testing/interceptor_suite.go000066400000000000000000000132171404040257500227440ustar00rootroot00000000000000// Copyright 2016 Michal Witkowski. All Rights Reserved. // See LICENSE for licensing terms. package grpc_testing import ( "context" "crypto/rand" "crypto/rsa" "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/pem" "flag" "math/big" "net" "time" pb_testproto "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "google.golang.org/grpc" "google.golang.org/grpc/credentials" ) var ( flagTls = flag.Bool("use_tls", true, "whether all gRPC middleware tests should use tls") certPEM []byte keyPEM []byte ) // InterceptorTestSuite is a testify/Suite that starts a gRPC PingService server and a client. type InterceptorTestSuite struct { suite.Suite TestService pb_testproto.TestServiceServer ServerOpts []grpc.ServerOption ClientOpts []grpc.DialOption serverAddr string ServerListener net.Listener Server *grpc.Server clientConn *grpc.ClientConn Client pb_testproto.TestServiceClient restartServerWithDelayedStart chan time.Duration serverRunning chan bool } func (s *InterceptorTestSuite) SetupSuite() { s.restartServerWithDelayedStart = make(chan time.Duration) s.serverRunning = make(chan bool) s.serverAddr = "127.0.0.1:0" var err error certPEM, keyPEM, err = generateCertAndKey([]string{"localhost", "example.com"}) if err != nil { s.T().Fatalf("unable to generate test certificate/key: " + err.Error()) } go func() { for { var err error s.ServerListener, err = net.Listen("tcp", s.serverAddr) if err != nil { s.T().Fatalf("unable to listen on address %s: %v", s.serverAddr, err) } s.serverAddr = s.ServerListener.Addr().String() require.NoError(s.T(), err, "must be able to allocate a port for serverListener") if *flagTls { cert, err := tls.X509KeyPair(certPEM, keyPEM) if err != nil { s.T().Fatalf("unable to load test TLS certificate: %v", err) } creds := credentials.NewServerTLSFromCert(&cert) s.ServerOpts = append(s.ServerOpts, grpc.Creds(creds)) } // This is the point where we hook up the interceptor s.Server = grpc.NewServer(s.ServerOpts...) // Create a service of the instantiator hasn't provided one. if s.TestService == nil { s.TestService = &TestPingService{T: s.T()} } pb_testproto.RegisterTestServiceServer(s.Server, s.TestService) go func() { s.Server.Serve(s.ServerListener) }() if s.Client == nil { s.Client = s.NewClient(s.ClientOpts...) } s.serverRunning <- true d := <-s.restartServerWithDelayedStart s.Server.Stop() time.Sleep(d) } }() select { case <-s.serverRunning: case <-time.After(2 * time.Second): s.T().Fatal("server failed to start before deadline") } } func (s *InterceptorTestSuite) RestartServer(delayedStart time.Duration) <-chan bool { s.restartServerWithDelayedStart <- delayedStart time.Sleep(10 * time.Millisecond) return s.serverRunning } func (s *InterceptorTestSuite) NewClient(dialOpts ...grpc.DialOption) pb_testproto.TestServiceClient { newDialOpts := append(dialOpts, grpc.WithBlock()) if *flagTls { cp := x509.NewCertPool() if !cp.AppendCertsFromPEM(certPEM) { s.T().Fatal("failed to append certificate") } creds := credentials.NewTLS(&tls.Config{ServerName: "localhost", RootCAs: cp}) newDialOpts = append(newDialOpts, grpc.WithTransportCredentials(creds)) } else { newDialOpts = append(newDialOpts, grpc.WithInsecure()) } ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() clientConn, err := grpc.DialContext(ctx, s.ServerAddr(), newDialOpts...) require.NoError(s.T(), err, "must not error on client Dial") return pb_testproto.NewTestServiceClient(clientConn) } func (s *InterceptorTestSuite) ServerAddr() string { return s.serverAddr } func (s *InterceptorTestSuite) SimpleCtx() context.Context { ctx, _ := context.WithTimeout(context.TODO(), 2*time.Second) return ctx } func (s *InterceptorTestSuite) DeadlineCtx(deadline time.Time) context.Context { ctx, _ := context.WithDeadline(context.TODO(), deadline) return ctx } func (s *InterceptorTestSuite) TearDownSuite() { time.Sleep(10 * time.Millisecond) if s.ServerListener != nil { s.Server.GracefulStop() s.T().Logf("stopped grpc.Server at: %v", s.ServerAddr()) s.ServerListener.Close() } if s.clientConn != nil { s.clientConn.Close() } } // generateCertAndKey copied from https://github.com/johanbrandhorst/certify/blob/master/issuers/vault/vault_suite_test.go#L255 // with minor modifications. func generateCertAndKey(san []string) ([]byte, []byte, error) { priv, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { return nil, nil, err } notBefore := time.Now() notAfter := notBefore.Add(time.Hour) serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) if err != nil { return nil, nil, err } template := x509.Certificate{ SerialNumber: serialNumber, Subject: pkix.Name{ CommonName: "example.com", }, NotBefore: notBefore, NotAfter: notAfter, KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, BasicConstraintsValid: true, DNSNames: san, } derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, priv.Public(), priv) if err != nil { return nil, nil, err } certOut := pem.EncodeToMemory(&pem.Block{ Type: "CERTIFICATE", Bytes: derBytes, }) keyOut := pem.EncodeToMemory(&pem.Block{ Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv), }) return certOut, keyOut, nil } go-grpc-middleware-1.3.0/testing/mutex_readerwriter.go000066400000000000000000000012411404040257500231100ustar00rootroot00000000000000package grpc_testing import ( "io" "sync" ) // MutexReadWriter is a io.ReadWriter that can be read and worked on from multiple go routines. type MutexReadWriter struct { sync.Mutex rw io.ReadWriter } // NewMutexReadWriter creates a new thread-safe io.ReadWriter. func NewMutexReadWriter(rw io.ReadWriter) *MutexReadWriter { return &MutexReadWriter{rw: rw} } // Write implements the io.Writer interface. func (m *MutexReadWriter) Write(p []byte) (int, error) { m.Lock() defer m.Unlock() return m.rw.Write(p) } // Read implements the io.Reader interface. func (m *MutexReadWriter) Read(p []byte) (int, error) { m.Lock() defer m.Unlock() return m.rw.Read(p) } go-grpc-middleware-1.3.0/testing/pingservice.go000066400000000000000000000037701404040257500215160ustar00rootroot00000000000000// Copyright 2016 Michal Witkowski. All Rights Reserved. // See LICENSE for licensing terms. /* Package `grpc_testing` provides helper functions for testing validators in this package. */ package grpc_testing import ( "context" "io" "testing" pb_testproto "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) const ( // DefaultPongValue is the default value used. DefaultResponseValue = "default_response_value" // ListResponseCount is the expected number of responses to PingList ListResponseCount = 100 ) type TestPingService struct { T *testing.T } func (s *TestPingService) PingEmpty(ctx context.Context, _ *pb_testproto.Empty) (*pb_testproto.PingResponse, error) { return &pb_testproto.PingResponse{Value: DefaultResponseValue, Counter: 42}, nil } func (s *TestPingService) Ping(ctx context.Context, ping *pb_testproto.PingRequest) (*pb_testproto.PingResponse, error) { // Send user trailers and headers. return &pb_testproto.PingResponse{Value: ping.Value, Counter: 42}, nil } func (s *TestPingService) PingError(ctx context.Context, ping *pb_testproto.PingRequest) (*pb_testproto.Empty, error) { code := codes.Code(ping.ErrorCodeReturned) return nil, status.Errorf(code, "Userspace error.") } func (s *TestPingService) PingList(ping *pb_testproto.PingRequest, stream pb_testproto.TestService_PingListServer) error { if ping.ErrorCodeReturned != 0 { return status.Errorf(codes.Code(ping.ErrorCodeReturned), "foobar") } // Send user trailers and headers. for i := 0; i < ListResponseCount; i++ { stream.Send(&pb_testproto.PingResponse{Value: ping.Value, Counter: int32(i)}) } return nil } func (s *TestPingService) PingStream(stream pb_testproto.TestService_PingStreamServer) error { count := 0 for true { ping, err := stream.Recv() if err == io.EOF { break } if err != nil { return err } stream.Send(&pb_testproto.PingResponse{Value: ping.Value, Counter: int32(count)}) count += 1 } return nil } go-grpc-middleware-1.3.0/testing/testproto/000077500000000000000000000000001404040257500207055ustar00rootroot00000000000000go-grpc-middleware-1.3.0/testing/testproto/Makefile000066400000000000000000000002211404040257500223400ustar00rootroot00000000000000all: test_go test_go: test.proto PATH="${GOPATH}/bin:${PATH}" protoc \ -I. \ -I${GOPATH}/src \ --go_out=plugins=grpc:. \ test.proto go-grpc-middleware-1.3.0/testing/testproto/test.manual_extractfields.pb.go000066400000000000000000000004011404040257500270030ustar00rootroot00000000000000// Manual code for logging field extraction tests. package mwitkow_testproto // This is implementing grpc_logging.requestLogFieldsExtractor func (m *PingRequest) ExtractRequestFields(appendToMap map[string]interface{}) { appendToMap["value"] = m.Value } go-grpc-middleware-1.3.0/testing/testproto/test.manual_validator.pb.go000066400000000000000000000010101404040257500261240ustar00rootroot00000000000000// Manual code for validation tests. package mwitkow_testproto import ( "errors" "math" ) // Implements the legacy validation interface from protoc-gen-validate. func (p *PingRequest) Validate() error { if p.SleepTimeMs > 10000 { return errors.New("cannot sleep for more than 10s") } return nil } // Implements the new validation interface from protoc-gen-validate. func (p *PingResponse) Validate(bool) error { if p.Counter > math.MaxInt16 { return errors.New("ping allocation exceeded") } return nil } go-grpc-middleware-1.3.0/testing/testproto/test.pb.go000066400000000000000000000374241404040257500226250ustar00rootroot00000000000000// Code generated by protoc-gen-go. DO NOT EDIT. // source: test.proto package mwitkow_testproto import ( context "context" fmt "fmt" proto "github.com/golang/protobuf/proto" grpc "google.golang.org/grpc" codes "google.golang.org/grpc/codes" status "google.golang.org/grpc/status" math "math" ) // Reference imports to suppress errors if they are not otherwise used. var _ = proto.Marshal var _ = fmt.Errorf var _ = math.Inf // This is a compile-time assertion to ensure that this generated file // is compatible with the proto package it is being compiled against. // A compilation error at this line likely means your copy of the // proto package needs to be updated. // TODO(domgreen): This is blocking us from upgrading to later versions of gRPC. Fix in new PR. const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package type Empty struct { XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_unrecognized []byte `json:"-"` XXX_sizecache int32 `json:"-"` } func (m *Empty) Reset() { *m = Empty{} } func (m *Empty) String() string { return proto.CompactTextString(m) } func (*Empty) ProtoMessage() {} func (*Empty) Descriptor() ([]byte, []int) { return fileDescriptor_c161fcfdc0c3ff1e, []int{0} } func (m *Empty) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_Empty.Unmarshal(m, b) } func (m *Empty) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { return xxx_messageInfo_Empty.Marshal(b, m, deterministic) } func (m *Empty) XXX_Merge(src proto.Message) { xxx_messageInfo_Empty.Merge(m, src) } func (m *Empty) XXX_Size() int { return xxx_messageInfo_Empty.Size(m) } func (m *Empty) XXX_DiscardUnknown() { xxx_messageInfo_Empty.DiscardUnknown(m) } var xxx_messageInfo_Empty proto.InternalMessageInfo type PingRequest struct { Value string `protobuf:"bytes,1,opt,name=value,proto3" json:"value,omitempty"` SleepTimeMs int32 `protobuf:"varint,2,opt,name=sleep_time_ms,json=sleepTimeMs,proto3" json:"sleep_time_ms,omitempty"` ErrorCodeReturned uint32 `protobuf:"varint,3,opt,name=error_code_returned,json=errorCodeReturned,proto3" json:"error_code_returned,omitempty"` XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_unrecognized []byte `json:"-"` XXX_sizecache int32 `json:"-"` } func (m *PingRequest) Reset() { *m = PingRequest{} } func (m *PingRequest) String() string { return proto.CompactTextString(m) } func (*PingRequest) ProtoMessage() {} func (*PingRequest) Descriptor() ([]byte, []int) { return fileDescriptor_c161fcfdc0c3ff1e, []int{1} } func (m *PingRequest) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_PingRequest.Unmarshal(m, b) } func (m *PingRequest) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { return xxx_messageInfo_PingRequest.Marshal(b, m, deterministic) } func (m *PingRequest) XXX_Merge(src proto.Message) { xxx_messageInfo_PingRequest.Merge(m, src) } func (m *PingRequest) XXX_Size() int { return xxx_messageInfo_PingRequest.Size(m) } func (m *PingRequest) XXX_DiscardUnknown() { xxx_messageInfo_PingRequest.DiscardUnknown(m) } var xxx_messageInfo_PingRequest proto.InternalMessageInfo func (m *PingRequest) GetValue() string { if m != nil { return m.Value } return "" } func (m *PingRequest) GetSleepTimeMs() int32 { if m != nil { return m.SleepTimeMs } return 0 } func (m *PingRequest) GetErrorCodeReturned() uint32 { if m != nil { return m.ErrorCodeReturned } return 0 } type PingResponse struct { Value string `protobuf:"bytes,1,opt,name=Value,proto3" json:"Value,omitempty"` Counter int32 `protobuf:"varint,2,opt,name=counter,proto3" json:"counter,omitempty"` XXX_NoUnkeyedLiteral struct{} `json:"-"` XXX_unrecognized []byte `json:"-"` XXX_sizecache int32 `json:"-"` } func (m *PingResponse) Reset() { *m = PingResponse{} } func (m *PingResponse) String() string { return proto.CompactTextString(m) } func (*PingResponse) ProtoMessage() {} func (*PingResponse) Descriptor() ([]byte, []int) { return fileDescriptor_c161fcfdc0c3ff1e, []int{2} } func (m *PingResponse) XXX_Unmarshal(b []byte) error { return xxx_messageInfo_PingResponse.Unmarshal(m, b) } func (m *PingResponse) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { return xxx_messageInfo_PingResponse.Marshal(b, m, deterministic) } func (m *PingResponse) XXX_Merge(src proto.Message) { xxx_messageInfo_PingResponse.Merge(m, src) } func (m *PingResponse) XXX_Size() int { return xxx_messageInfo_PingResponse.Size(m) } func (m *PingResponse) XXX_DiscardUnknown() { xxx_messageInfo_PingResponse.DiscardUnknown(m) } var xxx_messageInfo_PingResponse proto.InternalMessageInfo func (m *PingResponse) GetValue() string { if m != nil { return m.Value } return "" } func (m *PingResponse) GetCounter() int32 { if m != nil { return m.Counter } return 0 } func init() { proto.RegisterType((*Empty)(nil), "mwitkow.testproto.Empty") proto.RegisterType((*PingRequest)(nil), "mwitkow.testproto.PingRequest") proto.RegisterType((*PingResponse)(nil), "mwitkow.testproto.PingResponse") } func init() { proto.RegisterFile("test.proto", fileDescriptor_c161fcfdc0c3ff1e) } var fileDescriptor_c161fcfdc0c3ff1e = []byte{ // 289 bytes of a gzipped FileDescriptorProto 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xac, 0x50, 0x4f, 0x4b, 0xfb, 0x40, 0x10, 0xed, 0xfe, 0xfa, 0x8b, 0xb5, 0x13, 0x7b, 0xe8, 0xea, 0x21, 0x78, 0xd0, 0xb0, 0xa7, 0x9c, 0x42, 0xd1, 0xbb, 0x17, 0x11, 0x15, 0x14, 0x25, 0x29, 0x5e, 0x43, 0x4d, 0x06, 0x59, 0xec, 0x66, 0xe3, 0xee, 0xa4, 0xc1, 0x8f, 0xe1, 0x37, 0x96, 0xdd, 0x46, 0x28, 0x68, 0xd1, 0x43, 0x8f, 0xef, 0xbd, 0xe1, 0xfd, 0x19, 0x00, 0x42, 0x4b, 0x69, 0x63, 0x34, 0x69, 0x3e, 0x55, 0x9d, 0xa4, 0x57, 0xdd, 0xa5, 0x8e, 0xf3, 0x94, 0x18, 0x41, 0x70, 0xa5, 0x1a, 0x7a, 0x17, 0x1d, 0x84, 0x8f, 0xb2, 0x7e, 0xc9, 0xf0, 0xad, 0x45, 0x4b, 0xfc, 0x08, 0x82, 0xd5, 0x62, 0xd9, 0x62, 0xc4, 0x62, 0x96, 0x8c, 0xb3, 0x35, 0xe0, 0x02, 0x26, 0x76, 0x89, 0xd8, 0x14, 0x24, 0x15, 0x16, 0xca, 0x46, 0xff, 0x62, 0x96, 0x04, 0x59, 0xe8, 0xc9, 0xb9, 0x54, 0x78, 0x6f, 0x79, 0x0a, 0x87, 0x68, 0x8c, 0x36, 0x45, 0xa9, 0x2b, 0x2c, 0x0c, 0x52, 0x6b, 0x6a, 0xac, 0xa2, 0x61, 0xcc, 0x92, 0x49, 0x36, 0xf5, 0xd2, 0xa5, 0xae, 0x30, 0xeb, 0x05, 0x71, 0x01, 0x07, 0xeb, 0x60, 0xdb, 0xe8, 0xda, 0xa2, 0x4b, 0x7e, 0xda, 0x4c, 0xf6, 0x80, 0x47, 0x30, 0x2a, 0x75, 0x5b, 0x13, 0x9a, 0x3e, 0xf3, 0x0b, 0x9e, 0x7d, 0x0c, 0x21, 0x9c, 0xa3, 0xa5, 0x1c, 0xcd, 0x4a, 0x96, 0xc8, 0x6f, 0x60, 0xec, 0xfc, 0xfc, 0x2a, 0x1e, 0xa5, 0xdf, 0x26, 0xa7, 0x5e, 0x39, 0x3e, 0xfd, 0x41, 0xd9, 0xec, 0x21, 0x06, 0xfc, 0x16, 0xfe, 0x3b, 0x86, 0x9f, 0x6c, 0x3d, 0xf5, 0xbf, 0xfa, 0x8b, 0xd5, 0x75, 0x5f, 0xca, 0xad, 0xff, 0xd5, 0x6f, 0x6b, 0x69, 0x31, 0xe0, 0x0f, 0xb0, 0xef, 0x4e, 0xef, 0xa4, 0xa5, 0x1d, 0xf4, 0x9a, 0x31, 0x9e, 0x03, 0x38, 0x2e, 0x27, 0x83, 0x0b, 0xb5, 0x03, 0xcb, 0x84, 0xcd, 0xd8, 0xf3, 0x9e, 0x57, 0xce, 0x3f, 0x03, 0x00, 0x00, 0xff, 0xff, 0xf0, 0x75, 0xf0, 0x5c, 0x7d, 0x02, 0x00, 0x00, } // Reference imports to suppress errors if they are not otherwise used. var _ context.Context var _ grpc.ClientConn // This is a compile-time assertion to ensure that this generated file // is compatible with the grpc package it is being compiled against. const _ = grpc.SupportPackageIsVersion4 // TestServiceClient is the client API for TestService service. // // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://godoc.org/google.golang.org/grpc#ClientConn.NewStream. type TestServiceClient interface { PingEmpty(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*PingResponse, error) Ping(ctx context.Context, in *PingRequest, opts ...grpc.CallOption) (*PingResponse, error) PingError(ctx context.Context, in *PingRequest, opts ...grpc.CallOption) (*Empty, error) PingList(ctx context.Context, in *PingRequest, opts ...grpc.CallOption) (TestService_PingListClient, error) PingStream(ctx context.Context, opts ...grpc.CallOption) (TestService_PingStreamClient, error) } type testServiceClient struct { cc *grpc.ClientConn } func NewTestServiceClient(cc *grpc.ClientConn) TestServiceClient { return &testServiceClient{cc} } func (c *testServiceClient) PingEmpty(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*PingResponse, error) { out := new(PingResponse) err := c.cc.Invoke(ctx, "/mwitkow.testproto.TestService/PingEmpty", in, out, opts...) if err != nil { return nil, err } return out, nil } func (c *testServiceClient) Ping(ctx context.Context, in *PingRequest, opts ...grpc.CallOption) (*PingResponse, error) { out := new(PingResponse) err := c.cc.Invoke(ctx, "/mwitkow.testproto.TestService/Ping", in, out, opts...) if err != nil { return nil, err } return out, nil } func (c *testServiceClient) PingError(ctx context.Context, in *PingRequest, opts ...grpc.CallOption) (*Empty, error) { out := new(Empty) err := c.cc.Invoke(ctx, "/mwitkow.testproto.TestService/PingError", in, out, opts...) if err != nil { return nil, err } return out, nil } func (c *testServiceClient) PingList(ctx context.Context, in *PingRequest, opts ...grpc.CallOption) (TestService_PingListClient, error) { stream, err := c.cc.NewStream(ctx, &_TestService_serviceDesc.Streams[0], "/mwitkow.testproto.TestService/PingList", opts...) if err != nil { return nil, err } x := &testServicePingListClient{stream} if err := x.ClientStream.SendMsg(in); err != nil { return nil, err } if err := x.ClientStream.CloseSend(); err != nil { return nil, err } return x, nil } type TestService_PingListClient interface { Recv() (*PingResponse, error) grpc.ClientStream } type testServicePingListClient struct { grpc.ClientStream } func (x *testServicePingListClient) Recv() (*PingResponse, error) { m := new(PingResponse) if err := x.ClientStream.RecvMsg(m); err != nil { return nil, err } return m, nil } func (c *testServiceClient) PingStream(ctx context.Context, opts ...grpc.CallOption) (TestService_PingStreamClient, error) { stream, err := c.cc.NewStream(ctx, &_TestService_serviceDesc.Streams[1], "/mwitkow.testproto.TestService/PingStream", opts...) if err != nil { return nil, err } x := &testServicePingStreamClient{stream} return x, nil } type TestService_PingStreamClient interface { Send(*PingRequest) error Recv() (*PingResponse, error) grpc.ClientStream } type testServicePingStreamClient struct { grpc.ClientStream } func (x *testServicePingStreamClient) Send(m *PingRequest) error { return x.ClientStream.SendMsg(m) } func (x *testServicePingStreamClient) Recv() (*PingResponse, error) { m := new(PingResponse) if err := x.ClientStream.RecvMsg(m); err != nil { return nil, err } return m, nil } // TestServiceServer is the server API for TestService service. type TestServiceServer interface { PingEmpty(context.Context, *Empty) (*PingResponse, error) Ping(context.Context, *PingRequest) (*PingResponse, error) PingError(context.Context, *PingRequest) (*Empty, error) PingList(*PingRequest, TestService_PingListServer) error PingStream(TestService_PingStreamServer) error } // UnimplementedTestServiceServer can be embedded to have forward compatible implementations. type UnimplementedTestServiceServer struct { } func (*UnimplementedTestServiceServer) PingEmpty(ctx context.Context, req *Empty) (*PingResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method PingEmpty not implemented") } func (*UnimplementedTestServiceServer) Ping(ctx context.Context, req *PingRequest) (*PingResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method Ping not implemented") } func (*UnimplementedTestServiceServer) PingError(ctx context.Context, req *PingRequest) (*Empty, error) { return nil, status.Errorf(codes.Unimplemented, "method PingError not implemented") } func (*UnimplementedTestServiceServer) PingList(req *PingRequest, srv TestService_PingListServer) error { return status.Errorf(codes.Unimplemented, "method PingList not implemented") } func (*UnimplementedTestServiceServer) PingStream(srv TestService_PingStreamServer) error { return status.Errorf(codes.Unimplemented, "method PingStream not implemented") } func RegisterTestServiceServer(s *grpc.Server, srv TestServiceServer) { s.RegisterService(&_TestService_serviceDesc, srv) } func _TestService_PingEmpty_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(Empty) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(TestServiceServer).PingEmpty(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: "/mwitkow.testproto.TestService/PingEmpty", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(TestServiceServer).PingEmpty(ctx, req.(*Empty)) } return interceptor(ctx, in, info, handler) } func _TestService_Ping_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(PingRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(TestServiceServer).Ping(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: "/mwitkow.testproto.TestService/Ping", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(TestServiceServer).Ping(ctx, req.(*PingRequest)) } return interceptor(ctx, in, info, handler) } func _TestService_PingError_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(PingRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(TestServiceServer).PingError(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: "/mwitkow.testproto.TestService/PingError", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(TestServiceServer).PingError(ctx, req.(*PingRequest)) } return interceptor(ctx, in, info, handler) } func _TestService_PingList_Handler(srv interface{}, stream grpc.ServerStream) error { m := new(PingRequest) if err := stream.RecvMsg(m); err != nil { return err } return srv.(TestServiceServer).PingList(m, &testServicePingListServer{stream}) } type TestService_PingListServer interface { Send(*PingResponse) error grpc.ServerStream } type testServicePingListServer struct { grpc.ServerStream } func (x *testServicePingListServer) Send(m *PingResponse) error { return x.ServerStream.SendMsg(m) } func _TestService_PingStream_Handler(srv interface{}, stream grpc.ServerStream) error { return srv.(TestServiceServer).PingStream(&testServicePingStreamServer{stream}) } type TestService_PingStreamServer interface { Send(*PingResponse) error Recv() (*PingRequest, error) grpc.ServerStream } type testServicePingStreamServer struct { grpc.ServerStream } func (x *testServicePingStreamServer) Send(m *PingResponse) error { return x.ServerStream.SendMsg(m) } func (x *testServicePingStreamServer) Recv() (*PingRequest, error) { m := new(PingRequest) if err := x.ServerStream.RecvMsg(m); err != nil { return nil, err } return m, nil } var _TestService_serviceDesc = grpc.ServiceDesc{ ServiceName: "mwitkow.testproto.TestService", HandlerType: (*TestServiceServer)(nil), Methods: []grpc.MethodDesc{ { MethodName: "PingEmpty", Handler: _TestService_PingEmpty_Handler, }, { MethodName: "Ping", Handler: _TestService_Ping_Handler, }, { MethodName: "PingError", Handler: _TestService_PingError_Handler, }, }, Streams: []grpc.StreamDesc{ { StreamName: "PingList", Handler: _TestService_PingList_Handler, ServerStreams: true, }, { StreamName: "PingStream", Handler: _TestService_PingStream_Handler, ServerStreams: true, ClientStreams: true, }, }, Metadata: "test.proto", } go-grpc-middleware-1.3.0/testing/testproto/test.proto000066400000000000000000000010441404040257500227500ustar00rootroot00000000000000syntax = "proto3"; package mwitkow.testproto; message Empty { } message PingRequest { string value = 1; int32 sleep_time_ms = 2; uint32 error_code_returned = 3; } message PingResponse { string Value = 1; int32 counter = 2; } service TestService { rpc PingEmpty(Empty) returns (PingResponse) {} rpc Ping(PingRequest) returns (PingResponse) {} rpc PingError(PingRequest) returns (Empty) {} rpc PingList(PingRequest) returns (stream PingResponse) {} rpc PingStream(stream PingRequest) returns (stream PingResponse) {} } go-grpc-middleware-1.3.0/tracing/000077500000000000000000000000001404040257500166145ustar00rootroot00000000000000go-grpc-middleware-1.3.0/tracing/opentracing/000077500000000000000000000000001404040257500211255ustar00rootroot00000000000000go-grpc-middleware-1.3.0/tracing/opentracing/client_interceptors.go000066400000000000000000000113511404040257500255340ustar00rootroot00000000000000// Copyright 2017 Michal Witkowski. All Rights Reserved. // See LICENSE for licensing terms. package grpc_opentracing import ( "context" "io" "sync" "github.com/grpc-ecosystem/go-grpc-middleware/util/metautils" opentracing "github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go/ext" "github.com/opentracing/opentracing-go/log" "google.golang.org/grpc" "google.golang.org/grpc/grpclog" "google.golang.org/grpc/metadata" ) // UnaryClientInterceptor returns a new unary client interceptor for OpenTracing. func UnaryClientInterceptor(opts ...Option) grpc.UnaryClientInterceptor { o := evaluateOptions(opts) return func(parentCtx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { if o.filterOutFunc != nil && !o.filterOutFunc(parentCtx, method) { return invoker(parentCtx, method, req, reply, cc, opts...) } newCtx, clientSpan := newClientSpanFromContext(parentCtx, o.tracer, method) if o.unaryRequestHandlerFunc != nil { o.unaryRequestHandlerFunc(clientSpan, req) } err := invoker(newCtx, method, req, reply, cc, opts...) finishClientSpan(clientSpan, err) return err } } // StreamClientInterceptor returns a new streaming client interceptor for OpenTracing. func StreamClientInterceptor(opts ...Option) grpc.StreamClientInterceptor { o := evaluateOptions(opts) return func(parentCtx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { if o.filterOutFunc != nil && !o.filterOutFunc(parentCtx, method) { return streamer(parentCtx, desc, cc, method, opts...) } newCtx, clientSpan := newClientSpanFromContext(parentCtx, o.tracer, method) clientStream, err := streamer(newCtx, desc, cc, method, opts...) if err != nil { finishClientSpan(clientSpan, err) return nil, err } return &tracedClientStream{ClientStream: clientStream, clientSpan: clientSpan}, nil } } // type serverStreamingRetryingStream is the implementation of grpc.ClientStream that acts as a // proxy to the underlying call. If any of the RecvMsg() calls fail, it will try to reestablish // a new ClientStream according to the retry policy. type tracedClientStream struct { grpc.ClientStream mu sync.Mutex alreadyFinished bool clientSpan opentracing.Span } func (s *tracedClientStream) Header() (metadata.MD, error) { h, err := s.ClientStream.Header() if err != nil { s.finishClientSpan(err) } return h, err } func (s *tracedClientStream) SendMsg(m interface{}) error { err := s.ClientStream.SendMsg(m) if err != nil { s.finishClientSpan(err) } return err } func (s *tracedClientStream) CloseSend() error { err := s.ClientStream.CloseSend() s.finishClientSpan(err) return err } func (s *tracedClientStream) RecvMsg(m interface{}) error { err := s.ClientStream.RecvMsg(m) if err != nil { s.finishClientSpan(err) } return err } func (s *tracedClientStream) finishClientSpan(err error) { s.mu.Lock() defer s.mu.Unlock() if !s.alreadyFinished { finishClientSpan(s.clientSpan, err) s.alreadyFinished = true } } // ClientAddContextTags returns a context with specified opentracing tags, which // are used by UnaryClientInterceptor/StreamClientInterceptor when creating a // new span. func ClientAddContextTags(ctx context.Context, tags opentracing.Tags) context.Context { return context.WithValue(ctx, clientSpanTagKey{}, tags) } type clientSpanTagKey struct{} func newClientSpanFromContext(ctx context.Context, tracer opentracing.Tracer, fullMethodName string) (context.Context, opentracing.Span) { var parentSpanCtx opentracing.SpanContext if parent := opentracing.SpanFromContext(ctx); parent != nil { parentSpanCtx = parent.Context() } opts := []opentracing.StartSpanOption{ opentracing.ChildOf(parentSpanCtx), ext.SpanKindRPCClient, grpcTag, } if tagx := ctx.Value(clientSpanTagKey{}); tagx != nil { if opt, ok := tagx.(opentracing.StartSpanOption); ok { opts = append(opts, opt) } } clientSpan := tracer.StartSpan(fullMethodName, opts...) // Make sure we add this to the metadata of the call, so it gets propagated: md := metautils.ExtractOutgoing(ctx).Clone() if err := tracer.Inject(clientSpan.Context(), opentracing.HTTPHeaders, metadataTextMap(md)); err != nil { grpclog.Infof("grpc_opentracing: failed serializing trace information: %v", err) } ctxWithMetadata := md.ToOutgoing(ctx) return opentracing.ContextWithSpan(ctxWithMetadata, clientSpan), clientSpan } func finishClientSpan(clientSpan opentracing.Span, err error) { if err != nil && err != io.EOF { ext.Error.Set(clientSpan, true) clientSpan.LogFields(log.String("event", "error"), log.String("message", err.Error())) } clientSpan.Finish() } go-grpc-middleware-1.3.0/tracing/opentracing/doc.go000066400000000000000000000013301404040257500222160ustar00rootroot00000000000000// Copyright 2017 Michal Witkowski. All Rights Reserved. // See LICENSE for licensing terms. /* `grpc_opentracing` adds OpenTracing OpenTracing Interceptors These are both client-side and server-side interceptors for OpenTracing. They are a provider-agnostic, with backends such as Zipkin, or Google Stackdriver Trace. For a service that sends out requests and receives requests, you *need* to use both, otherwise downstream requests will not have the appropriate requests propagated. All server-side spans are tagged with grpc_ctxtags information. For more information see: http://opentracing.io/documentation/ https://github.com/opentracing/specification/blob/master/semantic_conventions.md */ package grpc_opentracing go-grpc-middleware-1.3.0/tracing/opentracing/id_extract.go000066400000000000000000000052171404040257500236070ustar00rootroot00000000000000package grpc_opentracing import ( "strings" grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags" opentracing "github.com/opentracing/opentracing-go" "google.golang.org/grpc/grpclog" ) const ( TagTraceId = "trace.traceid" TagSpanId = "trace.spanid" TagSampled = "trace.sampled" jaegerNotSampledFlag = "0" ) // injectOpentracingIdsToTags writes trace data to ctxtags. // This is done in an incredibly hacky way, because the public-facing interface of opentracing doesn't give access to // the TraceId and SpanId of the SpanContext. Only the Tracer's Inject/Extract methods know what these are. // Most tracers have them encoded as keys with 'traceid' and 'spanid': // https://github.com/openzipkin/zipkin-go-opentracing/blob/594640b9ef7e5c994e8d9499359d693c032d738c/propagation_ot.go#L29 // https://github.com/opentracing/basictracer-go/blob/1b32af207119a14b1b231d451df3ed04a72efebf/propagation_ot.go#L26 // Jaeger from Uber use one-key schema with next format '{trace-id}:{span-id}:{parent-span-id}:{flags}' // https://www.jaegertracing.io/docs/client-libraries/#trace-span-identity // Datadog uses keys ending with 'trace-id' and 'parent-id' (for span) by default: // https://github.com/DataDog/dd-trace-go/blob/v1/ddtrace/tracer/textmap.go#L77 func injectOpentracingIdsToTags(traceHeaderName string, span opentracing.Span, tags grpc_ctxtags.Tags) { if err := span.Tracer().Inject(span.Context(), opentracing.HTTPHeaders, &tagsCarrier{Tags: tags, traceHeaderName: traceHeaderName}); err != nil { grpclog.Infof("grpc_opentracing: failed extracting trace info into ctx %v", err) } } // tagsCarrier is a really hacky way of type tagsCarrier struct { grpc_ctxtags.Tags traceHeaderName string } func (t *tagsCarrier) Set(key, val string) { key = strings.ToLower(key) if key == t.traceHeaderName { parts := strings.Split(val, ":") if len(parts) == 4 { t.Tags.Set(TagTraceId, parts[0]) t.Tags.Set(TagSpanId, parts[1]) if parts[3] != jaegerNotSampledFlag { t.Tags.Set(TagSampled, "true") } else { t.Tags.Set(TagSampled, "false") } return } } if strings.Contains(key, "traceid") { t.Tags.Set(TagTraceId, val) // this will most likely be base-16 (hex) encoded } if strings.Contains(key, "spanid") && !strings.Contains(strings.ToLower(key), "parent") { t.Tags.Set(TagSpanId, val) // this will most likely be base-16 (hex) encoded } if strings.Contains(key, "sampled") { switch val { case "true", "false": t.Tags.Set(TagSampled, val) } } if strings.HasSuffix(key, "trace-id") { t.Tags.Set(TagTraceId, val) } if strings.HasSuffix(key, "parent-id") { t.Tags.Set(TagSpanId, val) } } go-grpc-middleware-1.3.0/tracing/opentracing/id_extract_test.go000066400000000000000000000014301404040257500246370ustar00rootroot00000000000000package grpc_opentracing import ( "fmt" grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags" "github.com/stretchr/testify/assert" "testing" ) func TestTagsCarrier_Set_JaegerTraceFormat(t *testing.T) { var ( fakeTraceSampled = 1 fakeInboundTraceId = "deadbeef" fakeInboundSpanId = "c0decafe" traceHeaderName = "uber-trace-id" ) traceHeaderValue := fmt.Sprintf("%s:%s:%s:%d", fakeInboundTraceId, fakeInboundSpanId, fakeInboundSpanId, fakeTraceSampled) c := &tagsCarrier{ Tags: grpc_ctxtags.NewTags(), traceHeaderName: traceHeaderName, } c.Set(traceHeaderName, traceHeaderValue) assert.EqualValues(t, map[string]interface{}{ TagTraceId: fakeInboundTraceId, TagSpanId: fakeInboundSpanId, TagSampled: "true", }, c.Tags.Values()) } go-grpc-middleware-1.3.0/tracing/opentracing/interceptors_test.go000066400000000000000000000314421404040257500252400ustar00rootroot00000000000000// Copyright 2017 Michal Witkowski. All Rights Reserved. // See LICENSE for licensing terms. package grpc_opentracing_test import ( "context" "errors" "fmt" "io" "net/http" "strconv" "strings" "testing" "github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go/log" "github.com/opentracing/opentracing-go/mocktracer" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "google.golang.org/grpc" "google.golang.org/grpc/codes" "github.com/grpc-ecosystem/go-grpc-middleware" "github.com/grpc-ecosystem/go-grpc-middleware/tags" "github.com/grpc-ecosystem/go-grpc-middleware/testing" pb_testproto "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto" "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing" ) var ( goodPing = &pb_testproto.PingRequest{Value: "something", SleepTimeMs: 9999} fakeInboundTraceId = 1337 fakeInboundSpanId = 999 traceHeaderName = "uber-trace-id" filterFunc = func(ctx context.Context, fullMethodName string) bool { return true } unaryRequestHandlerFunc = func(span opentracing.Span, req interface{}) { span.LogFields(log.Bool("unary-request-handler", true)) } ) type tracingAssertService struct { pb_testproto.TestServiceServer T *testing.T } func (s *tracingAssertService) Ping(ctx context.Context, ping *pb_testproto.PingRequest) (*pb_testproto.PingResponse, error) { assert.NotNil(s.T, opentracing.SpanFromContext(ctx), "handlers must have the spancontext in their context, otherwise propagation will fail") tags := grpc_ctxtags.Extract(ctx) assert.True(s.T, tags.Has(grpc_opentracing.TagTraceId), "tags must contain traceid") assert.True(s.T, tags.Has(grpc_opentracing.TagSpanId), "tags must contain spanid") assert.True(s.T, tags.Has(grpc_opentracing.TagSampled), "tags must contain sampled") assert.Equal(s.T, tags.Values()[grpc_opentracing.TagSampled], "true", "sampled must be set to true") return s.TestServiceServer.Ping(ctx, ping) } func (s *tracingAssertService) PingError(ctx context.Context, ping *pb_testproto.PingRequest) (*pb_testproto.Empty, error) { assert.NotNil(s.T, opentracing.SpanFromContext(ctx), "handlers must have the spancontext in their context, otherwise propagation will fail") return s.TestServiceServer.PingError(ctx, ping) } func (s *tracingAssertService) PingList(ping *pb_testproto.PingRequest, stream pb_testproto.TestService_PingListServer) error { assert.NotNil(s.T, opentracing.SpanFromContext(stream.Context()), "handlers must have the spancontext in their context, otherwise propagation will fail") tags := grpc_ctxtags.Extract(stream.Context()) assert.True(s.T, tags.Has(grpc_opentracing.TagTraceId), "tags must contain traceid") assert.True(s.T, tags.Has(grpc_opentracing.TagSpanId), "tags must contain spanid") assert.True(s.T, tags.Has(grpc_opentracing.TagSampled), "tags must contain sampled") assert.Equal(s.T, tags.Values()[grpc_opentracing.TagSampled], "true", "sampled must be set to true") return s.TestServiceServer.PingList(ping, stream) } func (s *tracingAssertService) PingEmpty(ctx context.Context, empty *pb_testproto.Empty) (*pb_testproto.PingResponse, error) { assert.NotNil(s.T, opentracing.SpanFromContext(ctx), "handlers must have the spancontext in their context, otherwise propagation will fail") tags := grpc_ctxtags.Extract(ctx) assert.True(s.T, tags.Has(grpc_opentracing.TagTraceId), "tags must contain traceid") assert.True(s.T, tags.Has(grpc_opentracing.TagSpanId), "tags must contain spanid") assert.True(s.T, tags.Has(grpc_opentracing.TagSampled), "tags must contain sampled") assert.Equal(s.T, tags.Values()[grpc_opentracing.TagSampled], "false", "sampled must be set to false") return s.TestServiceServer.PingEmpty(ctx, empty) } func TestTaggingSuite(t *testing.T) { mockTracer := mocktracer.New() opts := []grpc_opentracing.Option{ grpc_opentracing.WithTracer(mockTracer), grpc_opentracing.WithFilterFunc(filterFunc), grpc_opentracing.WithTraceHeaderName(traceHeaderName), grpc_opentracing.WithUnaryRequestHandlerFunc(unaryRequestHandlerFunc), } s := &OpentracingSuite{ mockTracer: mockTracer, InterceptorTestSuite: makeInterceptorTestSuite(t, opts), } suite.Run(t, s) } func TestTaggingSuiteJaeger(t *testing.T) { mockTracer := mocktracer.New() mockTracer.RegisterInjector(opentracing.HTTPHeaders, jaegerFormatInjector{}) mockTracer.RegisterExtractor(opentracing.HTTPHeaders, jaegerFormatExtractor{}) opts := []grpc_opentracing.Option{ grpc_opentracing.WithTracer(mockTracer), grpc_opentracing.WithUnaryRequestHandlerFunc(unaryRequestHandlerFunc), } s := &OpentracingSuite{ mockTracer: mockTracer, InterceptorTestSuite: makeInterceptorTestSuite(t, opts), } suite.Run(t, s) } func makeInterceptorTestSuite(t *testing.T, opts []grpc_opentracing.Option) *grpc_testing.InterceptorTestSuite { return &grpc_testing.InterceptorTestSuite{ TestService: &tracingAssertService{TestServiceServer: &grpc_testing.TestPingService{T: t}, T: t}, ClientOpts: []grpc.DialOption{ grpc.WithUnaryInterceptor(grpc_opentracing.UnaryClientInterceptor(opts...)), grpc.WithStreamInterceptor(grpc_opentracing.StreamClientInterceptor(opts...)), }, ServerOpts: []grpc.ServerOption{ grpc_middleware.WithStreamServerChain( grpc_ctxtags.StreamServerInterceptor(grpc_ctxtags.WithFieldExtractor(grpc_ctxtags.CodeGenRequestFieldExtractor)), grpc_opentracing.StreamServerInterceptor(opts...)), grpc_middleware.WithUnaryServerChain( grpc_ctxtags.UnaryServerInterceptor(grpc_ctxtags.WithFieldExtractor(grpc_ctxtags.CodeGenRequestFieldExtractor)), grpc_opentracing.UnaryServerInterceptor(opts...)), }, } } type OpentracingSuite struct { *grpc_testing.InterceptorTestSuite mockTracer *mocktracer.MockTracer } func (s *OpentracingSuite) SetupTest() { s.mockTracer.Reset() } func (s *OpentracingSuite) createContextFromFakeHttpRequestParent(ctx context.Context, sampled bool, opName string) context.Context { jFlag := 0 if sampled { jFlag = 1 } if len(opName) == 0 { opName = "/fake/parent/http/request" } hdr := http.Header{} hdr.Set(traceHeaderName, fmt.Sprintf("%d:%d:%d:%d", fakeInboundTraceId, fakeInboundSpanId, fakeInboundSpanId, jFlag)) hdr.Set("mockpfx-ids-traceid", fmt.Sprint(fakeInboundTraceId)) hdr.Set("mockpfx-ids-spanid", fmt.Sprint(fakeInboundSpanId)) hdr.Set("mockpfx-ids-sampled", fmt.Sprint(sampled)) parentSpanContext, err := s.mockTracer.Extract(opentracing.HTTPHeaders, opentracing.HTTPHeadersCarrier(hdr)) require.NoError(s.T(), err, "parsing a fake HTTP request headers shouldn't fail, ever") fakeSpan := s.mockTracer.StartSpan( opName, // this is magical, it attaches the new span to the parent parentSpanContext, and creates an unparented one if empty. opentracing.ChildOf(parentSpanContext), ) fakeSpan.Finish() return opentracing.ContextWithSpan(ctx, fakeSpan) } func (s *OpentracingSuite) assertTracesCreated(methodName string) (clientSpan *mocktracer.MockSpan, serverSpan *mocktracer.MockSpan) { spans := s.mockTracer.FinishedSpans() for _, span := range spans { s.T().Logf("span: %v, tags: %v", span, span.Tags()) } require.Len(s.T(), spans, 3, "should record 3 spans: one fake inbound, one client, one server") traceIdAssert := fmt.Sprintf("traceId=%d", fakeInboundTraceId) for _, span := range spans { assert.Contains(s.T(), span.String(), traceIdAssert, "not part of the fake parent trace: %v", span) if span.OperationName == methodName { kind := fmt.Sprintf("%v", span.Tag("span.kind")) if kind == "client" { clientSpan = span } else if kind == "server" { serverSpan = span } assert.EqualValues(s.T(), span.Tag("component"), "gRPC", "span must be tagged with gRPC component") } } require.NotNil(s.T(), clientSpan, "client span must be there") require.NotNil(s.T(), serverSpan, "server span must be there") assert.EqualValues(s.T(), serverSpan.Tag("grpc.request.value"), "something", "grpc_ctxtags must be propagated, in this case ones from request fields") return clientSpan, serverSpan } func (s *OpentracingSuite) TestPing_PropagatesTraces() { ctx := s.createContextFromFakeHttpRequestParent(s.SimpleCtx(), true, "") _, err := s.Client.Ping(ctx, goodPing) require.NoError(s.T(), err, "there must be not be an on a successful call") s.assertTracesCreated("/mwitkow.testproto.TestService/Ping") } func (s *OpentracingSuite) TestPing_CustomOpName() { customOpName := "customOpName" ctx := s.createContextFromFakeHttpRequestParent(s.SimpleCtx(), true, customOpName) _, err := s.Client.Ping(ctx, goodPing) require.NoError(s.T(), err, "there must be not be an error on a successful call") spans := s.mockTracer.FinishedSpans() spanOpNames := make([]string, len(spans)) for _, span := range spans { spanOpNames = append(spanOpNames, span.OperationName) } require.Contains(s.T(), spanOpNames, customOpName, "finished spans must contain the custom operation name") } func (s *OpentracingSuite) TestPing_WithUnaryRequestHandlerFunc() { ctx := s.createContextFromFakeHttpRequestParent(s.SimpleCtx(), true, "") _, err := s.Client.Ping(ctx, goodPing) require.NoError(s.T(), err, "there must be not be an on a successful call") var hasLogKey bool Loop: for _, span := range s.mockTracer.FinishedSpans() { for _, record := range span.Logs() { for _, field := range record.Fields { if field.Key == "unary-request-handler" { hasLogKey = true break Loop } } } } require.True(s.T(), hasLogKey, "span field 'unary-request-handler' not found") } func (s *OpentracingSuite) TestPing_ClientContextTags() { const name = "opentracing.custom" ctx := grpc_opentracing.ClientAddContextTags( s.createContextFromFakeHttpRequestParent(s.SimpleCtx(), true, ""), opentracing.Tags{name: ""}, ) _, err := s.Client.Ping(ctx, goodPing) require.NoError(s.T(), err, "there must be not be an on a successful call") for _, span := range s.mockTracer.FinishedSpans() { if span.OperationName == "/mwitkow.testproto.TestService/Ping" { kind := fmt.Sprintf("%v", span.Tag("span.kind")) if kind == "client" { assert.Contains(s.T(), span.Tags(), name, "custom opentracing.Tags must be included in context") } } } } func (s *OpentracingSuite) TestPingList_PropagatesTraces() { ctx := s.createContextFromFakeHttpRequestParent(s.SimpleCtx(), true, "") stream, err := s.Client.PingList(ctx, goodPing) require.NoError(s.T(), err, "should not fail on establishing the stream") for { _, err := stream.Recv() if err == io.EOF { break } require.NoError(s.T(), err, "reading stream should not fail") } s.assertTracesCreated("/mwitkow.testproto.TestService/PingList") } func (s *OpentracingSuite) TestPingError_PropagatesTraces() { ctx := s.createContextFromFakeHttpRequestParent(s.SimpleCtx(), true, "") erroringPing := &pb_testproto.PingRequest{Value: "something", ErrorCodeReturned: uint32(codes.OutOfRange)} _, err := s.Client.PingError(ctx, erroringPing) require.Error(s.T(), err, "there must be an error returned here") clientSpan, serverSpan := s.assertTracesCreated("/mwitkow.testproto.TestService/PingError") assert.Equal(s.T(), true, clientSpan.Tag("error"), "client span needs to be marked as an error") assert.Equal(s.T(), true, serverSpan.Tag("error"), "server span needs to be marked as an error") } func (s *OpentracingSuite) TestPingEmpty_NotSampleTraces() { ctx := s.createContextFromFakeHttpRequestParent(s.SimpleCtx(), false, "") _, err := s.Client.PingEmpty(ctx, &pb_testproto.Empty{}) require.NoError(s.T(), err, "there must be not be an on a successful call") } type jaegerFormatInjector struct{} func (jaegerFormatInjector) Inject(ctx mocktracer.MockSpanContext, carrier interface{}) error { w := carrier.(opentracing.TextMapWriter) flags := 0 if ctx.Sampled { flags = 1 } w.Set(traceHeaderName, fmt.Sprintf("%d:%d::%d", ctx.TraceID, ctx.SpanID, flags)) return nil } type jaegerFormatExtractor struct{} func (jaegerFormatExtractor) Extract(carrier interface{}) (mocktracer.MockSpanContext, error) { rval := mocktracer.MockSpanContext{Sampled: true} reader, ok := carrier.(opentracing.TextMapReader) if !ok { return rval, opentracing.ErrInvalidCarrier } err := reader.ForeachKey(func(key, val string) error { lowerKey := strings.ToLower(key) switch { case lowerKey == traceHeaderName: parts := strings.Split(val, ":") if len(parts) != 4 { return errors.New("invalid trace id format") } traceId, err := strconv.Atoi(parts[0]) if err != nil { return err } rval.TraceID = traceId spanId, err := strconv.Atoi(parts[1]) if err != nil { return err } rval.SpanID = spanId flags, err := strconv.Atoi(parts[3]) if err != nil { return err } rval.Sampled = flags%2 == 1 } return nil }) if rval.TraceID == 0 || rval.SpanID == 0 { return rval, opentracing.ErrSpanContextNotFound } if err != nil { return rval, err } return rval, nil }go-grpc-middleware-1.3.0/tracing/opentracing/metadata.go000066400000000000000000000025401404040257500232350ustar00rootroot00000000000000// Copyright 2017 Michal Witkowski. All Rights Reserved. // See LICENSE for licensing terms. package grpc_opentracing import ( "encoding/base64" "strings" "google.golang.org/grpc/metadata" ) const ( binHdrSuffix = "-bin" ) // metadataTextMap extends a metadata.MD to be an opentracing textmap type metadataTextMap metadata.MD // Set is a opentracing.TextMapReader interface that extracts values. func (m metadataTextMap) Set(key, val string) { // gRPC allows for complex binary values to be written. encodedKey, encodedVal := encodeKeyValue(key, val) // The metadata object is a multimap, and previous values may exist, but for opentracing headers, we do not append // we just override. m[encodedKey] = []string{encodedVal} } // ForeachKey is a opentracing.TextMapReader interface that extracts values. func (m metadataTextMap) ForeachKey(callback func(key, val string) error) error { for k, vv := range m { for _, v := range vv { if err := callback(k, v); err != nil { return err } } } return nil } // encodeKeyValue encodes key and value qualified for transmission via gRPC. // note: copy pasted from private values of grpc.metadata func encodeKeyValue(k, v string) (string, string) { k = strings.ToLower(k) if strings.HasSuffix(k, binHdrSuffix) { val := base64.StdEncoding.EncodeToString([]byte(v)) v = string(val) } return k, v } go-grpc-middleware-1.3.0/tracing/opentracing/options.go000066400000000000000000000044401404040257500231510ustar00rootroot00000000000000// Copyright 2017 Michal Witkowski. All Rights Reserved. // See LICENSE for licensing terms. package grpc_opentracing import ( "context" "github.com/opentracing/opentracing-go" ) var ( defaultOptions = &options{ filterOutFunc: nil, tracer: nil, } ) // FilterFunc allows users to provide a function that filters out certain methods from being traced. // // If it returns false, the given request will not be traced. type FilterFunc func(ctx context.Context, fullMethodName string) bool // UnaryRequestHandlerFunc is a custom request handler type UnaryRequestHandlerFunc func(span opentracing.Span, req interface{}) // OpNameFunc is a func that allows custom operation names instead of the gRPC method. type OpNameFunc func(method string) string type options struct { filterOutFunc FilterFunc tracer opentracing.Tracer traceHeaderName string unaryRequestHandlerFunc UnaryRequestHandlerFunc opNameFunc OpNameFunc } func evaluateOptions(opts []Option) *options { optCopy := &options{} *optCopy = *defaultOptions for _, o := range opts { o(optCopy) } if optCopy.tracer == nil { optCopy.tracer = opentracing.GlobalTracer() } if optCopy.traceHeaderName == "" { optCopy.traceHeaderName = "uber-trace-id" } return optCopy } type Option func(*options) // WithFilterFunc customizes the function used for deciding whether a given call is traced or not. func WithFilterFunc(f FilterFunc) Option { return func(o *options) { o.filterOutFunc = f } } // WithTraceHeaderName customizes the trace header name where trace metadata passed with requests. // Default one is `uber-trace-id` func WithTraceHeaderName(name string) Option { return func(o *options) { o.traceHeaderName = name } } // WithTracer sets a custom tracer to be used for this middleware, otherwise the opentracing.GlobalTracer is used. func WithTracer(tracer opentracing.Tracer) Option { return func(o *options) { o.tracer = tracer } } // WithUnaryRequestHandlerFunc sets a custom handler for the request func WithUnaryRequestHandlerFunc(f UnaryRequestHandlerFunc) Option { return func(o *options) { o.unaryRequestHandlerFunc = f } } // WithOpName customizes the trace Operation name func WithOpName(f OpNameFunc) Option { return func(o *options) { o.opNameFunc = f } } go-grpc-middleware-1.3.0/tracing/opentracing/server_interceptors.go000066400000000000000000000066701404040257500255740ustar00rootroot00000000000000// Copyright 2017 Michal Witkowski. All Rights Reserved. // See LICENSE for licensing terms. package grpc_opentracing import ( "context" "github.com/grpc-ecosystem/go-grpc-middleware" "github.com/grpc-ecosystem/go-grpc-middleware/tags" "github.com/grpc-ecosystem/go-grpc-middleware/util/metautils" "github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go/ext" "github.com/opentracing/opentracing-go/log" "google.golang.org/grpc" "google.golang.org/grpc/grpclog" ) var ( grpcTag = opentracing.Tag{Key: string(ext.Component), Value: "gRPC"} ) // UnaryServerInterceptor returns a new unary server interceptor for OpenTracing. func UnaryServerInterceptor(opts ...Option) grpc.UnaryServerInterceptor { o := evaluateOptions(opts) return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { if o.filterOutFunc != nil && !o.filterOutFunc(ctx, info.FullMethod) { return handler(ctx, req) } opName := info.FullMethod if o.opNameFunc != nil { opName = o.opNameFunc(info.FullMethod) } newCtx, serverSpan := newServerSpanFromInbound(ctx, o.tracer, o.traceHeaderName, opName) if o.unaryRequestHandlerFunc != nil { o.unaryRequestHandlerFunc(serverSpan, req) } resp, err := handler(newCtx, req) finishServerSpan(ctx, serverSpan, err) return resp, err } } // StreamServerInterceptor returns a new streaming server interceptor for OpenTracing. func StreamServerInterceptor(opts ...Option) grpc.StreamServerInterceptor { o := evaluateOptions(opts) return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { if o.filterOutFunc != nil && !o.filterOutFunc(stream.Context(), info.FullMethod) { return handler(srv, stream) } opName := info.FullMethod if o.opNameFunc != nil { opName = o.opNameFunc(info.FullMethod) } newCtx, serverSpan := newServerSpanFromInbound(stream.Context(), o.tracer, o.traceHeaderName, opName) wrappedStream := grpc_middleware.WrapServerStream(stream) wrappedStream.WrappedContext = newCtx err := handler(srv, wrappedStream) finishServerSpan(newCtx, serverSpan, err) return err } } func newServerSpanFromInbound(ctx context.Context, tracer opentracing.Tracer, traceHeaderName, opName string) (context.Context, opentracing.Span) { md := metautils.ExtractIncoming(ctx) parentSpanContext, err := tracer.Extract(opentracing.HTTPHeaders, metadataTextMap(md)) if err != nil && err != opentracing.ErrSpanContextNotFound { grpclog.Infof("grpc_opentracing: failed parsing trace information: %v", err) } serverSpan := tracer.StartSpan( opName, // this is magical, it attaches the new span to the parent parentSpanContext, and creates an unparented one if empty. ext.RPCServerOption(parentSpanContext), grpcTag, ) injectOpentracingIdsToTags(traceHeaderName, serverSpan, grpc_ctxtags.Extract(ctx)) return opentracing.ContextWithSpan(ctx, serverSpan), serverSpan } func finishServerSpan(ctx context.Context, serverSpan opentracing.Span, err error) { // Log context information tags := grpc_ctxtags.Extract(ctx) for k, v := range tags.Values() { // Don't tag errors, log them instead. if vErr, ok := v.(error); ok { serverSpan.LogKV(k, vErr.Error()) } else { serverSpan.SetTag(k, v) } } if err != nil { ext.Error.Set(serverSpan, true) serverSpan.LogFields(log.String("event", "error"), log.String("message", err.Error())) } serverSpan.Finish() } go-grpc-middleware-1.3.0/util/000077500000000000000000000000001404040257500161425ustar00rootroot00000000000000go-grpc-middleware-1.3.0/util/backoffutils/000077500000000000000000000000001404040257500206165ustar00rootroot00000000000000go-grpc-middleware-1.3.0/util/backoffutils/backoff.go000066400000000000000000000013461404040257500225440ustar00rootroot00000000000000// Copyright 2016 Michal Witkowski. All Rights Reserved. // See LICENSE for licensing terms. /* Backoff Helper Utilities Implements common backoff features. */ package backoffutils import ( "math/rand" "time" ) // JitterUp adds random jitter to the duration. // // This adds or subtracts time from the duration within a given jitter fraction. // For example for 10s and jitter 0.1, it will return a time within [9s, 11s]) func JitterUp(duration time.Duration, jitter float64) time.Duration { multiplier := jitter * (rand.Float64()*2 - 1) return time.Duration(float64(duration) * (1 + multiplier)) } // ExponentBase2 computes 2^(a-1) where a >= 1. If a is 0, the result is 0. func ExponentBase2(a uint) uint { return (1 << a) >> 1 } go-grpc-middleware-1.3.0/util/backoffutils/backoff_test.go000066400000000000000000000021611404040257500235770ustar00rootroot00000000000000// Copyright 2016 Michal Witkowski. All Rights Reserved. // See LICENSE for licensing terms. package backoffutils_test import ( "testing" "time" "github.com/grpc-ecosystem/go-grpc-middleware/util/backoffutils" "github.com/stretchr/testify/assert" ) // scale duration by a factor func scaleDuration(d time.Duration, factor float64) time.Duration { return time.Duration(float64(d) * factor) } func TestJitterUp(t *testing.T) { // arguments to jitterup duration := 10 * time.Second variance := 0.10 // bound to check max := 11000 * time.Millisecond min := 9000 * time.Millisecond high := scaleDuration(max, 0.98) low := scaleDuration(min, 1.02) highCount := 0 lowCount := 0 for i := 0; i < 1000; i++ { out := backoffutils.JitterUp(duration, variance) assert.True(t, out <= max, "value %s must be <= %s", out, max) assert.True(t, out >= min, "value %s must be >= %s", out, min) if out > high { highCount++ } if out < low { lowCount++ } } assert.True(t, highCount != 0, "at least one sample should reach to >%s", high) assert.True(t, lowCount != 0, "at least one sample should to <%s", low) } go-grpc-middleware-1.3.0/util/metautils/000077500000000000000000000000001404040257500201515ustar00rootroot00000000000000go-grpc-middleware-1.3.0/util/metautils/doc.go000066400000000000000000000014631404040257500212510ustar00rootroot00000000000000// Copyright 2016 Michal Witkowski. All Rights Reserved. // See LICENSE for licensing terms. /* Package `metautils` provides convenience functions for dealing with gRPC metadata.MD objects inside Context handlers. While the upstream grpc-go package contains decent functionality (see https://github.com/grpc/grpc-go/blob/master/Documentation/grpc-metadata.md) they are hard to use. The majority of functions center around the NiceMD, which is a convenience wrapper around metadata.MD. For example the following code allows you to easily extract incoming metadata (server handler) and put it into a new client context metadata. nmd := metautils.ExtractIncoming(serverCtx).Clone(":authorization", ":custom") clientCtx := nmd.Set("x-client-header", "2").Set("x-another", "3").ToOutgoing(ctx) */ package metautils go-grpc-middleware-1.3.0/util/metautils/nicemd.go000066400000000000000000000065141404040257500217450ustar00rootroot00000000000000// Copyright 2016 Michal Witkowski. All Rights Reserved. // See LICENSE for licensing terms. package metautils import ( "context" "strings" "google.golang.org/grpc/metadata" ) // NiceMD is a convenience wrapper definiting extra functions on the metadata. type NiceMD metadata.MD // ExtractIncoming extracts an inbound metadata from the server-side context. // // This function always returns a NiceMD wrapper of the metadata.MD, in case the context doesn't have metadata it returns // a new empty NiceMD. func ExtractIncoming(ctx context.Context) NiceMD { md, ok := metadata.FromIncomingContext(ctx) if !ok { return NiceMD(metadata.Pairs()) } return NiceMD(md) } // ExtractOutgoing extracts an outbound metadata from the client-side context. // // This function always returns a NiceMD wrapper of the metadata.MD, in case the context doesn't have metadata it returns // a new empty NiceMD. func ExtractOutgoing(ctx context.Context) NiceMD { md, ok := metadata.FromOutgoingContext(ctx) if !ok { return NiceMD(metadata.Pairs()) } return NiceMD(md) } // Clone performs a *deep* copy of the metadata.MD. // // You can specify the lower-case copiedKeys to only copy certain whitelisted keys. If no keys are explicitly whitelisted // all keys get copied. func (m NiceMD) Clone(copiedKeys ...string) NiceMD { newMd := NiceMD(metadata.Pairs()) for k, vv := range m { found := false if len(copiedKeys) == 0 { found = true } else { for _, allowedKey := range copiedKeys { if strings.EqualFold(allowedKey, k) { found = true break } } } if !found { continue } newMd[k] = make([]string, len(vv)) copy(newMd[k], vv) } return NiceMD(newMd) } // ToOutgoing sets the given NiceMD as a client-side context for dispatching. func (m NiceMD) ToOutgoing(ctx context.Context) context.Context { return metadata.NewOutgoingContext(ctx, metadata.MD(m)) } // ToIncoming sets the given NiceMD as a server-side context for dispatching. // // This is mostly useful in ServerInterceptors.. func (m NiceMD) ToIncoming(ctx context.Context) context.Context { return metadata.NewIncomingContext(ctx, metadata.MD(m)) } // Get retrieves a single value from the metadata. // // It works analogously to http.Header.Get, returning the first value if there are many set. If the value is not set, // an empty string is returned. // // The function is binary-key safe. func (m NiceMD) Get(key string) string { k := strings.ToLower(key) vv, ok := m[k] if !ok { return "" } return vv[0] } // Del retrieves a single value from the metadata. // // It works analogously to http.Header.Del, deleting all values if they exist. // // The function is binary-key safe. func (m NiceMD) Del(key string) NiceMD { k := strings.ToLower(key) delete(m, k) return m } // Set sets the given value in a metadata. // // It works analogously to http.Header.Set, overwriting all previous metadata values. // // The function is binary-key safe. func (m NiceMD) Set(key string, value string) NiceMD { k := strings.ToLower(key) m[k] = []string{value} return m } // Add retrieves a single value from the metadata. // // It works analogously to http.Header.Add, as it appends to any existing values associated with key. // // The function is binary-key safe. func (m NiceMD) Add(key string, value string) NiceMD { k := strings.ToLower(key) m[k] = append(m[k], value) return m } go-grpc-middleware-1.3.0/util/metautils/nicemd_test.go000066400000000000000000000100021404040257500227670ustar00rootroot00000000000000// Copyright 2016 Michal Witkowski. All Rights Reserved. // See LICENSE for licensing terms. package metautils_test import ( "context" "testing" "github.com/grpc-ecosystem/go-grpc-middleware/util/metautils" "github.com/stretchr/testify/assert" "google.golang.org/grpc/metadata" ) var ( testPairs = []string{"singlekey", "uno", "multikey", "one", "multikey", "two", "multikey", "three"} parentCtx = context.WithValue(context.TODO(), "parentKey", "parentValue") ) func assertRetainsParentContext(t *testing.T, ctx context.Context) { x := ctx.Value("parentKey") assert.EqualValues(t, "parentValue", x, "context must contain parentCtx") } func TestNiceMD_Get(t *testing.T) { nmd := metautils.NiceMD(metadata.Pairs(testPairs...)) assert.Equal(t, "uno", nmd.Get("singlekey"), "for present single-key value it should return it") assert.Equal(t, "one", nmd.Get("multikey"), "for present multi-key should return first value") assert.Empty(t, nmd.Get("nokey"), "for non existing key should return stuff") } func TestNiceMD_Del(t *testing.T) { nmd := metautils.NiceMD(metadata.Pairs(testPairs...)) assert.Equal(t, "uno", nmd.Get("singlekey"), "for present single-key value it should return it") nmd.Del("singlekey").Del("doesnt exist") assert.Empty(t, nmd.Get("singlekey"), "after deletion singlekey shouldn't exist") } func TestNiceMD_Add(t *testing.T) { nmd := metautils.NiceMD(metadata.Pairs(testPairs...)) nmd.Add("multikey", "four").Add("newkey", "something") assert.EqualValues(t, []string{"one", "two", "three", "four"}, nmd["multikey"], "append should add a new four at the end") assert.EqualValues(t, []string{"something"}, nmd["newkey"], "append should be able to create new keys") } func TestNiceMD_Set(t *testing.T) { nmd := metautils.NiceMD(metadata.Pairs(testPairs...)) nmd.Set("multikey", "one").Set("newkey", "something").Set("newkey", "another") assert.EqualValues(t, []string{"one"}, nmd["multikey"], "set should override existing multi keys") assert.EqualValues(t, []string{"another"}, nmd["newkey"], "set should override new keys") } func TestNiceMD_SetGet(t *testing.T) { nmd := metautils.NiceMD(metadata.Pairs(testPairs...)) nmd.Set("another-key", "onetwothree") assert.EqualValues(t, "onetwothree", nmd.Get("another-key")) nmd.Set("another-key-bin", "binarydata") assert.EqualValues(t, "binarydata", nmd.Get("another-key-bin")) } func TestNiceMD_Clone(t *testing.T) { nmd := metautils.NiceMD(metadata.Pairs(testPairs...)) fullCopied := nmd.Clone() assert.Equal(t, len(fullCopied), len(nmd), "clone full should copy all keys") assert.Equal(t, "uno", fullCopied.Get("singlekey"), "full copied should have content") subCopied := nmd.Clone("multikey") assert.Len(t, subCopied, 1, "sub copied clone should only have one key") assert.Empty(t, subCopied.Get("singlekey"), "there shouldn't be a singlekey in the subcopied") // Test side effects and full copying: assert.EqualValues(t, subCopied["multikey"], nmd["multikey"], "before overwrites multikey should have the same values") subCopied["multikey"][1] = "modifiedtwo" assert.NotEqual(t, subCopied["multikey"], nmd["multikey"], "before overwrites multikey should have the same values") } func TestNiceMD_ToOutgoing(t *testing.T) { nmd := metautils.NiceMD(metadata.Pairs(testPairs...)) nCtx := nmd.ToOutgoing(parentCtx) assertRetainsParentContext(t, nCtx) eCtx := metautils.ExtractOutgoing(nCtx).Clone().Set("newvalue", "something").ToOutgoing(nCtx) assertRetainsParentContext(t, eCtx) assert.NotEqual(t, metautils.ExtractOutgoing(nCtx), metautils.ExtractOutgoing(eCtx), "the niceMD pointed to by ectx and nctx are different.") } func TestNiceMD_ToIncoming(t *testing.T) { nmd := metautils.NiceMD(metadata.Pairs(testPairs...)) nCtx := nmd.ToIncoming(parentCtx) assertRetainsParentContext(t, nCtx) eCtx := metautils.ExtractIncoming(nCtx).Clone().Set("newvalue", "something").ToIncoming(nCtx) assertRetainsParentContext(t, eCtx) assert.NotEqual(t, metautils.ExtractIncoming(nCtx), metautils.ExtractIncoming(eCtx), "the niceMD pointed to by ectx and nctx are different.") } go-grpc-middleware-1.3.0/validator/000077500000000000000000000000001404040257500171525ustar00rootroot00000000000000go-grpc-middleware-1.3.0/validator/doc.go000066400000000000000000000040361404040257500202510ustar00rootroot00000000000000// Copyright 2016 Michal Witkowski. All Rights Reserved. // See LICENSE for licensing terms. /* `grpc_validator` is a generic request contents validator server-side middleware for gRPC. Request Validator Middleware Validating input is important, and hard. It also causes a lot of boilerplate code. This middleware checks for the existence of a `Validate` method on each of the messages of a gRPC request. This includes the single request of the `Unary` calls, as well as each message of the inbound Stream calls. In case of a validation failure, an `InvalidArgument` gRPC status is returned, along with a description of the validation failure. While it is generic, it was intended to be used with https://github.com/mwitkow/go-proto-validators, a Go protocol buffers codegen plugin that creates the `Validate` methods (including nested messages) based on declarative options in the `.proto` files themselves. For example: syntax = "proto3"; package validator.examples; import "github.com/mwitkow/go-proto-validators/validator.proto"; message InnerMessage { // some_integer can only be in range (1, 100). int32 some_integer = 1 [(validator.field) = {int_gt: 0, int_lt: 100}]; // some_float can only be in range (0;1). double some_float = 2 [(validator.field) = {float_gte: 0, float_lte: 1}]; } message OuterMessage { // important_string must be a lowercase alpha-numeric of 5 to 30 characters (RE2 syntax). string important_string = 1 [(validator.field) = {regex: "^[a-z]{2,5}$"}]; // proto3 doesn't have `required`, the `msg_exist` enforces presence of InnerMessage. InnerMessage inner = 2 [(validator.field) = {msg_exists : true}]; } The `OuterMessage.Validate` would include validation of regexes, existence of the InnerMessage and the range values within it. The `grpc_validator` middleware would then automatically use that to check all messages processed by the server. Please consult https://github.com/mwitkow/go-proto-validators for details on `protoc` invocation and other parameters of customization. */ package grpc_validator go-grpc-middleware-1.3.0/validator/validator.go000066400000000000000000000054021404040257500214670ustar00rootroot00000000000000// Copyright 2016 Michal Witkowski. All Rights Reserved. // See LICENSE for licensing terms. package grpc_validator import ( "context" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) // The validate interface starting with protoc-gen-validate v0.6.0. // See https://github.com/envoyproxy/protoc-gen-validate/pull/455. type validator interface { Validate(all bool) error } // The validate interface prior to protoc-gen-validate v0.6.0. type validatorLegacy interface { Validate() error } func validate(req interface{}) error { switch v := req.(type) { case validatorLegacy: if err := v.Validate(); err != nil { return status.Error(codes.InvalidArgument, err.Error()) } case validator: if err := v.Validate(false); err != nil { return status.Error(codes.InvalidArgument, err.Error()) } } return nil } // UnaryServerInterceptor returns a new unary server interceptor that validates incoming messages. // // Invalid messages will be rejected with `InvalidArgument` before reaching any userspace handlers. func UnaryServerInterceptor() grpc.UnaryServerInterceptor { return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { if err := validate(req); err != nil { return nil, err } return handler(ctx, req) } } // UnaryClientInterceptor returns a new unary client interceptor that validates outgoing messages. // // Invalid messages will be rejected with `InvalidArgument` before sending the request to server. func UnaryClientInterceptor() grpc.UnaryClientInterceptor { return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { if err := validate(req); err != nil { return err } return invoker(ctx, method, req, reply, cc, opts...) } } // StreamServerInterceptor returns a new streaming server interceptor that validates incoming messages. // // The stage at which invalid messages will be rejected with `InvalidArgument` varies based on the // type of the RPC. For `ServerStream` (1:m) requests, it will happen before reaching any userspace // handlers. For `ClientStream` (n:1) or `BidiStream` (n:m) RPCs, the messages will be rejected on // calls to `stream.Recv()`. func StreamServerInterceptor() grpc.StreamServerInterceptor { return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { wrapper := &recvWrapper{stream} return handler(srv, wrapper) } } type recvWrapper struct { grpc.ServerStream } func (s *recvWrapper) RecvMsg(m interface{}) error { if err := s.ServerStream.RecvMsg(m); err != nil { return err } if err := validate(m); err != nil { return err } return nil } go-grpc-middleware-1.3.0/validator/validator_test.go000066400000000000000000000101461404040257500225270ustar00rootroot00000000000000// Copyright 2016 Michal Witkowski. All Rights Reserved. // See LICENSE for licensing terms. package grpc_validator import ( "io" "math" "testing" grpc_testing "github.com/grpc-ecosystem/go-grpc-middleware/testing" pb_testproto "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) var ( // See test.manual_validator.pb.go for the validator check of SleepTimeMs. goodPing = &pb_testproto.PingRequest{Value: "something", SleepTimeMs: 9999} badPing = &pb_testproto.PingRequest{Value: "something", SleepTimeMs: 10001} // See test.manual_validator.pb.go for the validator check of the counter. goodPingResponse = &pb_testproto.PingResponse{Counter: 100} badPingResponse = &pb_testproto.PingResponse{Counter: math.MaxInt16 + 1} ) func TestValidateWrapper(t *testing.T) { assert.NoError(t, validate(goodPing)) assert.Error(t, validate(badPing)) assert.NoError(t, validate(goodPingResponse)) assert.Error(t, validate(badPingResponse)) } func TestValidatorTestSuite(t *testing.T) { s := &ValidatorTestSuite{ InterceptorTestSuite: &grpc_testing.InterceptorTestSuite{ ServerOpts: []grpc.ServerOption{ grpc.StreamInterceptor(StreamServerInterceptor()), grpc.UnaryInterceptor(UnaryServerInterceptor()), }, }, } suite.Run(t, s) cs := &ClientValidatorTestSuite{ InterceptorTestSuite: &grpc_testing.InterceptorTestSuite{ ClientOpts: []grpc.DialOption{ grpc.WithUnaryInterceptor(UnaryClientInterceptor()), }, }, } suite.Run(t, cs) } type ValidatorTestSuite struct { *grpc_testing.InterceptorTestSuite } func (s *ValidatorTestSuite) TestValidPasses_Unary() { _, err := s.Client.Ping(s.SimpleCtx(), goodPing) assert.NoError(s.T(), err, "no error expected") } func (s *ValidatorTestSuite) TestInvalidErrors_Unary() { _, err := s.Client.Ping(s.SimpleCtx(), badPing) assert.Error(s.T(), err, "no error expected") assert.Equal(s.T(), codes.InvalidArgument, status.Code(err), "gRPC status must be InvalidArgument") } func (s *ValidatorTestSuite) TestValidPasses_ServerStream() { stream, err := s.Client.PingList(s.SimpleCtx(), goodPing) require.NoError(s.T(), err, "no error on stream establishment expected") for true { _, err := stream.Recv() if err == io.EOF { break } assert.NoError(s.T(), err, "no error on messages sent occured") } } func (s *ValidatorTestSuite) TestInvalidErrors_ServerStream() { stream, err := s.Client.PingList(s.SimpleCtx(), badPing) require.NoError(s.T(), err, "no error on stream establishment expected") _, err = stream.Recv() assert.Error(s.T(), err, "error should be received on first message") assert.Equal(s.T(), codes.InvalidArgument, status.Code(err), "gRPC status must be InvalidArgument") } func (s *ValidatorTestSuite) TestInvalidErrors_BidiStream() { stream, err := s.Client.PingStream(s.SimpleCtx()) require.NoError(s.T(), err, "no error on stream establishment expected") stream.Send(goodPing) _, err = stream.Recv() assert.NoError(s.T(), err, "receiving a good ping should return a good pong") stream.Send(goodPing) _, err = stream.Recv() assert.NoError(s.T(), err, "receiving a good ping should return a good pong") stream.Send(badPing) _, err = stream.Recv() assert.Error(s.T(), err, "receiving a good ping should return a good pong") assert.Equal(s.T(), codes.InvalidArgument, status.Code(err), "gRPC status must be InvalidArgument") err = stream.CloseSend() assert.NoError(s.T(), err, "there should be no error closing the stream on send") } type ClientValidatorTestSuite struct { *grpc_testing.InterceptorTestSuite } func (s *ClientValidatorTestSuite) TestValidPasses_Unary() { _, err := s.Client.Ping(s.SimpleCtx(), goodPing) assert.NoError(s.T(), err, "no error expected") } func (s *ClientValidatorTestSuite) TestInvalidErrors_Unary() { _, err := s.Client.Ping(s.SimpleCtx(), badPing) assert.Error(s.T(), err, "error expected") assert.Equal(s.T(), codes.InvalidArgument, status.Code(err), "gRPC status must be InvalidArgument") } go-grpc-middleware-1.3.0/wrappers.go000066400000000000000000000016411404040257500173610ustar00rootroot00000000000000// Copyright 2016 Michal Witkowski. All Rights Reserved. // See LICENSE for licensing terms. package grpc_middleware import ( "context" "google.golang.org/grpc" ) // WrappedServerStream is a thin wrapper around grpc.ServerStream that allows modifying context. type WrappedServerStream struct { grpc.ServerStream // WrappedContext is the wrapper's own Context. You can assign it. WrappedContext context.Context } // Context returns the wrapper's WrappedContext, overwriting the nested grpc.ServerStream.Context() func (w *WrappedServerStream) Context() context.Context { return w.WrappedContext } // WrapServerStream returns a ServerStream that has the ability to overwrite context. func WrapServerStream(stream grpc.ServerStream) *WrappedServerStream { if existing, ok := stream.(*WrappedServerStream); ok { return existing } return &WrappedServerStream{ServerStream: stream, WrappedContext: stream.Context()} } go-grpc-middleware-1.3.0/wrappers_test.go000066400000000000000000000025421404040257500204210ustar00rootroot00000000000000// Copyright 2016 Michal Witkowski. All Rights Reserved. // See LICENSE for licensing terms. package grpc_middleware import ( "context" "testing" "github.com/stretchr/testify/assert" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) func TestWrapServerStream(t *testing.T) { ctx := context.WithValue(context.TODO(), "something", 1) fake := &fakeServerStream{ctx: ctx} wrapped := WrapServerStream(fake) assert.NotNil(t, wrapped.Context().Value("something"), "values from fake must propagate to wrapper") wrapped.WrappedContext = context.WithValue(wrapped.Context(), "other", 2) assert.NotNil(t, wrapped.Context().Value("other"), "values from wrapper must be set") } type fakeServerStream struct { grpc.ServerStream ctx context.Context recvMessage interface{} sentMessage interface{} } func (f *fakeServerStream) Context() context.Context { return f.ctx } func (f *fakeServerStream) SendMsg(m interface{}) error { if f.sentMessage != nil { return status.Errorf(codes.AlreadyExists, "fakeServerStream only takes one message, sorry") } f.sentMessage = m return nil } func (f *fakeServerStream) RecvMsg(m interface{}) error { if f.recvMessage == nil { return status.Errorf(codes.NotFound, "fakeServerStream has no message, sorry") } return nil } type fakeClientStream struct { grpc.ClientStream }