pax_global_header00006660000000000000000000000064144764227330014526gustar00rootroot0000000000000052 comment=e563407504ce5a1c63fe732e550ae0de75266967 go-proton-api-1.0.0/000077500000000000000000000000001447642273300142175ustar00rootroot00000000000000go-proton-api-1.0.0/.github/000077500000000000000000000000001447642273300155575ustar00rootroot00000000000000go-proton-api-1.0.0/.github/workflows/000077500000000000000000000000001447642273300176145ustar00rootroot00000000000000go-proton-api-1.0.0/.github/workflows/check.yml000066400000000000000000000010621447642273300214130ustar00rootroot00000000000000name: Lint and Test on: push jobs: check: runs-on: ubuntu-latest steps: - name: Get sources uses: actions/checkout@v3 - name: Set up Go 1.18 uses: actions/setup-go@v3 with: go-version: '1.18' - name: Run golangci-lint uses: golangci/golangci-lint-action@v3 with: version: v1.50.0 args: --timeout=180s skip-cache: true - name: Run tests run: go test -v ./... - name: Run tests with race check run: go test -v -race ./... go-proton-api-1.0.0/.gitignore000066400000000000000000000000471447642273300162100ustar00rootroot00000000000000# Editor files .*.sw? *~ .idea .vscode go-proton-api-1.0.0/CONTRIBUTING.md000066400000000000000000000007531447642273300164550ustar00rootroot00000000000000# Contribution Policy By making a contribution to this project: 1. I assign any and all copyright related to the contribution to Proton AG; 2. I certify that the contribution was created in whole by me; 3. I understand and agree that this project and the contribution are public and that a record of the contribution (including all personal information I submit with it) is maintained indefinitely and may be redistributed with this project or the open source license(s) involved.go-proton-api-1.0.0/COPYING_NOTES.md000066400000000000000000000220521447642273300166220ustar00rootroot00000000000000# Copying The MIT License (MIT) Copyright (c) 2020 James Houlahan Copyright (c) 2022 Proton AG Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # Dependencies Go Proton API includes the following 3rd party software: * [The Go Project libraries](https://golang.org/project/) | Available under [BSD license](https://golang.org/LICENSE) * [semver](https://github.com/Masterminds/semver/v3) available under [license](https://github.com/Masterminds/semver/v3/blob/master/LICENSE) * [gluon](https://github.com/ProtonMail/gluon) available under [license](https://github.com/ProtonMail/gluon/blob/master/LICENSE) * [go-crypto](https://github.com/ProtonMail/go-crypto) available under [license](https://github.com/ProtonMail/go-crypto/blob/master/LICENSE) * [go-srp](https://github.com/ProtonMail/go-srp) available under [license](https://github.com/ProtonMail/go-srp/blob/master/LICENSE) * [gopenpgp](https://github.com/ProtonMail/gopenpgp/v2) available under [license](https://github.com/ProtonMail/gopenpgp/v2/blob/master/LICENSE) * [goquery](https://github.com/PuerkitoBio/goquery) available under [license](https://github.com/PuerkitoBio/goquery/blob/master/LICENSE) * [juniper](https://github.com/bradenaw/juniper) available under [license](https://github.com/bradenaw/juniper/blob/master/LICENSE) * [go-message](https://github.com/emersion/go-message) available under [license](https://github.com/emersion/go-message/blob/master/LICENSE) * [go-vcard](https://github.com/emersion/go-vcard) available under [license](https://github.com/emersion/go-vcard/blob/master/LICENSE) * [gin](https://github.com/gin-gonic/gin) available under [license](https://github.com/gin-gonic/gin/blob/master/LICENSE) * [resty](https://github.com/go-resty/resty/v2) available under [license](https://github.com/go-resty/resty/v2/blob/master/LICENSE) * [uuid](https://github.com/google/uuid) available under [license](https://github.com/google/uuid/blob/master/LICENSE) * [logrus](https://github.com/sirupsen/logrus) available under [license](https://github.com/sirupsen/logrus/blob/master/LICENSE) * [testify](https://github.com/stretchr/testify) available under [license](https://github.com/stretchr/testify/blob/master/LICENSE) * [cli](https://github.com/urfave/cli/v2) available under [license](https://github.com/urfave/cli/v2/blob/master/LICENSE) * [goleak](https://go.uber.org/goleak) available under [license](https://pkg.go.dev/go.uber.org/goleak?tab=licenses) * [exp](https://golang.org/x/exp) available under [license](https://cs.opensource.google/go/x/exp/+/master:LICENSE) * [net](https://golang.org/x/net) available under [license](https://cs.opensource.google/go/x/net/+/master:LICENSE) * [text](https://golang.org/x/text) available under [license](https://cs.opensource.google/go/x/text/+/master:LICENSE) * [grpc](https://google.golang.org/grpc) available under [license](https://github.com/grpc/grpc-go/blob/master/LICENSE) * [protobuf](https://google.golang.org/protobuf) available under [license](https://github.com/protocolbuffers/protobuf/blob/main/LICENSE) * [bcrypt](https://github.com/ProtonMail/bcrypt) available under [license](https://github.com/ProtonMail/bcrypt/blob/master/LICENSE) * [go-mime](https://github.com/ProtonMail/go-mime) available under [license](https://github.com/ProtonMail/go-mime/blob/master/LICENSE) * [cascadia](https://github.com/andybalholm/cascadia) available under [license](https://github.com/andybalholm/cascadia/blob/master/LICENSE) * [sonic](https://github.com/bytedance/sonic) available under [license](https://github.com/bytedance/sonic/blob/master/LICENSE) * [base64x](https://github.com/chenzhuoyu/base64x) available under [license](https://github.com/chenzhuoyu/base64x/blob/master/LICENSE) * [circl](https://github.com/cloudflare/circl) available under [license](https://github.com/cloudflare/circl/blob/master/LICENSE) * [go-md2man](https://github.com/cpuguy83/go-md2man/v2) available under [license](https://github.com/cpuguy83/go-md2man/v2/blob/master/LICENSE) * [saferith](https://github.com/cronokirby/saferith) available under [license](https://github.com/cronokirby/saferith/blob/master/LICENSE) * [go-spew](https://github.com/davecgh/go-spew) available under [license](https://github.com/davecgh/go-spew/blob/master/LICENSE) * [go-textwrapper](https://github.com/emersion/go-textwrapper) available under [license](https://github.com/emersion/go-textwrapper/blob/master/LICENSE) * [mimetype](https://github.com/gabriel-vasile/mimetype) available under [license](https://github.com/gabriel-vasile/mimetype/blob/master/LICENSE) * [sse](https://github.com/gin-contrib/sse) available under [license](https://github.com/gin-contrib/sse/blob/master/LICENSE) * [locales](https://github.com/go-playground/locales) available under [license](https://github.com/go-playground/locales/blob/master/LICENSE) * [universal-translator](https://github.com/go-playground/universal-translator) available under [license](https://github.com/go-playground/universal-translator/blob/master/LICENSE) * [validator](https://github.com/go-playground/validator/v10) available under [license](https://github.com/go-playground/validator/v10/blob/master/LICENSE) * [go-json](https://github.com/goccy/go-json) available under [license](https://github.com/goccy/go-json/blob/master/LICENSE) * [protobuf](https://github.com/golang/protobuf) available under [license](https://github.com/golang/protobuf/blob/master/LICENSE) * [go](https://github.com/json-iterator/go) available under [license](https://github.com/json-iterator/go/blob/master/LICENSE) * [cpuid](https://github.com/klauspost/cpuid/v2) available under [license](https://github.com/klauspost/cpuid/v2/blob/master/LICENSE) * [text](https://github.com/kr/text) available under [license](https://github.com/kr/text/blob/master/LICENSE) * [go-urn](https://github.com/leodido/go-urn) available under [license](https://github.com/leodido/go-urn/blob/master/LICENSE) * [go-isatty](https://github.com/mattn/go-isatty) available under [license](https://github.com/mattn/go-isatty/blob/master/LICENSE) * [concurrent](https://github.com/modern-go/concurrent) available under [license](https://github.com/modern-go/concurrent/blob/master/LICENSE) * [reflect2](https://github.com/modern-go/reflect2) available under [license](https://github.com/modern-go/reflect2/blob/master/LICENSE) * [go-toml](https://github.com/pelletier/go-toml/v2) available under [license](https://github.com/pelletier/go-toml/v2/blob/master/LICENSE) * [errors](https://github.com/pkg/errors) available under [license](https://github.com/pkg/errors/blob/master/LICENSE) * [go-difflib](https://github.com/pmezard/go-difflib) available under [license](https://github.com/pmezard/go-difflib/blob/master/LICENSE) * [go-internal](https://github.com/rogpeppe/go-internal) available under [license](https://github.com/rogpeppe/go-internal/blob/master/LICENSE) * [blackfriday](https://github.com/russross/blackfriday/v2) available under [license](https://github.com/russross/blackfriday/v2/blob/master/LICENSE) * [golang-asm](https://github.com/twitchyliquid64/golang-asm) available under [license](https://github.com/twitchyliquid64/golang-asm/blob/master/LICENSE) * [codec](https://github.com/ugorji/go/codec) available under [license](https://github.com/ugorji/go/codec/blob/master/LICENSE) * [smetrics](https://github.com/xrash/smetrics) available under [license](https://github.com/xrash/smetrics/blob/master/LICENSE) * [arch](https://golang.org/x/arch) available under [license](https://cs.opensource.google/go/x/arch/+/master:LICENSE) * [crypto](https://golang.org/x/crypto) available under [license](https://cs.opensource.google/go/x/crypto/+/master:LICENSE) * [sync](https://golang.org/x/sync) available under [license](https://cs.opensource.google/go/x/sync/+/master:LICENSE) * [sys](https://golang.org/x/sys) available under [license](https://cs.opensource.google/go/x/sys/+/master:LICENSE) * [genproto](https://google.golang.org/genproto) available under [license](https://pkg.go.dev/google.golang.org/genproto?tab=licenses) * [yaml](https://gopkg.in/yaml.v3) available under [license](https://github.com/go-yaml/yaml/blob/v3.0.1/LICENSE) go-proton-api-1.0.0/LICENSE000066400000000000000000000021261447642273300152250ustar00rootroot00000000000000The MIT License (MIT) Copyright (c) 2020 James Houlahan Copyright (c) 2022 Proton AG Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. go-proton-api-1.0.0/README.md000066400000000000000000000027731447642273300155070ustar00rootroot00000000000000# Go Proton API CI Status GoDoc Go Report Card License This repository holds Go Proton API, a Go library implementing a client and development server for (a subset of) the Proton REST API. The license can be found in the [LICENSE](./LICENSE) file. For the contribution policy, see [CONTRIBUTING](./CONTRIBUTING.md). ## Environment variables Most of the integration tests run locally. The ones that interact with Proton servers require the following environment variables set: - ```GO_PROTON_API_TEST_USERNAME``` - ```GO_PROTON_API_TEST_PASSWORD``` ## Contribution This library is forked from [go-proton-api](https://github.com/ProtonMail/go-proton-api) in order to support the [Proton API Bridge](https://github.com/henrybear327/Proton-API-Bridge) project. Contribution is welcomed! The intention to upstream the changes are planned, once the changes to the codebase has stabalized. go-proton-api-1.0.0/address.go000066400000000000000000000032271447642273300161770ustar00rootroot00000000000000package proton import ( "context" "github.com/go-resty/resty/v2" "golang.org/x/exp/slices" ) func (c *Client) GetAddresses(ctx context.Context) ([]Address, error) { var res struct { Addresses []Address } if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetResult(&res).Get("/core/v4/addresses") }); err != nil { return nil, err } slices.SortFunc(res.Addresses, func(a, b Address) int { return a.Order - b.Order }) return res.Addresses, nil } func (c *Client) GetAddress(ctx context.Context, addressID string) (Address, error) { var res struct { Address Address } if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetResult(&res).Get("/core/v4/addresses/" + addressID) }); err != nil { return Address{}, err } return res.Address, nil } func (c *Client) OrderAddresses(ctx context.Context, req OrderAddressesReq) error { return c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetBody(req).Put("/core/v4/addresses/order") }) } func (c *Client) EnableAddress(ctx context.Context, addressID string) error { return c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.Put("/core/v4/addresses/" + addressID + "/enable") }) } func (c *Client) DisableAddress(ctx context.Context, addressID string) error { return c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.Put("/core/v4/addresses/" + addressID + "/disable") }) } func (c *Client) DeleteAddress(ctx context.Context, addressID string) error { return c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.Delete("/core/v4/addresses/" + addressID) }) } go-proton-api-1.0.0/address_test.go000066400000000000000000000042021447642273300172300ustar00rootroot00000000000000package proton_test import ( "context" "testing" "github.com/henrybear327/go-proton-api" "github.com/henrybear327/go-proton-api/server" "github.com/stretchr/testify/require" ) func TestAddress_Types(t *testing.T) { s := server.New() defer s.Close() // Create a user on the server. userID, _, err := s.CreateUser("user", []byte("pass")) require.NoError(t, err) id2, err := s.CreateAddress(userID, "user@alias.com", []byte("pass")) require.NoError(t, err) require.NoError(t, s.ChangeAddressType(userID, id2, proton.AddressTypeAlias)) id3, err := s.CreateAddress(userID, "user@custom.com", []byte("pass")) require.NoError(t, err) require.NoError(t, s.ChangeAddressType(userID, id3, proton.AddressTypeCustom)) id4, err := s.CreateAddress(userID, "user@premium.com", []byte("pass")) require.NoError(t, err) require.NoError(t, s.ChangeAddressType(userID, id4, proton.AddressTypePremium)) id5, err := s.CreateAddress(userID, "user@external.com", []byte("pass")) require.NoError(t, err) require.NoError(t, s.ChangeAddressType(userID, id5, proton.AddressTypeExternal)) m := proton.New( proton.WithHostURL(s.GetHostURL()), proton.WithTransport(proton.InsecureTransport()), ) defer m.Close() // Create one session for the user. c, auth, err := m.NewClientWithLogin(context.Background(), "user", []byte("pass")) require.NoError(t, err) require.Equal(t, userID, auth.UserID) // Get addresses for the user. addrs, err := c.GetAddresses(context.Background()) require.NoError(t, err) for _, addr := range addrs { switch addr.ID { case id2: require.Equal(t, addr.Email, "user@alias.com") require.Equal(t, addr.Type, proton.AddressTypeAlias) case id3: require.Equal(t, addr.Email, "user@custom.com") require.Equal(t, addr.Type, proton.AddressTypeCustom) case id4: require.Equal(t, addr.Email, "user@premium.com") require.Equal(t, addr.Type, proton.AddressTypePremium) case id5: require.Equal(t, addr.Email, "user@external.com") require.Equal(t, addr.Type, proton.AddressTypeExternal) default: require.Equal(t, addr.Email, "user@proton.local") require.Equal(t, addr.Type, proton.AddressTypeOriginal) } } } go-proton-api-1.0.0/address_types.go000066400000000000000000000010161447642273300174150ustar00rootroot00000000000000package proton type Address struct { ID string Email string Send Bool Receive Bool Status AddressStatus Type AddressType Order int DisplayName string Keys Keys } type OrderAddressesReq struct { AddressIDs []string } type AddressStatus int const ( AddressStatusDisabled AddressStatus = iota AddressStatusEnabled AddressStatusDeleting ) type AddressType int const ( AddressTypeOriginal AddressType = iota + 1 AddressTypeAlias AddressTypeCustom AddressTypePremium AddressTypeExternal ) go-proton-api-1.0.0/attachment.go000066400000000000000000000052331447642273300167010ustar00rootroot00000000000000package proton import ( "bytes" "context" "fmt" "io" "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/go-resty/resty/v2" ) func (c *Client) GetAttachment(ctx context.Context, attachmentID string) ([]byte, error) { var buffer bytes.Buffer if err := c.getAttachment(ctx, attachmentID, &buffer); err != nil { return nil, err } return buffer.Bytes(), nil } func (c *Client) GetAttachmentInto(ctx context.Context, attachmentID string, reader io.ReaderFrom) error { return c.getAttachment(ctx, attachmentID, reader) } func (c *Client) UploadAttachment(ctx context.Context, addrKR *crypto.KeyRing, req CreateAttachmentReq) (Attachment, error) { var res struct { Attachment Attachment } kr, err := addrKR.FirstKey() if err != nil { return res.Attachment, fmt.Errorf("failed to get first key: %w", err) } sig, err := kr.SignDetached(crypto.NewPlainMessage(req.Body)) if err != nil { return Attachment{}, fmt.Errorf("failed to sign attachment: %w", err) } enc, err := kr.EncryptAttachment(crypto.NewPlainMessage(req.Body), req.Filename) if err != nil { return Attachment{}, fmt.Errorf("failed to encrypt attachment: %w", err) } if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetResult(&res). SetMultipartFormData(map[string]string{ "MessageID": req.MessageID, "Filename": req.Filename, "MIMEType": string(req.MIMEType), "Disposition": string(req.Disposition), "ContentID": req.ContentID, }). SetMultipartFields( &resty.MultipartField{ Param: "KeyPackets", FileName: "blob", ContentType: "application/octet-stream", Reader: bytes.NewReader(enc.KeyPacket), }, &resty.MultipartField{ Param: "DataPacket", FileName: "blob", ContentType: "application/octet-stream", Reader: bytes.NewReader(enc.DataPacket), }, &resty.MultipartField{ Param: "Signature", FileName: "blob", ContentType: "application/octet-stream", Reader: bytes.NewReader(sig.GetBinary()), }, ). Post("/mail/v4/attachments") }); err != nil { return Attachment{}, err } return res.Attachment, nil } func (c *Client) getAttachment(ctx context.Context, attachmentID string, reader io.ReaderFrom) error { res, err := c.doRes(ctx, func(req *resty.Request) (*resty.Response, error) { res, err := req.SetDoNotParseResponse(true).Get("/mail/v4/attachments/" + attachmentID) return parseResponse(res, err) }) if err != nil { return fmt.Errorf("failed to request attachment: %w", err) } defer res.RawBody().Close() if _, err = reader.ReadFrom(res.RawBody()); err != nil { return err } return nil } go-proton-api-1.0.0/attachment_interfaces.go000066400000000000000000000051131447642273300211010ustar00rootroot00000000000000package proton import ( "bytes" "context" "github.com/ProtonMail/gluon/async" "github.com/bradenaw/juniper/parallel" ) // AttachmentAllocator abstract the attachment download buffer creation. type AttachmentAllocator interface { // NewBuffer should return a new byte buffer for use. Note that this function may be called from multiple go-routines. NewBuffer() *bytes.Buffer } type DefaultAttachmentAllocator struct{} func NewDefaultAttachmentAllocator() *DefaultAttachmentAllocator { return &DefaultAttachmentAllocator{} } func (DefaultAttachmentAllocator) NewBuffer() *bytes.Buffer { return bytes.NewBuffer(nil) } // Scheduler allows the user to specify how the attachment data for the message should be downloaded. type Scheduler interface { Schedule(ctx context.Context, attachmentIDs []string, storageProvider AttachmentAllocator, downloader func(context.Context, string, *bytes.Buffer) error) ([]*bytes.Buffer, error) } // SequentialScheduler downloads the attachments one by one. type SequentialScheduler struct{} func NewSequentialScheduler() *SequentialScheduler { return &SequentialScheduler{} } func (SequentialScheduler) Schedule(ctx context.Context, attachmentIDs []string, storageProvider AttachmentAllocator, downloader func(context.Context, string, *bytes.Buffer) error) ([]*bytes.Buffer, error) { result := make([]*bytes.Buffer, len(attachmentIDs)) for i, v := range attachmentIDs { select { case <-ctx.Done(): return nil, ctx.Err() default: } buffer := storageProvider.NewBuffer() if err := downloader(ctx, v, buffer); err != nil { return nil, err } result[i] = buffer } return result, nil } type ParallelScheduler struct { workers int panicHandler async.PanicHandler } func NewParallelScheduler(workers int, panicHandler async.PanicHandler) *ParallelScheduler { if workers == 0 { workers = 1 } return &ParallelScheduler{workers: workers} } func (p ParallelScheduler) Schedule(ctx context.Context, attachmentIDs []string, storageProvider AttachmentAllocator, downloader func(context.Context, string, *bytes.Buffer) error) ([]*bytes.Buffer, error) { // If we have less attachments than the maximum works, reduce worker count to match attachment count. workers := p.workers if len(attachmentIDs) < workers { workers = len(attachmentIDs) } return parallel.MapContext(ctx, workers, attachmentIDs, func(ctx context.Context, id string) (*bytes.Buffer, error) { defer async.HandlePanic(p.panicHandler) buffer := storageProvider.NewBuffer() if err := downloader(ctx, id, buffer); err != nil { return nil, err } return buffer, nil }) } go-proton-api-1.0.0/attachment_test.go000066400000000000000000000034421447642273300177400ustar00rootroot00000000000000package proton_test import ( "context" "errors" "net/http" "sync" "testing" "github.com/henrybear327/go-proton-api" "github.com/henrybear327/go-proton-api/server" "github.com/stretchr/testify/require" ) func TestAttachment_429Response(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() s := server.New() defer s.Close() m := proton.New( proton.WithHostURL(s.GetHostURL()), proton.WithTransport(proton.InsecureTransport()), ) _, _, err := s.CreateUser("user", []byte("pass")) require.NoError(t, err) c, _, err := m.NewClientWithLogin(ctx, "user", []byte("pass")) require.NoError(t, err) s.AddStatusHook(func(r *http.Request) (int, bool) { return http.StatusTooManyRequests, true }) _, err = c.GetAttachment(ctx, "someID") require.Error(t, err) apiErr := new(proton.APIError) require.True(t, errors.As(err, &apiErr), "expected to be API error") require.Equal(t, 429, apiErr.Status) require.Equal(t, proton.InvalidValue, apiErr.Code) require.Equal(t, "Request failed with status 429", apiErr.Message) } func TestAttachment_ContextCancelled(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) s := server.New() defer s.Close() m := proton.New( proton.WithHostURL(s.GetHostURL()), proton.WithTransport(proton.InsecureTransport()), ) _, _, err := s.CreateUser("user", []byte("pass")) require.NoError(t, err) c, _, err := m.NewClientWithLogin(ctx, "user", []byte("pass")) require.NoError(t, err) wg := sync.WaitGroup{} wg.Add(1) s.AddStatusHook(func(r *http.Request) (int, bool) { wg.Wait() return http.StatusTooManyRequests, true }) go func() { _, err = c.GetAttachment(ctx, "someID") wg.Done() }() cancel() wg.Wait() require.Error(t, err) require.True(t, errors.Is(err, context.Canceled)) } go-proton-api-1.0.0/attachment_types.go000066400000000000000000000010521447642273300201200ustar00rootroot00000000000000package proton import ( "github.com/ProtonMail/gluon/rfc822" ) type Attachment struct { ID string Name string Size int64 MIMEType rfc822.MIMEType Disposition Disposition Headers Headers KeyPackets string Signature string } type Disposition string const ( InlineDisposition Disposition = "inline" AttachmentDisposition Disposition = "attachment" ) type CreateAttachmentReq struct { MessageID string Filename string MIMEType rfc822.MIMEType Disposition Disposition ContentID string Body []byte } go-proton-api-1.0.0/auth.go000066400000000000000000000021261447642273300155100ustar00rootroot00000000000000package proton import ( "context" "github.com/go-resty/resty/v2" ) func (c *Client) Auth2FA(ctx context.Context, req Auth2FAReq) error { return c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetBody(req).Post("/auth/v4/2fa") }) } func (c *Client) AuthDelete(ctx context.Context) error { return c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.Delete("/auth/v4") }) } func (c *Client) AuthSessions(ctx context.Context) ([]AuthSession, error) { var res struct { Sessions []AuthSession } if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetResult(&res).Get("/auth/v4/sessions") }); err != nil { return nil, err } return res.Sessions, nil } func (c *Client) AuthRevoke(ctx context.Context, authUID string) error { return c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.Delete("/auth/v4/sessions/" + authUID) }) } func (c *Client) AuthRevokeAll(ctx context.Context) error { return c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.Delete("/auth/v4/sessions") }) } go-proton-api-1.0.0/auth_test.go000066400000000000000000000114371447642273300165540ustar00rootroot00000000000000package proton_test import ( "context" "runtime" "testing" "time" "github.com/bradenaw/juniper/parallel" "github.com/henrybear327/go-proton-api" "github.com/henrybear327/go-proton-api/server" "github.com/stretchr/testify/require" ) func TestAuth(t *testing.T) { s := server.New() defer s.Close() _, _, err := s.CreateUser("user", []byte("pass")) require.NoError(t, err) m := proton.New( proton.WithHostURL(s.GetHostURL()), proton.WithTransport(proton.InsecureTransport()), ) defer m.Close() // Create one session. c1, auth1, err := m.NewClientWithLogin(context.Background(), "user", []byte("pass")) require.NoError(t, err) // Revoke all other sessions. require.NoError(t, c1.AuthRevokeAll(context.Background())) // Create another session. c2, _, err := m.NewClientWithLogin(context.Background(), "user", []byte("pass")) require.NoError(t, err) // There should be two sessions. sessions, err := c1.AuthSessions(context.Background()) require.NoError(t, err) require.Len(t, sessions, 2) // Revoke the first session. require.NoError(t, c2.AuthRevoke(context.Background(), auth1.UID)) // The first session should no longer work. require.Error(t, c1.AuthDelete(context.Background())) // There should be one session remaining. remaining, err := c2.AuthSessions(context.Background()) require.NoError(t, err) require.Len(t, remaining, 1) // Delete the last session. require.NoError(t, c2.AuthDelete(context.Background())) } func TestAuth_Refresh(t *testing.T) { s := server.New() defer s.Close() // Create a user on the server. userID, _, err := s.CreateUser("user", []byte("pass")) require.NoError(t, err) // The auth is valid for 4 seconds. s.SetAuthLife(4 * time.Second) m := proton.New( proton.WithHostURL(s.GetHostURL()), proton.WithTransport(proton.InsecureTransport()), ) defer m.Close() // Create one session for the user. c, auth, err := m.NewClientWithLogin(context.Background(), "user", []byte("pass")) require.NoError(t, err) require.Equal(t, userID, auth.UserID) // Wait for 2 seconds. time.Sleep(2 * time.Second) // The client should still be authenticated. { user, err := c.GetUser(context.Background()) require.NoError(t, err) require.Equal(t, "user", user.Name) require.Equal(t, userID, user.ID) } // Wait for 2 more seconds. time.Sleep(2 * time.Second) // The client's auth token should have expired, but will be refreshed on the next request. { user, err := c.GetUser(context.Background()) require.NoError(t, err) require.Equal(t, "user", user.Name) require.Equal(t, userID, user.ID) } } func TestAuth_Refresh_Multi(t *testing.T) { s := server.New() defer s.Close() // Create a user on the server. userID, _, err := s.CreateUser("user", []byte("pass")) require.NoError(t, err) // The auth is valid for 4 seconds. s.SetAuthLife(4 * time.Second) m := proton.New( proton.WithHostURL(s.GetHostURL()), proton.WithTransport(proton.InsecureTransport()), ) defer m.Close() c, auth, err := m.NewClientWithLogin(context.Background(), "user", []byte("pass")) require.NoError(t, err) require.Equal(t, userID, auth.UserID) time.Sleep(2 * time.Second) // The client should still be authenticated. parallel.Do(runtime.NumCPU(), 100, func(idx int) { user, err := c.GetUser(context.Background()) require.NoError(t, err) require.Equal(t, "user", user.Name) require.Equal(t, userID, user.ID) }) // Wait for the auth to expire. time.Sleep(2 * time.Second) // Client auth token should have expired, but will be refreshed on the next request. parallel.Do(runtime.NumCPU(), 100, func(idx int) { user, err := c.GetUser(context.Background()) require.NoError(t, err) require.Equal(t, "user", user.Name) require.Equal(t, userID, user.ID) }) } func TestAuth_Refresh_Deauth(t *testing.T) { s := server.New() defer s.Close() // Create a user on the server. userID, _, err := s.CreateUser("user", []byte("pass")) require.NoError(t, err) m := proton.New( proton.WithHostURL(s.GetHostURL()), proton.WithTransport(proton.InsecureTransport()), ) defer m.Close() // Create one session for the user. c, auth, err := m.NewClientWithLogin(context.Background(), "user", []byte("pass")) require.NoError(t, err) require.Equal(t, userID, auth.UserID) deauth := false c.AddDeauthHandler(func() { deauth = true }) // The client should still be authenticated. { user, err := c.GetUser(context.Background()) require.NoError(t, err) require.Equal(t, "user", user.Name) require.Equal(t, userID, user.ID) } require.NoError(t, s.RevokeUser(userID)) // The client's auth token should have expired, and should not be refreshed { _, err := c.GetUser(context.Background()) require.Error(t, err) } // The client shuold call de-auth handlers. require.Eventually(t, func() bool { return deauth }, time.Second, 300*time.Millisecond) } go-proton-api-1.0.0/block.go000066400000000000000000000021041447642273300156350ustar00rootroot00000000000000package proton import ( "context" "io" "github.com/go-resty/resty/v2" ) func (c *Client) GetBlock(ctx context.Context, bareURL, token string) (io.ReadCloser, error) { res, err := c.doRes(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetHeader("pm-storage-token", token).SetDoNotParseResponse(true).Get(bareURL) }) if err != nil { return nil, err } return res.RawBody(), nil } func (c *Client) RequestBlockUpload(ctx context.Context, req BlockUploadReq) ([]BlockUploadLink, error) { var res struct { UploadLinks []BlockUploadLink } if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetResult(&res).SetBody(req).Post("/drive/blocks") }); err != nil { return nil, err } return res.UploadLinks, nil } func (c *Client) UploadBlock(ctx context.Context, bareURL, token string, block io.Reader) error { return c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r. SetHeader("pm-storage-token", token). SetMultipartField("Block", "blob", "application/octet-stream", block). Post(bareURL) }) } go-proton-api-1.0.0/block_types.go000066400000000000000000000015751447642273300170740ustar00rootroot00000000000000package proton // Block is a block of file contents. They are split in 4MB blocks although this number may change in the future. // Each block is its own data packet separated from the key packet which is held by the node, // which means the sessionKey is the same for every block. type Block struct { Index int BareURL string // URL to the block Token string // Token for download URL Hash string // Encrypted block's sha256 hash, in base64 EncSignature string // Encrypted signature of the block SignatureEmail string // Email used to sign the block } type BlockUploadReq struct { AddressID string ShareID string LinkID string RevisionID string BlockList []BlockUploadInfo } type BlockUploadInfo struct { Index int Size int64 EncSignature string Hash string } type BlockUploadLink struct { Token string BareURL string } go-proton-api-1.0.0/boolean.go000066400000000000000000000013511447642273300161650ustar00rootroot00000000000000package proton import "encoding/json" // Bool is a convenience type for boolean values; it converts from APIBool to Go's builtin bool type. type Bool bool // APIBool is the boolean type used by the API (0 or 1). type APIBool int const ( APIFalse APIBool = iota APITrue ) func (b *Bool) UnmarshalJSON(data []byte) error { var v APIBool if err := json.Unmarshal(data, &v); err != nil { return err } *b = Bool(v == APITrue) return nil } func (b Bool) MarshalJSON() ([]byte, error) { var v APIBool if b { v = APITrue } else { v = APIFalse } return json.Marshal(v) } func (b Bool) String() string { if b { return "true" } return "false" } func (b Bool) FormatURL() string { if b { return "1" } return "0" } go-proton-api-1.0.0/calendar.go000066400000000000000000000034431447642273300163230ustar00rootroot00000000000000package proton import ( "context" "github.com/go-resty/resty/v2" ) func (c *Client) GetCalendars(ctx context.Context) ([]Calendar, error) { var res struct { Calendars []Calendar } if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetResult(&res).Get("/calendar/v1") }); err != nil { return nil, err } return res.Calendars, nil } func (c *Client) GetCalendar(ctx context.Context, calendarID string) (Calendar, error) { var res struct { Calendar Calendar } if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetResult(&res).Get("/calendar/v1/" + calendarID) }); err != nil { return Calendar{}, err } return res.Calendar, nil } func (c *Client) GetCalendarKeys(ctx context.Context, calendarID string) (CalendarKeys, error) { var res struct { Keys CalendarKeys } if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetResult(&res).Get("/calendar/v1/" + calendarID + "/keys") }); err != nil { return nil, err } return res.Keys, nil } func (c *Client) GetCalendarMembers(ctx context.Context, calendarID string) ([]CalendarMember, error) { var res struct { Members []CalendarMember } if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetResult(&res).Get("/calendar/v1/" + calendarID + "/members") }); err != nil { return nil, err } return res.Members, nil } func (c *Client) GetCalendarPassphrase(ctx context.Context, calendarID string) (CalendarPassphrase, error) { var res struct { Passphrase CalendarPassphrase } if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetResult(&res).Get("/calendar/v1/" + calendarID + "/passphrase") }); err != nil { return CalendarPassphrase{}, err } return res.Passphrase, nil } go-proton-api-1.0.0/calendar_event.go000066400000000000000000000035051447642273300175230ustar00rootroot00000000000000package proton import ( "context" "net/url" "strconv" "github.com/go-resty/resty/v2" ) func (c *Client) CountCalendarEvents(ctx context.Context, calendarID string) (int, error) { var res struct { Total int } if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetResult(&res).Get("/calendar/v1/" + calendarID + "/events") }); err != nil { return 0, err } return res.Total, nil } // TODO: For now, the query params are partially constant -- should they be configurable? func (c *Client) GetCalendarEvents(ctx context.Context, calendarID string, page, pageSize int, filter url.Values) ([]CalendarEvent, error) { var res struct { Events []CalendarEvent } if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetQueryParams(map[string]string{ "Page": strconv.Itoa(page), "PageSize": strconv.Itoa(pageSize), }).SetQueryParamsFromValues(filter).SetResult(&res).Get("/calendar/v1/" + calendarID + "/events") }); err != nil { return nil, err } return res.Events, nil } func (c *Client) GetAllCalendarEvents(ctx context.Context, calendarID string, filter url.Values) ([]CalendarEvent, error) { total, err := c.CountCalendarEvents(ctx, calendarID) if err != nil { return nil, err } return fetchPaged(ctx, total, maxPageSize, c, func(ctx context.Context, page, pageSize int) ([]CalendarEvent, error) { return c.GetCalendarEvents(ctx, calendarID, page, pageSize, filter) }) } func (c *Client) GetCalendarEvent(ctx context.Context, calendarID, eventID string) (CalendarEvent, error) { var res struct { Event CalendarEvent } if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetResult(&res).Get("/calendar/v1/" + calendarID + "/events/" + eventID) }); err != nil { return CalendarEvent{}, err } return res.Event, nil } go-proton-api-1.0.0/calendar_event_types.go000066400000000000000000000043521447642273300207500ustar00rootroot00000000000000package proton import ( "encoding/base64" "github.com/ProtonMail/gopenpgp/v2/crypto" ) type CalendarEvent struct { ID string UID string CalendarID string SharedEventID string CreateTime int64 LastEditTime int64 StartTime int64 StartTimezone string EndTime int64 EndTimezone string FullDay Bool Author string Permissions CalendarPermissions Attendees []CalendarAttendee SharedKeyPacket string CalendarKeyPacket string SharedEvents []CalendarEventPart CalendarEvents []CalendarEventPart AttendeesEvents []CalendarEventPart PersonalEvents []CalendarEventPart } // TODO: Only personal events have MemberID; should we have a different type for that? type CalendarEventPart struct { MemberID string Type CalendarEventType Data string Signature string Author string } func (part CalendarEventPart) Decode(calKR *crypto.KeyRing, addrKR *crypto.KeyRing, kp []byte) error { if part.Type&CalendarEventTypeEncrypted != 0 { var enc *crypto.PGPMessage if kp != nil { raw, err := base64.StdEncoding.DecodeString(part.Data) if err != nil { return err } enc = crypto.NewPGPSplitMessage(kp, raw).GetPGPMessage() } else { var err error if enc, err = crypto.NewPGPMessageFromArmored(part.Data); err != nil { return err } } dec, err := calKR.Decrypt(enc, nil, crypto.GetUnixTime()) if err != nil { return err } part.Data = dec.GetString() } if part.Type&CalendarEventTypeSigned != 0 { sig, err := crypto.NewPGPSignatureFromArmored(part.Signature) if err != nil { return err } if err := addrKR.VerifyDetached(crypto.NewPlainMessageFromString(part.Data), sig, crypto.GetUnixTime()); err != nil { return err } } return nil } type CalendarEventType int const ( CalendarEventTypeClear CalendarEventType = iota CalendarEventTypeEncrypted CalendarEventTypeSigned ) type CalendarAttendee struct { ID string Token string Status CalendarAttendeeStatus Permissions CalendarPermissions } // TODO: What is this? type CalendarAttendeeStatus int const ( CalendarAttendeeStatusPending CalendarAttendeeStatus = iota CalendarAttendeeStatusMaybe CalendarAttendeeStatusNo CalendarAttendeeStatusYes ) go-proton-api-1.0.0/calendar_types.go000066400000000000000000000053161447642273300175500ustar00rootroot00000000000000package proton import ( "errors" "github.com/ProtonMail/gopenpgp/v2/crypto" ) type Calendar struct { ID string Name string Description string Color string Display Bool Type CalendarType Flags CalendarFlag } type CalendarFlag int64 const ( CalendarFlagActive CalendarFlag = 1 << iota CalendarFlagUpdatePassphrase CalendarFlagResetNeeded CalendarFlagIncompleteSetup CalendarFlagLostAccess ) type CalendarType int const ( CalendarTypeNormal CalendarType = iota CalendarTypeSubscribed ) type CalendarKey struct { ID string CalendarID string PassphraseID string PrivateKey string Flags CalendarKeyFlag } func (key CalendarKey) Unlock(passphrase []byte) (*crypto.Key, error) { lockedKey, err := crypto.NewKeyFromArmored(key.PrivateKey) if err != nil { return nil, err } return lockedKey.Unlock(passphrase) } type CalendarKeys []CalendarKey func (keys CalendarKeys) Unlock(passphrase []byte) (*crypto.KeyRing, error) { kr, err := crypto.NewKeyRing(nil) if err != nil { return nil, err } for _, key := range keys { if k, err := key.Unlock(passphrase); err != nil { continue } else if err := kr.AddKey(k); err != nil { return nil, err } } return kr, nil } // TODO: What is this? type CalendarKeyFlag int64 const ( CalendarKeyFlagActive CalendarKeyFlag = 1 << iota CalendarKeyFlagPrimary ) type CalendarMember struct { ID string Permissions CalendarPermissions Email string Color string Display Bool CalendarID string } // TODO: What is this? type CalendarPermissions int // TODO: Support invitations. type CalendarPassphrase struct { ID string Flags CalendarPassphraseFlag MemberPassphrases []MemberPassphrase } func (passphrase CalendarPassphrase) Decrypt(memberID string, addrKR *crypto.KeyRing) ([]byte, error) { for _, passphrase := range passphrase.MemberPassphrases { if passphrase.MemberID == memberID { return passphrase.decrypt(addrKR) } } return nil, errors.New("no such member passphrase") } // TODO: What is this? type CalendarPassphraseFlag int64 type MemberPassphrase struct { MemberID string Passphrase string Signature string } func (passphrase MemberPassphrase) decrypt(addrKR *crypto.KeyRing) ([]byte, error) { msg, err := crypto.NewPGPMessageFromArmored(passphrase.Passphrase) if err != nil { return nil, err } sig, err := crypto.NewPGPSignatureFromArmored(passphrase.Signature) if err != nil { return nil, err } dec, err := addrKR.Decrypt(msg, nil, crypto.GetUnixTime()) if err != nil { return nil, err } if err := addrKR.VerifyDetached(dec, sig, crypto.GetUnixTime()); err != nil { return nil, err } return dec.GetBinary(), nil } go-proton-api-1.0.0/client.go000066400000000000000000000110611447642273300160230ustar00rootroot00000000000000package proton import ( "context" "errors" "fmt" "net" "net/http" "sync" "sync/atomic" "github.com/go-resty/resty/v2" ) // clientID is a unique identifier for a client. var clientID uint64 // AuthHandler is given any new auths that are returned from the API due to an unexpected auth refresh. type AuthHandler func(Auth) // Handler is a generic function that can be registered for a certain event (e.g. deauth, API code). type Handler func() // Client is the proton client. type Client struct { m *Manager // clientID is this client's unique ID. clientID uint64 uid string acc string ref string authLock sync.RWMutex authHandlers []AuthHandler deauthHandlers []Handler hookLock sync.RWMutex deauthOnce sync.Once } func newClient(m *Manager, uid string) *Client { c := &Client{ m: m, uid: uid, clientID: atomic.AddUint64(&clientID, 1), } return c } func (c *Client) AddAuthHandler(handler AuthHandler) { c.hookLock.Lock() defer c.hookLock.Unlock() c.authHandlers = append(c.authHandlers, handler) } func (c *Client) AddDeauthHandler(handler Handler) { c.hookLock.Lock() defer c.hookLock.Unlock() c.deauthHandlers = append(c.deauthHandlers, handler) } func (c *Client) AddPreRequestHook(hook resty.RequestMiddleware) { c.hookLock.Lock() defer c.hookLock.Unlock() c.m.rc.OnBeforeRequest(func(rc *resty.Client, r *resty.Request) error { if clientID, ok := ClientIDFromContext(r.Context()); !ok || clientID != c.clientID { return nil } return hook(rc, r) }) } func (c *Client) AddPostRequestHook(hook resty.ResponseMiddleware) { c.hookLock.Lock() defer c.hookLock.Unlock() c.m.rc.OnAfterResponse(func(rc *resty.Client, r *resty.Response) error { if clientID, ok := ClientIDFromContext(r.Request.Context()); !ok || clientID != c.clientID { return nil } return hook(rc, r) }) } func (c *Client) Close() { c.authLock.Lock() defer c.authLock.Unlock() c.uid = "" c.acc = "" c.ref = "" c.hookLock.Lock() defer c.hookLock.Unlock() c.authHandlers = nil c.deauthHandlers = nil } func (c *Client) withAuth(acc, ref string) *Client { c.acc = acc c.ref = ref return c } func (c *Client) do(ctx context.Context, fn func(*resty.Request) (*resty.Response, error)) error { if _, err := c.doRes(ctx, fn); err != nil { return err } return nil } func (c *Client) doRes(ctx context.Context, fn func(*resty.Request) (*resty.Response, error)) (*resty.Response, error) { c.hookLock.RLock() defer c.hookLock.RUnlock() res, err := c.exec(ctx, fn) if res != nil { // If we receive no response, we can't do anything. if res.RawResponse == nil { return nil, newNetError(err, "received no response from API") } // If we receive a net error, we can't do anything. if resErr, ok := err.(*resty.ResponseError); ok { if netErr := new(net.OpError); errors.As(resErr.Err, &netErr) { return nil, newNetError(netErr, "network error while communicating with API") } } // If we receive a 401, we need to refresh the auth. if res.StatusCode() == http.StatusUnauthorized { if err := c.authRefresh(ctx); err != nil { return nil, fmt.Errorf("failed to refresh auth: %w", err) } if res, err = c.exec(ctx, fn); err != nil { return nil, fmt.Errorf("failed to retry request: %w", err) } } } return res, err } func (c *Client) exec(ctx context.Context, fn func(*resty.Request) (*resty.Response, error)) (*resty.Response, error) { c.authLock.RLock() defer c.authLock.RUnlock() r := c.m.r(WithClient(ctx, c.clientID)) if c.uid != "" { r.SetHeader("x-pm-uid", c.uid) } if c.acc != "" { r.SetAuthToken(c.acc) } return fn(r) } func (c *Client) authRefresh(ctx context.Context) error { c.authLock.Lock() defer c.authLock.Unlock() c.hookLock.RLock() defer c.hookLock.RUnlock() auth, err := c.m.authRefresh(ctx, c.uid, c.ref) if err != nil { if respErr, ok := err.(*resty.ResponseError); ok { switch respErr.Response.StatusCode() { case http.StatusBadRequest, http.StatusUnprocessableEntity: c.deauthOnce.Do(func() { for _, handler := range c.deauthHandlers { handler() } }) return fmt.Errorf("failed to refresh auth, de-auth: %w", err) case http.StatusConflict, http.StatusTooManyRequests, http.StatusInternalServerError, http.StatusServiceUnavailable: return fmt.Errorf("failed to refresh auth, server issues: %w", err) default: // } } return fmt.Errorf("failed to refresh auth: %w", err) } c.acc = auth.AccessToken c.ref = auth.RefreshToken for _, handler := range c.authHandlers { handler(auth) } return nil } go-proton-api-1.0.0/contact.go000066400000000000000000000070121447642273300162010ustar00rootroot00000000000000package proton import ( "context" "strconv" "github.com/go-resty/resty/v2" ) func (c *Client) GetContact(ctx context.Context, contactID string) (Contact, error) { var res struct { Contact Contact } if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetResult(&res).Get("/contacts/v4/" + contactID) }); err != nil { return Contact{}, err } return res.Contact, nil } func (c *Client) CountContacts(ctx context.Context) (int, error) { var res struct { Total int } if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetResult(&res).Get("/contacts/v4") }); err != nil { return 0, err } return res.Total, nil } func (c *Client) CountContactEmails(ctx context.Context, email string) (int, error) { var res struct { Total int } if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetResult(&res).SetQueryParam("Email", email).Get("/contacts/v4/emails") }); err != nil { return 0, err } return res.Total, nil } func (c *Client) GetContacts(ctx context.Context, page, pageSize int) ([]Contact, error) { var res struct { Contacts []Contact } if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetQueryParams(map[string]string{ "Page": strconv.Itoa(page), "PageSize": strconv.Itoa(pageSize), }).SetResult(&res).Get("/contacts/v4") }); err != nil { return nil, err } return res.Contacts, nil } func (c *Client) GetAllContacts(ctx context.Context) ([]Contact, error) { total, err := c.CountContacts(ctx) if err != nil { return nil, err } return fetchPaged(ctx, total, maxPageSize, c, func(ctx context.Context, page, pageSize int) ([]Contact, error) { return c.GetContacts(ctx, page, pageSize) }) } func (c *Client) GetContactEmails(ctx context.Context, email string, page, pageSize int) ([]ContactEmail, error) { var res struct { ContactEmails []ContactEmail } if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetQueryParams(map[string]string{ "Page": strconv.Itoa(page), "PageSize": strconv.Itoa(pageSize), "Email": email, }).SetResult(&res).Get("/contacts/v4/emails") }); err != nil { return nil, err } return res.ContactEmails, nil } func (c *Client) GetAllContactEmails(ctx context.Context, email string) ([]ContactEmail, error) { total, err := c.CountContactEmails(ctx, email) if err != nil { return nil, err } return fetchPaged(ctx, total, maxPageSize, c, func(ctx context.Context, page, pageSize int) ([]ContactEmail, error) { return c.GetContactEmails(ctx, email, page, pageSize) }) } func (c *Client) CreateContacts(ctx context.Context, req CreateContactsReq) ([]CreateContactsRes, error) { var res struct { Responses []CreateContactsRes } if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetBody(req).SetResult(&res).Post("/contacts/v4") }); err != nil { return nil, err } return res.Responses, nil } func (c *Client) UpdateContact(ctx context.Context, contactID string, req UpdateContactReq) (Contact, error) { var res struct { Contact Contact } if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetBody(req).SetResult(&res).Put("/contacts/v4/" + contactID) }); err != nil { return Contact{}, err } return res.Contact, nil } func (c *Client) DeleteContacts(ctx context.Context, req DeleteContactsReq) error { return c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetBody(req).Put("/contacts/v4/delete") }) } go-proton-api-1.0.0/contact_card.go000066400000000000000000000152501447642273300171750ustar00rootroot00000000000000package proton import ( "bytes" "errors" "strings" "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/bradenaw/juniper/xslices" "github.com/emersion/go-vcard" ) const ( FieldPMScheme = "X-PM-SCHEME" FieldPMSign = "X-PM-SIGN" FieldPMEncrypt = "X-PM-ENCRYPT" FieldPMMIMEType = "X-PM-MIMETYPE" ) type Cards []*Card func (c *Cards) Merge(kr *crypto.KeyRing) (vcard.Card, error) { merged := newVCard() for _, card := range *c { dec, err := card.decode(kr) if err != nil { return nil, err } for k, fields := range dec { for _, f := range fields { merged.Add(k, f) } } } return merged, nil } func (c *Cards) Get(cardType CardType) (*Card, bool) { for _, card := range *c { if card.Type == cardType { return card, true } } return nil, false } type Card struct { Type CardType Data string Signature string } type CardType int const ( CardTypeClear CardType = iota CardTypeEncrypted CardTypeSigned ) func NewCard(kr *crypto.KeyRing, cardType CardType) (*Card, error) { card := &Card{Type: cardType} if err := card.encode(kr, newVCard()); err != nil { return nil, err } return card, nil } func newVCard() vcard.Card { card := make(vcard.Card) card.AddValue(vcard.FieldVersion, "4.0") return card } func (c Card) Get(kr *crypto.KeyRing, key string) ([]*vcard.Field, error) { dec, err := c.decode(kr) if err != nil { return nil, err } return dec[key], nil } func (c *Card) Set(kr *crypto.KeyRing, key, value string) error { dec, err := c.decode(kr) if err != nil { return err } if field := dec.Get(key); field != nil { field.Value = value return c.encode(kr, dec) } dec.AddValue(key, value) return c.encode(kr, dec) } func (c *Card) ChangeType(kr *crypto.KeyRing, cardType CardType) error { dec, err := c.decode(kr) if err != nil { return err } c.Type = cardType return c.encode(kr, dec) } // GetGroup returns a type to manipulate the group defined by the given key/value pair. func (c Card) GetGroup(kr *crypto.KeyRing, groupKey, groupValue string) (CardGroup, error) { group, err := c.getGroup(kr, groupKey, groupValue) if err != nil { return CardGroup{}, err } return CardGroup{Card: c, kr: kr, group: group}, nil } // DeleteGroup removes all values in the group defined by the given key/value pair. func (c *Card) DeleteGroup(kr *crypto.KeyRing, groupKey, groupValue string) error { group, err := c.getGroup(kr, groupKey, groupValue) if err != nil { return err } return c.deleteGroup(kr, group) } type CardGroup struct { Card kr *crypto.KeyRing group string } // Get returns the values in the group with the given key. func (g CardGroup) Get(key string) ([]string, error) { dec, err := g.decode(g.kr) if err != nil { return nil, err } var fields []*vcard.Field for _, field := range dec[key] { if field.Group != g.group { continue } fields = append(fields, field) } return xslices.Map(fields, func(field *vcard.Field) string { return field.Value }), nil } // Set sets the value in the group. func (g *CardGroup) Set(key, value string, params vcard.Params) error { dec, err := g.decode(g.kr) if err != nil { return err } for _, field := range dec[key] { if field.Group != g.group { continue } field.Value = value return g.encode(g.kr, dec) } dec.Add(key, &vcard.Field{ Value: value, Group: g.group, Params: params, }) return g.encode(g.kr, dec) } // Add adds a value to the group. func (g *CardGroup) Add(key, value string, params vcard.Params) error { dec, err := g.decode(g.kr) if err != nil { return err } dec.Add(key, &vcard.Field{ Value: value, Group: g.group, Params: params, }) return g.encode(g.kr, dec) } // Remove removes the value in the group with the given key/value. func (g *CardGroup) Remove(key, value string) error { dec, err := g.decode(g.kr) if err != nil { return err } fields, ok := dec[key] if !ok { return errors.New("no such key") } var rest []*vcard.Field for _, field := range fields { if field.Group != g.group { rest = append(rest, field) } else if field.Value != value { rest = append(rest, field) } } if len(rest) > 0 { dec[key] = rest } else { delete(dec, key) } return g.encode(g.kr, dec) } // RemoveAll removes all values in the group with the given key. func (g *CardGroup) RemoveAll(key string) error { dec, err := g.decode(g.kr) if err != nil { return err } fields, ok := dec[key] if !ok { return errors.New("no such key") } var rest []*vcard.Field for _, field := range fields { if field.Group != g.group { rest = append(rest, field) } } if len(rest) > 0 { dec[key] = rest } else { delete(dec, key) } return g.encode(g.kr, dec) } func (c Card) getGroup(kr *crypto.KeyRing, groupKey, groupValue string) (string, error) { fields, err := c.Get(kr, groupKey) if err != nil { return "", err } for _, field := range fields { if field.Value != groupValue { continue } return field.Group, nil } return "", errors.New("no such field") } func (c *Card) deleteGroup(kr *crypto.KeyRing, group string) error { dec, err := c.decode(kr) if err != nil { return err } for key, fields := range dec { var rest []*vcard.Field for _, field := range fields { if field.Group != group { rest = append(rest, field) } } if len(rest) > 0 { dec[key] = rest } else { delete(dec, key) } } return c.encode(kr, dec) } func (c Card) decode(kr *crypto.KeyRing) (vcard.Card, error) { if c.Type&CardTypeEncrypted != 0 { enc, err := crypto.NewPGPMessageFromArmored(c.Data) if err != nil { return nil, err } dec, err := kr.Decrypt(enc, nil, crypto.GetUnixTime()) if err != nil { return nil, err } c.Data = dec.GetString() } if c.Type&CardTypeSigned != 0 { sig, err := crypto.NewPGPSignatureFromArmored(c.Signature) if err != nil { return nil, err } if err := kr.VerifyDetached(crypto.NewPlainMessageFromString(c.Data), sig, crypto.GetUnixTime()); err != nil { return nil, err } } return vcard.NewDecoder(strings.NewReader(c.Data)).Decode() } func (c *Card) encode(kr *crypto.KeyRing, card vcard.Card) error { buf := new(bytes.Buffer) if err := vcard.NewEncoder(buf).Encode(card); err != nil { return err } if c.Type&CardTypeSigned != 0 { sig, err := kr.SignDetached(crypto.NewPlainMessageFromString(buf.String())) if err != nil { return err } if c.Signature, err = sig.GetArmored(); err != nil { return err } } if c.Type&CardTypeEncrypted != 0 { enc, err := kr.Encrypt(crypto.NewPlainMessageFromString(buf.String()), nil) if err != nil { return err } if c.Data, err = enc.GetArmored(); err != nil { return err } } else { c.Data = buf.String() } return nil } go-proton-api-1.0.0/contact_types.go000066400000000000000000000056361447642273300174370ustar00rootroot00000000000000package proton import ( "encoding/base64" "strconv" "strings" "github.com/ProtonMail/gluon/rfc822" "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/emersion/go-vcard" ) type RecipientType int const ( RecipientTypeInternal RecipientType = iota + 1 RecipientTypeExternal ) type ContactSettings struct { MIMEType *rfc822.MIMEType Scheme *EncryptionScheme Sign *bool Encrypt *bool Keys []*crypto.Key } type Contact struct { ContactMetadata ContactCards } func (c *Contact) GetSettings(kr *crypto.KeyRing, email string) (ContactSettings, error) { signedCard, ok := c.Cards.Get(CardTypeSigned) if !ok { return ContactSettings{}, nil } group, err := signedCard.GetGroup(kr, vcard.FieldEmail, email) if err != nil { return ContactSettings{}, nil } var settings ContactSettings scheme, err := group.Get(FieldPMScheme) if err != nil { return ContactSettings{}, err } if len(scheme) > 0 { switch scheme[0] { case "pgp-inline": settings.Scheme = newPtr(PGPInlineScheme) case "pgp-mime": settings.Scheme = newPtr(PGPMIMEScheme) } } mimeType, err := group.Get(FieldPMMIMEType) if err != nil { return ContactSettings{}, err } if len(mimeType) > 0 { settings.MIMEType = newPtr(rfc822.MIMEType(mimeType[0])) } sign, err := group.Get(FieldPMSign) if err != nil { return ContactSettings{}, err } if len(sign) > 0 { sign, err := strconv.ParseBool(sign[0]) if err != nil { return ContactSettings{}, err } settings.Sign = newPtr(sign) } encrypt, err := group.Get(FieldPMEncrypt) if err != nil { return ContactSettings{}, err } if len(encrypt) > 0 { encrypt, err := strconv.ParseBool(encrypt[0]) if err != nil { return ContactSettings{}, err } settings.Encrypt = newPtr(encrypt) } keys, err := group.Get(vcard.FieldKey) if err != nil { return ContactSettings{}, err } if len(keys) > 0 { for _, key := range keys { dec, err := base64.StdEncoding.DecodeString(strings.SplitN(key, ",", 2)[1]) if err != nil { return ContactSettings{}, err } pubKey, err := crypto.NewKey(dec) if err != nil { return ContactSettings{}, err } settings.Keys = append(settings.Keys, pubKey) } } return settings, nil } type ContactMetadata struct { ID string Name string UID string Size int64 CreateTime int64 ModifyTime int64 ContactEmails []ContactEmail LabelIDs []string } type ContactCards struct { Cards Cards } type ContactEmail struct { ID string Name string Email string Type []string ContactID string LabelIDs []string } type CreateContactsReq struct { Contacts []ContactCards Overwrite int Labels int } type CreateContactsRes struct { Index int Response struct { APIError Contact Contact } } type UpdateContactReq struct { Cards Cards } type DeleteContactsReq struct { IDs []string } func newPtr[T any](v T) *T { return &v } go-proton-api-1.0.0/contexts.go000066400000000000000000000011141447642273300164120ustar00rootroot00000000000000package proton import "context" type withClientKeyType struct{} var withClientKey withClientKeyType // WithClient marks this context as originating from the client with the given ID. func WithClient(parent context.Context, clientID uint64) context.Context { return context.WithValue(parent, withClientKey, clientID) } // ClientIDFromContext returns true if this context was marked as originating from a client. func ClientIDFromContext(ctx context.Context) (uint64, bool) { clientID, ok := ctx.Value(withClientKey).(uint64) if !ok { return 0, false } return clientID, true } go-proton-api-1.0.0/core_settings.go000066400000000000000000000022461447642273300174220ustar00rootroot00000000000000package proton import ( "context" "github.com/go-resty/resty/v2" ) func (c *Client) GetUserSettings(ctx context.Context) (UserSettings, error) { var res struct { UserSettings UserSettings } if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetResult(&res).Get("/core/v4/settings") }); err != nil { return UserSettings{}, err } return res.UserSettings, nil } func (c *Client) SetUserSettingsTelemetry(ctx context.Context, req SetTelemetryReq) (UserSettings, error) { var res struct { UserSettings UserSettings } if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetBody(req).SetResult(&res).Put("/core/v4/settings/telemetry") }); err != nil { return UserSettings{}, err } return res.UserSettings, nil } func (c *Client) SetUserSettingsCrashReports(ctx context.Context, req SetCrashReportReq) (UserSettings, error) { var res struct { UserSettings UserSettings } if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetBody(req).SetResult(&res).Put("/core/v4/settings/crashreports") }); err != nil { return UserSettings{}, err } return res.UserSettings, nil } go-proton-api-1.0.0/core_settings_type.go000066400000000000000000000004601447642273300204570ustar00rootroot00000000000000package proton type UserSettings struct { Telemetry SettingsBool CrashReports SettingsBool } type SetTelemetryReq struct { Telemetry SettingsBool } type SetCrashReportReq struct { CrashReports SettingsBool } type SettingsBool int const ( SettingDisabled SettingsBool = iota SettingEnabled ) go-proton-api-1.0.0/data.go000066400000000000000000000007541447642273300154650ustar00rootroot00000000000000package proton import ( "context" "github.com/go-resty/resty/v2" ) func (c *Client) SendDataEvent(ctx context.Context, req SendStatsReq) error { return c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetBody(req).Post("/data/v1/stats") }) } func (c *Client) SendDataEventMultiple(ctx context.Context, req SendStatsMultiReq) error { return c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetBody(req).Post("/data/v1/stats/multiple") }) } go-proton-api-1.0.0/data_type.go000066400000000000000000000003361447642273300165220ustar00rootroot00000000000000package proton type SendStatsReq struct { MeasurementGroup string Event string Values map[string]any Dimensions map[string]any } type SendStatsMultiReq struct { EventInfo []SendStatsReq } go-proton-api-1.0.0/event.go000066400000000000000000000043401447642273300156700ustar00rootroot00000000000000package proton import ( "context" "time" "github.com/ProtonMail/gluon/async" "github.com/go-resty/resty/v2" ) func (c *Client) GetLatestEventID(ctx context.Context) (string, error) { var res struct { Event } if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetResult(&res).Get("/core/v4/events/latest") }); err != nil { return "", err } return res.EventID, nil } // maxCollectedEvents limits the number of events which are collected per one GetEvent // call. const maxCollectedEvents = 50 func (c *Client) GetEvent(ctx context.Context, eventID string) ([]Event, bool, error) { var events []Event event, more, err := c.getEvent(ctx, eventID) if err != nil { return nil, more, err } events = append(events, event) nCollected := 0 for more { nCollected++ if nCollected >= maxCollectedEvents { break } event, more, err = c.getEvent(ctx, event.EventID) if err != nil { return nil, false, err } events = append(events, event) } return events, more, nil } // NewEventStreamer returns a new event stream. // It polls the API for new events at random intervals between `period` and `period+jitter`. func (c *Client) NewEventStream(ctx context.Context, period, jitter time.Duration, lastEventID string) <-chan Event { eventCh := make(chan Event) go func() { defer async.HandlePanic(c.m.panicHandler) defer close(eventCh) ticker := NewTicker(period, jitter, c.m.panicHandler) defer ticker.Stop() for { select { case <-ctx.Done(): return case <-ticker.C: // ... } events, _, err := c.GetEvent(ctx, lastEventID) if err != nil { continue } if events[len(events)-1].EventID == lastEventID { continue } for _, evt := range events { select { case <-ctx.Done(): return case eventCh <- evt: lastEventID = evt.EventID } } } }() return eventCh } func (c *Client) getEvent(ctx context.Context, eventID string) (Event, bool, error) { var res struct { Event More Bool } if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetResult(&res).Get("/core/v4/events/" + eventID) }); err != nil { return Event{}, false, err } return res.Event, bool(res.More), nil } go-proton-api-1.0.0/event_drive.go000066400000000000000000000046551447642273300170720ustar00rootroot00000000000000package proton import ( "context" "github.com/go-resty/resty/v2" ) func (c *Client) GetLatestVolumeEventID(ctx context.Context, volumeID string) (string, error) { var res struct { EventID string } if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetResult(&res).Get("/drive/volumes/" + volumeID + "/events/latest") }); err != nil { return "", err } return res.EventID, nil } func (c *Client) GetLatestShareEventID(ctx context.Context, shareID string) (string, error) { var res struct { EventID string } if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetResult(&res).Get("/drive/shares/" + shareID + "/events/latest") }); err != nil { return "", err } return res.EventID, nil } func (c *Client) GetVolumeEvent(ctx context.Context, volumeID, eventID string) (DriveEvent, error) { event, more, err := c.getVolumeEvent(ctx, volumeID, eventID) if err != nil { return DriveEvent{}, err } for more { var next DriveEvent next, more, err = c.getVolumeEvent(ctx, volumeID, event.EventID) if err != nil { return DriveEvent{}, err } event.Events = append(event.Events, next.Events...) } return event, nil } func (c *Client) GetShareEvent(ctx context.Context, shareID, eventID string) (DriveEvent, error) { event, more, err := c.getShareEvent(ctx, shareID, eventID) if err != nil { return DriveEvent{}, err } for more { var next DriveEvent next, more, err = c.getShareEvent(ctx, shareID, event.EventID) if err != nil { return DriveEvent{}, err } event.Events = append(event.Events, next.Events...) } return event, nil } func (c *Client) getVolumeEvent(ctx context.Context, volumeID, eventID string) (DriveEvent, bool, error) { var res struct { DriveEvent More Bool } if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetResult(&res).Get("/drive/volumes/" + volumeID + "/events/" + eventID) }); err != nil { return DriveEvent{}, false, err } return res.DriveEvent, bool(res.More), nil } func (c *Client) getShareEvent(ctx context.Context, shareID, eventID string) (DriveEvent, bool, error) { var res struct { DriveEvent More Bool } if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetResult(&res).Get("/drive/shares/" + shareID + "/events/" + eventID) }); err != nil { return DriveEvent{}, false, err } return res.DriveEvent, bool(res.More), nil } go-proton-api-1.0.0/event_drive_types.go000066400000000000000000000005201447642273300203010ustar00rootroot00000000000000package proton type DriveEvent struct { EventID string Events []LinkEvent Refresh Bool } type LinkEvent struct { EventID string EventType LinkEventType CreateTime int Link Link Data any } type LinkEventType int const ( LinkEventDelete LinkEventType = iota LinkEventCreate LinkEventUpdate LinkEventUpdateMetadata ) go-proton-api-1.0.0/event_test.go000066400000000000000000000055131447642273300167320ustar00rootroot00000000000000package proton_test import ( "context" "testing" "time" "github.com/google/uuid" "github.com/henrybear327/go-proton-api" "github.com/henrybear327/go-proton-api/server" "github.com/stretchr/testify/require" ) func TestEventStreamer(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() s := server.New() defer s.Close() m := proton.New( proton.WithHostURL(s.GetHostURL()), proton.WithTransport(proton.InsecureTransport()), ) _, _, err := s.CreateUser("user", []byte("pass")) require.NoError(t, err) c, _, err := m.NewClientWithLogin(ctx, "user", []byte("pass")) require.NoError(t, err) createTestMessages(t, c, "pass", 10) latestEventID, err := c.GetLatestEventID(ctx) require.NoError(t, err) eventCh := make(chan proton.Event) go func() { for event := range c.NewEventStream(ctx, time.Second, 0, latestEventID) { eventCh <- event } }() // Perform some action to generate an event. metadata, err := c.GetMessageMetadata(ctx, proton.MessageFilter{}) require.NoError(t, err) require.NoError(t, c.LabelMessages(ctx, []string{metadata[0].ID}, proton.TrashLabel)) // Wait for the first event. <-eventCh // Close the client; this should stop the client's event streamer. c.Close() // Create a new client and perform some actions with it to generate more events. cc, _, err := m.NewClientWithLogin(ctx, "user", []byte("pass")) require.NoError(t, err) defer cc.Close() require.NoError(t, cc.LabelMessages(ctx, []string{metadata[1].ID}, proton.TrashLabel)) // We should not receive any more events from the original client. select { case <-eventCh: require.Fail(t, "received unexpected event") default: // ... } } func TestMaxEventMerge(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() s := server.New() defer s.Close() s.SetMaxUpdatesPerEvent(1) m := proton.New( proton.WithHostURL(s.GetHostURL()), proton.WithTransport(proton.InsecureTransport()), ) _, _, err := s.CreateUser("user", []byte("pass")) require.NoError(t, err) c, _, err := m.NewClientWithLogin(ctx, "user", []byte("pass")) require.NoError(t, err) latestID, err := c.GetLatestEventID(ctx) require.NoError(t, err) label, err := c.CreateLabel(context.Background(), proton.CreateLabelReq{ Name: uuid.NewString(), Color: "#f66", Type: proton.LabelTypeFolder, }) require.NoError(t, err) for i := 0; i < 75; i++ { _, err := c.UpdateLabel(ctx, label.ID, proton.UpdateLabelReq{Name: uuid.NewString()}) require.NoError(t, err) } events, more, err := c.GetEvent(ctx, latestID) require.NoError(t, err) require.True(t, more) require.Equal(t, 50, len(events)) events2, more, err := c.GetEvent(ctx, events[len(events)-1].EventID) require.NotEqual(t, events, events2) require.NoError(t, err) require.False(t, more) require.Equal(t, 26, len(events2)) } go-proton-api-1.0.0/event_types.go000066400000000000000000000055561447642273300171260ustar00rootroot00000000000000package proton import ( "fmt" "strings" "github.com/bradenaw/juniper/xslices" ) type Event struct { EventID string Refresh RefreshFlag User *User UserSettings *UserSettings MailSettings *MailSettings Messages []MessageEvent Labels []LabelEvent Addresses []AddressEvent UsedSpace *int } func (event Event) String() string { var parts []string if event.Refresh != 0 { parts = append(parts, fmt.Sprintf("refresh: %v", event.Refresh)) } if event.User != nil { parts = append(parts, "user: [modified]") } if event.MailSettings != nil { parts = append(parts, "mail-settings: [modified]") } if len(event.Messages) > 0 { parts = append(parts, fmt.Sprintf( "messages: created=%d, updated=%d, deleted=%d", xslices.CountFunc(event.Messages, func(e MessageEvent) bool { return e.Action == EventCreate }), xslices.CountFunc(event.Messages, func(e MessageEvent) bool { return e.Action == EventUpdate || e.Action == EventUpdateFlags }), xslices.CountFunc(event.Messages, func(e MessageEvent) bool { return e.Action == EventDelete }), )) } if len(event.Labels) > 0 { parts = append(parts, fmt.Sprintf( "labels: created=%d, updated=%d, deleted=%d", xslices.CountFunc(event.Labels, func(e LabelEvent) bool { return e.Action == EventCreate }), xslices.CountFunc(event.Labels, func(e LabelEvent) bool { return e.Action == EventUpdate || e.Action == EventUpdateFlags }), xslices.CountFunc(event.Labels, func(e LabelEvent) bool { return e.Action == EventDelete }), )) } if len(event.Addresses) > 0 { parts = append(parts, fmt.Sprintf( "addresses: created=%d, updated=%d, deleted=%d", xslices.CountFunc(event.Addresses, func(e AddressEvent) bool { return e.Action == EventCreate }), xslices.CountFunc(event.Addresses, func(e AddressEvent) bool { return e.Action == EventUpdate || e.Action == EventUpdateFlags }), xslices.CountFunc(event.Addresses, func(e AddressEvent) bool { return e.Action == EventDelete }), )) } return fmt.Sprintf("Event %s: %s", event.EventID, strings.Join(parts, ", ")) } type RefreshFlag uint8 const ( RefreshMail RefreshFlag = 1 << iota // 1<<0 = 1 _ // 1<<1 = 2 _ // 1<<2 = 4 _ // 1<<3 = 8 _ // 1<<4 = 16 _ // 1<<5 = 32 _ // 1<<6 = 64 _ // 1<<7 = 128 RefreshAll RefreshFlag = 1<5s, and we only allow 1s in the context. // Thus, it will fail. c := m.NewClient("", "", "") defer c.Close() if _, err := c.GetAddresses(ctx); err == nil { t.Fatal("expected error, instead got", err) } } func TestReturnErrNoConnection(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) })) defer ts.Close() // We will fail more times than we retry, so requests should fail with ErrNoConnection. m := proton.New( proton.WithHostURL(ts.URL), proton.WithRetryCount(5), proton.WithTransport(newFailingRoundTripper(10)), ) // The call should fail because every dial will fail and we'll run out of retries. c := m.NewClient("", "", "") defer c.Close() if _, err := c.GetAddresses(context.Background()); err == nil { t.Fatal("expected error, instead got", err) } } func TestStatusCallbacks(t *testing.T) { s := server.New() defer s.Close() ctl := proton.NewNetCtl() m := proton.New( proton.WithHostURL(s.GetHostURL()), proton.WithTransport(ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true})), ) statusCh := make(chan proton.Status, 1) m.AddStatusObserver(func(status proton.Status) { statusCh <- status }) ctx, cancel := context.WithCancel(context.Background()) defer cancel() ctl.Disable() require.Error(t, m.Ping(ctx)) require.Equal(t, proton.StatusDown, <-statusCh) ctl.Enable() require.NoError(t, m.Ping(ctx)) require.Equal(t, proton.StatusUp, <-statusCh) ctl.SetReadLimit(1) require.Error(t, m.Ping(ctx)) require.Equal(t, proton.StatusDown, <-statusCh) ctl.SetReadLimit(0) require.NoError(t, m.Ping(ctx)) require.Equal(t, proton.StatusUp, <-statusCh) } func Test503IsReportedAsAPIError(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusServiceUnavailable) })) defer ts.Close() m := proton.New( proton.WithHostURL(ts.URL), proton.WithRetryCount(5), ) c := m.NewClient("", "", "") defer c.Close() _, err := c.GetAddresses(context.Background()) require.Error(t, err) var protonErr *proton.APIError require.True(t, errors.As(err, &protonErr)) require.Equal(t, 503, protonErr.Status) } type failingRoundTripper struct { http.RoundTripper fails, calls int } func newFailingRoundTripper(fails int) http.RoundTripper { return &failingRoundTripper{ RoundTripper: http.DefaultTransport, fails: fails, } } func (rt *failingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { rt.calls++ if rt.calls < rt.fails { return nil, errors.New("simulating network error") } return rt.RoundTripper.RoundTrip(req) } go-proton-api-1.0.0/manager_user.go000066400000000000000000000017311447642273300172200ustar00rootroot00000000000000package proton import ( "context" ) func (m *Manager) GetCaptcha(ctx context.Context, token string) ([]byte, error) { res, err := m.r(ctx).SetQueryParam("Token", token).SetQueryParam("ForceWebMessaging", "1").Get("/core/v4/captcha") if err != nil { return nil, err } return res.Body(), nil } func (m *Manager) SendVerificationCode(ctx context.Context, req SendVerificationCodeReq) error { if _, err := m.r(ctx).SetBody(req).Post("/core/v4/users/code"); err != nil { return err } return nil } func (m *Manager) CreateUser(ctx context.Context, req CreateUserReq) (User, error) { var res struct { User User } if _, err := m.r(ctx).SetBody(req).SetResult(&res).Post("/core/v4/users"); err != nil { return User{}, err } return res.User, nil } func (m *Manager) GetUsernameAvailable(ctx context.Context, username string) error { if _, err := m.r(ctx).SetQueryParam("Name", username).Get("/core/v4/users/available"); err != nil { return err } return nil } go-proton-api-1.0.0/manager_user_types.go000066400000000000000000000007431447642273300204460ustar00rootroot00000000000000package proton type TokenType string const ( EmailTokenType TokenType = "email" SMSTokenType TokenType = "sms" ) type SendVerificationCodeReq struct { Username string Type TokenType Destination TokenDestination } type TokenDestination struct { Address string Phone string } type UserType int const ( MailUserType UserType = iota + 1 VPNUserType ) type CreateUserReq struct { Type UserType Username string Domain string Auth AuthVerifier } go-proton-api-1.0.0/message.go000066400000000000000000000157061447642273300162030ustar00rootroot00000000000000package proton import ( "bytes" "context" "fmt" "runtime" "strconv" "github.com/ProtonMail/gluon/async" "github.com/bradenaw/juniper/parallel" "github.com/bradenaw/juniper/xslices" "github.com/go-resty/resty/v2" ) const maxMessageIDs = 1000 func (c *Client) GetFullMessage(ctx context.Context, messageID string, scheduler Scheduler, storageProvider AttachmentAllocator) (FullMessage, error) { message, err := c.GetMessage(ctx, messageID) if err != nil { return FullMessage{}, err } attDataBuffers, err := scheduler.Schedule(ctx, xslices.Map(message.Attachments, func(att Attachment) string { return att.ID }), storageProvider, func(ctx context.Context, s string, buffer *bytes.Buffer) error { return c.GetAttachmentInto(ctx, s, buffer) }) if err != nil { return FullMessage{}, err } return FullMessage{ Message: message, AttData: xslices.Map(attDataBuffers, func(b *bytes.Buffer) []byte { return b.Bytes() }), }, nil } func (c *Client) GetMessage(ctx context.Context, messageID string) (Message, error) { var res struct { Message Message } if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetResult(&res).Get("/mail/v4/messages/" + messageID) }); err != nil { return Message{}, err } return res.Message, nil } func (c *Client) CountMessages(ctx context.Context) (int, error) { return c.countMessages(ctx, MessageFilter{}) } func (c *Client) GetMessageMetadata(ctx context.Context, filter MessageFilter) ([]MessageMetadata, error) { count, err := c.countMessages(ctx, filter) if err != nil { return nil, err } return fetchPaged(ctx, count, maxPageSize, c, func(ctx context.Context, page, pageSize int) ([]MessageMetadata, error) { return c.GetMessageMetadataPage(ctx, page, pageSize, filter) }) } func (c *Client) GetAllMessageIDs(ctx context.Context, afterID string) ([]string, error) { var messageIDs []string for ; ; afterID = messageIDs[len(messageIDs)-1] { page, err := c.GetMessageIDs(ctx, afterID, maxMessageIDs) if err != nil { return nil, err } if len(page) == 0 { return messageIDs, nil } messageIDs = append(messageIDs, page...) } } func (c *Client) DeleteMessage(ctx context.Context, messageIDs ...string) error { pages := xslices.Chunk(messageIDs, maxPageSize) return parallel.DoContext(ctx, runtime.NumCPU(), len(pages), func(ctx context.Context, idx int) error { defer async.HandlePanic(c.m.panicHandler) return c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetBody(MessageActionReq{IDs: pages[idx]}).Put("/mail/v4/messages/delete") }) }) } func (c *Client) MarkMessagesRead(ctx context.Context, messageIDs ...string) error { for _, page := range xslices.Chunk(messageIDs, maxPageSize) { if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetBody(MessageActionReq{IDs: page}).Put("/mail/v4/messages/read") }); err != nil { return err } } return nil } func (c *Client) MarkMessagesUnread(ctx context.Context, messageIDs ...string) error { for _, page := range xslices.Chunk(messageIDs, maxPageSize) { req := MessageActionReq{IDs: page} if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetBody(req).Put("/mail/v4/messages/unread") }); err != nil { return err } } return nil } func (c *Client) LabelMessages(ctx context.Context, messageIDs []string, labelID string) error { var results []LabelMessagesRes for _, chunk := range xslices.Chunk(messageIDs, maxPageSize) { var res LabelMessagesRes if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetBody(LabelMessagesReq{ LabelID: labelID, IDs: chunk, }).SetResult(&res).Put("/mail/v4/messages/label") }); err != nil { return err } if ok, errStr := res.ok(); !ok { tokens := xslices.Map(results, func(res LabelMessagesRes) UndoToken { return res.UndoToken }) if _, undoErr := c.UndoActions(ctx, tokens...); undoErr != nil { return fmt.Errorf("failed to undo label actions (undo reason: %v): %w", errStr, undoErr) } return fmt.Errorf("failed to label messages: %v", errStr) } results = append(results, res) } return nil } func (c *Client) UnlabelMessages(ctx context.Context, messageIDs []string, labelID string) error { var results []LabelMessagesRes for _, chunk := range xslices.Chunk(messageIDs, maxPageSize) { var res LabelMessagesRes if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetBody(LabelMessagesReq{ LabelID: labelID, IDs: chunk, }).SetResult(&res).Put("/mail/v4/messages/unlabel") }); err != nil { return err } if ok, errStr := res.ok(); !ok { tokens := xslices.Map(results, func(res LabelMessagesRes) UndoToken { return res.UndoToken }) if _, undoErr := c.UndoActions(ctx, tokens...); undoErr != nil { return fmt.Errorf("failed to undo unlabel actions (undo reason: %v): %w", errStr, undoErr) } return fmt.Errorf("failed to unlabel messages: %v", errStr) } results = append(results, res) } return nil } func (c *Client) GetMessageIDs(ctx context.Context, afterID string, limit int) ([]string, error) { if limit > maxMessageIDs { limit = maxMessageIDs } var res struct { IDs []string } if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { if afterID != "" { r = r.SetQueryParam("AfterID", afterID) } return r.SetQueryParam("Limit", strconv.Itoa(limit)).SetResult(&res).Get("/mail/v4/messages/ids") }); err != nil { return nil, err } return res.IDs, nil } func (c *Client) countMessages(ctx context.Context, filter MessageFilter) (int, error) { var res struct { Total int } req := struct { MessageFilter Limit int `json:",,string"` }{ MessageFilter: filter, Limit: 0, } if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetBody(req).SetResult(&res).SetHeader("X-HTTP-Method-Override", "GET").Post("/mail/v4/messages") }); err != nil { return 0, err } return res.Total, nil } func (c *Client) GetMessageMetadataPage(ctx context.Context, page, pageSize int, filter MessageFilter) ([]MessageMetadata, error) { var res struct { Messages []MessageMetadata Stale Bool } req := struct { MessageFilter Page int PageSize int Sort string }{ MessageFilter: filter, Page: page, PageSize: pageSize, Sort: "ID", } for { if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetBody(req).SetResult(&res).SetHeader("X-HTTP-Method-Override", "GET").Post("/mail/v4/messages") }); err != nil { return nil, err } if !res.Stale { break } } return res.Messages, nil } func (c *Client) GetGroupedMessageCount(ctx context.Context) ([]MessageGroupCount, error) { var res struct { Counts []MessageGroupCount } if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetResult(&res).Get("/mail/v4/messages/count") }); err != nil { return nil, err } return res.Counts, nil } go-proton-api-1.0.0/message_build.go000066400000000000000000000160041447642273300173520ustar00rootroot00000000000000package proton import ( "bufio" "bytes" "encoding/base64" "io" "mime" "net/mail" "strings" "time" "unicode/utf8" "github.com/ProtonMail/gluon/rfc822" "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/emersion/go-message" "github.com/emersion/go-message/textproto" "github.com/google/uuid" ) func BuildRFC822(kr *crypto.KeyRing, msg Message, attData map[string][]byte) ([]byte, error) { if msg.MIMEType == rfc822.MultipartMixed { return buildPGPRFC822(kr, msg) } header, err := getMixedMessageHeader(msg) if err != nil { return nil, err } buf := new(bytes.Buffer) w, err := message.CreateWriter(buf, header) if err != nil { return nil, err } var ( inlineAtts []Attachment inlineData [][]byte attachAtts []Attachment attachData [][]byte ) for _, att := range msg.Attachments { if att.Disposition == InlineDisposition { inlineAtts = append(inlineAtts, att) inlineData = append(inlineData, attData[att.ID]) } else { attachAtts = append(attachAtts, att) attachData = append(attachData, attData[att.ID]) } } if len(inlineAtts) > 0 { if err := writeRelatedParts(w, kr, msg, inlineAtts, inlineData); err != nil { return nil, err } } else if err := writeTextPart(w, kr, msg); err != nil { return nil, err } for i, att := range attachAtts { if err := writeAttachmentPart(w, kr, att, attachData[i]); err != nil { return nil, err } } if err := w.Close(); err != nil { return nil, err } return buf.Bytes(), nil } func writeTextPart(w *message.Writer, kr *crypto.KeyRing, msg Message) error { dec, err := msg.Decrypt(kr) if err != nil { return err } part, err := w.CreatePart(getTextPartHeader(dec, msg.MIMEType)) if err != nil { return err } if _, err := part.Write(dec); err != nil { return err } return part.Close() } func writeAttachmentPart(w *message.Writer, kr *crypto.KeyRing, att Attachment, attData []byte) error { kps, err := base64.StdEncoding.DecodeString(att.KeyPackets) if err != nil { return err } msg := crypto.NewPGPSplitMessage(kps, attData).GetPGPMessage() dec, err := kr.Decrypt(msg, nil, crypto.GetUnixTime()) if err != nil { return err } part, err := w.CreatePart(getAttachmentPartHeader(att)) if err != nil { return err } if _, err := part.Write(dec.GetBinary()); err != nil { return err } return part.Close() } func writeRelatedParts(w *message.Writer, kr *crypto.KeyRing, msg Message, atts []Attachment, attData [][]byte) error { var header message.Header header.SetContentType(string(rfc822.MultipartRelated), nil) rel, err := w.CreatePart(header) if err != nil { return err } if err := writeTextPart(rel, kr, msg); err != nil { return err } for i, att := range atts { if err := writeAttachmentPart(rel, kr, att, attData[i]); err != nil { return err } } return rel.Close() } func buildPGPRFC822(kr *crypto.KeyRing, msg Message) ([]byte, error) { raw, err := textproto.ReadHeader(bufio.NewReader(strings.NewReader(msg.Header))) if err != nil { return nil, err } dec, err := msg.Decrypt(kr) if err != nil { return nil, err } sigs, err := ExtractSignatures(kr, msg.Body) if err != nil { return nil, err } if len(sigs) > 0 { return buildMultipartSignedRFC822(message.Header{Header: raw}, dec, sigs[0]) } return buildMultipartEncryptedRFC822(message.Header{Header: raw}, dec) } func buildMultipartSignedRFC822(header message.Header, body []byte, sig Signature) ([]byte, error) { buf := new(bytes.Buffer) boundary := uuid.New().String() header.SetContentType("multipart/signed", map[string]string{ "micalg": sig.Hash, "protocol": "application/pgp-signature", "boundary": boundary, }) if err := textproto.WriteHeader(buf, header.Header); err != nil { return nil, err } w := rfc822.NewMultipartWriter(buf, boundary) bodyHeader, bodyData := rfc822.Split(body) if err := w.AddPart(func(w io.Writer) error { if _, err := w.Write(bodyHeader); err != nil { return err } if _, err := w.Write(bodyData); err != nil { return err } return nil }); err != nil { return nil, err } var sigHeader message.Header sigHeader.SetContentType("application/pgp-signature", map[string]string{"name": "OpenPGP_signature.asc"}) sigHeader.SetContentDisposition("attachment", map[string]string{"filename": "OpenPGP_signature"}) sigHeader.Set("Content-Description", "OpenPGP digital signature") sigData, err := sig.Data.GetArmored() if err != nil { return nil, err } if err := w.AddPart(func(w io.Writer) error { if err := textproto.WriteHeader(w, sigHeader.Header); err != nil { return err } if _, err := w.Write([]byte(sigData)); err != nil { return err } return nil }); err != nil { return nil, err } if err := w.Done(); err != nil { return nil, err } return buf.Bytes(), nil } func buildMultipartEncryptedRFC822(header message.Header, body []byte) ([]byte, error) { buf := new(bytes.Buffer) bodyHeader, bodyData := rfc822.Split(body) parsedHeader, err := rfc822.NewHeader(bodyHeader) if err != nil { return nil, err } parsedHeader.Entries(func(key, val string) { header.Set(key, val) }) if err := textproto.WriteHeader(buf, header.Header); err != nil { return nil, err } if _, err := buf.Write(bodyData); err != nil { return nil, err } return buf.Bytes(), nil } func getMixedMessageHeader(msg Message) (message.Header, error) { raw, err := textproto.ReadHeader(bufio.NewReader(strings.NewReader(msg.Header))) if err != nil { return message.Header{}, err } header := message.Header{Header: raw} header.SetContentType(string(rfc822.MultipartMixed), nil) if date, err := mail.ParseDate(header.Get("Date")); err != nil || date.Before(time.Unix(0, 0)) { if msgTime := time.Unix(msg.Time, 0); msgTime.After(time.Unix(0, 0)) { header.Set("Date", msgTime.In(time.UTC).Format(time.RFC1123Z)) } else { header.Del("Date") } header.Set("X-Original-Date", date.In(time.UTC).Format(time.RFC1123Z)) } return header, nil } func getTextPartHeader(body []byte, mimeType rfc822.MIMEType) message.Header { var header message.Header params := make(map[string]string) if utf8.Valid(body) { params["charset"] = "utf-8" } header.SetContentType(string(mimeType), params) // Use quoted-printable for all text/... parts header.Set("Content-Transfer-Encoding", "quoted-printable") return header } func getAttachmentPartHeader(att Attachment) message.Header { var header message.Header for key, val := range att.Headers { for _, val := range val { header.Add(key, val) } } // All attachments have a content type. header.SetContentType(string(att.MIMEType), map[string]string{"name": mime.QEncoding.Encode("utf-8", att.Name)}) // All attachments have a content disposition. header.SetContentDisposition(string(att.Disposition), map[string]string{"filename": mime.QEncoding.Encode("utf-8", att.Name)}) // Use base64 for all attachments except embedded RFC822 messages. if att.MIMEType != rfc822.MessageRFC822 { header.Set("Content-Transfer-Encoding", "base64") } else { header.Del("Content-Transfer-Encoding") } return header } go-proton-api-1.0.0/message_draft_types.go000066400000000000000000000013411447642273300205750ustar00rootroot00000000000000package proton import ( "net/mail" "github.com/ProtonMail/gluon/rfc822" ) type DraftTemplate struct { Subject string Sender *mail.Address ToList []*mail.Address CCList []*mail.Address BCCList []*mail.Address Body string MIMEType rfc822.MIMEType Unread Bool ExternalID string `json:",omitempty"` } type CreateDraftAction int const ( ReplyAction CreateDraftAction = iota ReplyAllAction ForwardAction AutoResponseAction ReadReceiptAction ) type CreateDraftReq struct { Message DraftTemplate AttachmentKeyPackets []string ParentID string `json:",omitempty"` Action CreateDraftAction } type UpdateDraftReq struct { Message DraftTemplate AttachmentKeyPackets []string } go-proton-api-1.0.0/message_encrypt.go000066400000000000000000000216061447642273300177430ustar00rootroot00000000000000package proton import ( "bytes" "encoding/base64" "fmt" "io" "mime" "strings" "github.com/ProtonMail/gluon/rfc822" "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/google/uuid" "golang.org/x/text/encoding/htmlindex" "golang.org/x/text/encoding/ianaindex" ) // CharsetReader returns a charset decoder for the given charset. // If set, it will be used to decode non-utf8 encoded messages. var CharsetReader func(charset string, input io.Reader) (io.Reader, error) // EncryptRFC822 encrypts the given message literal as a PGP attachment. func EncryptRFC822(kr *crypto.KeyRing, literal []byte) ([]byte, error) { var buf bytes.Buffer if err := tryEncrypt(&buf, kr, rfc822.Parse(literal)); err != nil { return encryptFull(kr, literal) } return buf.Bytes(), nil } // tryEncrypt tries to encrypt the given message section. // It first checks if the message is encrypted/signed or has multiple text parts. // If so, it returns an error -- we need to encrypt the whole message as a PGP attachment. func tryEncrypt(w io.Writer, kr *crypto.KeyRing, s *rfc822.Section) error { var textCount int if err := s.Walk(func(s *rfc822.Section) error { // Ensure we can read the content type. contentType, _, err := s.ContentType() if err != nil { return fmt.Errorf("cannot read content type: %w", err) } // Ensure we can read the content disposition. if header, err := s.ParseHeader(); err != nil { return fmt.Errorf("cannot read header: %w", err) } else if header.Has("Content-Disposition") { if _, _, err := rfc822.ParseMediaType(header.Get("Content-Disposition")); err != nil { return fmt.Errorf("cannot read content disposition: %w", err) } } // Check if the message is already encrypted or signed. if contentType.SubType() == "encrypted" { return fmt.Errorf("already encrypted") } else if contentType.SubType() == "signed" { return fmt.Errorf("already signed") } if contentType.Type() != "text" { return nil } if textCount++; textCount > 1 { return fmt.Errorf("multiple text parts") } return nil }); err != nil { return err } return encrypt(w, kr, s) } // encrypt encrypts the given message section with the given keyring and writes the result to w. func encrypt(w io.Writer, kr *crypto.KeyRing, s *rfc822.Section) error { contentType, contentParams, err := s.ContentType() if err != nil { return err } if contentType.IsMultiPart() { return encryptMultipart(w, kr, s, contentParams["boundary"]) } if contentType.Type() == "text" || contentType.Type() == "message" { return encryptText(w, kr, s) } return encryptAtt(w, kr, s) } // encryptMultipart encrypts the given multipart message section with the given keyring and writes the result to w. func encryptMultipart(w io.Writer, kr *crypto.KeyRing, s *rfc822.Section, boundary string) error { // Write the header. if _, err := w.Write(s.Header()); err != nil { return err } // Create a new multipart writer with the boundary from the header. ww := rfc822.NewMultipartWriter(w, boundary) children, err := s.Children() if err != nil { return err } // Encrypt each child part. for _, child := range children { if err := ww.AddPart(func(w io.Writer) error { return encrypt(w, kr, child) }); err != nil { return err } } return ww.Done() } // encryptText encrypts the given text message section with the given keyring and writes the result to w. func encryptText(w io.Writer, kr *crypto.KeyRing, s *rfc822.Section) error { contentType, contentParams, err := s.ContentType() if err != nil { return err } header, err := s.ParseHeader() if err != nil { return err } body, err := s.DecodedBody() if err != nil { return err } // Remove the Content-Transfer-Encoding header as we decode the body. header.Del("Content-Transfer-Encoding") // If the text part has a charset, decode it to UTF-8. if charset, ok := contentParams["charset"]; ok { decoder, err := getCharsetDecoder(bytes.NewReader(body), charset) if err != nil { return err } if body, err = io.ReadAll(decoder); err != nil { return err } // Remove old content type. header.Del("Content-Type") header.Set("Content-Type", mime.FormatMediaType( string(contentType), replace(contentParams, "charset", "utf-8")), ) } // Encrypt the body. enc, err := kr.Encrypt(crypto.NewPlainMessage(body), nil) if err != nil { return err } // Armor the encrypted body. arm, err := enc.GetArmored() if err != nil { return err } // Write the header. if _, err := w.Write(header.Raw()); err != nil { return err } // Write the armored body. if _, err := w.Write([]byte(arm)); err != nil { return err } return nil } // encryptAtt encrypts the given attachment section with the given keyring and writes the result to w. func encryptAtt(w io.Writer, kr *crypto.KeyRing, s *rfc822.Section) error { header, err := s.ParseHeader() if err != nil { return err } body, err := s.DecodedBody() if err != nil { return err } // Set the Content-Transfer-Encoding header to base64. header.Set("Content-Transfer-Encoding", "base64") // Encrypt the body. enc, err := kr.Encrypt(crypto.NewPlainMessage(body), nil) if err != nil { return err } // Write the header. if _, err := w.Write(header.Raw()); err != nil { return err } // Write the base64 body. if err := encodeBase64(w, enc.GetBinary()); err != nil { return err } return nil } // encryptFull builds a PGP/MIME encrypted message from the given literal. func encryptFull(kr *crypto.KeyRing, literal []byte) ([]byte, error) { enc, err := kr.Encrypt(crypto.NewPlainMessage(literal), kr) if err != nil { return nil, err } arm, err := enc.GetArmored() if err != nil { return nil, err } header, err := rfc822.Parse(literal).ParseHeader() if err != nil { return nil, err } buf := new(bytes.Buffer) boundary := strings.ReplaceAll(uuid.NewString(), "-", "") multipartWriter := rfc822.NewMultipartWriter(buf, boundary) { newHeader := rfc822.NewEmptyHeader() if value, ok := header.GetChecked("Message-Id"); ok { newHeader.Set("Message-Id", value) } contentType := mime.FormatMediaType("multipart/encrypted", map[string]string{ "boundary": boundary, "protocol": "application/pgp-encrypted", }) newHeader.Set("Mime-version", "1.0") newHeader.Set("Content-Type", contentType) if value, ok := header.GetChecked("From"); ok { newHeader.Set("From", value) } if value, ok := header.GetChecked("To"); ok { newHeader.Set("To", value) } if value, ok := header.GetChecked("Subject"); ok { newHeader.Set("Subject", value) } if value, ok := header.GetChecked("Date"); ok { newHeader.Set("Date", value) } if value, ok := header.GetChecked("Received"); ok { newHeader.Set("Received", value) } buf.Write(newHeader.Raw()) } // Write PGP control data { pgpControlHeader := rfc822.NewEmptyHeader() pgpControlHeader.Set("Content-Description", "PGP/MIME version identification") pgpControlHeader.Set("Content-Type", "application/pgp-encrypted") if err := multipartWriter.AddPart(func(writer io.Writer) error { if _, err := writer.Write(pgpControlHeader.Raw()); err != nil { return err } _, err := writer.Write([]byte("Version: 1")) return err }); err != nil { return nil, err } } // write PGP attachment { pgpAttachmentHeader := rfc822.NewEmptyHeader() contentType := mime.FormatMediaType("application/octet-stream", map[string]string{ "name": "encrypted.asc", }) pgpAttachmentHeader.Set("Content-Description", "OpenPGP encrypted message") pgpAttachmentHeader.Set("Content-Disposition", "inline; filename=encrypted.asc") pgpAttachmentHeader.Set("Content-Type", contentType) if err := multipartWriter.AddPart(func(writer io.Writer) error { if _, err := writer.Write(pgpAttachmentHeader.Raw()); err != nil { return err } _, err := writer.Write([]byte(arm)) return err }); err != nil { return nil, err } } // finish messsage if err := multipartWriter.Done(); err != nil { return nil, err } return buf.Bytes(), nil } func encodeBase64(writer io.Writer, b []byte) error { encoder := base64.NewEncoder(base64.StdEncoding, writer) defer encoder.Close() if _, err := encoder.Write(b); err != nil { return err } return nil } func getCharsetDecoder(r io.Reader, charset string) (io.Reader, error) { if CharsetReader != nil { if enc, err := CharsetReader(charset, r); err == nil { return enc, nil } } if enc, err := ianaindex.MIME.Encoding(strings.ToLower(charset)); err == nil { return enc.NewDecoder().Reader(r), nil } if enc, err := ianaindex.MIME.Encoding("cs" + strings.ToLower(charset)); err == nil { return enc.NewDecoder().Reader(r), nil } if enc, err := htmlindex.Get(strings.ToLower(charset)); err == nil { return enc.NewDecoder().Reader(r), nil } return nil, fmt.Errorf("unknown charset: %s", charset) } func replace[Key comparable, Value any](m map[Key]Value, key Key, value Value) map[Key]Value { m[key] = value return m } go-proton-api-1.0.0/message_encrypt_test.go000066400000000000000000000155321447642273300210030ustar00rootroot00000000000000package proton import ( "bytes" "encoding/base64" "io" "testing" "github.com/ProtonMail/gluon/rfc822" "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestEncryptMessage_Simple(t *testing.T) { const message = `From: Nathaniel Borenstein To: Ned Freed Subject: Sample message (import 2) MIME-Version: 1.0 Content-type: text/plain This is explicitly typed plain ASCII text. ` key, err := crypto.GenerateKey("foobar", "foo@bar.com", "x25519", 0) require.NoError(t, err) kr, err := crypto.NewKeyRing(key) require.NoError(t, err) encryptedMessage, err := EncryptRFC822(kr, []byte(message)) require.NoError(t, err) section := rfc822.Parse(encryptedMessage) // Check root header: header, err := section.ParseHeader() require.NoError(t, err) assert.Equal(t, header.Get("From"), "Nathaniel Borenstein ") assert.Equal(t, header.Get("To"), "Ned Freed ") assert.Equal(t, header.Get("Subject"), "Sample message (import 2)") assert.Equal(t, header.Get("MIME-Version"), "1.0") // Read the body. body, err := section.DecodedBody() require.NoError(t, err) // Unarmor the PGP message. enc, err := crypto.NewPGPMessageFromArmored(string(body)) require.NoError(t, err) // Decrypt the PGP message. dec, err := kr.Decrypt(enc, nil, crypto.GetUnixTime()) require.NoError(t, err) require.Equal(t, "This is explicitly typed plain ASCII text.\n", dec.GetString()) } func TestEncryptMessage_MultipleTextParts(t *testing.T) { const message = `From: Nathaniel Borenstein To: Ned Freed Subject: Sample message (import 2) MIME-Version: 1.0 Content-type: multipart/mixed; boundary="simple boundary" Received: from mail.protonmail.ch by mail.protonmail.ch; Tue, 25 Nov 2016 This is the preamble. It is to be ignored, though it is a handy place for mail composers to include an explanatory note to non-MIME compliant readers. --simple boundary This is implicitly typed plain ASCII text. It does NOT end with a linebreak. --simple boundary Content-type: text/plain; charset=us-ascii This is explicitly typed plain ASCII text. It DOES end with a linebreak. --simple boundary-- This is the epilogue. It is also to be ignored. ` key, err := crypto.GenerateKey("foobar", "foo@bar.com", "x25519", 0) require.NoError(t, err) kr, err := crypto.NewKeyRing(key) require.NoError(t, err) encryptedMessage, err := EncryptRFC822(kr, []byte(message)) require.NoError(t, err) section := rfc822.Parse(encryptedMessage) { // Check root header: header, err := section.ParseHeader() require.NoError(t, err) assert.Equal(t, header.Get("From"), "Nathaniel Borenstein ") assert.Equal(t, header.Get("To"), "Ned Freed ") assert.Equal(t, header.Get("Subject"), "Sample message (import 2)") assert.Equal(t, header.Get("MIME-Version"), "1.0") assert.Equal(t, header.Get("Received"), "from mail.protonmail.ch by mail.protonmail.ch; Tue, 25 Nov 2016") mediaType, params, err := rfc822.ParseMediaType(header.Get("Content-Type")) require.NoError(t, err) assert.Equal(t, "multipart/encrypted", mediaType) assert.Equal(t, "application/pgp-encrypted", params["protocol"]) assert.NotEmpty(t, params["boundary"]) } children, err := section.Children() require.NoError(t, err) require.Equal(t, 2, len(children)) { // check first child. child := children[0] header, err := child.ParseHeader() require.NoError(t, err) assert.Equal(t, header.Get("Content-Description"), "PGP/MIME version identification") assert.Equal(t, header.Get("Content-Type"), "application/pgp-encrypted") assert.Equal(t, []byte("Version: 1"), child.Body()) } { // check second child. child := children[1] header, err := child.ParseHeader() require.NoError(t, err) assert.Equal(t, header.Get("Content-Description"), "OpenPGP encrypted message") assert.Equal(t, header.Get("Content-Disposition"), "inline; filename=encrypted.asc") assert.Equal(t, header.Get("Content-type"), "application/octet-stream; name=encrypted.asc") body := child.Body() assert.True(t, bytes.HasPrefix(body, []byte("-----BEGIN PGP MESSAGE-----"))) assert.True(t, bytes.HasSuffix(body, []byte("-----END PGP MESSAGE-----"))) } } func TestEncryptMessage_Attachment(t *testing.T) { const message = `From: Nathaniel Borenstein To: Ned Freed Subject: Sample message (import 2) MIME-Version: 1.0 Content-type: multipart/mixed; boundary="simple boundary" --simple boundary Content-type: text/plain; charset=us-ascii Hello world --simple boundary Content-Type: application/pdf; name="test.pdf" Content-Disposition: attachment; filename="test.pdf" Content-Transfer-Encoding: base64 SGVsbG8gQXR0YWNobWVudA== --simple boundary-- ` key, err := crypto.GenerateKey("foobar", "foo@bar.com", "x25519", 0) require.NoError(t, err) kr, err := crypto.NewKeyRing(key) require.NoError(t, err) encryptedMessage, err := EncryptRFC822(kr, []byte(message)) require.NoError(t, err) section := rfc822.Parse(encryptedMessage) { // Check root header: header, err := section.ParseHeader() require.NoError(t, err) assert.Equal(t, header.Get("From"), "Nathaniel Borenstein ") assert.Equal(t, header.Get("To"), "Ned Freed ") assert.Equal(t, header.Get("Subject"), "Sample message (import 2)") assert.Equal(t, header.Get("MIME-Version"), "1.0") mediaType, params, err := rfc822.ParseMediaType(header.Get("Content-Type")) require.NoError(t, err) assert.Equal(t, "multipart/mixed", mediaType) assert.NotEmpty(t, params["boundary"]) } children, err := section.Children() require.NoError(t, err) require.Equal(t, 2, len(children)) { // check first child. child := children[0] header, err := child.ParseHeader() require.NoError(t, err) header.Entries(func(key, value string) { // Old header should be deleted. assert.NotEqual(t, key, "Content-type") assert.NotEqual(t, value, "text/plain; charset=us-ascii") }) assert.Equal(t, header.Get("Content-Type"), "text/plain; charset=utf-8") } { // check second child. child := children[1] header, err := child.ParseHeader() require.NoError(t, err) assert.Equal(t, header.Get("Content-Transfer-Encoding"), "base64") assert.Equal(t, header.Get("Content-Disposition"), `attachment; filename="test.pdf"`) assert.Equal(t, header.Get("Content-type"), `application/pdf; name="test.pdf"`) body := child.Body() // Read the body. bodyDecoded, err := io.ReadAll(base64.NewDecoder(base64.StdEncoding, bytes.NewReader(body))) require.NoError(t, err) // Unarmor the PGP message. enc := crypto.NewPGPMessage(bodyDecoded) // Decrypt the PGP message. dec, err := kr.Decrypt(enc, nil, crypto.GetUnixTime()) require.NoError(t, err) require.Equal(t, "Hello Attachment", dec.GetString()) } } go-proton-api-1.0.0/message_import.go000066400000000000000000000070711447642273300175710ustar00rootroot00000000000000package proton import ( "context" "errors" "fmt" "strconv" "github.com/ProtonMail/gluon/async" "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/bradenaw/juniper/iterator" "github.com/bradenaw/juniper/parallel" "github.com/bradenaw/juniper/stream" "github.com/bradenaw/juniper/xslices" "github.com/go-resty/resty/v2" ) const ( // maxImportCount is the maximum number of messages that can be imported in a single request. maxImportCount = 10 // maxImportSize is the maximum total request size permitted for a single import request. maxImportSize = 30 * 1024 * 1024 ) var ErrImportEncrypt = errors.New("failed to encrypt message") var ErrImportSizeExceeded = errors.New("message exceeds maximum import size of 30MB") func (c *Client) ImportMessages(ctx context.Context, addrKR *crypto.KeyRing, workers, buffer int, req ...ImportReq) (stream.Stream[ImportRes], error) { // Encrypt each message. for idx := range req { enc, err := EncryptRFC822(addrKR, req[idx].Message) if err != nil { return nil, fmt.Errorf("%w %v: %v", ErrImportEncrypt, idx, err) } req[idx].Message = enc } // If any of the messages exceed the maximum import size, return an error. if xslices.Any(req, func(req ImportReq) bool { return len(req.Message) > maxImportSize }) { return nil, ErrImportSizeExceeded } return stream.Flatten(parallel.MapStream( ctx, stream.FromIterator(iterator.Slice(chunkSized(req, maxImportCount, maxImportSize, func(req ImportReq) int { return len(req.Message) }))), workers, buffer, func(ctx context.Context, req []ImportReq) (stream.Stream[ImportRes], error) { defer async.HandlePanic(c.m.panicHandler) res, err := c.importMessages(ctx, req) if err != nil { return nil, fmt.Errorf("failed to import messages: %w", err) } for _, res := range res { if res.Code != SuccessCode { return nil, fmt.Errorf("failed to import message: %w", res.APIError) } } return stream.FromIterator(iterator.Slice(res)), nil }, )), nil } func (c *Client) importMessages(ctx context.Context, req []ImportReq) ([]ImportRes, error) { names := iterator.Collect(iterator.Map(iterator.Counter(len(req)), func(i int) string { return strconv.Itoa(i) })) var named []namedImportReq for idx, name := range names { named = append(named, namedImportReq{ ImportReq: req[idx], Name: name, }) } type namedImportRes struct { Name string Response ImportRes } var res struct { Responses []namedImportRes } if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { fields, err := buildImportReqFields(named) if err != nil { return nil, err } return r.SetMultipartFields(fields...).SetResult(&res).Post("/mail/v4/messages/import") }); err != nil { return nil, err } namedRes := make(map[string]ImportRes, len(res.Responses)) for _, res := range res.Responses { namedRes[res.Name] = res.Response } return xslices.Map(names, func(name string) ImportRes { return namedRes[name] }), nil } // chunkSized splits a slice into chunks of maximum size and length. // It is assumed that the size of each element is less than the maximum size. func chunkSized[T any](vals []T, maxLen, maxSize int, getSize func(T) int) [][]T { var chunks [][]T for len(vals) > 0 { var ( curChunk []T curSize int ) for len(vals) > 0 && len(curChunk) < maxLen && curSize < maxSize { val, size := vals[0], getSize(vals[0]) if curSize+size > maxSize { break } curChunk = append(curChunk, val) curSize += size vals = vals[1:] } chunks = append(chunks, curChunk) } return chunks } go-proton-api-1.0.0/message_import_test.go000066400000000000000000000021201447642273300206160ustar00rootroot00000000000000package proton import ( "reflect" "testing" ) func Test_chunkSized(t *testing.T) { type args struct { vals []int maxLen int maxSize int getSize func(int) int } tests := []struct { name string args args want [][]int }{ { name: "limit by length", args: args{ vals: []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, maxLen: 3, // Split into chunks of at most 3 maxSize: 100, getSize: func(i int) int { return i }, }, want: [][]int{ {1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10}, }, }, { name: "limit by size", args: args{ vals: []int{1, 1, 1, 1, 1, 2, 2, 2, 2, 2}, maxLen: 100, maxSize: 5, // Split into chunks of at most 5 getSize: func(i int) int { return i }, }, want: [][]int{ {1, 1, 1, 1, 1}, {2, 2}, {2, 2}, {2}, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := chunkSized(tt.args.vals, tt.args.maxLen, tt.args.maxSize, tt.args.getSize); !reflect.DeepEqual(got, tt.want) { t.Errorf("chunkSized() = %v, want %v", got, tt.want) } }) } } go-proton-api-1.0.0/message_import_types.go000066400000000000000000000021521447642273300210100ustar00rootroot00000000000000package proton import ( "bytes" "encoding/json" "github.com/ProtonMail/gluon/rfc822" "github.com/go-resty/resty/v2" ) type ImportReq struct { Metadata ImportMetadata Message []byte } type namedImportReq struct { ImportReq Name string } type ImportMetadata struct { AddressID string LabelIDs []string Unread Bool Flags MessageFlag } type ImportRes struct { APIError MessageID string } func buildImportReqFields(req []namedImportReq) ([]*resty.MultipartField, error) { var fields []*resty.MultipartField metadata := make(map[string]ImportMetadata, len(req)) for _, req := range req { metadata[req.Name] = req.Metadata fields = append(fields, &resty.MultipartField{ Param: req.Name, FileName: req.Name + ".eml", ContentType: string(rfc822.MessageRFC822), Reader: bytes.NewReader(append(req.Message, "\r\n"...)), }) } b, err := json.Marshal(metadata) if err != nil { return nil, err } fields = append(fields, &resty.MultipartField{ Param: "Metadata", ContentType: "application/json", Reader: bytes.NewReader(b), }) return fields, nil } go-proton-api-1.0.0/message_send.go000066400000000000000000000041261447642273300172060ustar00rootroot00000000000000package proton import ( "context" "fmt" "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/go-resty/resty/v2" ) func (c *Client) CreateDraft(ctx context.Context, addrKR *crypto.KeyRing, req CreateDraftReq) (Message, error) { var res struct { Message Message } kr, err := addrKR.FirstKey() if err != nil { return Message{}, fmt.Errorf("failed to get first key: %w", err) } enc, err := kr.Encrypt(crypto.NewPlainMessageFromString(req.Message.Body), nil) if err != nil { return Message{}, fmt.Errorf("failed to encrypt draft: %w", err) } arm, err := enc.GetArmored() if err != nil { return Message{}, fmt.Errorf("failed to armor draft: %w", err) } req.Message.Body = arm if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetBody(req).SetResult(&res).Post("/mail/v4/messages") }); err != nil { return Message{}, err } return res.Message, nil } func (c *Client) UpdateDraft(ctx context.Context, draftID string, addrKR *crypto.KeyRing, req UpdateDraftReq) (Message, error) { var res struct { Message Message } if req.Message.Body != "" { kr, err := addrKR.FirstKey() if err != nil { return Message{}, fmt.Errorf("failed to get first key: %w", err) } enc, err := kr.Encrypt(crypto.NewPlainMessageFromString(req.Message.Body), nil) if err != nil { return Message{}, fmt.Errorf("failed to encrypt draft: %w", err) } arm, err := enc.GetArmored() if err != nil { return Message{}, fmt.Errorf("failed to armor draft: %w", err) } req.Message.Body = arm } if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetBody(req).SetResult(&res).Put("/mail/v4/messages/" + draftID) }); err != nil { return Message{}, err } return res.Message, nil } func (c *Client) SendDraft(ctx context.Context, draftID string, req SendDraftReq) (Message, error) { var res struct { Sent Message } if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetBody(req).SetResult(&res).Post("/mail/v4/messages/" + draftID) }); err != nil { return Message{}, err } return res.Sent, nil } go-proton-api-1.0.0/message_send_types.go000066400000000000000000000204151447642273300204310ustar00rootroot00000000000000package proton import ( "encoding/base64" "fmt" "github.com/ProtonMail/gluon/rfc822" "github.com/ProtonMail/gopenpgp/v2/crypto" ) type EncryptionScheme int const ( InternalScheme EncryptionScheme = 1 << iota EncryptedOutsideScheme ClearScheme PGPInlineScheme PGPMIMEScheme ClearMIMEScheme ) type SignatureType int const ( NoSignature SignatureType = iota DetachedSignature AttachedSignature ) type MessageRecipient struct { Type EncryptionScheme Signature SignatureType BodyKeyPacket string `json:",omitempty"` AttachmentKeyPackets map[string]string `json:",omitempty"` } type MessagePackage struct { Addresses map[string]*MessageRecipient MIMEType rfc822.MIMEType Type EncryptionScheme Body string BodyKey *SessionKey `json:",omitempty"` AttachmentKeys map[string]*SessionKey `json:",omitempty"` } func newMessagePackage(mimeType rfc822.MIMEType, encBodyData []byte) *MessagePackage { return &MessagePackage{ Addresses: make(map[string]*MessageRecipient), MIMEType: mimeType, Body: base64.StdEncoding.EncodeToString(encBodyData), AttachmentKeys: make(map[string]*SessionKey), } } type SessionKey struct { Key string Algorithm string } func newSessionKey(key *crypto.SessionKey) *SessionKey { return &SessionKey{ Key: key.GetBase64Key(), Algorithm: key.Algo, } } type SendPreferences struct { // Encrypt indicates whether the email should be encrypted or not. // If it's encrypted, we need to know which public key to use. Encrypt bool // PubKey contains an OpenPGP key that can be used for encryption. PubKey *crypto.KeyRing // SignatureType indicates how the email should be signed. SignatureType SignatureType // EncryptionScheme indicates if we should encrypt body and attachments separately and // what MIME format to give the final encrypted email. The two standard PGP // schemes are PGP/MIME and PGP/Inline. However we use a custom scheme for // internal emails (including the so-called encrypted-to-outside emails, // which even though meant for external users, they don't really get out of // our platform). If the email is sent unencrypted, no PGP scheme is needed. EncryptionScheme EncryptionScheme // MIMEType is the MIME type to use for formatting the body of the email // (before encryption/after decryption). The standard possibilities are the // enriched HTML format, text/html, and plain text, text/plain. But it's // also possible to have a multipart/mixed format, which is typically used // for PGP/MIME encrypted emails, where attachments go into the body too. // Because of this, this option is sometimes called MIME format. MIMEType rfc822.MIMEType } type SendDraftReq struct { Packages []*MessagePackage } func (req *SendDraftReq) AddMIMEPackage( kr *crypto.KeyRing, mimeBody string, prefs map[string]SendPreferences, ) error { for _, prefs := range prefs { if prefs.MIMEType != rfc822.MultipartMixed { return fmt.Errorf("invalid MIME type for MIME package: %s", prefs.MIMEType) } } pkg, err := newMIMEPackage(kr, mimeBody, prefs) if err != nil { return err } req.Packages = append(req.Packages, pkg) return nil } func (req *SendDraftReq) AddTextPackage( kr *crypto.KeyRing, body string, mimeType rfc822.MIMEType, prefs map[string]SendPreferences, attKeys map[string]*crypto.SessionKey, ) error { pkg, err := newTextPackage(kr, body, mimeType, prefs, attKeys) if err != nil { return err } req.Packages = append(req.Packages, pkg) return nil } func newMIMEPackage( kr *crypto.KeyRing, mimeBody string, prefs map[string]SendPreferences, ) (*MessagePackage, error) { decBodyKey, encBodyData, err := encSplit(kr, mimeBody) if err != nil { return nil, fmt.Errorf("failed to encrypt MIME body: %w", err) } pkg := newMessagePackage(rfc822.MultipartMixed, encBodyData) for addr, prefs := range prefs { if prefs.MIMEType != rfc822.MultipartMixed { return nil, fmt.Errorf("invalid MIME type for MIME package: %s", prefs.MIMEType) } if prefs.SignatureType != DetachedSignature { return nil, fmt.Errorf("invalid signature type for MIME package: %d", prefs.SignatureType) } recipient := &MessageRecipient{ Type: prefs.EncryptionScheme, Signature: prefs.SignatureType, } switch prefs.EncryptionScheme { case PGPMIMEScheme: if prefs.PubKey == nil { return nil, fmt.Errorf("missing public key for %s", addr) } encBodyKey, err := prefs.PubKey.EncryptSessionKey(decBodyKey) if err != nil { return nil, fmt.Errorf("failed to encrypt session key: %w", err) } recipient.BodyKeyPacket = base64.StdEncoding.EncodeToString(encBodyKey) case ClearMIMEScheme: pkg.BodyKey = &SessionKey{ Key: decBodyKey.GetBase64Key(), Algorithm: decBodyKey.Algo, } default: return nil, fmt.Errorf("invalid encryption scheme for MIME package: %d", prefs.EncryptionScheme) } pkg.Addresses[addr] = recipient pkg.Type |= prefs.EncryptionScheme } return pkg, nil } func newTextPackage( kr *crypto.KeyRing, body string, mimeType rfc822.MIMEType, prefs map[string]SendPreferences, attKeys map[string]*crypto.SessionKey, ) (*MessagePackage, error) { if mimeType != rfc822.TextPlain && mimeType != rfc822.TextHTML { return nil, fmt.Errorf("invalid MIME type for package: %s", mimeType) } decBodyKey, encBodyData, err := encSplit(kr, body) if err != nil { return nil, fmt.Errorf("failed to encrypt message body: %w", err) } pkg := newMessagePackage(mimeType, encBodyData) for addr, prefs := range prefs { if prefs.MIMEType != mimeType { return nil, fmt.Errorf("invalid MIME type for package: %s", prefs.MIMEType) } if prefs.SignatureType == DetachedSignature && !prefs.Encrypt { if prefs.EncryptionScheme == PGPInlineScheme { return nil, fmt.Errorf("invalid encryption scheme for %s: %d", addr, prefs.EncryptionScheme) } if prefs.EncryptionScheme == ClearScheme && mimeType != rfc822.TextPlain { return nil, fmt.Errorf("invalid MIME type for clear package: %s", mimeType) } } if prefs.EncryptionScheme == InternalScheme && !prefs.Encrypt { return nil, fmt.Errorf("internal packages must be encrypted") } if prefs.EncryptionScheme == PGPInlineScheme && mimeType != rfc822.TextPlain { return nil, fmt.Errorf("invalid MIME type for PGP inline package: %s", mimeType) } switch prefs.EncryptionScheme { case ClearScheme: pkg.BodyKey = newSessionKey(decBodyKey) for attID, attKey := range attKeys { pkg.AttachmentKeys[attID] = newSessionKey(attKey) } case InternalScheme, PGPInlineScheme: // ... default: return nil, fmt.Errorf("invalid encryption scheme for package: %d", prefs.EncryptionScheme) } recipient := &MessageRecipient{ Type: prefs.EncryptionScheme, Signature: prefs.SignatureType, AttachmentKeyPackets: make(map[string]string), } if prefs.Encrypt { if prefs.PubKey == nil { return nil, fmt.Errorf("missing public key for %s", addr) } if prefs.SignatureType != DetachedSignature { return nil, fmt.Errorf("invalid signature type for package: %d", prefs.SignatureType) } encBodyKey, err := prefs.PubKey.EncryptSessionKey(decBodyKey) if err != nil { return nil, fmt.Errorf("failed to encrypt session key: %w", err) } recipient.BodyKeyPacket = base64.StdEncoding.EncodeToString(encBodyKey) for attID, attKey := range attKeys { encAttKey, err := prefs.PubKey.EncryptSessionKey(attKey) if err != nil { return nil, fmt.Errorf("failed to encrypt attachment key: %w", err) } recipient.AttachmentKeyPackets[attID] = base64.StdEncoding.EncodeToString(encAttKey) } } pkg.Addresses[addr] = recipient pkg.Type |= prefs.EncryptionScheme } return pkg, nil } func encSplit(kr *crypto.KeyRing, body string) (*crypto.SessionKey, []byte, error) { encBody, err := kr.Encrypt(crypto.NewPlainMessageFromString(body), kr) if err != nil { return nil, nil, fmt.Errorf("failed to encrypt MIME body: %w", err) } splitEncBody, err := encBody.SplitMessage() if err != nil { return nil, nil, fmt.Errorf("failed to split message: %w", err) } decBodyKey, err := kr.DecryptSessionKey(splitEncBody.GetBinaryKeyPacket()) if err != nil { return nil, nil, fmt.Errorf("failed to decrypt session key: %w", err) } return decBodyKey, splitEncBody.GetBinaryDataPacket(), nil } go-proton-api-1.0.0/message_send_types_test.go000066400000000000000000000261661447642273300215010ustar00rootroot00000000000000package proton import ( "testing" "github.com/ProtonMail/gluon/rfc822" "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/stretchr/testify/require" ) func TestSendDraftReq_AddMIMEPackage(t *testing.T) { key, err := crypto.GenerateKey("name", "email", "rsa", 2048) require.NoError(t, err) kr, err := crypto.NewKeyRing(key) require.NoError(t, err) tests := []struct { name string mimeBody string prefs map[string]SendPreferences wantErr bool }{ { name: "Clear MIME with detached signature", mimeBody: "this is a mime body", prefs: map[string]SendPreferences{"mime-sign@email.com": { Encrypt: false, SignatureType: DetachedSignature, EncryptionScheme: ClearMIMEScheme, MIMEType: rfc822.MultipartMixed, }}, wantErr: false, }, { name: "Clear MIME with no signature (error)", mimeBody: "this is a mime body", prefs: map[string]SendPreferences{"mime-no-sign@email.com": { Encrypt: false, SignatureType: NoSignature, EncryptionScheme: ClearMIMEScheme, MIMEType: rfc822.MultipartMixed, }}, wantErr: true, }, { name: "Clear MIME with plain text (error)", mimeBody: "this is a mime body", prefs: map[string]SendPreferences{"mime-plain@email.com": { Encrypt: false, SignatureType: DetachedSignature, EncryptionScheme: ClearMIMEScheme, MIMEType: rfc822.TextPlain, }}, wantErr: true, }, { name: "Clear MIME with rich text (error)", mimeBody: "this is a mime body", prefs: map[string]SendPreferences{"mime-html@email.com": { Encrypt: false, SignatureType: DetachedSignature, EncryptionScheme: ClearMIMEScheme, MIMEType: rfc822.TextHTML, }}, wantErr: true, }, { name: "PGP MIME with detached signature", mimeBody: "this is a mime body", prefs: map[string]SendPreferences{"mime-encrypted@email.com": { Encrypt: true, PubKey: kr, SignatureType: DetachedSignature, EncryptionScheme: PGPMIMEScheme, MIMEType: rfc822.MultipartMixed, }}, wantErr: false, }, { name: "PGP MIME with plain text (error)", mimeBody: "this is a mime body", prefs: map[string]SendPreferences{"mime-encrypted-plain@email.com": { Encrypt: true, PubKey: kr, SignatureType: DetachedSignature, EncryptionScheme: PGPMIMEScheme, MIMEType: rfc822.TextPlain, }}, wantErr: true, }, { name: "PGP MIME with rich text (error)", mimeBody: "this is a mime body", prefs: map[string]SendPreferences{"mime-encrypted-plain@email.com": { Encrypt: true, PubKey: kr, SignatureType: DetachedSignature, EncryptionScheme: PGPMIMEScheme, MIMEType: rfc822.TextHTML, }}, wantErr: true, }, { name: "PGP MIME with missing public key (error)", mimeBody: "this is a mime body", prefs: map[string]SendPreferences{"mime-encrypted-no-pubkey@email.com": { Encrypt: true, SignatureType: DetachedSignature, EncryptionScheme: PGPMIMEScheme, MIMEType: rfc822.MultipartMixed, }}, wantErr: true, }, { name: "PGP MIME with no signature (error)", mimeBody: "this is a mime body", prefs: map[string]SendPreferences{"mime-encrypted-no-signature@email.com": { Encrypt: true, PubKey: kr, SignatureType: NoSignature, EncryptionScheme: PGPMIMEScheme, MIMEType: rfc822.MultipartMixed, }}, wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var req SendDraftReq if err := req.AddMIMEPackage(kr, tt.mimeBody, tt.prefs); (err != nil) != tt.wantErr { t.Errorf("SendDraftReq.AddMIMEPackage() error = %v, wantErr %v", err, tt.wantErr) } }) } } func TestSendDraftReq_AddPackage(t *testing.T) { key, err := crypto.GenerateKey("name", "email", "rsa", 2048) require.NoError(t, err) kr, err := crypto.NewKeyRing(key) require.NoError(t, err) tests := []struct { name string body string mimeType rfc822.MIMEType prefs map[string]SendPreferences attKeys map[string]*crypto.SessionKey wantErr bool }{ { name: "internal plain text with detached signature", body: "this is a text/plain body", mimeType: rfc822.TextPlain, prefs: map[string]SendPreferences{"internal-plain@email.com": { Encrypt: true, PubKey: kr, SignatureType: DetachedSignature, EncryptionScheme: InternalScheme, MIMEType: rfc822.TextPlain, }}, wantErr: false, }, { name: "internal rich text with detached signature", body: "this is a text/html body", mimeType: rfc822.TextHTML, prefs: map[string]SendPreferences{"internal-html@email.com": { Encrypt: true, PubKey: kr, SignatureType: DetachedSignature, EncryptionScheme: InternalScheme, MIMEType: rfc822.TextHTML, }}, wantErr: false, }, { name: "internal rich text with bad package content type (error)", body: "this is a text/html body", mimeType: "bad content type", prefs: map[string]SendPreferences{"internal-bad-package-content-type@email.com": { Encrypt: true, PubKey: kr, SignatureType: DetachedSignature, EncryptionScheme: InternalScheme, MIMEType: rfc822.TextHTML, }}, wantErr: true, }, { name: "internal rich text with bad recipient content type (error)", body: "this is a text/html body", mimeType: rfc822.TextHTML, prefs: map[string]SendPreferences{"internal-bad-recipient-content-type@email.com": { Encrypt: true, PubKey: kr, SignatureType: DetachedSignature, EncryptionScheme: InternalScheme, MIMEType: "bad content type", }}, wantErr: true, }, { name: "internal with multipart (error)", body: "this is a text/html body", mimeType: rfc822.MultipartMixed, prefs: map[string]SendPreferences{"internal-multipart-mixed@email.com": { Encrypt: true, PubKey: kr, SignatureType: DetachedSignature, EncryptionScheme: InternalScheme, MIMEType: rfc822.MultipartMixed, }}, wantErr: true, }, { name: "internal without encryption (error)", body: "this is a text/html body", mimeType: rfc822.TextHTML, prefs: map[string]SendPreferences{"internal-no-encrypt@email.com": { Encrypt: false, PubKey: kr, SignatureType: DetachedSignature, EncryptionScheme: InternalScheme, MIMEType: rfc822.TextHTML, }}, wantErr: true, }, { name: "internal without pubkey (error)", body: "this is a text/html body", mimeType: rfc822.TextHTML, prefs: map[string]SendPreferences{"internal-no-pubkey@email.com": { Encrypt: true, SignatureType: DetachedSignature, EncryptionScheme: InternalScheme, MIMEType: rfc822.TextHTML, }}, wantErr: true, }, { name: "internal without signature (error)", body: "this is a text/html body", mimeType: rfc822.TextHTML, prefs: map[string]SendPreferences{"internal-no-sig@email.com": { Encrypt: true, PubKey: kr, SignatureType: NoSignature, EncryptionScheme: InternalScheme, MIMEType: rfc822.TextHTML, }}, wantErr: true, }, { name: "clear rich text without signature", body: "this is a text/html body", mimeType: rfc822.TextHTML, prefs: map[string]SendPreferences{"clear-rich@email.com": { Encrypt: false, SignatureType: NoSignature, EncryptionScheme: ClearScheme, MIMEType: rfc822.TextHTML, }}, wantErr: false, }, { name: "clear plain text without signature", body: "this is a text/plain body", mimeType: rfc822.TextPlain, prefs: map[string]SendPreferences{"clear-plain@email.com": { Encrypt: false, SignatureType: NoSignature, EncryptionScheme: ClearScheme, MIMEType: rfc822.TextPlain, }}, wantErr: false, }, { name: "clear plain text with signature", body: "this is a text/plain body", mimeType: rfc822.TextPlain, prefs: map[string]SendPreferences{"clear-plain-with-sig@email.com": { Encrypt: false, SignatureType: DetachedSignature, EncryptionScheme: ClearScheme, MIMEType: rfc822.TextPlain, }}, wantErr: false, }, { name: "clear plain text with bad scheme (error)", body: "this is a text/plain body", mimeType: rfc822.TextPlain, prefs: map[string]SendPreferences{"clear-plain-with-sig@email.com": { Encrypt: false, SignatureType: DetachedSignature, EncryptionScheme: PGPInlineScheme, MIMEType: rfc822.TextPlain, }}, wantErr: true, }, { name: "clear rich text with signature (error)", body: "this is a text/html body", mimeType: rfc822.TextHTML, prefs: map[string]SendPreferences{"clear-plain-with-sig@email.com": { Encrypt: false, SignatureType: DetachedSignature, EncryptionScheme: ClearScheme, MIMEType: rfc822.TextHTML, }}, wantErr: true, }, { name: "encrypted plain text with signature", body: "this is a text/plain body", mimeType: rfc822.TextPlain, prefs: map[string]SendPreferences{"pgp-inline-with-sig@email.com": { Encrypt: true, PubKey: kr, SignatureType: DetachedSignature, EncryptionScheme: PGPInlineScheme, MIMEType: rfc822.TextPlain, }}, wantErr: false, }, { name: "encrypted html text with signature (error)", body: "this is a text/html body", mimeType: rfc822.TextHTML, prefs: map[string]SendPreferences{"pgp-inline-rich-with-sig@email.com": { Encrypt: true, PubKey: kr, SignatureType: DetachedSignature, EncryptionScheme: PGPInlineScheme, MIMEType: rfc822.TextHTML, }}, wantErr: true, }, { name: "encrypted mixed text with signature (error)", body: "this is a multipart/mixed body", mimeType: rfc822.MultipartMixed, prefs: map[string]SendPreferences{"pgp-inline-mixed-with-sig@email.com": { Encrypt: true, PubKey: kr, SignatureType: DetachedSignature, EncryptionScheme: PGPInlineScheme, MIMEType: rfc822.MultipartMixed, }}, wantErr: true, }, { name: "encrypted for outside (error)", body: "this is a text/plain body", mimeType: rfc822.TextPlain, prefs: map[string]SendPreferences{"enc-for-outside@email.com": { Encrypt: true, PubKey: kr, SignatureType: DetachedSignature, EncryptionScheme: EncryptedOutsideScheme, MIMEType: rfc822.TextPlain, }}, wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var req SendDraftReq if err := req.AddTextPackage(kr, tt.body, tt.mimeType, tt.prefs, tt.attKeys); (err != nil) != tt.wantErr { t.Errorf("SendDraftReq.AddPackage() error = %v, wantErr %v", err, tt.wantErr) } }) } } go-proton-api-1.0.0/message_types.go000066400000000000000000000113761447642273300174260ustar00rootroot00000000000000package proton import ( "bytes" "io" "net/mail" "github.com/ProtonMail/gluon/rfc822" "github.com/ProtonMail/go-crypto/openpgp/armor" "github.com/ProtonMail/gopenpgp/v2/crypto" "golang.org/x/exp/slices" ) type MessageMetadata struct { ID string AddressID string LabelIDs []string ExternalID string Subject string Sender *mail.Address ToList []*mail.Address CCList []*mail.Address BCCList []*mail.Address ReplyTos []*mail.Address Flags MessageFlag Time int64 Size int Unread Bool IsReplied Bool IsRepliedAll Bool IsForwarded Bool NumAttachments int } func (meta MessageMetadata) Seen() bool { return !bool(meta.Unread) } func (meta MessageMetadata) Starred() bool { return slices.Contains(meta.LabelIDs, StarredLabel) } func (meta MessageMetadata) IsDraft() bool { return meta.Flags&(MessageFlagReceived|MessageFlagSent) == 0 } type MessageFilter struct { ID []string `json:",omitempty"` Subject string `json:",omitempty"` AddressID string `json:",omitempty"` ExternalID string `json:",omitempty"` LabelID string `json:",omitempty"` EndID string `json:",omitempty"` Desc Bool } type Message struct { MessageMetadata Header string ParsedHeaders Headers Body string MIMEType rfc822.MIMEType Attachments []Attachment } type MessageFlag int64 const ( MessageFlagReceived MessageFlag = 1 << 0 MessageFlagSent MessageFlag = 1 << 1 MessageFlagInternal MessageFlag = 1 << 2 MessageFlagE2E MessageFlag = 1 << 3 MessageFlagAuto MessageFlag = 1 << 4 MessageFlagReplied MessageFlag = 1 << 5 MessageFlagRepliedAll MessageFlag = 1 << 6 MessageFlagForwarded MessageFlag = 1 << 7 MessageFlagAutoReplied MessageFlag = 1 << 8 MessageFlagImported MessageFlag = 1 << 9 MessageFlagOpened MessageFlag = 1 << 10 MessageFlagReceiptSent MessageFlag = 1 << 11 MessageFlagNotified MessageFlag = 1 << 12 MessageFlagTouched MessageFlag = 1 << 13 MessageFlagReceipt MessageFlag = 1 << 14 MessageFlagReceiptRequest MessageFlag = 1 << 16 MessageFlagPublicKey MessageFlag = 1 << 17 MessageFlagSign MessageFlag = 1 << 18 MessageFlagUnsubscribed MessageFlag = 1 << 19 MessageFlagScheduledSend MessageFlag = 1 << 20 MessageFlagAlias MessageFlag = 1 << 21 MessageFlagDMARCPass MessageFlag = 1 << 23 MessageFlagSPFFail MessageFlag = 1 << 24 MessageFlagDKIMFail MessageFlag = 1 << 25 MessageFlagDMARCFail MessageFlag = 1 << 26 MessageFlagHamManual MessageFlag = 1 << 27 MessageFlagSpamAuto MessageFlag = 1 << 28 MessageFlagSpamManual MessageFlag = 1 << 29 MessageFlagPhishingAuto MessageFlag = 1 << 30 MessageFlagPhishingManual MessageFlag = 1 << 31 ) func (f MessageFlag) Has(flag MessageFlag) bool { return f&flag != 0 } func (f MessageFlag) Matches(flag MessageFlag) bool { return f&flag == flag } func (f MessageFlag) HasAny(flags ...MessageFlag) bool { for _, flag := range flags { if f.Has(flag) { return true } } return false } func (f MessageFlag) HasAll(flags ...MessageFlag) bool { for _, flag := range flags { if !f.Has(flag) { return false } } return true } func (f MessageFlag) Add(flag MessageFlag) MessageFlag { return f | flag } func (f MessageFlag) Remove(flag MessageFlag) MessageFlag { return f &^ flag } func (f MessageFlag) Toggle(flag MessageFlag) MessageFlag { if f.Has(flag) { return f.Remove(flag) } return f.Add(flag) } func (m Message) Decrypt(kr *crypto.KeyRing) ([]byte, error) { enc, err := crypto.NewPGPMessageFromArmored(m.Body) if err != nil { return nil, err } dec, err := kr.Decrypt(enc, nil, crypto.GetUnixTime()) if err != nil { return nil, err } return dec.GetBinary(), nil } func (m Message) DecryptInto(kr *crypto.KeyRing, buffer io.ReaderFrom) error { armored, err := armor.Decode(bytes.NewReader([]byte(m.Body))) if err != nil { return err } stream, err := kr.DecryptStream(armored.Body, nil, crypto.GetUnixTime()) if err != nil { return err } if _, err := buffer.ReadFrom(stream); err != nil { return err } return nil } type FullMessage struct { Message AttData [][]byte } type Signature struct { Hash string Data *crypto.PGPSignature } type MessageActionReq struct { IDs []string } type LabelMessagesReq struct { LabelID string IDs []string } type LabelMessagesRes struct { Responses []LabelMessageRes UndoToken UndoToken } func (res LabelMessagesRes) ok() (bool, string) { for _, resp := range res.Responses { if resp.Response.Code != SuccessCode { return false, resp.Response.Error() } } return true, "" } type LabelMessageRes struct { ID string Response APIError } type MessageGroupCount struct { LabelID string Total int Unread int } go-proton-api-1.0.0/message_types_test.go000066400000000000000000000020731447642273300204570ustar00rootroot00000000000000package proton import ( "os" "testing" "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/stretchr/testify/require" ) func TestDecrypt(t *testing.T) { body, err := os.ReadFile("testdata/body.pgp") require.NoError(t, err) pubKR := loadKeyRing(t, "testdata/pub.asc", nil) prvKR := loadKeyRing(t, "testdata/prv.asc", []byte("password")) msg := Message{Body: string(body)} sigs, err := ExtractSignatures(prvKR, msg.Body) require.NoError(t, err) enc, err := crypto.NewPGPMessageFromArmored(msg.Body) require.NoError(t, err) dec, err := prvKR.Decrypt(enc, nil, crypto.GetUnixTime()) require.NoError(t, err) require.NoError(t, pubKR.VerifyDetached(dec, sigs[0].Data, crypto.GetUnixTime())) } func loadKeyRing(t *testing.T, file string, pass []byte) *crypto.KeyRing { f, err := os.Open(file) require.NoError(t, err) defer f.Close() key, err := crypto.NewKeyFromArmoredReader(f) require.NoError(t, err) if pass != nil { key, err = key.Unlock(pass) require.NoError(t, err) } kr, err := crypto.NewKeyRing(key) require.NoError(t, err) return kr } go-proton-api-1.0.0/netctl.go000066400000000000000000000271101447642273300160400ustar00rootroot00000000000000package proton import ( "context" "crypto/tls" "errors" "fmt" "io" "net" "net/http" "sync" "time" ) // Listener wraps a net.Listener. // It can be configured to spawn connections that drop all reads or writes. type Listener struct { net.Listener canRead bool rlock sync.RWMutex canWrite bool wlock sync.RWMutex conns []net.Conn connLock sync.RWMutex done chan struct{} doneOnce sync.Once newConn func(net.Conn, *Listener) net.Conn } // NewListener returns a new DropListener. func NewListener(l net.Listener, newConn func(net.Conn, *Listener) net.Conn) *Listener { return &Listener{ Listener: l, canRead: true, canWrite: true, done: make(chan struct{}), newConn: newConn, } } func (l *Listener) Accept() (net.Conn, error) { conn, err := l.Listener.Accept() if err != nil { return nil, err } l.connLock.Lock() defer l.connLock.Unlock() dropConn := l.newConn(conn, l) l.conns = append(l.conns, dropConn) return dropConn, nil } // SetCanRead sets whether the connections spawned by this listener can read. func (l *Listener) SetCanRead(canRead bool) { l.rlock.Lock() defer l.rlock.Unlock() l.canRead = canRead } // SetCanWrite sets whether the connections spawned by this listener can write. func (l *Listener) SetCanWrite(canWrite bool) { l.wlock.Lock() defer l.wlock.Unlock() l.canWrite = canWrite } // Close closes the listener. func (l *Listener) Close() error { defer l.doneOnce.Do(func() { close(l.done) }) return l.Listener.Close() } // Done returns a channel that is closed when the listener is closed. func (l *Listener) Done() <-chan struct{} { return l.done } // DropAll closes all connections spawned by this listener. func (l *Listener) DropAll() { l.connLock.RLock() defer l.connLock.RUnlock() for _, conn := range l.conns { _ = conn.Close() } } type hangConn struct { net.Conn l *Listener } func NewHangConn(c net.Conn, l *Listener) net.Conn { return &hangConn{ Conn: c, l: l, } } func (c *hangConn) Read(b []byte) (int, error) { c.l.rlock.RLock() defer c.l.rlock.RUnlock() if !c.l.canRead { c.l.rlock.RUnlock() <-c.l.Done() c.l.rlock.RLock() } return c.Conn.Read(b) } func (c *hangConn) Write(b []byte) (int, error) { c.l.wlock.RLock() defer c.l.wlock.RUnlock() if !c.l.canWrite { c.l.wlock.RUnlock() <-c.l.Done() c.l.wlock.RLock() } return c.Conn.Write(b) } type dropConn struct { net.Conn l *Listener } func NewDropConn(c net.Conn, l *Listener) net.Conn { return &dropConn{ Conn: c, l: l, } } func (c *dropConn) Read(b []byte) (int, error) { c.l.rlock.RLock() defer c.l.rlock.RUnlock() if c.l.canRead { return c.Conn.Read(b) } // Read half the length of the buffer. n, err := c.Conn.Read(b[:len(b)/2]) if err != nil { return n, fmt.Errorf("read: %w", err) } if err := c.Close(); err != nil { return n, fmt.Errorf("close: %w", err) } return n, errors.New("read: connection closed") } func (c *dropConn) Write(b []byte) (int, error) { c.l.wlock.RLock() defer c.l.wlock.RUnlock() if c.l.canWrite { return c.Conn.Write(b) } // Write half the length of the buffer. n, err := c.Conn.Write(b[:len(b)/2]) if err != nil { return n, fmt.Errorf("write: %w", err) } if err := c.Close(); err != nil { return n, fmt.Errorf("close: %w", err) } return n, errors.New("write: connection closed") } func (c *dropConn) Close() error { if tcpConn, ok := c.Conn.(*net.TCPConn); ok { if err := tcpConn.SetLinger(0); err != nil { return err } } return c.Conn.Close() } // InsecureTransport returns an http.Transport with InsecureSkipVerify set to true. func InsecureTransport() *http.Transport { return &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, } } // ctl can be used to control whether a dialer can dial, and whether the resulting // connection can read or write. type NetCtl struct { canDial bool dialLimit uint64 dialCount uint64 onDial []func(net.Conn) dlock sync.RWMutex canRead bool readLimit uint64 readCount uint64 readSpeed int onRead []func([]byte) rlock sync.RWMutex canWrite bool writeLimit uint64 writeCount uint64 writeSpeed int onWrite []func([]byte) wlock sync.RWMutex conns []net.Conn } // NewNetCtl returns a new ctl with all fields set to true. func NewNetCtl() *NetCtl { return &NetCtl{ canDial: true, canRead: true, canWrite: true, } } // SetCanDial sets whether the dialer can dial. func (c *NetCtl) SetCanDial(canDial bool) { c.dlock.Lock() defer c.dlock.Unlock() c.canDial = canDial } // SetDialLimit sets the maximum number of times dialers using this controller can dial. func (c *NetCtl) SetDialLimit(limit uint64) { c.dlock.Lock() defer c.dlock.Unlock() c.dialLimit = limit } // SetCanRead sets whether the connection can read. func (c *NetCtl) SetCanRead(canRead bool) { c.dlock.Lock() defer c.dlock.Unlock() for _, conn := range c.conns { conn.Close() } c.rlock.Lock() defer c.rlock.Unlock() c.canRead = canRead } // SetReadLimit sets the maximum number of bytes that can be read. func (c *NetCtl) SetReadLimit(limit uint64) { c.dlock.Lock() defer c.dlock.Unlock() for _, conn := range c.conns { conn.Close() } c.rlock.Lock() defer c.rlock.Unlock() c.readLimit = limit c.readCount = 0 } // SetReadSpeed sets the maximum number of bytes that can be read per second. func (c *NetCtl) SetReadSpeed(speed int) { c.dlock.Lock() defer c.dlock.Unlock() for _, conn := range c.conns { conn.Close() } c.rlock.Lock() defer c.rlock.Unlock() c.readSpeed = speed } // SetCanWrite sets whether the connection can write. func (c *NetCtl) SetCanWrite(canWrite bool) { c.dlock.Lock() defer c.dlock.Unlock() for _, conn := range c.conns { conn.Close() } c.wlock.Lock() defer c.wlock.Unlock() c.canWrite = canWrite } // SetWriteLimit sets the maximum number of bytes that can be written. func (c *NetCtl) SetWriteLimit(limit uint64) { c.dlock.Lock() defer c.dlock.Unlock() for _, conn := range c.conns { conn.Close() } c.wlock.Lock() defer c.wlock.Unlock() c.writeLimit = limit c.writeCount = 0 } // SetWriteSpeed sets the maximum number of bytes that can be written per second. func (c *NetCtl) SetWriteSpeed(speed int) { c.dlock.Lock() defer c.dlock.Unlock() for _, conn := range c.conns { conn.Close() } c.wlock.Lock() defer c.wlock.Unlock() c.writeSpeed = speed } // OnDial adds a callback that is called with the created connection when a dial is successful. func (c *NetCtl) OnDial(f func(net.Conn)) { c.dlock.Lock() defer c.dlock.Unlock() c.onDial = append(c.onDial, f) } // OnRead adds a callback that is called with the read bytes when a read is successful. func (c *NetCtl) OnRead(fn func([]byte)) { c.rlock.Lock() defer c.rlock.Unlock() c.onRead = append(c.onRead, fn) } // OnWrite adds a callback that is called with the written bytes when a write is successful. func (c *NetCtl) OnWrite(fn func([]byte)) { c.wlock.Lock() defer c.wlock.Unlock() c.onWrite = append(c.onWrite, fn) } // Disable is equivalent to disallowing dial, read and write. func (c *NetCtl) Disable() { c.SetCanDial(false) c.SetCanRead(false) c.SetCanWrite(false) } // Enable is equivalent to allowing dial, read and write. func (c *NetCtl) Enable() { c.SetCanDial(true) c.SetCanRead(true) c.SetCanWrite(true) } // NewDialer returns a new dialer controlled by the ctl. func (c *NetCtl) NewRoundTripper(tlsConfig *tls.Config) http.RoundTripper { return &http.Transport{ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { return c.dial(ctx, &net.Dialer{}, network, addr) }, DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { return c.dial(ctx, &tls.Dialer{Config: tlsConfig}, network, addr) }, TLSClientConfig: tlsConfig, ResponseHeaderTimeout: time.Second, ExpectContinueTimeout: time.Second, } } // ctxDialer implements DialContext. type ctxDialer interface { DialContext(ctx context.Context, network, addr string) (net.Conn, error) } // dial dials using d, but only if the controller allows it. func (c *NetCtl) dial(ctx context.Context, dialer ctxDialer, network, addr string) (net.Conn, error) { c.dlock.Lock() defer c.dlock.Unlock() if !c.canDial { return nil, errors.New("dial failed (not allowed)") } if c.dialLimit > 0 && c.dialCount >= c.dialLimit { return nil, errors.New("dial failed (limit reached)") } conn, err := dialer.DialContext(ctx, network, addr) if err != nil { return nil, err } c.dialCount++ for _, fn := range c.onDial { fn(conn) } c.conns = append(c.conns, conn) return newConn(conn, c), nil } // read reads from r, but only if the controller allows it. func (c *NetCtl) read(r io.Reader, b []byte) (int, error) { c.rlock.Lock() defer c.rlock.Unlock() if !c.canRead { return 0, errors.New("read failed (not allowed)") } if c.readLimit > 0 && c.readCount >= c.readLimit { return 0, errors.New("read failed (limit reached)") } var rem uint64 if c.readLimit > 0 && c.readLimit-c.readCount < uint64(len(b)) { rem = c.readLimit - c.readCount } else { rem = uint64(len(b)) } c.rlock.Unlock() n, err := newSlowReader(r, c.readSpeed).Read(b[:rem]) c.rlock.Lock() c.readCount += uint64(n) for _, fn := range c.onRead { fn(b[:n]) } return n, err } // write writes to w, but only if the controller allows it. func (c *NetCtl) write(w io.Writer, b []byte) (int, error) { c.wlock.Lock() defer c.wlock.Unlock() if !c.canWrite { return 0, errors.New("write failed (not allowed)") } if c.writeLimit > 0 && c.writeCount >= c.writeLimit { return 0, errors.New("write failed (limit exceeded)") } var rem uint64 if c.writeLimit > 0 && c.writeLimit-c.writeCount < uint64(len(b)) { rem = c.writeLimit - c.writeCount } else { rem = uint64(len(b)) } c.wlock.Unlock() n, err := newSlowWriter(w, c.writeSpeed).Write(b[:rem]) c.wlock.Lock() c.writeCount += uint64(n) for _, fn := range c.onWrite { fn(b[:n]) } if uint64(n) < rem { return n, fmt.Errorf("write incomplete (limit reached)") } return n, err } // conn is a wrapper around net.conn that can be used to control whether a connection can read or write. type conn struct { net.Conn ctl *NetCtl } func newConn(c net.Conn, ctl *NetCtl) *conn { return &conn{ Conn: c, ctl: ctl, } } // Read reads from the wrapped connection, but only if the controller allows it. func (c *conn) Read(b []byte) (int, error) { return c.ctl.read(c.Conn, b) } // Write writes to the wrapped connection, but only if the controller allows it. func (c *conn) Write(b []byte) (int, error) { return c.ctl.write(c.Conn, b) } // slowReader is an io.Reader that reads at a fixed rate. type slowReader struct { r io.Reader // bytesPerSec is the number of bytes to read per second. bytesPerSec int } func newSlowReader(r io.Reader, bytesPerSec int) *slowReader { return &slowReader{ r: r, bytesPerSec: bytesPerSec, } } func (r *slowReader) Read(b []byte) (int, error) { start := time.Now() n, err := r.r.Read(b) if r.bytesPerSec > 0 { time.Sleep(time.Until(start.Add(time.Duration(n*r.bytesPerSec) * time.Second))) } return n, err } // slowWriter is an io.Writer that writes at a fixed rate. type slowWriter struct { w io.Writer // bytesPerSec is the number of bytes to write per second. bytesPerSec int } func newSlowWriter(w io.Writer, bytesPerSec int) *slowWriter { return &slowWriter{ w: w, bytesPerSec: bytesPerSec, } } func (w *slowWriter) Write(b []byte) (int, error) { start := time.Now() n, err := w.w.Write(b) if w.bytesPerSec > 0 { time.Sleep(time.Until(start.Add(time.Duration(n*w.bytesPerSec) * time.Second))) } return n, err } go-proton-api-1.0.0/netctl_test.go000066400000000000000000000037541447642273300171070ustar00rootroot00000000000000package proton_test import ( "bytes" "crypto/tls" "io" "net/http" "net/http/httptest" "testing" "github.com/henrybear327/go-proton-api" ) func TestNetCtl_ReadLimit(t *testing.T) { // Create a test http server that writes 100 bytes. // Including the header, this is 217 bytes (100 bytes + 117 bytes). ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { if _, err := w.Write(make([]byte, 100)); err != nil { t.Fatal(err) } })) defer ts.Close() // Create a new net controller. ctl := proton.NewNetCtl() // Set the read limit to 300 bytes -- the first request should succeed, the second should fail. ctl.SetReadLimit(300) // Create a new http client with the dialer. client := &http.Client{ Transport: ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true}), } // This should succeed. if resp, err := client.Get(ts.URL); err != nil { t.Fatal(err) } else { resp.Body.Close() } // This should fail. if _, err := client.Get(ts.URL); err == nil { t.Fatal("expected error") } } func TestNetCtl_WriteLimit(t *testing.T) { // Create a test http server that reads the given body. ts := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { if _, err := io.ReadAll(r.Body); err != nil { t.Fatal(err) } })) defer ts.Close() // Create a new net controller. ctl := proton.NewNetCtl() // Set the read limit to 300 bytes -- the first request should succeed, the second should fail. ctl.SetWriteLimit(300) // Create a new http client with the dialer. client := &http.Client{ Transport: ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true}), } // This should succeed. if resp, err := client.Post(ts.URL, "application/octet-stream", bytes.NewReader(make([]byte, 100))); err != nil { t.Fatal(err) } else { resp.Body.Close() } // This should fail. if _, err := client.Post(ts.URL, "application/octet-stream", bytes.NewReader(make([]byte, 100))); err == nil { t.Fatal("expected error") } } go-proton-api-1.0.0/option.go000066400000000000000000000052341447642273300160620ustar00rootroot00000000000000package proton import ( "net/http" "github.com/ProtonMail/gluon/async" "github.com/go-resty/resty/v2" ) // Option represents a type that can be used to configure the manager. type Option interface { config(*managerBuilder) } func WithHostURL(hostURL string) Option { return &withHostURL{ hostURL: hostURL, } } type withHostURL struct { hostURL string } func (opt withHostURL) config(builder *managerBuilder) { builder.hostURL = opt.hostURL } func WithAppVersion(appVersion string) Option { return &withAppVersion{ appVersion: appVersion, } } type withUserAgent struct { userAgent string } func (opt withUserAgent) config(builder *managerBuilder) { builder.userAgent = opt.userAgent } func WithUserAgent(userAgent string) Option { return &withUserAgent{ userAgent: userAgent, } } type withAppVersion struct { appVersion string } func (opt withAppVersion) config(builder *managerBuilder) { builder.appVersion = opt.appVersion } func WithTransport(transport http.RoundTripper) Option { return &withTransport{ transport: transport, } } type withTransport struct { transport http.RoundTripper } func (opt withTransport) config(builder *managerBuilder) { builder.transport = opt.transport } type withSkipVerifyProofs struct { skipVerifyProofs bool } func (opt withSkipVerifyProofs) config(builder *managerBuilder) { builder.verifyProofs = !opt.skipVerifyProofs } func WithSkipVerifyProofs() Option { return &withSkipVerifyProofs{ skipVerifyProofs: true, } } func WithRetryCount(retryCount int) Option { return &withRetryCount{ retryCount: retryCount, } } type withRetryCount struct { retryCount int } func (opt withRetryCount) config(builder *managerBuilder) { builder.retryCount = opt.retryCount } func WithCookieJar(jar http.CookieJar) Option { return &withCookieJar{ jar: jar, } } type withCookieJar struct { jar http.CookieJar } func (opt withCookieJar) config(builder *managerBuilder) { builder.cookieJar = opt.jar } func WithLogger(logger resty.Logger) Option { return &withLogger{ logger: logger, } } type withLogger struct { logger resty.Logger } func (opt withLogger) config(builder *managerBuilder) { builder.logger = opt.logger } func WithDebug(debug bool) Option { return &withDebug{ debug: debug, } } type withDebug struct { debug bool } func (opt withDebug) config(builder *managerBuilder) { builder.debug = opt.debug } func WithPanicHandler(panicHandler async.PanicHandler) Option { return &withPanicHandler{ panicHandler: panicHandler, } } type withPanicHandler struct { panicHandler async.PanicHandler } func (opt withPanicHandler) config(builder *managerBuilder) { builder.panicHandler = opt.panicHandler } go-proton-api-1.0.0/package.go000066400000000000000000000001201447642273300161320ustar00rootroot00000000000000// Package proton implements types for accessing the Proton API. package proton go-proton-api-1.0.0/paging.go000066400000000000000000000014761447642273300160230ustar00rootroot00000000000000package proton import ( "context" "runtime" "github.com/ProtonMail/gluon/async" "github.com/bradenaw/juniper/iterator" "github.com/bradenaw/juniper/parallel" "github.com/bradenaw/juniper/stream" ) const maxPageSize = 150 func fetchPaged[T any]( ctx context.Context, total, pageSize int, c *Client, fn func(ctx context.Context, page, pageSize int) ([]T, error), ) ([]T, error) { return stream.Collect(ctx, stream.Flatten(parallel.MapStream( ctx, stream.FromIterator(iterator.Counter(total/pageSize+1)), runtime.NumCPU(), runtime.NumCPU(), func(ctx context.Context, page int) (stream.Stream[T], error) { defer async.HandlePanic(c.m.panicHandler) values, err := fn(ctx, page, pageSize) if err != nil { return nil, err } return stream.FromIterator(iterator.Slice(values)), nil }, ))) } go-proton-api-1.0.0/pool.go000066400000000000000000000070751447642273300155300ustar00rootroot00000000000000package proton import ( "context" "errors" "fmt" "sync" "github.com/ProtonMail/gluon/async" ) // ErrJobCancelled indicates the job was cancelled. var ErrJobCancelled = errors.New("job cancelled by surrounding context") // Pool is a worker pool that handles input of type In and returns results of type Out. type Pool[In comparable, Out any] struct { queue *async.QueuedChannel[*job[In, Out]] wg sync.WaitGroup panicHandler async.PanicHandler } // doneFunc must be called to free up pool resources. type doneFunc func() // New returns a new pool. func NewPool[In comparable, Out any](size int, panicHandler async.PanicHandler, work func(context.Context, In) (Out, error)) *Pool[In, Out] { pool := &Pool[In, Out]{ queue: async.NewQueuedChannel[*job[In, Out]](0, 0, panicHandler, "gpa-pool"), } for i := 0; i < size; i++ { pool.wg.Add(1) go func() { defer async.HandlePanic(pool.panicHandler) defer pool.wg.Done() for job := range pool.queue.GetChannel() { select { case <-job.ctx.Done(): job.postFailure(ErrJobCancelled) default: res, err := work(job.ctx, job.req) if err != nil { job.postFailure(err) } else { job.postSuccess(res) } job.waitDone() } } }() } return pool } // Process submits jobs to the pool. The callback provides access to the result, or an error if one occurred. func (pool *Pool[In, Out]) Process(ctx context.Context, reqs []In, fn func(int, In, Out, error) error) error { ctx, cancel := context.WithCancel(ctx) defer cancel() var ( wg sync.WaitGroup errList []error lock sync.Mutex ) for i, req := range reqs { req := req wg.Add(1) go func(index int) { defer async.HandlePanic(pool.panicHandler) defer wg.Done() job, done, err := pool.newJob(ctx, req) if err != nil { lock.Lock() defer lock.Unlock() // Cancel ongoing jobs. cancel() // Collect the error. errList = append(errList, err) return } defer done() res, err := job.result() if err := fn(index, req, res, err); err != nil { lock.Lock() defer lock.Unlock() // Cancel ongoing jobs. cancel() // Collect the error. errList = append(errList, err) } }(i) } wg.Wait() // TODO: Join the errors somehow? if len(errList) > 0 { return errList[0] } return nil } // ProcessAll submits jobs to the pool. All results are returned once available. func (pool *Pool[In, Out]) ProcessAll(ctx context.Context, reqs []In) ([]Out, error) { data := make([]Out, len(reqs)) if err := pool.Process(ctx, reqs, func(index int, req In, res Out, err error) error { if err != nil { return err } data[index] = res return nil }); err != nil { return nil, err } return data, nil } // ProcessOne submits one job to the pool and returns the result. func (pool *Pool[In, Out]) ProcessOne(ctx context.Context, req In) (Out, error) { job, done, err := pool.newJob(ctx, req) if err != nil { var o Out return o, err } defer done() return job.result() } func (pool *Pool[In, Out]) Done() { pool.queue.Close() pool.wg.Wait() } // newJob submits a job to the pool. It returns a job handle and a DoneFunc. // The job handle allows the job result to be obtained. The DoneFunc is used to mark the job as done, // which frees up the worker in the pool for reuse. func (pool *Pool[In, Out]) newJob(ctx context.Context, req In) (*job[In, Out], doneFunc, error) { job := newJob[In, Out](ctx, req) if !pool.queue.Enqueue(job) { return nil, nil, fmt.Errorf("pool closed") } return job, func() { close(job.done) }, nil } go-proton-api-1.0.0/pool_test.go000066400000000000000000000070341447642273300165620ustar00rootroot00000000000000package proton import ( "context" "errors" "runtime" "sync" "testing" "time" "github.com/ProtonMail/gluon/async" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestPool_NewJob(t *testing.T) { doubler := newDoubler(runtime.NumCPU()) defer doubler.Done() job1, done1, err := doubler.newJob(context.Background(), 1) require.NoError(t, err) defer done1() job2, done2, err := doubler.newJob(context.Background(), 2) require.NoError(t, err) defer done2() res2, err := job2.result() require.NoError(t, err) res1, err := job1.result() require.NoError(t, err) assert.Equal(t, 2, res1) assert.Equal(t, 4, res2) } func TestPool_NewJob_Done(t *testing.T) { // Create a doubler pool with 2 workers. doubler := newDoubler(2) defer doubler.Done() // Start two jobs. Don't mark the jobs as done yet. job1, done1, err := doubler.newJob(context.Background(), 1) require.NoError(t, err) job2, done2, err := doubler.newJob(context.Background(), 2) require.NoError(t, err) // Get the first result. res1, _ := job1.result() assert.Equal(t, 2, res1) // Get the first result. res2, _ := job2.result() assert.Equal(t, 4, res2) // Additional jobs will wait. job3, done3, err := doubler.newJob(context.Background(), 3) require.NoError(t, err) job4, done4, err := doubler.newJob(context.Background(), 4) require.NoError(t, err) // Channel to collect results from jobs 3 and 4. resCh := make(chan int, 2) go func() { res, _ := job3.result() resCh <- res }() go func() { res, _ := job4.result() resCh <- res }() // Mark jobs 1 and 2 as done, freeing up the workers. done1() done2() assert.ElementsMatch(t, []int{6, 8}, []int{<-resCh, <-resCh}) // Mark jobs 3 and 4 as done, freeing up the workers. done3() done4() } func TestPool_Process(t *testing.T) { doubler := newDoubler(runtime.NumCPU()) defer doubler.Done() res := make([]int, 5) require.NoError(t, doubler.Process(context.Background(), []int{1, 2, 3, 4, 5}, func(index, reqVal, resVal int, err error) error { require.NoError(t, err) res[index] = resVal return nil })) assert.Equal(t, []int{ 2, 4, 6, 8, 10, }, res) } func TestPool_Process_Error(t *testing.T) { doubler := newDoublerWithError(runtime.NumCPU()) defer doubler.Done() assert.Error(t, doubler.Process(context.Background(), []int{1, 2, 3, 4, 5}, func(_int, _ int, _ int, err error) error { return err })) } func TestPool_Process_Parallel(t *testing.T) { doubler := newDoubler(runtime.NumCPU(), 100*time.Millisecond) defer doubler.Done() var wg sync.WaitGroup for i := 0; i < 8; i++ { wg.Add(1) go func() { defer wg.Done() require.NoError(t, doubler.Process(context.Background(), []int{1, 2, 3, 4}, func(_ int, _ int, _ int, err error) error { return nil })) }() } wg.Wait() } func TestPool_ProcessAll(t *testing.T) { doubler := newDoubler(runtime.NumCPU()) defer doubler.Done() res, err := doubler.ProcessAll(context.Background(), []int{1, 2, 3, 4, 5}) require.NoError(t, err) assert.Equal(t, []int{ 2, 4, 6, 8, 10, }, res) } func newDoubler(workers int, delay ...time.Duration) *Pool[int, int] { return NewPool(workers, async.NoopPanicHandler{}, func(ctx context.Context, req int) (int, error) { if len(delay) > 0 { time.Sleep(delay[0]) } return 2 * req, nil }) } func newDoublerWithError(workers int) *Pool[int, int] { return NewPool(workers, async.NoopPanicHandler{}, func(ctx context.Context, req int) (int, error) { if req%2 == 0 { return 0, errors.New("oops") } return 2 * req, nil }) } go-proton-api-1.0.0/response.go000066400000000000000000000133051447642273300164060ustar00rootroot00000000000000package proton import ( "encoding/json" "errors" "fmt" "io" "math/rand" "net" "net/http" "strconv" "time" "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/go-resty/resty/v2" "github.com/sirupsen/logrus" ) type Code int const ( SuccessCode Code = 1000 MultiCode Code = 1001 InvalidValue Code = 2001 AFileOrFolderNameExist Code = 2500 ADraftExist Code = 2500 AppVersionMissingCode Code = 5001 AppVersionBadCode Code = 5003 UsernameInvalid Code = 6003 // Deprecated, but still used. PasswordWrong Code = 8002 HumanVerificationRequired Code = 9001 PaidPlanRequired Code = 10004 AuthRefreshTokenInvalid Code = 10013 ) var ( ErrFileNameExist = errors.New("a file with that name already exists (Code=2500, Status=422)") ErrFolderNameExist = errors.New("a folder with that name already exists (Code=2500, Status=422)") ErrADraftExist = errors.New("draft already exists on this revision (Code=2500, Status=409)") ) // APIError represents an error returned by the API. type APIError struct { // Status is the HTTP status code of the response that caused the error. Status int // Code is the error code returned by the API. Code Code // Message is the error message returned by the API. Message string `json:"Error"` // Details contains optional error details which are specific to each request. Details any } func (err APIError) Error() string { return fmt.Sprintf("%v (Code=%v, Status=%v)", err.Message, err.Code, err.Status) } func (err APIError) DetailsToString() string { if err.Details == nil { return "" } bytes, e := json.Marshal(err.Details) if e != nil { return fmt.Sprintf("Failed to generate json: %v", e) } return string(bytes) } // NetError represents a network error. It is returned when the API is unreachable. type NetError struct { // Cause is the underlying error that caused the network error. Cause error // Message is an additional message that describes the network error. Message string } func newNetError(err error, message string) *NetError { return &NetError{Cause: err, Message: message} } func (err *NetError) Error() string { return fmt.Sprintf("%s: %v", err.Message, err.Cause) } func (err *NetError) Unwrap() error { return err.Cause } func (err *NetError) Is(target error) bool { _, ok := target.(*NetError) return ok } func catchAPIError(_ *resty.Client, res *resty.Response) error { if !res.IsError() { return nil } method := "NONE" route := "N/A" if res.Request != nil { method = res.Request.Method route = res.Request.URL } var err error if apiErr, ok := res.Error().(*APIError); ok { apiErr.Status = res.StatusCode() err = apiErr } else { statusCode := res.StatusCode() statusText := res.Status() // Catch error that may slip through when APIError deserialization routine fails for whichever reason. if statusCode >= 400 { err = &APIError{ Status: statusCode, Code: 0, Message: statusText, } } else { err = fmt.Errorf("%v", res.Status()) } } return fmt.Errorf( "%v %s %s: %w", res.StatusCode(), method, route, err, ) } func updateTime(_ *resty.Client, res *resty.Response) error { date, err := time.Parse(time.RFC1123, res.Header().Get("Date")) if err != nil { return err } crypto.UpdateTime(date.Unix()) return nil } // nolint:gosec func catchRetryAfter(_ *resty.Client, res *resty.Response) (time.Duration, error) { // 0 and no error means default behaviour which is exponential backoff with jitter. if res.StatusCode() != http.StatusTooManyRequests && res.StatusCode() != http.StatusServiceUnavailable { return 0, nil } // Parse the Retry-After header, or fallback to 10 seconds. after, err := strconv.Atoi(res.Header().Get("Retry-After")) if err != nil { after = 10 } // Add some jitter to the delay. after += rand.Intn(10) logrus.WithFields(logrus.Fields{ "pkg": "go-proton-api", "status": res.StatusCode(), "url": res.Request.URL, "method": res.Request.Method, "after": after, }).Warn("Too many requests, retrying after delay") return time.Duration(after) * time.Second, nil } func catchTooManyRequests(res *resty.Response, _ error) bool { return res.StatusCode() == http.StatusTooManyRequests || res.StatusCode() == http.StatusServiceUnavailable } func catchDialError(res *resty.Response, err error) bool { return res.RawResponse == nil } func catchDropError(_ *resty.Response, err error) bool { if netErr := new(net.OpError); errors.As(err, &netErr) { return true } return false } // parseResponse should be used as post-processing of response when request is // called with resty.SetDoNotParseResponse(off). // // In this case the resty is not processing request at all including http // status check or APIerror parsing. Hence, the returned error would be nil // even on non-200 reponsenses. // // This function also closes the response body. func parseResponse(res *resty.Response, err error) (*resty.Response, error) { if err != nil || res.StatusCode() == 200 { return res, err } method := "NONE" route := "N/A" if res.Request != nil { method = res.Request.Method route = res.Request.URL } apiErr, ok := parseRawAPIError(res.RawBody()) if !ok { apiErr = &APIError{ Code: 0, Message: res.Status(), } } apiErr.Status = res.StatusCode() return res, fmt.Errorf( "%v %s %s: %w", res.StatusCode(), method, route, apiErr, ) } func parseRawAPIError(rawResponse io.ReadCloser) (*APIError, bool) { apiErr := APIError{} defer rawResponse.Close() body, err := io.ReadAll(rawResponse) if err != nil { return &apiErr, false } if err := json.Unmarshal(body, &apiErr); err != nil { return &apiErr, false } return &apiErr, true } go-proton-api-1.0.0/response_test.go000066400000000000000000000053101447642273300174420ustar00rootroot00000000000000package proton_test import ( "context" "encoding/json" "errors" "net" "net/http" "net/http/httptest" "net/url" "testing" "github.com/henrybear327/go-proton-api" "github.com/henrybear327/go-proton-api/server" "github.com/stretchr/testify/require" ) func TestNetError_DropOnWrite(t *testing.T) { l, err := net.Listen("tcp", ":0") require.NoError(t, err) dropListener := proton.NewListener(l, proton.NewDropConn) // Use a custom listener that drops all writes. dropListener.SetCanWrite(false) // Simulate a server that refuses to write. s := server.New(server.WithListener(dropListener)) defer s.Close() m := proton.New(proton.WithHostURL(s.GetHostURL())) defer m.Close() // This should fail with a URL error. pingErr := m.Ping(context.Background()) if urlErr := new(url.Error); !errors.As(pingErr, &urlErr) { t.Fatalf("expected a url.Error, got %T: %v", pingErr, pingErr) } } func TestAPIError_DeserializeWithoutDetails(t *testing.T) { errJson := ` { "Status": 400, "Code": 1000, "Error": "Foo Bar" } ` var err proton.APIError require.NoError(t, json.Unmarshal([]byte(errJson), &err)) require.Nil(t, err.Details) } func TestAPIError_DeserializeWithoutDetailsValue(t *testing.T) { errJson := ` { "Status": 400, "Code": 1000, "Error": "Foo Bar", "Details": 20 } ` var err proton.APIError require.NoError(t, json.Unmarshal([]byte(errJson), &err)) require.NotNil(t, err.Details) require.Equal(t, `20`, err.DetailsToString()) } func TestAPIError_DeserializeWithDetailsObject(t *testing.T) { errJson := ` { "Status": 400, "Code": 1000, "Error": "Foo Bar", "Details": { "object2": { "v": 20 }, "foo": "bar" } } ` var err proton.APIError require.NoError(t, json.Unmarshal([]byte(errJson), &err)) require.NotNil(t, err.Details) require.Equal(t, `{"foo":"bar","object2":{"v":20}}`, err.DetailsToString()) } func TestAPIError_DeserializeWithDetailsArray(t *testing.T) { errJson := ` { "Status": 400, "Code": 1000, "Error": "Foo Bar", "Details": [ { "object2": { "v": 20 }, "foo": "bar" }, 499, "hello" ] } ` var err proton.APIError require.NoError(t, json.Unmarshal([]byte(errJson), &err)) require.NotNil(t, err.Details) require.Equal(t, `[{"foo":"bar","object2":{"v":20}},499,"hello"]`, err.DetailsToString()) } func TestNetError_RouteInErrorMessage(t *testing.T) { s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusBadRequest) })) defer s.Close() m := proton.New(proton.WithHostURL(s.URL)) defer m.Close() pingErr := m.Quark(context.Background(), "test/ping") require.Error(t, pingErr) require.Contains(t, pingErr.Error(), "GET") require.Contains(t, pingErr.Error(), "/test/ping") } go-proton-api-1.0.0/salt.go000066400000000000000000000005611447642273300155130ustar00rootroot00000000000000package proton import ( "context" "github.com/go-resty/resty/v2" ) func (c *Client) GetSalts(ctx context.Context) (Salts, error) { var res struct { KeySalts []Salt } if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetResult(&res).Get("/core/v4/keys/salts") }); err != nil { return nil, err } return res.KeySalts, nil } go-proton-api-1.0.0/salt_types.go000066400000000000000000000012631447642273300167370ustar00rootroot00000000000000package proton import ( "encoding/base64" "fmt" "github.com/ProtonMail/go-srp" "github.com/bradenaw/juniper/xslices" ) type Salt struct { ID, KeySalt string } type Salts []Salt func (salts Salts) SaltForKey(keyPass []byte, keyID string) ([]byte, error) { idx := xslices.IndexFunc(salts, func(salt Salt) bool { return salt.ID == keyID }) if idx < 0 { return nil, fmt.Errorf("no salt found for key %s", keyID) } keySalt, err := base64.StdEncoding.DecodeString(salts[idx].KeySalt) if err != nil { return nil, err } saltedKeyPass, err := srp.MailboxPassword(keyPass, keySalt) if err != nil { return nil, nil } return saltedKeyPass[len(saltedKeyPass)-31:], nil } go-proton-api-1.0.0/server/000077500000000000000000000000001447642273300155255ustar00rootroot00000000000000go-proton-api-1.0.0/server/addresses.go000066400000000000000000000044011447642273300200300ustar00rootroot00000000000000package server import ( "net/http" "github.com/gin-gonic/gin" "github.com/henrybear327/go-proton-api" "golang.org/x/exp/slices" ) func (s *Server) handleGetAddresses() gin.HandlerFunc { return func(c *gin.Context) { addresses, err := s.b.GetAddresses(c.GetString("UserID")) if err != nil { c.AbortWithStatus(http.StatusInternalServerError) return } c.JSON(http.StatusOK, gin.H{ "Addresses": addresses, }) } } func (s *Server) handleGetAddress() gin.HandlerFunc { return func(c *gin.Context) { address, err := s.b.GetAddress(c.GetString("UserID"), c.Param("addressID")) if err != nil { c.AbortWithStatus(http.StatusInternalServerError) return } c.JSON(http.StatusOK, gin.H{ "Address": address, }) } } func (s *Server) handlePutAddressEnable() gin.HandlerFunc { return func(c *gin.Context) { if err := s.b.EnableAddress(c.GetString("UserID"), c.Param("addressID")); err != nil { c.AbortWithStatus(http.StatusInternalServerError) return } } } func (s *Server) handlePutAddressDisable() gin.HandlerFunc { return func(c *gin.Context) { if err := s.b.DisableAddress(c.GetString("UserID"), c.Param("addressID")); err != nil { c.AbortWithStatus(http.StatusInternalServerError) return } } } func (s *Server) handleDeleteAddress() gin.HandlerFunc { return func(c *gin.Context) { if err := s.b.DeleteAddress(c.GetString("UserID"), c.Param("addressID")); err != nil { c.AbortWithStatus(http.StatusInternalServerError) return } } } func (s *Server) handlePutAddressesOrder() gin.HandlerFunc { return func(c *gin.Context) { var req proton.OrderAddressesReq if err := c.BindJSON(&req); err != nil { c.AbortWithStatus(http.StatusBadRequest) return } addresses, err := s.b.GetAddresses(c.GetString("UserID")) if err != nil { c.AbortWithStatus(http.StatusInternalServerError) return } if len(req.AddressIDs) != len(addresses) { c.AbortWithStatus(http.StatusBadRequest) return } for _, address := range addresses { if !slices.Contains(req.AddressIDs, address.ID) { c.AbortWithStatus(http.StatusBadRequest) return } } if err := s.b.SetAddressOrder(c.GetString("UserID"), req.AddressIDs); err != nil { c.AbortWithStatus(http.StatusInternalServerError) return } } } go-proton-api-1.0.0/server/attachments.go000066400000000000000000000026731447642273300203770ustar00rootroot00000000000000package server import ( "io" "mime/multipart" "net/http" "github.com/ProtonMail/gluon/rfc822" "github.com/gin-gonic/gin" "github.com/henrybear327/go-proton-api" ) func (s *Server) handlePostMailAttachments() gin.HandlerFunc { return func(c *gin.Context) { form, err := c.MultipartForm() if err != nil { c.AbortWithStatus(http.StatusBadRequest) return } attachment, err := s.b.CreateAttachment( c.GetString("UserID"), form.Value["MessageID"][0], form.Value["Filename"][0], rfc822.MIMEType(form.Value["MIMEType"][0]), proton.Disposition(form.Value["Disposition"][0]), form.Value["ContentID"][0], mustReadFileHeader(form.File["KeyPackets"][0]), mustReadFileHeader(form.File["DataPacket"][0]), string(mustReadFileHeader(form.File["Signature"][0])), ) if err != nil { _ = c.AbortWithError(http.StatusUnprocessableEntity, err) return } c.JSON(http.StatusOK, gin.H{ "Attachment": attachment, }) } } func (s *Server) handleGetMailAttachment() gin.HandlerFunc { return func(c *gin.Context) { attData, err := s.b.GetAttachment(c.Param("attachID")) if err != nil { _ = c.AbortWithError(http.StatusUnprocessableEntity, err) return } c.Data(http.StatusOK, "application/octet-stream", attData) } } func mustReadFileHeader(fh *multipart.FileHeader) []byte { f, err := fh.Open() if err != nil { panic(err) } data, err := io.ReadAll(f) if err != nil { panic(err) } return data } go-proton-api-1.0.0/server/auth.go000066400000000000000000000054451447642273300170250ustar00rootroot00000000000000package server import ( "encoding/base64" "net/http" "github.com/gin-gonic/gin" "github.com/henrybear327/go-proton-api" ) func (s *Server) handlePostAuthInfo() gin.HandlerFunc { return func(c *gin.Context) { var req proton.AuthInfoReq if err := c.BindJSON(&req); err != nil { return } info, err := s.b.NewAuthInfo(req.Username) if err != nil { _ = c.AbortWithError(http.StatusUnauthorized, err) return } c.JSON(http.StatusOK, info) } } func (s *Server) handlePostAuth() gin.HandlerFunc { return func(c *gin.Context) { var req proton.AuthReq if err := c.BindJSON(&req); err != nil { return } clientEphemeral, err := base64.StdEncoding.DecodeString(req.ClientEphemeral) if err != nil { _ = c.AbortWithError(http.StatusBadRequest, err) return } clientProof, err := base64.StdEncoding.DecodeString(req.ClientProof) if err != nil { _ = c.AbortWithError(http.StatusBadRequest, err) return } auth, err := s.b.NewAuth(req.Username, clientEphemeral, clientProof, req.SRPSession) if err != nil { _ = c.AbortWithError(http.StatusUnauthorized, err) return } c.JSON(http.StatusOK, auth) } } func (s *Server) handlePostAuthRefresh() gin.HandlerFunc { return func(c *gin.Context) { var req proton.AuthRefreshReq if err := c.BindJSON(&req); err != nil { return } auth, err := s.b.NewAuthRef(req.UID, req.RefreshToken) if err != nil { _ = c.AbortWithError(http.StatusUnprocessableEntity, err) return } c.JSON(http.StatusOK, auth) } } func (s *Server) handleDeleteAuth() gin.HandlerFunc { return func(c *gin.Context) { if err := s.b.DeleteSession(c.GetString("UserID"), c.GetString("AuthUID")); err != nil { _ = c.AbortWithError(http.StatusUnauthorized, err) return } } } func (s *Server) handleGetAuthSessions() gin.HandlerFunc { return func(c *gin.Context) { sessions, err := s.b.GetSessions(c.GetString("UserID")) if err != nil { _ = c.AbortWithError(http.StatusInternalServerError, err) return } c.JSON(http.StatusOK, gin.H{"Sessions": sessions}) } } func (s *Server) handleDeleteAuthSessions() gin.HandlerFunc { return func(c *gin.Context) { sessions, err := s.b.GetSessions(c.GetString("UserID")) if err != nil { _ = c.AbortWithError(http.StatusInternalServerError, err) return } for _, session := range sessions { if session.UID != c.GetString("AuthUID") { if err := s.b.DeleteSession(c.GetString("UserID"), session.UID); err != nil { _ = c.AbortWithError(http.StatusInternalServerError, err) return } } } } } func (s *Server) handleDeleteAuthSession() gin.HandlerFunc { return func(c *gin.Context) { if err := s.b.DeleteSession(c.GetString("UserID"), c.Param("authUID")); err != nil { _ = c.AbortWithError(http.StatusInternalServerError, err) return } } } go-proton-api-1.0.0/server/backend/000077500000000000000000000000001447642273300171145ustar00rootroot00000000000000go-proton-api-1.0.0/server/backend/account.go000066400000000000000000000034421447642273300211020ustar00rootroot00000000000000package backend import ( "sync" "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/bradenaw/juniper/xslices" "github.com/google/uuid" "github.com/henrybear327/go-proton-api" ) type account struct { userID string username string addresses map[string]*address mailSettings *mailSettings userSettings proton.UserSettings auth map[string]auth authLock sync.RWMutex keys []key salt []byte verifier []byte labelIDs []string messageIDs []string updateIDs []ID } func newAccount(userID, username string, armKey string, salt, verifier []byte) *account { return &account{ userID: userID, username: username, addresses: make(map[string]*address), mailSettings: newMailSettings(username), userSettings: newUserSettings(), auth: make(map[string]auth), keys: []key{{keyID: uuid.NewString(), key: armKey}}, salt: salt, verifier: verifier, } } func (acc *account) toUser() proton.User { return proton.User{ ID: acc.userID, Name: acc.username, DisplayName: acc.username, Email: acc.primary().email, Keys: xslices.Map(acc.keys, func(key key) proton.Key { privKey, err := crypto.NewKeyFromArmored(key.key) if err != nil { panic(err) } rawKey, err := privKey.Serialize() if err != nil { panic(err) } return proton.Key{ ID: key.keyID, PrivateKey: rawKey, Primary: key == acc.keys[0], Active: true, } }), } } func (acc *account) primary() *address { for _, addr := range acc.addresses { if addr.order == 1 { return addr } } panic("no primary address") } func (acc *account) getAddr(email string) (*address, bool) { for _, addr := range acc.addresses { if addr.email == email { return addr, true } } return nil, false } go-proton-api-1.0.0/server/backend/address.go000066400000000000000000000017221447642273300210720ustar00rootroot00000000000000package backend import ( "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/bradenaw/juniper/xslices" "github.com/henrybear327/go-proton-api" ) type address struct { addrID string email string order int status proton.AddressStatus addrType proton.AddressType keys []key } func (add *address) toAddress() proton.Address { return proton.Address{ ID: add.addrID, Email: add.email, Send: true, Receive: true, Status: add.status, Type: add.addrType, Order: add.order, DisplayName: add.email, Keys: xslices.Map(add.keys, func(key key) proton.Key { privKey, err := crypto.NewKeyFromArmored(key.key) if err != nil { panic(err) } rawKey, err := privKey.Serialize() if err != nil { panic(err) } return proton.Key{ ID: key.keyID, PrivateKey: rawKey, Token: key.tok, Signature: key.sig, Primary: key == add.keys[0], Active: true, } }), } } go-proton-api-1.0.0/server/backend/api.go000066400000000000000000000714761447642273300202330ustar00rootroot00000000000000package backend import ( "encoding/base64" "errors" "fmt" "strings" "github.com/ProtonMail/gluon/rfc822" "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/bradenaw/juniper/xslices" "github.com/henrybear327/go-proton-api" "golang.org/x/exp/maps" "golang.org/x/exp/slices" ) func (b *Backend) GetUser(userID string) (proton.User, error) { return withAcc(b, userID, func(acc *account) (proton.User, error) { return acc.toUser(), nil }) } func (b *Backend) GetKeySalts(userID string) ([]proton.Salt, error) { return withAcc(b, userID, func(acc *account) ([]proton.Salt, error) { return xslices.Map(acc.keys, func(key key) proton.Salt { return proton.Salt{ ID: key.keyID, KeySalt: base64.StdEncoding.EncodeToString(acc.salt), } }), nil }) } func (b *Backend) GetMailSettings(userID string) (proton.MailSettings, error) { return withAcc(b, userID, func(acc *account) (proton.MailSettings, error) { return acc.mailSettings.toMailSettings(), nil }) } func (b *Backend) SetMailSettingsAttachPublicKey(userID string, attach bool) (proton.MailSettings, error) { return withAcc(b, userID, func(acc *account) (proton.MailSettings, error) { acc.mailSettings.attachPubKey = attach return acc.mailSettings.toMailSettings(), nil }) } func (b *Backend) GetUserSettings(userID string) (proton.UserSettings, error) { return withAcc(b, userID, func(acc *account) (proton.UserSettings, error) { return acc.userSettings, nil }) } func (b *Backend) SetUserSettingsTelemetry(userID string, telemetry proton.SettingsBool) (proton.UserSettings, error) { return withAcc(b, userID, func(acc *account) (proton.UserSettings, error) { if telemetry != proton.SettingDisabled && telemetry != proton.SettingEnabled { return proton.UserSettings{}, errors.New("bad value") } acc.userSettings.Telemetry = telemetry updateID, err := b.newUpdate(&userSettingsUpdate{settings: acc.userSettings}) if err != nil { return acc.userSettings, err } acc.updateIDs = append(acc.updateIDs, updateID) return acc.userSettings, nil }) } func (b *Backend) SetUserSettingsCrashReports(userID string, crashReports proton.SettingsBool) (proton.UserSettings, error) { return withAcc(b, userID, func(acc *account) (proton.UserSettings, error) { if crashReports != proton.SettingDisabled && crashReports != proton.SettingEnabled { return proton.UserSettings{}, errors.New("bad value") } acc.userSettings.CrashReports = crashReports updateID, err := b.newUpdate(&userSettingsUpdate{settings: acc.userSettings}) if err != nil { return acc.userSettings, err } acc.updateIDs = append(acc.updateIDs, updateID) return acc.userSettings, nil }) } func (b *Backend) GetAddressID(email string) (string, error) { return withAccEmail(b, email, func(acc *account) (string, error) { addr, ok := acc.getAddr(email) if !ok { return "", fmt.Errorf("no such address: %s", email) } return addr.addrID, nil }) } func (b *Backend) GetAddress(userID, addrID string) (proton.Address, error) { return withAcc(b, userID, func(acc *account) (proton.Address, error) { if addr, ok := acc.addresses[addrID]; ok { return addr.toAddress(), nil } return proton.Address{}, errors.New("no such address") }) } func (b *Backend) GetAddresses(userID string) ([]proton.Address, error) { return withAcc(b, userID, func(acc *account) ([]proton.Address, error) { return xslices.Map(maps.Values(acc.addresses), func(add *address) proton.Address { return add.toAddress() }), nil }) } func (b *Backend) EnableAddress(userID, addrID string) error { return b.withAcc(userID, func(acc *account) error { acc.addresses[addrID].status = proton.AddressStatusEnabled updateID, err := b.newUpdate(&addressUpdated{addressID: addrID}) if err != nil { return err } acc.updateIDs = append(acc.updateIDs, updateID) return nil }) } func (b *Backend) DisableAddress(userID, addrID string) error { return b.withAcc(userID, func(acc *account) error { acc.addresses[addrID].status = proton.AddressStatusDisabled updateID, err := b.newUpdate(&addressUpdated{addressID: addrID}) if err != nil { return err } acc.updateIDs = append(acc.updateIDs, updateID) return nil }) } func (b *Backend) DeleteAddress(userID, addrID string) error { return b.withAcc(userID, func(acc *account) error { if acc.addresses[addrID].status != proton.AddressStatusDisabled { return errors.New("address is not disabled") } delete(acc.addresses, addrID) updateID, err := b.newUpdate(&addressDeleted{addressID: addrID}) if err != nil { return err } acc.updateIDs = append(acc.updateIDs, updateID) return nil }) } func (b *Backend) SetAddressOrder(userID string, addrIDs []string) error { return b.withAcc(userID, func(acc *account) error { for i, addrID := range addrIDs { if add, ok := acc.addresses[addrID]; ok { add.order = i + 1 } else { return fmt.Errorf("no such address: %s", addrID) } } return nil }) } func (b *Backend) HasLabel(userID, labelName string) (string, bool, error) { labels, err := b.GetLabels(userID) if err != nil { return "", false, err } for _, label := range labels { if label.Name == labelName { return label.ID, true, nil } } return "", false, nil } func (b *Backend) GetLabel(userID, labelID string) (proton.Label, error) { labels, err := b.GetLabels(userID) if err != nil { return proton.Label{}, err } for _, label := range labels { if label.ID == labelID { return label, nil } } return proton.Label{}, fmt.Errorf("no such label: %s", labelID) } func (b *Backend) GetLabels(userID string, types ...proton.LabelType) ([]proton.Label, error) { return withAcc(b, userID, func(acc *account) ([]proton.Label, error) { return withLabels(b, func(labels map[string]*label) ([]proton.Label, error) { res := xslices.Map(acc.labelIDs, func(labelID string) proton.Label { return labels[labelID].toLabel(labels) }) for labelName, labelID := range map[string]string{ "Inbox": proton.InboxLabel, "AllDrafts": proton.AllDraftsLabel, "AllSent": proton.AllSentLabel, "Trash": proton.TrashLabel, "Spam": proton.SpamLabel, "All Mail": proton.AllMailLabel, "Archive": proton.ArchiveLabel, "Sent": proton.SentLabel, "Drafts": proton.DraftsLabel, "Outbox": proton.OutboxLabel, "Starred": proton.StarredLabel, "Scheduled": proton.AllScheduledLabel, } { res = append(res, proton.Label{ ID: labelID, Name: labelName, Path: []string{labelName}, Type: proton.LabelTypeSystem, }) } if len(types) > 0 { res = xslices.Filter(res, func(label proton.Label) bool { return slices.Contains(types, label.Type) }) } return res, nil }) }) } func (b *Backend) CreateLabel(userID, labelName, parentID string, labelType proton.LabelType) (proton.Label, error) { return withAcc(b, userID, func(acc *account) (proton.Label, error) { return withLabels(b, func(labels map[string]*label) (proton.Label, error) { if parentID != "" { if labelType != proton.LabelTypeFolder { return proton.Label{}, fmt.Errorf("parentID can only be set for folders") } if _, ok := labels[parentID]; !ok { return proton.Label{}, fmt.Errorf("no such parent label: %s", parentID) } } label := newLabel(labelName, parentID, labelType) labels[label.labelID] = label updateID, err := b.newUpdate(&labelCreated{labelID: label.labelID}) if err != nil { return proton.Label{}, err } acc.labelIDs = append(acc.labelIDs, label.labelID) acc.updateIDs = append(acc.updateIDs, updateID) return label.toLabel(labels), nil }) }) } func (b *Backend) UpdateLabel(userID, labelID, name, parentID string) (proton.Label, error) { return withAcc(b, userID, func(acc *account) (proton.Label, error) { return withLabels(b, func(labels map[string]*label) (proton.Label, error) { if parentID != "" { if labels[labelID].labelType != proton.LabelTypeFolder { return proton.Label{}, fmt.Errorf("parentID can only be set for folders") } if _, ok := labels[parentID]; !ok { return proton.Label{}, fmt.Errorf("no such parent label: %s", parentID) } } labels[labelID].name = name labels[labelID].parentID = parentID updateID, err := b.newUpdate(&labelUpdated{labelID: labelID}) if err != nil { return proton.Label{}, err } acc.updateIDs = append(acc.updateIDs, updateID) return labels[labelID].toLabel(labels), nil }) }) } func (b *Backend) DeleteLabel(userID, labelID string) error { return b.withAcc(userID, func(acc *account) error { return b.withLabels(func(labels map[string]*label) error { if _, ok := labels[labelID]; !ok { return errors.New("label not found") } for _, labelID := range getLabelIDsToDelete(labelID, labels) { delete(labels, labelID) updateID, err := b.newUpdate(&labelDeleted{labelID: labelID}) if err != nil { return err } acc.labelIDs = xslices.Filter(acc.labelIDs, func(otherID string) bool { return otherID != labelID }) acc.updateIDs = append(acc.updateIDs, updateID) } return nil }) }) } func (b *Backend) CountMessages(userID string) (int, error) { return withAcc(b, userID, func(acc *account) (int, error) { return len(acc.messageIDs), nil }) } func (b *Backend) GetMessageIDs(userID string, afterID string, limit int) ([]string, error) { return withAcc(b, userID, func(acc *account) ([]string, error) { if len(acc.messageIDs) == 0 { return nil, nil } var lo, hi int if afterID == "" { lo = 0 } else { lo = slices.Index(acc.messageIDs, afterID) + 1 } if limit == 0 { hi = len(acc.messageIDs) } else { hi = lo + limit if hi > len(acc.messageIDs) { hi = len(acc.messageIDs) } } return acc.messageIDs[lo:hi], nil }) } func (b *Backend) GetMessages(userID string, page, pageSize int, filter proton.MessageFilter) ([]proton.MessageMetadata, error) { return withAcc(b, userID, func(acc *account) ([]proton.MessageMetadata, error) { return withMessages(b, func(messages map[string]*message) ([]proton.MessageMetadata, error) { metadata, err := withAtts(b, func(atts map[string]*attachment) ([]proton.MessageMetadata, error) { return xslices.Map(acc.messageIDs, func(messageID string) proton.MessageMetadata { return messages[messageID].toMetadata(b.attData, atts) }), nil }) if err != nil { return nil, err } if filter.Desc { xslices.Reverse(metadata) } // Note that this not a perfect replacement as we don't handle the case where this message could have been // deleted in between this metadata request. The backend has the information stored differently and can // resolve these gaps. if filter.EndID != "" { index := xslices.IndexFunc(metadata, func(metadata proton.MessageMetadata) bool { return metadata.ID == filter.EndID }) if index >= 0 { metadata = metadata[index:] } } metadata = xslices.Filter(metadata, func(metadata proton.MessageMetadata) bool { if len(filter.ID) > 0 { if !slices.Contains(filter.ID, metadata.ID) { return false } } if filter.Subject != "" { if !strings.Contains(metadata.Subject, filter.Subject) { return false } } if filter.AddressID != "" { if filter.AddressID != metadata.AddressID { return false } } if filter.ExternalID != "" { if filter.ExternalID != metadata.ExternalID { return false } } if filter.LabelID != "" { if !slices.Contains(metadata.LabelIDs, filter.LabelID) { return false } } return true }) pages := xslices.Chunk(metadata, pageSize) if page >= len(pages) { return nil, nil } return pages[page], nil }) }) } func (b *Backend) GetMessage(userID, messageID string) (proton.Message, error) { return withAcc(b, userID, func(acc *account) (proton.Message, error) { return withMessages(b, func(messages map[string]*message) (proton.Message, error) { return withAtts(b, func(atts map[string]*attachment) (proton.Message, error) { message, ok := messages[messageID] if !ok { return proton.Message{}, errors.New("no such message") } return message.toMessage(b.attData, atts), nil }) }) }) } func (b *Backend) SetMessagesRead(userID string, read bool, messageIDs ...string) error { return b.withAcc(userID, func(acc *account) error { return b.withMessages(func(messages map[string]*message) error { for _, messageID := range messageIDs { messages[messageID].unread = !read updateID, err := b.newUpdate(&messageUpdated{messageID: messageID}) if err != nil { return err } acc.updateIDs = append(acc.updateIDs, updateID) } return nil }) }) } func (b *Backend) LabelMessages(userID, labelID string, messageIDs ...string) error { return b.labelMessages(userID, labelID, true, messageIDs...) } func (b *Backend) LabelMessagesNoEvents(userID, labelID string, messageIDs ...string) error { return b.labelMessages(userID, labelID, false, messageIDs...) } func (b *Backend) labelMessages(userID, labelID string, doEvents bool, messageIDs ...string) error { if labelID == proton.AllMailLabel || labelID == proton.AllDraftsLabel || labelID == proton.AllSentLabel { return fmt.Errorf("not allowed") } return b.withAcc(userID, func(acc *account) error { return b.withMessages(func(messages map[string]*message) error { return b.withLabels(func(labels map[string]*label) error { for _, messageID := range messageIDs { message, ok := messages[messageID] if !ok { continue } message.addLabel(labelID, labels) if doEvents { updateID, err := b.newUpdate(&messageUpdated{messageID: messageID}) if err != nil { return err } acc.updateIDs = append(acc.updateIDs, updateID) } } return nil }) }) }) } func (b *Backend) UnlabelMessages(userID, labelID string, messageIDs ...string) error { if labelID == proton.AllMailLabel || labelID == proton.AllDraftsLabel || labelID == proton.AllSentLabel { return fmt.Errorf("not allowed") } return b.withAcc(userID, func(acc *account) error { return b.withMessages(func(messages map[string]*message) error { return b.withLabels(func(labels map[string]*label) error { for _, messageID := range messageIDs { messages[messageID].remLabel(labelID, labels) updateID, err := b.newUpdate(&messageUpdated{messageID: messageID}) if err != nil { return err } acc.updateIDs = append(acc.updateIDs, updateID) } return nil }) }) }) } func (b *Backend) DeleteMessage(userID, messageID string) error { return b.withAcc(userID, func(acc *account) error { return b.withMessages(func(messages map[string]*message) error { message, ok := messages[messageID] if !ok { return errors.New("no such message") } for _, attID := range message.attIDs { if xslices.CountFunc(maps.Values(b.attachments), func(att *attachment) bool { return att.attDataID == b.attachments[attID].attDataID }) == 1 { delete(b.attData, b.attachments[attID].attDataID) } delete(b.attachments, attID) } delete(b.messages, messageID) updateID, err := b.newUpdate(&messageDeleted{messageID: messageID}) if err != nil { return err } acc.messageIDs = xslices.Filter(acc.messageIDs, func(otherID string) bool { return otherID != messageID }) acc.updateIDs = append(acc.updateIDs, updateID) return nil }) }) } func (b *Backend) CreateDraft(userID, addrID string, draft proton.DraftTemplate, parentID string) (proton.Message, error) { return withAcc(b, userID, func(acc *account) (proton.Message, error) { return withMessages(b, func(messages map[string]*message) (proton.Message, error) { return withLabels(b, func(labels map[string]*label) (proton.Message, error) { // Convert the parentID into externalRef.\ var parentRef string if parentID != "" { parentMsg, ok := messages[parentID] if ok { parentRef = "<" + parentMsg.externalID + ">" } } msg := newMessageFromTemplate(addrID, draft, parentRef) // Drafts automatically get the sysLabel "Drafts". msg.addLabel(proton.DraftsLabel, labels) messages[msg.messageID] = msg updateID, err := b.newUpdate(&messageCreated{messageID: msg.messageID}) if err != nil { return proton.Message{}, err } acc.messageIDs = append(acc.messageIDs, msg.messageID) acc.updateIDs = append(acc.updateIDs, updateID) return msg.toMessage(nil, nil), nil }) }) }) } func (b *Backend) UpdateDraft(userID, draftID string, changes proton.DraftTemplate) (proton.Message, error) { if changes.Sender == nil { return proton.Message{}, errors.New("the Sender is required") } if changes.Sender.Address == "" { return proton.Message{}, errors.New("the Address is required") } if changes.MIMEType != rfc822.TextPlain && changes.MIMEType != rfc822.TextHTML { return proton.Message{}, errors.New("the MIMEType must be text/plain or text/html") } return withAcc(b, userID, func(acc *account) (proton.Message, error) { return withMessages(b, func(messages map[string]*message) (proton.Message, error) { return withAtts(b, func(atts map[string]*attachment) (proton.Message, error) { if _, ok := messages[draftID]; !ok { return proton.Message{}, fmt.Errorf("message %q not found", draftID) } messages[draftID].applyChanges(changes) updateID, err := b.newUpdate(&messageUpdated{messageID: draftID}) if err != nil { return proton.Message{}, err } acc.updateIDs = append(acc.updateIDs, updateID) return messages[draftID].toMessage(b.attData, atts), nil }) }) }) } func (b *Backend) SendMessage(userID, messageID string, packages []*proton.MessagePackage) (proton.Message, error) { return withAcc(b, userID, func(acc *account) (proton.Message, error) { return withMessages(b, func(messages map[string]*message) (proton.Message, error) { return withLabels(b, func(labels map[string]*label) (proton.Message, error) { return withAtts(b, func(atts map[string]*attachment) (proton.Message, error) { msg := messages[messageID] msg.flags |= proton.MessageFlagSent msg.addLabel(proton.SentLabel, labels) updateID, err := b.newUpdate(&messageUpdated{messageID: messageID}) if err != nil { return proton.Message{}, err } acc.updateIDs = append(acc.updateIDs, updateID) for _, pkg := range packages { bodyData, err := base64.StdEncoding.DecodeString(pkg.Body) if err != nil { return proton.Message{}, err } for email, recipient := range pkg.Addresses { if recipient.Type != proton.InternalScheme { continue } if err := b.withAccEmail(email, func(acc *account) error { bodyKey, err := base64.StdEncoding.DecodeString(recipient.BodyKeyPacket) if err != nil { return err } armBody, err := crypto.NewPGPSplitMessage(bodyKey, bodyData).GetPGPMessage().GetArmored() if err != nil { return err } addrID, err := b.GetAddressID(email) if err != nil { return err } newMsg := newMessageFromSent(addrID, armBody, msg) newMsg.flags |= proton.MessageFlagReceived newMsg.addLabel(proton.InboxLabel, labels) newMsg.unread = true messages[newMsg.messageID] = newMsg for _, attID := range msg.attIDs { attKey, err := base64.StdEncoding.DecodeString(recipient.AttachmentKeyPackets[attID]) if err != nil { return err } att := newAttachment( atts[attID].filename, atts[attID].mimeType, atts[attID].disposition, attKey, atts[attID].attDataID, atts[attID].armSig, ) atts[att.attachID] = att messages[newMsg.messageID].attIDs = append(messages[newMsg.messageID].attIDs, att.attachID) } updateID, err := b.newUpdate(&messageCreated{messageID: newMsg.messageID}) if err != nil { return err } acc.messageIDs = append(acc.messageIDs, newMsg.messageID) acc.updateIDs = append(acc.updateIDs, updateID) return nil }); err != nil { return proton.Message{}, err } } } return msg.toMessage(b.attData, atts), nil }) }) }) }) } func (b *Backend) CreateAttachment( userID string, messageID string, filename string, mimeType rfc822.MIMEType, disposition proton.Disposition, contentID string, keyPackets, dataPacket []byte, armSig string, ) (proton.Attachment, error) { if disposition != proton.InlineDisposition && disposition != proton.AttachmentDisposition { return proton.Attachment{}, errors.New("The Disposition only allows 'attachment', or 'inline'") } if disposition == proton.InlineDisposition && contentID == "" { return proton.Attachment{}, errors.New("The 'inline' Disposition is only allowed with Content ID") } return withAcc(b, userID, func(acc *account) (proton.Attachment, error) { return withMessages(b, func(messages map[string]*message) (proton.Attachment, error) { return withAtts(b, func(atts map[string]*attachment) (proton.Attachment, error) { att := newAttachment( filename, mimeType, disposition, keyPackets, b.createAttData(dataPacket), armSig, ) atts[att.attachID] = att messages[messageID].attIDs = append(messages[messageID].attIDs, att.attachID) updateID, err := b.newUpdate(&messageUpdated{messageID: messageID}) if err != nil { return proton.Attachment{}, err } acc.updateIDs = append(acc.updateIDs, updateID) return att.toAttachment(), nil }) }) }) } func (b *Backend) GetAttachment(attachID string) ([]byte, error) { return withAtts(b, func(atts map[string]*attachment) ([]byte, error) { att, ok := atts[attachID] if !ok { return nil, fmt.Errorf("no such attachment: %s", attachID) } return b.attData[att.attDataID], nil }) } func (b *Backend) GetLatestEventID(userID string) (string, error) { return withAcc(b, userID, func(acc *account) (string, error) { return acc.updateIDs[len(acc.updateIDs)-1].String(), nil }) } func getLastUpdateIndex(total, first, max int) int { if max <= 0 || first+max > total { return total } return first + max } func (b *Backend) GetEvent(userID, rawEventID string) (event proton.Event, more bool, err error) { var eventID ID if err := eventID.FromString(rawEventID); err != nil { return proton.Event{}, false, fmt.Errorf("invalid event ID: %s", rawEventID) } more = false event, err = withAcc(b, userID, func(acc *account) (proton.Event, error) { return withMessages(b, func(messages map[string]*message) (proton.Event, error) { return withLabels(b, func(labels map[string]*label) (proton.Event, error) { return withAtts(b, func(attachments map[string]*attachment) (proton.Event, error) { firstUpdate := xslices.Index(acc.updateIDs, eventID) + 1 lastUpdate := getLastUpdateIndex(len(acc.updateIDs), firstUpdate, b.maxUpdatesPerEvent) updates, err := withUpdates(b, func(updates map[ID]update) ([]update, error) { return merge(xslices.Map(acc.updateIDs[firstUpdate:lastUpdate], func(updateID ID) update { return updates[updateID] })), nil }) if err != nil { return proton.Event{}, fmt.Errorf("failed to merge updates: %w", err) } more = lastUpdate != len(acc.updateIDs) return buildEvent(updates, acc.addresses, messages, labels, acc.updateIDs[lastUpdate-1].String(), b.attData, attachments), nil }) }) }) }) if err != nil { return proton.Event{}, false, err } return event, more, nil } func (b *Backend) GetPublicKeys(email string) ([]proton.PublicKey, error) { return withAccEmail(b, email, func(acc *account) ([]proton.PublicKey, error) { var keys []proton.PublicKey for _, addr := range acc.addresses { if addr.email == email { for _, key := range addr.keys { pubKey, err := key.getPubKey() if err != nil { return nil, err } armKey, err := pubKey.GetArmoredPublicKey() if err != nil { return nil, err } keys = append(keys, proton.PublicKey{ Flags: proton.KeyStateTrusted | proton.KeyStateActive, PublicKey: armKey, }) } } } return keys, nil }) } func getLabelIDsToDelete(labelID string, labels map[string]*label) []string { labelIDs := []string{labelID} for _, label := range labels { if label.parentID == labelID { labelIDs = append(labelIDs, getLabelIDsToDelete(label.labelID, labels)...) } } return labelIDs } func (b *Backend) AddAddressCreatedUpdate(userID, addrID string) error { return b.withAcc(userID, func(acc *account) error { updateID, err := b.newUpdate(&addressCreated{addressID: addrID}) if err != nil { return err } acc.updateIDs = append(acc.updateIDs, updateID) return nil }) } func (b *Backend) AddLabelCreatedUpdate(userID, labelID string) error { return b.withAcc(userID, func(acc *account) error { updateID, err := b.newUpdate(&labelCreated{labelID: labelID}) if err != nil { return err } acc.updateIDs = append(acc.updateIDs, updateID) return nil }) } func (b *Backend) AddMessageCreatedUpdate(userID, messageID string) error { return b.withAcc(userID, func(acc *account) error { updateID, err := b.newUpdate(&messageCreated{messageID: messageID}) if err != nil { return err } acc.updateIDs = append(acc.updateIDs, updateID) return nil }) } func (b *Backend) GetMessageGroupCount(userID string) ([]proton.MessageGroupCount, error) { var result []proton.MessageGroupCount err := b.withAcc(userID, func(acc *account) error { return b.withMessages(func(m map[string]*message) error { type stats struct { total int unread int } labelStats := make(map[string]stats) for _, msg := range m { for _, lbl := range msg.getLabelIDs() { v, ok := labelStats[lbl] if !ok { v = stats{} } v.total++ if msg.unread { v.unread++ } labelStats[lbl] = v } } result = make([]proton.MessageGroupCount, 0, len(labelStats)) for lbl, stats := range labelStats { result = append(result, proton.MessageGroupCount{ LabelID: lbl, Total: stats.total, Unread: stats.unread, }) } return nil }) }) return result, err } func buildEvent( updates []update, addresses map[string]*address, messages map[string]*message, labels map[string]*label, eventID string, attachmentData map[string][]byte, attachments map[string]*attachment, ) proton.Event { event := proton.Event{EventID: eventID} for _, update := range updates { switch update := update.(type) { case *userRefreshed: event.Refresh = update.refresh case *messageCreated: event.Messages = append(event.Messages, proton.MessageEvent{ EventItem: proton.EventItem{ ID: update.messageID, Action: proton.EventCreate, }, Message: messages[update.messageID].toMetadata(attachmentData, attachments), }) case *messageUpdated: event.Messages = append(event.Messages, proton.MessageEvent{ EventItem: proton.EventItem{ ID: update.messageID, Action: proton.EventUpdate, }, Message: messages[update.messageID].toMetadata(attachmentData, attachments), }) case *messageDeleted: event.Messages = append(event.Messages, proton.MessageEvent{ EventItem: proton.EventItem{ ID: update.messageID, Action: proton.EventDelete, }, }) case *labelCreated: event.Labels = append(event.Labels, proton.LabelEvent{ EventItem: proton.EventItem{ ID: update.labelID, Action: proton.EventCreate, }, Label: labels[update.labelID].toLabel(labels), }) case *labelUpdated: event.Labels = append(event.Labels, proton.LabelEvent{ EventItem: proton.EventItem{ ID: update.labelID, Action: proton.EventUpdate, }, Label: labels[update.labelID].toLabel(labels), }) case *labelDeleted: event.Labels = append(event.Labels, proton.LabelEvent{ EventItem: proton.EventItem{ ID: update.labelID, Action: proton.EventDelete, }, }) case *addressCreated: event.Addresses = append(event.Addresses, proton.AddressEvent{ EventItem: proton.EventItem{ ID: update.addressID, Action: proton.EventCreate, }, Address: addresses[update.addressID].toAddress(), }) case *addressUpdated: event.Addresses = append(event.Addresses, proton.AddressEvent{ EventItem: proton.EventItem{ ID: update.addressID, Action: proton.EventUpdate, }, Address: addresses[update.addressID].toAddress(), }) case *addressDeleted: event.Addresses = append(event.Addresses, proton.AddressEvent{ EventItem: proton.EventItem{ ID: update.addressID, Action: proton.EventDelete, }, }) case *userSettingsUpdate: event.UserSettings = &proton.UserSettings{ Telemetry: update.settings.Telemetry, CrashReports: update.settings.CrashReports, } } } return event } go-proton-api-1.0.0/server/backend/api_auth.go000066400000000000000000000055431447642273300212440ustar00rootroot00000000000000package backend import ( "encoding/base64" "fmt" "github.com/ProtonMail/go-srp" "github.com/google/uuid" "github.com/henrybear327/go-proton-api" ) func (b *Backend) NewAuthInfo(username string) (proton.AuthInfo, error) { return withAccName(b, username, func(acc *account) (proton.AuthInfo, error) { server, err := srp.NewServerFromSigned(modulus, acc.verifier, 2048) if err != nil { return proton.AuthInfo{}, nil } challenge, err := server.GenerateChallenge() if err != nil { return proton.AuthInfo{}, nil } session := uuid.NewString() b.srpLock.Lock() defer b.srpLock.Unlock() b.srp[session] = server return proton.AuthInfo{ Version: 4, Modulus: modulus, ServerEphemeral: base64.StdEncoding.EncodeToString(challenge), Salt: base64.StdEncoding.EncodeToString(acc.salt), SRPSession: session, }, nil }) } func (b *Backend) NewAuth(username string, ephemeral, proof []byte, session string) (proton.Auth, error) { return withAccName(b, username, func(acc *account) (proton.Auth, error) { b.srpLock.Lock() defer b.srpLock.Unlock() server, ok := b.srp[session] if !ok { return proton.Auth{}, fmt.Errorf("invalid session") } delete(b.srp, session) serverProof, err := server.VerifyProofs(ephemeral, proof) if !ok { return proton.Auth{}, fmt.Errorf("invalid proof: %w", err) } authUID, auth := uuid.NewString(), newAuth(b.authLife) acc.authLock.Lock() defer acc.authLock.Unlock() acc.auth[authUID] = auth return auth.toAuth(acc.userID, authUID, serverProof), nil }) } func (b *Backend) NewAuthRef(authUID, authRef string) (proton.Auth, error) { b.accLock.RLock() defer b.accLock.RUnlock() for _, acc := range b.accounts { acc.authLock.Lock() defer acc.authLock.Unlock() auth, ok := acc.auth[authUID] if !ok { continue } if auth.ref != authRef { return proton.Auth{}, fmt.Errorf("invalid auth ref") } newAuth := newAuth(b.authLife) acc.auth[authUID] = newAuth return newAuth.toAuth(acc.userID, authUID, nil), nil } return proton.Auth{}, fmt.Errorf("invalid auth") } func (b *Backend) VerifyAuth(authUID, authAcc string) (string, error) { return withAccAuth(b, authUID, authAcc, func(acc *account) (string, error) { return acc.userID, nil }) } func (b *Backend) GetSessions(userID string) ([]proton.AuthSession, error) { return withAcc(b, userID, func(acc *account) ([]proton.AuthSession, error) { acc.authLock.RLock() defer acc.authLock.RUnlock() var sessions []proton.AuthSession for authUID, auth := range acc.auth { sessions = append(sessions, auth.toAuthSession(authUID)) } return sessions, nil }) } func (b *Backend) DeleteSession(userID, authUID string) error { return b.withAcc(userID, func(acc *account) error { acc.authLock.Lock() defer acc.authLock.Unlock() delete(acc.auth, authUID) return nil }) } go-proton-api-1.0.0/server/backend/attachment.go000066400000000000000000000023031447642273300215710ustar00rootroot00000000000000package backend import ( "encoding/base64" "github.com/ProtonMail/gluon/rfc822" "github.com/google/uuid" "github.com/henrybear327/go-proton-api" ) func (b *Backend) createAttData(dataPacket []byte) string { attDataID := uuid.NewString() b.attDataLock.Lock() defer b.attDataLock.Unlock() b.attData[attDataID] = dataPacket return attDataID } type attachment struct { attachID string attDataID string filename string mimeType rfc822.MIMEType disposition proton.Disposition keyPackets []byte armSig string } func newAttachment( filename string, mimeType rfc822.MIMEType, disposition proton.Disposition, keyPackets []byte, dataPacketID string, armSig string, ) *attachment { return &attachment{ attachID: uuid.NewString(), attDataID: dataPacketID, filename: filename, mimeType: mimeType, disposition: disposition, keyPackets: keyPackets, armSig: armSig, } } func (att *attachment) toAttachment() proton.Attachment { return proton.Attachment{ ID: att.attachID, Name: att.filename, MIMEType: att.mimeType, Disposition: att.disposition, KeyPackets: base64.StdEncoding.EncodeToString(att.keyPackets), Signature: att.armSig, } } go-proton-api-1.0.0/server/backend/backend.go000066400000000000000000000335561447642273300210460ustar00rootroot00000000000000package backend import ( "fmt" "net/mail" "sync" "time" "github.com/ProtonMail/gluon/rfc822" "github.com/ProtonMail/go-srp" "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/bradenaw/juniper/xslices" "github.com/google/uuid" "github.com/henrybear327/go-proton-api" "golang.org/x/exp/maps" "golang.org/x/exp/slices" ) type Backend struct { domain string accounts map[string]*account accLock sync.RWMutex attachments map[string]*attachment attLock sync.Mutex attData map[string][]byte attDataLock sync.Mutex messages map[string]*message msgLock sync.Mutex labels map[string]*label lblLock sync.Mutex updates map[ID]update updatesLock sync.RWMutex maxUpdatesPerEvent int srp map[string]*srp.Server srpLock sync.Mutex authLife time.Duration enableDedup bool } func New(authLife time.Duration, domain string, enableDedup bool) *Backend { return &Backend{ domain: domain, accounts: make(map[string]*account), attachments: make(map[string]*attachment), attData: make(map[string][]byte), messages: make(map[string]*message), labels: make(map[string]*label), updates: make(map[ID]update), maxUpdatesPerEvent: 0, srp: make(map[string]*srp.Server), authLife: authLife, enableDedup: enableDedup, } } func (b *Backend) SetAuthLife(authLife time.Duration) { b.authLife = authLife } func (b *Backend) SetMaxUpdatesPerEvent(max int) { b.maxUpdatesPerEvent = max } func (b *Backend) CreateUser(username string, password []byte) (string, error) { b.accLock.Lock() defer b.accLock.Unlock() salt, err := crypto.RandomToken(16) if err != nil { return "", err } passphrase, err := hashPassword(password, salt) if err != nil { return "", err } srpAuth, err := srp.NewAuthForVerifier(password, modulus, salt) if err != nil { return "", err } verifier, err := srpAuth.GenerateVerifier(2048) if err != nil { return "", err } armKey, err := GenerateKey(username, username, passphrase, "rsa", 2048) if err != nil { return "", err } userID := uuid.NewString() b.accounts[userID] = newAccount(userID, username, armKey, salt, verifier) return userID, nil } func (b *Backend) RemoveUser(userID string) error { b.accLock.Lock() defer b.accLock.Unlock() user, ok := b.accounts[userID] if !ok { return fmt.Errorf("user %s does not exist", userID) } for _, labelID := range user.labelIDs { delete(b.labels, labelID) } for _, messageID := range user.messageIDs { for _, attID := range b.messages[messageID].attIDs { if xslices.CountFunc(maps.Values(b.attachments), func(att *attachment) bool { return att.attDataID == b.attachments[attID].attDataID }) == 1 { delete(b.attData, b.attachments[attID].attDataID) } delete(b.attachments, attID) } delete(b.messages, messageID) } delete(b.accounts, userID) return nil } func (b *Backend) RefreshUser(userID string, refresh proton.RefreshFlag) error { return b.withAcc(userID, func(acc *account) error { updateID, err := b.newUpdate(&userRefreshed{refresh: refresh}) if err != nil { return err } if refresh == proton.RefreshAll { acc.updateIDs = []ID{updateID} } else { acc.updateIDs = append(acc.updateIDs, updateID) } return nil }) } func (b *Backend) CreateUserKey(userID string, password []byte) error { b.accLock.Lock() defer b.accLock.Unlock() user, ok := b.accounts[userID] if !ok { return fmt.Errorf("user %s does not exist", userID) } salt, err := crypto.RandomToken(16) if err != nil { return err } passphrase, err := hashPassword(password, salt) if err != nil { return err } armKey, err := GenerateKey(user.username, user.username, passphrase, "rsa", 2048) if err != nil { return err } user.keys = append(user.keys, key{keyID: uuid.NewString(), key: armKey}) return nil } func (b *Backend) RemoveUserKey(userID, keyID string) error { b.accLock.Lock() defer b.accLock.Unlock() user, ok := b.accounts[userID] if !ok { return fmt.Errorf("user %s does not exist", userID) } idx := xslices.IndexFunc(user.keys, func(key key) bool { return key.keyID == keyID }) if idx == -1 { return fmt.Errorf("key %s does not exist", keyID) } user.keys = append(user.keys[:idx], user.keys[idx+1:]...) return nil } func (b *Backend) CreateAddress(userID, email string, password []byte, withKey bool, status proton.AddressStatus, addrType proton.AddressType) (string, error) { return b.createAddress(userID, email, password, withKey, status, addrType, false) } func (b *Backend) CreateAddressAsUpdate(userID, email string, password []byte, withKey bool, status proton.AddressStatus, addrType proton.AddressType) (string, error) { return b.createAddress(userID, email, password, withKey, status, addrType, true) } func (b *Backend) createAddress(userID, email string, password []byte, withKey bool, status proton.AddressStatus, addrType proton.AddressType, issueUpdateInsteadOfCreate bool) (string, error) { return withAcc(b, userID, func(acc *account) (string, error) { var keys []key if withKey { token, err := crypto.RandomToken(32) if err != nil { return "", err } armKey, err := GenerateKey(acc.username, email, token, "rsa", 2048) if err != nil { return "", err } passphrase, err := hashPassword([]byte(password), acc.salt) if err != nil { return "", err } userKR, err := acc.keys[0].unlock(passphrase) if err != nil { return "", err } encToken, sigToken, err := encryptWithSignature(userKR, token) if err != nil { return "", err } keys = append(keys, key{ keyID: uuid.NewString(), key: armKey, tok: encToken, sig: sigToken, }) } addressID := uuid.NewString() acc.addresses[addressID] = &address{ addrID: addressID, email: email, order: len(acc.addresses) + 1, status: status, addrType: addrType, keys: keys, } var update update if issueUpdateInsteadOfCreate { update = &addressUpdated{addressID: addressID} } else { update = &addressCreated{addressID: addressID} } updateID, err := b.newUpdate(update) if err != nil { return "", err } acc.updateIDs = append(acc.updateIDs, updateID) return addressID, nil }) } func (b *Backend) ChangeAddressType(userID, addrId string, addrType proton.AddressType) error { return b.withAcc(userID, func(acc *account) error { for _, addr := range acc.addresses { if addr.addrID == addrId { addr.addrType = addrType return nil } } return fmt.Errorf("no addrID matching %s for user %s", addrId, userID) }) } func (b *Backend) CreateAddressKey(userID, addrID string, password []byte) error { return b.withAcc(userID, func(acc *account) error { token, err := crypto.RandomToken(32) if err != nil { return err } armKey, err := GenerateKey(acc.username, acc.addresses[addrID].email, token, "rsa", 2048) if err != nil { return err } passphrase, err := hashPassword([]byte(password), acc.salt) if err != nil { return err } userKR, err := acc.keys[0].unlock(passphrase) if err != nil { return err } encToken, sigToken, err := encryptWithSignature(userKR, token) if err != nil { return err } acc.addresses[addrID].keys = append(acc.addresses[addrID].keys, key{ keyID: uuid.NewString(), key: armKey, tok: encToken, sig: sigToken, }) updateID, err := b.newUpdate(&addressUpdated{addressID: addrID}) if err != nil { return err } acc.updateIDs = append(acc.updateIDs, updateID) return nil }) } func (b *Backend) RemoveAddress(userID, addrID string) error { return b.withAcc(userID, func(acc *account) error { if _, ok := acc.addresses[addrID]; !ok { return fmt.Errorf("address %s not found", addrID) } delete(acc.addresses, addrID) updateID, err := b.newUpdate(&addressDeleted{addressID: addrID}) if err != nil { return err } acc.updateIDs = append(acc.updateIDs, updateID) return nil }) } func (b *Backend) RemoveAddressKey(userID, addrID, keyID string) error { return b.withAcc(userID, func(acc *account) error { idx := xslices.IndexFunc(acc.addresses[addrID].keys, func(key key) bool { return key.keyID == keyID }) if idx < 0 { return fmt.Errorf("key %s not found", keyID) } acc.addresses[addrID].keys = append(acc.addresses[addrID].keys[:idx], acc.addresses[addrID].keys[idx+1:]...) updateID, err := b.newUpdate(&addressUpdated{addressID: addrID}) if err != nil { return err } acc.updateIDs = append(acc.updateIDs, updateID) return nil }) } // TODO: Implement this when we support subscriptions in the test server. func (b *Backend) CreateSubscription(userID, planID string) error { return nil } func (b *Backend) CreateMessage( userID, addrID string, subject string, sender *mail.Address, toList, ccList, bccList, replytos []*mail.Address, armBody string, mimeType rfc822.MIMEType, flags proton.MessageFlag, date time.Time, unread, starred bool, ) (string, error) { return withAcc(b, userID, func(acc *account) (string, error) { return withMessages(b, func(messages map[string]*message) (string, error) { msg := newMessage(addrID, subject, sender, toList, ccList, bccList, replytos, armBody, mimeType, "", date) msg.flags |= flags msg.unread = unread msg.starred = starred addrListEqual := func(l1 []*mail.Address, l2 []*mail.Address) bool { s1 := xslices.Map(l1, func(addr *mail.Address) string { return addr.Address }) s2 := xslices.Map(l2, func(addr *mail.Address) string { return addr.Address }) return slices.Equal(s1, s2) } var foundDuplicate bool if b.enableDedup { for _, m := range messages { if m.addrID != msg.addrID { continue } toEqual := addrListEqual(m.toList, msg.toList) bccEqual := addrListEqual(m.bccList, msg.bccList) ccEqual := addrListEqual(m.ccList, msg.ccList) if m.sender.Address == msg.sender.Address && toEqual && bccEqual && ccEqual && m.subject == msg.subject { msg.messageID = m.messageID foundDuplicate = true break } } } if !foundDuplicate { messages[msg.messageID] = msg updateID, err := b.newUpdate(&messageCreated{messageID: msg.messageID}) if err != nil { return "", err } acc.messageIDs = append(acc.messageIDs, msg.messageID) acc.updateIDs = append(acc.updateIDs, updateID) } return msg.messageID, nil }) }) } func (b *Backend) Encrypt(userID, addrID, decBody string) (string, error) { return withAcc(b, userID, func(acc *account) (string, error) { pubKey, err := acc.addresses[addrID].keys[0].getPubKey() if err != nil { return "", err } kr, err := crypto.NewKeyRing(pubKey) if err != nil { return "", err } enc, err := kr.Encrypt(crypto.NewPlainMessageFromString(decBody), nil) if err != nil { return "", err } return enc.GetArmored() }) } func (b *Backend) withAcc(userID string, fn func(acc *account) error) error { b.accLock.RLock() defer b.accLock.RUnlock() acc, ok := b.accounts[userID] if !ok { return fmt.Errorf("account %s not found", userID) } return fn(acc) } func (b *Backend) withAccEmail(email string, fn func(acc *account) error) error { b.accLock.RLock() defer b.accLock.RUnlock() for _, acc := range b.accounts { for _, addr := range acc.addresses { if addr.email == email { return fn(acc) } } } return fmt.Errorf("account %s not found", email) } func withAcc[T any](b *Backend, userID string, fn func(acc *account) (T, error)) (T, error) { b.accLock.RLock() defer b.accLock.RUnlock() for _, acc := range b.accounts { if acc.userID == userID { return fn(acc) } } return *new(T), fmt.Errorf("account not found") } func withAccName[T any](b *Backend, username string, fn func(acc *account) (T, error)) (T, error) { b.accLock.RLock() defer b.accLock.RUnlock() for _, acc := range b.accounts { if acc.username == username { return fn(acc) } } return *new(T), fmt.Errorf("account not found") } func withAccEmail[T any](b *Backend, email string, fn func(acc *account) (T, error)) (T, error) { b.accLock.RLock() defer b.accLock.RUnlock() for _, acc := range b.accounts { if _, ok := acc.getAddr(email); ok { return fn(acc) } } return *new(T), fmt.Errorf("account not found") } func withAccAuth[T any](b *Backend, authUID, authAcc string, fn func(acc *account) (T, error)) (T, error) { b.accLock.Lock() defer b.accLock.Unlock() for _, acc := range b.accounts { acc.authLock.Lock() defer acc.authLock.Unlock() val, ok := acc.auth[authUID] if !ok { continue } if time.Since(val.creation) > b.authLife { acc.auth[authUID] = auth{ref: val.ref, creation: val.creation} } else if val.acc == authAcc { return fn(acc) } } return *new(T), fmt.Errorf("account not found") } func (b *Backend) withMessages(fn func(map[string]*message) error) error { b.msgLock.Lock() defer b.msgLock.Unlock() return fn(b.messages) } func withMessages[T any](b *Backend, fn func(map[string]*message) (T, error)) (T, error) { b.msgLock.Lock() defer b.msgLock.Unlock() return fn(b.messages) } func withAtts[T any](b *Backend, fn func(map[string]*attachment) (T, error)) (T, error) { b.attLock.Lock() defer b.attLock.Unlock() return fn(b.attachments) } func (b *Backend) withLabels(fn func(map[string]*label) error) error { b.lblLock.Lock() defer b.lblLock.Unlock() return fn(b.labels) } func withLabels[T any](b *Backend, fn func(map[string]*label) (T, error)) (T, error) { b.lblLock.Lock() defer b.lblLock.Unlock() return fn(b.labels) } func (b *Backend) newUpdate(event update) (ID, error) { return withUpdates(b, func(updates map[ID]update) (ID, error) { updateID := ID(len(updates)) updates[updateID] = event return updateID, nil }) } func withUpdates[T any](b *Backend, fn func(map[ID]update) (T, error)) (T, error) { b.updatesLock.Lock() defer b.updatesLock.Unlock() return fn(b.updates) } go-proton-api-1.0.0/server/backend/core_settings.go000066400000000000000000000003301447642273300223070ustar00rootroot00000000000000package backend import ( "github.com/henrybear327/go-proton-api" ) func newUserSettings() proton.UserSettings { return proton.UserSettings{Telemetry: proton.SettingEnabled, CrashReports: proton.SettingEnabled} } go-proton-api-1.0.0/server/backend/crypto.go000066400000000000000000000015171447642273300207670ustar00rootroot00000000000000package backend import ( "github.com/ProtonMail/go-srp" "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/gopenpgp/v2/helper" ) var GenerateKey = helper.GenerateKey func hashPassword(password, salt []byte) ([]byte, error) { passphrase, err := srp.MailboxPassword(password, salt) if err != nil { return nil, err } return passphrase[len(passphrase)-31:], nil } func encryptWithSignature(kr *crypto.KeyRing, b []byte) (string, string, error) { enc, err := kr.Encrypt(crypto.NewPlainMessage(b), nil) if err != nil { return "", "", err } encArm, err := enc.GetArmored() if err != nil { return "", "", err } sig, err := kr.SignDetached(crypto.NewPlainMessage(b)) if err != nil { return "", "", err } sigArm, err := sig.GetArmored() if err != nil { return "", "", err } return encArm, sigArm, nil } go-proton-api-1.0.0/server/backend/crypto_fast.go000066400000000000000000000010621447642273300217770ustar00rootroot00000000000000package backend import "github.com/ProtonMail/gopenpgp/v2/crypto" var preCompKey *crypto.Key func init() { key, err := crypto.GenerateKey("name", "email", "rsa", 1024) if err != nil { panic(err) } preCompKey = key } // FastGenerateKey is a fast version of GenerateKey that uses a pre-computed key. // This is useful for testing but is incredibly insecure. func FastGenerateKey(_, _ string, passphrase []byte, _ string, _ int) (string, error) { encKey, err := preCompKey.Lock(passphrase) if err != nil { return "", err } return encKey.Armor() } go-proton-api-1.0.0/server/backend/label.go000066400000000000000000000015471447642273300205310ustar00rootroot00000000000000package backend import ( "github.com/google/uuid" "github.com/henrybear327/go-proton-api" ) type label struct { labelID string parentID string name string labelType proton.LabelType messageIDs map[string]struct{} } func newLabel(labelName, parentID string, labelType proton.LabelType) *label { return &label{ labelID: uuid.NewString(), parentID: parentID, name: labelName, labelType: labelType, messageIDs: make(map[string]struct{}), } } func (label *label) toLabel(labels map[string]*label) proton.Label { var path []string for labelID := label.labelID; labelID != ""; labelID = labels[labelID].parentID { path = append([]string{labels[labelID].name}, path...) } return proton.Label{ ID: label.labelID, ParentID: label.parentID, Name: label.name, Path: path, Type: label.labelType, } } go-proton-api-1.0.0/server/backend/mail_settings.go000066400000000000000000000011171447642273300223050ustar00rootroot00000000000000package backend import ( "github.com/ProtonMail/gluon/rfc822" "github.com/henrybear327/go-proton-api" ) type mailSettings struct { displayName string draftMIMEType rfc822.MIMEType attachPubKey bool } func newMailSettings(displayName string) *mailSettings { return &mailSettings{ displayName: displayName, attachPubKey: false, } } func (settings *mailSettings) toMailSettings() proton.MailSettings { return proton.MailSettings{ DisplayName: settings.displayName, DraftMIMEType: settings.draftMIMEType, AttachPublicKey: proton.Bool(settings.attachPubKey), } } go-proton-api-1.0.0/server/backend/message.go000066400000000000000000000221101447642273300210630ustar00rootroot00000000000000package backend import ( "net/mail" "strings" "time" "github.com/ProtonMail/gluon/rfc822" "github.com/bradenaw/juniper/xslices" "github.com/google/uuid" "github.com/henrybear327/go-proton-api" "golang.org/x/exp/slices" ) type message struct { messageID string externalID string addrID string labelIDs []string attIDs []string inReplyTo string // sysLabel is the system label for the message. // If nil, the message's flags are used to determine the system label (inbox, sent, drafts). // If "", the message has no system label (e.g. is in a custom folder or all mail). // If non-nil and non-empty, the message has the system label with the given ID (e.g. spam, trash). sysLabel *string subject string sender *mail.Address toList []*mail.Address ccList []*mail.Address bccList []*mail.Address replytos []*mail.Address date time.Time armBody string mimeType rfc822.MIMEType flags proton.MessageFlag unread bool starred bool } func newMessage( addrID string, subject string, sender *mail.Address, toList, ccList, bccList, replytos []*mail.Address, armBody string, mimeType rfc822.MIMEType, externalID string, date time.Time, ) *message { return &message{ messageID: uuid.NewString(), externalID: externalID, addrID: addrID, sysLabel: pointer(""), subject: subject, sender: sender, toList: toList, ccList: ccList, bccList: bccList, replytos: replytos, date: date, armBody: armBody, mimeType: mimeType, } } func newMessageFromSent(addrID, armBody string, msg *message) *message { return &message{ messageID: uuid.NewString(), externalID: msg.externalID, addrID: addrID, sysLabel: pointer(""), subject: msg.subject, sender: msg.sender, toList: msg.toList, ccList: msg.ccList, bccList: nil, // BCC is not sent to the recipient replytos: msg.replytos, date: time.Now(), armBody: armBody, mimeType: msg.mimeType, inReplyTo: msg.inReplyTo, } } func newMessageFromTemplate(addrID string, template proton.DraftTemplate, parentRef string) *message { return &message{ messageID: uuid.NewString(), externalID: template.ExternalID, addrID: addrID, sysLabel: pointer(""), inReplyTo: parentRef, subject: template.Subject, sender: template.Sender, toList: template.ToList, ccList: template.CCList, bccList: template.BCCList, unread: bool(template.Unread), armBody: template.Body, mimeType: template.MIMEType, } } func (msg *message) toMessage(attData map[string][]byte, att map[string]*attachment) proton.Message { return proton.Message{ MessageMetadata: msg.toMetadata(attData, att), Header: msg.getHeader(), ParsedHeaders: msg.getParsedHeaders(), Body: msg.armBody, MIMEType: msg.mimeType, Attachments: xslices.Map(msg.attIDs, func(attID string) proton.Attachment { return att[attID].toAttachment() }), } } func (msg *message) getLabelIDs() []string { labelIDs := []string{proton.AllMailLabel} if msg.flags.HasAny(proton.MessageFlagSent, proton.MessageFlagScheduledSend) { labelIDs = append(labelIDs, proton.AllSentLabel) } if !msg.flags.HasAny(proton.MessageFlagSent, proton.MessageFlagScheduledSend, proton.MessageFlagReceived) { labelIDs = append(labelIDs, proton.AllDraftsLabel) } if msg.starred { labelIDs = append(labelIDs, proton.StarredLabel) } if msg.sysLabel != nil { if *msg.sysLabel != "" { labelIDs = append(labelIDs, *msg.sysLabel) } } else { switch { case msg.flags.Has(proton.MessageFlagReceived): labelIDs = append(labelIDs, proton.InboxLabel) case msg.flags.Has(proton.MessageFlagSent): labelIDs = append(labelIDs, proton.SentLabel) case msg.flags.Has(proton.MessageFlagScheduledSend): labelIDs = append(labelIDs, proton.AllScheduledLabel) default: labelIDs = append(labelIDs, proton.DraftsLabel) } } return labelIDs } func (msg *message) toMetadata(attData map[string][]byte, att map[string]*attachment) proton.MessageMetadata { labelIDs := msg.getLabelIDs() messageSize := len(msg.armBody) for _, a := range msg.attIDs { messageSize += len(attData[att[a].attDataID]) } return proton.MessageMetadata{ ID: msg.messageID, ExternalID: msg.externalID, AddressID: msg.addrID, LabelIDs: append(msg.labelIDs, labelIDs...), Subject: msg.subject, Sender: msg.sender, ToList: msg.toList, CCList: msg.ccList, BCCList: msg.bccList, ReplyTos: msg.replytos, Size: messageSize, Flags: msg.flags, Unread: proton.Bool(msg.unread), NumAttachments: len(attData), } } func (msg *message) getHeader() string { builder := new(strings.Builder) builder.WriteString("Subject: " + msg.subject + "\r\n") if msg.sender != nil && (msg.sender.Name != "" || msg.sender.Address != "") { builder.WriteString("From: " + msg.sender.String() + "\r\n") } if len(msg.toList) > 0 { builder.WriteString("To: " + toAddressList(msg.toList) + "\r\n") } if len(msg.ccList) > 0 { builder.WriteString("Cc: " + toAddressList(msg.ccList) + "\r\n") } if len(msg.bccList) > 0 { builder.WriteString("Bcc: " + toAddressList(msg.bccList) + "\r\n") } if msg.mimeType != "" { builder.WriteString("Content-Type: " + string(msg.mimeType) + "\r\n") } if len(msg.inReplyTo) > 0 { builder.WriteString("References: " + msg.inReplyTo + "\r\n") } if msg.inReplyTo != "" { builder.WriteString("In-Reply-To: " + msg.inReplyTo + "\r\n") } builder.WriteString("Date: " + msg.date.Format(time.RFC822) + "\r\n") return builder.String() } func (msg *message) getParsedHeaders() proton.Headers { header, err := rfc822.NewHeader([]byte(msg.getHeader())) if err != nil { panic(err) } parsed := make(proton.Headers) header.Entries(func(key, value string) { parsed[key] = append(parsed[key], value) }) return parsed } // applyChanges will apply non-nil field from passed message. // // NOTE: This is not feature complete. It might panic on non-implemented changes. func (msg *message) applyChanges(changes proton.DraftTemplate) { if changes.Subject != "" { msg.subject = changes.Subject } if changes.Sender != nil { msg.sender = changes.Sender } if changes.ToList != nil { msg.toList = append([]*mail.Address{}, changes.ToList...) } if changes.CCList != nil { msg.ccList = append([]*mail.Address{}, changes.CCList...) } if changes.BCCList != nil { msg.bccList = append([]*mail.Address{}, changes.BCCList...) } if changes.Body != "" { msg.armBody = changes.Body } if changes.MIMEType != "" { msg.mimeType = changes.MIMEType } if changes.ExternalID != "" { msg.externalID = changes.ExternalID } } func (msg *message) addLabel(labelID string, labels map[string]*label) { switch labelID { case proton.InboxLabel, proton.SentLabel, proton.DraftsLabel, proton.AllScheduledLabel: msg.addFlagLabel(labelID, labels) case proton.TrashLabel, proton.SpamLabel, proton.ArchiveLabel: msg.addSystemLabel(labelID, labels) case proton.StarredLabel: msg.starred = true default: if label, ok := labels[labelID]; ok { msg.addUserLabel(label, labels) } } } func (msg *message) addFlagLabel(labelID string, labels map[string]*label) { msg.labelIDs = xslices.Filter(msg.labelIDs, func(otherLabelID string) bool { return labels[otherLabelID].labelType == proton.LabelTypeLabel }) msg.sysLabel = nil } func (msg *message) addSystemLabel(labelID string, labels map[string]*label) { msg.labelIDs = xslices.Filter(msg.labelIDs, func(otherLabelID string) bool { return labels[otherLabelID].labelType == proton.LabelTypeLabel }) msg.sysLabel = &labelID } func (msg *message) addUserLabel(label *label, labels map[string]*label) { if label.labelType != proton.LabelTypeLabel { msg.labelIDs = xslices.Filter(msg.labelIDs, func(otherLabelID string) bool { return labels[otherLabelID].labelType == proton.LabelTypeLabel }) msg.sysLabel = pointer("") } if !slices.Contains(msg.labelIDs, label.labelID) { msg.labelIDs = append(msg.labelIDs, label.labelID) } } func (msg *message) remLabel(labelID string, labels map[string]*label) { switch labelID { case proton.InboxLabel, proton.SentLabel, proton.DraftsLabel, proton.AllScheduledLabel: msg.remFlagLabel(labelID, labels) case proton.TrashLabel, proton.SpamLabel, proton.ArchiveLabel: msg.remSystemLabel(labelID, labels) case proton.StarredLabel: msg.starred = false default: if label, ok := labels[labelID]; ok { msg.remUserLabel(label, labels) } } } func (msg *message) remFlagLabel(labelID string, labels map[string]*label) { if msg.sysLabel == nil { msg.sysLabel = pointer("") } } func (msg *message) remSystemLabel(labelID string, labels map[string]*label) { if msg.sysLabel != nil && *msg.sysLabel == labelID { msg.sysLabel = pointer("") } } func (msg *message) remUserLabel(label *label, labels map[string]*label) { msg.labelIDs = xslices.Filter(msg.labelIDs, func(otherLabelID string) bool { return otherLabelID != label.labelID }) } func toAddressList(addrs []*mail.Address) string { res := make([]string, len(addrs)) for i, addr := range addrs { res[i] = addr.String() } return strings.Join(res, ", ") } func pointer[T any](v T) *T { return &v } go-proton-api-1.0.0/server/backend/modulus.asc000066400000000000000000000005301447642273300212720ustar00rootroot00000000000000+88jb48lF5TyDBveyHZ7QhSvtc4V3pN8/eQW6kk6ok2egy4lr5Wz9h8iZP3erN9lReSx1Lk+WsLu1b3soDhXX/twTCUhxYwjS8r983aEshZJJq7p5tNroQ5pzrZMbK8Oszjajgdg2YzcMcaJqb9+Doi7egj/esUQ+Q7BWdxeK77Wafj9v7PiW6Ozx6ulppu1mZ+YGnXSXJsl1Cl4nPm7PNkgj4BQT3HLrxakh7Xc3agmepRKO/1jLaOBU/oO17URbA5rwh/ZlAOqEAKH5vJ+hA2acM3Bwsa/K8I/jWicxOoaLZ4RZFpLYvOxGbb4DggR2Ri/C6tNyeEQQKAtxpeV5g==go-proton-api-1.0.0/server/backend/modulus.go000066400000000000000000000004671447642273300211420ustar00rootroot00000000000000package backend import ( _ "embed" "github.com/ProtonMail/gopenpgp/v2/crypto" ) var modulus string func init() { arm, err := crypto.NewClearTextMessage(asc, sig).GetArmored() if err != nil { panic(err) } modulus = arm } //go:embed modulus.asc var asc []byte //go:embed modulus.sig var sig []byte go-proton-api-1.0.0/server/backend/modulus.sig000066400000000000000000000001401447642273300213030ustar00rootroot00000000000000Â^\Ö= 5…ÄéQ&úÖ£E³ûw¬*ç™ÎLyBCùC_ßF(¡bÎôE´ÇÿRë=Ihûîöߟ†àE&P°á~É1/†ü» ¼ý€ go-proton-api-1.0.0/server/backend/quark.go000066400000000000000000000075711447642273300206000ustar00rootroot00000000000000package backend import ( "flag" "fmt" "github.com/henrybear327/go-proton-api" ) func (s *Backend) RunQuarkCommand(command string, args ...string) (any, error) { switch command { case "encryption:id": return s.quarkEncryptionID(args...) case "user:create": return s.quarkUserCreate(args...) case "user:create:address": return s.quarkUserCreateAddress(args...) case "user:create:subscription": return s.quarkUserCreateSubscription(args...) default: return nil, fmt.Errorf("unknown command: %s", command) } } func (s *Backend) quarkEncryptionID(args ...string) (string, error) { fs := flag.NewFlagSet("encryption:id", flag.ContinueOnError) // Positional arguments. // arg0: value decrypt := fs.Bool("decrypt", false, "decrypt the given encrypted ID") if err := fs.Parse(args); err != nil { return "", err } // TODO: Encrypt/decrypt are currently no-op. if *decrypt { return fs.Arg(0), nil } else { return fs.Arg(0), nil } } func (s *Backend) quarkUserCreate(args ...string) (proton.User, error) { fs := flag.NewFlagSet("user:create", flag.ContinueOnError) // Flag arguments. name := fs.String("name", "", "new user's name") pass := fs.String("password", "", "new user's password") newAddr := fs.Bool("create-address", false, "create the user's default address, will not automatically setup the address key") genKeys := fs.String("gen-keys", "", "generate new address keys for the user") status := fs.Int("status", 2, "User status") if err := fs.Parse(args); err != nil { return proton.User{}, err } addressStatus, err := quarkStatusToAddressStatus(*status) if err != nil { return proton.User{}, err } userID, err := s.CreateUser(*name, []byte(*pass)) if err != nil { return proton.User{}, fmt.Errorf("failed to create user: %w", err) } // TODO: Create keys of different types (we always use RSA2048). if *newAddr || *genKeys != "" { if _, err := s.CreateAddress(userID, *name+"@"+s.domain, []byte(*pass), *genKeys != "", addressStatus, proton.AddressTypeOriginal); err != nil { return proton.User{}, fmt.Errorf("failed to create address with keys: %w", err) } } return s.GetUser(userID) } func (s *Backend) quarkUserCreateAddress(args ...string) (proton.Address, error) { fs := flag.NewFlagSet("user:create:address", flag.ContinueOnError) // Positional arguments. // arg0: userID // arg1: password // arg2: email // Flag arguments. genKeys := fs.String("gen-keys", "", "generate new address keys for the user") status := fs.Int("status", 2, "User status") if err := fs.Parse(args); err != nil { return proton.Address{}, err } addressStatus, err := quarkStatusToAddressStatus(*status) if err != nil { return proton.Address{}, err } // TODO: Create keys of different types (we always use RSA2048). addrID, err := s.CreateAddress(fs.Arg(0), fs.Arg(2), []byte(fs.Arg(1)), *genKeys != "", addressStatus, proton.AddressTypeOriginal) if err != nil { return proton.Address{}, fmt.Errorf("failed to create address with keys: %w", err) } return s.GetAddress(fs.Arg(0), addrID) } func (s *Backend) quarkUserCreateSubscription(args ...string) (any, error) { fs := flag.NewFlagSet("user:create:subscription", flag.ContinueOnError) // Positional arguments. // arg0: userID // Flag arguments. planID := fs.String("planID", "", "plan ID for the user") if err := fs.Parse(args); err != nil { return nil, err } if err := s.CreateSubscription(fs.Arg(0), *planID); err != nil { return proton.Address{}, fmt.Errorf("failed to create subscription: %w", err) } return nil, nil } func quarkStatusToAddressStatus(status int) (proton.AddressStatus, error) { switch status { case 0: return proton.AddressStatusDeleting, nil case 1: return proton.AddressStatusDisabled, nil case 2: fallthrough case 3: fallthrough case 4: fallthrough case 5: return proton.AddressStatusEnabled, nil } return 0, fmt.Errorf("invalid status value") } go-proton-api-1.0.0/server/backend/types.go000066400000000000000000000035451447642273300206160ustar00rootroot00000000000000package backend import ( "encoding/base64" "math/big" "time" "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/google/uuid" "github.com/henrybear327/go-proton-api" ) type ID uint64 func (v ID) String() string { return base64.URLEncoding.EncodeToString(v.Bytes()) } func (v ID) Bytes() []byte { if v == 0 { return []byte{0} } return new(big.Int).SetUint64(uint64(v)).Bytes() } func (v *ID) FromString(s string) error { b, err := base64.URLEncoding.DecodeString(s) if err != nil { return err } *v = ID(new(big.Int).SetBytes(b).Uint64()) return nil } type auth struct { acc string ref string creation time.Time } func newAuth(authLife time.Duration) auth { return auth{ acc: uuid.NewString(), ref: uuid.NewString(), creation: time.Now(), } } func (auth *auth) toAuth(userID, authUID string, proof []byte) proton.Auth { return proton.Auth{ UserID: userID, UID: authUID, AccessToken: auth.acc, RefreshToken: auth.ref, ServerProof: base64.StdEncoding.EncodeToString(proof), PasswordMode: proton.OnePasswordMode, } } func (auth *auth) toAuthSession(authUID string) proton.AuthSession { return proton.AuthSession{ UID: authUID, CreateTime: auth.creation.Unix(), Revocable: true, } } type key struct { keyID string key string tok string sig string } func (key key) unlock(passphrase []byte) (*crypto.KeyRing, error) { lockedKey, err := crypto.NewKeyFromArmored(key.key) if err != nil { return nil, err } unlockedKey, err := lockedKey.Unlock(passphrase) if err != nil { return nil, err } return crypto.NewKeyRing(unlockedKey) } func (key key) getPubKey() (*crypto.Key, error) { privKey, err := crypto.NewKeyFromArmored(key.key) if err != nil { return nil, err } pubKeyBin, err := privKey.GetPublicKey() if err != nil { return nil, err } return crypto.NewKey(pubKeyBin) } go-proton-api-1.0.0/server/backend/types_test.go000066400000000000000000000006511447642273300216500ustar00rootroot00000000000000package backend import ( "testing" "github.com/stretchr/testify/require" ) func TestID(t *testing.T) { var v ID // We can set the ID from a string. require.NoError(t, v.FromString("AQIDBA==")) // We can get the ID as a string. require.Equal(t, "AQIDBA==", v.String()) // We can get the ID as bytes. require.Equal(t, []byte{1, 2, 3, 4}, v.Bytes()) // The ID is correct. require.Equal(t, ID(0x01020304), v) } go-proton-api-1.0.0/server/backend/updates.go000066400000000000000000000061051447642273300211120ustar00rootroot00000000000000package backend import ( "github.com/bradenaw/juniper/xslices" "github.com/henrybear327/go-proton-api" ) func merge(updates []update) []update { if len(updates) < 2 { return updates } if merged := merge(updates[1:]); xslices.IndexFunc(merged, func(other update) bool { return other.replaces(updates[0]) }) < 0 { return append([]update{updates[0]}, merged...) } else { return merged } } type update interface { replaces(other update) bool _isUpdate() } type baseUpdate struct{} func (baseUpdate) replaces(update) bool { return false } func (baseUpdate) _isUpdate() {} type userRefreshed struct { baseUpdate refresh proton.RefreshFlag } type messageCreated struct { baseUpdate messageID string } type messageUpdated struct { baseUpdate messageID string } func (update *messageUpdated) replaces(other update) bool { switch other := other.(type) { case *messageUpdated: return update.messageID == other.messageID default: return false } } type messageDeleted struct { baseUpdate messageID string } func (update *messageDeleted) replaces(other update) bool { switch other := other.(type) { case *messageCreated: return update.messageID == other.messageID case *messageUpdated: return update.messageID == other.messageID case *messageDeleted: if update.messageID != other.messageID { return false } panic("message deleted twice") default: return false } } type labelCreated struct { baseUpdate labelID string } type labelUpdated struct { baseUpdate labelID string } func (update *labelUpdated) replaces(other update) bool { switch other := other.(type) { case *labelUpdated: return update.labelID == other.labelID default: return false } } type labelDeleted struct { baseUpdate labelID string } func (update *labelDeleted) replaces(other update) bool { switch other := other.(type) { case *labelCreated: return update.labelID == other.labelID case *labelUpdated: return update.labelID == other.labelID case *labelDeleted: if update.labelID != other.labelID { return false } panic("label deleted twice") default: return false } } type addressCreated struct { baseUpdate addressID string } type addressUpdated struct { baseUpdate addressID string } func (update *addressUpdated) replaces(other update) bool { switch other := other.(type) { case *addressUpdated: return update.addressID == other.addressID default: return false } } type addressDeleted struct { baseUpdate addressID string } func (update *addressDeleted) replaces(other update) bool { switch other := other.(type) { case *addressCreated: return update.addressID == other.addressID case *addressUpdated: return update.addressID == other.addressID case *addressDeleted: if update.addressID != other.addressID { return false } panic("address deleted twice") default: return false } } type userSettingsUpdate struct { baseUpdate settings proton.UserSettings } func (update *userSettingsUpdate) replaces(other update) bool { switch other.(type) { case *userSettingsUpdate: return true default: return false } } go-proton-api-1.0.0/server/backend/updates_test.go000066400000000000000000000023021447642273300221440ustar00rootroot00000000000000package backend import ( "reflect" "testing" ) func Test_mergeUpdates(t *testing.T) { tests := []struct { name string have []update want []update }{ { name: "single", have: []update{&labelCreated{labelID: "1"}}, want: []update{&labelCreated{labelID: "1"}}, }, { name: "multiple", have: []update{ &labelCreated{labelID: "1"}, &labelCreated{labelID: "2"}, }, want: []update{ &labelCreated{labelID: "1"}, &labelCreated{labelID: "2"}, }, }, { name: "replace with updated", have: []update{ &labelCreated{labelID: "1"}, &labelUpdated{labelID: "1"}, &labelUpdated{labelID: "1"}, }, want: []update{ &labelCreated{labelID: "1"}, &labelUpdated{labelID: "1"}, }, }, { name: "replace with delete", have: []update{ &labelCreated{labelID: "1"}, &labelUpdated{labelID: "1"}, &labelUpdated{labelID: "1"}, &labelDeleted{labelID: "1"}, }, want: []update{ &labelDeleted{labelID: "1"}, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := merge(tt.have); !reflect.DeepEqual(got, tt.want) { t.Errorf("mergeUpdates() = %v, want %v", got, tt.want) } }) } } go-proton-api-1.0.0/server/cache.go000066400000000000000000000016321447642273300171210ustar00rootroot00000000000000package server import ( "sync" "github.com/henrybear327/go-proton-api" ) func NewAuthCache() AuthCacher { return &authCache{ info: make(map[string]proton.AuthInfo), auth: make(map[string]proton.Auth), } } type authCache struct { info map[string]proton.AuthInfo auth map[string]proton.Auth lock sync.RWMutex } func (c *authCache) GetAuthInfo(username string) (proton.AuthInfo, bool) { c.lock.RLock() defer c.lock.RUnlock() info, ok := c.info[username] return info, ok } func (c *authCache) SetAuthInfo(username string, info proton.AuthInfo) { c.lock.Lock() defer c.lock.Unlock() c.info[username] = info } func (c *authCache) GetAuth(username string) (proton.Auth, bool) { c.lock.RLock() defer c.lock.RUnlock() auth, ok := c.auth[username] return auth, ok } func (c *authCache) SetAuth(username string, auth proton.Auth) { c.lock.Lock() defer c.lock.Unlock() c.auth[username] = auth } go-proton-api-1.0.0/server/call.go000066400000000000000000000014551447642273300167740ustar00rootroot00000000000000package server import ( "net/http" "net/url" "time" ) type Call struct { URL *url.URL Method string Status int Time time.Time Duration time.Duration RequestHeader http.Header RequestBody []byte ResponseHeader http.Header ResponseBody []byte } type callWatcher struct { paths map[string]struct{} callFn func(Call) } func newCallWatcher(fn func(Call), paths ...string) callWatcher { pathMap := make(map[string]struct{}, len(paths)) for _, path := range paths { pathMap[path] = struct{}{} } return callWatcher{ paths: pathMap, callFn: fn, } } func (watcher *callWatcher) isWatching(path string) bool { if len(watcher.paths) == 0 { return true } _, ok := watcher.paths[path] return ok } func (watcher *callWatcher) publish(call Call) { watcher.callFn(call) } go-proton-api-1.0.0/server/cmd/000077500000000000000000000000001447642273300162705ustar00rootroot00000000000000go-proton-api-1.0.0/server/cmd/client/000077500000000000000000000000001447642273300175465ustar00rootroot00000000000000go-proton-api-1.0.0/server/cmd/client/client.go000066400000000000000000000126261447642273300213620ustar00rootroot00000000000000package main import ( "encoding/json" "fmt" "io" "log" "net" "os" "github.com/henrybear327/go-proton-api/server/proto" "github.com/urfave/cli/v2" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" ) func main() { app := cli.NewApp() app.Flags = []cli.Flag{ &cli.StringFlag{ Name: "host", Usage: "host to connect to", Value: "localhost", }, &cli.IntFlag{ Name: "port", Usage: "port to connect to", Value: 8080, }, } app.Commands = []*cli.Command{ { Name: "info", Action: getInfoAction, }, { Name: "auth", Subcommands: []*cli.Command{ { Name: "revoke", Action: revokeUserAction, Flags: []cli.Flag{ &cli.StringFlag{ Name: "userID", Usage: "user ID to revoke", Required: true, }, }, }, }, }, { Name: "user", Subcommands: []*cli.Command{ { Name: "create", Action: createUserAction, Flags: []cli.Flag{ &cli.StringFlag{ Name: "username", Usage: "username of the account", Required: true, }, &cli.StringFlag{ Name: "password", Usage: "password of the account", Required: true, }, }, }, }, }, { Name: "address", Subcommands: []*cli.Command{ { Name: "create", Action: createAddressAction, Flags: []cli.Flag{ &cli.StringFlag{ Name: "userID", Usage: "ID of the user to create the address for", Required: true, }, &cli.StringFlag{ Name: "email", Usage: "email of the account", Required: true, }, &cli.StringFlag{ Name: "password", Usage: "password of the account", Required: true, }, }, }, { Name: "remove", Action: removeAddressAction, Flags: []cli.Flag{ &cli.StringFlag{ Name: "userID", Usage: "ID of the user to remove the address from", Required: true, }, &cli.StringFlag{ Name: "addressID", Usage: "ID of the address to remove", Required: true, }, }, }, }, }, { Name: "label", Subcommands: []*cli.Command{ { Name: "create", Action: createLabelAction, Flags: []cli.Flag{ &cli.StringFlag{ Name: "userID", Usage: "ID of the user to create the label for", Required: true, }, &cli.StringFlag{ Name: "name", Usage: "name of the label", Required: true, }, &cli.StringFlag{ Name: "parentID", Usage: "the ID of the parent label", }, &cli.BoolFlag{ Name: "exclusive", Usage: "Create an exclusive label (i.e. a folder)", }, }, }, }, }, } if err := app.Run(os.Args); err != nil { log.Fatal(err) } } func getInfoAction(c *cli.Context) error { client, err := newServerClient(c) if err != nil { return err } res, err := client.GetInfo(c.Context, &proto.GetInfoRequest{}) if err != nil { return err } return pretty(c.App.Writer, res) } func createUserAction(c *cli.Context) error { client, err := newServerClient(c) if err != nil { return err } res, err := client.CreateUser(c.Context, &proto.CreateUserRequest{ Username: c.String("username"), Password: []byte(c.String("password")), }) if err != nil { return err } return pretty(c.App.Writer, res) } func revokeUserAction(c *cli.Context) error { client, err := newServerClient(c) if err != nil { return err } res, err := client.RevokeUser(c.Context, &proto.RevokeUserRequest{ UserID: c.String("userID"), }) if err != nil { return err } return pretty(c.App.Writer, res) } func createAddressAction(c *cli.Context) error { client, err := newServerClient(c) if err != nil { return err } res, err := client.CreateAddress(c.Context, &proto.CreateAddressRequest{ UserID: c.String("userID"), Email: c.String("email"), Password: []byte(c.String("password")), }) if err != nil { return err } return pretty(c.App.Writer, res) } func removeAddressAction(c *cli.Context) error { client, err := newServerClient(c) if err != nil { return err } res, err := client.RemoveAddress(c.Context, &proto.RemoveAddressRequest{ UserID: c.String("userID"), AddrID: c.String("addressID"), }) if err != nil { return err } return pretty(c.App.Writer, res) } func createLabelAction(c *cli.Context) error { client, err := newServerClient(c) if err != nil { return err } var labelType proto.LabelType if c.Bool("exclusive") { labelType = proto.LabelType_FOLDER } else { labelType = proto.LabelType_LABEL } res, err := client.CreateLabel(c.Context, &proto.CreateLabelRequest{ UserID: c.String("userID"), Name: c.String("name"), Type: labelType, }) if err != nil { return err } return pretty(c.App.Writer, res) } func newServerClient(c *cli.Context) (proto.ServerClient, error) { cc, err := grpc.DialContext( c.Context, net.JoinHostPort(c.String("host"), fmt.Sprint(c.Int("port"))), grpc.WithTransportCredentials(insecure.NewCredentials()), ) if err != nil { return nil, fmt.Errorf("failed to dial: %w", err) } return proto.NewServerClient(cc), nil } func pretty[T any](w io.Writer, v T) error { enc, err := json.MarshalIndent(v, "", " ") if err != nil { return err } if _, err := w.Write(enc); err != nil { return err } return nil } go-proton-api-1.0.0/server/cmd/server/000077500000000000000000000000001447642273300175765ustar00rootroot00000000000000go-proton-api-1.0.0/server/cmd/server/main.go000066400000000000000000000057211447642273300210560ustar00rootroot00000000000000package main import ( "context" "fmt" "log" "net" "os" "github.com/henrybear327/go-proton-api" "github.com/henrybear327/go-proton-api/server" "github.com/henrybear327/go-proton-api/server/proto" "github.com/urfave/cli/v2" "google.golang.org/grpc" ) func main() { app := cli.NewApp() app.Flags = []cli.Flag{ &cli.IntFlag{ Name: "port", Aliases: []string{"p"}, Usage: "port to serve gRPC on", Value: 8080, }, &cli.BoolFlag{ Name: "tls", }, } app.Action = run if err := app.Run(os.Args); err != nil { log.Fatal(err) } } func run(c *cli.Context) error { s := server.New(server.WithTLS(c.Bool("tls"))) defer s.Close() return newService(s).run(c.Int("port")) } type service struct { proto.UnimplementedServerServer server *server.Server gRPCServer *grpc.Server } func newService(server *server.Server) *service { s := &service{ server: server, gRPCServer: grpc.NewServer(), } proto.RegisterServerServer(s.gRPCServer, s) return s } func (s *service) GetInfo(ctx context.Context, req *proto.GetInfoRequest) (*proto.GetInfoResponse, error) { return &proto.GetInfoResponse{ HostURL: s.server.GetHostURL(), ProxyURL: s.server.GetProxyURL(), }, nil } func (s *service) CreateUser(ctx context.Context, req *proto.CreateUserRequest) (*proto.CreateUserResponse, error) { userID, addrID, err := s.server.CreateUser(req.Username, req.Password) if err != nil { return nil, err } return &proto.CreateUserResponse{ UserID: userID, AddrID: addrID, }, nil } func (s *service) RevokeUser(ctx context.Context, req *proto.RevokeUserRequest) (*proto.RevokeUserResponse, error) { if err := s.server.RevokeUser(req.UserID); err != nil { return nil, err } return &proto.RevokeUserResponse{}, nil } func (s *service) CreateAddress(ctx context.Context, req *proto.CreateAddressRequest) (*proto.CreateAddressResponse, error) { addrID, err := s.server.CreateAddress(req.UserID, req.Email, req.Password) if err != nil { return nil, err } return &proto.CreateAddressResponse{ AddrID: addrID, }, nil } func (s *service) RemoveAddress(ctx context.Context, req *proto.RemoveAddressRequest) (*proto.RemoveAddressResponse, error) { if err := s.server.RemoveAddress(req.UserID, req.AddrID); err != nil { return nil, err } return &proto.RemoveAddressResponse{}, nil } func (s *service) CreateLabel(ctx context.Context, req *proto.CreateLabelRequest) (*proto.CreateLabelResponse, error) { var labelType proton.LabelType switch req.Type { case proto.LabelType_FOLDER: labelType = proton.LabelTypeFolder case proto.LabelType_LABEL: labelType = proton.LabelTypeLabel } labelID, err := s.server.CreateLabel(req.UserID, req.Name, req.ParentID, labelType) if err != nil { return nil, err } return &proto.CreateLabelResponse{ LabelID: labelID, }, nil } func (s *service) run(port int) error { listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) if err != nil { return err } return s.gRPCServer.Serve(listener) } go-proton-api-1.0.0/server/contacts.go000066400000000000000000000004401447642273300176700ustar00rootroot00000000000000package server import ( "net/http" "github.com/gin-gonic/gin" "github.com/henrybear327/go-proton-api" ) func (s *Server) handleGetContactsEmails() gin.HandlerFunc { return func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{ "ContactEmails": []proton.ContactEmail{}, }) } } go-proton-api-1.0.0/server/core_settings.go000066400000000000000000000025651447642273300207340ustar00rootroot00000000000000package server import ( "net/http" "github.com/gin-gonic/gin" "github.com/henrybear327/go-proton-api" ) func (s *Server) handleGetUserSettings() gin.HandlerFunc { return func(c *gin.Context) { settings, err := s.b.GetUserSettings(c.GetString("UserID")) if err != nil { c.AbortWithStatus(http.StatusInternalServerError) return } c.JSON(http.StatusOK, gin.H{ "UserSettings": settings, }) } } func (s *Server) handlePutUserSettingsTelemetry() gin.HandlerFunc { return func(c *gin.Context) { var req proton.SetTelemetryReq if err := c.ShouldBindJSON(&req); err != nil { c.AbortWithStatus(http.StatusBadRequest) return } settings, err := s.b.SetUserSettingsTelemetry(c.GetString("UserID"), req.Telemetry) if err != nil { c.AbortWithStatus(http.StatusInternalServerError) return } c.JSON(http.StatusOK, gin.H{ "UserSettings": settings, }) } } func (s *Server) handlePutUserSettingsCrashReports() gin.HandlerFunc { return func(c *gin.Context) { var req proton.SetCrashReportReq if err := c.ShouldBindJSON(&req); err != nil { c.AbortWithStatus(http.StatusBadRequest) return } settings, err := s.b.SetUserSettingsCrashReports(c.GetString("UserID"), req.CrashReports) if err != nil { c.AbortWithStatus(http.StatusInternalServerError) return } c.JSON(http.StatusOK, gin.H{ "UserSettings": settings, }) } } go-proton-api-1.0.0/server/data.go000066400000000000000000000020211447642273300167600ustar00rootroot00000000000000package server import ( "github.com/gin-gonic/gin" "github.com/henrybear327/go-proton-api" "net/http" ) func (s *Server) handlePostDataStats() gin.HandlerFunc { return func(c *gin.Context) { var req proton.SendStatsReq if err := c.BindJSON(&req); err != nil { c.AbortWithStatus(http.StatusBadRequest) return } if !validateSendStatReq(&req) { c.AbortWithStatus(http.StatusBadRequest) return } c.JSON(http.StatusOK, gin.H{ "Code": proton.SuccessCode, }) } } func (s *Server) handlePostDataStatsMultiple() gin.HandlerFunc { return func(c *gin.Context) { var req proton.SendStatsMultiReq if err := c.BindJSON(&req); err != nil { c.AbortWithStatus(http.StatusBadRequest) return } for _, event := range req.EventInfo { if !validateSendStatReq(&event) { c.AbortWithStatus(http.StatusBadRequest) return } } c.JSON(http.StatusOK, gin.H{ "Code": proton.SuccessCode, }) } } func validateSendStatReq(req *proton.SendStatsReq) bool { return req.MeasurementGroup != "" } go-proton-api-1.0.0/server/domains.go000066400000000000000000000003561447642273300175120ustar00rootroot00000000000000package server import ( "net/http" "github.com/gin-gonic/gin" ) func (s *Server) handleGetDomainsAvailable() gin.HandlerFunc { return func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{ "Domains": []string{s.domain}, }) } } go-proton-api-1.0.0/server/errors.go000066400000000000000000000002721447642273300173710ustar00rootroot00000000000000package server import "errors" var ( ErrNoSuchUser = errors.New("no such user") ErrNoSuchAddress = errors.New("no such address") ErrNoSuchLabel = errors.New("no such label") ) go-proton-api-1.0.0/server/events.go000066400000000000000000000014301447642273300173560ustar00rootroot00000000000000package server import ( "net/http" "github.com/gin-gonic/gin" "github.com/henrybear327/go-proton-api" ) func (s *Server) handleGetEvents() gin.HandlerFunc { return func(c *gin.Context) { event, more, err := s.b.GetEvent(c.GetString("UserID"), c.Param("eventID")) if err != nil { _ = c.AbortWithError(http.StatusBadRequest, err) return } c.JSON( http.StatusOK, struct { proton.Event More proton.Bool }{ event, proton.Bool(more), }, ) } } func (s *Server) handleGetEventsLatest() gin.HandlerFunc { return func(c *gin.Context) { eventID, err := s.b.GetLatestEventID(c.GetString("UserID")) if err != nil { _ = c.AbortWithError(http.StatusBadRequest, err) return } c.JSON(http.StatusOK, gin.H{ "EventID": eventID, }) } } go-proton-api-1.0.0/server/helper_test.go000066400000000000000000000006301447642273300203710ustar00rootroot00000000000000package server import ( "fmt" "github.com/google/uuid" ) func newMessageLiteral(from, to string) []byte { return []byte(fmt.Sprintf("From: %v\r\nReceiver: %v\r\nSubject: %v\r\n\r\nHello World!", from, to, uuid.New())) } func newMessageLiteralWithSubject(from, to, subject string) []byte { return []byte(fmt.Sprintf("From: %v\r\nReceiver: %v\r\nSubject: %v\r\n\r\nHello World!", from, to, subject)) } go-proton-api-1.0.0/server/init_test.go000066400000000000000000000002161447642273300200550ustar00rootroot00000000000000package server import "github.com/henrybear327/go-proton-api/server/backend" func init() { backend.GenerateKey = backend.FastGenerateKey } go-proton-api-1.0.0/server/keys.go000066400000000000000000000014421447642273300170300ustar00rootroot00000000000000package server import ( "net/http" "github.com/gin-gonic/gin" "github.com/henrybear327/go-proton-api" ) func (s *Server) handleGetKeys() gin.HandlerFunc { return func(c *gin.Context) { if pubKeys, err := s.b.GetPublicKeys(c.Query("Email")); err == nil && len(pubKeys) > 0 { c.JSON(http.StatusOK, gin.H{ "Keys": pubKeys, "RecipientType": proton.RecipientTypeInternal, }) } else { c.JSON(http.StatusOK, gin.H{ "RecipientType": proton.RecipientTypeExternal, }) } } } func (s *Server) handleGetKeySalts() gin.HandlerFunc { return func(c *gin.Context) { salts, err := s.b.GetKeySalts(c.GetString("UserID")) if err != nil { c.AbortWithStatus(http.StatusInternalServerError) return } c.JSON(http.StatusOK, gin.H{ "KeySalts": salts, }) } } go-proton-api-1.0.0/server/labels.go000066400000000000000000000044021447642273300173160ustar00rootroot00000000000000package server import ( "net/http" "strconv" "github.com/bradenaw/juniper/xslices" "github.com/gin-gonic/gin" "github.com/henrybear327/go-proton-api" ) func (s *Server) handleGetMailLabels() gin.HandlerFunc { return func(c *gin.Context) { types := xslices.Map(c.QueryArray("Type"), func(val string) proton.LabelType { labelType, err := strconv.Atoi(val) if err != nil { panic(err) } return proton.LabelType(labelType) }) labels, err := s.b.GetLabels(c.GetString("UserID"), types...) if err != nil { c.AbortWithStatus(http.StatusInternalServerError) return } c.JSON(http.StatusOK, gin.H{ "Labels": labels, }) } } func (s *Server) handlePostMailLabels() gin.HandlerFunc { return func(c *gin.Context) { var req proton.CreateLabelReq if err := c.BindJSON(&req); err != nil { c.AbortWithStatus(http.StatusBadRequest) return } if _, has, err := s.b.HasLabel(c.GetString("UserID"), req.Name); err != nil { c.AbortWithStatus(http.StatusInternalServerError) return } else if has { c.AbortWithStatus(http.StatusConflict) return } label, err := s.b.CreateLabel(c.GetString("UserID"), req.Name, req.ParentID, req.Type) if err != nil { c.AbortWithStatus(http.StatusInternalServerError) return } c.JSON(http.StatusOK, gin.H{ "Label": label, }) } } func (s *Server) handlePutMailLabel() gin.HandlerFunc { return func(c *gin.Context) { var req proton.UpdateLabelReq if err := c.BindJSON(&req); err != nil { c.AbortWithStatus(http.StatusBadRequest) return } if labelID, has, err := s.b.HasLabel(c.GetString("UserID"), req.Name); err != nil { c.AbortWithStatus(http.StatusInternalServerError) return } else if has && labelID != c.Param("labelID") { c.AbortWithStatus(http.StatusConflict) return } label, err := s.b.UpdateLabel(c.GetString("UserID"), c.Param("labelID"), req.Name, req.ParentID) if err != nil { c.AbortWithStatus(http.StatusInternalServerError) return } c.JSON(http.StatusOK, gin.H{ "Label": label, }) } } func (s *Server) handleDeleteMailLabel() gin.HandlerFunc { return func(c *gin.Context) { if err := s.b.DeleteLabel(c.GetString("UserID"), c.Param("labelID")); err != nil { c.AbortWithStatus(http.StatusBadRequest) return } } } go-proton-api-1.0.0/server/mail_settings.go000066400000000000000000000016441447642273300207230ustar00rootroot00000000000000package server import ( "net/http" "github.com/gin-gonic/gin" "github.com/henrybear327/go-proton-api" ) func (s *Server) handleGetMailSettings() gin.HandlerFunc { return func(c *gin.Context) { settings, err := s.b.GetMailSettings(c.GetString("UserID")) if err != nil { c.AbortWithStatus(http.StatusInternalServerError) return } c.JSON(http.StatusOK, gin.H{ "MailSettings": settings, }) } } func (s *Server) handlePutMailSettingsAttachPublicKey() gin.HandlerFunc { return func(c *gin.Context) { var req proton.SetAttachPublicKeyReq if err := c.ShouldBindJSON(&req); err != nil { c.AbortWithStatus(http.StatusBadRequest) return } settings, err := s.b.SetMailSettingsAttachPublicKey(c.GetString("UserID"), bool(req.AttachPublicKey)) if err != nil { c.AbortWithStatus(http.StatusInternalServerError) return } c.JSON(http.StatusOK, gin.H{ "MailSettings": settings, }) } } go-proton-api-1.0.0/server/main_test.go000066400000000000000000000002201447642273300200310ustar00rootroot00000000000000package server import ( "testing" "go.uber.org/goleak" ) func TestMain(m *testing.M) { goleak.VerifyTestMain(m, goleak.IgnoreCurrent()) } go-proton-api-1.0.0/server/messages.go000066400000000000000000000351051447642273300176670ustar00rootroot00000000000000package server import ( "encoding/base64" "encoding/json" "fmt" "mime" "net/http" "net/mail" "strconv" "time" "github.com/ProtonMail/gluon/rfc822" "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/gin-gonic/gin" "github.com/henrybear327/go-proton-api" "golang.org/x/exp/slices" ) const ( defaultPage = 0 defaultPageSize = 100 ) func (s *Server) handleGetMailMessages() gin.HandlerFunc { return func(c *gin.Context) { s.getMailMessages( c, mustParseInt(c.DefaultQuery("Page", strconv.Itoa(defaultPage))), mustParseInt(c.DefaultQuery("PageSize", strconv.Itoa(defaultPageSize))), proton.MessageFilter{ID: c.QueryArray("ID")}, ) } } func (s *Server) getMailMessages(c *gin.Context, page, pageSize int, filter proton.MessageFilter) { // Set default page. if page <= 0 { page = defaultPage } // Set default page size. if pageSize <= 0 { pageSize = defaultPageSize } messages, err := s.b.GetMessages(c.GetString("UserID"), page, pageSize, filter) if err != nil { _ = c.AbortWithError(http.StatusInternalServerError, err) return } total, err := s.b.CountMessages(c.GetString("UserID")) if err != nil { _ = c.AbortWithError(http.StatusInternalServerError, err) return } c.JSON(http.StatusOK, gin.H{ "Messages": messages, "Total": total, "Stale": proton.APIFalse, }) } func (s *Server) handlePostMailMessages() gin.HandlerFunc { return func(c *gin.Context) { switch c.GetHeader("X-HTTP-Method-Override") { case "GET": var req struct { proton.MessageFilter Page int PageSize int } if err := c.BindJSON(&req); err != nil { c.AbortWithStatus(http.StatusBadRequest) return } s.getMailMessages(c, req.Page, req.PageSize, req.MessageFilter) default: s.postMailMessages(c) } } } func (s *Server) postMailMessages(c *gin.Context) { var req proton.CreateDraftReq if err := c.BindJSON(&req); err != nil { c.AbortWithStatus(http.StatusBadRequest) return } addrID, err := s.b.GetAddressID(req.Message.Sender.Address) if err != nil { c.AbortWithStatus(http.StatusUnprocessableEntity) return } message, err := s.b.CreateDraft(c.GetString("UserID"), addrID, req.Message, req.ParentID) if err != nil { c.AbortWithStatus(http.StatusUnprocessableEntity) return } c.JSON(http.StatusOK, gin.H{ "Message": message, }) } func (s *Server) handleGetMailMessageIDs() gin.HandlerFunc { return func(c *gin.Context) { limit, err := strconv.Atoi(c.Query("Limit")) if err != nil { c.AbortWithStatus(http.StatusBadRequest) return } messageIDs, err := s.b.GetMessageIDs(c.GetString("UserID"), c.Query("AfterID"), limit) if err != nil { c.AbortWithStatus(http.StatusUnprocessableEntity) return } c.JSON(http.StatusOK, gin.H{ "IDs": messageIDs, }) } } func (s *Server) handleGetMailMessage() gin.HandlerFunc { return func(c *gin.Context) { message, err := s.b.GetMessage(c.GetString("UserID"), c.Param("messageID")) if err != nil { c.AbortWithStatusJSON(http.StatusUnprocessableEntity, proton.APIError{ Code: proton.InvalidValue, Message: fmt.Sprintf("Message %s not found", c.Param("messageID")), }) return } c.JSON(http.StatusOK, gin.H{ "Message": message, }) } } func (s *Server) handlePostMailMessage() gin.HandlerFunc { return func(c *gin.Context) { var req proton.SendDraftReq if err := c.BindJSON(&req); err != nil { c.AbortWithStatus(http.StatusBadRequest) return } message, err := s.b.SendMessage(c.GetString("UserID"), c.Param("messageID"), req.Packages) if err != nil { c.AbortWithStatus(http.StatusUnprocessableEntity) return } c.JSON(http.StatusOK, gin.H{ "Sent": message, }) } } func (s *Server) handlePutMailMessage() gin.HandlerFunc { return func(c *gin.Context) { var req proton.UpdateDraftReq if err := c.BindJSON(&req); err != nil { c.AbortWithStatus(http.StatusBadRequest) return } message, err := s.b.UpdateDraft(c.GetString("UserID"), c.Param("messageID"), req.Message) if err != nil { c.AbortWithStatus(http.StatusUnprocessableEntity) return } c.JSON(http.StatusOK, gin.H{ "Message": message, }) } } func (s *Server) handlePutMailMessagesRead() gin.HandlerFunc { return func(c *gin.Context) { var req proton.MessageActionReq if err := c.BindJSON(&req); err != nil { c.AbortWithStatus(http.StatusBadRequest) return } if err := s.b.SetMessagesRead(c.GetString("UserID"), true, req.IDs...); err != nil { c.AbortWithStatus(http.StatusUnprocessableEntity) return } } } func (s *Server) handlePutMailMessagesUnread() gin.HandlerFunc { return func(c *gin.Context) { var req proton.MessageActionReq if err := c.BindJSON(&req); err != nil { c.AbortWithStatus(http.StatusBadRequest) return } if err := s.b.SetMessagesRead(c.GetString("UserID"), false, req.IDs...); err != nil { c.AbortWithStatus(http.StatusUnprocessableEntity) return } } } func (s *Server) handlePutMailMessagesLabel() gin.HandlerFunc { return func(c *gin.Context) { var req proton.LabelMessagesReq if err := c.BindJSON(&req); err != nil { c.AbortWithStatus(http.StatusBadRequest) return } if err := s.b.LabelMessages(c.GetString("UserID"), req.LabelID, req.IDs...); err != nil { c.AbortWithStatus(http.StatusUnprocessableEntity) return } } } func (s *Server) handlePutMailMessagesUnlabel() gin.HandlerFunc { return func(c *gin.Context) { var req proton.LabelMessagesReq if err := c.BindJSON(&req); err != nil { c.AbortWithStatus(http.StatusBadRequest) return } if err := s.b.UnlabelMessages(c.GetString("UserID"), req.LabelID, req.IDs...); err != nil { c.AbortWithStatus(http.StatusUnprocessableEntity) return } } } func (s *Server) handlePutMailMessagesImport() gin.HandlerFunc { return func(c *gin.Context) { form, err := c.MultipartForm() if err != nil { c.AbortWithStatus(http.StatusBadRequest) return } var metadata map[string]proton.ImportMetadata if err := json.Unmarshal([]byte(form.Value["Metadata"][0]), &metadata); err != nil { c.AbortWithStatus(http.StatusBadRequest) return } files := make(map[string][]byte) for name, file := range form.File { files[name] = mustReadFileHeader(file[0]) } type response struct { Name string Response proton.ImportRes } var responses []response for name, literal := range files { res := response{Name: name} messageID, err := s.importMessage( c.GetString("UserID"), metadata[name].AddressID, metadata[name].LabelIDs, literal, metadata[name].Flags, bool(metadata[name].Unread), ) if err != nil { res.Response = proton.ImportRes{ APIError: proton.APIError{ Code: proton.InvalidValue, Message: fmt.Sprintf("failed to import: %v", err), }, } } else { res.Response = proton.ImportRes{ APIError: proton.APIError{Code: proton.SuccessCode}, MessageID: messageID, } } responses = append(responses, res) } c.JSON(http.StatusOK, gin.H{ "Code": proton.MultiCode, "Responses": responses, }) } } func (s *Server) handleDeleteMailMessages() gin.HandlerFunc { return func(c *gin.Context) { var req proton.MessageActionReq if err := c.BindJSON(&req); err != nil { c.AbortWithStatus(http.StatusBadRequest) return } for _, messageID := range req.IDs { if err := s.b.DeleteMessage(c.GetString("UserID"), messageID); err != nil { c.AbortWithStatus(http.StatusUnprocessableEntity) return } } } } func (s *Server) handleMessageGroupCount() gin.HandlerFunc { return func(c *gin.Context) { count, err := s.b.GetMessageGroupCount(c.GetString("UserID")) if err != nil { c.AbortWithStatusJSON(http.StatusUnprocessableEntity, proton.APIError{ Code: proton.InvalidValue, Message: fmt.Sprintf("Message %s not found", c.Param("messageID")), }) return } c.JSON(http.StatusOK, gin.H{ "Counts": count, }) } } func (s *Server) importMessage( userID, addrID string, labelIDs []string, literal []byte, flags proton.MessageFlag, unread bool, ) (string, error) { var exclusive int for _, labelID := range labelIDs { switch labelID { case proton.AllDraftsLabel, proton.AllSentLabel, proton.AllMailLabel, proton.OutboxLabel: return "", fmt.Errorf("invalid label ID: %s", labelID) } label, err := s.b.GetLabel(userID, labelID) if err != nil { return "", fmt.Errorf("invalid label ID: %s", labelID) } if label.Type != proton.LabelTypeLabel { exclusive++ } } if exclusive > 1 { return "", fmt.Errorf("too many exclusive labels") } header, body, atts, mimeType, err := s.parseMessage(literal) if err != nil { return "", fmt.Errorf("failed to parse message: %w", err) } messageID, err := s.importBody(userID, addrID, header, body, mimeType, flags, unread, slices.Contains(labelIDs, proton.StarredLabel)) if err != nil { return "", fmt.Errorf("failed to import message: %w", err) } for _, att := range atts { if _, err := s.importAttachment(userID, messageID, att); err != nil { return "", fmt.Errorf("failed to import attachment: %w", err) } } for _, labelID := range labelIDs { if err := s.b.LabelMessagesNoEvents(userID, labelID, messageID); err != nil { return "", fmt.Errorf("failed to label message: %w", err) } } return messageID, nil } func (s *Server) parseMessage(literal []byte) (*rfc822.Header, []string, []*rfc822.Section, rfc822.MIMEType, error) { root := rfc822.Parse(literal) header, err := root.ParseHeader() if err != nil { return nil, nil, nil, "", fmt.Errorf("failed to parse header: %w", err) } body, atts, err := collect(root) if err != nil { return nil, nil, nil, "", fmt.Errorf("failed to collect body and attachments: %w", err) } mimeType, _, err := root.ContentType() if err != nil { return nil, nil, nil, "", fmt.Errorf("failed to parse content type: %w", err) } // Force all multipart types to be multipart/mixed. if mimeType.Type() == "multipart" { mimeType = "multipart/mixed" } return header, body, atts, mimeType, nil } func collect(section *rfc822.Section) ([]string, []*rfc822.Section, error) { mimeType, _, err := section.ContentType() if err != nil { return nil, nil, fmt.Errorf("failed to parse content type: %w", err) } switch mimeType.Type() { case "text": return []string{string(section.Body())}, nil, nil case "multipart": children, err := section.Children() if err != nil { return nil, nil, fmt.Errorf("failed to parse children: %w", err) } switch mimeType.SubType() { case "encrypted": if len(children) != 2 { return nil, nil, fmt.Errorf("expected two children for multipart/encrypted, got %d", len(children)) } return []string{string(children[1].Body())}, nil, nil default: var ( multiBody []string multiAtts []*rfc822.Section ) for _, child := range children { body, atts, err := collect(child) if err != nil { return nil, nil, fmt.Errorf("failed to collect child: %w", err) } multiBody = append(multiBody, body...) multiAtts = append(multiAtts, atts...) } return multiBody, multiAtts, nil } default: return nil, []*rfc822.Section{section}, nil } } func (s *Server) importBody( userID, addrID string, header *rfc822.Header, body []string, mimeType rfc822.MIMEType, flags proton.MessageFlag, unread, starred bool, ) (string, error) { subject := header.Get("Subject") sender := tryParseAddress(header.Get("From")) toList := tryParseAddressList(header.Get("To")) ccList := tryParseAddressList(header.Get("Cc")) bccList := tryParseAddressList(header.Get("Bcc")) replytos := tryParseAddressList(header.Get("Reply-To")) date := time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC) headerDate := header.Get("Date") if len(headerDate) != 0 { d, err := mail.ParseDate(headerDate) if err != nil { return "", err } date = d } // NOTE: Importing without sender adds empty sender on API side if sender == nil { sender = &mail.Address{} } // NOTE: Importing without sender adds empty reply to on API side if len(replytos) == 0 { replytos = []*mail.Address{{}} } // NOTE: Importing just the first body part matches API behaviour but sucks! return s.b.CreateMessage( userID, addrID, subject, sender, toList, ccList, bccList, replytos, string(body[0]), rfc822.MIMEType(mimeType), flags, date, unread, starred, ) } func (s *Server) importAttachment(userID, messageID string, att *rfc822.Section) (proton.Attachment, error) { header, err := att.ParseHeader() if err != nil { return proton.Attachment{}, fmt.Errorf("failed to parse attachment header: %w", err) } mimeType, _, err := att.ContentType() if err != nil { return proton.Attachment{}, fmt.Errorf("failed to parse attachment content type: %w", err) } var disposition, filename string if !header.Has("Content-Disposition") { disposition = "attachment" filename = "attachment.bin" } else if dispType, dispParams, err := mime.ParseMediaType(header.Get("Content-Disposition")); err == nil { disposition = dispType filename = dispParams["filename"] } else { disposition = "attachment" filename = "attachment.bin" } var body *crypto.PGPSplitMessage if header.Get("Content-Transfer-Encoding") == "base64" { b := make([]byte, base64.StdEncoding.DecodedLen(len(att.Body()))) n, err := base64.StdEncoding.Decode(b, att.Body()) if err != nil { return proton.Attachment{}, fmt.Errorf("failed to decode attachment body: %w", err) } split, err := crypto.NewPGPMessage(b[:n]).SplitMessage() if err != nil { return proton.Attachment{}, fmt.Errorf("failed to split attachment body: %w", err) } body = split } else { msg, err := crypto.NewPGPMessageFromArmored(string(att.Body())) if err != nil { return proton.Attachment{}, fmt.Errorf("failed to parse attachment body: %w", err) } split, err := msg.SplitMessage() if err != nil { return proton.Attachment{}, fmt.Errorf("failed to split attachment body: %w", err) } body = split } // TODO: What about the signature? return s.b.CreateAttachment( userID, messageID, filename, mimeType, proton.Disposition(disposition), header.Get("Content-Id"), body.GetBinaryKeyPacket(), body.GetBinaryDataPacket(), "", ) } func tryParseAddress(s string) *mail.Address { if s == "" { return nil } addr, err := mail.ParseAddress(s) if err != nil { return &mail.Address{ Name: s, } } return addr } func tryParseAddressList(s string) []*mail.Address { if s == "" { return nil } addrs, err := mail.ParseAddressList(s) if err != nil { return []*mail.Address{{ Name: s, }} } return addrs } func mustParseInt(num string) int { val, err := strconv.Atoi(num) if err != nil { panic(err) } return val } go-proton-api-1.0.0/server/ping.go000066400000000000000000000002101447642273300170020ustar00rootroot00000000000000package server import "github.com/gin-gonic/gin" func (s *Server) handleGetPing() gin.HandlerFunc { return func(c *gin.Context) {} } go-proton-api-1.0.0/server/proto/000077500000000000000000000000001447642273300166705ustar00rootroot00000000000000go-proton-api-1.0.0/server/proto/server.go000066400000000000000000000002171447642273300205250ustar00rootroot00000000000000package proto //go:generate protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative server.proto go-proton-api-1.0.0/server/proto/server.pb.go000066400000000000000000000751701447642273300211370ustar00rootroot00000000000000// Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.28.0 // protoc v3.21.10 // source: server.proto package proto import ( protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" reflect "reflect" sync "sync" ) const ( // Verify that this generated code is sufficiently up-to-date. _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) // Verify that runtime/protoimpl is sufficiently up-to-date. _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) type LabelType int32 const ( LabelType_FOLDER LabelType = 0 LabelType_LABEL LabelType = 1 ) // Enum value maps for LabelType. var ( LabelType_name = map[int32]string{ 0: "FOLDER", 1: "LABEL", } LabelType_value = map[string]int32{ "FOLDER": 0, "LABEL": 1, } ) func (x LabelType) Enum() *LabelType { p := new(LabelType) *p = x return p } func (x LabelType) String() string { return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) } func (LabelType) Descriptor() protoreflect.EnumDescriptor { return file_server_proto_enumTypes[0].Descriptor() } func (LabelType) Type() protoreflect.EnumType { return &file_server_proto_enumTypes[0] } func (x LabelType) Number() protoreflect.EnumNumber { return protoreflect.EnumNumber(x) } // Deprecated: Use LabelType.Descriptor instead. func (LabelType) EnumDescriptor() ([]byte, []int) { return file_server_proto_rawDescGZIP(), []int{0} } type GetInfoRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields } func (x *GetInfoRequest) Reset() { *x = GetInfoRequest{} if protoimpl.UnsafeEnabled { mi := &file_server_proto_msgTypes[0] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } } func (x *GetInfoRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*GetInfoRequest) ProtoMessage() {} func (x *GetInfoRequest) ProtoReflect() protoreflect.Message { mi := &file_server_proto_msgTypes[0] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use GetInfoRequest.ProtoReflect.Descriptor instead. func (*GetInfoRequest) Descriptor() ([]byte, []int) { return file_server_proto_rawDescGZIP(), []int{0} } type GetInfoResponse struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields HostURL string `protobuf:"bytes,1,opt,name=hostURL,proto3" json:"hostURL,omitempty"` ProxyURL string `protobuf:"bytes,2,opt,name=proxyURL,proto3" json:"proxyURL,omitempty"` } func (x *GetInfoResponse) Reset() { *x = GetInfoResponse{} if protoimpl.UnsafeEnabled { mi := &file_server_proto_msgTypes[1] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } } func (x *GetInfoResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*GetInfoResponse) ProtoMessage() {} func (x *GetInfoResponse) ProtoReflect() protoreflect.Message { mi := &file_server_proto_msgTypes[1] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use GetInfoResponse.ProtoReflect.Descriptor instead. func (*GetInfoResponse) Descriptor() ([]byte, []int) { return file_server_proto_rawDescGZIP(), []int{1} } func (x *GetInfoResponse) GetHostURL() string { if x != nil { return x.HostURL } return "" } func (x *GetInfoResponse) GetProxyURL() string { if x != nil { return x.ProxyURL } return "" } type CreateUserRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields Username string `protobuf:"bytes,1,opt,name=username,proto3" json:"username,omitempty"` Password []byte `protobuf:"bytes,3,opt,name=password,proto3" json:"password,omitempty"` } func (x *CreateUserRequest) Reset() { *x = CreateUserRequest{} if protoimpl.UnsafeEnabled { mi := &file_server_proto_msgTypes[2] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } } func (x *CreateUserRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*CreateUserRequest) ProtoMessage() {} func (x *CreateUserRequest) ProtoReflect() protoreflect.Message { mi := &file_server_proto_msgTypes[2] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use CreateUserRequest.ProtoReflect.Descriptor instead. func (*CreateUserRequest) Descriptor() ([]byte, []int) { return file_server_proto_rawDescGZIP(), []int{2} } func (x *CreateUserRequest) GetUsername() string { if x != nil { return x.Username } return "" } func (x *CreateUserRequest) GetPassword() []byte { if x != nil { return x.Password } return nil } type CreateUserResponse struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields UserID string `protobuf:"bytes,1,opt,name=userID,proto3" json:"userID,omitempty"` AddrID string `protobuf:"bytes,2,opt,name=addrID,proto3" json:"addrID,omitempty"` } func (x *CreateUserResponse) Reset() { *x = CreateUserResponse{} if protoimpl.UnsafeEnabled { mi := &file_server_proto_msgTypes[3] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } } func (x *CreateUserResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*CreateUserResponse) ProtoMessage() {} func (x *CreateUserResponse) ProtoReflect() protoreflect.Message { mi := &file_server_proto_msgTypes[3] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use CreateUserResponse.ProtoReflect.Descriptor instead. func (*CreateUserResponse) Descriptor() ([]byte, []int) { return file_server_proto_rawDescGZIP(), []int{3} } func (x *CreateUserResponse) GetUserID() string { if x != nil { return x.UserID } return "" } func (x *CreateUserResponse) GetAddrID() string { if x != nil { return x.AddrID } return "" } type RevokeUserRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields UserID string `protobuf:"bytes,1,opt,name=userID,proto3" json:"userID,omitempty"` } func (x *RevokeUserRequest) Reset() { *x = RevokeUserRequest{} if protoimpl.UnsafeEnabled { mi := &file_server_proto_msgTypes[4] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } } func (x *RevokeUserRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*RevokeUserRequest) ProtoMessage() {} func (x *RevokeUserRequest) ProtoReflect() protoreflect.Message { mi := &file_server_proto_msgTypes[4] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use RevokeUserRequest.ProtoReflect.Descriptor instead. func (*RevokeUserRequest) Descriptor() ([]byte, []int) { return file_server_proto_rawDescGZIP(), []int{4} } func (x *RevokeUserRequest) GetUserID() string { if x != nil { return x.UserID } return "" } type RevokeUserResponse struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields } func (x *RevokeUserResponse) Reset() { *x = RevokeUserResponse{} if protoimpl.UnsafeEnabled { mi := &file_server_proto_msgTypes[5] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } } func (x *RevokeUserResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*RevokeUserResponse) ProtoMessage() {} func (x *RevokeUserResponse) ProtoReflect() protoreflect.Message { mi := &file_server_proto_msgTypes[5] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use RevokeUserResponse.ProtoReflect.Descriptor instead. func (*RevokeUserResponse) Descriptor() ([]byte, []int) { return file_server_proto_rawDescGZIP(), []int{5} } type CreateAddressRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields UserID string `protobuf:"bytes,1,opt,name=userID,proto3" json:"userID,omitempty"` Email string `protobuf:"bytes,2,opt,name=email,proto3" json:"email,omitempty"` Password []byte `protobuf:"bytes,3,opt,name=password,proto3" json:"password,omitempty"` } func (x *CreateAddressRequest) Reset() { *x = CreateAddressRequest{} if protoimpl.UnsafeEnabled { mi := &file_server_proto_msgTypes[6] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } } func (x *CreateAddressRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*CreateAddressRequest) ProtoMessage() {} func (x *CreateAddressRequest) ProtoReflect() protoreflect.Message { mi := &file_server_proto_msgTypes[6] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use CreateAddressRequest.ProtoReflect.Descriptor instead. func (*CreateAddressRequest) Descriptor() ([]byte, []int) { return file_server_proto_rawDescGZIP(), []int{6} } func (x *CreateAddressRequest) GetUserID() string { if x != nil { return x.UserID } return "" } func (x *CreateAddressRequest) GetEmail() string { if x != nil { return x.Email } return "" } func (x *CreateAddressRequest) GetPassword() []byte { if x != nil { return x.Password } return nil } type CreateAddressResponse struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields AddrID string `protobuf:"bytes,1,opt,name=addrID,proto3" json:"addrID,omitempty"` } func (x *CreateAddressResponse) Reset() { *x = CreateAddressResponse{} if protoimpl.UnsafeEnabled { mi := &file_server_proto_msgTypes[7] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } } func (x *CreateAddressResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*CreateAddressResponse) ProtoMessage() {} func (x *CreateAddressResponse) ProtoReflect() protoreflect.Message { mi := &file_server_proto_msgTypes[7] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use CreateAddressResponse.ProtoReflect.Descriptor instead. func (*CreateAddressResponse) Descriptor() ([]byte, []int) { return file_server_proto_rawDescGZIP(), []int{7} } func (x *CreateAddressResponse) GetAddrID() string { if x != nil { return x.AddrID } return "" } type RemoveAddressRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields UserID string `protobuf:"bytes,1,opt,name=userID,proto3" json:"userID,omitempty"` AddrID string `protobuf:"bytes,2,opt,name=addrID,proto3" json:"addrID,omitempty"` } func (x *RemoveAddressRequest) Reset() { *x = RemoveAddressRequest{} if protoimpl.UnsafeEnabled { mi := &file_server_proto_msgTypes[8] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } } func (x *RemoveAddressRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*RemoveAddressRequest) ProtoMessage() {} func (x *RemoveAddressRequest) ProtoReflect() protoreflect.Message { mi := &file_server_proto_msgTypes[8] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use RemoveAddressRequest.ProtoReflect.Descriptor instead. func (*RemoveAddressRequest) Descriptor() ([]byte, []int) { return file_server_proto_rawDescGZIP(), []int{8} } func (x *RemoveAddressRequest) GetUserID() string { if x != nil { return x.UserID } return "" } func (x *RemoveAddressRequest) GetAddrID() string { if x != nil { return x.AddrID } return "" } type RemoveAddressResponse struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields } func (x *RemoveAddressResponse) Reset() { *x = RemoveAddressResponse{} if protoimpl.UnsafeEnabled { mi := &file_server_proto_msgTypes[9] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } } func (x *RemoveAddressResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*RemoveAddressResponse) ProtoMessage() {} func (x *RemoveAddressResponse) ProtoReflect() protoreflect.Message { mi := &file_server_proto_msgTypes[9] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use RemoveAddressResponse.ProtoReflect.Descriptor instead. func (*RemoveAddressResponse) Descriptor() ([]byte, []int) { return file_server_proto_rawDescGZIP(), []int{9} } type CreateLabelRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields UserID string `protobuf:"bytes,1,opt,name=userID,proto3" json:"userID,omitempty"` Name string `protobuf:"bytes,2,opt,name=name,proto3" json:"name,omitempty"` ParentID string `protobuf:"bytes,3,opt,name=parentID,proto3" json:"parentID,omitempty"` Type LabelType `protobuf:"varint,4,opt,name=type,proto3,enum=proto.LabelType" json:"type,omitempty"` } func (x *CreateLabelRequest) Reset() { *x = CreateLabelRequest{} if protoimpl.UnsafeEnabled { mi := &file_server_proto_msgTypes[10] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } } func (x *CreateLabelRequest) String() string { return protoimpl.X.MessageStringOf(x) } func (*CreateLabelRequest) ProtoMessage() {} func (x *CreateLabelRequest) ProtoReflect() protoreflect.Message { mi := &file_server_proto_msgTypes[10] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use CreateLabelRequest.ProtoReflect.Descriptor instead. func (*CreateLabelRequest) Descriptor() ([]byte, []int) { return file_server_proto_rawDescGZIP(), []int{10} } func (x *CreateLabelRequest) GetUserID() string { if x != nil { return x.UserID } return "" } func (x *CreateLabelRequest) GetName() string { if x != nil { return x.Name } return "" } func (x *CreateLabelRequest) GetParentID() string { if x != nil { return x.ParentID } return "" } func (x *CreateLabelRequest) GetType() LabelType { if x != nil { return x.Type } return LabelType_FOLDER } type CreateLabelResponse struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields LabelID string `protobuf:"bytes,1,opt,name=labelID,proto3" json:"labelID,omitempty"` } func (x *CreateLabelResponse) Reset() { *x = CreateLabelResponse{} if protoimpl.UnsafeEnabled { mi := &file_server_proto_msgTypes[11] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } } func (x *CreateLabelResponse) String() string { return protoimpl.X.MessageStringOf(x) } func (*CreateLabelResponse) ProtoMessage() {} func (x *CreateLabelResponse) ProtoReflect() protoreflect.Message { mi := &file_server_proto_msgTypes[11] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) } return ms } return mi.MessageOf(x) } // Deprecated: Use CreateLabelResponse.ProtoReflect.Descriptor instead. func (*CreateLabelResponse) Descriptor() ([]byte, []int) { return file_server_proto_rawDescGZIP(), []int{11} } func (x *CreateLabelResponse) GetLabelID() string { if x != nil { return x.LabelID } return "" } var File_server_proto protoreflect.FileDescriptor var file_server_proto_rawDesc = []byte{ 0x0a, 0x0c, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x05, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x10, 0x0a, 0x0e, 0x47, 0x65, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x47, 0x0a, 0x0f, 0x47, 0x65, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x68, 0x6f, 0x73, 0x74, 0x55, 0x52, 0x4c, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x68, 0x6f, 0x73, 0x74, 0x55, 0x52, 0x4c, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x55, 0x52, 0x4c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x55, 0x52, 0x4c, 0x22, 0x4b, 0x0a, 0x11, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x55, 0x73, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x1a, 0x0a, 0x08, 0x75, 0x73, 0x65, 0x72, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x75, 0x73, 0x65, 0x72, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x22, 0x44, 0x0a, 0x12, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x55, 0x73, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x44, 0x12, 0x16, 0x0a, 0x06, 0x61, 0x64, 0x64, 0x72, 0x49, 0x44, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x61, 0x64, 0x64, 0x72, 0x49, 0x44, 0x22, 0x2b, 0x0a, 0x11, 0x52, 0x65, 0x76, 0x6f, 0x6b, 0x65, 0x55, 0x73, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x44, 0x22, 0x14, 0x0a, 0x12, 0x52, 0x65, 0x76, 0x6f, 0x6b, 0x65, 0x55, 0x73, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x60, 0x0a, 0x14, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x44, 0x12, 0x14, 0x0a, 0x05, 0x65, 0x6d, 0x61, 0x69, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x05, 0x65, 0x6d, 0x61, 0x69, 0x6c, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x08, 0x70, 0x61, 0x73, 0x73, 0x77, 0x6f, 0x72, 0x64, 0x22, 0x2f, 0x0a, 0x15, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x61, 0x64, 0x64, 0x72, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x61, 0x64, 0x64, 0x72, 0x49, 0x44, 0x22, 0x46, 0x0a, 0x14, 0x52, 0x65, 0x6d, 0x6f, 0x76, 0x65, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x44, 0x12, 0x16, 0x0a, 0x06, 0x61, 0x64, 0x64, 0x72, 0x49, 0x44, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x61, 0x64, 0x64, 0x72, 0x49, 0x44, 0x22, 0x17, 0x0a, 0x15, 0x52, 0x65, 0x6d, 0x6f, 0x76, 0x65, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x82, 0x01, 0x0a, 0x12, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x4c, 0x61, 0x62, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x75, 0x73, 0x65, 0x72, 0x49, 0x44, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x61, 0x72, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x61, 0x72, 0x65, 0x6e, 0x74, 0x49, 0x44, 0x12, 0x24, 0x0a, 0x04, 0x74, 0x79, 0x70, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x10, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x4c, 0x61, 0x62, 0x65, 0x6c, 0x54, 0x79, 0x70, 0x65, 0x52, 0x04, 0x74, 0x79, 0x70, 0x65, 0x22, 0x2f, 0x0a, 0x13, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x4c, 0x61, 0x62, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x6c, 0x61, 0x62, 0x65, 0x6c, 0x49, 0x44, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6c, 0x61, 0x62, 0x65, 0x6c, 0x49, 0x44, 0x2a, 0x22, 0x0a, 0x09, 0x4c, 0x61, 0x62, 0x65, 0x6c, 0x54, 0x79, 0x70, 0x65, 0x12, 0x0a, 0x0a, 0x06, 0x46, 0x4f, 0x4c, 0x44, 0x45, 0x52, 0x10, 0x00, 0x12, 0x09, 0x0a, 0x05, 0x4c, 0x41, 0x42, 0x45, 0x4c, 0x10, 0x01, 0x32, 0xa6, 0x03, 0x0a, 0x06, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x12, 0x38, 0x0a, 0x07, 0x47, 0x65, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x15, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x47, 0x65, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x47, 0x65, 0x74, 0x49, 0x6e, 0x66, 0x6f, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x41, 0x0a, 0x0a, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x55, 0x73, 0x65, 0x72, 0x12, 0x18, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x55, 0x73, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x55, 0x73, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x41, 0x0a, 0x0a, 0x52, 0x65, 0x76, 0x6f, 0x6b, 0x65, 0x55, 0x73, 0x65, 0x72, 0x12, 0x18, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x76, 0x6f, 0x6b, 0x65, 0x55, 0x73, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x76, 0x6f, 0x6b, 0x65, 0x55, 0x73, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x4a, 0x0a, 0x0d, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x1b, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x4a, 0x0a, 0x0d, 0x52, 0x65, 0x6d, 0x6f, 0x76, 0x65, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x12, 0x1b, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x76, 0x65, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x52, 0x65, 0x6d, 0x6f, 0x76, 0x65, 0x41, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x44, 0x0a, 0x0b, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x4c, 0x61, 0x62, 0x65, 0x6c, 0x12, 0x19, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x4c, 0x61, 0x62, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x4c, 0x61, 0x62, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x32, 0x5a, 0x30, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x6e, 0x4d, 0x61, 0x69, 0x6c, 0x2f, 0x67, 0x6f, 0x2d, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x6e, 0x2d, 0x61, 0x70, 0x69, 0x2f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( file_server_proto_rawDescOnce sync.Once file_server_proto_rawDescData = file_server_proto_rawDesc ) func file_server_proto_rawDescGZIP() []byte { file_server_proto_rawDescOnce.Do(func() { file_server_proto_rawDescData = protoimpl.X.CompressGZIP(file_server_proto_rawDescData) }) return file_server_proto_rawDescData } var file_server_proto_enumTypes = make([]protoimpl.EnumInfo, 1) var file_server_proto_msgTypes = make([]protoimpl.MessageInfo, 12) var file_server_proto_goTypes = []interface{}{ (LabelType)(0), // 0: proto.LabelType (*GetInfoRequest)(nil), // 1: proto.GetInfoRequest (*GetInfoResponse)(nil), // 2: proto.GetInfoResponse (*CreateUserRequest)(nil), // 3: proto.CreateUserRequest (*CreateUserResponse)(nil), // 4: proto.CreateUserResponse (*RevokeUserRequest)(nil), // 5: proto.RevokeUserRequest (*RevokeUserResponse)(nil), // 6: proto.RevokeUserResponse (*CreateAddressRequest)(nil), // 7: proto.CreateAddressRequest (*CreateAddressResponse)(nil), // 8: proto.CreateAddressResponse (*RemoveAddressRequest)(nil), // 9: proto.RemoveAddressRequest (*RemoveAddressResponse)(nil), // 10: proto.RemoveAddressResponse (*CreateLabelRequest)(nil), // 11: proto.CreateLabelRequest (*CreateLabelResponse)(nil), // 12: proto.CreateLabelResponse } var file_server_proto_depIdxs = []int32{ 0, // 0: proto.CreateLabelRequest.type:type_name -> proto.LabelType 1, // 1: proto.Server.GetInfo:input_type -> proto.GetInfoRequest 3, // 2: proto.Server.CreateUser:input_type -> proto.CreateUserRequest 5, // 3: proto.Server.RevokeUser:input_type -> proto.RevokeUserRequest 7, // 4: proto.Server.CreateAddress:input_type -> proto.CreateAddressRequest 9, // 5: proto.Server.RemoveAddress:input_type -> proto.RemoveAddressRequest 11, // 6: proto.Server.CreateLabel:input_type -> proto.CreateLabelRequest 2, // 7: proto.Server.GetInfo:output_type -> proto.GetInfoResponse 4, // 8: proto.Server.CreateUser:output_type -> proto.CreateUserResponse 6, // 9: proto.Server.RevokeUser:output_type -> proto.RevokeUserResponse 8, // 10: proto.Server.CreateAddress:output_type -> proto.CreateAddressResponse 10, // 11: proto.Server.RemoveAddress:output_type -> proto.RemoveAddressResponse 12, // 12: proto.Server.CreateLabel:output_type -> proto.CreateLabelResponse 7, // [7:13] is the sub-list for method output_type 1, // [1:7] is the sub-list for method input_type 1, // [1:1] is the sub-list for extension type_name 1, // [1:1] is the sub-list for extension extendee 0, // [0:1] is the sub-list for field type_name } func init() { file_server_proto_init() } func file_server_proto_init() { if File_server_proto != nil { return } if !protoimpl.UnsafeEnabled { file_server_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*GetInfoRequest); i { case 0: return &v.state case 1: return &v.sizeCache case 2: return &v.unknownFields default: return nil } } file_server_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*GetInfoResponse); i { case 0: return &v.state case 1: return &v.sizeCache case 2: return &v.unknownFields default: return nil } } file_server_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*CreateUserRequest); i { case 0: return &v.state case 1: return &v.sizeCache case 2: return &v.unknownFields default: return nil } } file_server_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*CreateUserResponse); i { case 0: return &v.state case 1: return &v.sizeCache case 2: return &v.unknownFields default: return nil } } file_server_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*RevokeUserRequest); i { case 0: return &v.state case 1: return &v.sizeCache case 2: return &v.unknownFields default: return nil } } file_server_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*RevokeUserResponse); i { case 0: return &v.state case 1: return &v.sizeCache case 2: return &v.unknownFields default: return nil } } file_server_proto_msgTypes[6].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*CreateAddressRequest); i { case 0: return &v.state case 1: return &v.sizeCache case 2: return &v.unknownFields default: return nil } } file_server_proto_msgTypes[7].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*CreateAddressResponse); i { case 0: return &v.state case 1: return &v.sizeCache case 2: return &v.unknownFields default: return nil } } file_server_proto_msgTypes[8].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*RemoveAddressRequest); i { case 0: return &v.state case 1: return &v.sizeCache case 2: return &v.unknownFields default: return nil } } file_server_proto_msgTypes[9].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*RemoveAddressResponse); i { case 0: return &v.state case 1: return &v.sizeCache case 2: return &v.unknownFields default: return nil } } file_server_proto_msgTypes[10].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*CreateLabelRequest); i { case 0: return &v.state case 1: return &v.sizeCache case 2: return &v.unknownFields default: return nil } } file_server_proto_msgTypes[11].Exporter = func(v interface{}, i int) interface{} { switch v := v.(*CreateLabelResponse); i { case 0: return &v.state case 1: return &v.sizeCache case 2: return &v.unknownFields default: return nil } } } type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_server_proto_rawDesc, NumEnums: 1, NumMessages: 12, NumExtensions: 0, NumServices: 1, }, GoTypes: file_server_proto_goTypes, DependencyIndexes: file_server_proto_depIdxs, EnumInfos: file_server_proto_enumTypes, MessageInfos: file_server_proto_msgTypes, }.Build() File_server_proto = out.File file_server_proto_rawDesc = nil file_server_proto_goTypes = nil file_server_proto_depIdxs = nil } go-proton-api-1.0.0/server/proto/server.proto000066400000000000000000000034141447642273300212650ustar00rootroot00000000000000syntax = "proto3"; option go_package = "github.com/henrybear327/go-proton-api/server/proto"; package proto; //********************************************************************************************************************** // Service Declaration //********************************************************************************************************************** service Server { rpc GetInfo (GetInfoRequest) returns (GetInfoResponse); rpc CreateUser(CreateUserRequest) returns (CreateUserResponse); rpc RevokeUser(RevokeUserRequest) returns (RevokeUserResponse); rpc CreateAddress(CreateAddressRequest) returns (CreateAddressResponse); rpc RemoveAddress(RemoveAddressRequest) returns (RemoveAddressResponse); rpc CreateLabel(CreateLabelRequest) returns (CreateLabelResponse); } //********************************************************************************************************************** message GetInfoRequest { } message GetInfoResponse { string hostURL = 1; string proxyURL = 2; } message CreateUserRequest { string username = 1; bytes password = 3; } message CreateUserResponse { string userID = 1; string addrID = 2; } message RevokeUserRequest { string userID = 1; } message RevokeUserResponse { } message CreateAddressRequest { string userID = 1; string email = 2; bytes password = 3; } message CreateAddressResponse { string addrID = 1; } message RemoveAddressRequest { string userID = 1; string addrID = 2; } message RemoveAddressResponse { } enum LabelType { FOLDER = 0; LABEL = 1; } message CreateLabelRequest { string userID = 1; string name = 2; string parentID = 3; LabelType type = 4; } message CreateLabelResponse { string labelID = 1; } go-proton-api-1.0.0/server/proto/server_grpc.pb.go000066400000000000000000000244261447642273300221500ustar00rootroot00000000000000// Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: // - protoc-gen-go-grpc v1.2.0 // - protoc v3.21.10 // source: server.proto package proto import ( context "context" grpc "google.golang.org/grpc" codes "google.golang.org/grpc/codes" status "google.golang.org/grpc/status" ) // This is a compile-time assertion to ensure that this generated file // is compatible with the grpc package it is being compiled against. // Requires gRPC-Go v1.32.0 or later. const _ = grpc.SupportPackageIsVersion7 // ServerClient is the client API for Server service. // // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. type ServerClient interface { GetInfo(ctx context.Context, in *GetInfoRequest, opts ...grpc.CallOption) (*GetInfoResponse, error) CreateUser(ctx context.Context, in *CreateUserRequest, opts ...grpc.CallOption) (*CreateUserResponse, error) RevokeUser(ctx context.Context, in *RevokeUserRequest, opts ...grpc.CallOption) (*RevokeUserResponse, error) CreateAddress(ctx context.Context, in *CreateAddressRequest, opts ...grpc.CallOption) (*CreateAddressResponse, error) RemoveAddress(ctx context.Context, in *RemoveAddressRequest, opts ...grpc.CallOption) (*RemoveAddressResponse, error) CreateLabel(ctx context.Context, in *CreateLabelRequest, opts ...grpc.CallOption) (*CreateLabelResponse, error) } type serverClient struct { cc grpc.ClientConnInterface } func NewServerClient(cc grpc.ClientConnInterface) ServerClient { return &serverClient{cc} } func (c *serverClient) GetInfo(ctx context.Context, in *GetInfoRequest, opts ...grpc.CallOption) (*GetInfoResponse, error) { out := new(GetInfoResponse) err := c.cc.Invoke(ctx, "/proto.Server/GetInfo", in, out, opts...) if err != nil { return nil, err } return out, nil } func (c *serverClient) CreateUser(ctx context.Context, in *CreateUserRequest, opts ...grpc.CallOption) (*CreateUserResponse, error) { out := new(CreateUserResponse) err := c.cc.Invoke(ctx, "/proto.Server/CreateUser", in, out, opts...) if err != nil { return nil, err } return out, nil } func (c *serverClient) RevokeUser(ctx context.Context, in *RevokeUserRequest, opts ...grpc.CallOption) (*RevokeUserResponse, error) { out := new(RevokeUserResponse) err := c.cc.Invoke(ctx, "/proto.Server/RevokeUser", in, out, opts...) if err != nil { return nil, err } return out, nil } func (c *serverClient) CreateAddress(ctx context.Context, in *CreateAddressRequest, opts ...grpc.CallOption) (*CreateAddressResponse, error) { out := new(CreateAddressResponse) err := c.cc.Invoke(ctx, "/proto.Server/CreateAddress", in, out, opts...) if err != nil { return nil, err } return out, nil } func (c *serverClient) RemoveAddress(ctx context.Context, in *RemoveAddressRequest, opts ...grpc.CallOption) (*RemoveAddressResponse, error) { out := new(RemoveAddressResponse) err := c.cc.Invoke(ctx, "/proto.Server/RemoveAddress", in, out, opts...) if err != nil { return nil, err } return out, nil } func (c *serverClient) CreateLabel(ctx context.Context, in *CreateLabelRequest, opts ...grpc.CallOption) (*CreateLabelResponse, error) { out := new(CreateLabelResponse) err := c.cc.Invoke(ctx, "/proto.Server/CreateLabel", in, out, opts...) if err != nil { return nil, err } return out, nil } // ServerServer is the server API for Server service. // All implementations must embed UnimplementedServerServer // for forward compatibility type ServerServer interface { GetInfo(context.Context, *GetInfoRequest) (*GetInfoResponse, error) CreateUser(context.Context, *CreateUserRequest) (*CreateUserResponse, error) RevokeUser(context.Context, *RevokeUserRequest) (*RevokeUserResponse, error) CreateAddress(context.Context, *CreateAddressRequest) (*CreateAddressResponse, error) RemoveAddress(context.Context, *RemoveAddressRequest) (*RemoveAddressResponse, error) CreateLabel(context.Context, *CreateLabelRequest) (*CreateLabelResponse, error) mustEmbedUnimplementedServerServer() } // UnimplementedServerServer must be embedded to have forward compatible implementations. type UnimplementedServerServer struct { } func (UnimplementedServerServer) GetInfo(context.Context, *GetInfoRequest) (*GetInfoResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method GetInfo not implemented") } func (UnimplementedServerServer) CreateUser(context.Context, *CreateUserRequest) (*CreateUserResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method CreateUser not implemented") } func (UnimplementedServerServer) RevokeUser(context.Context, *RevokeUserRequest) (*RevokeUserResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method RevokeUser not implemented") } func (UnimplementedServerServer) CreateAddress(context.Context, *CreateAddressRequest) (*CreateAddressResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method CreateAddress not implemented") } func (UnimplementedServerServer) RemoveAddress(context.Context, *RemoveAddressRequest) (*RemoveAddressResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method RemoveAddress not implemented") } func (UnimplementedServerServer) CreateLabel(context.Context, *CreateLabelRequest) (*CreateLabelResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method CreateLabel not implemented") } func (UnimplementedServerServer) mustEmbedUnimplementedServerServer() {} // UnsafeServerServer may be embedded to opt out of forward compatibility for this service. // Use of this interface is not recommended, as added methods to ServerServer will // result in compilation errors. type UnsafeServerServer interface { mustEmbedUnimplementedServerServer() } func RegisterServerServer(s grpc.ServiceRegistrar, srv ServerServer) { s.RegisterService(&Server_ServiceDesc, srv) } func _Server_GetInfo_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(GetInfoRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(ServerServer).GetInfo(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: "/proto.Server/GetInfo", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(ServerServer).GetInfo(ctx, req.(*GetInfoRequest)) } return interceptor(ctx, in, info, handler) } func _Server_CreateUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(CreateUserRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(ServerServer).CreateUser(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: "/proto.Server/CreateUser", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(ServerServer).CreateUser(ctx, req.(*CreateUserRequest)) } return interceptor(ctx, in, info, handler) } func _Server_RevokeUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(RevokeUserRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(ServerServer).RevokeUser(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: "/proto.Server/RevokeUser", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(ServerServer).RevokeUser(ctx, req.(*RevokeUserRequest)) } return interceptor(ctx, in, info, handler) } func _Server_CreateAddress_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(CreateAddressRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(ServerServer).CreateAddress(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: "/proto.Server/CreateAddress", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(ServerServer).CreateAddress(ctx, req.(*CreateAddressRequest)) } return interceptor(ctx, in, info, handler) } func _Server_RemoveAddress_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(RemoveAddressRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(ServerServer).RemoveAddress(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: "/proto.Server/RemoveAddress", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(ServerServer).RemoveAddress(ctx, req.(*RemoveAddressRequest)) } return interceptor(ctx, in, info, handler) } func _Server_CreateLabel_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(CreateLabelRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { return srv.(ServerServer).CreateLabel(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, FullMethod: "/proto.Server/CreateLabel", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(ServerServer).CreateLabel(ctx, req.(*CreateLabelRequest)) } return interceptor(ctx, in, info, handler) } // Server_ServiceDesc is the grpc.ServiceDesc for Server service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) var Server_ServiceDesc = grpc.ServiceDesc{ ServiceName: "proto.Server", HandlerType: (*ServerServer)(nil), Methods: []grpc.MethodDesc{ { MethodName: "GetInfo", Handler: _Server_GetInfo_Handler, }, { MethodName: "CreateUser", Handler: _Server_CreateUser_Handler, }, { MethodName: "RevokeUser", Handler: _Server_RevokeUser_Handler, }, { MethodName: "CreateAddress", Handler: _Server_CreateAddress_Handler, }, { MethodName: "RemoveAddress", Handler: _Server_RemoveAddress_Handler, }, { MethodName: "CreateLabel", Handler: _Server_CreateLabel_Handler, }, }, Streams: []grpc.StreamDesc{}, Metadata: "server.proto", } go-proton-api-1.0.0/server/proxy.go000066400000000000000000000125721447642273300172440ustar00rootroot00000000000000package server import ( "bytes" "compress/gzip" "encoding/json" "io" "net/http" "net/http/httputil" "net/url" "strings" "github.com/gin-gonic/gin" "github.com/henrybear327/go-proton-api" ) func newProxy(proxyOrigin, base, path string, transport http.RoundTripper) http.HandlerFunc { origin, err := url.Parse(proxyOrigin) if err != nil { panic(err) } return (&httputil.ReverseProxy{ Director: func(req *http.Request) { req.URL.Scheme = origin.Scheme req.URL.Host = origin.Host req.URL.Path = origin.Path + strings.TrimPrefix(path, base) req.Host = origin.Host }, Transport: transport, }).ServeHTTP } func (s *Server) handleProxy(base string) gin.HandlerFunc { return func(c *gin.Context) { proxy := newProxyServer(s.proxyOrigin, base, s.proxyTransport) proxy.handle("/", s.handleProxyAll) if s.authCacher != nil { proxy.handle("/auth/v4", s.handleProxyAuth) proxy.handle("/auth/v4/info", s.handleProxyAuthInfo) } proxy.ServeHTTP(c.Writer, c.Request) } } func (s *Server) handleProxyAll(proxier func(string) HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { if _, err := proxier(r.URL.Path)(w, r); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } } } func (s *Server) handleProxyAuth(proxier func(string) HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodPost: s.handleProxyAuthPost(w, r, proxier(r.URL.Path)) case http.MethodDelete: s.handleProxyAuthDelete(w, r, proxier(r.URL.Path)) } } } func (s *Server) handleProxyAuthPost(w http.ResponseWriter, r *http.Request, proxier HandlerFunc) { req, err := readFromBody[proton.AuthReq](r) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } if info, ok := s.authCacher.GetAuth(req.Username); ok { if err := writeBody(w, info); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } } else { b, err := proxier(w, r) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } res, err := readFrom[proton.Auth](b) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } s.authCacher.SetAuth(req.Username, res) } } func (s *Server) handleProxyAuthDelete(w http.ResponseWriter, r *http.Request, proxier HandlerFunc) { // When caching, we don't need to do anything here. } func (s *Server) handleProxyAuthInfo(proxier func(string) HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { req, err := readFromBody[proton.AuthInfoReq](r) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } if info, ok := s.authCacher.GetAuthInfo(req.Username); ok { if err := writeBody(w, info); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } } else { b, err := proxier(r.URL.Path)(w, r) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } res, err := readFrom[proton.AuthInfo](b) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } s.authCacher.SetAuthInfo(req.Username, res) } } } type HandlerFunc func(http.ResponseWriter, *http.Request) ([]byte, error) type proxyServer struct { mux *http.ServeMux origin, base string transport http.RoundTripper } func newProxyServer(origin, base string, transport http.RoundTripper) *proxyServer { return &proxyServer{ mux: http.NewServeMux(), origin: origin, base: base, transport: transport, } } func (s *proxyServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { s.mux.ServeHTTP(w, r) } func (s *proxyServer) handle(path string, h func(func(string) HandlerFunc) http.HandlerFunc) { s.mux.Handle(s.base+path, h(func(path string) HandlerFunc { return func(w http.ResponseWriter, r *http.Request) ([]byte, error) { buf := new(bytes.Buffer) // Call the proxy, capturing whatever data it writes. newProxy(s.origin, s.base, path, s.transport)(&writerWrapper{w, buf}, r) // If there is a gzip header entry, decode it. if strings.Contains(w.Header().Get("Content-Encoding"), "gzip") { return gzipDecode(buf.Bytes()) } // Otherwise, return the original written data. return buf.Bytes(), nil } })) } type writerWrapper struct { http.ResponseWriter buf *bytes.Buffer } func (w *writerWrapper) Write(b []byte) (int, error) { if _, err := w.buf.Write(b); err != nil { return 0, err } return w.ResponseWriter.Write(b) } func readFrom[T any](b []byte) (T, error) { var v T if err := json.Unmarshal(b, &v); err != nil { return *new(T), err } return v, nil } func readFromBody[T any](r *http.Request) (T, error) { b, err := io.ReadAll(r.Body) if err != nil { return *new(T), err } defer r.Body.Close() v, err := readFrom[T](b) if err != nil { return *new(T), err } r.Body = io.NopCloser(bytes.NewReader(b)) return v, nil } func writeBody[T any](w http.ResponseWriter, v T) error { b, err := json.Marshal(v) if err != nil { return err } w.Header().Set("Content-Type", "application/json") if _, err := w.Write(b); err != nil { return err } return nil } func gzipDecode(b []byte) ([]byte, error) { r, err := gzip.NewReader(bytes.NewReader(b)) if err != nil { return nil, err } defer r.Close() return io.ReadAll(r) } go-proton-api-1.0.0/server/quark.go000066400000000000000000000022451447642273300172020ustar00rootroot00000000000000package server import ( "encoding/json" "html/template" "net/http" "strings" "github.com/gin-gonic/gin" ) // TODO: This is a disgusting hack to match the output of the internal quark command. // They should return JSON instead of HTML! func (s *Server) handleQuarkCommand() gin.HandlerFunc { return func(c *gin.Context) { res, err := s.b.RunQuarkCommand(c.Param("command"), strings.Split(c.Query("strInput"), " ")...) if err != nil { _ = c.AbortWithError(http.StatusInternalServerError, err) return } var out string switch res := res.(type) { case string: out = res default: b, err := json.MarshalIndent(res, "", " ") if err != nil { _ = c.AbortWithError(http.StatusInternalServerError, err) return } out = string(b) } tmp, err := template.New("quarkCommand").Parse(`
{{.Content}}
`) if err != nil { _ = c.AbortWithError(http.StatusInternalServerError, err) return } if err := tmp.Execute(c.Writer, map[string]string{ "Content": template.HTMLEscapeString(out), }); err != nil { _ = c.AbortWithError(http.StatusInternalServerError, err) return } } } go-proton-api-1.0.0/server/quark_test.go000066400000000000000000000053371447642273300202460ustar00rootroot00000000000000package server import ( "context" "testing" "github.com/henrybear327/go-proton-api" "github.com/stretchr/testify/require" ) func TestServer_Quark_CreateUser(t *testing.T) { withServer(t, func(ctx context.Context, _ *Server, m *proton.Manager) { // Create two users, one with keys and one without. require.NoError(t, m.Quark(ctx, "user:create", "--name", "user-no-keys", "--password", "test", "--create-address")) require.NoError(t, m.Quark(ctx, "user:create", "--name", "user-keys", "--password", "test", "--gen-keys", "rsa2048")) require.NoError(t, m.Quark(ctx, "user:create", "--name", "user-disabled", "--password", "test", "--gen-keys", "rsa2048", "--status", "1")) { // The address should be created but should have no keys. c, _, err := m.NewClientWithLogin(ctx, "user-no-keys", []byte("test")) require.NoError(t, err) defer c.Close() addr, err := c.GetAddresses(ctx) require.NoError(t, err) require.Len(t, addr, 1) require.Len(t, addr[0].Keys, 0) } { // The address should be created and should have keys. c, _, err := m.NewClientWithLogin(ctx, "user-keys", []byte("test")) require.NoError(t, err) defer c.Close() addr, err := c.GetAddresses(ctx) require.NoError(t, err) require.Len(t, addr, 1) require.Len(t, addr[0].Keys, 1) } { // The address should be created and should be disabled c, _, err := m.NewClientWithLogin(ctx, "user-disabled", []byte("test")) require.NoError(t, err) defer c.Close() addr, err := c.GetAddresses(ctx) require.NoError(t, err) require.Len(t, addr, 1) require.Len(t, addr[0].Keys, 1) require.Equal(t, addr[0].Status, proton.AddressStatusDisabled) } }) } func TestServer_Quark_CreateAddress(t *testing.T) { withServer(t, func(ctx context.Context, _ *Server, m *proton.Manager) { // Create a user with one address. require.NoError(t, m.Quark(ctx, "user:create", "--name", "user", "--password", "test", "--gen-keys", "rsa2048")) // Login to the user. c, _, err := m.NewClientWithLogin(ctx, "user", []byte("test")) require.NoError(t, err) defer c.Close() // Get the user. user, err := c.GetUser(ctx) require.NoError(t, err) // Initially the user should have one address and it should have keys. addr, err := c.GetAddresses(ctx) require.NoError(t, err) require.Len(t, addr, 1) require.Len(t, addr[0].Keys, 1) // Create a new address. require.NoError(t, m.Quark(ctx, "user:create:address", "--gen-keys", "rsa2048", user.ID, "test", "alias@proton.local")) // Now the user should have two addresses, and they should both have keys. newAddr, err := c.GetAddresses(ctx) require.NoError(t, err) require.Len(t, newAddr, 2) require.Len(t, newAddr[0].Keys, 1) require.Len(t, newAddr[1].Keys, 1) }) } go-proton-api-1.0.0/server/rate_limit.go000066400000000000000000000022141447642273300202040ustar00rootroot00000000000000package server import ( "sync" "time" ) // rateLimiter is a rate limiter for the server. // If more than limit requests are made in the time window, the server will return 429. type rateLimiter struct { // limit is the rate limit to apply to the server. limit int // window is the window in which to apply the rate limit. window time.Duration // nextReset is the time at which the rate limit will reset. nextReset time.Time // count is the number of calls made to the server. count int // countLock is a mutex for the callCount. countLock sync.Mutex // statusCode to reply with statusCode int } func newRateLimiter(limit int, window time.Duration, statusCode int) *rateLimiter { return &rateLimiter{ limit: limit, window: window, statusCode: statusCode, } } // exceeded checks the rate limit and returns how long to wait before the next request. func (r *rateLimiter) exceeded() time.Duration { r.countLock.Lock() defer r.countLock.Unlock() if time.Now().After(r.nextReset) { r.count = 0 r.nextReset = time.Now().Add(r.window) } r.count++ if r.count > r.limit { return time.Until(r.nextReset) } return 0 } go-proton-api-1.0.0/server/reports.go000066400000000000000000000004231447642273300175510ustar00rootroot00000000000000package server import ( "net/http" "github.com/gin-gonic/gin" ) func (s *Server) handlePostReportBug() gin.HandlerFunc { return func(c *gin.Context) { if _, err := c.MultipartForm(); err != nil { _ = c.AbortWithError(http.StatusBadRequest, err) return } } } go-proton-api-1.0.0/server/router.go000066400000000000000000000220741447642273300174010ustar00rootroot00000000000000package server import ( "bytes" "errors" "fmt" "io" "net" "net/http" "net/url" "strconv" "strings" "time" "github.com/Masterminds/semver/v3" "github.com/gin-gonic/gin" "github.com/google/uuid" "github.com/henrybear327/go-proton-api" ) func initRouter(s *Server) { s.r.Use( s.requireValidAppVersion(), s.setSessionCookie(), s.applyStatusHooks(), s.applyRateLimit(), ) if core := s.r.Group("/core/v4"); core != nil { // Domains routes don't need authentication. if domains := core.Group("/domains"); domains != nil { domains.GET("/available", s.handleGetDomainsAvailable()) } // Reporting a bug is also possible without authentication. if reports := core.Group("/reports"); reports != nil { reports.POST("/bug", s.handlePostReportBug()) } // These routes require auth. if core := core.Group("", s.requireAuth()); core != nil { if users := core.Group("/users"); users != nil { users.GET("", s.handleGetUsers()) } if addresses := core.Group("/addresses"); addresses != nil { addresses.GET("", s.handleGetAddresses()) addresses.GET("/:addressID", s.handleGetAddress()) addresses.DELETE("/:addressID", s.handleDeleteAddress()) addresses.PUT("/:addressID/enable", s.handlePutAddressEnable()) addresses.PUT("/:addressID/disable", s.handlePutAddressDisable()) addresses.PUT("/order", s.handlePutAddressesOrder()) } if labels := core.Group("/labels"); labels != nil { labels.GET("", s.handleGetMailLabels()) labels.POST("", s.handlePostMailLabels()) labels.PUT("/:labelID", s.handlePutMailLabel()) labels.DELETE("/:labelID", s.handleDeleteMailLabel()) } if keys := core.Group("/keys"); keys != nil { keys.GET("", s.handleGetKeys()) keys.GET("/salts", s.handleGetKeySalts()) } if events := core.Group("/events"); events != nil { events.GET("/:eventID", s.handleGetEvents()) events.GET("/latest", s.handleGetEventsLatest()) } if settings := core.Group("/settings"); settings != nil { settings.GET("", s.handleGetUserSettings()) settings.PUT("/telemetry", s.handlePutUserSettingsTelemetry()) settings.PUT("/crashreports", s.handlePutUserSettingsCrashReports()) } } } // All mail routes need authentication. if mail := s.r.Group("/mail/v4", s.requireAuth()); mail != nil { if settings := mail.Group("/settings"); settings != nil { settings.GET("", s.handleGetMailSettings()) settings.PUT("/attachpublic", s.handlePutMailSettingsAttachPublicKey()) } if messages := mail.Group("/messages"); messages != nil { messages.GET("", s.handleGetMailMessages()) messages.POST("", s.handlePostMailMessages()) messages.GET("/ids", s.handleGetMailMessageIDs()) messages.GET("/:messageID", s.handleGetMailMessage()) messages.POST("/:messageID", s.handlePostMailMessage()) messages.PUT("/:messageID", s.handlePutMailMessage()) messages.PUT("/read", s.handlePutMailMessagesRead()) messages.PUT("/unread", s.handlePutMailMessagesUnread()) messages.PUT("/label", s.handlePutMailMessagesLabel()) messages.PUT("/unlabel", s.handlePutMailMessagesUnlabel()) messages.POST("/import", s.handlePutMailMessagesImport()) messages.PUT("/delete", s.handleDeleteMailMessages()) messages.GET("/count", s.handleMessageGroupCount()) } if attachments := mail.Group("/attachments"); attachments != nil { attachments.POST("", s.handlePostMailAttachments()) attachments.GET(":attachID", s.handleGetMailAttachment()) } } // All contacts routes need authentication. if contacts := s.r.Group("/contacts/v4", s.requireAuth()); contacts != nil { contacts.GET("/emails", s.handleGetContactsEmails()) } // All data routes need authentication. if data := s.r.Group("/data/v1", s.requireAuth()); data != nil { if stats := data.Group("/stats"); stats != nil { stats.POST("", s.handlePostDataStats()) stats.POST("/multiple", s.handlePostDataStatsMultiple()) } } // Top level auth routes don't need authentication. if auth := s.r.Group("/auth/v4"); auth != nil { auth.POST("", s.handlePostAuth()) auth.POST("/info", s.handlePostAuthInfo()) auth.POST("/refresh", s.handlePostAuthRefresh()) // These routes require auth. if auth := auth.Group("", s.requireAuth()); auth != nil { auth.DELETE("", s.handleDeleteAuth()) if sessions := auth.Group("/sessions"); sessions != nil { sessions.GET("", s.handleGetAuthSessions()) sessions.DELETE("", s.handleDeleteAuthSessions()) sessions.DELETE("/:authUID", s.handleDeleteAuthSession()) } } } // Test routes don't need authentication. if tests := s.r.Group("/tests"); tests != nil { tests.GET("/ping", s.handleGetPing()) } // Quark routes don't need authentication. if quark := s.r.Group("/internal/quark"); quark != nil { quark.GET("/:command", s.handleQuarkCommand()) } // Proxy any calls to the upstream server. if proxy := s.r.Group("/proxy"); proxy != nil { proxy.Any("/*path", s.handleProxy(proxy.BasePath())) } } func (s *Server) requireValidAppVersion() gin.HandlerFunc { return func(c *gin.Context) { appVersion := c.Request.Header.Get("x-pm-appversion") if appVersion == "" { c.AbortWithStatusJSON(http.StatusBadRequest, proton.APIError{ Code: proton.AppVersionMissingCode, Message: "Missing x-pm-appversion header", }) } else if ok := s.validateAppVersion(appVersion); !ok { c.AbortWithStatusJSON(http.StatusBadRequest, proton.APIError{ Code: proton.AppVersionBadCode, Message: "This version of the app is no longer supported, please update to continue using the app", }) } } } func (s *Server) setSessionCookie() gin.HandlerFunc { return func(c *gin.Context) { url, err := url.Parse(s.s.URL) if err != nil { panic(err) } host, _, err := net.SplitHostPort(url.Host) if err != nil { panic(err) } if cookie, err := c.Request.Cookie("Session-Id"); errors.Is(err, http.ErrNoCookie) { c.SetCookie("Session-Id", uuid.NewString(), int(90*24*time.Hour.Seconds()), "/", host, true, true) } else { c.SetCookie("Session-Id", cookie.Value, int(90*24*time.Hour.Seconds()), "/", host, true, true) } } } func (s *Server) applyStatusHooks() gin.HandlerFunc { return func(c *gin.Context) { s.statusHooksLock.RLock() defer s.statusHooksLock.RUnlock() for _, hook := range s.statusHooks { if status, ok := hook(c.Request); ok { c.AbortWithStatusJSON(status, proton.APIError{ Code: proton.InvalidValue, Message: fmt.Sprintf("Request failed with status %d", status), }) return } } } } func (s *Server) applyRateLimit() gin.HandlerFunc { return func(c *gin.Context) { if s.rateLimit == nil { return } if wait := s.rateLimit.exceeded(); wait > 0 { c.Header("Retry-After", strconv.Itoa(int(wait.Seconds()))) c.AbortWithStatus(s.rateLimit.statusCode) } } } func (s *Server) logCalls() gin.HandlerFunc { return func(c *gin.Context) { start := time.Now() req, err := io.ReadAll(c.Request.Body) if err != nil { panic(err) } else { c.Request.Body = io.NopCloser(bytes.NewReader(req)) } res, err := newBodyWriter(c.Writer) if err != nil { panic(err) } else { c.Writer = res } c.Next() s.callWatchersLock.RLock() defer s.callWatchersLock.RUnlock() for _, call := range s.callWatchers { if call.isWatching(c.Request.URL.Path) { call.publish(Call{ URL: c.Request.URL, Method: c.Request.Method, Status: c.Writer.Status(), Time: start, Duration: time.Since(start), RequestHeader: c.Request.Header, RequestBody: req, ResponseHeader: c.Writer.Header(), ResponseBody: res.bytes(), }) } } } } func (s *Server) handleOffline() gin.HandlerFunc { return func(c *gin.Context) { if s.offline { c.AbortWithStatus(http.StatusServiceUnavailable) return } } } func (s *Server) requireAuth() gin.HandlerFunc { return func(c *gin.Context) { authUID := c.Request.Header.Get("x-pm-uid") if authUID == "" { c.AbortWithStatus(http.StatusUnauthorized) return } auth := c.Request.Header.Get("Authorization") if auth == "" { c.AbortWithStatus(http.StatusUnauthorized) return } userID, err := s.b.VerifyAuth(authUID, strings.Split(auth, " ")[1]) if err != nil { c.AbortWithStatus(http.StatusUnauthorized) return } c.Set("UserID", userID) c.Set("AuthUID", authUID) } } func (s *Server) validateAppVersion(appVersion string) bool { if s.minAppVersion == nil { return true } split := strings.Split(appVersion, "_") if len(split) != 2 { return false } version, err := semver.NewVersion(split[1]) if err != nil { return false } if version.LessThan(s.minAppVersion) { return false } return true } type bodyWriter struct { gin.ResponseWriter buf *bytes.Buffer } func newBodyWriter(w gin.ResponseWriter) (*bodyWriter, error) { if w == nil { return nil, errors.New("response writer is nil") } return &bodyWriter{ ResponseWriter: w, buf: &bytes.Buffer{}, }, nil } func (w bodyWriter) Write(b []byte) (int, error) { if n, err := w.buf.Write(b); err != nil { return n, err } return w.ResponseWriter.Write(b) } func (w bodyWriter) bytes() []byte { return w.buf.Bytes() } go-proton-api-1.0.0/server/server.go000066400000000000000000000147171447642273300173740ustar00rootroot00000000000000package server import ( "net/http" "net/http/httptest" "sync" "time" "github.com/Masterminds/semver/v3" "github.com/bradenaw/juniper/xslices" "github.com/gin-gonic/gin" "github.com/henrybear327/go-proton-api" "github.com/henrybear327/go-proton-api/server/backend" ) type AuthCacher interface { GetAuthInfo(username string) (proton.AuthInfo, bool) SetAuthInfo(username string, info proton.AuthInfo) GetAuth(username string) (proton.Auth, bool) SetAuth(username string, auth proton.Auth) } // StatusHook is a function that can be used to modify the response code of a call. type StatusHook func(*http.Request) (int, bool) type Server struct { // r is the gin router. r *gin.Engine // s is the underlying server. s *httptest.Server // b is the server backend, which manages accounts, messages, attachments, etc. b *backend.Backend // callWatchers records callWatchers received by the server. callWatchers []callWatcher callWatchersLock sync.RWMutex // statusHooks are hooks that can be used to modify the response code of a call. statusHooks []StatusHook statusHooksLock sync.RWMutex // domain is the test server domain. domain string // minAppVersion is the minimum app version that the server will accept. minAppVersion *semver.Version // proxyOrigin is the URL of the origin server when the server is a proxy. proxyOrigin string // proxyTransport is the transport to use when the server is a proxy. proxyTransport *http.Transport // authCacher can optionally be set to cache proxied auth calls. authCacher AuthCacher // offline is whether to pretend the server is offline and return 5xx errors. offline bool // rateLimit is the rate limiter for the server. rateLimit *rateLimiter } func New(opts ...Option) *Server { builder := newServerBuilder() for _, opt := range opts { opt.config(builder) } return builder.build() } // GetHostURL returns the API root to make calls to. func (s *Server) GetHostURL() string { return s.s.URL } // GetProxyURL returns the API root to make calls to which should be proxied. func (s *Server) GetProxyURL() string { return s.s.URL + "/proxy" } // GetDomain returns the domain of the server (e.g. "proton.local"). func (s *Server) GetDomain() string { return s.domain } // AddCallWatcher adds a call watcher to the server. func (s *Server) AddCallWatcher(fn func(Call), paths ...string) { s.callWatchersLock.Lock() defer s.callWatchersLock.Unlock() s.callWatchers = append(s.callWatchers, newCallWatcher(fn, paths...)) } // AddStatusHook adds a status hook to the server. func (s *Server) AddStatusHook(fn StatusHook) { s.statusHooksLock.Lock() defer s.statusHooksLock.Unlock() s.statusHooks = append(s.statusHooks, fn) } // CreateUser creates a new server user with the given username and password. // A single address will be created for the user, derived from the username and the server's domain. func (s *Server) CreateUser(username string, password []byte) (string, string, error) { userID, err := s.b.CreateUser(username, password) if err != nil { return "", "", err } addrID, err := s.b.CreateAddress(userID, username+"@"+s.domain, password, true, proton.AddressStatusEnabled, proton.AddressTypeOriginal) if err != nil { return "", "", err } return userID, addrID, nil } func (s *Server) RemoveUser(userID string) error { return s.b.RemoveUser(userID) } func (s *Server) RefreshUser(userID string, refresh proton.RefreshFlag) error { return s.b.RefreshUser(userID, refresh) } func (s *Server) GetUserKeyIDs(userID string) ([]string, error) { user, err := s.b.GetUser(userID) if err != nil { return nil, err } return xslices.Map(user.Keys, func(key proton.Key) string { return key.ID }), nil } func (s *Server) CreateUserKey(userID string, password []byte) error { return s.b.CreateUserKey(userID, password) } func (s *Server) RemoveUserKey(userID, keyID string) error { return s.b.RemoveUserKey(userID, keyID) } func (s *Server) CreateAddress(userID, email string, password []byte) (string, error) { return s.b.CreateAddress(userID, email, password, true, proton.AddressStatusEnabled, proton.AddressTypeOriginal) } func (s *Server) CreateAddressAsUpdate(userID, email string, password []byte) (string, error) { return s.b.CreateAddressAsUpdate(userID, email, password, true, proton.AddressStatusEnabled, proton.AddressTypeOriginal) } func (s *Server) ChangeAddressType(userID, addrId string, addrType proton.AddressType) error { return s.b.ChangeAddressType(userID, addrId, addrType) } func (s *Server) RemoveAddress(userID, addrID string) error { return s.b.RemoveAddress(userID, addrID) } func (s *Server) CreateAddressKey(userID, addrID string, password []byte) error { return s.b.CreateAddressKey(userID, addrID, password) } func (s *Server) RemoveAddressKey(userID, addrID, keyID string) error { return s.b.RemoveAddressKey(userID, addrID, keyID) } func (s *Server) CreateLabel(userID, name, parentID string, labelType proton.LabelType) (string, error) { label, err := s.b.CreateLabel(userID, name, parentID, labelType) if err != nil { return "", err } return label.ID, nil } func (s *Server) GetLabels(userID string) ([]proton.Label, error) { return s.b.GetLabels(userID) } func (s *Server) LabelMessage(userID, msgID, labelID string) error { return s.b.LabelMessages(userID, labelID, msgID) } func (s *Server) UnlabelMessage(userID, msgID, labelID string) error { return s.b.UnlabelMessages(userID, labelID, msgID) } func (s *Server) AddAddressCreatedEvent(userID, addrID string) error { return s.b.AddAddressCreatedUpdate(userID, addrID) } func (s *Server) AddLabelCreatedEvent(userID, labelID string) error { return s.b.AddLabelCreatedUpdate(userID, labelID) } func (s *Server) AddMessageCreatedEvent(userID, messageID string) error { return s.b.AddMessageCreatedUpdate(userID, messageID) } // SetMaxUpdatesPerEvent func (s *Server) SetMaxUpdatesPerEvent(max int) { s.b.SetMaxUpdatesPerEvent(max) } func (s *Server) SetAuthLife(authLife time.Duration) { s.b.SetAuthLife(authLife) } func (s *Server) SetMinAppVersion(minAppVersion *semver.Version) { s.minAppVersion = minAppVersion } func (s *Server) SetOffline(offline bool) { s.offline = offline } func (s *Server) RevokeUser(userID string) error { sessions, err := s.b.GetSessions(userID) if err != nil { return err } for _, session := range sessions { if err := s.b.DeleteSession(userID, session.UID); err != nil { return err } } return nil } func (s *Server) Close() { s.proxyTransport.CloseIdleConnections() s.s.Close() } go-proton-api-1.0.0/server/server_builder.go000066400000000000000000000121211447642273300210650ustar00rootroot00000000000000package server import ( "io" "net" "net/http" "net/http/httptest" "os" "time" "github.com/gin-gonic/gin" "github.com/henrybear327/go-proton-api" "github.com/henrybear327/go-proton-api/server/backend" ) type serverBuilder struct { config *http.Server listener net.Listener withTLS bool domain string logger io.Writer origin string proxyTransport *http.Transport cacher AuthCacher rateLimiter *rateLimiter enableDedup bool } func newServerBuilder() *serverBuilder { var logger io.Writer if os.Getenv("GO_PROTON_API_SERVER_LOGGER_ENABLED") != "" { logger = gin.DefaultWriter } else { logger = io.Discard } return &serverBuilder{ config: &http.Server{}, withTLS: true, domain: "proton.local", logger: logger, origin: proton.DefaultHostURL, proxyTransport: &http.Transport{}, } } func (builder *serverBuilder) build() *Server { gin.SetMode(gin.ReleaseMode) s := &Server{ r: gin.New(), b: backend.New(time.Hour, builder.domain, builder.enableDedup), domain: builder.domain, proxyOrigin: builder.origin, authCacher: builder.cacher, rateLimit: builder.rateLimiter, proxyTransport: builder.proxyTransport, } s.r.Use(gin.CustomRecovery(func(c *gin.Context, recovered any) { c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{ "Code": http.StatusInternalServerError, "Error": "Internal server error", "Details": recovered, }) })) var l net.Listener if builder.listener == nil { var err error if l, err = net.Listen("tcp", "127.0.0.1:0"); err != nil { panic(err) } } else { l = builder.listener } // Create the test server. s.s = &httptest.Server{ Listener: l, Config: builder.config, } // Set the server to use the custom handler. s.s.Config.Handler = s.r // Start the server. if builder.withTLS { s.s.StartTLS() } else { s.s.Start() } s.r.Use( gin.LoggerWithConfig(gin.LoggerConfig{Output: builder.logger}), gin.Recovery(), s.logCalls(), s.handleOffline(), ) initRouter(s) return s } // Option represents a type that can be used to configure the server. type Option interface { config(*serverBuilder) } // WithTLS controls whether the server should serve over TLS. func WithTLS(tls bool) Option { return &withTLS{ withTLS: tls, } } type withTLS struct { withTLS bool } func (opt withTLS) config(builder *serverBuilder) { builder.withTLS = opt.withTLS } // WithDomain controls the domain of the server. func WithDomain(domain string) Option { return &withDomain{ domain: domain, } } type withDomain struct { domain string } func (opt withDomain) config(builder *serverBuilder) { builder.domain = opt.domain } // WithLogger controls where Gin logs to. func WithLogger(logger io.Writer) Option { return &withLogger{ logger: logger, } } type withLogger struct { logger io.Writer } func (opt withLogger) config(builder *serverBuilder) { builder.logger = opt.logger } func WithProxyOrigin(origin string) Option { return &withProxyOrigin{ origin: origin, } } type withProxyOrigin struct { origin string } func (opt withProxyOrigin) config(builder *serverBuilder) { builder.origin = opt.origin } func WithAuthCacher(cacher AuthCacher) Option { return &withAuthCache{ cacher: cacher, } } type withAuthCache struct { cacher AuthCacher } func (opt withAuthCache) config(builder *serverBuilder) { builder.cacher = opt.cacher } func WithRateLimit(limit int, window time.Duration) Option { return &withRateLimit{ limit: limit, window: window, statusCode: http.StatusTooManyRequests, } } func WithRateLimitAndCustomStatusCode(limit int, window time.Duration, code int) Option { return &withRateLimit{ limit: limit, window: window, statusCode: code, } } type withRateLimit struct { limit int statusCode int window time.Duration } func (opt withRateLimit) config(builder *serverBuilder) { builder.rateLimiter = newRateLimiter(opt.limit, opt.window, opt.statusCode) } func WithProxyTransport(transport *http.Transport) Option { return &withProxyTransport{ transport: transport, } } type withProxyTransport struct { transport *http.Transport } func (opt withProxyTransport) config(builder *serverBuilder) { builder.proxyTransport = opt.transport } type withServerConfig struct { cfg *http.Server } func (opt withServerConfig) config(builder *serverBuilder) { builder.config = opt.cfg } // WithServerConfig allows you to configure the underlying HTTP server. func WithServerConfig(cfg *http.Server) Option { return withServerConfig{ cfg: cfg, } } type withNetListener struct { listener net.Listener } func (opt withNetListener) config(builder *serverBuilder) { builder.listener = opt.listener } // WithListener allows you to set the net.Listener to use. func WithListener(listener net.Listener) Option { return withNetListener{ listener: listener, } } type withMessageDedup struct{} func (withMessageDedup) config(builder *serverBuilder) { builder.enableDedup = true } func WithMessageDedup() Option { return &withMessageDedup{} } go-proton-api-1.0.0/server/server_test.go000066400000000000000000002127661447642273300204370ustar00rootroot00000000000000package server import ( "context" "crypto/tls" "encoding/json" "errors" "fmt" "net/http" "net/mail" "net/url" "os" "runtime" "sync" "sync/atomic" "testing" "time" "github.com/bradenaw/juniper/parallel" "github.com/Masterminds/semver/v3" "github.com/ProtonMail/gluon/async" "github.com/ProtonMail/gluon/rfc822" "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/bradenaw/juniper/iterator" "github.com/bradenaw/juniper/stream" "github.com/bradenaw/juniper/xslices" "github.com/google/uuid" "github.com/henrybear327/go-proton-api" "github.com/stretchr/testify/require" "golang.org/x/exp/slices" ) func TestServer_LoginLogout(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) { user, err := c.GetUser(ctx) require.NoError(t, err) require.Equal(t, "user", user.Name) require.Equal(t, "user@"+s.GetDomain(), user.Email) // Logout from the test API. require.NoError(t, c.AuthDelete(ctx)) // Future requests should fail. require.Error(t, c.AuthDelete(ctx)) }) }) } func TestServerMulti(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { _, _, err := s.CreateUser("user", []byte("pass")) require.NoError(t, err) // Create one client. c1, _, err := m.NewClientWithLogin(ctx, "user", []byte("pass")) require.NoError(t, err) defer c1.Close() // Create another client. c2, _, err := m.NewClientWithLogin(ctx, "user", []byte("pass")) require.NoError(t, err) defer c2.Close() // Both clients should be able to make requests. must(c1.GetUser(ctx)) must(c2.GetUser(ctx)) // Logout the first client; it should no longer be able to make requests. require.NoError(t, c1.AuthDelete(ctx)) require.Panics(t, func() { must(c1.GetUser(ctx)) }) // The second client should still be able to make requests. must(c2.GetUser(ctx)) }) } func TestServer_Ping(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, _ *proton.Manager) { ctl := proton.NewNetCtl() m := proton.New( proton.WithHostURL(s.GetHostURL()), proton.WithTransport(ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true})), ) var status proton.Status m.AddStatusObserver(func(s proton.Status) { status = s }) // When the network goes down, ping should fail. ctl.Disable() require.Error(t, m.Ping(ctx)) require.Equal(t, proton.StatusDown, status) // When the network goes up, ping should succeed. ctl.Enable() require.NoError(t, m.Ping(ctx)) require.Equal(t, proton.StatusUp, status) // When the API is down, ping should still succeed if the API is at least reachable. s.SetOffline(true) require.NoError(t, m.Ping(ctx)) require.Equal(t, proton.StatusUp, status) // When the API is down, ping should fail if the API cannot be reached. ctl.Disable() require.Error(t, m.Ping(ctx)) require.Equal(t, proton.StatusDown, status) // When the network goes up, ping should succeed, even if the API is down. ctl.Enable() require.NoError(t, m.Ping(ctx)) require.Equal(t, proton.StatusUp, status) // When the API comes back alive, ping should succeed. s.SetOffline(false) require.NoError(t, m.Ping(ctx)) require.Equal(t, proton.StatusUp, status) }) } func TestServer_Bool(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) { withMessages(ctx, t, c, "pass", 1, func([]string) { metadata, err := c.GetMessageMetadata(ctx, proton.MessageFilter{}) require.NoError(t, err) // By default the message is unread. require.True(t, bool(must(c.GetMessage(ctx, metadata[0].ID)).Unread)) // Mark the message as read. require.NoError(t, c.MarkMessagesRead(ctx, metadata[0].ID)) // Now the message is read. require.False(t, bool(must(c.GetMessage(ctx, metadata[0].ID)).Unread)) }) }) }) } func TestServer_Messages(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) { withMessages(ctx, t, c, "pass", 1000, func(messageIDs []string) { // Get the messages. metadata, err := c.GetMessageMetadata(ctx, proton.MessageFilter{}) require.NoError(t, err) // The messages should be the ones we created. require.ElementsMatch(t, messageIDs, xslices.Map(metadata, func(metadata proton.MessageMetadata) string { return metadata.ID })) // The messages should be in All Mail and should be unread. for _, message := range metadata { require.True(t, bool(message.Unread)) require.Equal(t, []string{proton.AllMailLabel}, message.LabelIDs) } // Mark the first three messages as read and put them in archive. require.NoError(t, c.MarkMessagesRead(ctx, messageIDs[0], messageIDs[1], messageIDs[2])) require.NoError(t, c.LabelMessages(ctx, []string{messageIDs[0], messageIDs[1], messageIDs[2]}, proton.ArchiveLabel)) // They should now be read. require.False(t, bool(must(c.GetMessage(ctx, messageIDs[0])).Unread)) require.False(t, bool(must(c.GetMessage(ctx, messageIDs[1])).Unread)) require.False(t, bool(must(c.GetMessage(ctx, messageIDs[2])).Unread)) // They should now be in archive. require.ElementsMatch(t, []string{proton.ArchiveLabel, proton.AllMailLabel}, must(c.GetMessage(ctx, messageIDs[0])).LabelIDs) require.ElementsMatch(t, []string{proton.ArchiveLabel, proton.AllMailLabel}, must(c.GetMessage(ctx, messageIDs[1])).LabelIDs) require.ElementsMatch(t, []string{proton.ArchiveLabel, proton.AllMailLabel}, must(c.GetMessage(ctx, messageIDs[2])).LabelIDs) // Put them in inbox. require.NoError(t, c.LabelMessages(ctx, []string{messageIDs[0], messageIDs[1], messageIDs[2]}, proton.ArchiveLabel)) }) }) }) } func TestServer_GetMessageMetadataPage(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) { withMessages(ctx, t, c, "pass", 1000, func(messageIDs []string) { for _, chunk := range xslices.Chunk(messageIDs, 150) { // Get the messages. metadata, err := c.GetMessageMetadataPage(ctx, 0, 150, proton.MessageFilter{ID: chunk}) require.NoError(t, err) // The messages should be the ones we created. require.ElementsMatch(t, chunk, xslices.Map(metadata, func(metadata proton.MessageMetadata) string { return metadata.ID })) } }) }) }) } func TestServer_MessageFilter(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) { withMessages(ctx, t, c, "pass", 1000, func(messageIDs []string) { // Get the messages. metadata, err := c.GetMessageMetadata(ctx, proton.MessageFilter{}) require.NoError(t, err) // The messages should be the ones we created. require.ElementsMatch(t, messageIDs, xslices.Map(metadata, func(metadata proton.MessageMetadata) string { return metadata.ID })) // Get metadata for just the first three messages. partial, err := c.GetMessageMetadata(ctx, proton.MessageFilter{ ID: []string{ metadata[0].ID, metadata[1].ID, metadata[2].ID, }, }) require.NoError(t, err) // The messages should be just the first three. require.Equal(t, metadata[:3], partial) }) }) }) } func TestServer_MessageFilterDesc(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) { withMessages(ctx, t, c, "pass", 100, func(messageIDs []string) { allMetadata := make([]proton.MessageMetadata, 0, 100) // first request. { metadata, err := c.GetMessageMetadataPage(ctx, 0, 10, proton.MessageFilter{Desc: true}) require.NoError(t, err) allMetadata = append(allMetadata, metadata...) } for i := 1; i < 11; i++ { // Get the messages. metadata, err := c.GetMessageMetadataPage(ctx, 0, 10, proton.MessageFilter{Desc: true, EndID: allMetadata[len(allMetadata)-1].ID}) require.NoError(t, err) require.NotEmpty(t, metadata) require.Equal(t, metadata[0].ID, allMetadata[len(allMetadata)-1].ID) allMetadata = append(allMetadata, metadata[1:]...) } // Final check. Asking for EndID as last message multiple times will always return the last id. metadata, err := c.GetMessageMetadataPage(ctx, 0, 10, proton.MessageFilter{Desc: true, EndID: allMetadata[len(allMetadata)-1].ID}) require.NoError(t, err) require.Len(t, metadata, 1) require.Equal(t, metadata[0].ID, allMetadata[len(allMetadata)-1].ID) // The messages should be the ones we created. require.ElementsMatch(t, messageIDs, xslices.Map(allMetadata, func(metadata proton.MessageMetadata) string { return metadata.ID })) }) }) }) } func TestServer_MessageIDs(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) { withMessages(ctx, t, c, "pass", 10000, func(wantMessageIDs []string) { allMessageIDs, err := c.GetAllMessageIDs(ctx, "") require.NoError(t, err) require.ElementsMatch(t, wantMessageIDs, allMessageIDs) halfMessageIDs, err := c.GetAllMessageIDs(ctx, allMessageIDs[len(allMessageIDs)/2]) require.NoError(t, err) require.ElementsMatch(t, allMessageIDs[len(allMessageIDs)/2+1:], halfMessageIDs) }) }) }) } func TestServer_MessagesDelete(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) { withMessages(ctx, t, c, "pass", 1000, func(messageIDs []string) { // Get the messages. metadata, err := c.GetMessageMetadata(ctx, proton.MessageFilter{}) require.NoError(t, err) // The messages should be the ones we created. require.ElementsMatch(t, messageIDs, xslices.Map(metadata, func(metadata proton.MessageMetadata) string { return metadata.ID })) // Delete half the messages. require.NoError(t, c.DeleteMessage(ctx, messageIDs[0:500]...)) // Get the remaining messages. remaining, err := c.GetMessageMetadata(ctx, proton.MessageFilter{}) require.NoError(t, err) // The remaining messages should be the ones we didn't delete. require.ElementsMatch(t, messageIDs[500:], xslices.Map(remaining, func(metadata proton.MessageMetadata) string { return metadata.ID })) }) }) }) } func TestServer_MessagesDeleteAfterUpdate(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) { withMessages(ctx, t, c, "pass", 1000, func(messageIDs []string) { // Get the initial event ID. eventID, err := c.GetLatestEventID(ctx) require.NoError(t, err) // Put half the messages in archive. require.NoError(t, c.LabelMessages(ctx, messageIDs[0:500], proton.ArchiveLabel)) // Delete half the messages. require.NoError(t, c.DeleteMessage(ctx, messageIDs[0:500]...)) // Get the event reflecting this change. event, more, err := c.GetEvent(ctx, eventID) require.NoError(t, err) require.False(t, more) require.Equal(t, 1, len(event)) // The event should have the correct number of message events. require.Len(t, event[0].Messages, 500) // All the events should be delete events. for _, message := range event[0].Messages { require.Equal(t, proton.EventDelete, message.Action) } }) }) }) } func TestServer_Events(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) { withMessages(ctx, t, c, "pass", 3, func(messageIDs []string) { // Get the latest event ID to stream from. fromEventID, err := c.GetLatestEventID(ctx) require.NoError(t, err) // Begin collecting events. eventCh := c.NewEventStream(ctx, time.Second, 0, fromEventID) // Mark a message as read. require.NoError(t, c.MarkMessagesRead(ctx, messageIDs[0])) // The message should eventually be read. require.Eventually(t, func() bool { event := <-eventCh if len(event.Messages) != 1 { return false } if event.Messages[0].ID != messageIDs[0] { return false } return !bool(event.Messages[0].Message.Unread) }, 5*time.Second, time.Millisecond*100) // Add another message to archive. require.NoError(t, c.LabelMessages(ctx, []string{messageIDs[1]}, proton.ArchiveLabel)) // The message should eventually be in archive and all mail. require.Eventually(t, func() bool { event := <-eventCh if len(event.Messages) != 1 { return false } if event.Messages[0].ID != messageIDs[1] { return false } return elementsMatch([]string{proton.ArchiveLabel, proton.AllMailLabel}, event.Messages[0].Message.LabelIDs) }, 5*time.Second, time.Millisecond*100) // Perform a sequence of actions on the same message. require.NoError(t, c.LabelMessages(ctx, []string{messageIDs[2]}, proton.TrashLabel)) require.NoError(t, c.UnlabelMessages(ctx, []string{messageIDs[2]}, proton.TrashLabel)) require.NoError(t, c.MarkMessagesRead(ctx, messageIDs[2])) require.NoError(t, c.MarkMessagesUnread(ctx, messageIDs[2])) // The final state of the message should be correct. require.Eventually(t, func() bool { event := <-eventCh if len(event.Messages) != 1 { return false } if event.Messages[0].ID != messageIDs[2] { return false } return bool(event.Messages[0].Message.Unread) && elementsMatch([]string{proton.AllMailLabel}, event.Messages[0].Message.LabelIDs) }, 5*time.Second, time.Millisecond*100) // No more events should be sent. select { case <-eventCh: t.Fatal("unexpected event") default: // .... } }) }) }) } func TestServer_Events_Multi(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { for i := 0; i < 10; i++ { withUser(ctx, t, s, m, fmt.Sprintf("user%v", i), "pass", func(c *proton.Client) { latest, err := c.GetLatestEventID(ctx) require.NoError(t, err) // Fetching latest again should return the same event ID. latestAgain, err := c.GetLatestEventID(ctx) require.NoError(t, err) require.Equal(t, latest, latestAgain) events, more, err := c.GetEvent(ctx, latest) require.NoError(t, err) require.False(t, more) // The event should be empty. require.Equal(t, []proton.Event{{EventID: events[0].EventID}}, events) // After fetching an empty event, its ID should still be the latest. eventAgain, more, err := c.GetEvent(ctx, events[0].EventID) require.NoError(t, err) require.False(t, more) require.Equal(t, eventAgain[0].EventID, events[0].EventID) }) } }) } func TestServer_Events_Refresh(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) { user, err := c.GetUser(ctx) require.NoError(t, err) // Get the latest event ID to stream from. fromEventID, err := c.GetLatestEventID(ctx) require.NoError(t, err) // Refresh the user's mail. require.NoError(t, s.RefreshUser(user.ID, proton.RefreshMail)) // Begin collecting events. eventCh := c.NewEventStream(ctx, time.Second, 0, fromEventID) // The user should eventually be refreshed. require.Eventually(t, func() bool { return (<-eventCh).Refresh&proton.RefreshMail != 0 }, 5*time.Second, time.Millisecond*100) }) }) } func TestServer_Events_UserSettings(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) { user, err := c.GetUser(ctx) require.NoError(t, err) _, err = s.b.SetUserSettingsTelemetry(user.ID, proton.SettingDisabled) require.NoError(t, err) // Get the latest event ID to stream from. fromEventID, err := c.GetLatestEventID(ctx) require.NoError(t, err) // Refresh the user's mail. _, err = s.b.SetUserSettingsTelemetry(user.ID, proton.SettingEnabled) require.NoError(t, err) // Begin collecting events. eventCh := c.NewEventStream(ctx, time.Second, 0, fromEventID) // The user should eventually be refreshed. require.Eventually(t, func() bool { e := <-eventCh return e.UserSettings != nil && e.UserSettings.Telemetry == proton.SettingEnabled }, 5*time.Second, time.Millisecond*100) }) }) } func TestServer_RevokeUser(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) { user, err := c.GetUser(ctx) require.NoError(t, err) require.Equal(t, "user", user.Name) require.Equal(t, "user@"+s.GetDomain(), user.Email) // Revoke the user's auth. require.NoError(t, s.RevokeUser(user.ID)) // Future requests should fail. require.Error(t, c.AuthDelete(ctx)) }) }) } func TestServer_Calls(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) { var calls []Call // Watch calls that are made. s.AddCallWatcher(func(call Call) { calls = append(calls, call) }) // Get the user. _, err := c.GetUser(ctx) require.NoError(t, err) // Logout the user. require.NoError(t, c.AuthDelete(ctx)) // The user call should be correct. userCall := calls[0] require.Equal(t, "/core/v4/users", userCall.URL.Path) require.Equal(t, "GET", userCall.Method) require.Equal(t, http.StatusOK, userCall.Status) // The logout call should be correct. logoutCall := calls[1] require.Equal(t, "/auth/v4", logoutCall.URL.Path) require.Equal(t, "DELETE", logoutCall.Method) require.Equal(t, http.StatusOK, logoutCall.Status) }) }) } func TestServer_Calls_Status(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) { var calls []Call // Watch calls that are made. s.AddCallWatcher(func(call Call) { calls = append(calls, call) }) // Make a bad call. _, err := c.GetMessage(ctx, "no such message ID") require.Error(t, err) // The user call should have error status. require.Equal(t, http.StatusUnprocessableEntity, calls[0].Status) }) }) } func TestServer_Calls_Request(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { var calls []Call s.AddCallWatcher(func(call Call) { calls = append(calls, call) }) withUser(ctx, t, s, m, "user", "pass", func(*proton.Client) { require.Equal( t, calls[0].RequestBody, must(json.Marshal(proton.AuthInfoReq{Username: "user"})), ) }) }) } func TestServer_Calls_Response(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { var calls []Call s.AddCallWatcher(func(call Call) { calls = append(calls, call) }) withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) { salts, err := c.GetSalts(ctx) require.NoError(t, err) require.Equal( t, calls[len(calls)-1].ResponseBody, must(json.Marshal(struct{ KeySalts []proton.Salt }{salts})), ) }) }) } func TestServer_Calls_Cookies(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { var calls []Call s.AddCallWatcher(func(call Call) { calls = append(calls, call) }) withUser(ctx, t, s, m, "user", "pass", func(*proton.Client) { // The header in the first call's response should set the Session-Id cookie. resHeader := (&http.Response{Header: calls[len(calls)-2].ResponseHeader}) require.Len(t, resHeader.Cookies(), 1) require.Equal(t, "Session-Id", resHeader.Cookies()[0].Name) require.NotEmpty(t, resHeader.Cookies()[0].Value) // The cookie should be sent in the next call. reqHeader := (&http.Request{Header: calls[len(calls)-1].RequestHeader}) require.Len(t, reqHeader.Cookies(), 1) require.Equal(t, "Session-Id", reqHeader.Cookies()[0].Name) require.NotEmpty(t, reqHeader.Cookies()[0].Value) // The cookie should be the same. require.Equal(t, resHeader.Cookies()[0].Value, reqHeader.Cookies()[0].Value) }) }) } func TestServer_Calls_Manager(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { var calls []Call // Watch calls that are made. s.AddCallWatcher(func(call Call) { calls = append(calls, call) }) // Make a non-user request. require.NoError(t, m.ReportBug(ctx, proton.ReportBugReq{})) // The call should be correct. reportCall := calls[0] require.Equal(t, "/core/v4/reports/bug", reportCall.URL.Path) require.Equal(t, "POST", reportCall.Method) require.Equal(t, http.StatusOK, reportCall.Status) }) } func TestServer_CreateMessage(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) { user, err := c.GetUser(ctx) require.NoError(t, err) addr, err := c.GetAddresses(ctx) require.NoError(t, err) salt, err := c.GetSalts(ctx) require.NoError(t, err) pass, err := salt.SaltForKey([]byte("pass"), user.Keys.Primary().ID) require.NoError(t, err) _, addrKRs, err := proton.Unlock(user, addr, pass, async.NoopPanicHandler{}) require.NoError(t, err) draft, err := c.CreateDraft(ctx, addrKRs[addr[0].ID], proton.CreateDraftReq{ Message: proton.DraftTemplate{ Subject: "My subject", Sender: &mail.Address{Address: addr[0].Email}, ToList: []*mail.Address{{Address: "recipient@example.com"}}, }, }) require.NoError(t, err) require.Equal(t, addr[0].ID, draft.AddressID) require.Equal(t, "My subject", draft.Subject) require.Equal(t, &mail.Address{Address: "user@" + s.GetDomain()}, draft.Sender) require.Equal(t, []*mail.Address{{Address: "recipient@example.com"}}, draft.ToList) require.ElementsMatch(t, []string{proton.AllMailLabel, proton.AllDraftsLabel, proton.DraftsLabel}, draft.LabelIDs) }) }) } func TestServer_UpdateDraft(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) { user, err := c.GetUser(ctx) require.NoError(t, err) addr, err := c.GetAddresses(ctx) require.NoError(t, err) salt, err := c.GetSalts(ctx) require.NoError(t, err) pass, err := salt.SaltForKey([]byte("pass"), user.Keys.Primary().ID) require.NoError(t, err) _, addrKRs, err := proton.Unlock(user, addr, pass, async.NoopPanicHandler{}) require.NoError(t, err) // Create the draft. draft, err := c.CreateDraft(ctx, addrKRs[addr[0].ID], proton.CreateDraftReq{ Message: proton.DraftTemplate{ Subject: "My subject", Sender: &mail.Address{Address: addr[0].Email}, ToList: []*mail.Address{{Address: "recipient@example.com"}}, }, }) require.NoError(t, err) require.Equal(t, addr[0].ID, draft.AddressID) require.Equal(t, "My subject", draft.Subject) require.Equal(t, &mail.Address{Address: "user@" + s.GetDomain()}, draft.Sender) require.Equal(t, []*mail.Address{{Address: "recipient@example.com"}}, draft.ToList) // Create an event stream to watch for an update event. fromEventID, err := c.GetLatestEventID(ctx) require.NoError(t, err) eventCh := c.NewEventStream(ctx, time.Second, 0, fromEventID) // Update the draft subject/to-list. msg, err := c.UpdateDraft(ctx, draft.ID, addrKRs[addr[0].ID], proton.UpdateDraftReq{ Message: proton.DraftTemplate{ Subject: "Edited subject", Sender: &mail.Address{Address: addr[0].Email}, ToList: []*mail.Address{{Address: "edited@example.com"}}, MIMEType: rfc822.TextPlain, }, }) require.NoError(t, err) require.Equal(t, "Edited subject", msg.Subject) // We should eventually get an update event. require.Eventually(t, func() bool { event := <-eventCh if len(event.Messages) < 1 { return false } if event.Messages[0].ID != draft.ID { return false } if event.Messages[0].Action != proton.EventUpdate { return false } require.Equal(t, draft.ID, event.Messages[0].ID) require.Equal(t, "Edited subject", event.Messages[0].Message.Subject) require.Equal(t, []*mail.Address{{Address: "edited@example.com"}}, event.Messages[0].Message.ToList) return true }, 5*time.Second, time.Millisecond*100) }) }) } func TestServer_SendMessage(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) { user, err := c.GetUser(ctx) require.NoError(t, err) addr, err := c.GetAddresses(ctx) require.NoError(t, err) salt, err := c.GetSalts(ctx) require.NoError(t, err) pass, err := salt.SaltForKey([]byte("pass"), user.Keys.Primary().ID) require.NoError(t, err) _, addrKRs, err := proton.Unlock(user, addr, pass, async.NoopPanicHandler{}) require.NoError(t, err) draft, err := c.CreateDraft(ctx, addrKRs[addr[0].ID], proton.CreateDraftReq{ Message: proton.DraftTemplate{ Subject: "My subject", Sender: &mail.Address{Address: addr[0].Email}, ToList: []*mail.Address{{Address: "recipient@example.com"}}, }, }) require.NoError(t, err) sent, err := c.SendDraft(ctx, draft.ID, proton.SendDraftReq{}) require.NoError(t, err) require.Equal(t, draft.ID, sent.ID) require.Equal(t, addr[0].ID, sent.AddressID) require.Equal(t, "My subject", sent.Subject) require.Equal(t, []*mail.Address{{Address: "recipient@example.com"}}, sent.ToList) require.Contains(t, sent.LabelIDs, proton.SentLabel) }) }) } func TestServer_AuthDelete(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) { require.NoError(t, c.AuthDelete(ctx)) }) }) } func TestServer_ForceUpgrade(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() s := New() defer s.Close() s.SetMinAppVersion(semver.MustParse("1.0.0")) if _, _, err := s.CreateUser("user", []byte("pass")); err != nil { t.Fatal(err) } m := proton.New( proton.WithHostURL(s.GetHostURL()), proton.WithAppVersion("proton_0.9.0"), proton.WithTransport(proton.InsecureTransport()), ) defer m.Close() var called bool m.AddErrorHandler(proton.AppVersionBadCode, func() { called = true }) if _, _, err := m.NewClientWithLogin(ctx, "user", []byte("pass")); err == nil { t.Fatal(err) } require.True(t, called) } func TestServer_Import(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) { ctx, cancel := context.WithCancel(ctx) defer cancel() user, err := c.GetUser(ctx) require.NoError(t, err) addr, err := c.GetAddresses(ctx) require.NoError(t, err) salt, err := c.GetSalts(ctx) require.NoError(t, err) pass, err := salt.SaltForKey([]byte("pass"), user.Keys.Primary().ID) require.NoError(t, err) _, addrKRs, err := proton.Unlock(user, addr, pass, async.NoopPanicHandler{}) require.NoError(t, err) res := importMessages(ctx, t, c, addr[0].ID, addrKRs[addr[0].ID], []string{}, proton.MessageFlagReceived, 1) require.NoError(t, err) require.Len(t, res, 1) require.Equal(t, proton.SuccessCode, res[0].Code) message, err := c.GetMessage(ctx, res[0].MessageID) require.NoError(t, err) dec, err := message.Decrypt(addrKRs[message.AddressID]) require.NoError(t, err) require.NotEmpty(t, dec) }) }) } func TestServer_Import_Dedup(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) { ctx, cancel := context.WithCancel(ctx) defer cancel() user, err := c.GetUser(ctx) require.NoError(t, err) addr, err := c.GetAddresses(ctx) require.NoError(t, err) salt, err := c.GetSalts(ctx) require.NoError(t, err) pass, err := salt.SaltForKey([]byte("pass"), user.Keys.Primary().ID) require.NoError(t, err) _, addrKRs, err := proton.Unlock(user, addr, pass, async.NoopPanicHandler{}) require.NoError(t, err) subjectGenerator := func() string { return "my Subject" } res := importMessagesWithSubjectGenerator( ctx, t, c, addr[0].ID, addrKRs[addr[0].ID], []string{}, proton.MessageFlagReceived, 1, subjectGenerator, ) require.NoError(t, err) require.Len(t, res, 1) require.Equal(t, proton.SuccessCode, res[0].Code) message, err := c.GetMessage(ctx, res[0].MessageID) require.NoError(t, err) dec, err := message.Decrypt(addrKRs[message.AddressID]) require.NoError(t, err) require.NotEmpty(t, dec) // Import message again should be deduped. resDedup := importMessagesWithSubjectGenerator( ctx, t, c, addr[0].ID, addrKRs[addr[0].ID], []string{}, proton.MessageFlagReceived, 1, subjectGenerator, ) require.NoError(t, err) require.Len(t, resDedup, 1) require.Equal(t, proton.SuccessCode, resDedup[0].Code) require.Equal(t, res[0].MessageID, resDedup[0].MessageID) }) }, WithMessageDedup()) } func TestServer_Labels(t *testing.T) { type add string type rem string tests := []struct { name string flags proton.MessageFlag actions []any wantLabelIDs []string wantError bool }{ { name: "received flag, no actions", flags: proton.MessageFlagReceived, wantLabelIDs: []string{proton.AllMailLabel}, }, { name: "sent flag, no actions", flags: proton.MessageFlagSent, wantLabelIDs: []string{proton.AllMailLabel, proton.AllSentLabel}, }, { name: "scheduled flag, no actions", flags: proton.MessageFlagScheduledSend, wantLabelIDs: []string{proton.AllMailLabel, proton.AllSentLabel}, }, { name: "received flag, add inbox", flags: proton.MessageFlagReceived, actions: []any{add(proton.InboxLabel)}, wantLabelIDs: []string{proton.InboxLabel, proton.AllMailLabel}, }, { name: "sent flag, add sent", flags: proton.MessageFlagSent, actions: []any{add(proton.SentLabel)}, wantLabelIDs: []string{proton.SentLabel, proton.AllMailLabel, proton.AllSentLabel}, }, { name: "scheduled flag, add scheduled", flags: proton.MessageFlagScheduledSend, actions: []any{add(proton.AllScheduledLabel)}, wantLabelIDs: []string{proton.AllScheduledLabel, proton.AllMailLabel, proton.AllSentLabel}, }, { name: "received flag, add inbox then add archive", flags: proton.MessageFlagReceived, actions: []any{add(proton.InboxLabel), add(proton.ArchiveLabel)}, wantLabelIDs: []string{proton.ArchiveLabel, proton.AllMailLabel}, }, { name: "sent flag, add sent then add archive", flags: proton.MessageFlagSent, actions: []any{add(proton.SentLabel), add(proton.ArchiveLabel)}, wantLabelIDs: []string{proton.ArchiveLabel, proton.AllMailLabel, proton.AllSentLabel}, }, { name: "scheduled flag, add scheduled then add archive", flags: proton.MessageFlagScheduledSend, actions: []any{add(proton.AllScheduledLabel), add(proton.ArchiveLabel)}, wantLabelIDs: []string{proton.ArchiveLabel, proton.AllMailLabel, proton.AllSentLabel}, }, { name: "received flag, add inbox then remove inbox", flags: proton.MessageFlagReceived, actions: []any{add(proton.InboxLabel), rem(proton.InboxLabel)}, wantLabelIDs: []string{proton.AllMailLabel}, }, { name: "sent flag, add sent then remove sent", flags: proton.MessageFlagSent, actions: []any{add(proton.SentLabel), rem(proton.SentLabel)}, wantLabelIDs: []string{proton.AllMailLabel, proton.AllSentLabel}, }, { name: "scheduled flag, add scheduled then remove scheduled", flags: proton.MessageFlagScheduledSend, actions: []any{add(proton.AllScheduledLabel), rem(proton.AllScheduledLabel)}, wantLabelIDs: []string{proton.AllMailLabel, proton.AllSentLabel}, }, { name: "received flag, add inbox then remove archive", flags: proton.MessageFlagReceived, actions: []any{add(proton.InboxLabel), rem(proton.ArchiveLabel)}, wantLabelIDs: []string{proton.InboxLabel, proton.AllMailLabel}, }, { name: "sent flag, add sent then remove archive", flags: proton.MessageFlagSent, actions: []any{add(proton.SentLabel), rem(proton.ArchiveLabel)}, wantLabelIDs: []string{proton.SentLabel, proton.AllMailLabel, proton.AllSentLabel}, }, { name: "scheduled flag, add scheduled then remove archive", flags: proton.MessageFlagScheduledSend, actions: []any{add(proton.AllScheduledLabel), rem(proton.ArchiveLabel)}, wantLabelIDs: []string{proton.AllScheduledLabel, proton.AllMailLabel, proton.AllSentLabel}, }, { name: "received flag, add inbox then remove inbox then add archive", flags: proton.MessageFlagReceived, actions: []any{add(proton.InboxLabel), rem(proton.InboxLabel), add(proton.ArchiveLabel)}, wantLabelIDs: []string{proton.ArchiveLabel, proton.AllMailLabel}, }, { name: "sent flag, add sent then remove sent then add archive", flags: proton.MessageFlagSent, actions: []any{add(proton.SentLabel), rem(proton.SentLabel), add(proton.ArchiveLabel)}, wantLabelIDs: []string{proton.ArchiveLabel, proton.AllMailLabel, proton.AllSentLabel}, }, { name: "scheduled flag, add scheduled then remove scheduled then add archive", flags: proton.MessageFlagScheduledSend, actions: []any{add(proton.AllScheduledLabel), rem(proton.AllScheduledLabel), add(proton.ArchiveLabel)}, wantLabelIDs: []string{proton.ArchiveLabel, proton.AllMailLabel, proton.AllSentLabel}, }, { name: "received flag, add starred", flags: proton.MessageFlagReceived, actions: []any{add(proton.StarredLabel)}, wantLabelIDs: []string{proton.StarredLabel, proton.AllMailLabel}, }, { name: "sent flag, add starred", flags: proton.MessageFlagSent, actions: []any{add(proton.StarredLabel)}, wantLabelIDs: []string{proton.StarredLabel, proton.AllMailLabel, proton.AllSentLabel}, }, { name: "scheduled flag, add starred", flags: proton.MessageFlagScheduledSend, actions: []any{add(proton.StarredLabel)}, wantLabelIDs: []string{proton.StarredLabel, proton.AllMailLabel, proton.AllSentLabel}, }, { name: "received flag, add inbox, add starred, remove inbox", flags: proton.MessageFlagReceived, actions: []any{add(proton.InboxLabel), add(proton.StarredLabel), rem(proton.InboxLabel)}, wantLabelIDs: []string{proton.StarredLabel, proton.AllMailLabel}, }, { name: "sent flag, add sent, add starred, remove sent", flags: proton.MessageFlagSent, actions: []any{add(proton.SentLabel), add(proton.StarredLabel), rem(proton.SentLabel)}, wantLabelIDs: []string{proton.StarredLabel, proton.AllMailLabel, proton.AllSentLabel}, }, { name: "scheduled flag, add scheduled, add starred, remove scheduled", flags: proton.MessageFlagScheduledSend, actions: []any{add(proton.AllScheduledLabel), add(proton.StarredLabel), rem(proton.AllScheduledLabel)}, wantLabelIDs: []string{proton.StarredLabel, proton.AllMailLabel, proton.AllSentLabel}, }, { name: "received flag, add trash, remove trash", flags: proton.MessageFlagReceived, actions: []any{add(proton.TrashLabel), rem(proton.TrashLabel)}, wantLabelIDs: []string{proton.AllMailLabel}, }, { name: "sent flag, add trash, remove trash", flags: proton.MessageFlagSent, actions: []any{add(proton.TrashLabel), rem(proton.TrashLabel)}, wantLabelIDs: []string{proton.AllMailLabel, proton.AllSentLabel}, }, { name: "scheduled flag, add trash, remove trash", flags: proton.MessageFlagScheduledSend, actions: []any{add(proton.TrashLabel), rem(proton.TrashLabel)}, wantLabelIDs: []string{proton.AllMailLabel, proton.AllSentLabel}, }, { name: "received flag, add inbox, add trash, remove inbox", flags: proton.MessageFlagReceived, actions: []any{add(proton.InboxLabel), add(proton.TrashLabel), rem(proton.InboxLabel)}, wantLabelIDs: []string{proton.AllMailLabel, proton.TrashLabel}, }, { name: "scheduled & sent flags, add scheduled, add sent", flags: proton.MessageFlagScheduledSend | proton.MessageFlagSent, actions: []any{add(proton.AllScheduledLabel), add(proton.SentLabel)}, wantLabelIDs: []string{proton.AllMailLabel, proton.SentLabel, proton.AllSentLabel}, }, } withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) { ctx, cancel := context.WithCancel(ctx) defer cancel() user, err := c.GetUser(ctx) require.NoError(t, err) addr, err := c.GetAddresses(ctx) require.NoError(t, err) salt, err := c.GetSalts(ctx) require.NoError(t, err) pass, err := salt.SaltForKey([]byte("pass"), user.Keys.Primary().ID) require.NoError(t, err) _, addrKRs, err := proton.Unlock(user, addr, pass, async.NoopPanicHandler{}) require.NoError(t, err) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { res := importMessages(ctx, t, c, addr[0].ID, addrKRs[addr[0].ID], []string{}, tt.flags, 1) require.True(t, (func() error { for _, action := range tt.actions { switch action := action.(type) { case add: if err := c.LabelMessages(ctx, []string{res[0].MessageID}, string(action)); err != nil { return err } case rem: if err := c.UnlabelMessages(ctx, []string{res[0].MessageID}, string(action)); err != nil { return err } } } return nil }() != nil) == tt.wantError) message, err := c.GetMessage(ctx, res[0].MessageID) require.NoError(t, err) // The message should be in the correct labels. require.ElementsMatch(t, tt.wantLabelIDs, message.LabelIDs) // The flags should be preserved after import. require.True(t, message.Flags&tt.flags == tt.flags) }) } }) }) } func TestServer_Import_FlagsAndLabels(t *testing.T) { tests := []struct { name string labelIDs []string flags proton.MessageFlag wantLabelIDs []string wantError bool }{ { name: "received flag --> no label", flags: proton.MessageFlagReceived, wantLabelIDs: []string{proton.AllMailLabel}, }, { name: "received flag --> inbox", labelIDs: []string{proton.InboxLabel}, flags: proton.MessageFlagReceived, wantLabelIDs: []string{proton.InboxLabel, proton.AllMailLabel}, }, { name: "sent flag --> sent", labelIDs: []string{proton.SentLabel}, flags: proton.MessageFlagSent, wantLabelIDs: []string{proton.SentLabel, proton.AllSentLabel, proton.AllMailLabel}, }, { name: "received flag --> sent", labelIDs: []string{proton.SentLabel}, flags: proton.MessageFlagReceived, wantLabelIDs: []string{proton.InboxLabel, proton.AllMailLabel}, }, { name: "sent flag --> inbox", labelIDs: []string{proton.InboxLabel}, flags: proton.MessageFlagSent, wantLabelIDs: []string{proton.SentLabel, proton.AllSentLabel, proton.AllMailLabel}, }, { name: "no flag --> drafts", labelIDs: []string{proton.DraftsLabel}, wantLabelIDs: []string{proton.DraftsLabel, proton.AllDraftsLabel, proton.AllMailLabel}, }, { name: "forbidden: received flag --> All Mail", labelIDs: []string{proton.AllMailLabel}, flags: proton.MessageFlagReceived, wantError: true, }, { name: "forbidden: sent flag --> All Mail", labelIDs: []string{proton.AllMailLabel}, flags: proton.MessageFlagSent, wantError: true, }, { name: "forbidden: received flag --> inbox and all mail", labelIDs: []string{proton.InboxLabel, proton.AllMailLabel}, flags: proton.MessageFlagReceived, wantError: true, }, { name: "forbidden: sent flag --> sent and all mail", labelIDs: []string{proton.SentLabel, proton.AllMailLabel}, flags: proton.MessageFlagSent, wantError: true, }, { name: "forbidden: received flag --> inbox and sent", labelIDs: []string{proton.InboxLabel, proton.SentLabel}, flags: proton.MessageFlagReceived, wantError: true, }, { name: "forbidden: sent flag --> inbox and sent", labelIDs: []string{proton.InboxLabel, proton.SentLabel}, flags: proton.MessageFlagSent, wantError: true, }, { name: "forbidden: received flag --> inbox and archive", labelIDs: []string{proton.InboxLabel, proton.ArchiveLabel}, flags: proton.MessageFlagReceived, wantError: true, }, } withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) { ctx, cancel := context.WithCancel(ctx) defer cancel() user, err := c.GetUser(ctx) require.NoError(t, err) addr, err := c.GetAddresses(ctx) require.NoError(t, err) salt, err := c.GetSalts(ctx) require.NoError(t, err) pass, err := salt.SaltForKey([]byte("pass"), user.Keys.Primary().ID) require.NoError(t, err) _, addrKRs, err := proton.Unlock(user, addr, pass, async.NoopPanicHandler{}) require.NoError(t, err) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { str, err := c.ImportMessages(ctx, addrKRs[addr[0].ID], runtime.NumCPU(), runtime.NumCPU(), []proton.ImportReq{{ Metadata: proton.ImportMetadata{ AddressID: addr[0].ID, Flags: tt.flags, LabelIDs: tt.labelIDs, }, Message: newMessageLiteral("sender@example.com", "recipient@example.com"), }}...) require.NoError(t, err) res, err := stream.Collect(ctx, str) if tt.wantError { require.Error(t, err) } else { require.NoError(t, err) require.Equal(t, proton.SuccessCode, res[0].Code) message, err := c.GetMessage(ctx, res[0].MessageID) require.NoError(t, err) // The message should be in the correct labels. require.ElementsMatch(t, tt.wantLabelIDs, message.LabelIDs) // The flags should be preserved after import. require.True(t, message.Flags&tt.flags == tt.flags) } }) } }) }) } func TestServer_PublicKeys(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { if _, _, err := s.CreateUser("other", []byte("pass")); err != nil { t.Fatal(err) } withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) { intKeys, intType, err := c.GetPublicKeys(ctx, "other@"+s.GetDomain()) require.NoError(t, err) require.Equal(t, proton.RecipientTypeInternal, intType) require.Len(t, intKeys, 1) extKeys, extType, err := c.GetPublicKeys(ctx, "other@example.com") require.NoError(t, err) require.Equal(t, proton.RecipientTypeExternal, extType) require.Len(t, extKeys, 0) }) }) } func TestServer_Proxy(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { var calls []Call s.AddCallWatcher(func(call Call) { calls = append(calls, call) }) withUser(ctx, t, s, m, "user", "pass", func(_ *proton.Client) { proxy := New( WithProxyOrigin(s.GetHostURL()), WithProxyTransport(proton.InsecureTransport()), ) defer proxy.Close() m := proton.New( proton.WithHostURL(proxy.GetProxyURL()), proton.WithTransport(proton.InsecureTransport()), ) defer m.Close() // Login -- the call should be proxied to the upstream server. c, _, err := m.NewClientWithLogin(ctx, "user", []byte("pass")) require.NoError(t, err) defer c.Close() // The results of the call should be correct. user, err := c.GetUser(ctx) require.NoError(t, err) require.Equal(t, "user", user.Name) }) // Assert that the calls were proxied. require.Greater(t, len(calls), 0) }) } func TestServer_Proxy_Cache(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withUser(ctx, t, s, m, "user", "pass", func(_ *proton.Client) { proxy := New( WithProxyOrigin(s.GetHostURL()), WithProxyTransport(proton.InsecureTransport()), WithAuthCacher(NewAuthCache()), ) defer proxy.Close() // Need to skip verifying the server proofs for the proxy cache feature to work! m := proton.New( proton.WithHostURL(proxy.GetProxyURL()), proton.WithTransport(proton.InsecureTransport()), proton.WithSkipVerifyProofs(), ) defer m.Close() // Login 3 times; we should produce 1 unique auth. require.Len(t, xslices.Unique(iterator.Collect(iterator.Map(iterator.Counter(3), func(int) string { c, auth, err := m.NewClientWithLogin(ctx, "user", []byte("pass")) require.NoError(t, err) defer c.Close() return auth.UID }))), 1) }) }) } func TestServer_Proxy_AuthDelete(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withUser(ctx, t, s, m, "user", "pass", func(_ *proton.Client) { proxy := New( WithProxyOrigin(s.GetHostURL()), WithProxyTransport(proton.InsecureTransport()), WithAuthCacher(NewAuthCache()), ) defer proxy.Close() // Need to skip verifying the server proofs for the proxy cache feature to work! m := proton.New( proton.WithHostURL(proxy.GetProxyURL()), proton.WithTransport(proton.InsecureTransport()), ) defer m.Close() // Watch for login -- the calls should be proxied. var login []Call s.AddCallWatcher(func(call Call) { login = append(login, call) }) // Login -- the call should be proxied to the upstream server. c, _, err := m.NewClientWithLogin(ctx, "user", []byte("pass")) require.NoError(t, err) defer c.Close() // Assert that the login was proxied. require.NotEmpty(t, len(login)) // Watch for logout -- logout should not be proxied to the upstream server. var logout []Call s.AddCallWatcher(func(call Call) { logout = append(logout, call) }) // Logout -- the call should not be proxied to the upstream server. require.NoError(t, c.AuthDelete(ctx)) // Assert that the logout was not proxied! require.Empty(t, len(logout)) }) }) } func TestServer_RealProxy(t *testing.T) { username := os.Getenv("GO_PROTON_API_TEST_USERNAME") password := os.Getenv("GO_PROTON_API_TEST_PASSWORD") if username == "" || password == "" { t.Skip("skipping test, set the username and password to run") } ctx, cancel := context.WithCancel(context.Background()) defer cancel() proxy := New() defer proxy.Close() m := proton.New( proton.WithHostURL(proxy.GetProxyURL()), proton.WithTransport(proton.InsecureTransport()), ) defer m.Close() // Login -- the call should be proxied to the upstream server. c, _, err := m.NewClientWithLogin(ctx, username, []byte(password)) require.NoError(t, err) defer c.Close() // The results of the call should be correct. user, err := c.GetUser(ctx) require.NoError(t, err) require.Equal(t, username, user.Name) } func TestServer_RealProxy_Cache(t *testing.T) { username := os.Getenv("GO_PROTON_API_TEST_USERNAME") password := os.Getenv("GO_PROTON_API_TEST_PASSWORD") if username == "" || password == "" { t.Skip("skipping test, set the username and password to run") } ctx, cancel := context.WithCancel(context.Background()) defer cancel() proxy := New(WithAuthCacher(NewAuthCache())) defer proxy.Close() m := proton.New( proton.WithHostURL(proxy.GetProxyURL()), proton.WithTransport(proton.InsecureTransport()), proton.WithSkipVerifyProofs(), ) defer m.Close() // Login 3 times; we should produce 1 unique auth. require.Len(t, xslices.Unique(iterator.Collect(iterator.Map(iterator.Counter(3), func(int) string { c, auth, err := m.NewClientWithLogin(ctx, username, []byte(password)) require.NoError(t, err) defer c.Close() return auth.UID }))), 1) } func TestServer_Messages_Fetch(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) { withMessages(ctx, t, c, "pass", 1000, func(messageIDs []string) { ctl := proton.NewNetCtl() mm := proton.New( proton.WithHostURL(s.GetHostURL()), proton.WithTransport(ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true})), ) defer mm.Close() cc, _, err := mm.NewClientWithLogin(ctx, "user", []byte("pass")) require.NoError(t, err) defer cc.Close() total := countBytesRead(ctl, func() { res, err := stream.Collect(ctx, getFullMessages(ctx, cc, runtime.NumCPU(), runtime.NumCPU(), messageIDs...)) require.NoError(t, err) require.NotEmpty(t, res) }) ctl.SetReadLimit(total / 2) require.Less(t, countBytesRead(ctl, func() { res, err := stream.Collect(ctx, getFullMessages(ctx, cc, runtime.NumCPU(), runtime.NumCPU(), messageIDs...)) require.Error(t, err) require.Empty(t, res) }), total) ctl.SetReadLimit(0) require.Equal(t, countBytesRead(ctl, func() { res, err := stream.Collect(ctx, getFullMessages(ctx, cc, runtime.NumCPU(), runtime.NumCPU(), messageIDs...)) require.NoError(t, err) require.NotEmpty(t, res) }), total) }) }) }, WithTLS(false)) } func TestServer_Status(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withUser(ctx, t, s, m, "user", "pass", func(*proton.Client) { ctl := proton.NewNetCtl() mm := proton.New( proton.WithHostURL(s.GetHostURL()), proton.WithTransport(ctl.NewRoundTripper(&tls.Config{InsecureSkipVerify: true})), ) defer mm.Close() statusCh := make(chan proton.Status, 1) mm.AddStatusObserver(func(status proton.Status) { statusCh <- status }) cc, _, err := mm.NewClientWithLogin(ctx, "user", []byte("pass")) require.NoError(t, err) defer cc.Close() { user, err := cc.GetUser(ctx) require.NoError(t, err) require.Equal(t, "user", user.Name) } ctl.SetCanRead(false) { _, err := cc.GetUser(ctx) require.Error(t, err) } require.Equal(t, proton.StatusDown, <-statusCh) }) }, WithTLS(false)) } func TestServer_Labels_Duplicates(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) { req := proton.CreateLabelReq{ Name: uuid.NewString(), Color: "#f66", Type: proton.LabelTypeLabel, } label, err := c.CreateLabel(context.Background(), req) require.NoError(t, err) require.Equal(t, req.Name, label.Name) _, err = c.CreateLabel(context.Background(), req) require.Error(t, err) }) }) } func TestServer_Labels_Duplicates_Update(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) { label1, err := c.CreateLabel(context.Background(), proton.CreateLabelReq{ Name: uuid.NewString(), Color: "#f66", Type: proton.LabelTypeLabel, }) require.NoError(t, err) label2, err := c.CreateLabel(context.Background(), proton.CreateLabelReq{ Name: uuid.NewString(), Color: "#f66", Type: proton.LabelTypeLabel, }) require.NoError(t, err) // Updating label1 with label2's name should fail. _, err = c.UpdateLabel(context.Background(), label1.ID, proton.UpdateLabelReq{ Name: label2.Name, Color: label1.Color, }) require.Error(t, err) // Updating label1's color while preserving its name should succeed. _, err = c.UpdateLabel(context.Background(), label1.ID, proton.UpdateLabelReq{ Name: label1.Name, Color: "#f00", }) require.NoError(t, err) // Updating label1 with a new name should succeed. _, err = c.UpdateLabel(context.Background(), label1.ID, proton.UpdateLabelReq{ Name: uuid.NewString(), Color: label1.Color, }) require.NoError(t, err) }) }) } func TestServer_Labels_Subfolders(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) { parent, err := c.CreateLabel(context.Background(), proton.CreateLabelReq{ Name: uuid.NewString(), Color: "#f66", Type: proton.LabelTypeFolder, }) require.NoError(t, err) child, err := c.CreateLabel(context.Background(), proton.CreateLabelReq{ Name: uuid.NewString(), ParentID: parent.ID, Color: "#f66", Type: proton.LabelTypeFolder, }) require.NoError(t, err) require.Equal(t, []string{parent.Name, child.Name}, child.Path) child2, err := c.CreateLabel(context.Background(), proton.CreateLabelReq{ Name: uuid.NewString(), ParentID: child.ID, Color: "#f66", Type: proton.LabelTypeFolder, }) require.NoError(t, err) require.Equal(t, []string{parent.Name, child.Name, child2.Name}, child2.Path) }) }) } func TestServer_Labels_Subfolders_Reassign(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) { parent1, err := c.CreateLabel(context.Background(), proton.CreateLabelReq{ Name: uuid.NewString(), Color: "#f66", Type: proton.LabelTypeFolder, }) require.NoError(t, err) parent2, err := c.CreateLabel(context.Background(), proton.CreateLabelReq{ Name: uuid.NewString(), Color: "#f66", Type: proton.LabelTypeFolder, }) require.NoError(t, err) // Create a child initially under parent1. child, err := c.CreateLabel(context.Background(), proton.CreateLabelReq{ Name: uuid.NewString(), ParentID: parent1.ID, Color: "#f66", Type: proton.LabelTypeFolder, }) require.NoError(t, err) require.Equal(t, []string{parent1.Name, child.Name}, child.Path) // Reassign the child to parent2. child2, err := c.UpdateLabel(context.Background(), child.ID, proton.UpdateLabelReq{ Name: child.Name, Color: child.Color, ParentID: parent2.ID, }) require.NoError(t, err) require.Equal(t, []string{parent2.Name, child.Name}, child2.Path) // Reassign the child to no parent. child3, err := c.UpdateLabel(context.Background(), child.ID, proton.UpdateLabelReq{ Name: child2.Name, Color: child2.Color, ParentID: "", }) require.NoError(t, err) require.Equal(t, []string{child3.Name}, child3.Path) }) }) } func TestServer_Labels_Subfolders_DeleteParentWithChildren(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) { parent, err := c.CreateLabel(context.Background(), proton.CreateLabelReq{ Name: uuid.NewString(), Color: "#f66", Type: proton.LabelTypeFolder, }) require.NoError(t, err) child, err := c.CreateLabel(context.Background(), proton.CreateLabelReq{ Name: uuid.NewString(), ParentID: parent.ID, Color: "#f66", Type: proton.LabelTypeFolder, }) require.NoError(t, err) require.Equal(t, []string{parent.Name, child.Name}, child.Path) other, err := c.CreateLabel(context.Background(), proton.CreateLabelReq{ Name: uuid.NewString(), Color: "#f66", Type: proton.LabelTypeFolder, }) require.NoError(t, err) // Get labels before. before, err := c.GetLabels(context.Background(), proton.LabelTypeFolder) require.NoError(t, err) // Delete the parent. require.NoError(t, c.DeleteLabel(context.Background(), parent.ID)) // Get labels after. after, err := c.GetLabels(context.Background(), proton.LabelTypeFolder) require.NoError(t, err) // Both parent and child are deleted. require.Equal(t, len(before)-2, len(after)) // The only label left is the other one. require.Equal(t, other.ID, after[0].ID) }) }) } func TestServer_AddressCreateDelete(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) { user, err := c.GetUser(context.Background()) require.NoError(t, err) // Create an address. alias, err := s.CreateAddress(user.ID, "alias@example.com", []byte("pass")) require.NoError(t, err) // The user should have two addresses, both enabled. { addr, err := c.GetAddresses(context.Background()) require.NoError(t, err) require.Len(t, addr, 2) require.Equal(t, addr[0].Status, proton.AddressStatusEnabled) require.Equal(t, addr[1].Status, proton.AddressStatusEnabled) } // Disable the alias. require.NoError(t, c.DisableAddress(context.Background(), alias)) // The user should have two addresses, the primary enabled and the alias disabled. { addr, err := c.GetAddresses(context.Background()) require.NoError(t, err) require.Len(t, addr, 2) require.Equal(t, addr[0].Status, proton.AddressStatusEnabled) require.Equal(t, addr[1].Status, proton.AddressStatusDisabled) } // Delete the alias. require.NoError(t, c.DeleteAddress(context.Background(), alias)) // The user should have one address, the primary enabled. { addr, err := c.GetAddresses(context.Background()) require.NoError(t, err) require.Len(t, addr, 1) require.Equal(t, addr[0].Status, proton.AddressStatusEnabled) } }) }) } func TestServer_AddressOrder(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) { user, err := c.GetUser(context.Background()) require.NoError(t, err) primary, err := c.GetAddresses(context.Background()) require.NoError(t, err) // Create 3 additional addresses. addr1, err := s.CreateAddress(user.ID, "addr1@example.com", []byte("pass")) require.NoError(t, err) addr2, err := s.CreateAddress(user.ID, "addr2@example.com", []byte("pass")) require.NoError(t, err) addr3, err := s.CreateAddress(user.ID, "addr3@example.com", []byte("pass")) require.NoError(t, err) addresses, err := c.GetAddresses(context.Background()) require.NoError(t, err) // Check the order. require.Equal(t, primary[0].ID, addresses[0].ID) require.Equal(t, addr1, addresses[1].ID) require.Equal(t, addr2, addresses[2].ID) require.Equal(t, addr3, addresses[3].ID) // Update the order. require.NoError(t, c.OrderAddresses(ctx, proton.OrderAddressesReq{ AddressIDs: []string{addr3, addr2, addr1, primary[0].ID}, })) // Check the order. addresses, err = c.GetAddresses(context.Background()) require.NoError(t, err) require.Equal(t, addr3, addresses[0].ID) require.Equal(t, addr2, addresses[1].ID) require.Equal(t, addr1, addresses[2].ID) require.Equal(t, primary[0].ID, addresses[3].ID) }) }) } func TestServer_MailSettings(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) { settings, err := c.GetMailSettings(context.Background()) require.NoError(t, err) require.Equal(t, proton.Bool(false), settings.AttachPublicKey) updated, err := c.SetAttachPublicKey(context.Background(), proton.SetAttachPublicKeyReq{AttachPublicKey: true}) require.NoError(t, err) require.Equal(t, proton.Bool(true), updated.AttachPublicKey) }) }) } func TestServer_UserSettings(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) { settings, err := c.GetUserSettings(context.Background()) require.NoError(t, err) require.Equal(t, proton.SettingEnabled, settings.Telemetry) require.Equal(t, proton.SettingEnabled, settings.CrashReports) settings, err = c.SetUserSettingsTelemetry(context.Background(), proton.SetTelemetryReq{Telemetry: proton.SettingDisabled}) require.NoError(t, err) require.Equal(t, proton.SettingDisabled, settings.Telemetry) require.Equal(t, proton.SettingEnabled, settings.CrashReports) settings, err = c.SetUserSettingsCrashReports(context.Background(), proton.SetCrashReportReq{CrashReports: proton.SettingDisabled}) require.NoError(t, err) require.Equal(t, proton.SettingDisabled, settings.Telemetry) require.Equal(t, proton.SettingDisabled, settings.CrashReports) settings, err = c.SetUserSettingsTelemetry(context.Background(), proton.SetTelemetryReq{Telemetry: 2}) require.Error(t, err) settings, err = c.SetUserSettingsCrashReports(context.Background(), proton.SetCrashReportReq{CrashReports: 2}) require.Error(t, err) settings, err = c.GetUserSettings(context.Background()) require.NoError(t, err) require.Equal(t, proton.SettingDisabled, settings.Telemetry) require.Equal(t, proton.SettingDisabled, settings.CrashReports) }) }) } func TestServer_Domains(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { domains, err := m.GetDomains(ctx) require.NoError(t, err) require.Equal(t, []string{s.GetDomain()}, domains) }) } func TestServer_StatusHooks(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { s.AddStatusHook(func(req *http.Request) (int, bool) { if req.URL.Path == "/core/v4/addresses" { return http.StatusBadRequest, true } return 0, false }) withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) { addr, err := c.GetAddresses(context.Background()) require.Error(t, err) require.Nil(t, addr) if apiErr := new(proton.APIError); errors.As(err, &apiErr) { require.Equal(t, http.StatusBadRequest, apiErr.Status) } else { require.Fail(t, "expected APIError") } }) }) } func TestServer_SendDataEvent(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) { // Send data event minimal err := c.SendDataEvent(context.Background(), proton.SendStatsReq{MeasurementGroup: "proton.any.test"}) require.NoError(t, err) // Send data event Full. var req proton.SendStatsReq req.MeasurementGroup = "proton.any.test" req.Event = "test" req.Values = map[string]any{"string": "string", "integer": 42} req.Dimensions = map[string]any{"string": "string", "integer": 42} err = c.SendDataEvent(context.Background(), req) require.NoError(t, err) // Send bad data event. err = c.SendDataEvent(context.Background(), proton.SendStatsReq{}) require.Error(t, err) }) }) } func TestServer_SendDataEventMultiple(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) { // Send multiple minimal data event. var req proton.SendStatsMultiReq req.EventInfo = append(req.EventInfo, proton.SendStatsReq{MeasurementGroup: "proton.any.test"}) req.EventInfo = append(req.EventInfo, proton.SendStatsReq{MeasurementGroup: "proton.any.test2"}) err := c.SendDataEventMultiple(context.Background(), req) require.NoError(t, err) // send empty multiple data event. err = c.SendDataEventMultiple(context.Background(), proton.SendStatsMultiReq{}) require.NoError(t, err) // Send bad multiple data event. var badReq proton.SendStatsMultiReq badReq.EventInfo = append(badReq.EventInfo, proton.SendStatsReq{}) err = c.SendDataEventMultiple(context.Background(), badReq) require.Error(t, err) }) }) } func TestServer_GetMessageGroupCount(t *testing.T) { withServer(t, func(ctx context.Context, s *Server, m *proton.Manager) { withUser(ctx, t, s, m, "user", "pass", func(c *proton.Client) { ctx, cancel := context.WithCancel(ctx) defer cancel() user, err := c.GetUser(ctx) require.NoError(t, err) addr, err := c.GetAddresses(ctx) require.NoError(t, err) salt, err := c.GetSalts(ctx) require.NoError(t, err) pass, err := salt.SaltForKey([]byte("pass"), user.Keys.Primary().ID) require.NoError(t, err) _, addrKRs, err := proton.Unlock(user, addr, pass, async.NoopPanicHandler{}) require.NoError(t, err) expected := []proton.MessageGroupCount{ { LabelID: proton.InboxLabel, Total: 10, Unread: 4, }, { LabelID: proton.SentLabel, Total: 4, Unread: 0, }, { LabelID: proton.ArchiveLabel, Total: 3, Unread: 0, }, { LabelID: proton.TrashLabel, Total: 6, Unread: 0, }, { LabelID: proton.AllMailLabel, Total: 23, Unread: 4, }, } for _, st := range expected { if st.LabelID == proton.AllMailLabel { continue } var flags proton.MessageFlag if st.LabelID == proton.InboxLabel { flags = proton.MessageFlagReceived } else if st.LabelID == proton.SentLabel { flags = proton.MessageFlagSent } res := importMessages(ctx, t, c, addr[0].ID, addrKRs[addr[0].ID], []string{}, flags, st.Total) msgIDs := xslices.Map(res, func(r proton.ImportRes) string { return r.MessageID }) require.NoError(t, c.LabelMessages(ctx, msgIDs, st.LabelID)) if st.Unread == 0 { require.NoError(t, c.MarkMessagesRead(ctx, msgIDs...)) } else { require.NoError(t, c.MarkMessagesRead(ctx, msgIDs[st.Unread:]...)) } } counts, err := c.GetGroupedMessageCount(ctx) require.NoError(t, err) counts = xslices.Filter(counts, func(t proton.MessageGroupCount) bool { switch t.LabelID { case proton.InboxLabel, proton.TrashLabel, proton.ArchiveLabel, proton.AllMailLabel, proton.SentLabel: return true default: return false } }) require.NotEmpty(t, counts) require.ElementsMatch(t, expected, counts) }) }) } func withServer(t *testing.T, fn func(ctx context.Context, s *Server, m *proton.Manager), opts ...Option) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() s := New(opts...) defer s.Close() m := proton.New( proton.WithHostURL(s.GetHostURL()), proton.WithCookieJar(newTestCookieJar()), proton.WithTransport(proton.InsecureTransport()), ) defer m.Close() fn(ctx, s, m) } func withUser(ctx context.Context, t *testing.T, s *Server, m *proton.Manager, user, pass string, fn func(c *proton.Client)) { _, _, err := s.CreateUser(user, []byte(pass)) require.NoError(t, err) c, _, err := m.NewClientWithLogin(ctx, user, []byte(pass)) require.NoError(t, err) defer c.Close() fn(c) } func withMessages(ctx context.Context, t *testing.T, c *proton.Client, pass string, count int, fn func([]string)) { user, err := c.GetUser(ctx) require.NoError(t, err) addr, err := c.GetAddresses(ctx) require.NoError(t, err) salt, err := c.GetSalts(ctx) require.NoError(t, err) keyPass, err := salt.SaltForKey([]byte(pass), user.Keys.Primary().ID) require.NoError(t, err) _, addrKRs, err := proton.Unlock(user, addr, keyPass, async.NoopPanicHandler{}) require.NoError(t, err) fn(xslices.Map(importMessages(ctx, t, c, addr[0].ID, addrKRs[addr[0].ID], []string{}, proton.MessageFlagReceived, count), func(res proton.ImportRes) string { return res.MessageID })) } func importMessagesWithSubjectGenerator( ctx context.Context, t *testing.T, c *proton.Client, addrID string, addrKR *crypto.KeyRing, labelIDs []string, flags proton.MessageFlag, count int, subjectGenerator func() string, ) []proton.ImportRes { req := iterator.Collect(iterator.Map(iterator.Counter(count), func(int) proton.ImportReq { return proton.ImportReq{ Metadata: proton.ImportMetadata{ AddressID: addrID, LabelIDs: labelIDs, Flags: flags, Unread: true, }, Message: newMessageLiteralWithSubject("sender@example.com", "recipient@example.com", subjectGenerator()), } })) str, err := c.ImportMessages(ctx, addrKR, runtime.NumCPU(), runtime.NumCPU(), req...) require.NoError(t, err) res, err := stream.Collect(ctx, str) require.NoError(t, err) return res } func importMessages( ctx context.Context, t *testing.T, c *proton.Client, addrID string, addrKR *crypto.KeyRing, labelIDs []string, flags proton.MessageFlag, count int, ) []proton.ImportRes { return importMessagesWithSubjectGenerator(ctx, t, c, addrID, addrKR, labelIDs, flags, count, func() string { return uuid.NewString() }) } func countBytesRead(ctl *proton.NetCtl, fn func()) uint64 { var read uint64 ctl.OnRead(func(b []byte) { atomic.AddUint64(&read, uint64(len(b))) }) fn() return read } type testCookieJar struct { cookies map[string][]*http.Cookie lock sync.RWMutex } func newTestCookieJar() *testCookieJar { return &testCookieJar{ cookies: make(map[string][]*http.Cookie), } } func (j *testCookieJar) SetCookies(u *url.URL, cookies []*http.Cookie) { j.lock.Lock() defer j.lock.Unlock() j.cookies[u.Host] = cookies } func (j *testCookieJar) Cookies(u *url.URL) []*http.Cookie { j.lock.RLock() defer j.lock.RUnlock() return j.cookies[u.Host] } func must[T any](t T, err error) T { if err != nil { panic(err) } return t } func elementsMatch[T comparable](want, got []T) bool { if len(want) != len(got) { return false } for _, w := range want { if !slices.Contains(got, w) { return false } } return true } func getFullMessages(ctx context.Context, c *proton.Client, workers, buffer int, messageIDs ...string) stream.Stream[proton.FullMessage] { scheduler := proton.NewSequentialScheduler() attachmentStorageProvider := proton.NewDefaultAttachmentAllocator() return parallel.MapStream( ctx, stream.FromIterator(iterator.Slice(messageIDs)), workers, buffer, func(ctx context.Context, messageID string) (proton.FullMessage, error) { return c.GetFullMessage(ctx, messageID, scheduler, attachmentStorageProvider) }, ) } go-proton-api-1.0.0/server/users.go000066400000000000000000000005321447642273300172150ustar00rootroot00000000000000package server import ( "net/http" "github.com/gin-gonic/gin" ) func (s *Server) handleGetUsers() gin.HandlerFunc { return func(c *gin.Context) { user, err := s.b.GetUser(c.GetString("UserID")) if err != nil { c.AbortWithStatus(http.StatusInternalServerError) return } c.JSON(http.StatusOK, gin.H{ "User": user, }) } } go-proton-api-1.0.0/share.go000066400000000000000000000013571447642273300156560ustar00rootroot00000000000000package proton import ( "context" "github.com/go-resty/resty/v2" ) func (c *Client) ListShares(ctx context.Context, all bool) ([]ShareMetadata, error) { var res struct { Shares []ShareMetadata } if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { if all { r.SetQueryParam("ShowAll", "1") } return r.SetResult(&res).Get("/drive/shares") }); err != nil { return nil, err } return res.Shares, nil } func (c *Client) GetShare(ctx context.Context, shareID string) (Share, error) { var res struct { Share } if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetResult(&res).Get("/drive/shares/" + shareID) }); err != nil { return Share{}, err } return res.Share, nil } go-proton-api-1.0.0/share_types.go000066400000000000000000000047271447642273300171060ustar00rootroot00000000000000package proton import "github.com/ProtonMail/gopenpgp/v2/crypto" type ShareMetadata struct { ShareID string // Encrypted share ID LinkID string // Encrypted link ID to which the share points (root of share). VolumeID string // Encrypted volume ID on which the share is mounted Type ShareType // Type of share State ShareState // The state of the share (active, deleted) CreationTime int64 // Creation time of the share in Unix time ModifyTime int64 // Last modification time of the share in Unix time Creator string // Creator email address Flags ShareFlags // The flag bitmap Locked bool // Whether the share is locked VolumeSoftDeleted bool // Was the volume soft deleted } // Share is an entry point to a location in the file structure (Volume). // It points to a file or folder anywhere in the tree and holds a key called the ShareKey. // To access a file or folder in Drive, a user must be a member of a share. // The membership information is tied to a specific address, and key. // This key then allows the user to decrypt the share key, giving access to the file system rooted at that share. type Share struct { ShareMetadata AddressID string // Encrypted address ID AddressKeyID string // Encrypted address key ID Key string // The private ShareKey, encrypted with a passphrase Passphrase string // The encrypted passphrase PassphraseSignature string // The signature of the passphrase } func (s Share) GetKeyRing(addrKR *crypto.KeyRing) (*crypto.KeyRing, error) { enc, err := crypto.NewPGPMessageFromArmored(s.Passphrase) if err != nil { return nil, err } dec, err := addrKR.Decrypt(enc, nil, crypto.GetUnixTime()) if err != nil { return nil, err } sig, err := crypto.NewPGPSignatureFromArmored(s.PassphraseSignature) if err != nil { return nil, err } if err := addrKR.VerifyDetached(dec, sig, crypto.GetUnixTime()); err != nil { return nil, err } lockedKey, err := crypto.NewKeyFromArmored(s.Key) if err != nil { return nil, err } unlockedKey, err := lockedKey.Unlock(dec.GetBinary()) if err != nil { return nil, err } return crypto.NewKeyRing(unlockedKey) } type ShareType int const ( ShareTypeMain ShareType = 1 ShareTypeStandard ShareType = 2 ShareTypeDevice ShareType = 3 ) type ShareState int const ( ShareStateActive ShareState = 1 ShareStateDeleted ShareState = 2 ) type ShareFlags int const ( NoFlags ShareFlags = iota PrimaryShare ) go-proton-api-1.0.0/testdata/000077500000000000000000000000001447642273300160305ustar00rootroot00000000000000go-proton-api-1.0.0/testdata/body.pgp000066400000000000000000000111321447642273300174730ustar00rootroot00000000000000-----BEGIN PGP MESSAGE----- Version: ProtonMail wcBMAwhkvvXurdhrAQf9GTfrNtYdoXtqGDLEEdTZMnnZ56rWUYJASCBnlTZt QeKMG1t4S+YurOWrCJWiRQoYsS67ljHzs0cehzchWlTs0dAM5JKBmtFW0ZmV U9b8vtXwLCHu2S9mH/Yz9MJaR1w8/M5IED4cbFR4If0sSyAEIuXGTBQF3LNc 5bajT9qN1EgeL0ILMEBstLyO95QdPqNkOTwUrzxEEmQXkzjWw0rkdin19UfC s6Z7Ej2YZHQBUH8VAYQDnHKccvYDWcCpO4+r6MlEXQEDGbwt6W3wjUO7yu7R RvK1A8l8JUOCGf0PJnXHr+HzsnkIbmZ9t4AJn2iSvS+MceeYBnGKFNK7Mmi+ 0NLLegFp8UtMpYZTyY/WnjvpG9KEi0RxIKAJOuwZDaGJXhIoyqOqLxiHEnbZ JhDXK5nuJ5UAsrmsbjvzreIbNw8ToEMPd/9N3WwJB28gYNVMJdYluQACfvI1 mk+yd1JR2dPfyP+SA7KMaUPXwsMWN86ViLZE0h+IANUkxAGGMAPU2ikhlbi0 qZzRLx146sKHncZNfn3ge2noFNeJLSUvm5uzOxIjOgjKLyA33SCLWG7GYzaq 4vURCczZ7kJ/Y5M7P53y1uzMdp0AFsqhD0zMQZ35qmGHpmpN44srVxLk0piw TchQcgs/F5lnW+HFsiILuWjwts1hbXeEeMugDNiuqYazP3qHmR8aSDyFZusB fUuNegtuDXaI8HD6mIscToO1Jc2os0fPk/FMeo4vVFBQCv7hMk/hx8ytcDh7 c1963+Rjopq97F8GhDc5S/7xO3oCd+1WSFN7BfyNL0NmfUWM4Fur67Ru0VmW QQmesAeYnGzvCEfGA+tDZYp6gu4/HBfHPMjZs+27MgnkvdOHjLaTQRSmKLeo Oh3BECB9lT0qO85k8+wPZzFQvHpDOwxbW+7u5IZBolGNmO3erE/5NitTWOsJ qfbgYg34+3Juim5E4n8a5npcSMewea7IFtlebOoKSTJ6F8SX4lhky/HdcCjR jzHRh+bpnLEApufKfRplEQLOZoxxLJM++ey2MT32gZLjAXMjAt5eRyhOOrQe v6GdSDbDtMKWhEkTAbBbh2lgi8ZwW1w05pGXmeXXxP4wjBeG1hj4IfT1tFa3 w4IlOU0wzk+jaMBA0Mg4I5VUGNIvhUzU6td8GYN+ZCOQedSJZ/AkkJxeXRLD Sy0NTwW+/tyT9WDL7hfg8T2Dh9dzdXuHS23FAZ2irrzy/bIB/tAOtdqGgf7e KW1JWu5TjKzcMx2aP+ZpOJG6ieAqbU0xPCU2d7HdhHSV0tSwiDfEnV2u2lOJ HtInII03NOhmOe1MzWmmztBw4e1bEtVI62rRup7pFPubIlOiC1fmWzyFmlvY 3g/QF69j2Lg7UjBBlZ1IZ17TMe3oCr5hNZED5vdz8i6ECbhsaGFxWddAoyea V2TEW19pAc5/fs+v87ZTRZzoMXK+uTH+wVWT4AYYZ3GUWod3jiWcjeQqjNB+ VhkUSNbBYdJ7CBpUu5eb82N5ZE6v8jV1dw+kKgRE4DRzASP7Z/2yUi+ze2D5 i0TKyFtqMwfw3ptckUAWE7r57B7dbNY8b5woKaVBCkVyEnY0eDIqQUU/RoZq tgKki+4KgwNlIbT6OpxPVAog6fY6M3/3crO52cqWQPX/Uhp0J5rZXLEMRdmX M/7UvFcW96cYtOeCucux9WnNPWqSJ4ddzwt6lxqptsZU7jZS0IGRLqhAn23e u1/NdQpK91YIDeFIlMdyc9/h8VEnaSU/LE1+28INGJsET1dUcw4Qwxb1uOur kleYddoFs4tgi44uxNq3o+Pw9EfFFPu2t1U6FESN4qc+aJ/Sv8Oipi2fyDVv 7iD5XrJ4Oe4a2HXZLfvAaQig4cfa0x4Hqiqp/YSjWvyIZpIoiI+AKDS3bNXs e47mADlHpv4l/rKA/Q5Qjs8zRB74So0Y0oG1qXNRbcPADUrSyMAxqInz/Q3c 9CA35GSp7js2V5rKyv2a11z6VdLtPw/e3KsEvdVmx/7DxSeGvXO5lwN6RluK DJgIl9/7dVLuNvuVXk+tDCrs2dUHOf+a5N34x9B9AQGT3Ixo80le8e9+U5Bc r30iycgQcSISUw5NKCpOTj5BuabYfDQYKCA28l+lrEGq1+q4vumMQgGxEJJn gTqJTyQOE0abbohHl4/+TkKWNdWTwth0eZHHbs0Xrjf1mq7GnCh49QxkIxQt VDwxE/BkiO5SkU0CWdMGLL2LtaincqGtdIlZB7emRB8qdZOQt+5wp+XhiXbG DHoPV3739XzxojZUrOzyh6aoezXmntoHqucr/V85BwSJ5F8rjZiHQDoPuwmf oE3ABLMTC3T9rO2MfxepGbFxNmBaXUrO/muChvq29NO83o+KdPP6YJVu/qAL utMeEhBPgd+if956Ph2twjM1qKO5Hp1Z178KM+2R2PDz1kNmKtsNfiNJhrQ9 1qmLgNCp9qZCFqExHRHcmpgv94KuUYWyONZYBKBCAJKAPM5IRVsryOYUbUlS jeVO3aZES89nCh0zF4XiJ0LaI6kxCCzje2+cZBY5jxiUw4I5F4wOawCwSsv2 mLAuhBV/H0KWq9o0e2RlXMRtrDTgfdnRGHnfXqbQDIgsgQmga7zUej9cVOz2 HA6Oml/h8TjJPQVXkrTuKN4NohWTsz8QX8ISdXuaWoMn9MVq57aYweFV97tT UdzLvJCioq7+AiFu2WN9NmU/bmN2ye5gfn8Qw7eKTfD3VJFkF8Y6lw6RLm0D loMHIv9gaECLcxiJyINjttH+Shx2Nl2Pj5WOgixK+SnClAQ5Uzk7IZTj4CdJ YbIoXg3ukyTeGUaG2flwSbsRUBz0PFPAznTtovtgY55tne4EI7qpdwbo5+fi CBeGylEoozfuKcsnaOJJ4F6pdeBuhwoAV4drvI/qNFRhs82pV85n9aP+7XM/ f16vdig5rn1zodef/mlz989pwfDwjfwabQliNAfczX+rAur8CQcUs62qs41x vQjeSc30NQzR1+B2MlsQtuKxk7HJUlp4tpuldAfC1ffwSThfXKqjSpbDOTRU xVz1xlnjZoQQPhmrNCj0CqpKbqsjq5qEyvgM4llvXedw7wDdXIp52fIlwwO2 CwgXS3sTAa46YWkbqG+iIEFowjUse0T8tK/+3dtnYkS5Hen2rp1lUdl/mm2Z CPqJjOSXGs5Iz+4A2JOSBscYX7GDh/IcF0d9r55K4f5JvJqwYr2MnrQcOJLJ e3EC6weTi0d1a59eafY7m6b4cS2zRQPUtvcOVzkVXbvXMyjWi+WDJYl8EFPr mIOa/Ij7FCiuEOMlazwN5Ot8AIh0qSV9VubqVPUEyc6DozjqdxDL0ZLJxlTy fWTwR//spAgr/pRqhZxNCvfHFn1NeF7TFbDAPZojfEChF9n1u951jgq5sZ01 aXAiIo5bwVPc/FxHLEf0Zu5f/SS5ZNgIVquJGDNwPVD5S6o0h8w+UjT66Mkp X6wzIP9u4BAkvTHJVwkk79F/MmRPYYjwQooRHcnI+hD1ezliXZ9vgxysFR5g QqftD2tvkwPbxFXOer1mD5MSSMXWGXlhpLBPxiLFRurReBCabL53818Q4P80 Ky8rBUnw5p8XTdQ/fKaMfpQGHzmgiBaOU+sQnWR2IcErpuEWJB2qKHduUgep +if+VOXi3xe9oE/hc7GqX7kV7zmHPBgV9ixczvgouEbRsqSd1LlGYIOd7Suf 1cl7IoxEc2bnbZMdbWPQwyM10D+YDTP+hRs+fjAYoQA+iKy3fg6zHDZKo9BI 6NZQNSGmVn5apP6EpUsFLK9qWltryTTXabbLEdVGOHLSgX3F30wJt1IwtSoV BabBFJARq2JTeYEeOyQryyI18nSNQ3AWn30lbWo/Zm5+MxXCnsm46uvagPBM kbMJT/GbsMVU+gVj6BwJz4wH9AY1K0LZL2+ip8/zkwTMp5kXpQTkslEv3spQ f+wHEduwgfsANaoELc4L3NNcZyR8BD3AWilIEHwmHSoAxNXMjZzIC8JidttG 18GuoiFxf4PG6QR810y1NFw+eR3cFEOa6D5OZP8f8xlHDGJSCv+Rokkc4JLo 1kEGMbYY6IFW9DGaM3s3NOKO14U6YJQUpae23MF/+lTP5LbQadeWh4myU5Wg TefM7rv9dl1yPZnCNINWhSuxshrBse8wbGwUJZ9Ix9KwnU8zR/HoDm0QvPt4 XYXTvVFGF6ynUaDe2xFfgmBMHNuPPJ+pv/rJoGqZ5th2HOhWVuuTRWKK1JBT IKgTnkDXOXDDdqHohr9r5AeQ10thqi0Pi3PzcynNH706qcEdAK9G/dYJ7lXg FDNPwRJYrTrvLwEEA+7rYEoq3PajLndz96hou2sRX28/6J67VBpiOkfcgk9t SJ6Srr3Tn/Ud4YnmcB1LQx0goQC4Bmf/Q8iPHnRimi16g8vSl2wTx2EEr+40 RZcJWj68/3EETH5qPgFdp4m6teRnB7tE6wudCwsNsJ51LcGr0aKxvDVSZyAA jM09plC+5rUAdYF5teVfZZFQLt/PFAfapvvzEWE= =zZ65 -----END PGP MESSAGE----- go-proton-api-1.0.0/testdata/prv.asc000066400000000000000000000071731447642273300173370ustar00rootroot00000000000000-----BEGIN PGP PRIVATE KEY BLOCK----- Version: OpenPGP.js v4.10.10 Comment: https://openpgpjs.org xcMGBF4yECUBCACmt9I8R1+ibe0pHa/PMC1Zs07BFSjTB4k2B8EZJhc3dlgU WoNj06HKAKTnF0tCJfojb9Hhhns9E3da+/pr2mTeGLGCmxVlHmS7vt5MKezh rGmqT9QMJAgjgDlL8+ecWaapRyIQ4NXBX7H3DPuUDSVBxISSbXLxfORiYP5E DueF6aLLMrJFgFamFq5kgpFvXYZmPv+h/VoLP+ZLxnPfS92W3jJ3Y6ByZcC9 bsHiqBFAnGDnh0nIW109rEZRa8vKeSJL/48hUSxucvdAGWQyVpA2vhiWXzba 3O1gmjcphVrQCCgSyvyDvYF/uS3mR1F8d//BpKuYKInq6+Y5MPeKRJGVABEB AAH+CQMISThK/g0sWIlgIp1NC8BxDEFdWTKFvOVMFrmX11zKW5gBJClW6WDE sgbfZi62hLUBAJCwGfH4jH+737RIzFyuNvLTU0Kh7hcOBiEdORIQ6GPq0PFk fDdy4zqd5pFdvvG7bAF+SPWSh9ydFd6kkoelsg0vqUMQSD8cLRw3ESeyzdh1 Rw1sidfpReHO8Q27G95bX1nvK+tBlFyUhJAVP/kfDw0Pn21L45QPuO5ueB1R vbuihvHeFydZ6FmsNOI5avxHFmxOqO9qP/Bxs4R14JLMsQFt//n+Dp5pnF4F GzTTYHevY9SCjlMkn1x25FTmJB+aMDJrN+yrEw/oCsDF8Y5ZfueHaqDY46Lv oW4tkDExK045BDNH5yE0Z/3zjz2FpUfLiLqU8YnR72SoxHaRRuIdDlNIuiGc FYy0yZsj30SVSVDKL/A+Nu1Gz2GvHociTqqJQiyBXcp5C0DjEmz97QP8y1o3 atg71RAIOCWReu9YqbAWWdBZo0dannohPHTUeHMhh9bd5GTamG/xf5SrTBTg ooVWwlGuF8/NQpYGqjvENgr5a5Q0SqWUm31CeYcHLI70TooTbmIH6gvlUax1 O9VmzUpX9GlZQ1fwkHp6WAG+0cwMie2Nlwt0Ul4WjGQS0HeyYKGMriY/vMIx IFScwIVzWY3/nbVS5Z7BGrE3sNEOP313o603DKq+lasT3WKpRYQ9EJxvPdjW En0IcSeZnlLBM3ZAcPRlh4dmVQwEYx8nasmbIBj80nANA96jMT2X8Sft1zEN OpLPR86l8eRAkmbdAZ+GZIl6xkxsdxc+cad2OJubs+ze6CbjfCWE6zICUvE1 PQr1fARRmhyL6Ixpio8TTaSztaBh/QjF6vcU2dpApzri/k+w+UX9ZzP0LNbw VncvsJNMEWNzFZ0pmbfZX0rZwb9C+1yKzSdzY2hpem9mcmVuaWNAcG0ubWUg PHNjaGl6b2ZyZW5pY0BwbS5tZT7CwHYEEAEIACAFAl4yECUGCwkHCAMCBBUI CgIEFgIBAAIZAQIbAwIeAQAKCRDssm2jN69QH0w+B/9dpwofZoAMu7eeVS5B tkvLiCWJqWtoBIU2TfB1nAzIbOJA5cWegobKsEtBVI//QAQBwcjg+BjjXmGM KkmO6suDARMrATtct+G2kUl7FpRDE47okq0s+2KJb7bAaPQoBOx5xwFQM3Tj JkD9C+1xSJIrcIgpk1Rs129cNZXKXNc01v02xTrszYnbLqvneYFY4Qt1AUTP bB9us3c6nx0dDq0phGJwRbUOptNZbQrrJ3F3zmVKwGZLju0L0Gmy/F/AOMk4 S3hh/LOTcmSJ0ytZPtTkrTUKmCqDkEspTO4c17y1ffS/4LfzdnPFJJEXjvYU DQNtphFfgsbQeBmuzF2yQTyCx8MGBF4yECUBCADTP/kymrU/DLbGK6kgiUAB UU4zH7Rq6u1NVqKwdaBKOulMKst4QSlVfixI2IDjG2JgUJbCDjhqmgQ3AbDz Z7xOxUqscvM9xsVBbZM5KW+k5cOeAPGNu/GEz62gz/sUTQ5ZGLMjX+C31/3b olKNWuke46mBmPIcv0of7/izanZSRqUeJ4+KsWQjorPmmurqt3TCRq1h3dlm itHQTlQLn9EWRvvTIQagzh5bma7nfwIdnTLfRQW4JX/W6t09O3wj+g5t/X4S dTbXTHnjkLYahXiFDII+2KEcYGWOrs3HeJRb8GEhuOI0g5yK2ezX7RLpzSjj nql4oiQvczPO73fC4N83ABEBAAH+CQMIVc335hmgh/5gkyB7ZXJRkmU5v9yw CXscCDvKBTEFnVPteLUNi7D41TRWwQMIlahu106doJavog9PSwft+tQ590pl x7BwlE5+Rfr48svaHkkm8/AoewScpkqIH0Z/m2/LSpta3Lpjlj0ea8KUkk0V BM/KzwoF6hz85ZglT/s49MohLT6bowhgTE/ZeoLvJJ+NN40KPf9+2ZU3vR8Q NpnzqAXIg5iLuUKrAkic0r3DeOUbKebznMJevN6l6DNBQk1BpRI/talsID99 rk/OaQ8fSXC63NLmXNBg7Oig9iyQJYnqc0le3d3QztzQLgIg79S4PVlQuQvA truSPSrCaLcmqarjuqKDEA4zUzYcGDSKbLXn5JUI9Y74k+CK/IN5OyocwDn3 THAY63+rLgjEyzd3MKtpS2G6cv7FQVVNvk9j73v6zvlrWXYD8BClklLAhiH5 FjCuEnNf3EV6r6ztMp4/UJoqS5N+qrQHLR8upufHoEmmhtL0HLkcqkT7O/9O LJEKREqcM7w8oWogSQLHbZ5XxWXodDH+smJbu8aySsJT0EED8agiQaN7WTVu yMGYUms/U/SoCvXPLYnnFzSf4W6xiD6o5kreaFg5OvFlyxQevtcbcU3vXSYW QDXqr69Yp4lZXm0gOoDLBCCOkqhQcMJka58eP0hSQkCDsMiwxSk86JRtdw+W H8s1KE9/noKipu2g3v79n1bY0SzkVchGtPi6iV4SarUYEfchkxp8ZlEcnbOy mJJaIIHoseZ23un/78Bu+YUDJ3kjeBSwIKYrCe1+51Z+xXnS5/tanX2LFhB6 gJ7zLABynXrzu7UHxb3zWmHLZxeiYpVYFwrntTPYd7peOiSc/NDWiYLK1SwK nMnpdnLi0/LLuRMsWKKuIITw8sGfki1oxIh9D/6bWEe6eDJIJnAfieGbwsBf BBgBCAAJBQJeMhAlAhsMAAoJEOyybaM3r1AfgywH/01kOihA5/Q8doMipNkZ az3+4ZcAnPeqnIx6ba8xQTLL38Z6xy7SrTQyCLv01dMJVbRqie0ypk4Zeyxw CK7mMqJAq5vMuj/voKCjFZnW3wszxRV3p+U9/SlPn9Rirg2DVFwjScRYro4P 3Tml0oMmFD2jD7QkATwWdhYjTKwET8eCtv7CciKe9EOad6b4vLCiXpT6TiuS MPHS4iUgVMKL4jhVAQoqZDOMMN+odt7yKzhtUaF8VQKLwwHh/JC8BTuweNJN 5doj2cZGjeDu8HQzW/kDSVGQA05rOfZBX4hk8Cpm82ChsjJA7vb3AQFir3CD qF70cirPDl5bNhyYkJm6Asw= =82rN -----END PGP PRIVATE KEY BLOCK----- go-proton-api-1.0.0/testdata/pub.asc000066400000000000000000000046451447642273300173170ustar00rootroot00000000000000-----BEGIN PGP PUBLIC KEY BLOCK----- xsDNBGCwvxYBDACtFOvVIma53f1RLCaE3LtaIaY+sVHHdwsB8g13Kl0x5sK53AchIVR+6RE0JHG1 pbwQX4Hm05w6cjemDo652Cjn946zXQ65GYMYiG9Uw+HVldk3TsmKHdvI3zZNQkihnGSMP65BG5Mi 6M3Yq/5FAEP3cOCUKJKkSd6KEx6x3+mbjoPnb4fV0OlfNZa1+FDVlE1gkH3GKQIdcutF5nMDvxry RHM20vnR1YPrY587Uz6JTnarxCeENn442W/aiG5O2FXgt5QKW66TtTzESry/y6JEpg9EiLKG0Ki4 k6Z2kkP+YS5xvmqSohVqusmBnOk+wppIhrWaxGJ08Rv5HgzGS3gS29XmzxlBDE+FCrOVSOjAQ94g UtHZMIPL91A2JMc3RbOXpqVPNyJ+dRzQZ1obyXoaaoiLCQlBtVSbCKUOLVY+bmpyqUdSx45k31Hf FSUj8KrkjsCw6QFpVEfa5LxKfLHfulZdjL3FquxiYjrLHsYmdlIY2lqtaQocINk6VTa+YkkAEQEA Ac0cQlFBIDxwbS5icmlkZ2UucWFAZ21haWwuY29tPsLBDwQTAQgAORYhBMTS4mxV82UN59X4Y1MP t/KzWl0zBQJgsL8WBQkFo5qAAhsDBQsJCAcCBhUICQoLAgUWAgMBAAAKCRBTD7fys1pdMw0dC/9w Ud0I1lp/AHztlIrPYgwwJcSK7eSxuHXelX6mFImjlieKcWjdBL4Gj8MyOxPgjRDW7JecRZA/7tMI 37+izWH2CukevGrbpdyuzX0AR7I7DpX4tDVFNTxi7vYDk+Q+lVJ5dL4lYww3t7cuzhqUvj4oSJaS 9cNeFc66owij7juQoQQ7DmOsLMUw9qlMsDvZNvu83x7hIyGLBCY1gY1VtCeb3QT7uCG8LrQrWkI9 RLgzZioegHxMtvUgzQRw8U9mS8lJ4J2LaI3Z4DliyKSEebplVMfl53dSl1wfV5huZKifoo9NAusw lrRw+3Ae+VZ0Obnz14qmyCwevHv6QlkXtntSY1wyprOvzWiu8PE9rHoTmwLI8wMkbiLdFVXCZbon /1Hg0n1K0fv1A8cIc5JSeCe3y8YMm7b5oEie/cnArqDjZ8VB/vm5H9zvHxfJCI5FwlEVBlosSpib Tm/1fSpqDgAmH7IDe3wCY8899kmfbBqJzr+5xaCGt+0mgC8jpJIEIKHOwM0EYLC/FwEMAKtvqck9 78vAr1ttKpOAEQcKf1X04QLy2AvzHGNcud+XC1u0bHLm3OQsYyLaP3DVAvain6vrVVGiswdsexUI yIEpBTo+9Rco7MtwwESfxG10p2bbd8q74EaJZkt/ifL6oxEYgp8tCgAB6tqGoXCmkG0nKszrrTTz Lo/3bHjzfxF01oGDNlQVGVwW+8d5tjV5vowxeSjmdIZXJPNep4Lah/xFisWb71VwdzVEaOi6k7rQ J5k+Dp1wrCqW1H5RZZt6dGweU4LbuTYBWtnw/2YKz+hBOYGDzil9hqTG9fRXu31d4xOZxuZkv61R 3DWrxuECKUHgJvFaao0KSnBDa/T/RMJ9Y/KQ0bx0zXOTtoDOhOhpMA8JUTMfWb3Uul50ikxLI5EJ xnBroy2bLLaRW6ijMgpdnZRAtmhssHipOisxXoxiWMoRfJBR01DhbmSQPTjpsjqM2Z24hPcKN+sf 9kCKTmaJ2hbOfurriPmM0GHdgewbf5cemKgqVaPfhvyBXhnRjwARAQABwsD8BBgBCAAmFiEExNLi bFXzZQ3n1fhjUw+38rNaXTMFAmCwvxcFCQWjmoACGwwACgkQUw+38rNaXTNTSgwAqomSuzK80Goi eOqJ6e0LLiKJTGzMtrtugK9HYzFn1rT7n9W2lZuf4X8Ayo9i32Q4Of1V17EXOyYWHOK/prTDd9DV sRa+fzLVzC6jln3AKeRi9k/DIs7GDs0poQZyttTVLilK8uDkEWM7mWAyjyBTtWyiKTlfFb7W+M3R 1lTKXQsn/wBkboJNZj+VTNo5NZ6vIx4PJRFW2lsDKbYJ+Vh5vZUdTwHXr5gLadtWzrVgBVMiLyEr fgCzdyfMRy+g4uoYxt9JuFvisU/DDVNeAZ8hSgLdI4w65wjeXtT0syzpL9+pJQX0McugEpbIEiOt e55OL1C0hjvHnsLHPkRuUOtQKru/gNl0bLqZ7mYqPNhJbh/58k+N4eoeTvCjMy65anWuiWjPbm16 GH/3erZiijKDGYn8UqldiOK9dTC6DbvyJdxuYFliV7cSWIBtiOeGrajxzkuUHMW+d1d4l2gPqs2+ eT1x4J+7ydQgCvyyI4W01xcFlAL70VRTlYKIbMXJBZ6L =9sH1 -----END PGP PUBLIC KEY BLOCK----- go-proton-api-1.0.0/ticker.go000066400000000000000000000020201447642273300160210ustar00rootroot00000000000000package proton import ( "math/rand" "time" "github.com/ProtonMail/gluon/async" ) type Ticker struct { C chan time.Time stopCh chan struct{} doneCh chan struct{} } // NewTicker returns a new ticker that ticks at a random time between period and period+jitter. // It can be stopped by closing calling Stop(). func NewTicker(period, jitter time.Duration, panicHandler async.PanicHandler) *Ticker { t := &Ticker{ C: make(chan time.Time), stopCh: make(chan struct{}), doneCh: make(chan struct{}), } go func() { defer async.HandlePanic(panicHandler) defer close(t.doneCh) for { select { case <-t.stopCh: return case <-time.After(withJitter(period, jitter)): select { case <-t.stopCh: return case t.C <- time.Now(): // ... } } } }() return t } func (t *Ticker) Stop() { close(t.stopCh) <-t.doneCh } func withJitter(period, jitter time.Duration) time.Duration { if jitter == 0 { return period } return period + time.Duration(rand.Int63n(int64(jitter))) } go-proton-api-1.0.0/undo.go000066400000000000000000000012531447642273300155140ustar00rootroot00000000000000package proton import ( "context" "runtime" "time" "github.com/bradenaw/juniper/parallel" "github.com/go-resty/resty/v2" ) func (c *Client) UndoActions(ctx context.Context, tokens ...UndoToken) ([]UndoRes, error) { return parallel.MapContext(ctx, runtime.NumCPU(), tokens, func(ctx context.Context, token UndoToken) (UndoRes, error) { if time.Unix(token.ValidUntil, 0).Before(time.Now()) { return UndoRes{}, ErrUndoTokenExpired } var res UndoRes if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetBody(token).SetResult(&res).Post("/mail/v4/undoactions") }); err != nil { return UndoRes{}, err } return res, nil }) } go-proton-api-1.0.0/undo_types.go000066400000000000000000000003111447642273300167320ustar00rootroot00000000000000package proton import "errors" var ErrUndoTokenExpired = errors.New("undo token expired") type UndoToken struct { Token string ValidUntil int64 } type UndoRes struct { Messages []Message } go-proton-api-1.0.0/unlock.go000066400000000000000000000021141447642273300160370ustar00rootroot00000000000000package proton import ( "fmt" "runtime" "github.com/ProtonMail/gluon/async" "github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/bradenaw/juniper/parallel" ) func Unlock(user User, addresses []Address, saltedKeyPass []byte, panicHandler async.PanicHandler) (*crypto.KeyRing, map[string]*crypto.KeyRing, error) { userKR, err := user.Keys.Unlock(saltedKeyPass, nil) if err != nil { return nil, nil, fmt.Errorf("failed to unlock user keys: %w", err) } else if userKR.CountDecryptionEntities() == 0 { return nil, nil, fmt.Errorf("failed to unlock any user keys") } addrKRs := make(map[string]*crypto.KeyRing) for idx, addrKR := range parallel.Map(runtime.NumCPU(), addresses, func(addr Address) *crypto.KeyRing { defer async.HandlePanic(panicHandler) return addr.Keys.TryUnlock(saltedKeyPass, userKR) }) { if addrKR == nil { continue } else if addrKR.CountDecryptionEntities() == 0 { continue } addrKRs[addresses[idx].ID] = addrKR } if len(addrKRs) == 0 { return nil, nil, fmt.Errorf("failed to unlock any address keys") } return userKR, addrKRs, nil } go-proton-api-1.0.0/user.go000066400000000000000000000024271447642273300155310ustar00rootroot00000000000000package proton import ( "context" "encoding/base64" "github.com/ProtonMail/go-srp" "github.com/go-resty/resty/v2" ) func (c *Client) GetUser(ctx context.Context) (User, error) { var res struct { User User } if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetResult(&res).Get("/core/v4/users") }); err != nil { return User{}, err } return res.User, nil } func (c *Client) DeleteUser(ctx context.Context, password []byte, req DeleteUserReq) error { user, err := c.GetUser(ctx) if err != nil { return err } info, err := c.m.AuthInfo(ctx, AuthInfoReq{Username: user.Name}) if err != nil { return err } srpAuth, err := srp.NewAuth(info.Version, user.Name, password, info.Salt, info.Modulus, info.ServerEphemeral) if err != nil { return err } proofs, err := srpAuth.GenerateProofs(2048) if err != nil { return err } return c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetBody(struct { DeleteUserReq AuthReq }{ DeleteUserReq: req, AuthReq: AuthReq{ ClientProof: base64.StdEncoding.EncodeToString(proofs.ClientProof), ClientEphemeral: base64.StdEncoding.EncodeToString(proofs.ClientEphemeral), SRPSession: info.SRPSession, }, }).Delete("/core/v4/users/delete") }) } go-proton-api-1.0.0/user_types.go000066400000000000000000000004571447642273300167560ustar00rootroot00000000000000package proton type User struct { ID string Name string DisplayName string Email string Keys Keys UsedSpace int64 MaxSpace int64 MaxUpload int64 Credit int64 Currency string } type DeleteUserReq struct { Reason string Feedback string Email string } go-proton-api-1.0.0/utils/000077500000000000000000000000001447642273300153575ustar00rootroot00000000000000go-proton-api-1.0.0/utils/dependency_license.sh000077500000000000000000000060271447642273300215430ustar00rootroot00000000000000#!/bin/bash set -eo pipefail src=go.mod tgt=COPYING_NOTES.md STARTAUTOGEN="" ENDAUTOGEN="" RE_STARTAUTOGEN="^${STARTAUTOGEN}$" RE_ENDAUTOGEN="^${ENDAUTOGEN}$" tmpDepLicenses="" error(){ echo "Error: $*" exit 1 } generate_dep_licenses(){ [ -r $src ] || error "Cannot read file '$src'" tmpDepLicenses="$(mktemp)" # Collect all go.mod lines beginig with tab: # * which no replace # * which have replace grep -E $'^\t[^=>]*$' $src | sed -r 's/\t([^ ]*) v.*/\1/g' > "$tmpDepLicenses" # Replace each line with formated link sed -i -r '/^github.com\/therecipe\/qt\/internal\/binding\/files\/docs\//d;' "$tmpDepLicenses" sed -i -r 's|^(.*)/([[:alnum:]-]+)/(v[[:digit:]]+)$|* [\2](https://\1/\2/\3)|g' "$tmpDepLicenses" sed -i -r 's|^(.*)/([[:alnum:]-]+)$|* [\2](https://\1/\2)|g' "$tmpDepLicenses" sed -i -r 's|^(.*)/([[:alnum:]-]+).(v[[:digit:]]+)$|* [\2](https://\1/\2.\3)|g' "$tmpDepLicenses" ## add license file to github links, and others sed -i -r '/github.com/s|^(.*(https://[^)]+).*)$|\1 available under [license](\2/blob/master/LICENSE) |g' "$tmpDepLicenses" sed -i -r '/golang.org\/x/s|^(.*golang.org/x/([^)]+).*)$|\1 available under [license](https://cs.opensource.google/go/x/\2/+/master:LICENSE) |g' "$tmpDepLicenses" sed -i -r '/google.golang.org\/grpc/s|^(.*)$|\1 available under [license](https://github.com/grpc/grpc-go/blob/master/LICENSE) |g' "$tmpDepLicenses" sed -i -r '/google.golang.org\/protobuf/s|^(.*)$|\1 available under [license](https://github.com/protocolbuffers/protobuf/blob/main/LICENSE) |g' "$tmpDepLicenses" sed -i -r '/google.golang.org\/genproto/s|^(.*)$|\1 available under [license](https://pkg.go.dev/google.golang.org/genproto?tab=licenses) |g' "$tmpDepLicenses" sed -i -r '/go.uber.org\/goleak/s|^(.*)$|\1 available under [license](https://pkg.go.dev/go.uber.org/goleak?tab=licenses) |g' "$tmpDepLicenses" sed -i -r '/gopkg.in\/yaml\.v3/s|^(.*)$|\1 available under [license](https://github.com/go-yaml/yaml/blob/v3.0.1/LICENSE) |g' "$tmpDepLicenses" } check_dependecies(){ generate_dep_licenses tmpHaveLicenses=$(mktemp) sed "/${RE_STARTAUTOGEN}/,/${RE_ENDAUTOGEN}/!d;//d" $tgt > "$tmpHaveLicenses" diffOK=0 if ! diff "$tmpHaveLicenses" "$tmpDepLicenses"; then diffOK=1; fi rm "$tmpDepLicenses" || echo "Failed to clean tmp file" rm "$tmpHaveLicenses" || echo "Failed to clean tmp file" [ $diffOK -eq 0 ] || error "Dependency licenses are not up-to-date" exit 0 } update_dependecies(){ generate_dep_licenses sed -i -e "/${RE_STARTAUTOGEN}/,/${RE_ENDAUTOGEN}/!b" \ -e "/${RE_ENDAUTOGEN}/i ${STARTAUTOGEN}" \ -e "/${RE_ENDAUTOGEN}/r $tmpDepLicenses" \ -e "/${RE_ENDAUTOGEN}/a ${ENDAUTOGEN}" \ -e "d" \ $tgt rm "$tmpDepLicenses" || echo "Failed to clean tmp file" exit 0 } case $1 in "check") check_dependecies;; "update") update_dependecies;; *) error "One of actions needed: check update" ;; esac go-proton-api-1.0.0/volume.go000066400000000000000000000012671447642273300160630ustar00rootroot00000000000000package proton import ( "context" "github.com/go-resty/resty/v2" ) func (c *Client) ListVolumes(ctx context.Context) ([]Volume, error) { var res struct { Volumes []Volume } if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetResult(&res).Get("/drive/volumes") }); err != nil { return nil, err } return res.Volumes, nil } func (c *Client) GetVolume(ctx context.Context, volumeID string) (Volume, error) { var res struct { Volume Volume } if err := c.do(ctx, func(r *resty.Request) (*resty.Response, error) { return r.SetResult(&res).Get("/drive/volumes/" + volumeID) }); err != nil { return Volume{}, err } return res.Volume, nil } go-proton-api-1.0.0/volume_types.go000066400000000000000000000025361447642273300173070ustar00rootroot00000000000000package proton // Volume is a Proton Drive volume. type Volume struct { VolumeID string // Encrypted volume ID CreationTime int64 // Creation time of the volume in Unix time ModifyTime int64 // Last modification time of the volume in Unix time MaxSpace *int64 // Space limit for the volume in bytes, null if unlimited. UsedSpace int64 // Space used by files in the volume in bytes DownloadedBytes int64 // The amount of downloaded data since last reset UploadedBytes int64 // The amount of uploaded data since the last reset State VolumeState // The state of the volume (active, locked, maybe more in the future) Share VolumeShare // The main share of the volume RestoreStatus *VolumeRestoreStatus // The status of the restore task. Null if not applicable } // VolumeShare is the main share of a volume. type VolumeShare struct { ShareID string // Encrypted share ID LinkID string // Encrypted link ID } // VolumeState is the state of a volume. type VolumeState int const ( VolumeStateActive VolumeState = 1 VolumeStateLocked VolumeState = 3 ) // VolumeRestoreStatus is the status of the restore task. type VolumeRestoreStatus int const ( RestoreStatusDone VolumeRestoreStatus = 0 RestoreStatusInProgress VolumeRestoreStatus = 1 RestoreStatusFailed VolumeRestoreStatus = -1 )