pax_global_header00006660000000000000000000000064143610017150014510gustar00rootroot0000000000000052 comment=0d23249f92db8b5ce8d9834801993354e3b754c4 go-0.11.1/000077500000000000000000000000001436100171500121755ustar00rootroot00000000000000go-0.11.1/.editorconfig000066400000000000000000000002661436100171500146560ustar00rootroot00000000000000root = true [*] indent_style = tab indent_size = 4 end_of_line = lf charset = utf-8 trim_trailing_whitespace = true insert_final_newline = true [*.{yaml,yml}] indent_style = space go-0.11.1/.github/000077500000000000000000000000001436100171500135355ustar00rootroot00000000000000go-0.11.1/.github/workflows/000077500000000000000000000000001436100171500155725ustar00rootroot00000000000000go-0.11.1/.github/workflows/go.yml000066400000000000000000000021331436100171500167210ustar00rootroot00000000000000name: Go on: [push, pull_request] jobs: lint: runs-on: ubuntu-latest strategy: fail-fast: false matrix: go-version: [1.18] steps: - uses: actions/checkout@v3 - name: Set up Go ${{ matrix.go-version }} uses: actions/setup-go@v3 with: go-version: ${{ matrix.go-version }} - name: Install goimports run: | go install golang.org/x/tools/cmd/goimports@latest export PATH="$HOME/go/bin:$PATH" - name: Install pre-commit run: pip install pre-commit - name: Lint run: pre-commit run -a build: runs-on: ubuntu-latest strategy: fail-fast: false matrix: go-version: [1.17, 1.18] steps: - uses: actions/checkout@v3 - name: Set up Go ${{ matrix.go-version }} uses: actions/setup-go@v3 with: go-version: ${{ matrix.go-version }} - name: Install libolm run: sudo apt-get install libolm-dev libolm3 - name: Build run: go build -v ./... - name: Test run: go test -v ./... go-0.11.1/.gitignore000066400000000000000000000000201436100171500141550ustar00rootroot00000000000000.idea/ .vscode/ go-0.11.1/.pre-commit-config.yaml000066400000000000000000000005561436100171500164640ustar00rootroot00000000000000repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.1.0 hooks: - id: trailing-whitespace exclude_types: [markdown] - id: end-of-file-fixer - id: check-yaml - id: check-added-large-files - repo: https://github.com/tekwizely/pre-commit-golang rev: v1.0.0-beta.5 hooks: - id: go-imports-repo go-0.11.1/LICENSE000066400000000000000000000405251436100171500132100ustar00rootroot00000000000000Mozilla Public License Version 2.0 ================================== 1. Definitions -------------- 1.1. "Contributor" means each individual or legal entity that creates, contributes to the creation of, or owns Covered Software. 1.2. "Contributor Version" means the combination of the Contributions of others (if any) used by a Contributor and that particular Contributor's Contribution. 1.3. "Contribution" means Covered Software of a particular Contributor. 1.4. "Covered Software" means Source Code Form to which the initial Contributor has attached the notice in Exhibit A, the Executable Form of such Source Code Form, and Modifications of such Source Code Form, in each case including portions thereof. 1.5. "Incompatible With Secondary Licenses" means (a) that the initial Contributor has attached the notice described in Exhibit B to the Covered Software; or (b) that the Covered Software was made available under the terms of version 1.1 or earlier of the License, but not also under the terms of a Secondary License. 1.6. "Executable Form" means any form of the work other than Source Code Form. 1.7. "Larger Work" means a work that combines Covered Software with other material, in a separate file or files, that is not Covered Software. 1.8. "License" means this document. 1.9. "Licensable" means having the right to grant, to the maximum extent possible, whether at the time of the initial grant or subsequently, any and all of the rights conveyed by this License. 1.10. "Modifications" means any of the following: (a) any file in Source Code Form that results from an addition to, deletion from, or modification of the contents of Covered Software; or (b) any new file in Source Code Form that contains any Covered Software. 1.11. "Patent Claims" of a Contributor means any patent claim(s), including without limitation, method, process, and apparatus claims, in any patent Licensable by such Contributor that would be infringed, but for the grant of the License, by the making, using, selling, offering for sale, having made, import, or transfer of either its Contributions or its Contributor Version. 1.12. "Secondary License" means either the GNU General Public License, Version 2.0, the GNU Lesser General Public License, Version 2.1, the GNU Affero General Public License, Version 3.0, or any later versions of those licenses. 1.13. "Source Code Form" means the form of the work preferred for making modifications. 1.14. "You" (or "Your") means an individual or a legal entity exercising rights under this License. For legal entities, "You" includes any entity that controls, is controlled by, or is under common control with You. For purposes of this definition, "control" means (a) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (b) ownership of more than fifty percent (50%) of the outstanding shares or beneficial ownership of such entity. 2. License Grants and Conditions -------------------------------- 2.1. Grants Each Contributor hereby grants You a world-wide, royalty-free, non-exclusive license: (a) under intellectual property rights (other than patent or trademark) Licensable by such Contributor to use, reproduce, make available, modify, display, perform, distribute, and otherwise exploit its Contributions, either on an unmodified basis, with Modifications, or as part of a Larger Work; and (b) under Patent Claims of such Contributor to make, use, sell, offer for sale, have made, import, and otherwise transfer either its Contributions or its Contributor Version. 2.2. Effective Date The licenses granted in Section 2.1 with respect to any Contribution become effective for each Contribution on the date the Contributor first distributes such Contribution. 2.3. Limitations on Grant Scope The licenses granted in this Section 2 are the only rights granted under this License. No additional rights or licenses will be implied from the distribution or licensing of Covered Software under this License. Notwithstanding Section 2.1(b) above, no patent license is granted by a Contributor: (a) for any code that a Contributor has removed from Covered Software; or (b) for infringements caused by: (i) Your and any other third party's modifications of Covered Software, or (ii) the combination of its Contributions with other software (except as part of its Contributor Version); or (c) under Patent Claims infringed by Covered Software in the absence of its Contributions. This License does not grant any rights in the trademarks, service marks, or logos of any Contributor (except as may be necessary to comply with the notice requirements in Section 3.4). 2.4. Subsequent Licenses No Contributor makes additional grants as a result of Your choice to distribute the Covered Software under a subsequent version of this License (see Section 10.2) or under the terms of a Secondary License (if permitted under the terms of Section 3.3). 2.5. Representation Each Contributor represents that the Contributor believes its Contributions are its original creation(s) or it has sufficient rights to grant the rights to its Contributions conveyed by this License. 2.6. Fair Use This License is not intended to limit any rights You have under applicable copyright doctrines of fair use, fair dealing, or other equivalents. 2.7. Conditions Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted in Section 2.1. 3. Responsibilities ------------------- 3.1. Distribution of Source Form All distribution of Covered Software in Source Code Form, including any Modifications that You create or to which You contribute, must be under the terms of this License. You must inform recipients that the Source Code Form of the Covered Software is governed by the terms of this License, and how they can obtain a copy of this License. You may not attempt to alter or restrict the recipients' rights in the Source Code Form. 3.2. Distribution of Executable Form If You distribute Covered Software in Executable Form then: (a) such Covered Software must also be made available in Source Code Form, as described in Section 3.1, and You must inform recipients of the Executable Form how they can obtain a copy of such Source Code Form by reasonable means in a timely manner, at a charge no more than the cost of distribution to the recipient; and (b) You may distribute such Executable Form under the terms of this License, or sublicense it under different terms, provided that the license for the Executable Form does not attempt to limit or alter the recipients' rights in the Source Code Form under this License. 3.3. Distribution of a Larger Work You may create and distribute a Larger Work under terms of Your choice, provided that You also comply with the requirements of this License for the Covered Software. If the Larger Work is a combination of Covered Software with a work governed by one or more Secondary Licenses, and the Covered Software is not Incompatible With Secondary Licenses, this License permits You to additionally distribute such Covered Software under the terms of such Secondary License(s), so that the recipient of the Larger Work may, at their option, further distribute the Covered Software under the terms of either this License or such Secondary License(s). 3.4. Notices You may not remove or alter the substance of any license notices (including copyright notices, patent notices, disclaimers of warranty, or limitations of liability) contained within the Source Code Form of the Covered Software, except that You may alter any license notices to the extent required to remedy known factual inaccuracies. 3.5. Application of Additional Terms You may choose to offer, and to charge a fee for, warranty, support, indemnity or liability obligations to one or more recipients of Covered Software. However, You may do so only on Your own behalf, and not on behalf of any Contributor. You must make it absolutely clear that any such warranty, support, indemnity, or liability obligation is offered by You alone, and You hereby agree to indemnify every Contributor for any liability incurred by such Contributor as a result of warranty, support, indemnity or liability terms You offer. You may include additional disclaimers of warranty and limitations of liability specific to any jurisdiction. 4. Inability to Comply Due to Statute or Regulation --------------------------------------------------- If it is impossible for You to comply with any of the terms of this License with respect to some or all of the Covered Software due to statute, judicial order, or regulation then You must: (a) comply with the terms of this License to the maximum extent possible; and (b) describe the limitations and the code they affect. Such description must be placed in a text file included with all distributions of the Covered Software under this License. Except to the extent prohibited by statute or regulation, such description must be sufficiently detailed for a recipient of ordinary skill to be able to understand it. 5. Termination -------------- 5.1. The rights granted under this License will terminate automatically if You fail to comply with any of its terms. However, if You become compliant, then the rights granted under this License from a particular Contributor are reinstated (a) provisionally, unless and until such Contributor explicitly and finally terminates Your grants, and (b) on an ongoing basis, if such Contributor fails to notify You of the non-compliance by some reasonable means prior to 60 days after You have come back into compliance. Moreover, Your grants from a particular Contributor are reinstated on an ongoing basis if such Contributor notifies You of the non-compliance by some reasonable means, this is the first time You have received notice of non-compliance with this License from such Contributor, and You become compliant prior to 30 days after Your receipt of the notice. 5.2. If You initiate litigation against any entity by asserting a patent infringement claim (excluding declaratory judgment actions, counter-claims, and cross-claims) alleging that a Contributor Version directly or indirectly infringes any patent, then the rights granted to You by any and all Contributors for the Covered Software under Section 2.1 of this License shall terminate. 5.3. In the event of termination under Sections 5.1 or 5.2 above, all end user license agreements (excluding distributors and resellers) which have been validly granted by You or Your distributors under this License prior to termination shall survive termination. ************************************************************************ * * * 6. Disclaimer of Warranty * * ------------------------- * * * * Covered Software is provided under this License on an "as is" * * basis, without warranty of any kind, either expressed, implied, or * * statutory, including, without limitation, warranties that the * * Covered Software is free of defects, merchantable, fit for a * * particular purpose or non-infringing. The entire risk as to the * * quality and performance of the Covered Software is with You. * * Should any Covered Software prove defective in any respect, You * * (not any Contributor) assume the cost of any necessary servicing, * * repair, or correction. This disclaimer of warranty constitutes an * * essential part of this License. No use of any Covered Software is * * authorized under this License except under this disclaimer. * * * ************************************************************************ ************************************************************************ * * * 7. Limitation of Liability * * -------------------------- * * * * Under no circumstances and under no legal theory, whether tort * * (including negligence), contract, or otherwise, shall any * * Contributor, or anyone who distributes Covered Software as * * permitted above, be liable to You for any direct, indirect, * * special, incidental, or consequential damages of any character * * including, without limitation, damages for lost profits, loss of * * goodwill, work stoppage, computer failure or malfunction, or any * * and all other commercial damages or losses, even if such party * * shall have been informed of the possibility of such damages. This * * limitation of liability shall not apply to liability for death or * * personal injury resulting from such party's negligence to the * * extent applicable law prohibits such limitation. Some * * jurisdictions do not allow the exclusion or limitation of * * incidental or consequential damages, so this exclusion and * * limitation may not apply to You. * * * ************************************************************************ 8. Litigation ------------- Any litigation relating to this License may be brought only in the courts of a jurisdiction where the defendant maintains its principal place of business and such litigation shall be governed by laws of that jurisdiction, without reference to its conflict-of-law provisions. Nothing in this Section shall prevent a party's ability to bring cross-claims or counter-claims. 9. Miscellaneous ---------------- This License represents the complete agreement concerning the subject matter hereof. If any provision of this License is held to be unenforceable, such provision shall be reformed only to the extent necessary to make it enforceable. Any law or regulation which provides that the language of a contract shall be construed against the drafter shall not be used to construe this License against a Contributor. 10. Versions of the License --------------------------- 10.1. New Versions Mozilla Foundation is the license steward. Except as provided in Section 10.3, no one other than the license steward has the right to modify or publish new versions of this License. Each version will be given a distinguishing version number. 10.2. Effect of New Versions You may distribute the Covered Software under the terms of the version of the License under which You originally received the Covered Software, or under the terms of any subsequent version published by the license steward. 10.3. Modified Versions If you create software not governed by this License, and you want to create a new license for such software, you may create and use a modified version of this License if you rename the license and remove any references to the name of the license steward (except to note that such modified license differs from this License). 10.4. Distributing Source Code Form that is Incompatible With Secondary Licenses If You choose to distribute Source Code Form that is Incompatible With Secondary Licenses under the terms of this version of the License, the notice described in Exhibit B of this License must be attached. Exhibit A - Source Code Form License Notice ------------------------------------------- This Source Code Form is subject to the terms of the Mozilla Public License, v. 2.0. If a copy of the MPL was not distributed with this file, You can obtain one at http://mozilla.org/MPL/2.0/. If it is not possible or desirable to put the notice in a particular file, then You may include the notice in a location (such as a LICENSE file in a relevant directory) where a recipient would be likely to look for such a notice. You may add additional accurate notices of copyright ownership. Exhibit B - "Incompatible With Secondary Licenses" Notice --------------------------------------------------------- This Source Code Form is "Incompatible With Secondary Licenses", as defined by the Mozilla Public License, v. 2.0. go-0.11.1/README.md000066400000000000000000000022151436100171500134540ustar00rootroot00000000000000# mautrix-go [![GoDoc](https://godoc.org/maunium.net/go/mautrix?status.svg)](https://godoc.org/maunium.net/go/mautrix) A Golang Matrix framework. Used by [gomuks](https://matrix.org/docs/projects/client/gomuks), [go-neb](https://github.com/matrix-org/go-neb), [mautrix-whatsapp](https://github.com/mautrix/whatsapp) and others. Matrix room: [`#maunium:maunium.net`](https://matrix.to/#/#maunium:maunium.net) This project is based on [matrix-org/gomatrix](https://github.com/matrix-org/gomatrix). The original project is licensed under [Apache 2.0](https://github.com/matrix-org/gomatrix/blob/master/LICENSE). In addition to the basic client API features the original project has, this framework also has: * Appservice support (Intent API like mautrix-python, room state storage, etc) * End-to-end encryption support (incl. interactive SAS verification) * Structs for parsing event content * Helpers for parsing and generating Matrix HTML * Helpers for handling push rules This project contains modules that are licensed under Apache 2.0: * [maunium.net/go/mautrix/crypto/canonicaljson](crypto/canonicaljson) * [maunium.net/go/mautrix/crypto/olm](crypto/olm) go-0.11.1/appservice/000077500000000000000000000000001436100171500143365ustar00rootroot00000000000000go-0.11.1/appservice/appservice.go000066400000000000000000000253151436100171500170340ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package appservice import ( "errors" "fmt" "html/template" "io/ioutil" "net/http" "net/http/cookiejar" "os" "path/filepath" "strings" "sync" "syscall" "time" "github.com/gorilla/mux" "github.com/gorilla/websocket" "golang.org/x/net/publicsuffix" "gopkg.in/yaml.v2" "maunium.net/go/maulogger/v2" "maunium.net/go/mautrix" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) // EventChannelSize is the size for the Events channel in Appservice instances. var EventChannelSize = 64 var OTKChannelSize = 4 // Create a blank appservice instance. func Create() *AppService { jar, _ := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}) return &AppService{ LogConfig: CreateLogConfig(), clients: make(map[id.UserID]*mautrix.Client), intents: make(map[id.UserID]*IntentAPI), HTTPClient: &http.Client{Timeout: 180 * time.Second, Jar: jar}, StateStore: NewBasicStateStore(), Router: mux.NewRouter(), UserAgent: mautrix.DefaultUserAgent, txnIDC: NewTransactionIDCache(128), Live: true, Ready: false, } } // Load an appservice config from a file. func Load(path string) (*AppService, error) { data, readErr := ioutil.ReadFile(path) if readErr != nil { return nil, readErr } config := Create() return config, yaml.Unmarshal(data, config) } // QueryHandler handles room alias and user ID queries from the homeserver. type QueryHandler interface { QueryAlias(alias string) bool QueryUser(userID id.UserID) bool } type QueryHandlerStub struct{} func (qh *QueryHandlerStub) QueryAlias(alias string) bool { return false } func (qh *QueryHandlerStub) QueryUser(userID id.UserID) bool { return false } type WebsocketHandler func(WebsocketCommand) (ok bool, data interface{}) // AppService is the main config for all appservices. // It also serves as the appservice instance struct. type AppService struct { HomeserverDomain string `yaml:"homeserver_domain"` HomeserverURL string `yaml:"homeserver_url"` RegistrationPath string `yaml:"registration"` Host HostConfig `yaml:"host"` LogConfig LogConfig `yaml:"logging"` Registration *Registration `yaml:"-"` Log maulogger.Logger `yaml:"-"` txnIDC *TransactionIDCache Events chan *event.Event `yaml:"-"` DeviceLists chan *mautrix.DeviceLists `yaml:"-"` OTKCounts chan *mautrix.OTKCount `yaml:"-"` QueryHandler QueryHandler `yaml:"-"` StateStore StateStore `yaml:"-"` Router *mux.Router `yaml:"-"` UserAgent string `yaml:"-"` server *http.Server HTTPClient *http.Client botClient *mautrix.Client botIntent *IntentAPI MessageSendCheckpointEndpoint string DefaultHTTPRetries int Live bool Ready bool clients map[id.UserID]*mautrix.Client clientsLock sync.RWMutex intents map[id.UserID]*IntentAPI intentsLock sync.RWMutex ws *websocket.Conn wsWriteLock sync.Mutex StopWebsocket func(error) websocketHandlers map[string]WebsocketHandler websocketHandlersLock sync.RWMutex websocketRequests map[int]chan<- *WebsocketCommand websocketRequestsLock sync.RWMutex websocketRequestID int32 // ProcessID is an identifier sent to the websocket proxy for debugging connections ProcessID string } func getDefaultProcessID() string { pid := syscall.Getpid() uid := syscall.Getuid() hostname, _ := os.Hostname() return fmt.Sprintf("%s-%d-%d", hostname, uid, pid) } func (as *AppService) PrepareWebsocket() { if as.websocketHandlers == nil { as.websocketHandlers = make(map[string]WebsocketHandler, 32) as.websocketRequests = make(map[int]chan<- *WebsocketCommand) } } // HostConfig contains info about how to host the appservice. type HostConfig struct { Hostname string `yaml:"hostname"` Port uint16 `yaml:"port"` TLSKey string `yaml:"tls_key,omitempty"` TLSCert string `yaml:"tls_cert,omitempty"` } // Address gets the whole address of the Appservice. func (hc *HostConfig) Address() string { return fmt.Sprintf("%s:%d", hc.Hostname, hc.Port) } // Save saves this config into a file at the given path. func (as *AppService) Save(path string) error { data, err := yaml.Marshal(as) if err != nil { return err } return ioutil.WriteFile(path, data, 0644) } // YAML returns the config in YAML format. func (as *AppService) YAML() (string, error) { data, err := yaml.Marshal(as) if err != nil { return "", err } return string(data), nil } func (as *AppService) BotMXID() id.UserID { return id.NewUserID(as.Registration.SenderLocalpart, as.HomeserverDomain) } func (as *AppService) makeIntent(userID id.UserID) *IntentAPI { as.intentsLock.Lock() defer as.intentsLock.Unlock() intent, ok := as.intents[userID] if ok { return intent } localpart, homeserver, err := userID.Parse() if err != nil || len(localpart) == 0 || homeserver != as.HomeserverDomain { if err != nil { as.Log.Fatalfln("Failed to parse user ID %s: %v", userID, err) } else if len(localpart) == 0 { as.Log.Fatalfln("Failed to make intent for %s: localpart is empty", userID) } else if homeserver != as.HomeserverDomain { as.Log.Fatalfln("Failed to make intent for %s: homeserver isn't %s", userID, as.HomeserverDomain) } return nil } intent = as.NewIntentAPI(localpart) as.intents[userID] = intent return intent } func (as *AppService) Intent(userID id.UserID) *IntentAPI { as.intentsLock.RLock() intent, ok := as.intents[userID] as.intentsLock.RUnlock() if !ok { return as.makeIntent(userID) } return intent } func (as *AppService) BotIntent() *IntentAPI { if as.botIntent == nil { as.botIntent = as.makeIntent(as.BotMXID()) } return as.botIntent } func (as *AppService) makeClient(userID id.UserID) *mautrix.Client { as.clientsLock.Lock() defer as.clientsLock.Unlock() client, ok := as.clients[userID] if ok { return client } client, err := mautrix.NewClient(as.HomeserverURL, userID, as.Registration.AppToken) if err != nil { as.Log.Fatalln("Failed to create mautrix client instance:", err) return nil } client.UserAgent = as.UserAgent client.Syncer = nil client.Store = nil client.AppServiceUserID = userID client.Logger = as.Log.Sub(string(userID)) client.Client = as.HTTPClient client.DefaultHTTPRetries = as.DefaultHTTPRetries as.clients[userID] = client return client } func (as *AppService) Client(userID id.UserID) *mautrix.Client { as.clientsLock.RLock() client, ok := as.clients[userID] as.clientsLock.RUnlock() if !ok { return as.makeClient(userID) } return client } func (as *AppService) BotClient() *mautrix.Client { if as.botClient == nil { as.botClient = as.makeClient(as.BotMXID()) as.botClient.Logger = as.Log.Sub("Bot") } return as.botClient } // Init initializes the logger and loads the registration of this appservice. func (as *AppService) Init() (bool, error) { as.Events = make(chan *event.Event, EventChannelSize) as.OTKCounts = make(chan *mautrix.OTKCount, OTKChannelSize) as.DeviceLists = make(chan *mautrix.DeviceLists, EventChannelSize) as.QueryHandler = &QueryHandlerStub{} if len(as.UserAgent) == 0 { as.UserAgent = mautrix.DefaultUserAgent } if len(as.ProcessID) == 0 { as.ProcessID = getDefaultProcessID() } as.Log = maulogger.Create() as.LogConfig.Configure(as.Log) as.Log.Debugln("Logger initialized successfully.") if len(as.RegistrationPath) > 0 { var err error as.Registration, err = LoadRegistration(as.RegistrationPath) if err != nil { return false, err } } as.Log.Debugln("Appservice initialized successfully.") return true, nil } // LogConfig contains configs for the logger. type LogConfig struct { Directory string `yaml:"directory"` FileNameFormat string `yaml:"file_name_format"` FileDateFormat string `yaml:"file_date_format"` FileMode uint32 `yaml:"file_mode"` TimestampFormat string `yaml:"timestamp_format"` RawPrintLevel string `yaml:"print_level"` JSONStdout bool `yaml:"print_json"` JSONFile bool `yaml:"file_json"` PrintLevel int `yaml:"-"` } type umLogConfig LogConfig func (lc *LogConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { err := unmarshal((*umLogConfig)(lc)) if err != nil { return err } switch strings.ToUpper(lc.RawPrintLevel) { case "TRACE": lc.PrintLevel = -10 case "DEBUG": lc.PrintLevel = maulogger.LevelDebug.Severity case "INFO": lc.PrintLevel = maulogger.LevelInfo.Severity case "WARN", "WARNING": lc.PrintLevel = maulogger.LevelWarn.Severity case "ERR", "ERROR": lc.PrintLevel = maulogger.LevelError.Severity case "FATAL": lc.PrintLevel = maulogger.LevelFatal.Severity default: return errors.New("invalid print level " + lc.RawPrintLevel) } return err } func (lc *LogConfig) MarshalYAML() (interface{}, error) { switch { case lc.PrintLevel >= maulogger.LevelFatal.Severity: lc.RawPrintLevel = maulogger.LevelFatal.Name case lc.PrintLevel >= maulogger.LevelError.Severity: lc.RawPrintLevel = maulogger.LevelError.Name case lc.PrintLevel >= maulogger.LevelWarn.Severity: lc.RawPrintLevel = maulogger.LevelWarn.Name case lc.PrintLevel >= maulogger.LevelInfo.Severity: lc.RawPrintLevel = maulogger.LevelInfo.Name default: lc.RawPrintLevel = maulogger.LevelDebug.Name } return lc, nil } // CreateLogConfig creates a basic LogConfig. func CreateLogConfig() LogConfig { return LogConfig{ Directory: "./logs", FileNameFormat: "%[1]s-%02[2]d.log", TimestampFormat: "Jan _2, 2006 15:04:05", FileMode: 0600, FileDateFormat: "2006-01-02", PrintLevel: 10, } } type FileFormatData struct { Date string Index int } // GetFileFormat returns a mauLogger-compatible logger file format based on the data in the struct. func (lc LogConfig) GetFileFormat() maulogger.LoggerFileFormat { if len(lc.Directory) > 0 { _ = os.MkdirAll(lc.Directory, 0700) } path := filepath.Join(lc.Directory, lc.FileNameFormat) tpl, _ := template.New("fileformat").Parse(path) return func(now string, i int) string { var buf strings.Builder _ = tpl.Execute(&buf, FileFormatData{ Date: now, Index: i, }) return buf.String() } } // Configure configures a mauLogger instance with the data in this struct. func (lc LogConfig) Configure(log maulogger.Logger) { basicLogger := log.(*maulogger.BasicLogger) basicLogger.FileFormat = lc.GetFileFormat() basicLogger.FileMode = os.FileMode(lc.FileMode) basicLogger.FileTimeFormat = lc.FileDateFormat basicLogger.TimeFormat = lc.TimestampFormat basicLogger.PrintLevel = lc.PrintLevel basicLogger.JSONFile = lc.JSONFile if lc.JSONStdout { basicLogger.EnableJSONStdout() } } go-0.11.1/appservice/eventprocessor.go000066400000000000000000000065501436100171500177540ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package appservice import ( "encoding/json" "runtime/debug" log "maunium.net/go/maulogger/v2" "maunium.net/go/mautrix" "maunium.net/go/mautrix/event" ) type ExecMode uint8 const ( AsyncHandlers ExecMode = iota AsyncLoop Sync ) type EventHandler func(evt *event.Event) type OTKHandler func(otk *mautrix.OTKCount) type DeviceListHandler func(otk *mautrix.DeviceLists, since string) type EventProcessor struct { ExecMode ExecMode as *AppService log log.Logger stop chan struct{} handlers map[event.Type][]EventHandler otkHandlers []OTKHandler deviceListHandlers []DeviceListHandler } func NewEventProcessor(as *AppService) *EventProcessor { return &EventProcessor{ ExecMode: AsyncHandlers, as: as, log: as.Log.Sub("Events"), stop: make(chan struct{}, 1), handlers: make(map[event.Type][]EventHandler), otkHandlers: make([]OTKHandler, 0), deviceListHandlers: make([]DeviceListHandler, 0), } } func (ep *EventProcessor) On(evtType event.Type, handler EventHandler) { handlers, ok := ep.handlers[evtType] if !ok { handlers = []EventHandler{handler} } else { handlers = append(handlers, handler) } ep.handlers[evtType] = handlers } func (ep *EventProcessor) OnOTK(handler OTKHandler) { ep.otkHandlers = append(ep.otkHandlers, handler) } func (ep *EventProcessor) OnDeviceList(handler DeviceListHandler) { ep.deviceListHandlers = append(ep.deviceListHandlers, handler) } func (ep *EventProcessor) recoverFunc(data interface{}) { if err := recover(); err != nil { d, _ := json.Marshal(data) ep.log.Errorfln("Panic in Matrix event handler: %v (event content: %s):\n%s", err, string(d), string(debug.Stack())) } } func (ep *EventProcessor) callHandler(handler EventHandler, evt *event.Event) { defer ep.recoverFunc(evt) handler(evt) } func (ep *EventProcessor) callOTKHandler(handler OTKHandler, otk *mautrix.OTKCount) { defer ep.recoverFunc(otk) handler(otk) } func (ep *EventProcessor) callDeviceListHandler(handler DeviceListHandler, dl *mautrix.DeviceLists) { defer ep.recoverFunc(dl) handler(dl, "") } func (ep *EventProcessor) DispatchOTK(otk *mautrix.OTKCount) { for _, handler := range ep.otkHandlers { go ep.callOTKHandler(handler, otk) } } func (ep *EventProcessor) DispatchDeviceList(dl *mautrix.DeviceLists) { for _, handler := range ep.deviceListHandlers { go ep.callDeviceListHandler(handler, dl) } } func (ep *EventProcessor) Dispatch(evt *event.Event) { handlers, ok := ep.handlers[evt.Type] if !ok { return } switch ep.ExecMode { case AsyncHandlers: for _, handler := range handlers { go ep.callHandler(handler, evt) } case AsyncLoop: go func() { for _, handler := range handlers { ep.callHandler(handler, evt) } }() case Sync: for _, handler := range handlers { ep.callHandler(handler, evt) } } } func (ep *EventProcessor) Start() { for { select { case evt := <-ep.as.Events: ep.Dispatch(evt) case otk := <-ep.as.OTKCounts: ep.DispatchOTK(otk) case dl := <-ep.as.DeviceLists: ep.DispatchDeviceList(dl) case <-ep.stop: return } } } func (ep *EventProcessor) Stop() { ep.stop <- struct{}{} } go-0.11.1/appservice/http.go000066400000000000000000000170141436100171500156470ustar00rootroot00000000000000// Copyright (c) 2021 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package appservice import ( "context" "encoding/json" "errors" "io/ioutil" "net/http" "strings" "time" "github.com/gorilla/mux" "maunium.net/go/mautrix" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) // Start starts the HTTP server that listens for calls from the Matrix homeserver. func (as *AppService) Start() { as.Router.HandleFunc("/transactions/{txnID}", as.PutTransaction).Methods(http.MethodPut) as.Router.HandleFunc("/rooms/{roomAlias}", as.GetRoom).Methods(http.MethodGet) as.Router.HandleFunc("/users/{userID}", as.GetUser).Methods(http.MethodGet) as.Router.HandleFunc("/_matrix/app/v1/transactions/{txnID}", as.PutTransaction).Methods(http.MethodPut) as.Router.HandleFunc("/_matrix/app/v1/rooms/{roomAlias}", as.GetRoom).Methods(http.MethodGet) as.Router.HandleFunc("/_matrix/app/v1/users/{userID}", as.GetUser).Methods(http.MethodGet) as.Router.HandleFunc("/_matrix/mau/live", as.GetLive).Methods(http.MethodGet) as.Router.HandleFunc("/_matrix/mau/ready", as.GetReady).Methods(http.MethodGet) var err error as.server = &http.Server{ Addr: as.Host.Address(), Handler: as.Router, } as.Log.Infoln("Listening on", as.Host.Address()) if len(as.Host.TLSCert) == 0 || len(as.Host.TLSKey) == 0 { err = as.server.ListenAndServe() } else { err = as.server.ListenAndServeTLS(as.Host.TLSCert, as.Host.TLSKey) } if err != nil && err.Error() != "http: Server closed" { as.Log.Fatalln("Error while listening:", err) } else { as.Log.Debugln("Listener stopped.") } } func (as *AppService) Stop() { if as.server == nil { return } ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() _ = as.server.Shutdown(ctx) as.server = nil } // CheckServerToken checks if the given request originated from the Matrix homeserver. func (as *AppService) CheckServerToken(w http.ResponseWriter, r *http.Request) (isValid bool) { authHeader := r.Header.Get("Authorization") if len(authHeader) > 0 && strings.HasPrefix(authHeader, "Bearer ") { isValid = authHeader[len("Bearer "):] == as.Registration.ServerToken } else { queryToken := r.URL.Query().Get("access_token") if len(queryToken) > 0 { isValid = queryToken == as.Registration.ServerToken } else { Error{ ErrorCode: ErrUnknownToken, HTTPStatus: http.StatusForbidden, Message: "Missing access token", }.Write(w) return } } if !isValid { Error{ ErrorCode: ErrUnknownToken, HTTPStatus: http.StatusForbidden, Message: "Incorrect access token", }.Write(w) } return } // PutTransaction handles a /transactions PUT call from the homeserver. func (as *AppService) PutTransaction(w http.ResponseWriter, r *http.Request) { if !as.CheckServerToken(w, r) { return } vars := mux.Vars(r) txnID := vars["txnID"] if len(txnID) == 0 { Error{ ErrorCode: ErrNoTransactionID, HTTPStatus: http.StatusBadRequest, Message: "Missing transaction ID", }.Write(w) return } defer r.Body.Close() body, err := ioutil.ReadAll(r.Body) if err != nil || len(body) == 0 { Error{ ErrorCode: ErrNotJSON, HTTPStatus: http.StatusBadRequest, Message: "Missing request body", }.Write(w) return } if as.txnIDC.IsProcessed(txnID) { // Duplicate transaction ID: no-op WriteBlankOK(w) as.Log.Debugfln("Ignoring duplicate transaction %s", txnID) return } var txn Transaction err = json.Unmarshal(body, &txn) if err != nil { as.Log.Warnfln("Failed to parse JSON of transaction %s: %v", txnID, err) Error{ ErrorCode: ErrBadJSON, HTTPStatus: http.StatusBadRequest, Message: "Failed to parse body JSON", }.Write(w) } else { as.handleTransaction(txnID, &txn) WriteBlankOK(w) } } func (as *AppService) handleTransaction(id string, txn *Transaction) { as.Log.Debugfln("Starting handling of transaction %s (%s)", id, txn.ContentString()) if as.Registration.EphemeralEvents { if txn.EphemeralEvents != nil { as.handleEvents(txn.EphemeralEvents, event.EphemeralEventType) } else if txn.MSC2409EphemeralEvents != nil { as.handleEvents(txn.MSC2409EphemeralEvents, event.EphemeralEventType) } } as.handleEvents(txn.Events, event.UnknownEventType) if txn.DeviceLists != nil { as.handleDeviceLists(txn.DeviceLists) } else if txn.MSC3202DeviceLists != nil { as.handleDeviceLists(txn.MSC3202DeviceLists) } if txn.DeviceOTKCount != nil { as.handleOTKCounts(txn.DeviceOTKCount) } else if txn.MSC3202DeviceOTKCount != nil { as.handleOTKCounts(txn.MSC3202DeviceOTKCount) } as.txnIDC.MarkProcessed(id) } func (as *AppService) handleOTKCounts(otks map[id.UserID]mautrix.OTKCount) { for userID, otkCounts := range otks { otkCounts.UserID = userID select { case as.OTKCounts <- &otkCounts: default: as.Log.Warnfln("Dropped OTK count update for %s because channel is full", userID) } } } func (as *AppService) handleDeviceLists(dl *mautrix.DeviceLists) { select { case as.DeviceLists <- dl: default: as.Log.Warnln("Dropped device list update because channel is full") } } func (as *AppService) handleEvents(evts []*event.Event, defaultTypeClass event.TypeClass) { for _, evt := range evts { if len(evt.ToUserID) > 0 { evt.Type.Class = event.ToDeviceEventType } else if defaultTypeClass != event.UnknownEventType { evt.Type.Class = defaultTypeClass } else if evt.StateKey != nil { evt.Type.Class = event.StateEventType } else { evt.Type.Class = event.MessageEventType } err := evt.Content.ParseRaw(evt.Type) if errors.Is(err, event.ErrUnsupportedContentType) { as.Log.Debugfln("Not parsing content of %s: %v", evt.ID, err) } else if err != nil { as.Log.Debugfln("Failed to parse content of %s (type %s): %v", evt.ID, evt.Type.Type, err) } if _, ok := CheckpointTypes[evt.Type]; ok { go as.SendMessageSendCheckpoint(evt, StepBridge, 0) } if evt.Type.IsState() { // TODO remove this check after https://github.com/matrix-org/synapse/pull/11265 historical, ok := evt.Content.Raw["org.matrix.msc2716.historical"].(bool) if !ok || !historical { as.UpdateState(evt) } } as.Events <- evt } } // GetRoom handles a /rooms GET call from the homeserver. func (as *AppService) GetRoom(w http.ResponseWriter, r *http.Request) { if !as.CheckServerToken(w, r) { return } vars := mux.Vars(r) roomAlias := vars["roomAlias"] ok := as.QueryHandler.QueryAlias(roomAlias) if ok { WriteBlankOK(w) } else { Error{ ErrorCode: ErrUnknown, HTTPStatus: http.StatusNotFound, }.Write(w) } } // GetUser handles a /users GET call from the homeserver. func (as *AppService) GetUser(w http.ResponseWriter, r *http.Request) { if !as.CheckServerToken(w, r) { return } vars := mux.Vars(r) userID := id.UserID(vars["userID"]) ok := as.QueryHandler.QueryUser(userID) if ok { WriteBlankOK(w) } else { Error{ ErrorCode: ErrUnknown, HTTPStatus: http.StatusNotFound, }.Write(w) } } func (as *AppService) GetLive(w http.ResponseWriter, r *http.Request) { w.Header().Add("Content-Type", "application/json") if as.Live { w.WriteHeader(http.StatusOK) } else { w.WriteHeader(http.StatusInternalServerError) } w.Write([]byte("{}")) } func (as *AppService) GetReady(w http.ResponseWriter, r *http.Request) { w.Header().Add("Content-Type", "application/json") if as.Ready { w.WriteHeader(http.StatusOK) } else { w.WriteHeader(http.StatusInternalServerError) } w.Write([]byte("{}")) } go-0.11.1/appservice/intent.go000066400000000000000000000321231436100171500161670ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package appservice import ( "encoding/json" "errors" "fmt" "strings" "maunium.net/go/mautrix" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) type IntentAPI struct { *mautrix.Client bot *mautrix.Client as *AppService Localpart string UserID id.UserID IsCustomPuppet bool } func (as *AppService) NewIntentAPI(localpart string) *IntentAPI { userID := id.NewUserID(localpart, as.HomeserverDomain) bot := as.BotClient() if userID == bot.UserID { bot = nil } return &IntentAPI{ Client: as.Client(userID), bot: bot, as: as, Localpart: localpart, UserID: userID, IsCustomPuppet: false, } } func (intent *IntentAPI) Register() error { _, _, err := intent.Client.Register(&mautrix.ReqRegister{ Username: intent.Localpart, Type: mautrix.AuthTypeAppservice, InhibitLogin: true, }) return err } func (intent *IntentAPI) EnsureRegistered() error { if intent.IsCustomPuppet || intent.as.StateStore.IsRegistered(intent.UserID) { return nil } err := intent.Register() if err != nil && !errors.Is(err, mautrix.MUserInUse) { return fmt.Errorf("failed to ensure registered: %w", err) } intent.as.StateStore.MarkRegistered(intent.UserID) return nil } type EnsureJoinedParams struct { IgnoreCache bool BotOverride *mautrix.Client } func (intent *IntentAPI) EnsureJoined(roomID id.RoomID, extra ...EnsureJoinedParams) error { var params EnsureJoinedParams if len(extra) > 1 { panic("invalid number of extra parameters") } else if len(extra) == 1 { params = extra[0] } if intent.as.StateStore.IsInRoom(roomID, intent.UserID) && !params.IgnoreCache { return nil } if err := intent.EnsureRegistered(); err != nil { return fmt.Errorf("failed to ensure joined: %w", err) } resp, err := intent.JoinRoomByID(roomID) if err != nil { bot := intent.bot if params.BotOverride != nil { bot = params.BotOverride } if !errors.Is(err, mautrix.MForbidden) || bot == nil { return fmt.Errorf("failed to ensure joined: %w", err) } _, inviteErr := bot.InviteUser(roomID, &mautrix.ReqInviteUser{ UserID: intent.UserID, }) if inviteErr != nil { return fmt.Errorf("failed to invite in ensure joined: %w", inviteErr) } resp, err = intent.JoinRoomByID(roomID) if err != nil { return fmt.Errorf("failed to ensure joined after invite: %w", err) } } intent.as.StateStore.SetMembership(resp.RoomID, intent.UserID, event.MembershipJoin) return nil } func (intent *IntentAPI) SendMessageEvent(roomID id.RoomID, eventType event.Type, contentJSON interface{}) (*mautrix.RespSendEvent, error) { if err := intent.EnsureJoined(roomID); err != nil { return nil, err } return intent.Client.SendMessageEvent(roomID, eventType, contentJSON) } func (intent *IntentAPI) SendMassagedMessageEvent(roomID id.RoomID, eventType event.Type, contentJSON interface{}, ts int64) (*mautrix.RespSendEvent, error) { if err := intent.EnsureJoined(roomID); err != nil { return nil, err } return intent.Client.SendMessageEvent(roomID, eventType, contentJSON, mautrix.ReqSendEvent{Timestamp: ts}) } func (intent *IntentAPI) updateStoreWithOutgoingEvent(roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, eventID id.EventID) { fakeEvt := &event.Event{ StateKey: &stateKey, Sender: intent.UserID, Type: eventType, ID: eventID, RoomID: roomID, Content: event.Content{}, } var err error fakeEvt.Content.VeryRaw, err = json.Marshal(contentJSON) if err != nil { intent.Logger.Debugfln("Failed to marshal state event content to update state store: %v", err) return } err = json.Unmarshal(fakeEvt.Content.VeryRaw, &fakeEvt.Content.Raw) if err != nil { intent.Logger.Debugfln("Failed to unmarshal state event content to update state store: %v", err) return } err = fakeEvt.Content.ParseRaw(fakeEvt.Type) if err != nil { intent.Logger.Debugfln("Failed to parse state event content to update state store: %v", err) return } intent.as.UpdateState(fakeEvt) } func (intent *IntentAPI) SendStateEvent(roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}) (*mautrix.RespSendEvent, error) { if err := intent.EnsureJoined(roomID); err != nil { return nil, err } resp, err := intent.Client.SendStateEvent(roomID, eventType, stateKey, contentJSON) if err == nil && resp != nil { intent.updateStoreWithOutgoingEvent(roomID, eventType, stateKey, contentJSON, resp.EventID) } return resp, err } func (intent *IntentAPI) SendMassagedStateEvent(roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, ts int64) (*mautrix.RespSendEvent, error) { if err := intent.EnsureJoined(roomID); err != nil { return nil, err } resp, err := intent.Client.SendMassagedStateEvent(roomID, eventType, stateKey, contentJSON, ts) if err == nil && resp != nil { intent.updateStoreWithOutgoingEvent(roomID, eventType, stateKey, contentJSON, resp.EventID) } return resp, err } func (intent *IntentAPI) StateEvent(roomID id.RoomID, eventType event.Type, stateKey string, outContent interface{}) error { if err := intent.EnsureJoined(roomID); err != nil { return err } err := intent.Client.StateEvent(roomID, eventType, stateKey, outContent) if err == nil { intent.updateStoreWithOutgoingEvent(roomID, eventType, stateKey, outContent, "") } return err } func (intent *IntentAPI) State(roomID id.RoomID) (mautrix.RoomStateMap, error) { if err := intent.EnsureJoined(roomID); err != nil { return nil, err } state, err := intent.Client.State(roomID) if err == nil { for _, events := range state { for _, evt := range events { intent.as.UpdateState(evt) } } } return state, err } func (intent *IntentAPI) InviteUser(roomID id.RoomID, req *mautrix.ReqInviteUser) (resp *mautrix.RespInviteUser, err error) { resp, err = intent.Client.InviteUser(roomID, req) if err == nil { intent.as.StateStore.SetMembership(roomID, req.UserID, event.MembershipInvite) } return } func (intent *IntentAPI) KickUser(roomID id.RoomID, req *mautrix.ReqKickUser) (resp *mautrix.RespKickUser, err error) { resp, err = intent.Client.KickUser(roomID, req) if err == nil { intent.as.StateStore.SetMembership(roomID, req.UserID, event.MembershipLeave) } return } func (intent *IntentAPI) BanUser(roomID id.RoomID, req *mautrix.ReqBanUser) (resp *mautrix.RespBanUser, err error) { resp, err = intent.Client.BanUser(roomID, req) if err == nil { intent.as.StateStore.SetMembership(roomID, req.UserID, event.MembershipBan) } return } func (intent *IntentAPI) UnbanUser(roomID id.RoomID, req *mautrix.ReqUnbanUser) (resp *mautrix.RespUnbanUser, err error) { resp, err = intent.Client.UnbanUser(roomID, req) if err == nil { intent.as.StateStore.SetMembership(roomID, req.UserID, event.MembershipLeave) } return } func (intent *IntentAPI) Member(roomID id.RoomID, userID id.UserID) *event.MemberEventContent { member, ok := intent.as.StateStore.TryGetMember(roomID, userID) if !ok { _ = intent.StateEvent(roomID, event.StateMember, string(userID), &member) intent.as.StateStore.SetMember(roomID, userID, member) } return member } func (intent *IntentAPI) PowerLevels(roomID id.RoomID) (pl *event.PowerLevelsEventContent, err error) { pl = intent.as.StateStore.GetPowerLevels(roomID) if pl == nil { pl = &event.PowerLevelsEventContent{} err = intent.StateEvent(roomID, event.StatePowerLevels, "", pl) if err == nil { intent.as.StateStore.SetPowerLevels(roomID, pl) } } return } func (intent *IntentAPI) SetPowerLevels(roomID id.RoomID, levels *event.PowerLevelsEventContent) (resp *mautrix.RespSendEvent, err error) { resp, err = intent.SendStateEvent(roomID, event.StatePowerLevels, "", &levels) if err == nil { intent.as.StateStore.SetPowerLevels(roomID, levels) } return } func (intent *IntentAPI) SetPowerLevel(roomID id.RoomID, userID id.UserID, level int) (*mautrix.RespSendEvent, error) { pl, err := intent.PowerLevels(roomID) if err != nil { return nil, err } if pl.GetUserLevel(userID) != level { pl.SetUserLevel(userID, level) return intent.SendStateEvent(roomID, event.StatePowerLevels, "", &pl) } return nil, nil } func (intent *IntentAPI) UserTyping(roomID id.RoomID, typing bool, timeout int64) (resp *mautrix.RespTyping, err error) { if intent.as.StateStore.IsTyping(roomID, intent.UserID) == typing { return } resp, err = intent.Client.UserTyping(roomID, typing, timeout) if err != nil { return } if !typing { timeout = -1 } intent.as.StateStore.SetTyping(roomID, intent.UserID, timeout) return } func (intent *IntentAPI) SendText(roomID id.RoomID, text string) (*mautrix.RespSendEvent, error) { if err := intent.EnsureJoined(roomID); err != nil { return nil, err } return intent.Client.SendText(roomID, text) } // Deprecated: This does not allow setting image metadata, you should prefer SendMessageEvent with a properly filled &event.MessageEventContent func (intent *IntentAPI) SendImage(roomID id.RoomID, body string, url id.ContentURI) (*mautrix.RespSendEvent, error) { if err := intent.EnsureJoined(roomID); err != nil { return nil, err } return intent.Client.SendImage(roomID, body, url) } // Deprecated: This does not allow setting video metadata, you should prefer SendMessageEvent with a properly filled &event.MessageEventContent func (intent *IntentAPI) SendVideo(roomID id.RoomID, body string, url id.ContentURI) (*mautrix.RespSendEvent, error) { if err := intent.EnsureJoined(roomID); err != nil { return nil, err } return intent.Client.SendVideo(roomID, body, url) } func (intent *IntentAPI) SendNotice(roomID id.RoomID, text string) (*mautrix.RespSendEvent, error) { if err := intent.EnsureJoined(roomID); err != nil { return nil, err } return intent.Client.SendNotice(roomID, text) } func (intent *IntentAPI) RedactEvent(roomID id.RoomID, eventID id.EventID, req ...mautrix.ReqRedact) (*mautrix.RespSendEvent, error) { if err := intent.EnsureJoined(roomID); err != nil { return nil, err } return intent.Client.RedactEvent(roomID, eventID, req...) } func (intent *IntentAPI) SetRoomName(roomID id.RoomID, roomName string) (*mautrix.RespSendEvent, error) { return intent.SendStateEvent(roomID, event.StateRoomName, "", map[string]interface{}{ "name": roomName, }) } func (intent *IntentAPI) SetRoomAvatar(roomID id.RoomID, avatarURL id.ContentURI) (*mautrix.RespSendEvent, error) { return intent.SendStateEvent(roomID, event.StateRoomAvatar, "", map[string]interface{}{ "url": avatarURL.String(), }) } func (intent *IntentAPI) SetRoomTopic(roomID id.RoomID, topic string) (*mautrix.RespSendEvent, error) { return intent.SendStateEvent(roomID, event.StateTopic, "", map[string]interface{}{ "topic": topic, }) } func (intent *IntentAPI) SetDisplayName(displayName string) error { if err := intent.EnsureRegistered(); err != nil { return err } resp, err := intent.Client.GetOwnDisplayName() if err != nil { return fmt.Errorf("failed to check current displayname: %w", err) } else if resp.DisplayName == displayName { // No need to update return nil } return intent.Client.SetDisplayName(displayName) } func (intent *IntentAPI) SetAvatarURL(avatarURL id.ContentURI) error { if err := intent.EnsureRegistered(); err != nil { return err } resp, err := intent.Client.GetOwnAvatarURL() if err != nil { return fmt.Errorf("failed to check current avatar URL: %w", err) } else if resp.FileID == avatarURL.FileID && resp.Homeserver == avatarURL.Homeserver { // No need to update return nil } return intent.Client.SetAvatarURL(avatarURL) } func (intent *IntentAPI) Whoami() (*mautrix.RespWhoami, error) { if err := intent.EnsureRegistered(); err != nil { return nil, err } return intent.Client.Whoami() } func (intent *IntentAPI) JoinedMembers(roomID id.RoomID) (resp *mautrix.RespJoinedMembers, err error) { resp, err = intent.Client.JoinedMembers(roomID) if err != nil { return } for userID, member := range resp.Joined { var displayname string var avatarURL id.ContentURIString if member.DisplayName != nil { displayname = *member.DisplayName } if member.AvatarURL != nil { avatarURL = id.ContentURIString(*member.AvatarURL) } intent.as.StateStore.SetMember(roomID, userID, &event.MemberEventContent{ Membership: event.MembershipJoin, AvatarURL: avatarURL, Displayname: displayname, }) } return } func (intent *IntentAPI) Members(roomID id.RoomID, req ...mautrix.ReqMembers) (resp *mautrix.RespMembers, err error) { resp, err = intent.Client.Members(roomID, req...) if err != nil { return } for _, evt := range resp.Chunk { intent.as.UpdateState(evt) } return } func (intent *IntentAPI) EnsureInvited(roomID id.RoomID, userID id.UserID) error { if !intent.as.StateStore.IsInvited(roomID, userID) { _, err := intent.InviteUser(roomID, &mautrix.ReqInviteUser{ UserID: userID, }) if httpErr, ok := err.(mautrix.HTTPError); ok && httpErr.RespError != nil && strings.Contains(httpErr.RespError.Err, "is already in the room") { return nil } return err } return nil } go-0.11.1/appservice/message_send_checkpoint.go000066400000000000000000000122701436100171500215330ustar00rootroot00000000000000// Copyright (c) 2021 Sumner Evans // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package appservice import ( "bytes" "context" "encoding/json" "fmt" "io/ioutil" "net/http" "time" "maunium.net/go/mautrix" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) type MessageSendCheckpointStep string const ( StepClient MessageSendCheckpointStep = "CLIENT" StepHomeserver MessageSendCheckpointStep = "HOMESERVER" StepBridge MessageSendCheckpointStep = "BRIDGE" StepDecrypted MessageSendCheckpointStep = "DECRYPTED" StepRemote MessageSendCheckpointStep = "REMOTE" StepCommand MessageSendCheckpointStep = "COMMAND" ) type MessageSendCheckpointStatus string const ( StatusSuccesss MessageSendCheckpointStatus = "SUCCESS" StatusWillRetry MessageSendCheckpointStatus = "WILL_RETRY" StatusPermFailure MessageSendCheckpointStatus = "PERM_FAILURE" StatusUnsupported MessageSendCheckpointStatus = "UNSUPPORTED" StatusTimeout MessageSendCheckpointStatus = "TIMEOUT" ) type MessageSendCheckpointReportedBy string const ( ReportedByAsmux MessageSendCheckpointReportedBy = "ASMUX" ReportedByBridge MessageSendCheckpointReportedBy = "BRIDGE" ) type MessageSendCheckpoint struct { EventID id.EventID `json:"event_id"` RoomID id.RoomID `json:"room_id"` Step MessageSendCheckpointStep `json:"step"` Timestamp int64 `json:"timestamp"` Status MessageSendCheckpointStatus `json:"status"` EventType event.Type `json:"event_type"` ReportedBy MessageSendCheckpointReportedBy `json:"reported_by"` RetryNum int `json:"retry_num"` MessageType event.MessageType `json:"message_type,omitempty"` Info string `json:"info,omitempty"` } var CheckpointTypes = map[event.Type]struct{}{ event.EventRedaction: {}, event.EventMessage: {}, event.EventEncrypted: {}, event.EventSticker: {}, event.EventReaction: {}, event.CallInvite: {}, event.CallCandidates: {}, event.CallSelectAnswer: {}, event.CallAnswer: {}, event.CallHangup: {}, event.CallReject: {}, event.CallNegotiate: {}, } func NewMessageSendCheckpoint(evt *event.Event, step MessageSendCheckpointStep, status MessageSendCheckpointStatus, retryNum int) *MessageSendCheckpoint { checkpoint := MessageSendCheckpoint{ EventID: evt.ID, RoomID: evt.RoomID, Step: step, Timestamp: time.Now().UnixNano() / int64(time.Millisecond), Status: status, EventType: evt.Type, ReportedBy: ReportedByBridge, RetryNum: retryNum, } if evt.Type == event.EventMessage { checkpoint.MessageType = evt.Content.AsMessage().MsgType } return &checkpoint } func (as *AppService) SendMessageSendCheckpoint(evt *event.Event, step MessageSendCheckpointStep, retryNum int) { checkpoint := NewMessageSendCheckpoint(evt, step, StatusSuccesss, retryNum) go checkpoint.Send(as) } func (as *AppService) SendErrorMessageSendCheckpoint(evt *event.Event, step MessageSendCheckpointStep, err error, permanent bool, retryNum int) { status := StatusWillRetry if permanent { status = StatusPermFailure } checkpoint := NewMessageSendCheckpoint(evt, step, status, retryNum) checkpoint.Info = err.Error() go checkpoint.Send(as) } func (cp *MessageSendCheckpoint) Send(as *AppService) { err := SendCheckpoints(as, []*MessageSendCheckpoint{cp}) if err != nil { as.Log.Warnfln("Error sending checkpoint %s/%s for %s: %v", cp.Step, cp.Status, cp.EventID, err) } } type CheckpointsJSON struct { Checkpoints []*MessageSendCheckpoint `json:"checkpoints"` } func SendCheckpoints(as *AppService, checkpoints []*MessageSendCheckpoint) error { checkpointsJSON := CheckpointsJSON{Checkpoints: checkpoints} if as.HasWebsocket() { return as.SendWebsocket(&WebsocketRequest{ Command: "message_checkpoint", Data: checkpointsJSON, }) } if as.MessageSendCheckpointEndpoint == "" { return nil } var body bytes.Buffer if err := json.NewEncoder(&body).Encode(checkpointsJSON); err != nil { return fmt.Errorf("failed to encode message send checkpoint JSON: %w", err) } ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() req, err := http.NewRequestWithContext(ctx, http.MethodPost, as.MessageSendCheckpointEndpoint, &body) if err != nil { return err } req.Header.Set("Authorization", "Bearer "+as.Registration.AppToken) req.Header.Set("User-Agent", mautrix.DefaultUserAgent+" checkpoint sender") req.Header.Set("Content-Type", "application/json") resp, err := http.DefaultClient.Do(req) if err != nil { return fmt.Errorf("failed to send bridge state update: %w", err) } defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode > 299 { respBody, _ := ioutil.ReadAll(resp.Body) if respBody != nil { respBody = bytes.ReplaceAll(respBody, []byte("\n"), []byte("\\n")) } return fmt.Errorf("unexpected status code %d sending bridge state update: %s", resp.StatusCode, respBody) } return nil } go-0.11.1/appservice/protocol.go000066400000000000000000000066011436100171500165310ustar00rootroot00000000000000// Copyright (c) 2021 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package appservice import ( "encoding/json" "fmt" "net/http" "strings" "maunium.net/go/mautrix" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) // Transaction contains a list of events. type Transaction struct { Events []*event.Event `json:"events"` EphemeralEvents []*event.Event `json:"ephemeral,omitempty"` DeviceLists *mautrix.DeviceLists `json:"device_lists,omitempty"` DeviceOTKCount map[id.UserID]mautrix.OTKCount `json:"device_one_time_keys_count,omitempty"` MSC2409EphemeralEvents []*event.Event `json:"de.sorunome.msc2409.ephemeral,omitempty"` MSC3202DeviceLists *mautrix.DeviceLists `json:"org.matrix.msc3202.device_lists,omitempty"` MSC3202DeviceOTKCount map[id.UserID]mautrix.OTKCount `json:"org.matrix.msc3202.device_one_time_keys_count,omitempty"` } func (txn *Transaction) ContentString() string { var parts []string if len(txn.Events) > 0 { parts = append(parts, fmt.Sprintf("%d PDUs", len(txn.Events))) } if len(txn.EphemeralEvents) > 0 { parts = append(parts, fmt.Sprintf("%d EDUs", len(txn.EphemeralEvents))) } else if len(txn.MSC2409EphemeralEvents) > 0 { parts = append(parts, fmt.Sprintf("%d EDUs (unstable)", len(txn.MSC2409EphemeralEvents))) } if len(txn.DeviceOTKCount) > 0 { parts = append(parts, fmt.Sprintf("OTK counts for %d users", len(txn.DeviceOTKCount))) } else if len(txn.MSC3202DeviceOTKCount) > 0 { parts = append(parts, fmt.Sprintf("OTK counts for %d users (unstable)", len(txn.MSC3202DeviceOTKCount))) } if txn.DeviceLists != nil { parts = append(parts, fmt.Sprintf("%d device list changes", len(txn.DeviceLists.Changed))) } else if txn.MSC3202DeviceLists != nil { parts = append(parts, fmt.Sprintf("%d device list changes (unstable)", len(txn.MSC3202DeviceLists.Changed))) } return strings.Join(parts, ", ") } // EventListener is a function that receives events. type EventListener func(evt *event.Event) // WriteBlankOK writes a blank OK message as a reply to a HTTP request. func WriteBlankOK(w http.ResponseWriter) { w.Header().Add("Content-Type", "application/json") w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("{}")) } // Respond responds to a HTTP request with a JSON object. func Respond(w http.ResponseWriter, data interface{}) error { w.Header().Add("Content-Type", "application/json") dataStr, err := json.Marshal(data) if err != nil { return err } _, err = w.Write(dataStr) return err } // Error represents a Matrix protocol error. type Error struct { HTTPStatus int `json:"-"` ErrorCode ErrorCode `json:"errcode"` Message string `json:"error"` } func (err Error) Write(w http.ResponseWriter) { w.Header().Add("Content-Type", "application/json") w.WriteHeader(err.HTTPStatus) _ = Respond(w, &err) } // ErrorCode is the machine-readable code in an Error. type ErrorCode string // Native ErrorCodes const ( ErrUnknownToken ErrorCode = "M_UNKNOWN_TOKEN" ErrBadJSON ErrorCode = "M_BAD_JSON" ErrNotJSON ErrorCode = "M_NOT_JSON" ErrUnknown ErrorCode = "M_UNKNOWN" ) // Custom ErrorCodes const ( ErrNoTransactionID ErrorCode = "NET.MAUNIUM.NO_TRANSACTION_ID" ) go-0.11.1/appservice/random.go000066400000000000000000000014631436100171500161510ustar00rootroot00000000000000package appservice import ( "math/rand" "time" "unsafe" ) const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" const ( letterIdxBits = 6 letterIdxMask = 1<= 0; { if remain == 0 { cache, remain = src.Int63(), letterIdxMax } if idx := int(cache & letterIdxMask); idx < len(letterBytes) { b[i] = letterBytes[idx] i-- } cache >>= letterIdxBits remain-- } return *(*string)(unsafe.Pointer(&b)) } go-0.11.1/appservice/registration.go000066400000000000000000000061021436100171500173760ustar00rootroot00000000000000// Copyright (c) 2019 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package appservice import ( "io/ioutil" "regexp" "gopkg.in/yaml.v2" ) // Registration contains the data in a Matrix appservice registration. // See https://matrix.org/docs/spec/application_service/unstable.html#registration type Registration struct { ID string `yaml:"id"` URL string `yaml:"url"` AppToken string `yaml:"as_token"` ServerToken string `yaml:"hs_token"` SenderLocalpart string `yaml:"sender_localpart"` RateLimited *bool `yaml:"rate_limited,omitempty"` Namespaces Namespaces `yaml:"namespaces"` EphemeralEvents bool `yaml:"de.sorunome.msc2409.push_ephemeral,omitempty"` Protocols []string `yaml:"protocols,omitempty"` } // CreateRegistration creates a Registration with random appservice and homeserver tokens. func CreateRegistration() *Registration { return &Registration{ AppToken: RandomString(64), ServerToken: RandomString(64), } } // LoadRegistration loads a YAML file and turns it into a Registration. func LoadRegistration(path string) (*Registration, error) { data, err := ioutil.ReadFile(path) if err != nil { return nil, err } reg := &Registration{} err = yaml.Unmarshal(data, reg) if err != nil { return nil, err } return reg, nil } // Save saves this Registration into a file at the given path. func (reg *Registration) Save(path string) error { data, err := yaml.Marshal(reg) if err != nil { return err } return ioutil.WriteFile(path, data, 0600) } // YAML returns the registration in YAML format. func (reg *Registration) YAML() (string, error) { data, err := yaml.Marshal(reg) if err != nil { return "", err } return string(data), nil } // Namespaces contains the three areas that appservices can reserve parts of. type Namespaces struct { UserIDs []Namespace `yaml:"users,omitempty"` RoomAliases []Namespace `yaml:"aliases,omitempty"` RoomIDs []Namespace `yaml:"rooms,omitempty"` } // Namespace is a reserved namespace in any area. type Namespace struct { Regex string `yaml:"regex"` Exclusive bool `yaml:"exclusive"` } // RegisterUserIDs creates an user ID namespace registration. func (nslist *Namespaces) RegisterUserIDs(regex *regexp.Regexp, exclusive bool) { nslist.UserIDs = append(nslist.UserIDs, Namespace{ Regex: regex.String(), Exclusive: exclusive, }) } // RegisterRoomAliases creates an room alias namespace registration. func (nslist *Namespaces) RegisterRoomAliases(regex *regexp.Regexp, exclusive bool) { nslist.RoomAliases = append(nslist.RoomAliases, Namespace{ Regex: regex.String(), Exclusive: exclusive, }) } // RegisterRoomIDs creates an room ID namespace registration. func (nslist *Namespaces) RegisterRoomIDs(regex *regexp.Regexp, exclusive bool) { nslist.RoomIDs = append(nslist.RoomIDs, Namespace{ Regex: regex.String(), Exclusive: exclusive, }) } go-0.11.1/appservice/statestore.go000066400000000000000000000166531436100171500170750ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package appservice import ( "sync" "time" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) type StateStore interface { IsRegistered(userID id.UserID) bool MarkRegistered(userID id.UserID) IsTyping(roomID id.RoomID, userID id.UserID) bool SetTyping(roomID id.RoomID, userID id.UserID, timeout int64) IsInRoom(roomID id.RoomID, userID id.UserID) bool IsInvited(roomID id.RoomID, userID id.UserID) bool IsMembership(roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool GetMember(roomID id.RoomID, userID id.UserID) *event.MemberEventContent TryGetMember(roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, bool) SetMembership(roomID id.RoomID, userID id.UserID, membership event.Membership) SetMember(roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) SetPowerLevels(roomID id.RoomID, levels *event.PowerLevelsEventContent) GetPowerLevels(roomID id.RoomID) *event.PowerLevelsEventContent GetPowerLevel(roomID id.RoomID, userID id.UserID) int GetPowerLevelRequirement(roomID id.RoomID, eventType event.Type) int HasPowerLevel(roomID id.RoomID, userID id.UserID, eventType event.Type) bool } func (as *AppService) UpdateState(evt *event.Event) { switch content := evt.Content.Parsed.(type) { case *event.MemberEventContent: as.StateStore.SetMember(evt.RoomID, id.UserID(evt.GetStateKey()), content) case *event.PowerLevelsEventContent: as.StateStore.SetPowerLevels(evt.RoomID, content) } } type TypingStateStore struct { typing map[id.RoomID]map[id.UserID]int64 typingLock sync.RWMutex } func NewTypingStateStore() *TypingStateStore { return &TypingStateStore{ typing: make(map[id.RoomID]map[id.UserID]int64), } } func (store *TypingStateStore) IsTyping(roomID id.RoomID, userID id.UserID) bool { store.typingLock.RLock() defer store.typingLock.RUnlock() roomTyping, ok := store.typing[roomID] if !ok { return false } typingEndsAt := roomTyping[userID] return typingEndsAt >= time.Now().Unix() } func (store *TypingStateStore) SetTyping(roomID id.RoomID, userID id.UserID, timeout int64) { store.typingLock.Lock() defer store.typingLock.Unlock() roomTyping, ok := store.typing[roomID] if !ok { if timeout >= 0 { roomTyping = map[id.UserID]int64{ userID: time.Now().Unix() + timeout, } } else { return } } else { if timeout >= 0 { roomTyping[userID] = time.Now().Unix() + timeout } else { delete(roomTyping, userID) } } store.typing[roomID] = roomTyping } type BasicStateStore struct { registrationsLock sync.RWMutex `json:"-"` Registrations map[id.UserID]bool `json:"registrations"` membersLock sync.RWMutex `json:"-"` Members map[id.RoomID]map[id.UserID]*event.MemberEventContent `json:"memberships"` powerLevelsLock sync.RWMutex `json:"-"` PowerLevels map[id.RoomID]*event.PowerLevelsEventContent `json:"power_levels"` *TypingStateStore } func NewBasicStateStore() StateStore { return &BasicStateStore{ Registrations: make(map[id.UserID]bool), Members: make(map[id.RoomID]map[id.UserID]*event.MemberEventContent), PowerLevels: make(map[id.RoomID]*event.PowerLevelsEventContent), TypingStateStore: NewTypingStateStore(), } } func (store *BasicStateStore) IsRegistered(userID id.UserID) bool { store.registrationsLock.RLock() defer store.registrationsLock.RUnlock() registered, ok := store.Registrations[userID] return ok && registered } func (store *BasicStateStore) MarkRegistered(userID id.UserID) { store.registrationsLock.Lock() defer store.registrationsLock.Unlock() store.Registrations[userID] = true } func (store *BasicStateStore) GetRoomMembers(roomID id.RoomID) map[id.UserID]*event.MemberEventContent { store.membersLock.RLock() members, ok := store.Members[roomID] store.membersLock.RUnlock() if !ok { members = make(map[id.UserID]*event.MemberEventContent) store.membersLock.Lock() store.Members[roomID] = members store.membersLock.Unlock() } return members } func (store *BasicStateStore) GetMembership(roomID id.RoomID, userID id.UserID) event.Membership { return store.GetMember(roomID, userID).Membership } func (store *BasicStateStore) GetMember(roomID id.RoomID, userID id.UserID) *event.MemberEventContent { member, ok := store.TryGetMember(roomID, userID) if !ok { member = &event.MemberEventContent{Membership: event.MembershipLeave} } return member } func (store *BasicStateStore) TryGetMember(roomID id.RoomID, userID id.UserID) (member *event.MemberEventContent, ok bool) { store.membersLock.RLock() defer store.membersLock.RUnlock() members, membersOk := store.Members[roomID] if !membersOk { return } member, ok = members[userID] return } func (store *BasicStateStore) IsInRoom(roomID id.RoomID, userID id.UserID) bool { return store.IsMembership(roomID, userID, "join") } func (store *BasicStateStore) IsInvited(roomID id.RoomID, userID id.UserID) bool { return store.IsMembership(roomID, userID, "join", "invite") } func (store *BasicStateStore) IsMembership(roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool { membership := store.GetMembership(roomID, userID) for _, allowedMembership := range allowedMemberships { if allowedMembership == membership { return true } } return false } func (store *BasicStateStore) SetMembership(roomID id.RoomID, userID id.UserID, membership event.Membership) { store.membersLock.Lock() members, ok := store.Members[roomID] if !ok { members = map[id.UserID]*event.MemberEventContent{ userID: {Membership: membership}, } } else { member, ok := members[userID] if !ok { members[userID] = &event.MemberEventContent{Membership: membership} } else { member.Membership = membership members[userID] = member } } store.Members[roomID] = members store.membersLock.Unlock() } func (store *BasicStateStore) SetMember(roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) { store.membersLock.Lock() members, ok := store.Members[roomID] if !ok { members = map[id.UserID]*event.MemberEventContent{ userID: member, } } else { members[userID] = member } store.Members[roomID] = members store.membersLock.Unlock() } func (store *BasicStateStore) SetPowerLevels(roomID id.RoomID, levels *event.PowerLevelsEventContent) { store.powerLevelsLock.Lock() store.PowerLevels[roomID] = levels store.powerLevelsLock.Unlock() } func (store *BasicStateStore) GetPowerLevels(roomID id.RoomID) (levels *event.PowerLevelsEventContent) { store.powerLevelsLock.RLock() levels = store.PowerLevels[roomID] store.powerLevelsLock.RUnlock() return } func (store *BasicStateStore) GetPowerLevel(roomID id.RoomID, userID id.UserID) int { return store.GetPowerLevels(roomID).GetUserLevel(userID) } func (store *BasicStateStore) GetPowerLevelRequirement(roomID id.RoomID, eventType event.Type) int { return store.GetPowerLevels(roomID).GetEventLevel(eventType) } func (store *BasicStateStore) HasPowerLevel(roomID id.RoomID, userID id.UserID, eventType event.Type) bool { return store.GetPowerLevel(roomID, userID) >= store.GetPowerLevelRequirement(roomID, eventType) } go-0.11.1/appservice/txnid.go000066400000000000000000000021031436100171500160070ustar00rootroot00000000000000// Copyright (c) 2021 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package appservice import "sync" type TransactionIDCache struct { array []string arrayPtr int hash map[string]struct{} lock sync.RWMutex } func NewTransactionIDCache(size int) *TransactionIDCache { return &TransactionIDCache{ array: make([]string, size), hash: make(map[string]struct{}), } } func (txnIDC *TransactionIDCache) IsProcessed(txnID string) bool { txnIDC.lock.RLock() _, exists := txnIDC.hash[txnID] txnIDC.lock.RUnlock() return exists } func (txnIDC *TransactionIDCache) MarkProcessed(txnID string) { txnIDC.lock.Lock() txnIDC.hash[txnID] = struct{}{} if txnIDC.array[txnIDC.arrayPtr] != "" { for i := 0; i < len(txnIDC.array)/8; i++ { delete(txnIDC.hash, txnIDC.array[txnIDC.arrayPtr+i]) txnIDC.array[txnIDC.arrayPtr+i] = "" } } txnIDC.array[txnIDC.arrayPtr] = txnID txnIDC.lock.Unlock() } go-0.11.1/appservice/websocket.go000066400000000000000000000252211436100171500166550ustar00rootroot00000000000000// Copyright (c) 2022 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package appservice import ( "context" "encoding/json" "errors" "fmt" "net/http" "net/url" "path/filepath" "strings" "sync" "sync/atomic" "time" "github.com/gorilla/websocket" "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) type WebsocketRequest struct { ReqID int `json:"id,omitempty"` Command string `json:"command"` Data interface{} `json:"data"` Deadline time.Duration `json:"-"` } type WebsocketCommand struct { ReqID int `json:"id,omitempty"` Command string `json:"command"` Data json.RawMessage `json:"data"` } func (wsc *WebsocketCommand) MakeResponse(ok bool, data interface{}) *WebsocketRequest { if wsc.ReqID == 0 { return nil } cmd := "response" if !ok { cmd = "error" } if err, isError := data.(error); isError { var errorData json.RawMessage var jsonErr error unwrappedErr := err var prefixMessage string for unwrappedErr != nil { errorData, jsonErr = json.Marshal(unwrappedErr) if errorData != nil && len(errorData) > 2 && jsonErr == nil { prefixMessage = strings.Replace(err.Error(), unwrappedErr.Error(), "", 1) prefixMessage = strings.TrimRight(prefixMessage, ": ") break } unwrappedErr = errors.Unwrap(unwrappedErr) } if errorData != nil { if !gjson.GetBytes(errorData, "message").Exists() { errorData, _ = sjson.SetBytes(errorData, "message", err.Error()) } // else: marshaled error contains a message already } else { errorData, _ = sjson.SetBytes(nil, "message", err.Error()) } if len(prefixMessage) > 0 { errorData, _ = sjson.SetBytes(errorData, "prefix_message", prefixMessage) } data = errorData } return &WebsocketRequest{ ReqID: wsc.ReqID, Command: cmd, Data: data, } } type WebsocketTransaction struct { Status string `json:"status"` TxnID string `json:"txn_id"` Transaction } type WebsocketMessage struct { WebsocketTransaction WebsocketCommand } type MeowWebsocketCloseCode string const ( MeowServerShuttingDown MeowWebsocketCloseCode = "server_shutting_down" MeowConnectionReplaced MeowWebsocketCloseCode = "conn_replaced" ) var ( ErrWebsocketManualStop = errors.New("the websocket was disconnected manually") ErrWebsocketOverridden = errors.New("a new call to StartWebsocket overrode the previous connection") ErrWebsocketUnknownError = errors.New("an unknown error occurred") ErrWebsocketNotConnected = errors.New("websocket not connected") ErrWebsocketClosed = errors.New("websocket closed before response received") ) func (mwcc MeowWebsocketCloseCode) String() string { switch mwcc { case MeowServerShuttingDown: return "the server is shutting down" case MeowConnectionReplaced: return "the connection was replaced by another client" default: return string(mwcc) } } type CloseCommand struct { Code int `json:"-"` Command string `json:"command"` Status MeowWebsocketCloseCode `json:"status"` } func (cc CloseCommand) Error() string { return fmt.Sprintf("websocket: close %d: %s", cc.Code, cc.Status.String()) } func parseCloseError(err error) error { closeError := &websocket.CloseError{} if !errors.As(err, &closeError) { return err } var closeCommand CloseCommand closeCommand.Code = closeError.Code closeCommand.Command = "disconnect" if len(closeError.Text) > 0 { jsonErr := json.Unmarshal([]byte(closeError.Text), &closeCommand) if jsonErr != nil { return err } } if len(closeCommand.Status) == 0 { if closeCommand.Code == 4001 { closeCommand.Status = MeowConnectionReplaced } else if closeCommand.Code == websocket.CloseServiceRestart { closeCommand.Status = MeowServerShuttingDown } } return &closeCommand } func (as *AppService) HasWebsocket() bool { return as.ws != nil } func (as *AppService) SendWebsocket(cmd *WebsocketRequest) error { ws := as.ws if cmd == nil { return nil } else if ws == nil { return ErrWebsocketNotConnected } as.wsWriteLock.Lock() defer as.wsWriteLock.Unlock() if cmd.Deadline == 0 { cmd.Deadline = 3 * time.Minute } _ = ws.SetWriteDeadline(time.Now().Add(cmd.Deadline)) return ws.WriteJSON(cmd) } func (as *AppService) clearWebsocketResponseWaiters() { as.websocketRequestsLock.Lock() for _, waiter := range as.websocketRequests { waiter <- &WebsocketCommand{Command: "__websocket_closed"} } as.websocketRequests = make(map[int]chan<- *WebsocketCommand) as.websocketRequestsLock.Unlock() } func (as *AppService) addWebsocketResponseWaiter(reqID int, waiter chan<- *WebsocketCommand) { as.websocketRequestsLock.Lock() as.websocketRequests[reqID] = waiter as.websocketRequestsLock.Unlock() } func (as *AppService) removeWebsocketResponseWaiter(reqID int, waiter chan<- *WebsocketCommand) { as.websocketRequestsLock.Lock() existingWaiter, ok := as.websocketRequests[reqID] if ok && existingWaiter == waiter { delete(as.websocketRequests, reqID) } close(waiter) as.websocketRequestsLock.Unlock() } type ErrorResponse struct { Code string `json:"code"` Message string `json:"message"` } func (er *ErrorResponse) Error() string { return fmt.Sprintf("%s: %s", er.Code, er.Message) } func (as *AppService) RequestWebsocket(ctx context.Context, cmd *WebsocketRequest, response interface{}) error { cmd.ReqID = int(atomic.AddInt32(&as.websocketRequestID, 1)) respChan := make(chan *WebsocketCommand, 1) as.addWebsocketResponseWaiter(cmd.ReqID, respChan) defer as.removeWebsocketResponseWaiter(cmd.ReqID, respChan) err := as.SendWebsocket(cmd) if err != nil { return err } select { case resp := <-respChan: if resp.Command == "__websocket_closed" { return ErrWebsocketClosed } else if resp.Command == "error" { var respErr ErrorResponse err = json.Unmarshal(resp.Data, &respErr) if err != nil { return fmt.Errorf("failed to parse error JSON: %w", err) } return &respErr } else if response != nil { err = json.Unmarshal(resp.Data, &response) if err != nil { return fmt.Errorf("failed to parse response JSON: %w", err) } return nil } else { return nil } case <-ctx.Done(): return ctx.Err() } } func (as *AppService) unknownCommandHandler(cmd WebsocketCommand) (bool, interface{}) { as.Log.Warnfln("No handler for websocket command %s (%d)", cmd.Command, cmd.ReqID) return false, fmt.Errorf("unknown request type") } func (as *AppService) SetWebsocketCommandHandler(cmd string, handler WebsocketHandler) { as.websocketHandlersLock.Lock() as.websocketHandlers[cmd] = handler as.websocketHandlersLock.Unlock() } func (as *AppService) consumeWebsocket(stopFunc func(error), ws *websocket.Conn) { defer stopFunc(ErrWebsocketUnknownError) for { var msg WebsocketMessage err := ws.ReadJSON(&msg) if err != nil { as.Log.Debugln("Error reading from websocket:", err) stopFunc(parseCloseError(err)) return } if msg.Command == "" || msg.Command == "transaction" { if msg.TxnID == "" || !as.txnIDC.IsProcessed(msg.TxnID) { as.handleTransaction(msg.TxnID, &msg.Transaction) } else { as.Log.Debugfln("Ignoring duplicate transaction %s (%s)", msg.TxnID, msg.Transaction.ContentString()) } go func() { err = as.SendWebsocket(msg.MakeResponse(true, map[string]interface{}{"txn_id": msg.TxnID})) if err != nil { as.Log.Warnfln("Failed to send response to %s %d: %v", msg.Command, msg.ReqID, err) } }() } else if msg.Command == "connect" { as.Log.Debugln("Websocket connect confirmation received") } else if msg.Command == "response" || msg.Command == "error" { as.websocketRequestsLock.RLock() respChan, ok := as.websocketRequests[msg.ReqID] if ok { select { case respChan <- &msg.WebsocketCommand: default: as.Log.Warnfln("Failed to handle response to %d: channel didn't accept response", msg.ReqID) } } else { as.Log.Warnfln("Dropping response to %d: unknown request ID", msg.ReqID) } as.websocketRequestsLock.RUnlock() } else { as.websocketHandlersLock.RLock() handler, ok := as.websocketHandlers[msg.Command] as.websocketHandlersLock.RUnlock() if !ok { handler = as.unknownCommandHandler } go func() { okResp, data := handler(msg.WebsocketCommand) err = as.SendWebsocket(msg.MakeResponse(okResp, data)) if err != nil { as.Log.Warnfln("Failed to send response to %s %d: %v", msg.Command, msg.ReqID, err) } else if okResp { as.Log.Debugfln("Sent success response to %s %d", msg.Command, msg.ReqID) } else { as.Log.Debugfln("Sent error response to %s %d", msg.Command, msg.ReqID) } }() } } } func (as *AppService) StartWebsocket(baseURL string, onConnect func()) error { parsed, err := url.Parse(baseURL) if err != nil { return fmt.Errorf("failed to parse URL: %w", err) } parsed.Path = filepath.Join(parsed.Path, "_matrix/client/unstable/fi.mau.as_sync") if parsed.Scheme == "http" { parsed.Scheme = "ws" } else if parsed.Scheme == "https" { parsed.Scheme = "wss" } ws, resp, err := websocket.DefaultDialer.Dial(parsed.String(), http.Header{ "Authorization": []string{fmt.Sprintf("Bearer %s", as.Registration.AppToken)}, "User-Agent": []string{as.BotClient().UserAgent}, "X-Mautrix-Process-ID": []string{as.ProcessID}, "X-Mautrix-Websocket-Version": []string{"3"}, }) if resp != nil && resp.StatusCode >= 400 { var errResp Error err = json.NewDecoder(resp.Body).Decode(&errResp) if err != nil { return fmt.Errorf("websocket request returned HTTP %d with non-JSON body", resp.StatusCode) } else { return fmt.Errorf("websocket request returned %s (HTTP %d): %s", errResp.ErrorCode, resp.StatusCode, errResp.Message) } } else if err != nil { return fmt.Errorf("failed to open websocket: %w", err) } if as.StopWebsocket != nil { as.StopWebsocket(ErrWebsocketOverridden) } closeChan := make(chan error) closeChanOnce := sync.Once{} stopFunc := func(err error) { closeChanOnce.Do(func() { closeChan <- err }) } as.ws = ws as.StopWebsocket = stopFunc as.PrepareWebsocket() as.Log.Debugln("Appservice transaction websocket connected") go as.consumeWebsocket(stopFunc, ws) if onConnect != nil { onConnect() } closeErr := <-closeChan if as.ws == ws { as.clearWebsocketResponseWaiters() as.ws = nil } err = ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, "")) if err != nil && !errors.Is(err, websocket.ErrCloseSent) { as.Log.Warnln("Error writing close message to websocket:", err) } err = ws.Close() if err != nil { as.Log.Warnln("Error closing websocket:", err) } return closeErr } go-0.11.1/client.go000066400000000000000000001644431436100171500140160ustar00rootroot00000000000000// Package mautrix implements the Matrix Client-Server API. // // Specification can be found at https://spec.matrix.org/v1.2/client-server-api/ package mautrix import ( "bytes" "context" "encoding/json" "errors" "fmt" "io" "io/ioutil" "net/http" "net/url" "os" "strconv" "sync/atomic" "time" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" "maunium.net/go/mautrix/pushrules" ) type Logger interface { Debugfln(message string, args ...interface{}) } // StubLogger is an implementation of Logger that does nothing type StubLogger struct{} func (sl *StubLogger) Debugfln(message string, args ...interface{}) {} func (sl *StubLogger) Warnfln(message string, args ...interface{}) {} var stubLogger = &StubLogger{} type WarnLogger interface { Logger Warnfln(message string, args ...interface{}) } type Stringifiable interface { String() string } // Client represents a Matrix client. type Client struct { HomeserverURL *url.URL // The base homeserver URL UserID id.UserID // The user ID of the client. Used for forming HTTP paths which use the client's user ID. DeviceID id.DeviceID // The device ID of the client. AccessToken string // The access_token for the client. UserAgent string // The value for the User-Agent header Client *http.Client // The underlying HTTP client which will be used to make HTTP requests. Syncer Syncer // The thing which can process /sync responses Store Storer // The thing which can store rooms/tokens/ids Logger Logger SyncPresence event.Presence StreamSyncMinAge time.Duration // Number of times that mautrix will retry any HTTP request // if the request fails entirely or returns a HTTP gateway error (502-504) DefaultHTTPRetries int // Set to true to disable automatically sleeping on 429 errors. IgnoreRateLimit bool txnID int32 // The ?user_id= query parameter for application services. This must be set *prior* to calling a method. // If this is empty, no user_id parameter will be sent. // See https://spec.matrix.org/v1.2/application-service-api/#identity-assertion AppServiceUserID id.UserID syncingID uint32 // Identifies the current Sync. Only one Sync can be active at any given time. } type ClientWellKnown struct { Homeserver HomeserverInfo `json:"m.homeserver"` IdentityServer IdentityServerInfo `json:"m.identity_server"` } type HomeserverInfo struct { BaseURL string `json:"base_url"` } type IdentityServerInfo struct { BaseURL string `json:"base_url"` } // DiscoverClientAPI resolves the client API URL from a Matrix server name. // Use ParseUserID to extract the server name from a user ID. // https://spec.matrix.org/v1.2/client-server-api/#server-discovery func DiscoverClientAPI(serverName string) (*ClientWellKnown, error) { wellKnownURL := url.URL{ Scheme: "https", Host: serverName, Path: "/.well-known/matrix/client", } req, err := http.NewRequest("GET", wellKnownURL.String(), nil) if err != nil { return nil, err } req.Header.Set("Accept", "application/json") req.Header.Set("User-Agent", DefaultUserAgent+" .well-known fetcher") client := &http.Client{Timeout: 30 * time.Second} resp, err := client.Do(req) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode == http.StatusNotFound { return nil, nil } data, err := ioutil.ReadAll(resp.Body) if err != nil { return nil, err } var wellKnown ClientWellKnown err = json.Unmarshal(data, &wellKnown) if err != nil { return nil, errors.New(".well-known response not JSON") } return &wellKnown, nil } // SetCredentials sets the user ID and access token on this client instance. // // Deprecated: use the StoreCredentials field in ReqLogin instead. func (cli *Client) SetCredentials(userID id.UserID, accessToken string) { cli.AccessToken = accessToken cli.UserID = userID } // ClearCredentials removes the user ID and access token on this client instance. func (cli *Client) ClearCredentials() { cli.AccessToken = "" cli.UserID = "" cli.DeviceID = "" } // Sync starts syncing with the provided Homeserver. If Sync() is called twice then the first sync will be stopped and the // error will be nil. // // This function will block until a fatal /sync error occurs, so it should almost always be started as a new goroutine. // Fatal sync errors can be caused by: // - The failure to create a filter. // - Client.Syncer.OnFailedSync returning an error in response to a failed sync. // - Client.Syncer.ProcessResponse returning an error. // If you wish to continue retrying in spite of these fatal errors, call Sync() again. func (cli *Client) Sync() error { return cli.SyncWithContext(context.Background()) } func (cli *Client) SyncWithContext(ctx context.Context) error { // Mark the client as syncing. // We will keep syncing until the syncing state changes. Either because // Sync is called or StopSync is called. syncingID := cli.incrementSyncingID() nextBatch := cli.Store.LoadNextBatch(cli.UserID) filterID := cli.Store.LoadFilterID(cli.UserID) if filterID == "" { filterJSON := cli.Syncer.GetFilterJSON(cli.UserID) resFilter, err := cli.CreateFilter(filterJSON) if err != nil { return err } filterID = resFilter.FilterID cli.Store.SaveFilterID(cli.UserID, filterID) } lastSuccessfulSync := time.Now().Add(-cli.StreamSyncMinAge - 1*time.Hour) for { streamResp := false if cli.StreamSyncMinAge > 0 && time.Since(lastSuccessfulSync) > cli.StreamSyncMinAge { cli.Logger.Debugfln("Last sync is old, will stream next response") streamResp = true } resSync, err := cli.FullSyncRequest(ReqSync{ Timeout: 30000, Since: nextBatch, FilterID: filterID, FullState: false, SetPresence: cli.SyncPresence, Context: ctx, StreamResponse: streamResp, }) if err != nil { if ctx.Err() != nil { return ctx.Err() } duration, err2 := cli.Syncer.OnFailedSync(resSync, err) if err2 != nil { return err2 } time.Sleep(duration) continue } lastSuccessfulSync = time.Now() // Check that the syncing state hasn't changed // Either because we've stopped syncing or another sync has been started. // We discard the response from our sync. if cli.getSyncingID() != syncingID { return nil } // Save the token now *before* processing it. This means it's possible // to not process some events, but it means that we won't get constantly stuck processing // a malformed/buggy event which keeps making us panic. cli.Store.SaveNextBatch(cli.UserID, resSync.NextBatch) if err = cli.Syncer.ProcessResponse(resSync, nextBatch); err != nil { return err } nextBatch = resSync.NextBatch } } func (cli *Client) incrementSyncingID() uint32 { return atomic.AddUint32(&cli.syncingID, 1) } func (cli *Client) getSyncingID() uint32 { return atomic.LoadUint32(&cli.syncingID) } // StopSync stops the ongoing sync started by Sync. func (cli *Client) StopSync() { // Advance the syncing state so that any running Syncs will terminate. cli.incrementSyncingID() } const logBodyContextKey = "fi.mau.mautrix.log_body" const logRequestIDContextKey = "fi.mau.mautrix.request_id" func (cli *Client) LogRequest(req *http.Request) { if cli.Logger == stubLogger { return } body, ok := req.Context().Value(logBodyContextKey).(string) reqID, _ := req.Context().Value(logRequestIDContextKey).(int) if ok && len(body) > 0 { cli.Logger.Debugfln("req #%d: %s %s %s", reqID, req.Method, req.URL.String(), body) } else { cli.Logger.Debugfln("req #%d: %s %s", reqID, req.Method, req.URL.String()) } } func (cli *Client) MakeRequest(method string, httpURL string, reqBody interface{}, resBody interface{}) ([]byte, error) { return cli.MakeFullRequest(FullRequest{Method: method, URL: httpURL, RequestJSON: reqBody, ResponseJSON: resBody}) } type ClientResponseHandler = func(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) type FullRequest struct { Method string URL string Headers http.Header RequestJSON interface{} RequestBytes []byte RequestBody io.Reader RequestLength int64 ResponseJSON interface{} Context context.Context MaxAttempts int SensitiveContent bool Handler ClientResponseHandler } var requestID int32 var logSensitiveContent = os.Getenv("MAUTRIX_LOG_SENSITIVE_CONTENT") == "yes" func (params *FullRequest) compileRequest() (*http.Request, error) { var logBody string reqBody := params.RequestBody if params.Context == nil { params.Context = context.Background() } if params.RequestJSON != nil { jsonStr, err := json.Marshal(params.RequestJSON) if err != nil { return nil, HTTPError{ Message: "failed to marshal JSON", WrappedError: err, } } if params.SensitiveContent && !logSensitiveContent { logBody = "" } else { logBody = string(jsonStr) } reqBody = bytes.NewReader(jsonStr) } else if params.RequestBytes != nil { logBody = fmt.Sprintf("<%d bytes>", len(params.RequestBytes)) reqBody = bytes.NewReader(params.RequestBytes) params.RequestLength = int64(len(params.RequestBytes)) } else if params.RequestLength > 0 && params.RequestBody != nil { logBody = fmt.Sprintf("<%d bytes>", params.RequestLength) } ctx := context.WithValue(params.Context, logBodyContextKey, logBody) reqID := atomic.AddInt32(&requestID, 1) ctx = context.WithValue(ctx, logRequestIDContextKey, int(reqID)) req, err := http.NewRequestWithContext(ctx, params.Method, params.URL, reqBody) if err != nil { return nil, HTTPError{ Message: "failed to create request", WrappedError: err, } } if params.Headers != nil { req.Header = params.Headers } if params.RequestJSON != nil { req.Header.Set("Content-Type", "application/json") } if params.RequestLength > 0 && params.RequestBody != nil { req.ContentLength = params.RequestLength } return req, nil } // MakeFullRequest makes a JSON HTTP request to the given URL. // If "resBody" is not nil, the response body will be json.Unmarshalled into it. // // Returns the HTTP body as bytes on 2xx with a nil error. Returns an error if the response is not 2xx along // with the HTTP body bytes if it got that far. This error is an HTTPError which includes the returned // HTTP status code and possibly a RespError as the WrappedError, if the HTTP body could be decoded as a RespError. func (cli *Client) MakeFullRequest(params FullRequest) ([]byte, error) { if params.MaxAttempts == 0 { params.MaxAttempts = 1 + cli.DefaultHTTPRetries } req, err := params.compileRequest() if err != nil { return nil, err } if params.Handler == nil { params.Handler = cli.handleNormalResponse } req.Header.Set("User-Agent", cli.UserAgent) if len(cli.AccessToken) > 0 { req.Header.Set("Authorization", "Bearer "+cli.AccessToken) } return cli.executeCompiledRequest(req, params.MaxAttempts-1, 4*time.Second, params.ResponseJSON, params.Handler) } func (cli *Client) logWarning(format string, args ...interface{}) { warnLogger, ok := cli.Logger.(WarnLogger) if ok { warnLogger.Warnfln(format, args...) } else { cli.Logger.Debugfln(format, args...) } } func (cli *Client) doRetry(req *http.Request, cause error, retries int, backoff time.Duration, responseJSON interface{}, handler ClientResponseHandler) ([]byte, error) { reqID, _ := req.Context().Value(logRequestIDContextKey).(int) if req.Body != nil { if req.GetBody == nil { cli.logWarning("Failed to get new body to retry request #%d: GetBody is nil", reqID) return nil, cause } var err error req.Body, err = req.GetBody() if err != nil { cli.logWarning("Failed to get new body to retry request #%d: %v", reqID, err) return nil, cause } } cli.logWarning("Request #%d failed: %v, retrying in %d seconds", reqID, cause, int(backoff.Seconds())) time.Sleep(backoff) return cli.executeCompiledRequest(req, retries-1, backoff*2, responseJSON, handler) } func (cli *Client) readRequestBody(req *http.Request, res *http.Response) ([]byte, error) { contents, err := ioutil.ReadAll(res.Body) if err != nil { return nil, HTTPError{ Request: req, Response: res, Message: "failed to read response body", WrappedError: err, } } return contents, nil } func (cli *Client) closeTemp(file *os.File) { _ = file.Close() err := os.Remove(file.Name()) if err != nil { cli.logWarning("Failed to remove temp file %s: %v", file.Name(), err) } } func (cli *Client) streamResponse(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) { file, err := ioutil.TempFile("", "mautrix-response-") if err != nil { cli.logWarning("Failed to create temporary file: %v", err) _, err = cli.handleNormalResponse(req, res, responseJSON) return nil, err } defer cli.closeTemp(file) if _, err = io.Copy(file, res.Body); err != nil { return nil, fmt.Errorf("failed to copy response to file: %w", err) } else if _, err = file.Seek(0, 0); err != nil { return nil, fmt.Errorf("failed to seek to beginning of response file: %w", err) } else if err = json.NewDecoder(file).Decode(responseJSON); err != nil { return nil, fmt.Errorf("failed to unmarshal response body: %w", err) } else { return nil, nil } } func (cli *Client) handleNormalResponse(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) { if contents, err := cli.readRequestBody(req, res); err != nil { return nil, err } else if responseJSON == nil { return contents, nil } else if err = json.Unmarshal(contents, &responseJSON); err != nil { return nil, HTTPError{ Request: req, Response: res, Message: "failed to unmarshal response body", ResponseBody: string(contents), WrappedError: err, } } else { return contents, nil } } func (cli *Client) handleResponseError(req *http.Request, res *http.Response) ([]byte, error) { contents, err := cli.readRequestBody(req, res) if err != nil { return contents, err } respErr := &RespError{} if _ = json.Unmarshal(contents, respErr); respErr.ErrCode == "" { respErr = nil } return contents, HTTPError{ Request: req, Response: res, RespError: respErr, } } // parseBackoffFromResponse extracts the backoff time specified in the Retry-After header if present. See // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After. func (cli *Client) parseBackoffFromResponse(res *http.Response, now time.Time, fallback time.Duration) time.Duration { retryAfterHeaderValue := res.Header.Get("Retry-After") if retryAfterHeaderValue == "" { return fallback } if t, err := time.Parse(http.TimeFormat, retryAfterHeaderValue); err == nil { return t.Sub(now) } if seconds, err := strconv.Atoi(retryAfterHeaderValue); err == nil { return time.Duration(seconds) * time.Second } cli.logWarning(`Failed to parse Retry-After header value "%s"`, retryAfterHeaderValue) return fallback } func (cli *Client) shouldRetry(res *http.Response) bool { return res.StatusCode == http.StatusBadGateway || res.StatusCode == http.StatusServiceUnavailable || res.StatusCode == http.StatusGatewayTimeout || (res.StatusCode == http.StatusTooManyRequests && !cli.IgnoreRateLimit) } func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backoff time.Duration, responseJSON interface{}, handler ClientResponseHandler) ([]byte, error) { cli.LogRequest(req) res, err := cli.Client.Do(req) if res != nil { defer res.Body.Close() } if err != nil { if retries > 0 { return cli.doRetry(req, err, retries, backoff, responseJSON, handler) } return nil, HTTPError{ Request: req, Response: res, Message: "request error", WrappedError: err, } } if retries > 0 && cli.shouldRetry(res) { if res.StatusCode == http.StatusTooManyRequests { backoff = cli.parseBackoffFromResponse(res, time.Now(), backoff) } return cli.doRetry(req, fmt.Errorf("HTTP %d", res.StatusCode), retries, backoff, responseJSON, handler) } if res.StatusCode < 200 || res.StatusCode >= 300 { return cli.handleResponseError(req, res) } return handler(req, res, responseJSON) } // Whoami gets the user ID of the current user. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3accountwhoami func (cli *Client) Whoami() (resp *RespWhoami, err error) { urlPath := cli.BuildClientURL("v3", "account", "whoami") _, err = cli.MakeRequest("GET", urlPath, nil, &resp) return } // CreateFilter makes an HTTP request according to https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3useruseridfilter func (cli *Client) CreateFilter(filter *Filter) (resp *RespCreateFilter, err error) { urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "filter") _, err = cli.MakeRequest("POST", urlPath, filter, &resp) return } // SyncRequest makes an HTTP request according to https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3sync func (cli *Client) SyncRequest(timeout int, since, filterID string, fullState bool, setPresence event.Presence, ctx context.Context) (resp *RespSync, err error) { return cli.FullSyncRequest(ReqSync{ Timeout: timeout, Since: since, FilterID: filterID, FullState: fullState, SetPresence: setPresence, Context: ctx, }) } type ReqSync struct { Timeout int Since string FilterID string FullState bool SetPresence event.Presence Context context.Context StreamResponse bool } func (req *ReqSync) BuildQuery() map[string]string { query := map[string]string{ "timeout": strconv.Itoa(req.Timeout), } if req.Since != "" { query["since"] = req.Since } if req.FilterID != "" { query["filter"] = req.FilterID } if req.SetPresence != "" { query["set_presence"] = string(req.SetPresence) } if req.FullState { query["full_state"] = "true" } return query } // FullSyncRequest makes an HTTP request according to https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3sync func (cli *Client) FullSyncRequest(req ReqSync) (resp *RespSync, err error) { urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "sync"}, req.BuildQuery()) fullReq := FullRequest{ Method: http.MethodGet, URL: urlPath, ResponseJSON: &resp, Context: req.Context, // We don't want automatic retries for SyncRequest, the Sync() wrapper handles those. MaxAttempts: 1, } if req.StreamResponse { fullReq.Handler = cli.streamResponse } start := time.Now() _, err = cli.MakeFullRequest(fullReq) duration := time.Now().Sub(start) timeout := time.Duration(req.Timeout) * time.Millisecond buffer := 10 * time.Second if req.Since == "" { buffer = 1 * time.Minute } if err == nil && duration > timeout+buffer { cli.logWarning("Sync request (%s) took %s with timeout %s", req.Since, duration, timeout) } return } func (cli *Client) register(url string, req *ReqRegister) (resp *RespRegister, uiaResp *RespUserInteractive, err error) { var bodyBytes []byte bodyBytes, err = cli.MakeFullRequest(FullRequest{ Method: http.MethodPost, URL: url, RequestJSON: req, SensitiveContent: len(req.Password) > 0, }) if err != nil { httpErr, ok := err.(HTTPError) // if response has a 401 status, but doesn't have the errcode field, it's probably a UIA response. if ok && httpErr.IsStatus(http.StatusUnauthorized) && httpErr.RespError == nil { err = json.Unmarshal(bodyBytes, &uiaResp) } } else { // body should be RespRegister err = json.Unmarshal(bodyBytes, &resp) } return } // Register makes an HTTP request according to https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3register // // Registers with kind=user. For kind=guest, see RegisterGuest. func (cli *Client) Register(req *ReqRegister) (*RespRegister, *RespUserInteractive, error) { u := cli.BuildClientURL("v3", "register") return cli.register(u, req) } // RegisterGuest makes an HTTP request according to https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3register // with kind=guest. // // For kind=user, see Register. func (cli *Client) RegisterGuest(req *ReqRegister) (*RespRegister, *RespUserInteractive, error) { query := map[string]string{ "kind": "guest", } u := cli.BuildURLWithQuery(ClientURLPath{"v3", "register"}, query) return cli.register(u, req) } // RegisterDummy performs m.login.dummy registration according to https://spec.matrix.org/v1.2/client-server-api/#dummy-auth // // Only a username and password need to be provided on the ReqRegister struct. Most local/developer homeservers will allow registration // this way. If the homeserver does not, an error is returned. // // This does not set credentials on the client instance. See SetCredentials() instead. // // res, err := cli.RegisterDummy(&mautrix.ReqRegister{ // Username: "alice", // Password: "wonderland", // }) // if err != nil { // panic(err) // } // token := res.AccessToken func (cli *Client) RegisterDummy(req *ReqRegister) (*RespRegister, error) { res, uia, err := cli.Register(req) if err != nil && uia == nil { return nil, err } else if uia == nil { return nil, errors.New("server did not return user-interactive auth flows") } else if !uia.HasSingleStageFlow(AuthTypeDummy) { return nil, errors.New("server does not support m.login.dummy") } req.Auth = BaseAuthData{Type: AuthTypeDummy, Session: uia.Session} res, _, err = cli.Register(req) if err != nil { return nil, err } return res, nil } // GetLoginFlows fetches the login flows that the homeserver supports using https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3login func (cli *Client) GetLoginFlows() (resp *RespLoginFlows, err error) { urlPath := cli.BuildClientURL("v3", "login") _, err = cli.MakeRequest("GET", urlPath, nil, &resp) return } // Login a user to the homeserver according to https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3login func (cli *Client) Login(req *ReqLogin) (resp *RespLogin, err error) { _, err = cli.MakeFullRequest(FullRequest{ Method: http.MethodPost, URL: cli.BuildClientURL("v3", "login"), RequestJSON: req, ResponseJSON: &resp, SensitiveContent: len(req.Password) > 0 || len(req.Token) > 0, }) if req.StoreCredentials && err == nil { cli.DeviceID = resp.DeviceID cli.AccessToken = resp.AccessToken cli.UserID = resp.UserID cli.Logger.Debugfln("Stored credentials for %s/%s after login", cli.UserID, cli.DeviceID) } if req.StoreHomeserverURL && err == nil && resp.WellKnown != nil && len(resp.WellKnown.Homeserver.BaseURL) > 0 { var urlErr error cli.HomeserverURL, urlErr = url.Parse(resp.WellKnown.Homeserver.BaseURL) if urlErr != nil { cli.logWarning("Failed to parse homeserver URL '%s' in login response: %v", resp.WellKnown.Homeserver.BaseURL, urlErr) } else { cli.Logger.Debugfln("Updated homeserver URL to %s after login", cli.HomeserverURL.String()) } } return } // Logout the current user. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3logout // This does not clear the credentials from the client instance. See ClearCredentials() instead. func (cli *Client) Logout() (resp *RespLogout, err error) { urlPath := cli.BuildClientURL("v3", "logout") _, err = cli.MakeRequest("POST", urlPath, nil, &resp) return } // LogoutAll logs out all the devices of the current user. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3logoutall // This does not clear the credentials from the client instance. See ClearCredentials() instead. func (cli *Client) LogoutAll() (resp *RespLogout, err error) { urlPath := cli.BuildClientURL("v3", "logout", "all") _, err = cli.MakeRequest("POST", urlPath, nil, &resp) return } // Versions returns the list of supported Matrix versions on this homeserver. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientversions func (cli *Client) Versions() (resp *RespVersions, err error) { urlPath := cli.BuildClientURL("versions") _, err = cli.MakeRequest("GET", urlPath, nil, &resp) return } // JoinRoom joins the client to a room ID or alias. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3joinroomidoralias // // If serverName is specified, this will be added as a query param to instruct the homeserver to join via that server. If content is specified, it will // be JSON encoded and used as the request body. func (cli *Client) JoinRoom(roomIDorAlias, serverName string, content interface{}) (resp *RespJoinRoom, err error) { var urlPath string if serverName != "" { urlPath = cli.BuildURLWithQuery(ClientURLPath{"v3", "join", roomIDorAlias}, map[string]string{ "server_name": serverName, }) } else { urlPath = cli.BuildClientURL("v3", "join", roomIDorAlias) } _, err = cli.MakeRequest("POST", urlPath, content, &resp) return } // JoinRoomByID joins the client to a room ID. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidjoin // // Unlike JoinRoom, this method can only be used to join rooms that the server already knows about. // It's mostly intended for bridges and other things where it's already certain that the server is in the room. func (cli *Client) JoinRoomByID(roomID id.RoomID) (resp *RespJoinRoom, err error) { _, err = cli.MakeRequest("POST", cli.BuildClientURL("v3", "rooms", roomID, "join"), nil, &resp) return } // GetDisplayName returns the display name of the user with the specified MXID. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3profileuseriddisplayname func (cli *Client) GetDisplayName(mxid id.UserID) (resp *RespUserDisplayName, err error) { urlPath := cli.BuildClientURL("v3", "profile", mxid, "displayname") _, err = cli.MakeRequest("GET", urlPath, nil, &resp) return } // GetOwnDisplayName returns the user's display name. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3profileuseriddisplayname func (cli *Client) GetOwnDisplayName() (resp *RespUserDisplayName, err error) { return cli.GetDisplayName(cli.UserID) } // SetDisplayName sets the user's profile display name. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3profileuseriddisplayname func (cli *Client) SetDisplayName(displayName string) (err error) { urlPath := cli.BuildClientURL("v3", "profile", cli.UserID, "displayname") s := struct { DisplayName string `json:"displayname"` }{displayName} _, err = cli.MakeRequest("PUT", urlPath, &s, nil) return } // GetAvatarURL gets the avatar URL of the user with the specified MXID. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3profileuseridavatar_url func (cli *Client) GetAvatarURL(mxid id.UserID) (url id.ContentURI, err error) { urlPath := cli.BuildClientURL("v3", "profile", cli.UserID, "avatar_url") s := struct { AvatarURL id.ContentURI `json:"avatar_url"` }{} _, err = cli.MakeRequest("GET", urlPath, nil, &s) if err != nil { return } url = s.AvatarURL return } // GetOwnAvatarURL gets the user's avatar URL. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3profileuseridavatar_url func (cli *Client) GetOwnAvatarURL() (url id.ContentURI, err error) { return cli.GetAvatarURL(cli.UserID) } // SetAvatarURL sets the user's avatar URL. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3profileuseridavatar_url func (cli *Client) SetAvatarURL(url id.ContentURI) (err error) { urlPath := cli.BuildClientURL("v3", "profile", cli.UserID, "avatar_url") s := struct { AvatarURL string `json:"avatar_url"` }{url.String()} _, err = cli.MakeRequest("PUT", urlPath, &s, nil) if err != nil { return err } return nil } // GetAccountData gets the user's account data of this type. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3useruseridaccount_datatype func (cli *Client) GetAccountData(name string, output interface{}) (err error) { urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "account_data", name) _, err = cli.MakeRequest("GET", urlPath, nil, output) return } // SetAccountData sets the user's account data of this type. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3useruseridaccount_datatype func (cli *Client) SetAccountData(name string, data interface{}) (err error) { urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "account_data", name) _, err = cli.MakeRequest("PUT", urlPath, &data, nil) if err != nil { return err } return nil } // GetRoomAccountData gets the user's account data of this type in a specific room. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3useruseridaccount_datatype func (cli *Client) GetRoomAccountData(roomID id.RoomID, name string, output interface{}) (err error) { urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "rooms", roomID, "account_data", name) _, err = cli.MakeRequest("GET", urlPath, nil, output) return } // SetRoomAccountData sets the user's account data of this type in a specific room. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3useruseridroomsroomidaccount_datatype func (cli *Client) SetRoomAccountData(roomID id.RoomID, name string, data interface{}) (err error) { urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "rooms", roomID, "account_data", name) _, err = cli.MakeRequest("PUT", urlPath, &data, nil) if err != nil { return err } return nil } type ReqSendEvent struct { Timestamp int64 TransactionID string } // SendMessageEvent sends a message event into a room. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidsendeventtypetxnid // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal. func (cli *Client) SendMessageEvent(roomID id.RoomID, eventType event.Type, contentJSON interface{}, extra ...ReqSendEvent) (resp *RespSendEvent, err error) { var req ReqSendEvent if len(extra) > 0 { req = extra[0] } var txnID string if len(req.TransactionID) > 0 { txnID = req.TransactionID } else { txnID = cli.TxnID() } queryParams := map[string]string{} if req.Timestamp > 0 { queryParams["ts"] = strconv.FormatInt(req.Timestamp, 10) } urlData := ClientURLPath{"v3", "rooms", roomID, "send", eventType.String(), txnID} urlPath := cli.BuildURLWithQuery(urlData, queryParams) _, err = cli.MakeRequest("PUT", urlPath, contentJSON, &resp) return } // SendStateEvent sends a state event into a room. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidstateeventtypestatekey // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal. func (cli *Client) SendStateEvent(roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}) (resp *RespSendEvent, err error) { urlPath := cli.BuildClientURL("v3", "rooms", roomID, "state", eventType.String(), stateKey) _, err = cli.MakeRequest("PUT", urlPath, contentJSON, &resp) return } // SendMassagedStateEvent sends a state event into a room with a custom timestamp. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidstateeventtypestatekey // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal. func (cli *Client) SendMassagedStateEvent(roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, ts int64) (resp *RespSendEvent, err error) { urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "rooms", roomID, "state", eventType.String(), stateKey}, map[string]string{ "ts": strconv.FormatInt(ts, 10), }) _, err = cli.MakeRequest("PUT", urlPath, contentJSON, &resp) return } // SendText sends an m.room.message event into the given room with a msgtype of m.text // See https://spec.matrix.org/v1.2/client-server-api/#mtext func (cli *Client) SendText(roomID id.RoomID, text string) (*RespSendEvent, error) { return cli.SendMessageEvent(roomID, event.EventMessage, &event.MessageEventContent{ MsgType: event.MsgText, Body: text, }) } // SendImage sends an m.room.message event into the given room with a msgtype of m.image // See https://spec.matrix.org/v1.2/client-server-api/#mimage // // Deprecated: This does not allow setting image metadata, you should prefer SendMessageEvent with a properly filled &event.MessageEventContent func (cli *Client) SendImage(roomID id.RoomID, body string, url id.ContentURI) (*RespSendEvent, error) { return cli.SendMessageEvent(roomID, event.EventMessage, &event.MessageEventContent{ MsgType: event.MsgImage, Body: body, URL: url.CUString(), }) } // SendVideo sends an m.room.message event into the given room with a msgtype of m.video // See https://spec.matrix.org/v1.2/client-server-api/#mvideo // // Deprecated: This does not allow setting video metadata, you should prefer SendMessageEvent with a properly filled &event.MessageEventContent func (cli *Client) SendVideo(roomID id.RoomID, body string, url id.ContentURI) (*RespSendEvent, error) { return cli.SendMessageEvent(roomID, event.EventMessage, &event.MessageEventContent{ MsgType: event.MsgVideo, Body: body, URL: url.CUString(), }) } // SendNotice sends an m.room.message event into the given room with a msgtype of m.notice // See https://spec.matrix.org/v1.2/client-server-api/#mnotice func (cli *Client) SendNotice(roomID id.RoomID, text string) (*RespSendEvent, error) { return cli.SendMessageEvent(roomID, event.EventMessage, &event.MessageEventContent{ MsgType: event.MsgNotice, Body: text, }) } func (cli *Client) SendReaction(roomID id.RoomID, eventID id.EventID, reaction string) (*RespSendEvent, error) { return cli.SendMessageEvent(roomID, event.EventReaction, &event.ReactionEventContent{ RelatesTo: event.RelatesTo{ EventID: eventID, Type: event.RelAnnotation, Key: reaction, }, }) } // RedactEvent redacts the given event. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidredacteventidtxnid func (cli *Client) RedactEvent(roomID id.RoomID, eventID id.EventID, extra ...ReqRedact) (resp *RespSendEvent, err error) { req := ReqRedact{} if len(extra) > 0 { req = extra[0] } if req.Extra == nil { req.Extra = make(map[string]interface{}) } if len(req.Reason) > 0 { req.Extra["reason"] = req.Reason } var txnID string if len(req.TxnID) > 0 { txnID = req.TxnID } else { txnID = cli.TxnID() } urlPath := cli.BuildClientURL("v3", "rooms", roomID, "redact", eventID, txnID) _, err = cli.MakeRequest("PUT", urlPath, req.Extra, &resp) return } // CreateRoom creates a new Matrix room. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3createroom // resp, err := cli.CreateRoom(&mautrix.ReqCreateRoom{ // Preset: "public_chat", // }) // fmt.Println("Room:", resp.RoomID) func (cli *Client) CreateRoom(req *ReqCreateRoom) (resp *RespCreateRoom, err error) { urlPath := cli.BuildClientURL("v3", "createRoom") _, err = cli.MakeRequest("POST", urlPath, req, &resp) return } // LeaveRoom leaves the given room. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidleave func (cli *Client) LeaveRoom(roomID id.RoomID, optionalReq ...*ReqLeave) (resp *RespLeaveRoom, err error) { req := &ReqLeave{} if len(optionalReq) == 1 { req = optionalReq[0] } else if len(optionalReq) > 1 { panic("invalid number of arguments to LeaveRoom") } u := cli.BuildClientURL("v3", "rooms", roomID, "leave") _, err = cli.MakeRequest("POST", u, req, &resp) return } // ForgetRoom forgets a room entirely. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidforget func (cli *Client) ForgetRoom(roomID id.RoomID) (resp *RespForgetRoom, err error) { u := cli.BuildClientURL("v3", "rooms", roomID, "forget") _, err = cli.MakeRequest("POST", u, struct{}{}, &resp) return } // InviteUser invites a user to a room. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidinvite func (cli *Client) InviteUser(roomID id.RoomID, req *ReqInviteUser) (resp *RespInviteUser, err error) { u := cli.BuildClientURL("v3", "rooms", roomID, "invite") _, err = cli.MakeRequest("POST", u, req, &resp) return } // InviteUserByThirdParty invites a third-party identifier to a room. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidinvite-1 func (cli *Client) InviteUserByThirdParty(roomID id.RoomID, req *ReqInvite3PID) (resp *RespInviteUser, err error) { u := cli.BuildClientURL("v3", "rooms", roomID, "invite") _, err = cli.MakeRequest("POST", u, req, &resp) return } // KickUser kicks a user from a room. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidkick func (cli *Client) KickUser(roomID id.RoomID, req *ReqKickUser) (resp *RespKickUser, err error) { u := cli.BuildClientURL("v3", "rooms", roomID, "kick") _, err = cli.MakeRequest("POST", u, req, &resp) return } // BanUser bans a user from a room. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidban func (cli *Client) BanUser(roomID id.RoomID, req *ReqBanUser) (resp *RespBanUser, err error) { u := cli.BuildClientURL("v3", "rooms", roomID, "ban") _, err = cli.MakeRequest("POST", u, req, &resp) return } // UnbanUser unbans a user from a room. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidunban func (cli *Client) UnbanUser(roomID id.RoomID, req *ReqUnbanUser) (resp *RespUnbanUser, err error) { u := cli.BuildClientURL("v3", "rooms", roomID, "unban") _, err = cli.MakeRequest("POST", u, req, &resp) return } // UserTyping sets the typing status of the user. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidtypinguserid func (cli *Client) UserTyping(roomID id.RoomID, typing bool, timeout int64) (resp *RespTyping, err error) { req := ReqTyping{Typing: typing, Timeout: timeout} u := cli.BuildClientURL("v3", "rooms", roomID, "typing", cli.UserID) _, err = cli.MakeRequest("PUT", u, req, &resp) return } // GetPresence gets the presence of the user with the specified MXID. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3presenceuseridstatus func (cli *Client) GetPresence(userID id.UserID) (resp *RespPresence, err error) { resp = new(RespPresence) u := cli.BuildClientURL("v3", "presence", userID, "status") _, err = cli.MakeRequest("GET", u, nil, resp) return } // GetOwnPresence gets the user's presence. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3presenceuseridstatus func (cli *Client) GetOwnPresence() (resp *RespPresence, err error) { return cli.GetPresence(cli.UserID) } func (cli *Client) SetPresence(status event.Presence) (err error) { req := ReqPresence{Presence: status} u := cli.BuildClientURL("v3", "presence", cli.UserID, "status") _, err = cli.MakeRequest("PUT", u, req, nil) return } // StateEvent gets a single state event in a room. It will attempt to JSON unmarshal into the given "outContent" struct with // the HTTP response body, or return an error. // See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3roomsroomidstateeventtypestatekey func (cli *Client) StateEvent(roomID id.RoomID, eventType event.Type, stateKey string, outContent interface{}) (err error) { u := cli.BuildClientURL("v3", "rooms", roomID, "state", eventType.String(), stateKey) _, err = cli.MakeRequest("GET", u, nil, outContent) return } // parseRoomStateArray parses a JSON array as a stream and stores the events inside it in a room state map. func parseRoomStateArray(_ *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) { response := make(RoomStateMap) responsePtr := responseJSON.(*map[event.Type]map[string]*event.Event) *responsePtr = response dec := json.NewDecoder(res.Body) arrayStart, err := dec.Token() if err != nil { return nil, err } else if arrayStart != json.Delim('[') { return nil, fmt.Errorf("expected array start, got %+v", arrayStart) } for i := 1; dec.More(); i++ { var evt *event.Event err = dec.Decode(&evt) if err != nil { return nil, fmt.Errorf("failed to parse state array item #%d: %v", i, err) } _ = evt.Content.ParseRaw(evt.Type) subMap, ok := response[evt.Type] if !ok { subMap = make(map[string]*event.Event) response[evt.Type] = subMap } subMap[*evt.StateKey] = evt } arrayEnd, err := dec.Token() if err != nil { return nil, err } else if arrayEnd != json.Delim(']') { return nil, fmt.Errorf("expected array end, got %+v", arrayStart) } return nil, nil } // State gets all state in a room. // See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3roomsroomidstate func (cli *Client) State(roomID id.RoomID) (stateMap RoomStateMap, err error) { _, err = cli.MakeFullRequest(FullRequest{ Method: http.MethodGet, URL: cli.BuildClientURL("v3", "rooms", roomID, "state"), ResponseJSON: &stateMap, Handler: parseRoomStateArray, }) return } // UploadLink uploads an HTTP URL and then returns an MXC URI. func (cli *Client) UploadLink(link string) (*RespMediaUpload, error) { res, err := cli.Client.Get(link) if res != nil { defer res.Body.Close() } if err != nil { return nil, err } return cli.Upload(res.Body, res.Header.Get("Content-Type"), res.ContentLength) } func (cli *Client) GetDownloadURL(mxcURL id.ContentURI) string { return cli.BuildURL(MediaURLPath{"v3", "download", mxcURL.Homeserver, mxcURL.FileID}) } func (cli *Client) Download(mxcURL id.ContentURI) (io.ReadCloser, error) { resp, err := cli.Client.Get(cli.GetDownloadURL(mxcURL)) if err != nil { return nil, err } return resp.Body, nil } func (cli *Client) DownloadBytes(mxcURL id.ContentURI) ([]byte, error) { resp, err := cli.Download(mxcURL) if err != nil { return nil, err } defer resp.Close() return ioutil.ReadAll(resp) } // UnstableCreateMXC creates a blank Matrix content URI to allow uploading the content asynchronously later. // See https://github.com/matrix-org/matrix-spec-proposals/pull/2246 func (cli *Client) UnstableCreateMXC() (*RespCreateMXC, error) { u, _ := url.Parse(cli.BuildURL(MediaURLPath{"unstable", "fi.mau.msc2246", "create"})) var m RespCreateMXC _, err := cli.MakeFullRequest(FullRequest{ Method: http.MethodPost, URL: u.String(), ResponseJSON: &m, }) return &m, err } // UnstableUploadAsync creates a blank content URI with UnstableCreateMXC, starts uploading the data in the background // and returns the created MXC immediately. See https://github.com/matrix-org/matrix-spec-proposals/pull/2246 for more info. func (cli *Client) UnstableUploadAsync(req ReqUploadMedia) (*RespCreateMXC, error) { resp, err := cli.UnstableCreateMXC() if err != nil { return nil, err } req.UnstableMXC = resp.ContentURI go func() { _, err = cli.UploadMedia(req) if err != nil { cli.logWarning("Failed to upload %s: %v", req.UnstableMXC, err) } }() return resp, nil } func (cli *Client) UploadBytes(data []byte, contentType string) (*RespMediaUpload, error) { return cli.UploadBytesWithName(data, contentType, "") } func (cli *Client) UploadBytesWithName(data []byte, contentType, fileName string) (*RespMediaUpload, error) { return cli.UploadMedia(ReqUploadMedia{ ContentBytes: data, ContentType: contentType, FileName: fileName, }) } // Upload uploads the given data to the content repository and returns an MXC URI. // // Deprecated: UploadMedia should be used instead. func (cli *Client) Upload(content io.Reader, contentType string, contentLength int64) (*RespMediaUpload, error) { return cli.UploadMedia(ReqUploadMedia{ Content: content, ContentLength: contentLength, ContentType: contentType, }) } type ReqUploadMedia struct { ContentBytes []byte Content io.Reader ContentLength int64 ContentType string FileName string // UnstableMXC specifies an existing MXC URI which doesn't have content yet to upload into. // See https://github.com/matrix-org/matrix-spec-proposals/pull/2246 for more info. UnstableMXC id.ContentURI } // UploadMedia uploads the given data to the content repository and returns an MXC URI. // See https://spec.matrix.org/v1.2/client-server-api/#post_matrixmediav3upload func (cli *Client) UploadMedia(data ReqUploadMedia) (*RespMediaUpload, error) { u, _ := url.Parse(cli.BuildURL(MediaURLPath{"v3", "upload"})) method := http.MethodPost if !data.UnstableMXC.IsEmpty() { u, _ = url.Parse(cli.BuildURL(MediaURLPath{"unstable", "fi.mau.msc2246", "upload", data.UnstableMXC.Homeserver, data.UnstableMXC.FileID})) method = http.MethodPut } if len(data.FileName) > 0 { q := u.Query() q.Set("filename", data.FileName) u.RawQuery = q.Encode() } var headers http.Header if len(data.ContentType) > 0 { headers = http.Header{"Content-Type": []string{data.ContentType}} } var m RespMediaUpload _, err := cli.MakeFullRequest(FullRequest{ Method: method, URL: u.String(), Headers: headers, RequestBytes: data.ContentBytes, RequestBody: data.Content, RequestLength: data.ContentLength, ResponseJSON: &m, }) return &m, err } // GetURLPreview asks the homeserver to fetch a preview for a given URL. // // See https://spec.matrix.org/v1.2/client-server-api/#get_matrixmediav3preview_url func (cli *Client) GetURLPreview(url string) (*RespPreviewURL, error) { reqURL := cli.BuildURLWithQuery(MediaURLPath{"v3", "preview_url"}, map[string]string{ "url": url, }) var output RespPreviewURL _, err := cli.MakeRequest(http.MethodGet, reqURL, nil, &output) return &output, err } // JoinedMembers returns a map of joined room members. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3roomsroomidjoined_members // // In general, usage of this API is discouraged in favour of /sync, as calling this API can race with incoming membership changes. // This API is primarily designed for application services which may want to efficiently look up joined members in a room. func (cli *Client) JoinedMembers(roomID id.RoomID) (resp *RespJoinedMembers, err error) { u := cli.BuildClientURL("v3", "rooms", roomID, "joined_members") _, err = cli.MakeRequest("GET", u, nil, &resp) return } func (cli *Client) Members(roomID id.RoomID, req ...ReqMembers) (resp *RespMembers, err error) { var extra ReqMembers if len(req) > 0 { extra = req[0] } query := map[string]string{} if len(extra.At) > 0 { query["at"] = extra.At } if len(extra.Membership) > 0 { query["membership"] = string(extra.Membership) } if len(extra.NotMembership) > 0 { query["not_membership"] = string(extra.NotMembership) } u := cli.BuildURLWithQuery(ClientURLPath{"v3", "rooms", roomID, "members"}, query) _, err = cli.MakeRequest("GET", u, nil, &resp) return } // JoinedRooms returns a list of rooms which the client is joined to. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3joined_rooms // // In general, usage of this API is discouraged in favour of /sync, as calling this API can race with incoming membership changes. // This API is primarily designed for application services which may want to efficiently look up joined rooms. func (cli *Client) JoinedRooms() (resp *RespJoinedRooms, err error) { u := cli.BuildClientURL("v3", "joined_rooms") _, err = cli.MakeRequest("GET", u, nil, &resp) return } // Messages returns a list of message and state events for a room. It uses // pagination query parameters to paginate history in the room. // See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3roomsroomidmessages func (cli *Client) Messages(roomID id.RoomID, from, to string, dir rune, filter *FilterPart, limit int) (resp *RespMessages, err error) { query := map[string]string{ "from": from, "dir": string(dir), } if filter != nil { filterJSON, err := json.Marshal(filter) if err != nil { return nil, err } query["filter"] = string(filterJSON) } if to != "" { query["to"] = to } if limit != 0 { query["limit"] = strconv.Itoa(limit) } urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "rooms", roomID, "messages"}, query) _, err = cli.MakeRequest("GET", urlPath, nil, &resp) return } // Context returns a number of events that happened just before and after the // specified event. It use pagination query parameters to paginate history in // the room. // See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3roomsroomidcontexteventid func (cli *Client) Context(roomID id.RoomID, eventID id.EventID, filter *FilterPart, limit int) (resp *RespContext, err error) { query := map[string]string{} if filter != nil { filterJSON, err := json.Marshal(filter) if err != nil { return nil, err } query["filter"] = string(filterJSON) } if limit != 0 { query["limit"] = strconv.Itoa(limit) } urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "rooms", roomID, "context", eventID}, query) _, err = cli.MakeRequest("GET", urlPath, nil, &resp) return } func (cli *Client) GetEvent(roomID id.RoomID, eventID id.EventID) (resp *event.Event, err error) { urlPath := cli.BuildClientURL("v3", "rooms", roomID, "event", eventID) _, err = cli.MakeRequest("GET", urlPath, nil, &resp) return } func (cli *Client) MarkRead(roomID id.RoomID, eventID id.EventID) (err error) { return cli.MarkReadWithContent(roomID, eventID, struct{}{}) } // MarkReadWithContent sends a read receipt including custom data. // N.B. This is not (yet) a part of the spec, normal servers will drop any extra content. func (cli *Client) MarkReadWithContent(roomID id.RoomID, eventID id.EventID, content interface{}) (err error) { urlPath := cli.BuildClientURL("v3", "rooms", roomID, "receipt", "m.read", eventID) _, err = cli.MakeRequest("POST", urlPath, &content, nil) return } func (cli *Client) SetReadMarkers(roomID id.RoomID, content interface{}) (err error) { urlPath := cli.BuildClientURL("v3", "rooms", roomID, "read_markers") _, err = cli.MakeRequest("POST", urlPath, &content, nil) return } func (cli *Client) AddTag(roomID id.RoomID, tag string, order float64) error { var tagData event.Tag if order == order { tagData.Order = json.Number(strconv.FormatFloat(order, 'e', -1, 64)) } return cli.AddTagWithCustomData(roomID, tag, tagData) } func (cli *Client) AddTagWithCustomData(roomID id.RoomID, tag string, data interface{}) (err error) { urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "rooms", roomID, "tags", tag) _, err = cli.MakeRequest("PUT", urlPath, data, nil) return } func (cli *Client) GetTags(roomID id.RoomID) (tags event.TagEventContent, err error) { err = cli.GetTagsWithCustomData(roomID, &tags) return } func (cli *Client) GetTagsWithCustomData(roomID id.RoomID, resp interface{}) (err error) { urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "rooms", roomID, "tags") _, err = cli.MakeRequest("GET", urlPath, nil, &resp) return } func (cli *Client) RemoveTag(roomID id.RoomID, tag string) (err error) { urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "rooms", roomID, "tags", tag) _, err = cli.MakeRequest("DELETE", urlPath, nil, nil) return } // Deprecated: Synapse may not handle setting m.tag directly properly, so you should use the Add/RemoveTag methods instead. func (cli *Client) SetTags(roomID id.RoomID, tags event.Tags) (err error) { return cli.SetRoomAccountData(roomID, "m.tag", map[string]event.Tags{ "tags": tags, }) } // TurnServer returns turn server details and credentials for the client to use when initiating calls. // See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3voipturnserver func (cli *Client) TurnServer() (resp *RespTurnServer, err error) { urlPath := cli.BuildClientURL("v3", "voip", "turnServer") _, err = cli.MakeRequest("GET", urlPath, nil, &resp) return } func (cli *Client) CreateAlias(alias id.RoomAlias, roomID id.RoomID) (resp *RespAliasCreate, err error) { urlPath := cli.BuildClientURL("v3", "directory", "room", alias) _, err = cli.MakeRequest("PUT", urlPath, &ReqAliasCreate{RoomID: roomID}, &resp) return } func (cli *Client) ResolveAlias(alias id.RoomAlias) (resp *RespAliasResolve, err error) { urlPath := cli.BuildClientURL("v3", "directory", "room", alias) _, err = cli.MakeRequest("GET", urlPath, nil, &resp) return } func (cli *Client) DeleteAlias(alias id.RoomAlias) (resp *RespAliasDelete, err error) { urlPath := cli.BuildClientURL("v3", "directory", "room", alias) _, err = cli.MakeRequest("DELETE", urlPath, nil, &resp) return } func (cli *Client) GetAliases(roomID id.RoomID) (resp *RespAliasList, err error) { urlPath := cli.BuildClientURL("v3", "rooms", roomID, "aliases") _, err = cli.MakeRequest("GET", urlPath, nil, &resp) return } func (cli *Client) UploadKeys(req *ReqUploadKeys) (resp *RespUploadKeys, err error) { urlPath := cli.BuildClientURL("v3", "keys", "upload") _, err = cli.MakeRequest("POST", urlPath, req, &resp) return } func (cli *Client) QueryKeys(req *ReqQueryKeys) (resp *RespQueryKeys, err error) { urlPath := cli.BuildClientURL("v3", "keys", "query") _, err = cli.MakeRequest("POST", urlPath, req, &resp) return } func (cli *Client) ClaimKeys(req *ReqClaimKeys) (resp *RespClaimKeys, err error) { urlPath := cli.BuildClientURL("v3", "keys", "claim") _, err = cli.MakeRequest("POST", urlPath, req, &resp) return } func (cli *Client) GetKeyChanges(from, to string) (resp *RespKeyChanges, err error) { urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "keys", "changes"}, map[string]string{ "from": from, "to": to, }) _, err = cli.MakeRequest("POST", urlPath, nil, &resp) return } func (cli *Client) SendToDevice(eventType event.Type, req *ReqSendToDevice) (resp *RespSendToDevice, err error) { urlPath := cli.BuildClientURL("v3", "sendToDevice", eventType.String(), cli.TxnID()) _, err = cli.MakeRequest("PUT", urlPath, req, &resp) return } func (cli *Client) GetDevicesInfo() (resp *RespDevicesInfo, err error) { urlPath := cli.BuildClientURL("v3", "devices") _, err = cli.MakeRequest("GET", urlPath, nil, &resp) return } func (cli *Client) GetDeviceInfo(deviceID id.DeviceID) (resp *RespDeviceInfo, err error) { urlPath := cli.BuildClientURL("v3", "devices", deviceID) _, err = cli.MakeRequest("GET", urlPath, nil, &resp) return } func (cli *Client) SetDeviceInfo(deviceID id.DeviceID, req *ReqDeviceInfo) error { urlPath := cli.BuildClientURL("v3", "devices", deviceID) _, err := cli.MakeRequest("PUT", urlPath, req, nil) return err } func (cli *Client) DeleteDevice(deviceID id.DeviceID, req *ReqDeleteDevice) error { urlPath := cli.BuildClientURL("v3", "devices", deviceID) _, err := cli.MakeRequest("DELETE", urlPath, req, nil) return err } func (cli *Client) DeleteDevices(req *ReqDeleteDevices) error { urlPath := cli.BuildClientURL("v3", "delete_devices") _, err := cli.MakeRequest("DELETE", urlPath, req, nil) return err } type UIACallback = func(*RespUserInteractive) interface{} // UploadCrossSigningKeys uploads the given cross-signing keys to the server. // Because the endpoint requires user-interactive authentication a callback must be provided that, // given the UI auth parameters, produces the required result (or nil to end the flow). func (cli *Client) UploadCrossSigningKeys(keys *UploadCrossSigningKeysReq, uiaCallback UIACallback) error { content, err := cli.MakeFullRequest(FullRequest{ Method: http.MethodPost, URL: cli.BuildClientURL("v3", "keys", "device_signing", "upload"), RequestJSON: keys, SensitiveContent: keys.Auth != nil, }) if respErr, ok := err.(HTTPError); ok && respErr.IsStatus(http.StatusUnauthorized) { // try again with UI auth var uiAuthResp RespUserInteractive if err := json.Unmarshal(content, &uiAuthResp); err != nil { return fmt.Errorf("failed to decode UIA response: %w", err) } auth := uiaCallback(&uiAuthResp) if auth != nil { keys.Auth = auth return cli.UploadCrossSigningKeys(keys, uiaCallback) } } return err } func (cli *Client) UploadSignatures(req *ReqUploadSignatures) (resp *RespUploadSignatures, err error) { urlPath := cli.BuildClientURL("v3", "keys", "signatures", "upload") _, err = cli.MakeRequest("POST", urlPath, req, &resp) return } // GetPushRules returns the push notification rules for the global scope. func (cli *Client) GetPushRules() (*pushrules.PushRuleset, error) { return cli.GetScopedPushRules("global") } // GetScopedPushRules returns the push notification rules for the given scope. func (cli *Client) GetScopedPushRules(scope string) (resp *pushrules.PushRuleset, err error) { u, _ := url.Parse(cli.BuildClientURL("v3", "pushrules", scope)) // client.BuildURL returns the URL without a trailing slash, but the pushrules endpoint requires the slash. u.Path += "/" _, err = cli.MakeRequest("GET", u.String(), nil, &resp) return } func (cli *Client) GetPushRule(scope string, kind pushrules.PushRuleType, ruleID string) (resp *pushrules.PushRule, err error) { urlPath := cli.BuildClientURL("v3", "pushrules", scope, kind, ruleID) _, err = cli.MakeRequest("GET", urlPath, nil, &resp) if resp != nil { resp.Type = kind } return } func (cli *Client) DeletePushRule(scope string, kind pushrules.PushRuleType, ruleID string) error { urlPath := cli.BuildClientURL("v3", "pushrules", scope, kind, ruleID) _, err := cli.MakeRequest("DELETE", urlPath, nil, nil) return err } func (cli *Client) PutPushRule(scope string, kind pushrules.PushRuleType, ruleID string, req *ReqPutPushRule) error { query := make(map[string]string) if len(req.After) > 0 { query["after"] = req.After } if len(req.Before) > 0 { query["before"] = req.Before } urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "pushrules", scope, kind, ruleID}, query) _, err := cli.MakeRequest("PUT", urlPath, req, nil) return err } // BatchSend sends a batch of historical events into a room. This is only available for appservices. // // See https://github.com/matrix-org/matrix-doc/pull/2716 for more info. func (cli *Client) BatchSend(roomID id.RoomID, req *ReqBatchSend) (resp *RespBatchSend, err error) { path := ClientURLPath{"unstable", "org.matrix.msc2716", "rooms", roomID, "batch_send"} query := map[string]string{ "prev_event_id": req.PrevEventID.String(), } if len(req.BatchID) > 0 { query["batch_id"] = req.BatchID.String() } _, err = cli.MakeRequest("POST", cli.BuildURLWithQuery(path, query), req, &resp) return } // TxnID returns the next transaction ID. func (cli *Client) TxnID() string { txnID := atomic.AddInt32(&cli.txnID, 1) return fmt.Sprintf("mautrix-go_%d_%d", time.Now().UnixNano(), txnID) } // NewClient creates a new Matrix Client ready for syncing func NewClient(homeserverURL string, userID id.UserID, accessToken string) (*Client, error) { hsURL, err := parseAndNormalizeBaseURL(homeserverURL) if err != nil { return nil, err } return &Client{ AccessToken: accessToken, UserAgent: DefaultUserAgent, HomeserverURL: hsURL, UserID: userID, Client: &http.Client{Timeout: 180 * time.Second}, Syncer: NewDefaultSyncer(), Logger: stubLogger, // By default, use an in-memory store which will never save filter ids / next batch tokens to disk. // The client will work with this storer: it just won't remember across restarts. // In practice, a database backend should be used. Store: NewInMemoryStore(), }, nil } go-0.11.1/client_internal_test.go000066400000000000000000000030501436100171500167330ustar00rootroot00000000000000package mautrix import ( "fmt" "net/http" "testing" "time" ) type testLogger struct { StubLogger lastLogged string } func (tl *testLogger) Warnfln(message string, args ...interface{}) { tl.lastLogged = fmt.Sprintf(message, args...) } func TestBackoffFromResponse(t *testing.T) { now := time.Now().Truncate(time.Second) defaultBackoff := time.Duration(123) for name, tt := range map[string]struct { headerValue string expected time.Duration expectedLog string }{ "AsDate": { headerValue: now.In(time.UTC).Add(5 * time.Hour).Format(http.TimeFormat), expected: time.Duration(5) * time.Hour, expectedLog: "", }, "AsSeconds": { headerValue: "12345", expected: time.Duration(12345) * time.Second, expectedLog: "", }, "Missing": { headerValue: "", expected: defaultBackoff, expectedLog: "", }, "Bad": { headerValue: "invalid", expected: defaultBackoff, expectedLog: `Failed to parse Retry-After header value "invalid"`, }, } { t.Run(name, func(t *testing.T) { logger := &testLogger{} c := &Client{Logger: logger} actual := c.parseBackoffFromResponse( &http.Response{ Header: http.Header{ "Retry-After": []string{tt.headerValue}, }, }, now, time.Duration(123), ) if actual != tt.expected { t.Fatalf("Backoff duration output mismatch, expected %s, got %s", tt.expected, actual) } if logger.lastLogged != tt.expectedLog { t.Fatalf(`Log line mismatch, expected "%s", got "%s"`, tt.expectedLog, logger.lastLogged) } }) } } go-0.11.1/crypto/000077500000000000000000000000001436100171500135155ustar00rootroot00000000000000go-0.11.1/crypto/account.go000066400000000000000000000054521436100171500155060ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package crypto import ( "maunium.net/go/mautrix" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/id" ) type OlmAccount struct { Internal olm.Account signingKey id.SigningKey identityKey id.IdentityKey Shared bool } func NewOlmAccount() *OlmAccount { return &OlmAccount{ Internal: *olm.NewAccount(), } } func (account *OlmAccount) Keys() (id.SigningKey, id.IdentityKey) { if len(account.signingKey) == 0 || len(account.identityKey) == 0 { account.signingKey, account.identityKey = account.Internal.IdentityKeys() } return account.signingKey, account.identityKey } func (account *OlmAccount) SigningKey() id.SigningKey { if len(account.signingKey) == 0 { account.signingKey, account.identityKey = account.Internal.IdentityKeys() } return account.signingKey } func (account *OlmAccount) IdentityKey() id.IdentityKey { if len(account.identityKey) == 0 { account.signingKey, account.identityKey = account.Internal.IdentityKeys() } return account.identityKey } func (account *OlmAccount) getInitialKeys(userID id.UserID, deviceID id.DeviceID) *mautrix.DeviceKeys { deviceKeys := &mautrix.DeviceKeys{ UserID: userID, DeviceID: deviceID, Algorithms: []id.Algorithm{id.AlgorithmMegolmV1, id.AlgorithmOlmV1}, Keys: map[id.DeviceKeyID]string{ id.NewDeviceKeyID(id.KeyAlgorithmCurve25519, deviceID): string(account.IdentityKey()), id.NewDeviceKeyID(id.KeyAlgorithmEd25519, deviceID): string(account.SigningKey()), }, } signature, err := account.Internal.SignJSON(deviceKeys) if err != nil { panic(err) } deviceKeys.Signatures = mautrix.Signatures{ userID: { id.NewKeyID(id.KeyAlgorithmEd25519, deviceID.String()): signature, }, } return deviceKeys } func (account *OlmAccount) getOneTimeKeys(userID id.UserID, deviceID id.DeviceID, currentOTKCount int) map[id.KeyID]mautrix.OneTimeKey { newCount := int(account.Internal.MaxNumberOfOneTimeKeys()/2) - currentOTKCount if newCount > 0 { account.Internal.GenOneTimeKeys(uint(newCount)) } oneTimeKeys := make(map[id.KeyID]mautrix.OneTimeKey) // TODO do we need unsigned curve25519 one-time keys at all? // this just signs all of them for keyID, key := range account.Internal.OneTimeKeys() { key := mautrix.OneTimeKey{Key: key} signature, _ := account.Internal.SignJSON(key) key.Signatures = mautrix.Signatures{ userID: { id.NewKeyID(id.KeyAlgorithmEd25519, deviceID.String()): signature, }, } key.IsSigned = true oneTimeKeys[id.NewKeyID(id.KeyAlgorithmSignedCurve25519, keyID)] = key } account.Internal.MarkKeysAsPublished() return oneTimeKeys } go-0.11.1/crypto/attachment/000077500000000000000000000000001436100171500156455ustar00rootroot00000000000000go-0.11.1/crypto/attachment/attachments.go000066400000000000000000000163721436100171500205200ustar00rootroot00000000000000// Copyright (c) 2022 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package attachment import ( "crypto/aes" "crypto/cipher" "crypto/sha256" "encoding/base64" "errors" "hash" "io" "maunium.net/go/mautrix/crypto/utils" ) var ( HashMismatch = errors.New("mismatching SHA-256 digest") UnsupportedVersion = errors.New("unsupported Matrix file encryption version") UnsupportedAlgorithm = errors.New("unsupported JWK encryption algorithm") InvalidKey = errors.New("failed to decode key") InvalidInitVector = errors.New("failed to decode initialization vector") InvalidHash = errors.New("failed to decode SHA-256 hash") ReaderClosed = errors.New("encrypting reader was already closed") ) var ( keyBase64Length = base64.RawURLEncoding.EncodedLen(utils.AESCTRKeyLength) ivBase64Length = base64.RawStdEncoding.EncodedLen(utils.AESCTRIVLength) hashBase64Length = base64.RawStdEncoding.EncodedLen(utils.SHAHashLength) ) type JSONWebKey struct { Key string `json:"k"` Algorithm string `json:"alg"` Extractable bool `json:"ext"` KeyType string `json:"kty"` KeyOps []string `json:"key_ops"` } type EncryptedFileHashes struct { SHA256 string `json:"sha256"` } type decodedKeys struct { key [utils.AESCTRKeyLength]byte iv [utils.AESCTRIVLength]byte sha256 [utils.SHAHashLength]byte } type EncryptedFile struct { Key JSONWebKey `json:"key"` InitVector string `json:"iv"` Hashes EncryptedFileHashes `json:"hashes"` Version string `json:"v"` decoded *decodedKeys } func NewEncryptedFile() *EncryptedFile { key, iv := utils.GenAttachmentA256CTR() return &EncryptedFile{ Key: JSONWebKey{ Key: base64.RawURLEncoding.EncodeToString(key[:]), Algorithm: "A256CTR", Extractable: true, KeyType: "oct", KeyOps: []string{"encrypt", "decrypt"}, }, InitVector: base64.RawStdEncoding.EncodeToString(iv[:]), Version: "v2", decoded: &decodedKeys{key: key, iv: iv}, } } func (ef *EncryptedFile) decodeKeys(includeHash bool) error { if ef.decoded != nil { return nil } else if len(ef.Key.Key) != keyBase64Length { return InvalidKey } else if len(ef.InitVector) != ivBase64Length { return InvalidInitVector } else if includeHash && len(ef.Hashes.SHA256) != hashBase64Length { return InvalidHash } ef.decoded = &decodedKeys{} _, err := base64.RawURLEncoding.Decode(ef.decoded.key[:], []byte(ef.Key.Key)) if err != nil { return InvalidKey } _, err = base64.RawStdEncoding.Decode(ef.decoded.iv[:], []byte(ef.InitVector)) if err != nil { return InvalidInitVector } if includeHash { _, err = base64.RawStdEncoding.Decode(ef.decoded.sha256[:], []byte(ef.Hashes.SHA256)) if err != nil { return InvalidHash } } return nil } // Encrypt encrypts the given data, updates the SHA256 hash in the EncryptedFile struct and returns the ciphertext. // // Deprecated: this makes a copy for the ciphertext, which means 2x memory usage. EncryptInPlace is recommended. func (ef *EncryptedFile) Encrypt(plaintext []byte) []byte { ciphertext := make([]byte, len(plaintext)) copy(ciphertext, plaintext) ef.EncryptInPlace(ciphertext) return ciphertext } // EncryptInPlace encrypts the given data in-place (i.e. the provided data is overridden with the ciphertext) // and updates the SHA256 hash in the EncryptedFile struct. func (ef *EncryptedFile) EncryptInPlace(data []byte) { ef.decodeKeys(false) utils.XorA256CTR(data, ef.decoded.key, ef.decoded.iv) checksum := sha256.Sum256(data) ef.Hashes.SHA256 = base64.RawStdEncoding.EncodeToString(checksum[:]) } type encryptingReader struct { stream cipher.Stream hash hash.Hash source io.Reader file *EncryptedFile closed bool isDecrypting bool } func (r *encryptingReader) Read(dst []byte) (n int, err error) { if r.closed { return 0, ReaderClosed } else if r.isDecrypting && r.file.decoded == nil { if err = r.file.PrepareForDecryption(); err != nil { return } } n, err = r.source.Read(dst) r.stream.XORKeyStream(dst[:n], dst[:n]) r.hash.Write(dst[:n]) return } func (r *encryptingReader) Close() (err error) { closer, ok := r.source.(io.ReadCloser) if ok { err = closer.Close() } if r.isDecrypting { var downloadedChecksum [utils.SHAHashLength]byte r.hash.Sum(downloadedChecksum[:]) if downloadedChecksum != r.file.decoded.sha256 { return HashMismatch } } else { r.file.Hashes.SHA256 = base64.RawStdEncoding.EncodeToString(r.hash.Sum(nil)) } r.closed = true return } // EncryptStream wraps the given io.Reader in order to encrypt the data. // // The Close() method of the returned io.ReadCloser must be called for the SHA256 hash // in the EncryptedFile struct to be updated. The metadata is not valid before the hash // is filled. func (ef *EncryptedFile) EncryptStream(reader io.Reader) io.ReadCloser { ef.decodeKeys(false) block, _ := aes.NewCipher(ef.decoded.key[:]) return &encryptingReader{ stream: cipher.NewCTR(block, ef.decoded.iv[:]), hash: sha256.New(), source: reader, file: ef, } } // Decrypt decrypts the given data and returns the plaintext. // // Deprecated: this makes a copy for the plaintext data, which means 2x memory usage. DecryptInPlace is recommended. func (ef *EncryptedFile) Decrypt(ciphertext []byte) ([]byte, error) { plaintext := make([]byte, len(ciphertext)) copy(plaintext, ciphertext) return plaintext, ef.DecryptInPlace(plaintext) } // PrepareForDecryption checks that the version and algorithm are supported and decodes the base64 keys // // DecryptStream will call this with the first Read() call if this hasn't been called manually. // // DecryptInPlace will always call this automatically, so calling this manually is not necessary when using that function. func (ef *EncryptedFile) PrepareForDecryption() error { if ef.Version != "v2" { return UnsupportedVersion } else if ef.Key.Algorithm != "A256CTR" { return UnsupportedAlgorithm } else if err := ef.decodeKeys(true); err != nil { return err } return nil } // DecryptInPlace decrypts the given data in-place (i.e. the provided data is overridden with the plaintext). func (ef *EncryptedFile) DecryptInPlace(data []byte) error { if err := ef.PrepareForDecryption(); err != nil { return err } else if ef.decoded.sha256 != sha256.Sum256(data) { return HashMismatch } else { utils.XorA256CTR(data, ef.decoded.key, ef.decoded.iv) return nil } } // DecryptStream wraps the given io.Reader in order to decrypt the data. // // The first Read call will check the algorithm and decode keys, so it might return an error before actually reading anything. // If you want to validate the file before opening the stream, call PrepareForDecryption manually and check for errors. // // The Close call will validate the hash and return an error if it doesn't match. // In this case, the written data should be considered compromised and should not be used further. func (ef *EncryptedFile) DecryptStream(reader io.Reader) io.ReadCloser { block, _ := aes.NewCipher(ef.decoded.key[:]) return &encryptingReader{ stream: cipher.NewCTR(block, ef.decoded.iv[:]), hash: sha256.New(), source: reader, file: ef, } } go-0.11.1/crypto/attachment/attachments_test.go000066400000000000000000000047051436100171500215540ustar00rootroot00000000000000package attachment import ( "encoding/base64" "encoding/json" "testing" "github.com/stretchr/testify/assert" ) const helloWorldCiphertext = ":6\xc7O1yR\x06\xe8\xcf]" const helloWorldRawFile = `{ "v": "v2", "key": { "kty": "oct", "alg": "A256CTR", "ext": true, "k": "35XNdmWKOpn6UYS82Y83wEY8LagwQZHX2X0kAFW7sdg", "key_ops": [ "encrypt", "decrypt" ] }, "iv": "DOtPz8bC3qgAAAAAAAAAAA", "hashes": { "sha256": "rO+040ZhUxbpbmIS9GUuMSen4NPKFxMzqOUJeemM8mk" } }` const random32Bytes = "\x85\xb4\x16/\xcaO\x1d\xe6\x7f\x95\xeb\xdb+g\x11\xb1\x81\x1a\xafY\x00\x1dq!h{\x81F\xaa\xd7A\x00" func parseHelloWorld() *EncryptedFile { file := &EncryptedFile{} _ = json.Unmarshal([]byte(helloWorldRawFile), file) return file } func TestDecryptHelloWorld(t *testing.T) { file := parseHelloWorld() data := []byte(helloWorldCiphertext) err := file.DecryptInPlace(data) assert.NoError(t, err, "failed to decrypt file") assert.Equal(t, "hello world", string(data), "unexpected decrypt output") } func TestEncryptHelloWorld(t *testing.T) { file := parseHelloWorld() data := []byte("hello world") file.EncryptInPlace(data) assert.Equal(t, helloWorldCiphertext, string(data), "unexpected encrypt output") } func TestUnsupportedVersion(t *testing.T) { file := parseHelloWorld() file.Version = "foo" err := file.DecryptInPlace([]byte(helloWorldCiphertext)) assert.ErrorIs(t, err, UnsupportedVersion) } func TestUnsupportedAlgorithm(t *testing.T) { file := parseHelloWorld() file.Key.Algorithm = "bar" err := file.DecryptInPlace([]byte(helloWorldCiphertext)) assert.ErrorIs(t, err, UnsupportedAlgorithm) } func TestHashMismatch(t *testing.T) { file := parseHelloWorld() file.Hashes.SHA256 = base64.RawStdEncoding.EncodeToString([]byte(random32Bytes)) err := file.DecryptInPlace([]byte(helloWorldCiphertext)) assert.ErrorIs(t, err, HashMismatch) } func TestTooLongHash(t *testing.T) { file := parseHelloWorld() file.Hashes.SHA256 = "TG9yZW0gaXBzdW0gZG9sb3Igc2l0IGFtZXQsIGNvbnNlY3RldHVlciBhZGlwaXNjaW5nIGVsaXQuIFNlZCBwb3N1ZXJlIGludGVyZHVtIHNlbS4gUXVpc3F1ZSBsaWd1bGEgZXJvcyB1bGxhbWNvcnBlciBxdWlzLCBsYWNpbmlhIHF1aXMgZmFjaWxpc2lzIHNlZCBzYXBpZW4uCg" err := file.DecryptInPlace([]byte(helloWorldCiphertext)) assert.ErrorIs(t, err, InvalidHash) } func TestTooShortHash(t *testing.T) { file := parseHelloWorld() file.Hashes.SHA256 = "5/Gy1JftyyQ" err := file.DecryptInPlace([]byte(helloWorldCiphertext)) assert.ErrorIs(t, err, InvalidHash) } go-0.11.1/crypto/canonicaljson/000077500000000000000000000000001436100171500163365ustar00rootroot00000000000000go-0.11.1/crypto/canonicaljson/LICENSE000066400000000000000000000236761436100171500173610ustar00rootroot00000000000000 Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS go-0.11.1/crypto/canonicaljson/README.md000066400000000000000000000005201436100171500176120ustar00rootroot00000000000000# canonicaljson This is a Go package to produce Matrix [Canonical JSON](https://matrix.org/docs/spec/appendices#canonical-json). It is essentially just [json.go](https://github.com/matrix-org/gomatrixserverlib/blob/master/json.go) from gomatrixserverlib without all the other files that are completely useless for non-server use cases. go-0.11.1/crypto/canonicaljson/json.go000066400000000000000000000200141436100171500176330ustar00rootroot00000000000000/* Copyright 2016-2017 Vector Creations Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package canonicaljson import ( "encoding/binary" "fmt" "sort" "unicode/utf8" "github.com/tidwall/gjson" ) // CanonicalJSON re-encodes the JSON in a canonical encoding. The encoding is // the shortest possible encoding using integer values with sorted object keys. // https://matrix.org/docs/spec/appendices#canonical-json func CanonicalJSON(input []byte) ([]byte, error) { if !gjson.Valid(string(input)) { return nil, fmt.Errorf("invalid json") } return CanonicalJSONAssumeValid(input), nil } // CanonicalJSONAssumeValid is the same as CanonicalJSON, but assumes the // input is valid JSON func CanonicalJSONAssumeValid(input []byte) []byte { input = CompactJSON(input, make([]byte, 0, len(input))) return SortJSON(input, make([]byte, 0, len(input))) } // SortJSON reencodes the JSON with the object keys sorted by lexicographically // by codepoint. The input must be valid JSON. func SortJSON(input, output []byte) []byte { result := gjson.ParseBytes(input) return sortJSONValue(result, input, output) } // sortJSONValue takes a gjson.Result and sorts it. inputJSON must be the // raw JSON bytes that gjson.Result points to. func sortJSONValue(input gjson.Result, inputJSON, output []byte) []byte { if input.IsArray() { return sortJSONArray(input, inputJSON, output) } if input.IsObject() { return sortJSONObject(input, inputJSON, output) } // If its neither an object nor an array then there is no sub structure // to sort, so just append the raw bytes. return append(output, input.Raw...) } // sortJSONArray takes a gjson.Result and sorts it, assuming its an array. // inputJSON must be the raw JSON bytes that gjson.Result points to. func sortJSONArray(input gjson.Result, inputJSON, output []byte) []byte { sep := byte('[') // Iterate over each value in the array and sort it. input.ForEach(func(_, value gjson.Result) bool { output = append(output, sep) sep = ',' output = sortJSONValue(value, inputJSON, output) return true // keep iterating }) if sep == '[' { // If sep is still '[' then the array was empty and we never wrote the // initial '[', so we write it now along with the closing ']'. output = append(output, '[', ']') } else { // Otherwise we end the array by writing a single ']' output = append(output, ']') } return output } // sortJSONObject takes a gjson.Result and sorts it, assuming its an object. // inputJSON must be the raw JSON bytes that gjson.Result points to. func sortJSONObject(input gjson.Result, inputJSON, output []byte) []byte { type entry struct { key string // The parsed key string rawKey string // The raw, unparsed key JSON string value gjson.Result } var entries []entry // Iterate over each key/value pair and add it to a slice // that we can sort input.ForEach(func(key, value gjson.Result) bool { entries = append(entries, entry{ key: key.String(), rawKey: key.Raw, value: value, }) return true // keep iterating }) // Sort the slice based on the *parsed* key sort.Slice(entries, func(a, b int) bool { return entries[a].key < entries[b].key }) sep := byte('{') for _, entry := range entries { output = append(output, sep) sep = ',' // Append the raw unparsed JSON key, *not* the parsed key output = append(output, entry.rawKey...) output = append(output, ':') output = sortJSONValue(entry.value, inputJSON, output) } if sep == '{' { // If sep is still '{' then the object was empty and we never wrote the // initial '{', so we write it now along with the closing '}'. output = append(output, '{', '}') } else { // Otherwise we end the object by writing a single '}' output = append(output, '}') } return output } // CompactJSON makes the encoded JSON as small as possible by removing // whitespace and unneeded unicode escapes func CompactJSON(input, output []byte) []byte { var i int for i < len(input) { c := input[i] i++ // The valid whitespace characters are all less than or equal to SPACE 0x20. // The valid non-white characters are all greater than SPACE 0x20. // So we can check for whitespace by comparing against SPACE 0x20. if c <= ' ' { // Skip over whitespace. continue } // Add the non-whitespace character to the output. output = append(output, c) if c == '"' { // We are inside a string. for i < len(input) { c = input[i] i++ // Check if this is an escape sequence. if c == '\\' { escape := input[i] i++ if escape == 'u' { // If this is a unicode escape then we need to handle it specially output, i = compactUnicodeEscape(input, output, i) } else if escape == '/' { // JSON does not require escaping '/', but allows encoders to escape it as a special case. // Since the escape isn't required we remove it. output = append(output, escape) } else { // All other permitted escapes are single charater escapes that are already in their shortest form. output = append(output, '\\', escape) } } else { output = append(output, c) } if c == '"' { break } } } } return output } // compactUnicodeEscape unpacks a 4 byte unicode escape starting at index. // If the escape is a surrogate pair then decode the 6 byte \uXXXX escape // that follows. Returns the output slice and a new input index. func compactUnicodeEscape(input, output []byte, index int) ([]byte, int) { const ( ESCAPES = "uuuuuuuubtnufruuuuuuuuuuuuuuuuuu" HEX = "0123456789ABCDEF" ) // If there aren't enough bytes to decode the hex escape then return. if len(input)-index < 4 { return output, len(input) } // Decode the 4 hex digits. c := readHexDigits(input[index:]) index += 4 if c < ' ' { // If the character is less than SPACE 0x20 then it will need escaping. escape := ESCAPES[c] output = append(output, '\\', escape) if escape == 'u' { output = append(output, '0', '0', byte('0'+(c>>4)), HEX[c&0xF]) } } else if c == '\\' || c == '"' { // Otherwise the character only needs escaping if it is a QUOTE '"' or BACKSLASH '\\'. output = append(output, '\\', byte(c)) } else if c < 0xD800 || c >= 0xE000 { // If the character isn't a surrogate pair then encoded it directly as UTF-8. var buffer [4]byte n := utf8.EncodeRune(buffer[:], rune(c)) output = append(output, buffer[:n]...) } else { // Otherwise the escaped character was the first part of a UTF-16 style surrogate pair. // The next 6 bytes MUST be a '\uXXXX'. // If there aren't enough bytes to decode the hex escape then return. if len(input)-index < 6 { return output, len(input) } // Decode the 4 hex digits from the '\uXXXX'. surrogate := readHexDigits(input[index+2:]) index += 6 // Reconstruct the UCS4 codepoint from the surrogates. codepoint := 0x10000 + (((c & 0x3FF) << 10) | (surrogate & 0x3FF)) // Encode the charater as UTF-8. var buffer [4]byte n := utf8.EncodeRune(buffer[:], rune(codepoint)) output = append(output, buffer[:n]...) } return output, index } // Read 4 hex digits from the input slice. // Taken from https://github.com/NegativeMjark/indolentjson-rust/blob/8b959791fe2656a88f189c5d60d153be05fe3deb/src/readhex.rs#L21 func readHexDigits(input []byte) uint32 { hex := binary.BigEndian.Uint32(input) // subtract '0' hex -= 0x30303030 // strip the higher bits, maps 'a' => 'A' hex &= 0x1F1F1F1F mask := hex & 0x10101010 // subtract 'A' - 10 - '9' - 9 = 7 from the letters. hex -= mask >> 1 hex += mask >> 4 // collect the nibbles hex |= hex >> 4 hex &= 0xFF00FF hex |= hex >> 8 return hex & 0xFFFF } go-0.11.1/crypto/canonicaljson/json_test.go000066400000000000000000000055721436100171500207060ustar00rootroot00000000000000/* Copyright 2016-2017 Vector Creations Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package canonicaljson import ( "testing" ) func testSortJSON(t *testing.T, input, want string) { got := SortJSON([]byte(input), nil) // Squash out the whitespace before comparing the JSON in case SortJSON had inserted whitespace. if string(CompactJSON(got, nil)) != want { t.Errorf("SortJSON(%q): want %q got %q", input, want, got) } } func TestSortJSON(t *testing.T) { testSortJSON(t, `[{"b":"two","a":1}]`, `[{"a":1,"b":"two"}]`) testSortJSON(t, `{"B":{"4":4,"3":3},"A":{"1":1,"2":2}}`, `{"A":{"1":1,"2":2},"B":{"3":3,"4":4}}`) testSortJSON(t, `[true,false,null]`, `[true,false,null]`) testSortJSON(t, `[9007199254740991]`, `[9007199254740991]`) testSortJSON(t, "\t\n[9007199254740991]", `[9007199254740991]`) } func testCompactJSON(t *testing.T, input, want string) { got := string(CompactJSON([]byte(input), nil)) if got != want { t.Errorf("CompactJSON(%q): want %q got %q", input, want, got) } } func TestCompactJSON(t *testing.T) { testCompactJSON(t, "{ }", "{}") input := `["\u0000\u0001\u0002\u0003\u0004\u0005\u0006\u0007"]` want := input testCompactJSON(t, input, want) input = `["\u0008\u0009\u000A\u000B\u000C\u000D\u000E\u000F"]` want = `["\b\t\n\u000B\f\r\u000E\u000F"]` testCompactJSON(t, input, want) input = `["\u0010\u0011\u0012\u0013\u0014\u0015\u0016\u0017"]` want = input testCompactJSON(t, input, want) input = `["\u0018\u0019\u001A\u001B\u001C\u001D\u001E\u001F"]` want = input testCompactJSON(t, input, want) testCompactJSON(t, `["\u0061\u005C\u0042\u0022"]`, `["a\\B\""]`) testCompactJSON(t, `["\u0120"]`, "[\"\u0120\"]") testCompactJSON(t, `["\u0FFF"]`, "[\"\u0FFF\"]") testCompactJSON(t, `["\u1820"]`, "[\"\u1820\"]") testCompactJSON(t, `["\uFFFF"]`, "[\"\uFFFF\"]") testCompactJSON(t, `["\uD842\uDC20"]`, "[\"\U00020820\"]") testCompactJSON(t, `["\uDBFF\uDFFF"]`, "[\"\U0010FFFF\"]") testCompactJSON(t, `["\"\\\/"]`, `["\"\\/"]`) } func testReadHex(t *testing.T, input string, want uint32) { got := readHexDigits([]byte(input)) if want != got { t.Errorf("readHexDigits(%q): want 0x%x got 0x%x", input, want, got) } } func TestReadHex(t *testing.T) { testReadHex(t, "0123", 0x0123) testReadHex(t, "4567", 0x4567) testReadHex(t, "89AB", 0x89AB) testReadHex(t, "CDEF", 0xCDEF) testReadHex(t, "89ab", 0x89AB) testReadHex(t, "cdef", 0xCDEF) } go-0.11.1/crypto/cross_sign_key.go000066400000000000000000000110621436100171500170650ustar00rootroot00000000000000// Copyright (c) 2020 Nikos Filippakis // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package crypto import ( "fmt" "maunium.net/go/mautrix" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/id" ) // CrossSigningKeysCache holds the three cross-signing keys for the current user. type CrossSigningKeysCache struct { MasterKey *olm.PkSigning SelfSigningKey *olm.PkSigning UserSigningKey *olm.PkSigning } func (cskc *CrossSigningKeysCache) PublicKeys() *CrossSigningPublicKeysCache { return &CrossSigningPublicKeysCache{ MasterKey: cskc.MasterKey.PublicKey, SelfSigningKey: cskc.SelfSigningKey.PublicKey, UserSigningKey: cskc.UserSigningKey.PublicKey, } } type CrossSigningSeeds struct { MasterKey []byte SelfSigningKey []byte UserSigningKey []byte } func (mach *OlmMachine) ExportCrossSigningKeys() CrossSigningSeeds { return CrossSigningSeeds{ MasterKey: mach.CrossSigningKeys.MasterKey.Seed, SelfSigningKey: mach.CrossSigningKeys.SelfSigningKey.Seed, UserSigningKey: mach.CrossSigningKeys.UserSigningKey.Seed, } } func (mach *OlmMachine) ImportCrossSigningKeys(keys CrossSigningSeeds) (err error) { var keysCache CrossSigningKeysCache if keysCache.MasterKey, err = olm.NewPkSigningFromSeed(keys.MasterKey); err != nil { return } if keysCache.SelfSigningKey, err = olm.NewPkSigningFromSeed(keys.SelfSigningKey); err != nil { return } if keysCache.UserSigningKey, err = olm.NewPkSigningFromSeed(keys.UserSigningKey); err != nil { return } mach.Log.Trace("Got cross-signing keys: Master `%v` Self-signing `%v` User-signing `%v`", keysCache.MasterKey.PublicKey, keysCache.SelfSigningKey.PublicKey, keysCache.UserSigningKey.PublicKey) mach.CrossSigningKeys = &keysCache mach.crossSigningPubkeys = keysCache.PublicKeys() return } // GenerateCrossSigningKeys generates new cross-signing keys. func (mach *OlmMachine) GenerateCrossSigningKeys() (*CrossSigningKeysCache, error) { var keysCache CrossSigningKeysCache var err error if keysCache.MasterKey, err = olm.NewPkSigning(); err != nil { return nil, fmt.Errorf("failed to generate master key: %w", err) } if keysCache.SelfSigningKey, err = olm.NewPkSigning(); err != nil { return nil, fmt.Errorf("failed to generate self-signing key: %w", err) } if keysCache.UserSigningKey, err = olm.NewPkSigning(); err != nil { return nil, fmt.Errorf("failed to generate user-signing key: %w", err) } mach.Log.Debug("Generated cross-signing keys: Master: `%v` Self-signing: `%v` User-signing: `%v`", keysCache.MasterKey.PublicKey, keysCache.SelfSigningKey.PublicKey, keysCache.UserSigningKey.PublicKey) return &keysCache, nil } // PublishCrossSigningKeys signs and uploads the public keys of the given cross-signing keys to the server. func (mach *OlmMachine) PublishCrossSigningKeys(keys *CrossSigningKeysCache, uiaCallback mautrix.UIACallback) error { userID := mach.Client.UserID masterKeyID := id.NewKeyID(id.KeyAlgorithmEd25519, keys.MasterKey.PublicKey.String()) masterKey := mautrix.CrossSigningKeys{ UserID: userID, Usage: []id.CrossSigningUsage{id.XSUsageMaster}, Keys: map[id.KeyID]id.Ed25519{ masterKeyID: keys.MasterKey.PublicKey, }, } selfKey := mautrix.CrossSigningKeys{ UserID: userID, Usage: []id.CrossSigningUsage{id.XSUsageSelfSigning}, Keys: map[id.KeyID]id.Ed25519{ id.NewKeyID(id.KeyAlgorithmEd25519, keys.SelfSigningKey.PublicKey.String()): keys.SelfSigningKey.PublicKey, }, } selfSig, err := keys.MasterKey.SignJSON(selfKey) if err != nil { return fmt.Errorf("failed to sign self-signing key: %w", err) } selfKey.Signatures = map[id.UserID]map[id.KeyID]string{ userID: { masterKeyID: selfSig, }, } userKey := mautrix.CrossSigningKeys{ UserID: userID, Usage: []id.CrossSigningUsage{id.XSUsageUserSigning}, Keys: map[id.KeyID]id.Ed25519{ id.NewKeyID(id.KeyAlgorithmEd25519, keys.UserSigningKey.PublicKey.String()): keys.UserSigningKey.PublicKey, }, } userSig, err := keys.MasterKey.SignJSON(userKey) if err != nil { return fmt.Errorf("failed to sign user-signing key: %w", err) } userKey.Signatures = map[id.UserID]map[id.KeyID]string{ userID: { masterKeyID: userSig, }, } err = mach.Client.UploadCrossSigningKeys(&mautrix.UploadCrossSigningKeysReq{ Master: masterKey, SelfSigning: selfKey, UserSigning: userKey, }, uiaCallback) if err != nil { return err } mach.CrossSigningKeys = keys mach.crossSigningPubkeys = keys.PublicKeys() return nil } go-0.11.1/crypto/cross_sign_pubkey.go000066400000000000000000000042741436100171500176030ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package crypto import ( "fmt" "maunium.net/go/mautrix" "maunium.net/go/mautrix/id" ) type CrossSigningPublicKeysCache struct { MasterKey id.Ed25519 SelfSigningKey id.Ed25519 UserSigningKey id.Ed25519 } func (mach *OlmMachine) GetOwnCrossSigningPublicKeys() *CrossSigningPublicKeysCache { if mach.crossSigningPubkeys != nil { return mach.crossSigningPubkeys } if mach.CrossSigningKeys != nil { mach.crossSigningPubkeys = mach.CrossSigningKeys.PublicKeys() return mach.crossSigningPubkeys } cspk, err := mach.GetCrossSigningPublicKeys(mach.Client.UserID) if err != nil { mach.Log.Error("Failed to get own cross-signing public keys: %v", err) return nil } mach.crossSigningPubkeys = cspk return mach.crossSigningPubkeys } func (mach *OlmMachine) GetCrossSigningPublicKeys(userID id.UserID) (*CrossSigningPublicKeysCache, error) { dbKeys, err := mach.CryptoStore.GetCrossSigningKeys(userID) if err != nil { return nil, fmt.Errorf("failed to get keys from database: %w", err) } if len(dbKeys) > 0 { masterKey, ok := dbKeys[id.XSUsageMaster] if ok { selfSigning, _ := dbKeys[id.XSUsageSelfSigning] userSigning, _ := dbKeys[id.XSUsageUserSigning] return &CrossSigningPublicKeysCache{ MasterKey: masterKey, SelfSigningKey: selfSigning, UserSigningKey: userSigning, }, nil } } keys, err := mach.Client.QueryKeys(&mautrix.ReqQueryKeys{ DeviceKeys: mautrix.DeviceKeysRequest{ userID: mautrix.DeviceIDList{}, }, }) if err != nil { return nil, fmt.Errorf("failed to query keys: %w", err) } var cspk CrossSigningPublicKeysCache masterKeys, ok := keys.MasterKeys[userID] if !ok { return nil, nil } cspk.MasterKey = masterKeys.FirstKey() selfSigningKeys, ok := keys.SelfSigningKeys[userID] if !ok { return nil, nil } cspk.SelfSigningKey = selfSigningKeys.FirstKey() userSigningKeys, ok := keys.UserSigningKeys[userID] if ok { cspk.UserSigningKey = userSigningKeys.FirstKey() } return &cspk, nil } go-0.11.1/crypto/cross_sign_signing.go000066400000000000000000000203151436100171500177340ustar00rootroot00000000000000// Copyright (c) 2020 Nikos Filippakis // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package crypto import ( "errors" "fmt" "maunium.net/go/mautrix" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) var ( ErrCrossSigningKeysNotCached = errors.New("cross-signing private keys not in cache") ErrUserSigningKeyNotCached = errors.New("user-signing private key not in cache") ErrSelfSigningKeyNotCached = errors.New("self-signing private key not in cache") ErrSignatureUploadFail = errors.New("server-side failure uploading signatures") ErrCantSignOwnMasterKey = errors.New("signing your own master key is not allowed") ErrCantSignOtherDevice = errors.New("signing other users' devices is not allowed") ErrUserNotInQueryResponse = errors.New("could not find user in query keys response") ErrDeviceNotInQueryResponse = errors.New("could not find device in query keys response") ErrOlmAccountNotLoaded = errors.New("olm account has not been loaded") ErrCrossSigningMasterKeyNotFound = errors.New("cross-signing master key not found") ErrMasterKeyMACNotFound = errors.New("found cross-signing master key, but didn't find corresponding MAC in verification request") ErrMismatchingMasterKeyMAC = errors.New("mismatching cross-signing master key MAC") ) func (mach *OlmMachine) fetchMasterKey(device *DeviceIdentity, content *event.VerificationMacEventContent, verState *verificationState, transactionID string) (id.Ed25519, error) { crossSignKeys, err := mach.CryptoStore.GetCrossSigningKeys(device.UserID) if err != nil { return "", fmt.Errorf("failed to fetch cross-signing keys: %w", err) } masterKey, ok := crossSignKeys[id.XSUsageMaster] if !ok { return "", ErrCrossSigningMasterKeyNotFound } masterKeyID := id.NewKeyID(id.KeyAlgorithmEd25519, masterKey.String()) masterKeyMAC, ok := content.Mac[masterKeyID] if !ok { return masterKey, ErrMasterKeyMACNotFound } expectedMasterKeyMAC, _, err := mach.getPKAndKeysMAC(verState.sas, device.UserID, device.DeviceID, mach.Client.UserID, mach.Client.DeviceID, transactionID, masterKey, masterKeyID, content.Mac) if err != nil { return masterKey, fmt.Errorf("failed to calculate expected MAC for master key: %w", err) } if masterKeyMAC != expectedMasterKeyMAC { err = fmt.Errorf("%w: expected %s, got %s", ErrMismatchingMasterKeyMAC, expectedMasterKeyMAC, masterKeyMAC) } return masterKey, err } // SignUser creates a cross-signing signature for a user, stores it and uploads it to the server. func (mach *OlmMachine) SignUser(userID id.UserID, masterKey id.Ed25519) error { if userID == mach.Client.UserID { return ErrCantSignOwnMasterKey } else if mach.CrossSigningKeys == nil || mach.CrossSigningKeys.UserSigningKey == nil { return ErrUserSigningKeyNotCached } masterKeyObj := mautrix.ReqKeysSignatures{ UserID: userID, Usage: []id.CrossSigningUsage{id.XSUsageMaster}, Keys: map[id.KeyID]string{ id.NewKeyID(id.KeyAlgorithmEd25519, masterKey.String()): masterKey.String(), }, } signature, err := mach.signAndUpload(masterKeyObj, userID, masterKey.String(), mach.CrossSigningKeys.UserSigningKey) if err != nil { return err } mach.Log.Trace("Signed master key of %s with user-signing key: `%v`", userID, signature) if err := mach.CryptoStore.PutSignature(userID, masterKey, mach.Client.UserID, mach.CrossSigningKeys.UserSigningKey.PublicKey, signature); err != nil { return fmt.Errorf("error storing signature in crypto store: %w", err) } return nil } // SignOwnMasterKey uses the current account for signing the current user's master key and uploads the signature. func (mach *OlmMachine) SignOwnMasterKey() error { if mach.CrossSigningKeys == nil { return ErrCrossSigningKeysNotCached } else if mach.account == nil { return ErrOlmAccountNotLoaded } userID := mach.Client.UserID deviceID := mach.Client.DeviceID masterKey := mach.CrossSigningKeys.MasterKey.PublicKey masterKeyObj := mautrix.ReqKeysSignatures{ UserID: userID, Usage: []id.CrossSigningUsage{id.XSUsageMaster}, Keys: map[id.KeyID]string{ id.NewKeyID(id.KeyAlgorithmEd25519, masterKey.String()): masterKey.String(), }, } signature, err := mach.account.Internal.SignJSON(masterKeyObj) if err != nil { return fmt.Errorf("failed to sign JSON: %w", err) } masterKeyObj.Signatures = mautrix.Signatures{ userID: map[id.KeyID]string{ id.NewKeyID(id.KeyAlgorithmEd25519, deviceID.String()): signature, }, } mach.Log.Trace("Signed own master key with device %v: `%v`", deviceID, signature) resp, err := mach.Client.UploadSignatures(&mautrix.ReqUploadSignatures{ userID: map[string]mautrix.ReqKeysSignatures{ masterKey.String(): masterKeyObj, }, }) if err != nil { return fmt.Errorf("error while uploading signatures: %w", err) } else if len(resp.Failures) > 0 { return fmt.Errorf("%w: %+v", ErrSignatureUploadFail, resp.Failures) } if err := mach.CryptoStore.PutSignature(userID, masterKey, userID, mach.account.SigningKey(), signature); err != nil { return fmt.Errorf("error storing signature in crypto store: %w", err) } return nil } // SignOwnDevice creates a cross-signing signature for a device belonging to the current user and uploads it to the server. func (mach *OlmMachine) SignOwnDevice(device *DeviceIdentity) error { if device.UserID != mach.Client.UserID { return ErrCantSignOtherDevice } else if mach.CrossSigningKeys == nil || mach.CrossSigningKeys.SelfSigningKey == nil { return ErrSelfSigningKeyNotCached } deviceKeys, err := mach.getFullDeviceKeys(device) if err != nil { return err } deviceKeyObj := mautrix.ReqKeysSignatures{ UserID: device.UserID, DeviceID: device.DeviceID, Algorithms: deviceKeys.Algorithms, Keys: make(map[id.KeyID]string), } for keyID, key := range deviceKeys.Keys { deviceKeyObj.Keys[id.KeyID(keyID)] = key } signature, err := mach.signAndUpload(deviceKeyObj, device.UserID, device.DeviceID.String(), mach.CrossSigningKeys.SelfSigningKey) if err != nil { return err } mach.Log.Trace("Signed own device %s with self-signing key: `%v`", device.UserID, device.DeviceID, signature) if err := mach.CryptoStore.PutSignature(device.UserID, device.SigningKey, mach.Client.UserID, mach.CrossSigningKeys.SelfSigningKey.PublicKey, signature); err != nil { return fmt.Errorf("error storing signature in crypto store: %w", err) } return nil } // getFullDeviceKeys gets the full device keys object for the given device. // This is used because we don't cache some of the details like list of algorithms and unsupported key types. func (mach *OlmMachine) getFullDeviceKeys(device *DeviceIdentity) (*mautrix.DeviceKeys, error) { devicesKeys, err := mach.Client.QueryKeys(&mautrix.ReqQueryKeys{ DeviceKeys: mautrix.DeviceKeysRequest{ device.UserID: mautrix.DeviceIDList{device.DeviceID}, }, }) if err != nil { return nil, fmt.Errorf("error querying device keys for %s: %w", device.DeviceID, err) } userKeys, ok := devicesKeys.DeviceKeys[device.UserID] if !ok { return nil, ErrUserNotInQueryResponse } deviceKeys, ok := userKeys[device.DeviceID] if !ok { return nil, ErrDeviceNotInQueryResponse } _, err = mach.validateDevice(device.UserID, device.DeviceID, deviceKeys, device) return &deviceKeys, err } // signAndUpload signs the given key signatures object and uploads it to the server. func (mach *OlmMachine) signAndUpload(req mautrix.ReqKeysSignatures, userID id.UserID, signedThing string, key *olm.PkSigning) (string, error) { signature, err := key.SignJSON(req) if err != nil { return "", fmt.Errorf("failed to sign JSON: %w", err) } req.Signatures = mautrix.Signatures{ mach.Client.UserID: map[id.KeyID]string{ id.NewKeyID(id.KeyAlgorithmEd25519, key.PublicKey.String()): signature, }, } resp, err := mach.Client.UploadSignatures(&mautrix.ReqUploadSignatures{ userID: map[string]mautrix.ReqKeysSignatures{ signedThing: req, }, }) if err != nil { return "", fmt.Errorf("error while uploading signatures: %w", err) } else if len(resp.Failures) > 0 { return "", fmt.Errorf("%w: %+v", ErrSignatureUploadFail, resp.Failures) } return signature, nil } go-0.11.1/crypto/cross_sign_ssss.go000066400000000000000000000100731436100171500172710ustar00rootroot00000000000000// Copyright (c) 2020 Nikos Filippakis // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package crypto import ( "fmt" "maunium.net/go/mautrix" "maunium.net/go/mautrix/crypto/ssss" "maunium.net/go/mautrix/crypto/utils" "maunium.net/go/mautrix/event" ) // FetchCrossSigningKeysFromSSSS fetches all the cross-signing keys from SSSS, decrypts them using the given key and stores them in the olm machine. func (mach *OlmMachine) FetchCrossSigningKeysFromSSSS(key *ssss.Key) error { masterKey, err := mach.retrieveDecryptXSigningKey(event.AccountDataCrossSigningMaster, key) if err != nil { return err } selfSignKey, err := mach.retrieveDecryptXSigningKey(event.AccountDataCrossSigningSelf, key) if err != nil { return err } userSignKey, err := mach.retrieveDecryptXSigningKey(event.AccountDataCrossSigningUser, key) if err != nil { return err } return mach.ImportCrossSigningKeys(CrossSigningSeeds{ MasterKey: masterKey[:], SelfSigningKey: selfSignKey[:], UserSigningKey: userSignKey[:], }) } // retrieveDecryptXSigningKey retrieves the requested cross-signing key from SSSS and decrypts it using the given SSSS key. func (mach *OlmMachine) retrieveDecryptXSigningKey(keyName event.Type, key *ssss.Key) ([utils.AESCTRKeyLength]byte, error) { var decryptedKey [utils.AESCTRKeyLength]byte var encData ssss.EncryptedAccountDataEventContent // retrieve and parse the account data for this key type from SSSS err := mach.Client.GetAccountData(keyName.Type, &encData) if err != nil { return decryptedKey, err } decrypted, err := encData.Decrypt(keyName.Type, key) if err != nil { return decryptedKey, err } copy(decryptedKey[:], decrypted) return decryptedKey, nil } // GenerateAndUploadCrossSigningKeys generates a new key with all corresponding cross-signing keys. // // A passphrase can be provided to generate the SSSS key. If the passphrase is empty, a random key // is used. The base58-formatted recovery key is the first return parameter. // // The account password of the user is required for uploading keys to the server. func (mach *OlmMachine) GenerateAndUploadCrossSigningKeys(userPassword, passphrase string) (string, error) { key, err := mach.SSSS.GenerateAndUploadKey(passphrase) if err != nil { return "", fmt.Errorf("failed to generate and upload SSSS key: %w", err) } // generate the three cross-signing keys keysCache, err := mach.GenerateCrossSigningKeys() if err != nil { return "", err } recoveryKey := key.RecoveryKey() // Store the private keys in SSSS if err := mach.UploadCrossSigningKeysToSSSS(key, keysCache); err != nil { return recoveryKey, fmt.Errorf("failed to upload cross-signing keys to SSSS: %w", err) } // Publish cross-signing keys err = mach.PublishCrossSigningKeys(keysCache, func(uiResp *mautrix.RespUserInteractive) interface{} { return &mautrix.ReqUIAuthLogin{ BaseAuthData: mautrix.BaseAuthData{ Type: mautrix.AuthTypePassword, Session: uiResp.Session, }, User: mach.Client.UserID.String(), Password: userPassword, } }) if err != nil { return recoveryKey, fmt.Errorf("failed to publish cross-signing keys: %w", err) } err = mach.SSSS.SetDefaultKeyID(key.ID) if err != nil { return recoveryKey, fmt.Errorf("failed to mark %s as the default key: %w", key.ID, err) } return recoveryKey, nil } // UploadCrossSigningKeysToSSSS stores the given cross-signing keys on the server encrypted with the given key. func (mach *OlmMachine) UploadCrossSigningKeysToSSSS(key *ssss.Key, keys *CrossSigningKeysCache) error { if err := mach.SSSS.SetEncryptedAccountData(event.AccountDataCrossSigningMaster, keys.MasterKey.Seed, key); err != nil { return err } if err := mach.SSSS.SetEncryptedAccountData(event.AccountDataCrossSigningSelf, keys.SelfSigningKey.Seed, key); err != nil { return err } if err := mach.SSSS.SetEncryptedAccountData(event.AccountDataCrossSigningUser, keys.UserSigningKey.Seed, key); err != nil { return err } return nil } go-0.11.1/crypto/cross_sign_store.go000066400000000000000000000056021436100171500174340ustar00rootroot00000000000000// Copyright (c) 2020 Nikos Filippakis // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package crypto import ( "maunium.net/go/mautrix" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/id" ) func (mach *OlmMachine) storeCrossSigningKeys(crossSigningKeys map[id.UserID]mautrix.CrossSigningKeys, deviceKeys map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys) { for userID, userKeys := range crossSigningKeys { currentKeys, err := mach.CryptoStore.GetCrossSigningKeys(userID) if err != nil { mach.Log.Error("Error fetching current cross-signing keys of user %v: %v", userID, err) } if currentKeys != nil { for curKeyUsage, curKey := range currentKeys { // got a new key with the same usage as an existing key for _, newKeyUsage := range userKeys.Usage { if newKeyUsage == curKeyUsage { if _, ok := userKeys.Keys[id.NewKeyID(id.KeyAlgorithmEd25519, curKey.String())]; !ok { // old key is not in the new key map so we drop signatures made by it if count, err := mach.CryptoStore.DropSignaturesByKey(userID, curKey); err != nil { mach.Log.Error("Error deleting old signatures: %v", err) } else { mach.Log.Debug("Dropped %v signatures made by key `%v` (%v) as it has been replaced", count, curKey, curKeyUsage) } } } } } } for _, key := range userKeys.Keys { for _, usage := range userKeys.Usage { mach.Log.Debug("Storing cross-signing key for %v: %v (type %v)", userID, key, usage) if err := mach.CryptoStore.PutCrossSigningKey(userID, usage, key); err != nil { mach.Log.Error("Error storing cross-signing key: %v", err) } } for signUserID, keySigs := range userKeys.Signatures { for signKeyID, signature := range keySigs { _, signKeyName := signKeyID.Parse() signingKey := id.Ed25519(signKeyName) // if the signer is one of this user's own devices, find the key from the key ID if signUserID == userID { ownDeviceID := id.DeviceID(signKeyName) if ownDeviceKeys, ok := deviceKeys[userID][ownDeviceID]; ok { signingKey = ownDeviceKeys.Keys.GetEd25519(ownDeviceID) mach.Log.Debug("Treating %v as the device name", signKeyName) } } mach.Log.Debug("Verifying with key %v of user %v", signingKey, signUserID) if verified, err := olm.VerifySignatureJSON(userKeys, signUserID, signKeyName, signingKey); err != nil { mach.Log.Error("Error while verifying cross-signing keys: %v", err) } else { if verified { mach.Log.Debug("Cross-signing keys verified") mach.CryptoStore.PutSignature(userID, key, signUserID, signingKey, signature) } else { mach.Log.Error("Cross-signing keys verification unsuccessful", err) } } } } } } } go-0.11.1/crypto/cross_sign_test.go000066400000000000000000000111131436100171500172510ustar00rootroot00000000000000// Copyright (c) 2020 Nikos Filippakis // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package crypto import ( "database/sql" "testing" "maunium.net/go/mautrix" "maunium.net/go/mautrix/crypto/olm" sqlUpgrade "maunium.net/go/mautrix/crypto/sql_store_upgrade" "maunium.net/go/mautrix/id" ) func getOlmMachine(t *testing.T) *OlmMachine { db, err := sql.Open("sqlite3", ":memory:?_busy_timeout=5000") if err != nil { t.Fatalf("Error opening db: %v", err) } sqlUpgrade.Upgrade(db, "sqlite3") sqlStore := NewSQLCryptoStore(db, "sqlite3", "accid", id.DeviceID("dev"), []byte("test"), emptyLogger{}) userID := id.UserID("@mautrix") mk, _ := olm.NewPkSigning() ssk, _ := olm.NewPkSigning() usk, _ := olm.NewPkSigning() sqlStore.PutCrossSigningKey(userID, id.XSUsageMaster, mk.PublicKey) sqlStore.PutCrossSigningKey(userID, id.XSUsageSelfSigning, ssk.PublicKey) sqlStore.PutCrossSigningKey(userID, id.XSUsageUserSigning, usk.PublicKey) return &OlmMachine{ CryptoStore: sqlStore, CrossSigningKeys: &CrossSigningKeysCache{ MasterKey: mk, SelfSigningKey: ssk, UserSigningKey: usk, }, Client: &mautrix.Client{ UserID: userID, }, Log: emptyLogger{}, } } func TestTrustOwnDevice(t *testing.T) { m := getOlmMachine(t) ownDevice := &DeviceIdentity{ UserID: m.Client.UserID, DeviceID: "device", SigningKey: id.Ed25519("deviceKey"), } if m.IsDeviceTrusted(ownDevice) { t.Error("Own device trusted while it shouldn't be") } m.CryptoStore.PutSignature(ownDevice.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey, ownDevice.UserID, m.CrossSigningKeys.MasterKey.PublicKey, "sig1") m.CryptoStore.PutSignature(ownDevice.UserID, ownDevice.SigningKey, ownDevice.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey, "sig2") if !m.IsUserTrusted(ownDevice.UserID) { t.Error("Own user not trusted while they should be") } if !m.IsDeviceTrusted(ownDevice) { t.Error("Own device not trusted while it should be") } } func TestTrustOtherUser(t *testing.T) { m := getOlmMachine(t) otherUser := id.UserID("@user") if m.IsUserTrusted(otherUser) { t.Error("Other user trusted while they shouldn't be") } theirMasterKey, _ := olm.NewPkSigning() m.CryptoStore.PutCrossSigningKey(otherUser, id.XSUsageMaster, theirMasterKey.PublicKey) m.CryptoStore.PutSignature(m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey, m.Client.UserID, m.CrossSigningKeys.MasterKey.PublicKey, "sig1") // sign them with self-signing instead of user-signing key m.CryptoStore.PutSignature(otherUser, theirMasterKey.PublicKey, m.Client.UserID, m.CrossSigningKeys.SelfSigningKey.PublicKey, "invalid_sig") if m.IsUserTrusted(otherUser) { t.Error("Other user trusted before their master key has been signed with our user-signing key") } m.CryptoStore.PutSignature(otherUser, theirMasterKey.PublicKey, m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey, "sig2") if !m.IsUserTrusted(otherUser) { t.Error("Other user not trusted while they should be") } } func TestTrustOtherDevice(t *testing.T) { m := getOlmMachine(t) otherUser := id.UserID("@user") theirDevice := &DeviceIdentity{ UserID: otherUser, DeviceID: "theirDevice", SigningKey: id.Ed25519("theirDeviceKey"), } if m.IsUserTrusted(otherUser) { t.Error("Other user trusted while they shouldn't be") } if m.IsDeviceTrusted(theirDevice) { t.Error("Other device trusted while it shouldn't be") } theirMasterKey, _ := olm.NewPkSigning() m.CryptoStore.PutCrossSigningKey(otherUser, id.XSUsageMaster, theirMasterKey.PublicKey) theirSSK, _ := olm.NewPkSigning() m.CryptoStore.PutCrossSigningKey(otherUser, id.XSUsageSelfSigning, theirSSK.PublicKey) m.CryptoStore.PutSignature(m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey, m.Client.UserID, m.CrossSigningKeys.MasterKey.PublicKey, "sig1") m.CryptoStore.PutSignature(otherUser, theirMasterKey.PublicKey, m.Client.UserID, m.CrossSigningKeys.UserSigningKey.PublicKey, "sig2") if !m.IsUserTrusted(otherUser) { t.Error("Other user not trusted while they should be") } m.CryptoStore.PutSignature(otherUser, theirSSK.PublicKey, otherUser, theirMasterKey.PublicKey, "sig3") if m.IsDeviceTrusted(theirDevice) { t.Error("Other device trusted before it has been signed with user's SSK") } m.CryptoStore.PutSignature(otherUser, theirDevice.SigningKey, otherUser, theirSSK.PublicKey, "sig4") if !m.IsDeviceTrusted(theirDevice) { t.Error("Other device not trusted while it should be") } } go-0.11.1/crypto/cross_sign_validation.go000066400000000000000000000063511436100171500204340ustar00rootroot00000000000000// Copyright (c) 2020 Nikos Filippakis // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package crypto import ( "maunium.net/go/mautrix/id" ) // IsDeviceTrusted returns whether a device has been determined to be trusted either through verification or cross-signing. func (mach *OlmMachine) IsDeviceTrusted(device *DeviceIdentity) bool { userID := device.UserID if device.Trust == TrustStateVerified { return true } else if device.Trust == TrustStateBlacklisted { return false } if !mach.IsUserTrusted(userID) { return false } theirKeys, err := mach.CryptoStore.GetCrossSigningKeys(userID) if err != nil { mach.Log.Error("Error retrieving cross-singing key of user %v from database: %v", userID, err) return false } theirMSK, ok := theirKeys[id.XSUsageMaster] if !ok { mach.Log.Error("Master key of user %v not found", userID) return false } theirSSK, ok := theirKeys[id.XSUsageSelfSigning] if !ok { mach.Log.Error("Self-signing key of user %v not found", userID) return false } sskSigExists, err := mach.CryptoStore.IsKeySignedBy(userID, theirSSK, userID, theirMSK) if err != nil { mach.Log.Error("Error retrieving cross-singing signatures for master key of user %v from database: %v", userID, err) return false } if !sskSigExists { mach.Log.Warn("Self-signing key of user %v is not signed by their master key", userID) return false } deviceSigExists, err := mach.CryptoStore.IsKeySignedBy(userID, device.SigningKey, userID, theirSSK) if err != nil { mach.Log.Error("Error retrieving cross-singing signatures for master key of user %v from database: %v", userID, err) return false } return deviceSigExists } // IsUserTrusted returns whether a user has been determined to be trusted by our user-signing key having signed their master key. // In the case the user ID is our own and we have successfully retrieved our cross-signing keys, we trust our own user. func (mach *OlmMachine) IsUserTrusted(userID id.UserID) bool { csPubkeys := mach.GetOwnCrossSigningPublicKeys() if csPubkeys == nil { return false } if userID == mach.Client.UserID { return true } // first we verify our user-signing key sskSigs, err := mach.CryptoStore.GetSignaturesForKeyBy(mach.Client.UserID, csPubkeys.UserSigningKey, mach.Client.UserID) if err != nil { mach.Log.Error("Error retrieving our self-singing key signatures: %v", err) return false } if _, ok := sskSigs[csPubkeys.MasterKey]; !ok { // our user-signing key was not signed by our master key return false } theirKeys, err := mach.CryptoStore.GetCrossSigningKeys(userID) if err != nil { mach.Log.Error("Error retrieving cross-singing key of user %v from database: %v", userID, err) return false } theirMskKey, ok := theirKeys[id.XSUsageMaster] if !ok { mach.Log.Error("Master key of user %v not found", userID) return false } sigExists, err := mach.CryptoStore.IsKeySignedBy(userID, theirMskKey, mach.Client.UserID, csPubkeys.UserSigningKey) if err != nil { mach.Log.Error("Error retrieving cross-singing signatures for master key of user %v from database: %v", userID, err) return false } return sigExists } go-0.11.1/crypto/decryptmegolm.go000066400000000000000000000103661436100171500167250ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package crypto import ( "encoding/json" "errors" "fmt" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) var ( IncorrectEncryptedContentType = errors.New("event content is not instance of *event.EncryptedEventContent") NoSessionFound = errors.New("failed to decrypt megolm event: no session with given ID found") DuplicateMessageIndex = errors.New("duplicate megolm message index") WrongRoom = errors.New("encrypted megolm event is not intended for this room") DeviceKeyMismatch = errors.New("device keys in event and verified device info do not match") ) type megolmEvent struct { RoomID id.RoomID `json:"room_id"` Type event.Type `json:"type"` Content event.Content `json:"content"` } // DecryptMegolmEvent decrypts an m.room.encrypted event where the algorithm is m.megolm.v1.aes-sha2 func (mach *OlmMachine) DecryptMegolmEvent(evt *event.Event) (*event.Event, error) { content, ok := evt.Content.Parsed.(*event.EncryptedEventContent) if !ok { return nil, IncorrectEncryptedContentType } else if content.Algorithm != id.AlgorithmMegolmV1 { return nil, UnsupportedAlgorithm } sess, err := mach.CryptoStore.GetGroupSession(evt.RoomID, content.SenderKey, content.SessionID) if err != nil { return nil, fmt.Errorf("failed to get group session: %w", err) } else if sess == nil { return nil, fmt.Errorf("%w (ID %s)", NoSessionFound, content.SessionID) } plaintext, messageIndex, err := sess.Internal.Decrypt(content.MegolmCiphertext) if err != nil { return nil, fmt.Errorf("failed to decrypt megolm event: %w", err) } else if !mach.CryptoStore.ValidateMessageIndex(content.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp) { return nil, DuplicateMessageIndex } var verified bool ownSigningKey, ownIdentityKey := mach.account.Keys() if content.DeviceID == mach.Client.DeviceID && sess.SigningKey == ownSigningKey && content.SenderKey == ownIdentityKey { verified = true } else { device, err := mach.GetOrFetchDevice(evt.Sender, content.DeviceID) if err != nil { // We don't want to throw these errors as the message can still be decrypted. mach.Log.Debug("Failed to get device %s/%s to verify session %s: %v", evt.Sender, content.DeviceID, sess.ID(), err) // TODO maybe store the info that the device is deleted? } else if mach.IsDeviceTrusted(device) && len(sess.ForwardingChains) == 0 { // For some reason, matrix-nio had a comment saying not to events decrypted using a forwarded key as verified. if device.SigningKey != sess.SigningKey || device.IdentityKey != content.SenderKey { return nil, DeviceKeyMismatch } verified = true } } megolmEvt := &megolmEvent{} err = json.Unmarshal(plaintext, &megolmEvt) if err != nil { return nil, fmt.Errorf("failed to parse megolm payload: %w", err) } else if megolmEvt.RoomID != evt.RoomID { return nil, WrongRoom } megolmEvt.Type.Class = evt.Type.Class err = megolmEvt.Content.ParseRaw(megolmEvt.Type) if err != nil { if event.IsUnsupportedContentType(err) { mach.Log.Warn("Unsupported event type %s in encrypted event %s", megolmEvt.Type.Repr(), evt.ID) } else { return nil, fmt.Errorf("failed to parse content of megolm payload event: %w", err) } } if content.RelatesTo != nil { relatable, ok := megolmEvt.Content.Parsed.(event.Relatable) if ok { if relatable.OptionalGetRelatesTo() == nil { relatable.SetRelatesTo(content.RelatesTo) } else { mach.Log.Trace("Not overriding relation data in %s, as encrypted payload already has it", evt.ID) } } else { mach.Log.Warn("Encrypted event %s has relation data, but content type %T (%s) doesn't support it", evt.ID, megolmEvt.Content.Parsed, megolmEvt.Type.String()) } } megolmEvt.Type.Class = evt.Type.Class return &event.Event{ Sender: evt.Sender, Type: megolmEvt.Type, Timestamp: evt.Timestamp, ID: evt.ID, RoomID: evt.RoomID, Content: megolmEvt.Content, Unsigned: evt.Unsigned, Mautrix: event.MautrixInfo{ Verified: verified, }, }, nil } go-0.11.1/crypto/decryptolm.go000066400000000000000000000225711436100171500162350ustar00rootroot00000000000000// Copyright (c) 2021 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package crypto import ( "encoding/json" "errors" "fmt" "time" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) var ( UnsupportedAlgorithm = errors.New("unsupported event encryption algorithm") NotEncryptedForMe = errors.New("olm event doesn't contain ciphertext for this device") UnsupportedOlmMessageType = errors.New("unsupported olm message type") DecryptionFailedWithMatchingSession = errors.New("decryption failed with matching session") DecryptionFailedForNormalMessage = errors.New("decryption failed for normal message") SenderMismatch = errors.New("mismatched sender in olm payload") RecipientMismatch = errors.New("mismatched recipient in olm payload") RecipientKeyMismatch = errors.New("mismatched recipient key in olm payload") ) // DecryptedOlmEvent represents an event that was decrypted from an event encrypted with the m.olm.v1.curve25519-aes-sha2 algorithm. type DecryptedOlmEvent struct { Source *event.Event `json:"-"` SenderKey id.SenderKey `json:"-"` Sender id.UserID `json:"sender"` SenderDevice id.DeviceID `json:"sender_device"` Keys OlmEventKeys `json:"keys"` Recipient id.UserID `json:"recipient"` RecipientKeys OlmEventKeys `json:"recipient_keys"` Type event.Type `json:"type"` Content event.Content `json:"content"` } func (mach *OlmMachine) decryptOlmEvent(evt *event.Event, traceID string) (*DecryptedOlmEvent, error) { content, ok := evt.Content.Parsed.(*event.EncryptedEventContent) if !ok { return nil, IncorrectEncryptedContentType } else if content.Algorithm != id.AlgorithmOlmV1 { return nil, UnsupportedAlgorithm } ownContent, ok := content.OlmCiphertext[mach.account.IdentityKey()] if !ok { return nil, NotEncryptedForMe } decrypted, err := mach.decryptAndParseOlmCiphertext(evt.Sender, content.SenderKey, ownContent.Type, ownContent.Body, traceID) if err != nil { return nil, err } decrypted.Source = evt return decrypted, nil } type OlmEventKeys struct { Ed25519 id.Ed25519 `json:"ed25519"` } func (mach *OlmMachine) decryptAndParseOlmCiphertext(sender id.UserID, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string, traceID string) (*DecryptedOlmEvent, error) { if olmType != id.OlmMsgTypePreKey && olmType != id.OlmMsgTypeMsg { return nil, UnsupportedOlmMessageType } endTimeTrace := mach.timeTrace("decrypting olm ciphertext", traceID, 5*time.Second) plaintext, err := mach.tryDecryptOlmCiphertext(sender, senderKey, olmType, ciphertext, traceID) endTimeTrace() if err != nil { return nil, err } defer mach.timeTrace("parsing decrypted olm event", traceID, time.Second)() var olmEvt DecryptedOlmEvent err = json.Unmarshal(plaintext, &olmEvt) if err != nil { return nil, fmt.Errorf("failed to parse olm payload: %w", err) } if sender != olmEvt.Sender { return nil, SenderMismatch } else if mach.Client.UserID != olmEvt.Recipient { return nil, RecipientMismatch } else if mach.account.SigningKey() != olmEvt.RecipientKeys.Ed25519 { return nil, RecipientKeyMismatch } err = olmEvt.Content.ParseRaw(olmEvt.Type) if err != nil && !event.IsUnsupportedContentType(err) { return nil, fmt.Errorf("failed to parse content of olm payload event: %w", err) } olmEvt.SenderKey = senderKey return &olmEvt, nil } func (mach *OlmMachine) tryDecryptOlmCiphertext(sender id.UserID, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string, traceID string) ([]byte, error) { endTimeTrace := mach.timeTrace("waiting for olm lock", traceID, 5*time.Second) mach.olmLock.Lock() endTimeTrace() defer mach.olmLock.Unlock() plaintext, err := mach.tryDecryptOlmCiphertextWithExistingSession(senderKey, olmType, ciphertext, traceID) if err != nil { if err == DecryptionFailedWithMatchingSession { mach.Log.Warn("Found matching session yet decryption failed for sender %s with key %s", sender, senderKey) go mach.unwedgeDevice(sender, senderKey) } return nil, fmt.Errorf("failed to decrypt olm event: %w", err) } if plaintext != nil { // Decryption successful return plaintext, nil } // Decryption failed with every known session or no known sessions, let's try to create a new session. // // New sessions can only be created if it's a prekey message, we can't decrypt the message // if it isn't one at this point in time anymore, so return early. if olmType != id.OlmMsgTypePreKey { go mach.unwedgeDevice(sender, senderKey) return nil, DecryptionFailedForNormalMessage } mach.Log.Trace("Trying to create inbound session for %s/%s", sender, senderKey) endTimeTrace = mach.timeTrace("creating inbound olm session", traceID, time.Second) session, err := mach.createInboundSession(senderKey, ciphertext) endTimeTrace() if err != nil { go mach.unwedgeDevice(sender, senderKey) return nil, fmt.Errorf("failed to create new session from prekey message: %w", err) } mach.Log.Debug("Created inbound olm session %s for %s/%s: %s", session.ID(), sender, senderKey, session.Describe()) endTimeTrace = mach.timeTrace(fmt.Sprintf("decrypting prekey olm message with %s/%s", senderKey, session.ID()), traceID, time.Second) plaintext, err = session.Decrypt(ciphertext, olmType) endTimeTrace() if err != nil { go mach.unwedgeDevice(sender, senderKey) return nil, fmt.Errorf("failed to decrypt olm event with session created from prekey message: %w", err) } endTimeTrace = mach.timeTrace(fmt.Sprintf("updating new session %s/%s in database", senderKey, session.ID()), traceID, time.Second) err = mach.CryptoStore.UpdateSession(senderKey, session) endTimeTrace() if err != nil { mach.Log.Warn("Failed to update new olm session in crypto store after decrypting: %v", err) } return plaintext, nil } func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string, traceID string) ([]byte, error) { endTimeTrace := mach.timeTrace(fmt.Sprintf("getting sessions with %s", senderKey), traceID, time.Second) sessions, err := mach.CryptoStore.GetSessions(senderKey) endTimeTrace() if err != nil { return nil, fmt.Errorf("failed to get session for %s: %w", senderKey, err) } for _, session := range sessions { if olmType == id.OlmMsgTypePreKey { endTimeTrace = mach.timeTrace(fmt.Sprintf("checking if prekey olm message matches session %s/%s", senderKey, session.ID()), traceID, time.Second) matches, err := session.Internal.MatchesInboundSession(ciphertext) endTimeTrace() if err != nil { return nil, fmt.Errorf("failed to check if ciphertext matches inbound session: %w", err) } else if !matches { continue } } mach.Log.Trace("Trying to decrypt olm message from %s with session %s: %s", senderKey, session.ID(), session.Describe()) endTimeTrace = mach.timeTrace(fmt.Sprintf("decrypting olm message with %s/%s", senderKey, session.ID()), traceID, time.Second) plaintext, err := session.Decrypt(ciphertext, olmType) endTimeTrace() if err != nil { if olmType == id.OlmMsgTypePreKey { return nil, DecryptionFailedWithMatchingSession } } else { endTimeTrace = mach.timeTrace(fmt.Sprintf("updating session %s/%s in database", senderKey, session.ID()), traceID, time.Second) err = mach.CryptoStore.UpdateSession(senderKey, session) endTimeTrace() if err != nil { mach.Log.Warn("Failed to update olm session in crypto store after decrypting: %v", err) } mach.Log.Trace("Decrypted olm message from %s with session %s", senderKey, session.ID()) return plaintext, nil } } return nil, nil } func (mach *OlmMachine) createInboundSession(senderKey id.SenderKey, ciphertext string) (*OlmSession, error) { session, err := mach.account.NewInboundSessionFrom(senderKey, ciphertext) if err != nil { return nil, err } mach.saveAccount() err = mach.CryptoStore.AddSession(senderKey, session) if err != nil { mach.Log.Error("Failed to store created inbound session: %v", err) } return session, nil } const MinUnwedgeInterval = 1 * time.Hour func (mach *OlmMachine) unwedgeDevice(sender id.UserID, senderKey id.SenderKey) { mach.recentlyUnwedgedLock.Lock() prevUnwedge, ok := mach.recentlyUnwedged[senderKey] delta := time.Now().Sub(prevUnwedge) if ok && delta < MinUnwedgeInterval { mach.Log.Debug("Not creating new Olm session with %s/%s, previous recreation was %s ago", sender, senderKey, delta) mach.recentlyUnwedgedLock.Unlock() return } mach.recentlyUnwedged[senderKey] = time.Now() mach.recentlyUnwedgedLock.Unlock() deviceIdentity, err := mach.GetOrFetchDeviceByKey(sender, senderKey) if err != nil { mach.Log.Error("Failed to find device info by identity key: %v", err) return } else if deviceIdentity == nil { mach.Log.Warn("Didn't find identity of %s/%s, can't unwedge session", sender, senderKey) return } mach.Log.Debug("Creating new Olm session with %s/%s (key: %s)", sender, deviceIdentity.DeviceID, senderKey) mach.devicesToUnwedgeLock.Lock() mach.devicesToUnwedge[senderKey] = true mach.devicesToUnwedgeLock.Unlock() err = mach.SendEncryptedToDevice(deviceIdentity, event.ToDeviceDummy, event.Content{}) if err != nil { mach.Log.Error("Failed to send dummy event to unwedge session with %s/%s: %v", sender, senderKey, err) } } go-0.11.1/crypto/devicelist.go000066400000000000000000000154121436100171500162020ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package crypto import ( "errors" "fmt" "maunium.net/go/mautrix" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/id" ) var ( MismatchingDeviceID = errors.New("mismatching device ID in parameter and keys object") MismatchingUserID = errors.New("mismatching user ID in parameter and keys object") MismatchingSigningKey = errors.New("received update for device with different signing key") NoSigningKeyFound = errors.New("didn't find ed25519 signing key") NoIdentityKeyFound = errors.New("didn't find curve25519 identity key") InvalidKeySignature = errors.New("invalid signature on device keys") ) func (mach *OlmMachine) LoadDevices(user id.UserID) map[id.DeviceID]*DeviceIdentity { return mach.fetchKeys([]id.UserID{user}, "", true)[user] } func (mach *OlmMachine) fetchKeys(users []id.UserID, sinceToken string, includeUntracked bool) (data map[id.UserID]map[id.DeviceID]*DeviceIdentity) { req := &mautrix.ReqQueryKeys{ DeviceKeys: mautrix.DeviceKeysRequest{}, Timeout: 10 * 1000, Token: sinceToken, } if !includeUntracked { users = mach.CryptoStore.FilterTrackedUsers(users) } if len(users) == 0 { return } for _, userID := range users { req.DeviceKeys[userID] = mautrix.DeviceIDList{} } mach.Log.Trace("Querying keys for %v", users) resp, err := mach.Client.QueryKeys(req) if err != nil { mach.Log.Warn("Failed to query keys: %v", err) return } for server, err := range resp.Failures { mach.Log.Warn("Query keys failure for %s: %v", server, err) } mach.Log.Trace("Query key result received with %d users", len(resp.DeviceKeys)) data = make(map[id.UserID]map[id.DeviceID]*DeviceIdentity) for userID, devices := range resp.DeviceKeys { delete(req.DeviceKeys, userID) newDevices := make(map[id.DeviceID]*DeviceIdentity) existingDevices, err := mach.CryptoStore.GetDevices(userID) if err != nil { mach.Log.Warn("Failed to get existing devices for %s: %v", userID, err) existingDevices = make(map[id.DeviceID]*DeviceIdentity) } mach.Log.Trace("Updating devices for %s, got %d devices, have %d in store", userID, len(devices), len(existingDevices)) changed := false for deviceID, deviceKeys := range devices { existing, ok := existingDevices[deviceID] if !ok { // New device changed = true } mach.Log.Trace("Validating device %s of %s", deviceID, userID) newDevice, err := mach.validateDevice(userID, deviceID, deviceKeys, existing) if err != nil { mach.Log.Error("Failed to validate device %s of %s: %v", deviceID, userID, err) } else if newDevice != nil { newDevices[deviceID] = newDevice for signerUserID, signerKeys := range deviceKeys.Signatures { for signerKey, signature := range signerKeys { // verify and save self-signing key signature for each device if selfSignKeys, ok := resp.SelfSigningKeys[signerUserID]; ok { for _, pubKey := range selfSignKeys.Keys { if selfSigs, ok := deviceKeys.Signatures[signerUserID]; !ok { continue } else if _, ok := selfSigs[id.NewKeyID(id.KeyAlgorithmEd25519, pubKey.String())]; !ok { continue } if verified, err := olm.VerifySignatureJSON(deviceKeys, signerUserID, pubKey.String(), pubKey); verified { if signKey, ok := deviceKeys.Keys[id.DeviceKeyID(signerKey)]; ok { signature := deviceKeys.Signatures[signerUserID][id.NewKeyID(id.KeyAlgorithmEd25519, pubKey.String())] mach.Log.Trace("Verified self-signing signature for device %v: `%v`", deviceID, signature) mach.CryptoStore.PutSignature(userID, id.Ed25519(signKey), signerUserID, pubKey, signature) } } else { mach.Log.Warn("Could not verify device self-signing signatures: %v", err) } } } // save signature of device made by its own device signing key if signKey, ok := deviceKeys.Keys[id.DeviceKeyID(signerKey)]; ok { mach.CryptoStore.PutSignature(userID, id.Ed25519(signKey), signerUserID, id.Ed25519(signKey), signature) } } } } } mach.Log.Trace("Storing new device list for %s containing %d devices", userID, len(newDevices)) err = mach.CryptoStore.PutDevices(userID, newDevices) if err != nil { mach.Log.Warn("Failed to update device list for %s: %v", userID, err) } data[userID] = newDevices changed = changed || len(newDevices) != len(existingDevices) if changed { mach.OnDevicesChanged(userID) } } for userID := range req.DeviceKeys { mach.Log.Warn("Didn't get any keys for user %s", userID) } mach.storeCrossSigningKeys(resp.MasterKeys, resp.DeviceKeys) mach.storeCrossSigningKeys(resp.SelfSigningKeys, resp.DeviceKeys) mach.storeCrossSigningKeys(resp.UserSigningKeys, resp.DeviceKeys) return data } // OnDevicesChanged finds all shared rooms with the given user and invalidates outbound sessions in those rooms. // // This is called automatically whenever a device list change is noticed in ProcessSyncResponse and usually does // not need to be called manually. func (mach *OlmMachine) OnDevicesChanged(userID id.UserID) { for _, roomID := range mach.StateStore.FindSharedRooms(userID) { mach.Log.Debug("Devices of %s changed, invalidating group session for %s", userID, roomID) err := mach.CryptoStore.RemoveOutboundGroupSession(roomID) if err != nil { mach.Log.Warn("Failed to invalidate outbound group session of %s on device change for %s: %v", roomID, userID, err) } } } func (mach *OlmMachine) validateDevice(userID id.UserID, deviceID id.DeviceID, deviceKeys mautrix.DeviceKeys, existing *DeviceIdentity) (*DeviceIdentity, error) { if deviceID != deviceKeys.DeviceID { return nil, MismatchingDeviceID } else if userID != deviceKeys.UserID { return nil, MismatchingUserID } signingKey := deviceKeys.Keys.GetEd25519(deviceID) identityKey := deviceKeys.Keys.GetCurve25519(deviceID) if signingKey == "" { return nil, NoSigningKeyFound } else if identityKey == "" { return nil, NoIdentityKeyFound } if existing != nil && existing.SigningKey != signingKey { return existing, MismatchingSigningKey } ok, err := olm.VerifySignatureJSON(deviceKeys, userID, deviceID.String(), signingKey) if err != nil { return existing, fmt.Errorf("failed to verify signature: %w", err) } else if !ok { return existing, InvalidKeySignature } name, ok := deviceKeys.Unsigned["device_display_name"].(string) if !ok { name = string(deviceID) } return &DeviceIdentity{ UserID: userID, DeviceID: deviceID, IdentityKey: identityKey, SigningKey: signingKey, Trust: TrustStateUnset, Name: name, Deleted: false, }, nil } go-0.11.1/crypto/encryptmegolm.go000066400000000000000000000262031436100171500167340ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package crypto import ( "encoding/json" "errors" "fmt" "maunium.net/go/mautrix" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) var ( AlreadyShared = errors.New("group session already shared") NoGroupSession = errors.New("no group session created") ) func getRelatesTo(content interface{}) *event.RelatesTo { contentStruct, ok := content.(*event.Content) if ok { content = contentStruct.Parsed } relatable, ok := content.(event.Relatable) if ok { return relatable.OptionalGetRelatesTo() } return nil } type rawMegolmEvent struct { RoomID id.RoomID `json:"room_id"` Type event.Type `json:"type"` Content interface{} `json:"content"` } // IsShareError returns true if the error is caused by the lack of an outgoing megolm session and can be solved with OlmMachine.ShareGroupSession func IsShareError(err error) bool { return err == SessionExpired || err == SessionNotShared || err == NoGroupSession } // EncryptMegolmEvent encrypts data with the m.megolm.v1.aes-sha2 algorithm. // // If you use the event.Content struct, make sure you pass a pointer to the struct, // as JSON serialization will not work correctly otherwise. func (mach *OlmMachine) EncryptMegolmEvent(roomID id.RoomID, evtType event.Type, content interface{}) (*event.EncryptedEventContent, error) { mach.Log.Trace("Encrypting event of type %s for %s", evtType.Type, roomID) session, err := mach.CryptoStore.GetOutboundGroupSession(roomID) if err != nil { return nil, fmt.Errorf("failed to get outbound group session: %w", err) } else if session == nil { return nil, NoGroupSession } plaintext, err := json.Marshal(&rawMegolmEvent{ RoomID: roomID, Type: evtType, Content: content, }) if err != nil { return nil, err } ciphertext, err := session.Encrypt(plaintext) if err != nil { return nil, err } err = mach.CryptoStore.UpdateOutboundGroupSession(session) if err != nil { mach.Log.Warn("Failed to update megolm session in crypto store after encrypting: %v", err) } return &event.EncryptedEventContent{ Algorithm: id.AlgorithmMegolmV1, SenderKey: mach.account.IdentityKey(), DeviceID: mach.Client.DeviceID, SessionID: session.ID(), MegolmCiphertext: ciphertext, RelatesTo: getRelatesTo(content), }, nil } func (mach *OlmMachine) newOutboundGroupSession(roomID id.RoomID) *OutboundGroupSession { session := NewOutboundGroupSession(roomID, mach.StateStore.GetEncryptionEvent(roomID)) signingKey, idKey := mach.account.Keys() mach.createGroupSession(idKey, signingKey, roomID, session.ID(), session.Internal.Key(), "create") return session } type deviceSessionWrapper struct { session *OlmSession identity *DeviceIdentity } // ShareGroupSession shares a group session for a specific room with all the devices of the given user list. // // For devices with TrustStateBlacklisted, a m.room_key.withheld event with code=m.blacklisted is sent. // If AllowUnverifiedDevices is false, a similar event with code=m.unverified is sent to devices with TrustStateUnset func (mach *OlmMachine) ShareGroupSession(roomID id.RoomID, users []id.UserID) error { mach.Log.Debug("Sharing group session for room %s to %v", roomID, users) session, err := mach.CryptoStore.GetOutboundGroupSession(roomID) if err != nil { return fmt.Errorf("failed to get previous outbound group session: %w", err) } else if session != nil && session.Shared && !session.Expired() { return AlreadyShared } if session == nil || session.Expired() { session = mach.newOutboundGroupSession(roomID) } withheldCount := 0 toDeviceWithheld := &mautrix.ReqSendToDevice{Messages: make(map[id.UserID]map[id.DeviceID]*event.Content)} olmSessions := make(map[id.UserID]map[id.DeviceID]deviceSessionWrapper) missingSessions := make(map[id.UserID]map[id.DeviceID]*DeviceIdentity) missingUserSessions := make(map[id.DeviceID]*DeviceIdentity) var fetchKeys []id.UserID for _, userID := range users { devices, err := mach.CryptoStore.GetDevices(userID) if err != nil { mach.Log.Error("Failed to get devices of %s", userID) } else if devices == nil { mach.Log.Trace("GetDevices returned nil for %s, will fetch keys and retry", userID) fetchKeys = append(fetchKeys, userID) } else if len(devices) == 0 { mach.Log.Trace("%s has no devices, skipping", userID) } else { mach.Log.Trace("Trying to find olm sessions to encrypt %s for %s", session.ID(), userID) toDeviceWithheld.Messages[userID] = make(map[id.DeviceID]*event.Content) olmSessions[userID] = make(map[id.DeviceID]deviceSessionWrapper) mach.findOlmSessionsForUser(session, userID, devices, olmSessions[userID], toDeviceWithheld.Messages[userID], missingUserSessions) mach.Log.Trace("Found %d sessions, withholding from %d sessions and missing %d sessions to encrypt %s for for %s", len(olmSessions[userID]), len(toDeviceWithheld.Messages[userID]), len(missingUserSessions), session.ID(), userID) withheldCount += len(toDeviceWithheld.Messages[userID]) if len(missingUserSessions) > 0 { missingSessions[userID] = missingUserSessions missingUserSessions = make(map[id.DeviceID]*DeviceIdentity) } if len(toDeviceWithheld.Messages[userID]) == 0 { delete(toDeviceWithheld.Messages, userID) } } } if len(fetchKeys) > 0 { mach.Log.Trace("Fetching missing keys for %v", fetchKeys) for userID, devices := range mach.fetchKeys(fetchKeys, "", true) { mach.Log.Trace("Got %d device keys for %s", len(devices), userID) missingSessions[userID] = devices } } if len(missingSessions) > 0 { mach.Log.Trace("Creating missing outbound sessions") err = mach.createOutboundSessions(missingSessions) if err != nil { mach.Log.Error("Failed to create missing outbound sessions: %v", err) } } for userID, devices := range missingSessions { if len(devices) == 0 { // No missing sessions continue } output, ok := olmSessions[userID] if !ok { output = make(map[id.DeviceID]deviceSessionWrapper) olmSessions[userID] = output } withheld, ok := toDeviceWithheld.Messages[userID] if !ok { withheld = make(map[id.DeviceID]*event.Content) toDeviceWithheld.Messages[userID] = withheld } mach.Log.Trace("Trying to find olm sessions to encrypt %s for %s (post-fetch retry)", session.ID(), userID) mach.findOlmSessionsForUser(session, userID, devices, output, withheld, nil) mach.Log.Trace("Found %d sessions and withholding from %d sessions to encrypt %s for for %s (post-fetch retry)", len(output), len(withheld), session.ID(), userID) withheldCount += len(toDeviceWithheld.Messages[userID]) if len(toDeviceWithheld.Messages[userID]) == 0 { delete(toDeviceWithheld.Messages, userID) } } err = mach.encryptAndSendGroupSession(session, olmSessions) if err != nil { return fmt.Errorf("failed to share group session: %w", err) } if len(toDeviceWithheld.Messages) > 0 { mach.Log.Trace("Sending to-device messages to %d devices of %d users to report withheld keys in %s", withheldCount, len(toDeviceWithheld.Messages), roomID) // TODO remove the next 4 lines once clients support m.room_key.withheld _, err = mach.Client.SendToDevice(event.ToDeviceOrgMatrixRoomKeyWithheld, toDeviceWithheld) if err != nil { mach.Log.Warn("Failed to report withheld keys in %s (legacy event type): %v", roomID, err) } _, err = mach.Client.SendToDevice(event.ToDeviceRoomKeyWithheld, toDeviceWithheld) if err != nil { mach.Log.Warn("Failed to report withheld keys in %s: %v", roomID, err) } } mach.Log.Debug("Group session %s for %s successfully shared", session.ID(), roomID) session.Shared = true return mach.CryptoStore.AddOutboundGroupSession(session) } func (mach *OlmMachine) encryptAndSendGroupSession(session *OutboundGroupSession, olmSessions map[id.UserID]map[id.DeviceID]deviceSessionWrapper) error { mach.olmLock.Lock() defer mach.olmLock.Unlock() mach.Log.Trace("Encrypting group session %s for all found devices", session.ID()) deviceCount := 0 toDevice := &mautrix.ReqSendToDevice{Messages: make(map[id.UserID]map[id.DeviceID]*event.Content)} for userID, sessions := range olmSessions { if len(sessions) == 0 { continue } output := make(map[id.DeviceID]*event.Content) toDevice.Messages[userID] = output for deviceID, device := range sessions { mach.Log.Trace("Encrypting group session %s for %s of %s", session.ID(), deviceID, userID) content := mach.encryptOlmEvent(device.session, device.identity, event.ToDeviceRoomKey, session.ShareContent()) output[deviceID] = &event.Content{Parsed: content} deviceCount++ mach.Log.Trace("Encrypted group session %s for %s of %s", session.ID(), deviceID, userID) } } mach.Log.Trace("Sending to-device to %d devices of %d users to share group session %s", deviceCount, len(toDevice.Messages), session.ID()) _, err := mach.Client.SendToDevice(event.ToDeviceEncrypted, toDevice) return err } func (mach *OlmMachine) findOlmSessionsForUser(session *OutboundGroupSession, userID id.UserID, devices map[id.DeviceID]*DeviceIdentity, output map[id.DeviceID]deviceSessionWrapper, withheld map[id.DeviceID]*event.Content, missingOutput map[id.DeviceID]*DeviceIdentity) { for deviceID, device := range devices { userKey := UserDevice{UserID: userID, DeviceID: deviceID} if state := session.Users[userKey]; state != OGSNotShared { continue } else if userID == mach.Client.UserID && deviceID == mach.Client.DeviceID { session.Users[userKey] = OGSIgnored } else if device.Trust == TrustStateBlacklisted { mach.Log.Debug("Not encrypting group session %s for %s of %s: device is blacklisted", session.ID(), deviceID, userID) withheld[deviceID] = &event.Content{Parsed: &event.RoomKeyWithheldEventContent{ RoomID: session.RoomID, Algorithm: id.AlgorithmMegolmV1, SessionID: session.ID(), SenderKey: mach.account.IdentityKey(), Code: event.RoomKeyWithheldBlacklisted, Reason: "Device is blacklisted", }} session.Users[userKey] = OGSIgnored } else if !mach.AllowUnverifiedDevices && !mach.IsDeviceTrusted(device) { mach.Log.Debug("Not encrypting group session %s for %s of %s: device is not verified", session.ID(), deviceID, userID) withheld[deviceID] = &event.Content{Parsed: &event.RoomKeyWithheldEventContent{ RoomID: session.RoomID, Algorithm: id.AlgorithmMegolmV1, SessionID: session.ID(), SenderKey: mach.account.IdentityKey(), Code: event.RoomKeyWithheldUnverified, Reason: "This device does not encrypt messages for unverified devices", }} session.Users[userKey] = OGSIgnored } else if deviceSession, err := mach.CryptoStore.GetLatestSession(device.IdentityKey); err != nil { mach.Log.Error("Failed to get session for %s of %s: %v", deviceID, userID, err) } else if deviceSession == nil { mach.Log.Warn("Didn't find a session for %s of %s", deviceID, userID) if missingOutput != nil { missingOutput[deviceID] = device } } else { output[deviceID] = deviceSessionWrapper{ session: deviceSession, identity: device, } session.Users[userKey] = OGSAlreadyShared } } } go-0.11.1/crypto/encryptolm.go000066400000000000000000000076511436100171500162510ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package crypto import ( "encoding/json" "fmt" "maunium.net/go/mautrix" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) func (mach *OlmMachine) encryptOlmEvent(session *OlmSession, recipient *DeviceIdentity, evtType event.Type, content event.Content) *event.EncryptedEventContent { evt := &DecryptedOlmEvent{ Sender: mach.Client.UserID, SenderDevice: mach.Client.DeviceID, Keys: OlmEventKeys{Ed25519: mach.account.SigningKey()}, Recipient: recipient.UserID, RecipientKeys: OlmEventKeys{Ed25519: recipient.SigningKey}, Type: evtType, Content: content, } plaintext, err := json.Marshal(evt) if err != nil { panic(err) } mach.Log.Trace("Encrypting olm message for %s with session %s: %s", recipient.IdentityKey, session.ID(), session.Describe()) msgType, ciphertext := session.Encrypt(plaintext) err = mach.CryptoStore.UpdateSession(recipient.IdentityKey, session) if err != nil { mach.Log.Warn("Failed to update olm session in crypto store after encrypting: %v", err) } return &event.EncryptedEventContent{ Algorithm: id.AlgorithmOlmV1, SenderKey: mach.account.IdentityKey(), OlmCiphertext: event.OlmCiphertexts{ recipient.IdentityKey: { Type: msgType, Body: string(ciphertext), }, }, } } func (mach *OlmMachine) shouldCreateNewSession(identityKey id.IdentityKey) bool { if !mach.CryptoStore.HasSession(identityKey) { return true } mach.devicesToUnwedgeLock.Lock() _, shouldUnwedge := mach.devicesToUnwedge[identityKey] if shouldUnwedge { delete(mach.devicesToUnwedge, identityKey) } mach.devicesToUnwedgeLock.Unlock() return shouldUnwedge } func (mach *OlmMachine) createOutboundSessions(input map[id.UserID]map[id.DeviceID]*DeviceIdentity) error { request := make(mautrix.OneTimeKeysRequest) for userID, devices := range input { request[userID] = make(map[id.DeviceID]id.KeyAlgorithm) for deviceID, identity := range devices { if mach.shouldCreateNewSession(identity.IdentityKey) { request[userID][deviceID] = id.KeyAlgorithmSignedCurve25519 } } if len(request[userID]) == 0 { delete(request, userID) } } if len(request) == 0 { return nil } resp, err := mach.Client.ClaimKeys(&mautrix.ReqClaimKeys{ OneTimeKeys: request, Timeout: 10 * 1000, }) if err != nil { return fmt.Errorf("failed to claim keys: %w", err) } for userID, user := range resp.OneTimeKeys { for deviceID, oneTimeKeys := range user { var oneTimeKey mautrix.OneTimeKey var keyID id.KeyID for keyID, oneTimeKey = range oneTimeKeys { break } keyAlg, keyIndex := keyID.Parse() if keyAlg != id.KeyAlgorithmSignedCurve25519 { mach.Log.Warn("Unexpected key ID algorithm in one-time key response for %s of %s: %s", deviceID, userID, keyID) continue } identity := input[userID][deviceID] if ok, err := olm.VerifySignatureJSON(oneTimeKey, userID, deviceID.String(), identity.SigningKey); err != nil { mach.Log.Error("Failed to verify signature for %s of %s: %v", deviceID, userID, err) } else if !ok { mach.Log.Warn("Invalid signature for %s of %s", deviceID, userID) } else if sess, err := mach.account.Internal.NewOutboundSession(identity.IdentityKey, oneTimeKey.Key); err != nil { mach.Log.Error("Failed to create outbound session for %s of %s: %v", deviceID, userID, err) } else { wrapped := wrapSession(sess) err = mach.CryptoStore.AddSession(identity.IdentityKey, wrapped) if err != nil { mach.Log.Error("Failed to store created session for %s of %s: %v", deviceID, userID, err) } else { mach.Log.Debug("Created new Olm session with %s/%s (OTK ID: %d)", userID, deviceID, keyIndex) } } } } return nil } go-0.11.1/crypto/keyexport.go000066400000000000000000000130751436100171500161040ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package crypto import ( "bytes" "crypto/aes" "crypto/cipher" "crypto/hmac" "crypto/rand" "crypto/sha256" "crypto/sha512" "encoding/base64" "encoding/binary" "encoding/json" "fmt" "math" "golang.org/x/crypto/pbkdf2" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/id" ) type SenderClaimedKeys struct { Ed25519 id.Ed25519 `json:"ed25519"` } type ExportedSession struct { Algorithm id.Algorithm `json:"algorithm"` ForwardingChains []string `json:"forwarding_curve25519_key_chain"` RoomID id.RoomID `json:"room_id"` SenderKey id.SenderKey `json:"sender_key"` SenderClaimedKeys SenderClaimedKeys `json:"sender_claimed_keys"` SessionID id.SessionID `json:"session_id"` SessionKey string `json:"session_key"` } // The default number of pbkdf2 rounds to use when exporting keys const defaultPassphraseRounds = 100000 const exportPrefix = "-----BEGIN MEGOLM SESSION DATA-----\n" const exportSuffix = "-----END MEGOLM SESSION DATA-----\n" // Only version 0x01 is currently specified in the spec const exportVersion1 = 0x01 // The standard for wrapping base64 is 76 bytes const exportLineLengthLimit = 76 // Byte count for version + salt + iv + number of rounds const exportHeaderLength = 1 + 16 + 16 + 4 // SHA-256 hash length const exportHashLength = 32 func computeKey(passphrase string, salt []byte, rounds int) (encryptionKey, hashKey []byte) { key := pbkdf2.Key([]byte(passphrase), salt, rounds, 64, sha512.New) encryptionKey = key[:32] hashKey = key[32:] return } func makeExportIV() []byte { iv := make([]byte, 16) _, err := rand.Read(iv) if err != nil { panic(olm.NotEnoughGoRandom) } // Set bit 63 to zero iv[7] &= 0b11111110 return iv } func makeExportKeys(passphrase string) (encryptionKey, hashKey, salt, iv []byte) { salt = make([]byte, 16) _, err := rand.Read(salt) if err != nil { panic(olm.NotEnoughGoRandom) } encryptionKey, hashKey = computeKey(passphrase, salt, defaultPassphraseRounds) iv = makeExportIV() return } func exportSessions(sessions []*InboundGroupSession) ([]ExportedSession, error) { export := make([]ExportedSession, len(sessions)) for i, session := range sessions { key, err := session.Internal.Export(session.Internal.FirstKnownIndex()) if err != nil { return nil, fmt.Errorf("failed to export session: %w", err) } export[i] = ExportedSession{ Algorithm: id.AlgorithmMegolmV1, ForwardingChains: session.ForwardingChains, RoomID: session.RoomID, SenderKey: session.SenderKey, SenderClaimedKeys: SenderClaimedKeys{}, SessionID: session.ID(), SessionKey: key, } } return export, nil } func exportSessionsJSON(sessions []*InboundGroupSession) ([]byte, error) { exportedSessions, err := exportSessions(sessions) if err != nil { return nil, err } return json.Marshal(exportedSessions) } func min(a, b int) int { if a > b { return b } return a } func formatKeyExportData(data []byte) []byte { base64Data := make([]byte, base64.StdEncoding.EncodedLen(len(data))) base64.StdEncoding.Encode(base64Data, data) // Prefix + data and newline for each 76 characters of data + suffix outputLength := len(exportPrefix) + len(base64Data) + int(math.Ceil(float64(len(base64Data))/exportLineLengthLimit)) + len(exportSuffix) var buf bytes.Buffer buf.Grow(outputLength) buf.WriteString(exportPrefix) for ptr := 0; ptr < len(base64Data); ptr += exportLineLengthLimit { buf.Write(base64Data[ptr:min(ptr+exportLineLengthLimit, len(base64Data))]) buf.WriteRune('\n') } buf.WriteString(exportSuffix) if buf.Len() != buf.Cap() || buf.Len() != outputLength { panic(fmt.Errorf("unexpected length %d / %d / %d", buf.Len(), buf.Cap(), outputLength)) } return buf.Bytes() } // ExportKeys exports the given Megolm sessions with the format specified in the Matrix spec. // See https://spec.matrix.org/v1.2/client-server-api/#key-exports func ExportKeys(passphrase string, sessions []*InboundGroupSession) ([]byte, error) { // Make all the keys necessary for exporting encryptionKey, hashKey, salt, iv := makeExportKeys(passphrase) // Export all the given sessions and put them in JSON unencryptedData, err := exportSessionsJSON(sessions) if err != nil { return nil, err } // The export data consists of: // 1 byte of export format version // 16 bytes of salt // 16 bytes of IV (initialization vector) // 4 bytes of the number of rounds // the encrypted export data // 32 bytes of the hash of all the data above exportData := make([]byte, exportHeaderLength+len(unencryptedData)+exportHashLength) dataWithoutHashLength := len(exportData) - exportHashLength // Create the header for the export data exportData[0] = exportVersion1 copy(exportData[1:17], salt) copy(exportData[17:33], iv) binary.BigEndian.PutUint32(exportData[33:37], defaultPassphraseRounds) // Encrypt data with AES-256-CTR block, _ := aes.NewCipher(encryptionKey) cipher.NewCTR(block, iv).XORKeyStream(exportData[exportHeaderLength:dataWithoutHashLength], unencryptedData) // Hash all the data with HMAC-SHA256 and put it at the end mac := hmac.New(sha256.New, hashKey) mac.Write(exportData[:dataWithoutHashLength]) mac.Sum(exportData[:dataWithoutHashLength]) // Format the export (prefix, base64'd exportData, suffix) and return return formatKeyExportData(exportData), nil } go-0.11.1/crypto/keyimport.go000066400000000000000000000125501436100171500160720ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package crypto import ( "bytes" "crypto/aes" "crypto/cipher" "crypto/hmac" "crypto/sha256" "encoding/base64" "encoding/binary" "encoding/json" "errors" "fmt" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/id" ) var ( ErrMissingExportPrefix = errors.New("invalid Matrix key export: missing prefix") ErrMissingExportSuffix = errors.New("invalid Matrix key export: missing suffix") ErrUnsupportedExportVersion = errors.New("unsupported Matrix key export format version") ErrMismatchingExportHash = errors.New("mismatching hash; incorrect passphrase?") ErrInvalidExportedAlgorithm = errors.New("session has unknown algorithm") ErrMismatchingExportedSessionID = errors.New("imported session has different ID than expected") ) var exportPrefixBytes, exportSuffixBytes = []byte(exportPrefix), []byte(exportSuffix) func decodeKeyExport(data []byte) ([]byte, error) { // If the valid prefix and suffix aren't there, it's probably not a Matrix key export if !bytes.HasPrefix(data, exportPrefixBytes) { return nil, ErrMissingExportPrefix } else if !bytes.HasSuffix(data, exportSuffixBytes) { return nil, ErrMissingExportSuffix } // Remove the prefix and suffix, we don't care about them anymore data = data[len(exportPrefix) : len(data)-len(exportSuffix)] // Allocate space for the decoded data. Ignore newlines when counting the length exportData := make([]byte, base64.StdEncoding.DecodedLen(len(data)-bytes.Count(data, []byte{'\n'}))) n, err := base64.StdEncoding.Decode(exportData, data) if err != nil { return nil, err } return exportData[:n], nil } func decryptKeyExport(passphrase string, exportData []byte) ([]ExportedSession, error) { if exportData[0] != exportVersion1 { return nil, ErrUnsupportedExportVersion } // Get all the different parts of the export salt := exportData[1:17] iv := exportData[17:33] passphraseRounds := binary.BigEndian.Uint32(exportData[33:37]) dataWithoutHashLength := len(exportData) - exportHashLength encryptedData := exportData[exportHeaderLength:dataWithoutHashLength] hash := exportData[dataWithoutHashLength:] // Compute the encryption and hash keys from the passphrase and salt encryptionKey, hashKey := computeKey(passphrase, salt, int(passphraseRounds)) // Compute and verify the hash. If it doesn't match, the passphrase is probably wrong mac := hmac.New(sha256.New, hashKey) mac.Write(exportData[:dataWithoutHashLength]) if !bytes.Equal(hash, mac.Sum(nil)) { return nil, ErrMismatchingExportHash } // Decrypt the export block, _ := aes.NewCipher(encryptionKey) unencryptedData := make([]byte, len(exportData)-exportHashLength-exportHeaderLength) cipher.NewCTR(block, iv).XORKeyStream(unencryptedData, encryptedData) // Parse the decrypted JSON var sessionsJSON []ExportedSession err := json.Unmarshal(unencryptedData, &sessionsJSON) if err != nil { return nil, fmt.Errorf("invalid export json: %w", err) } return sessionsJSON, nil } func (mach *OlmMachine) importExportedRoomKey(session ExportedSession) (bool, error) { if session.Algorithm != id.AlgorithmMegolmV1 { return false, ErrInvalidExportedAlgorithm } igsInternal, err := olm.InboundGroupSessionImport([]byte(session.SessionKey)) if err != nil { return false, fmt.Errorf("failed to import session: %w", err) } else if igsInternal.ID() != session.SessionID { return false, ErrMismatchingExportedSessionID } igs := &InboundGroupSession{ Internal: *igsInternal, SigningKey: session.SenderClaimedKeys.Ed25519, SenderKey: session.SenderKey, RoomID: session.RoomID, // TODO should we add something here to mark the signing key as unverified like key requests do? ForwardingChains: session.ForwardingChains, } existingIGS, _ := mach.CryptoStore.GetGroupSession(igs.RoomID, igs.SenderKey, igs.ID()) if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= igs.Internal.FirstKnownIndex() { // We already have an equivalent or better session in the store, so don't override it. return false, nil } err = mach.CryptoStore.PutGroupSession(igs.RoomID, igs.SenderKey, igs.ID(), igs) if err != nil { return false, fmt.Errorf("failed to store imported session: %w", err) } mach.markSessionReceived(igs.ID()) return true, nil } // ImportKeys imports data that was exported with the format specified in the Matrix spec. // See https://spec.matrix.org/v1.2/client-server-api/#key-exports func (mach *OlmMachine) ImportKeys(passphrase string, data []byte) (int, int, error) { exportData, err := decodeKeyExport(data) if err != nil { return 0, 0, err } sessions, err := decryptKeyExport(passphrase, exportData) if err != nil { return 0, 0, err } count := 0 for _, session := range sessions { imported, err := mach.importExportedRoomKey(session) if err != nil { mach.Log.Warn("Failed to import Megolm session %s/%s from file: %v", session.RoomID, session.SessionID, err) } else if imported { mach.Log.Debug("Imported Megolm session %s/%s from file", session.RoomID, session.SessionID) count++ } else { mach.Log.Debug("Skipped Megolm session %s/%s: already in store", session.RoomID, session.SessionID) } } return count, len(sessions), nil } go-0.11.1/crypto/keysharing.go000066400000000000000000000253501436100171500162150ustar00rootroot00000000000000// Copyright (c) 2020 Nikos Filippakis // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. //go:build !nosas // +build !nosas package crypto import ( "context" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/id" "maunium.net/go/mautrix" "maunium.net/go/mautrix/event" ) type KeyShareRejection struct { Code event.RoomKeyWithheldCode Reason string } var ( // Reject a key request without responding KeyShareRejectNoResponse = KeyShareRejection{} KeyShareRejectBlacklisted = KeyShareRejection{event.RoomKeyWithheldBlacklisted, "You have been blacklisted by this device"} KeyShareRejectUnverified = KeyShareRejection{event.RoomKeyWithheldUnverified, "You have not been verified by this device"} KeyShareRejectOtherUser = KeyShareRejection{event.RoomKeyWithheldUnauthorized, "This device does not share keys to other users"} KeyShareRejectUnavailable = KeyShareRejection{event.RoomKeyWithheldUnavailable, "Requested session ID not found on this device"} KeyShareRejectInternalError = KeyShareRejection{event.RoomKeyWithheldUnavailable, "An internal error occurred while trying to share the requested session"} ) // RequestRoomKey sends a key request for a room to the current user's devices. If the context is cancelled, then so is the key request. // Returns a bool channel that will get notified either when the key is received or the request is cancelled. // // Deprecated: this only supports a single key request target, so the whole automatic cancelling feature isn't very useful. func (mach *OlmMachine) RequestRoomKey(ctx context.Context, toUser id.UserID, toDevice id.DeviceID, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (chan bool, error) { requestID := mach.Client.TxnID() keyResponseReceived := make(chan struct{}) mach.roomKeyRequestFilled.Store(sessionID, keyResponseReceived) err := mach.SendRoomKeyRequest(roomID, senderKey, sessionID, requestID, map[id.UserID][]id.DeviceID{toUser: {toDevice}}) if err != nil { return nil, err } resChan := make(chan bool, 1) go func() { select { case <-keyResponseReceived: // key request successful mach.Log.Debug("Key for session %v was received, cancelling other key requests", sessionID) resChan <- true case <-ctx.Done(): // if the context is done, key request was unsuccessful mach.Log.Debug("Context closed (%v) before forwared key for session %v received, sending key request cancellation", ctx.Err(), sessionID) resChan <- false } // send a message to all devices cancelling this key request mach.roomKeyRequestFilled.Delete(sessionID) cancelEvtContent := &event.Content{ Parsed: event.RoomKeyRequestEventContent{ Action: event.KeyRequestActionCancel, RequestID: requestID, RequestingDeviceID: mach.Client.DeviceID, }, } toDeviceCancel := &mautrix.ReqSendToDevice{ Messages: map[id.UserID]map[id.DeviceID]*event.Content{ toUser: { toDevice: cancelEvtContent, }, }, } mach.Client.SendToDevice(event.ToDeviceRoomKeyRequest, toDeviceCancel) }() return resChan, nil } // SendRoomKeyRequest sends a key request for the given key (identified by the room ID, sender key and session ID) to the given users. // // The request ID parameter is optional. If it's empty, a random ID will be generated. // // This function does not wait for the keys to arrive. You can use WaitForSession to wait for the session to // arrive (in any way, not just as a reply to this request). There's also RequestRoomKey which waits for a response // to the specific key request, but currently it only supports a single target device and is therefore deprecated. // A future function may properly support multiple targets and automatically canceling the other requests when receiving // the first response. func (mach *OlmMachine) SendRoomKeyRequest(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, requestID string, users map[id.UserID][]id.DeviceID) error { if len(requestID) == 0 { requestID = mach.Client.TxnID() } requestEvent := &event.Content{ Parsed: &event.RoomKeyRequestEventContent{ Action: event.KeyRequestActionRequest, Body: event.RequestedKeyInfo{ Algorithm: id.AlgorithmMegolmV1, RoomID: roomID, SenderKey: senderKey, SessionID: sessionID, }, RequestID: requestID, RequestingDeviceID: mach.Client.DeviceID, }, } toDeviceReq := &mautrix.ReqSendToDevice{ Messages: make(map[id.UserID]map[id.DeviceID]*event.Content, len(users)), } for user, devices := range users { toDeviceReq.Messages[user] = make(map[id.DeviceID]*event.Content, len(devices)) for _, device := range devices { toDeviceReq.Messages[user][device] = requestEvent } } _, err := mach.Client.SendToDevice(event.ToDeviceRoomKeyRequest, toDeviceReq) return err } func (mach *OlmMachine) importForwardedRoomKey(evt *DecryptedOlmEvent, content *event.ForwardedRoomKeyEventContent) bool { if content.Algorithm != id.AlgorithmMegolmV1 || evt.Keys.Ed25519 == "" { mach.Log.Debug("Ignoring weird forwarded room key from %s/%s: alg=%s, ed25519=%s, sessionid=%s, roomid=%s", evt.Sender, evt.SenderDevice, content.Algorithm, evt.Keys.Ed25519, content.SessionID, content.RoomID) return false } igsInternal, err := olm.InboundGroupSessionImport([]byte(content.SessionKey)) if err != nil { mach.Log.Error("Failed to import inbound group session: %v", err) return false } else if igsInternal.ID() != content.SessionID { mach.Log.Warn("Mismatched session ID while creating inbound group session") return false } igs := &InboundGroupSession{ Internal: *igsInternal, SigningKey: evt.Keys.Ed25519, SenderKey: content.SenderKey, RoomID: content.RoomID, ForwardingChains: append(content.ForwardingKeyChain, evt.SenderKey.String()), id: content.SessionID, } err = mach.CryptoStore.PutGroupSession(content.RoomID, content.SenderKey, content.SessionID, igs) if err != nil { mach.Log.Error("Failed to store new inbound group session: %v", err) return false } mach.markSessionReceived(content.SessionID) mach.Log.Trace("Received forwarded inbound group session %s/%s/%s", content.RoomID, content.SenderKey, content.SessionID) return true } func (mach *OlmMachine) rejectKeyRequest(rejection KeyShareRejection, device *DeviceIdentity, request event.RequestedKeyInfo) { if rejection.Code == "" { // If the rejection code is empty, it means don't share keys, but also don't tell the requester. return } content := event.RoomKeyWithheldEventContent{ RoomID: request.RoomID, Algorithm: request.Algorithm, SessionID: request.SessionID, SenderKey: request.SenderKey, Code: rejection.Code, Reason: rejection.Reason, } err := mach.sendToOneDevice(device.UserID, device.DeviceID, event.ToDeviceRoomKeyWithheld, &content) if err != nil { mach.Log.Warn("Failed to send key share rejection %s to %s/%s: %v", rejection.Code, device.UserID, device.DeviceID, err) } err = mach.sendToOneDevice(device.UserID, device.DeviceID, event.ToDeviceOrgMatrixRoomKeyWithheld, &content) if err != nil { mach.Log.Warn("Failed to send key share rejection %s (org.matrix.) to %s/%s: %v", rejection.Code, device.UserID, device.DeviceID, err) } } func (mach *OlmMachine) defaultAllowKeyShare(device *DeviceIdentity, _ event.RequestedKeyInfo) *KeyShareRejection { if mach.Client.UserID != device.UserID { mach.Log.Debug("Ignoring key request from a different user (%s)", device.UserID) return &KeyShareRejectOtherUser } else if mach.Client.DeviceID == device.DeviceID { mach.Log.Debug("Ignoring key request from ourselves") return &KeyShareRejectNoResponse } else if device.Trust == TrustStateBlacklisted { mach.Log.Debug("Ignoring key request from blacklisted device %s", device.DeviceID) return &KeyShareRejectBlacklisted } else if mach.IsDeviceTrusted(device) { mach.Log.Debug("Accepting key request from verified device %s", device.DeviceID) return nil } else if mach.ShareKeysToUnverifiedDevices { mach.Log.Debug("Accepting key request from unverified device %s (ShareKeysToUnverifiedDevices is true)", device.DeviceID) return nil } else { mach.Log.Debug("Ignoring key request from unverified device %s", device.DeviceID) return &KeyShareRejectUnverified } } func (mach *OlmMachine) handleRoomKeyRequest(sender id.UserID, content *event.RoomKeyRequestEventContent) { if content.Action != event.KeyRequestActionRequest { return } else if content.RequestingDeviceID == mach.Client.DeviceID && sender == mach.Client.UserID { mach.Log.Debug("Ignoring key request %s from ourselves", content.RequestID) return } mach.Log.Debug("Received key request %s for %s from %s/%s", content.RequestID, content.Body.SessionID, sender, content.RequestingDeviceID) device, err := mach.GetOrFetchDevice(sender, content.RequestingDeviceID) if err != nil { mach.Log.Error("Failed to fetch device %s/%s that requested keys: %v", sender, content.RequestingDeviceID, err) return } rejection := mach.AllowKeyShare(device, content.Body) if rejection != nil { mach.rejectKeyRequest(*rejection, device, content.Body) return } igs, err := mach.CryptoStore.GetGroupSession(content.Body.RoomID, content.Body.SenderKey, content.Body.SessionID) if err != nil { mach.Log.Error("Failed to fetch group session to forward to %s/%s: %v", device.UserID, device.DeviceID, err) mach.rejectKeyRequest(KeyShareRejectInternalError, device, content.Body) return } else if igs == nil { mach.Log.Warn("Didn't find group session %s to forward to %s/%s", content.Body.SessionID, device.UserID, device.DeviceID) mach.rejectKeyRequest(KeyShareRejectUnavailable, device, content.Body) return } exportedKey, err := igs.Internal.Export(igs.Internal.FirstKnownIndex()) if err != nil { mach.Log.Error("Failed to export session %s to forward to %s/%s: %v", igs.ID(), device.UserID, device.DeviceID, err) mach.rejectKeyRequest(KeyShareRejectInternalError, device, content.Body) return } forwardedRoomKey := event.Content{ Parsed: &event.ForwardedRoomKeyEventContent{ RoomKeyEventContent: event.RoomKeyEventContent{ Algorithm: id.AlgorithmMegolmV1, RoomID: igs.RoomID, SessionID: igs.ID(), SessionKey: exportedKey, }, SenderKey: content.Body.SenderKey, ForwardingKeyChain: igs.ForwardingChains, SenderClaimedKey: igs.SigningKey, }, } if err := mach.SendEncryptedToDevice(device, event.ToDeviceForwardedRoomKey, forwardedRoomKey); err != nil { mach.Log.Error("Failed to send encrypted forwarded key %s to %s/%s: %v", igs.ID(), device.UserID, device.DeviceID, err) } mach.Log.Debug("Sent encrypted forwarded key to device %s/%s for session %s", device.UserID, device.DeviceID, igs.ID()) } go-0.11.1/crypto/machine.go000066400000000000000000000475151436100171500154640ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package crypto import ( "errors" "fmt" "sync" "time" "maunium.net/go/mautrix/appservice" "maunium.net/go/mautrix/crypto/ssss" "maunium.net/go/mautrix/id" "maunium.net/go/mautrix" "maunium.net/go/mautrix/event" ) // Logger is a simple logging struct for OlmMachine. // Implementations are recommended to use fmt.Sprintf and manually add a newline after the message. type Logger interface { Error(message string, args ...interface{}) Warn(message string, args ...interface{}) Debug(message string, args ...interface{}) Trace(message string, args ...interface{}) } // OlmMachine is the main struct for handling Matrix end-to-end encryption. type OlmMachine struct { Client *mautrix.Client SSSS *ssss.Machine Log Logger CryptoStore Store StateStore StateStore AllowUnverifiedDevices bool ShareKeysToUnverifiedDevices bool AllowKeyShare func(*DeviceIdentity, event.RequestedKeyInfo) *KeyShareRejection DefaultSASTimeout time.Duration // AcceptVerificationFrom determines whether the machine will accept verification requests from this device. AcceptVerificationFrom func(string, *DeviceIdentity, id.RoomID) (VerificationRequestResponse, VerificationHooks) account *OlmAccount roomKeyRequestFilled *sync.Map keyVerificationTransactionState *sync.Map keyWaiters map[id.SessionID]chan struct{} keyWaitersLock sync.Mutex devicesToUnwedge map[id.IdentityKey]bool devicesToUnwedgeLock sync.Mutex recentlyUnwedged map[id.IdentityKey]time.Time recentlyUnwedgedLock sync.Mutex olmLock sync.Mutex CrossSigningKeys *CrossSigningKeysCache crossSigningPubkeys *CrossSigningPublicKeysCache } // StateStore is used by OlmMachine to get room state information that's needed for encryption. type StateStore interface { // IsEncrypted returns whether a room is encrypted. IsEncrypted(id.RoomID) bool // GetEncryptionEvent returns the encryption event's content for an encrypted room. GetEncryptionEvent(id.RoomID) *event.EncryptionEventContent // FindSharedRooms returns the encrypted rooms that another user is also in for a user ID. FindSharedRooms(id.UserID) []id.RoomID } // NewOlmMachine creates an OlmMachine with the given client, logger and stores. func NewOlmMachine(client *mautrix.Client, log Logger, cryptoStore Store, stateStore StateStore) *OlmMachine { mach := &OlmMachine{ Client: client, SSSS: ssss.NewSSSSMachine(client), Log: log, CryptoStore: cryptoStore, StateStore: stateStore, AllowUnverifiedDevices: true, ShareKeysToUnverifiedDevices: false, DefaultSASTimeout: 10 * time.Minute, AcceptVerificationFrom: func(string, *DeviceIdentity, id.RoomID) (VerificationRequestResponse, VerificationHooks) { // Reject requests by default. Users need to override this to return appropriate verification hooks. return RejectRequest, nil }, roomKeyRequestFilled: &sync.Map{}, keyVerificationTransactionState: &sync.Map{}, keyWaiters: make(map[id.SessionID]chan struct{}), devicesToUnwedge: make(map[id.IdentityKey]bool), recentlyUnwedged: make(map[id.IdentityKey]time.Time), } mach.AllowKeyShare = mach.defaultAllowKeyShare return mach } // Load loads the Olm account information from the crypto store. If there's no olm account, a new one is created. // This must be called before using the machine. func (mach *OlmMachine) Load() (err error) { mach.account, err = mach.CryptoStore.GetAccount() if err != nil { return } if mach.account == nil { mach.account = NewOlmAccount() } return nil } func (mach *OlmMachine) saveAccount() { err := mach.CryptoStore.PutAccount(mach.account) if err != nil { mach.Log.Error("Failed to save account: %v", err) } } // FlushStore calls the Flush method of the CryptoStore. func (mach *OlmMachine) FlushStore() error { return mach.CryptoStore.Flush() } func (mach *OlmMachine) timeTrace(thing, trace string, expectedDuration time.Duration) func() { start := time.Now() return func() { duration := time.Now().Sub(start) if duration > expectedDuration { mach.Log.Warn("%s took %s (trace: %s)", thing, duration, trace) } } } func Fingerprint(signingKey id.SigningKey) string { spacedSigningKey := make([]byte, len(signingKey)+(len(signingKey)-1)/4) var ptr = 0 for i, chr := range signingKey { spacedSigningKey[ptr] = byte(chr) ptr++ if i%4 == 3 { spacedSigningKey[ptr] = ' ' ptr++ } } return string(spacedSigningKey) } // Fingerprint returns the fingerprint of the Olm account that can be used for non-interactive verification. func (mach *OlmMachine) Fingerprint() string { return Fingerprint(mach.account.SigningKey()) } // OwnIdentity returns this device's DeviceIdentity struct func (mach *OlmMachine) OwnIdentity() *DeviceIdentity { return &DeviceIdentity{ UserID: mach.Client.UserID, DeviceID: mach.Client.DeviceID, IdentityKey: mach.account.IdentityKey(), SigningKey: mach.account.SigningKey(), Trust: TrustStateVerified, Deleted: false, } } func (mach *OlmMachine) AddAppserviceListener(ep *appservice.EventProcessor, az *appservice.AppService) { // ToDeviceForwardedRoomKey and ToDeviceRoomKey should only be present inside encrypted to-device events ep.On(event.ToDeviceEncrypted, mach.HandleToDeviceEvent) ep.On(event.ToDeviceRoomKeyRequest, mach.HandleToDeviceEvent) ep.On(event.ToDeviceRoomKeyWithheld, mach.HandleToDeviceEvent) ep.On(event.ToDeviceOrgMatrixRoomKeyWithheld, mach.HandleToDeviceEvent) ep.On(event.ToDeviceVerificationRequest, mach.HandleToDeviceEvent) ep.On(event.ToDeviceVerificationStart, mach.HandleToDeviceEvent) ep.On(event.ToDeviceVerificationAccept, mach.HandleToDeviceEvent) ep.On(event.ToDeviceVerificationKey, mach.HandleToDeviceEvent) ep.On(event.ToDeviceVerificationMAC, mach.HandleToDeviceEvent) ep.On(event.ToDeviceVerificationCancel, mach.HandleToDeviceEvent) ep.OnOTK(mach.HandleOTKCounts) ep.OnDeviceList(mach.HandleDeviceLists) mach.Log.Trace("Added listeners for encryption data coming from appservice transactions") } func (mach *OlmMachine) HandleDeviceLists(dl *mautrix.DeviceLists, since string) { if len(dl.Changed) > 0 { traceID := time.Now().Format("15:04:05.000000") mach.Log.Trace("Device list changes in /sync: %v (trace: %s)", dl.Changed, traceID) mach.fetchKeys(dl.Changed, since, false) mach.Log.Trace("Finished handling device list changes (trace: %s)", traceID) } } func (mach *OlmMachine) HandleOTKCounts(otkCount *mautrix.OTKCount) { if (len(otkCount.UserID) > 0 && otkCount.UserID != mach.Client.UserID) || (len(otkCount.DeviceID) > 0 && otkCount.DeviceID != mach.Client.DeviceID) { // TODO This log probably needs to be silence-able if someone wants to use encrypted appservices with multiple e2ee sessions mach.Log.Debug("Dropping OTK counts targeted to %s/%s (not us)", otkCount.UserID, otkCount.DeviceID) return } minCount := mach.account.Internal.MaxNumberOfOneTimeKeys() / 2 if otkCount.SignedCurve25519 < int(minCount) { traceID := time.Now().Format("15:04:05.000000") mach.Log.Debug("Sync response said we have %d signed curve25519 keys left, sharing new ones... (trace: %s)", otkCount.SignedCurve25519, traceID) err := mach.ShareKeys(otkCount.SignedCurve25519) if err != nil { mach.Log.Error("Failed to share keys: %v (trace: %s)", err, traceID) } else { mach.Log.Debug("Successfully shared keys (trace: %s)", traceID) } } } // ProcessSyncResponse processes a single /sync response. // // This can be easily registered into a mautrix client using .OnSync(): // // client.Syncer.(*mautrix.DefaultSyncer).OnSync(c.crypto.ProcessSyncResponse) func (mach *OlmMachine) ProcessSyncResponse(resp *mautrix.RespSync, since string) bool { mach.HandleDeviceLists(&resp.DeviceLists, since) for _, evt := range resp.ToDevice.Events { evt.Type.Class = event.ToDeviceEventType err := evt.Content.ParseRaw(evt.Type) if err != nil { mach.Log.Warn("Failed to parse to-device event of type %s: %v", evt.Type.Type, err) continue } mach.HandleToDeviceEvent(evt) } mach.HandleOTKCounts(&resp.DeviceOTKCount) return true } // HandleMemberEvent handles a single membership event. // // Currently this is not automatically called, so you must add a listener yourself: // // client.Syncer.(*mautrix.DefaultSyncer).OnEventType(event.StateMember, c.crypto.HandleMemberEvent) func (mach *OlmMachine) HandleMemberEvent(evt *event.Event) { if !mach.StateStore.IsEncrypted(evt.RoomID) { return } content := evt.Content.AsMember() if content == nil { return } var prevContent *event.MemberEventContent if evt.Unsigned.PrevContent != nil { _ = evt.Unsigned.PrevContent.ParseRaw(evt.Type) prevContent = evt.Unsigned.PrevContent.AsMember() } if prevContent == nil { prevContent = &event.MemberEventContent{Membership: "unknown"} } if prevContent.Membership == content.Membership || (prevContent.Membership == event.MembershipInvite && content.Membership == event.MembershipJoin) || (prevContent.Membership == event.MembershipBan && content.Membership == event.MembershipLeave) || (prevContent.Membership == event.MembershipLeave && content.Membership == event.MembershipBan) { return } mach.Log.Trace("Got membership state event in %s changing %s from %s to %s, invalidating group session", evt.RoomID, evt.GetStateKey(), prevContent.Membership, content.Membership) err := mach.CryptoStore.RemoveOutboundGroupSession(evt.RoomID) if err != nil { mach.Log.Warn("Failed to invalidate outbound group session of %s: %v", evt.RoomID, err) } } // HandleToDeviceEvent handles a single to-device event. This is automatically called by ProcessSyncResponse, so you // don't need to add any custom handlers if you use that method. func (mach *OlmMachine) HandleToDeviceEvent(evt *event.Event) { if len(evt.ToUserID) > 0 && (evt.ToUserID != mach.Client.UserID || evt.ToDeviceID != mach.Client.DeviceID) { // TODO This log probably needs to be silence-able if someone wants to use encrypted appservices with multiple e2ee sessions mach.Log.Debug("Dropping to-device event targeted to %s/%s (not us)", evt.ToUserID, evt.ToDeviceID) return } traceID := time.Now().Format("15:04:05.000000") if evt.Type != event.ToDeviceEncrypted { mach.Log.Trace("Starting handling to-device event of type %s from %s (trace: %s)", evt.Type.Type, evt.Sender, traceID) } switch content := evt.Content.Parsed.(type) { case *event.EncryptedEventContent: mach.Log.Debug("Handling encrypted to-device event from %s/%s (trace: %s)", evt.Sender, content.SenderKey, traceID) decryptedEvt, err := mach.decryptOlmEvent(evt, traceID) if err != nil { mach.Log.Error("Failed to decrypt to-device event: %v (trace: %s)", err, traceID) return } mach.Log.Trace("Successfully decrypted to-device from %s/%s into type %s (sender key: %s, trace: %s)", decryptedEvt.Sender, decryptedEvt.SenderDevice, decryptedEvt.Type.String(), decryptedEvt.SenderKey, traceID) switch decryptedContent := decryptedEvt.Content.Parsed.(type) { case *event.RoomKeyEventContent: mach.receiveRoomKey(decryptedEvt, decryptedContent, traceID) mach.Log.Trace("Handled room key event from %s/%s (trace: %s)", decryptedEvt.Sender, decryptedEvt.SenderDevice, traceID) case *event.ForwardedRoomKeyEventContent: if mach.importForwardedRoomKey(decryptedEvt, decryptedContent) { if ch, ok := mach.roomKeyRequestFilled.Load(decryptedContent.SessionID); ok { // close channel to notify listener that the key was received close(ch.(chan struct{})) } } mach.Log.Trace("Handled forwarded room key event from %s/%s (trace: %s)", decryptedEvt.Sender, decryptedEvt.SenderDevice, traceID) case *event.DummyEventContent: mach.Log.Debug("Received encrypted dummy event from %s/%s (trace: %s)", decryptedEvt.Sender, decryptedEvt.SenderDevice, traceID) default: mach.Log.Debug("Unhandled encrypted to-device event of type %s from %s/%s (trace: %s)", decryptedEvt.Type.String(), decryptedEvt.Sender, decryptedEvt.SenderDevice, traceID) } return case *event.RoomKeyRequestEventContent: mach.handleRoomKeyRequest(evt.Sender, content) // verification cases case *event.VerificationStartEventContent: mach.handleVerificationStart(evt.Sender, content, content.TransactionID, 10*time.Minute, "") case *event.VerificationAcceptEventContent: mach.handleVerificationAccept(evt.Sender, content, content.TransactionID) case *event.VerificationKeyEventContent: mach.handleVerificationKey(evt.Sender, content, content.TransactionID) case *event.VerificationMacEventContent: mach.handleVerificationMAC(evt.Sender, content, content.TransactionID) case *event.VerificationCancelEventContent: mach.handleVerificationCancel(evt.Sender, content, content.TransactionID) case *event.VerificationRequestEventContent: mach.handleVerificationRequest(evt.Sender, content, content.TransactionID, "") case *event.RoomKeyWithheldEventContent: mach.handleRoomKeyWithheld(content) default: deviceID, _ := evt.Content.Raw["device_id"].(string) mach.Log.Trace("Unhandled to-device event of type %s from %s/%s (trace: %s)", evt.Type.Type, evt.Sender, deviceID, traceID) return } mach.Log.Trace("Finished handling to-device event of type %s from %s (trace: %s)", evt.Type.Type, evt.Sender, traceID) } // GetOrFetchDevice attempts to retrieve the device identity for the given device from the store // and if it's not found it asks the server for it. func (mach *OlmMachine) GetOrFetchDevice(userID id.UserID, deviceID id.DeviceID) (*DeviceIdentity, error) { // get device identity device, err := mach.CryptoStore.GetDevice(userID, deviceID) if err != nil { return nil, fmt.Errorf("failed to get sender device from store: %w", err) } else if device != nil { return device, nil } // try to fetch if not found usersToDevices := mach.fetchKeys([]id.UserID{userID}, "", true) if devices, ok := usersToDevices[userID]; ok { if device, ok = devices[deviceID]; ok { return device, nil } return nil, fmt.Errorf("didn't get identity for device %s of %s", deviceID, userID) } return nil, fmt.Errorf("didn't get any devices for %s", userID) } // GetOrFetchDeviceByKey attempts to retrieve the device identity for the device with the given identity key from the // store and if it's not found it asks the server for it. This returns nil if the server doesn't return a device with // the given identity key. func (mach *OlmMachine) GetOrFetchDeviceByKey(userID id.UserID, identityKey id.IdentityKey) (*DeviceIdentity, error) { deviceIdentity, err := mach.CryptoStore.FindDeviceByKey(userID, identityKey) if err != nil || deviceIdentity != nil { return deviceIdentity, err } mach.Log.Debug("Didn't find identity of %s/%s in crypto store, fetching from server", userID, identityKey) devices := mach.LoadDevices(userID) for _, device := range devices { if device.IdentityKey == identityKey { return device, nil } } return nil, nil } // SendEncryptedToDevice sends an Olm-encrypted event to the given user device. func (mach *OlmMachine) SendEncryptedToDevice(device *DeviceIdentity, evtType event.Type, content event.Content) error { if err := mach.createOutboundSessions(map[id.UserID]map[id.DeviceID]*DeviceIdentity{ device.UserID: { device.DeviceID: device, }, }); err != nil { return err } mach.olmLock.Lock() defer mach.olmLock.Unlock() olmSess, err := mach.CryptoStore.GetLatestSession(device.IdentityKey) if err != nil { return err } if olmSess == nil { return fmt.Errorf("didn't find created outbound session for device %s of %s", device.DeviceID, device.UserID) } encrypted := mach.encryptOlmEvent(olmSess, device, evtType, content) encryptedContent := &event.Content{Parsed: &encrypted} mach.Log.Debug("Sending encrypted to-device event of type %s to %s/%s (identity key: %s, olm session ID: %s)", evtType.Type, device.UserID, device.DeviceID, device.IdentityKey, olmSess.ID()) _, err = mach.Client.SendToDevice(event.ToDeviceEncrypted, &mautrix.ReqSendToDevice{ Messages: map[id.UserID]map[id.DeviceID]*event.Content{ device.UserID: { device.DeviceID: encryptedContent, }, }, }, ) return err } func (mach *OlmMachine) createGroupSession(senderKey id.SenderKey, signingKey id.Ed25519, roomID id.RoomID, sessionID id.SessionID, sessionKey string, traceID string) { igs, err := NewInboundGroupSession(senderKey, signingKey, roomID, sessionKey) if err != nil { mach.Log.Error("Failed to create inbound group session: %v", err) return } else if igs.ID() != sessionID { mach.Log.Warn("Mismatched session ID while creating inbound group session") return } err = mach.CryptoStore.PutGroupSession(roomID, senderKey, sessionID, igs) if err != nil { mach.Log.Error("Failed to store new inbound group session: %v", err) return } mach.markSessionReceived(sessionID) mach.Log.Debug("Received inbound group session %s / %s / %s", roomID, senderKey, sessionID) } func (mach *OlmMachine) markSessionReceived(id id.SessionID) { mach.keyWaitersLock.Lock() ch, ok := mach.keyWaiters[id] if ok { close(ch) delete(mach.keyWaiters, id) } mach.keyWaitersLock.Unlock() } // WaitForSession waits for the given Megolm session to arrive. func (mach *OlmMachine) WaitForSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool { mach.keyWaitersLock.Lock() ch, ok := mach.keyWaiters[sessionID] if !ok { ch := make(chan struct{}) mach.keyWaiters[sessionID] = ch } mach.keyWaitersLock.Unlock() select { case <-ch: return true case <-time.After(timeout): sess, err := mach.CryptoStore.GetGroupSession(roomID, senderKey, sessionID) // Check if the session somehow appeared in the store without telling us // We accept withheld sessions as received, as then the decryption attempt will show the error. return sess != nil || errors.Is(err, ErrGroupSessionWithheld) } } func (mach *OlmMachine) receiveRoomKey(evt *DecryptedOlmEvent, content *event.RoomKeyEventContent, traceID string) { // TODO nio had a comment saying "handle this better" for the case where evt.Keys.Ed25519 is none? if content.Algorithm != id.AlgorithmMegolmV1 || evt.Keys.Ed25519 == "" { mach.Log.Debug("Ignoring weird room key from %s/%s: alg=%s, ed25519=%s, sessionid=%s, roomid=%s", evt.Sender, evt.SenderDevice, content.Algorithm, evt.Keys.Ed25519, content.SessionID, content.RoomID) return } mach.createGroupSession(evt.SenderKey, evt.Keys.Ed25519, content.RoomID, content.SessionID, content.SessionKey, traceID) } func (mach *OlmMachine) handleRoomKeyWithheld(content *event.RoomKeyWithheldEventContent) { if content.Algorithm != id.AlgorithmMegolmV1 { mach.Log.Debug("Non-megolm room key withheld event: %+v", content) return } err := mach.CryptoStore.PutWithheldGroupSession(*content) if err != nil { mach.Log.Error("Failed to save room key withheld event: %v", err) } } // ShareKeys uploads necessary keys to the server. // // If the Olm account hasn't been shared, the account keys will be uploaded. // If currentOTKCount is less than half of the limit (100 / 2 = 50), enough one-time keys will be uploaded so exactly // half of the limit is filled. func (mach *OlmMachine) ShareKeys(currentOTKCount int) error { var deviceKeys *mautrix.DeviceKeys if !mach.account.Shared { deviceKeys = mach.account.getInitialKeys(mach.Client.UserID, mach.Client.DeviceID) mach.Log.Trace("Going to upload initial account keys") } oneTimeKeys := mach.account.getOneTimeKeys(mach.Client.UserID, mach.Client.DeviceID, currentOTKCount) if len(oneTimeKeys) == 0 && deviceKeys == nil { mach.Log.Trace("No one-time keys nor device keys got when trying to share keys") return nil } req := &mautrix.ReqUploadKeys{ DeviceKeys: deviceKeys, OneTimeKeys: oneTimeKeys, } mach.Log.Trace("Uploading %d one-time keys", len(oneTimeKeys)) _, err := mach.Client.UploadKeys(req) if err != nil { return err } mach.account.Shared = true mach.saveAccount() return nil } go-0.11.1/crypto/machine_test.go000066400000000000000000000127421436100171500165150ustar00rootroot00000000000000// Copyright (c) 2020 Nikos Filippakis // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package crypto import ( "os" "testing" "maunium.net/go/mautrix" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) type emptyLogger struct{} func (emptyLogger) Error(message string, args ...interface{}) {} func (emptyLogger) Warn(message string, args ...interface{}) {} func (emptyLogger) Debug(message string, args ...interface{}) {} func (emptyLogger) Trace(message string, args ...interface{}) {} type mockStateStore struct{} func (mockStateStore) IsEncrypted(id.RoomID) bool { return true } func (mockStateStore) GetEncryptionEvent(id.RoomID) *event.EncryptionEventContent { return &event.EncryptionEventContent{ RotationPeriodMessages: 3, } } func (mockStateStore) FindSharedRooms(id.UserID) []id.RoomID { return []id.RoomID{"room1"} } func newMachine(t *testing.T, userID id.UserID) (*OlmMachine, string) { client, err := mautrix.NewClient("http://localhost", userID, "token") if err != nil { t.Fatalf("Error creating client: %v", err) } client.DeviceID = "device1" storeFileName := "gob_store_test_" + userID.String() + ".gob" gobStore, err := NewGobStore(storeFileName) if err != nil { os.Remove(storeFileName) t.Fatalf("Error creating Gob store: %v", err) } machine := NewOlmMachine(client, emptyLogger{}, gobStore, mockStateStore{}) if err := machine.Load(); err != nil { os.Remove(storeFileName) t.Fatalf("Error creating account: %v", err) } return machine, storeFileName } func TestOlmMachineOlmMegolmSessions(t *testing.T) { machineOut, storeFileNameOut := newMachine(t, "user1") defer os.Remove(storeFileNameOut) machineIn, storeFileNameIn := newMachine(t, "user2") defer os.Remove(storeFileNameIn) // generate OTKs for receiving machine otks := machineIn.account.getOneTimeKeys("user2", "device2", 0) var otk mautrix.OneTimeKey for _, otkTmp := range otks { // take first OTK otk = otkTmp break } // create outbound olm session for sending machine using OTK olmSession, err := machineOut.account.Internal.NewOutboundSession(machineIn.account.IdentityKey(), otk.Key) if err != nil { t.Errorf("Failed to create outbound olm session: %v", err) } // store sender device identity in receiving machine store machineIn.CryptoStore.PutDevices("user1", map[id.DeviceID]*DeviceIdentity{ "device1": { UserID: "user1", DeviceID: "device1", IdentityKey: machineOut.account.IdentityKey(), SigningKey: machineOut.account.SigningKey(), }, }) // create & store outbound megolm session for sending the event later megolmOutSession := machineOut.newOutboundGroupSession("room1") megolmOutSession.Shared = true machineOut.CryptoStore.AddOutboundGroupSession(megolmOutSession) // encrypt m.room_key event with olm session deviceIdentity := &DeviceIdentity{ UserID: "user2", DeviceID: "device2", IdentityKey: machineIn.account.IdentityKey(), SigningKey: machineIn.account.SigningKey(), } wrapped := wrapSession(olmSession) content := machineOut.encryptOlmEvent(wrapped, deviceIdentity, event.ToDeviceRoomKey, megolmOutSession.ShareContent()) senderKey := machineOut.account.IdentityKey() signingKey := machineOut.account.SigningKey() for _, content := range content.OlmCiphertext { // decrypt olm ciphertext decrypted, err := machineIn.decryptAndParseOlmCiphertext("user1", senderKey, content.Type, content.Body, "test") if err != nil { t.Errorf("Error decrypting olm content: %v", err) } // store room key in new inbound group session decrypted.Content.ParseRaw(event.ToDeviceRoomKey) roomKeyEvt := decrypted.Content.AsRoomKey() igs, err := NewInboundGroupSession(senderKey, signingKey, "room1", roomKeyEvt.SessionKey) if err != nil { t.Errorf("Error creating inbound megolm session: %v", err) } if err = machineIn.CryptoStore.PutGroupSession("room1", senderKey, igs.ID(), igs); err != nil { t.Errorf("Error storing inbound megolm session: %v", err) } } // encrypt event with megolm session in sending machine eventContent := map[string]string{"hello": "world"} encryptedEvtContent, err := machineOut.EncryptMegolmEvent("room1", event.EventMessage, eventContent) if err != nil { t.Errorf("Error encrypting megolm event: %v", err) } if megolmOutSession.MessageCount != 1 { t.Errorf("Megolm outbound session message count is not 1 but %d", megolmOutSession.MessageCount) } encryptedEvt := &event.Event{ Content: event.Content{Parsed: encryptedEvtContent}, Type: event.EventEncrypted, ID: "event1", RoomID: "room1", Sender: "user1", } // decrypt event on receiving machine and confirm decryptedEvt, err := machineIn.DecryptMegolmEvent(encryptedEvt) if err != nil { t.Errorf("Error decrypting megolm event: %v", err) } if decryptedEvt.Type != event.EventMessage { t.Errorf("Expected event type %v, got %v", event.EventMessage, decryptedEvt.Type) } if decryptedEvt.Content.Raw["hello"] != "world" { t.Errorf("Expected event content %v, got %v", eventContent, decryptedEvt.Content.Raw) } machineOut.EncryptMegolmEvent("room1", event.EventMessage, eventContent) if megolmOutSession.Expired() { t.Error("Megolm outbound session expired before 3rd message") } machineOut.EncryptMegolmEvent("room1", event.EventMessage, eventContent) if !megolmOutSession.Expired() { t.Error("Megolm outbound session not expired after 3rd message") } } go-0.11.1/crypto/olm/000077500000000000000000000000001436100171500143045ustar00rootroot00000000000000go-0.11.1/crypto/olm/LICENSE000066400000000000000000000236761436100171500153270ustar00rootroot00000000000000 Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS go-0.11.1/crypto/olm/README.md000066400000000000000000000001131436100171500155560ustar00rootroot00000000000000# Go olm bindings Based on [Dhole/go-olm](https://github.com/Dhole/go-olm) go-0.11.1/crypto/olm/account.go000066400000000000000000000274601436100171500163000ustar00rootroot00000000000000package olm // #cgo LDFLAGS: -lolm -lstdc++ // #include import "C" import ( "crypto/rand" "encoding/json" "unsafe" "github.com/tidwall/gjson" "github.com/tidwall/sjson" "maunium.net/go/mautrix/crypto/canonicaljson" "maunium.net/go/mautrix/id" ) // Account stores a device account for end to end encrypted messaging. type Account struct { int *C.OlmAccount mem []byte } // AccountFromPickled loads an Account from a pickled base64 string. Decrypts // the Account using the supplied key. Returns error on failure. If the key // doesn't match the one used to encrypt the Account then the error will be // "BAD_ACCOUNT_KEY". If the base64 couldn't be decoded then the error will be // "INVALID_BASE64". func AccountFromPickled(pickled, key []byte) (*Account, error) { if len(pickled) == 0 { return nil, EmptyInput } a := NewBlankAccount() return a, a.Unpickle(pickled, key) } func NewBlankAccount() *Account { memory := make([]byte, accountSize()) return &Account{ int: C.olm_account(unsafe.Pointer(&memory[0])), mem: memory, } } // NewAccount creates a new Account. func NewAccount() *Account { a := NewBlankAccount() random := make([]byte, a.createRandomLen()+1) _, err := rand.Read(random) if err != nil { panic(NotEnoughGoRandom) } r := C.olm_create_account( (*C.OlmAccount)(a.int), unsafe.Pointer(&random[0]), C.size_t(len(random))) if r == errorVal() { panic(a.lastError()) } else { return a } } // accountSize returns the size of an account object in bytes. func accountSize() uint { return uint(C.olm_account_size()) } // lastError returns an error describing the most recent error to happen to an // account. func (a *Account) lastError() error { return convertError(C.GoString(C.olm_account_last_error((*C.OlmAccount)(a.int)))) } // Clear clears the memory used to back this Account. func (a *Account) Clear() error { r := C.olm_clear_account((*C.OlmAccount)(a.int)) if r == errorVal() { return a.lastError() } else { return nil } } // pickleLen returns the number of bytes needed to store an Account. func (a *Account) pickleLen() uint { return uint(C.olm_pickle_account_length((*C.OlmAccount)(a.int))) } // createRandomLen returns the number of random bytes needed to create an // Account. func (a *Account) createRandomLen() uint { return uint(C.olm_create_account_random_length((*C.OlmAccount)(a.int))) } // identityKeysLen returns the size of the output buffer needed to hold the // identity keys. func (a *Account) identityKeysLen() uint { return uint(C.olm_account_identity_keys_length((*C.OlmAccount)(a.int))) } // signatureLen returns the length of an ed25519 signature encoded as base64. func (a *Account) signatureLen() uint { return uint(C.olm_account_signature_length((*C.OlmAccount)(a.int))) } // oneTimeKeysLen returns the size of the output buffer needed to hold the one // time keys. func (a *Account) oneTimeKeysLen() uint { return uint(C.olm_account_one_time_keys_length((*C.OlmAccount)(a.int))) } // genOneTimeKeysRandomLen returns the number of random bytes needed to // generate a given number of new one time keys. func (a *Account) genOneTimeKeysRandomLen(num uint) uint { return uint(C.olm_account_generate_one_time_keys_random_length( (*C.OlmAccount)(a.int), C.size_t(num))) } // Pickle returns an Account as a base64 string. Encrypts the Account using the // supplied key. func (a *Account) Pickle(key []byte) []byte { if len(key) == 0 { panic(NoKeyProvided) } pickled := make([]byte, a.pickleLen()) r := C.olm_pickle_account( (*C.OlmAccount)(a.int), unsafe.Pointer(&key[0]), C.size_t(len(key)), unsafe.Pointer(&pickled[0]), C.size_t(len(pickled))) if r == errorVal() { panic(a.lastError()) } return pickled[:r] } func (a *Account) Unpickle(pickled, key []byte) error { if len(key) == 0 { return NoKeyProvided } r := C.olm_unpickle_account( (*C.OlmAccount)(a.int), unsafe.Pointer(&key[0]), C.size_t(len(key)), unsafe.Pointer(&pickled[0]), C.size_t(len(pickled))) if r == errorVal() { return a.lastError() } return nil } func (a *Account) GobEncode() ([]byte, error) { pickled := a.Pickle(pickleKey) length := unpaddedBase64.DecodedLen(len(pickled)) rawPickled := make([]byte, length) _, err := unpaddedBase64.Decode(rawPickled, pickled) return rawPickled, err } func (a *Account) GobDecode(rawPickled []byte) error { if a.int == nil { *a = *NewBlankAccount() } length := unpaddedBase64.EncodedLen(len(rawPickled)) pickled := make([]byte, length) unpaddedBase64.Encode(pickled, rawPickled) return a.Unpickle(pickled, pickleKey) } func (a *Account) MarshalJSON() ([]byte, error) { pickled := a.Pickle(pickleKey) quotes := make([]byte, len(pickled)+2) quotes[0] = '"' quotes[len(quotes)-1] = '"' copy(quotes[1:len(quotes)-1], pickled) return quotes, nil } func (a *Account) UnmarshalJSON(data []byte) error { if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { return InputNotJSONString } if a.int == nil { *a = *NewBlankAccount() } return a.Unpickle(data[1:len(data)-1], pickleKey) } // IdentityKeysJSON returns the public parts of the identity keys for the Account. func (a *Account) IdentityKeysJSON() []byte { identityKeys := make([]byte, a.identityKeysLen()) r := C.olm_account_identity_keys( (*C.OlmAccount)(a.int), unsafe.Pointer(&identityKeys[0]), C.size_t(len(identityKeys))) if r == errorVal() { panic(a.lastError()) } else { return identityKeys } } // IdentityKeys returns the public parts of the Ed25519 and Curve25519 identity // keys for the Account. func (a *Account) IdentityKeys() (id.Ed25519, id.Curve25519) { identityKeysJSON := a.IdentityKeysJSON() results := gjson.GetManyBytes(identityKeysJSON, "ed25519", "curve25519") return id.Ed25519(results[0].Str), id.Curve25519(results[1].Str) } // Sign returns the signature of a message using the ed25519 key for this // Account. func (a *Account) Sign(message []byte) []byte { if len(message) == 0 { panic(EmptyInput) } signature := make([]byte, a.signatureLen()) r := C.olm_account_sign( (*C.OlmAccount)(a.int), unsafe.Pointer(&message[0]), C.size_t(len(message)), unsafe.Pointer(&signature[0]), C.size_t(len(signature))) if r == errorVal() { panic(a.lastError()) } return signature } // SignJSON signs the given JSON object following the Matrix specification: // https://matrix.org/docs/spec/appendices#signing-json func (a *Account) SignJSON(obj interface{}) (string, error) { objJSON, err := json.Marshal(obj) if err != nil { return "", err } objJSON, _ = sjson.DeleteBytes(objJSON, "unsigned") objJSON, _ = sjson.DeleteBytes(objJSON, "signatures") return string(a.Sign(canonicaljson.CanonicalJSONAssumeValid(objJSON))), nil } // OneTimeKeys returns the public parts of the unpublished one time keys for // the Account. // // The returned data is a struct with the single value "Curve25519", which is // itself an object mapping key id to base64-encoded Curve25519 key. For // example: // { // Curve25519: { // "AAAAAA": "wo76WcYtb0Vk/pBOdmduiGJ0wIEjW4IBMbbQn7aSnTo", // "AAAAAB": "LRvjo46L1X2vx69sS9QNFD29HWulxrmW11Up5AfAjgU" // } // } func (a *Account) OneTimeKeys() map[string]id.Curve25519 { oneTimeKeysJSON := make([]byte, a.oneTimeKeysLen()) r := C.olm_account_one_time_keys( (*C.OlmAccount)(a.int), unsafe.Pointer(&oneTimeKeysJSON[0]), C.size_t(len(oneTimeKeysJSON))) if r == errorVal() { panic(a.lastError()) } var oneTimeKeys struct { Curve25519 map[string]id.Curve25519 `json:"curve25519"` } err := json.Unmarshal(oneTimeKeysJSON, &oneTimeKeys) if err != nil { panic(err) } return oneTimeKeys.Curve25519 } // MarkKeysAsPublished marks the current set of one time keys as being // published. func (a *Account) MarkKeysAsPublished() { C.olm_account_mark_keys_as_published((*C.OlmAccount)(a.int)) } // MaxNumberOfOneTimeKeys returns the largest number of one time keys this // Account can store. func (a *Account) MaxNumberOfOneTimeKeys() uint { return uint(C.olm_account_max_number_of_one_time_keys((*C.OlmAccount)(a.int))) } // GenOneTimeKeys generates a number of new one time keys. If the total number // of keys stored by this Account exceeds MaxNumberOfOneTimeKeys then the old // keys are discarded. func (a *Account) GenOneTimeKeys(num uint) { random := make([]byte, a.genOneTimeKeysRandomLen(num)+1) _, err := rand.Read(random) if err != nil { panic(NotEnoughGoRandom) } r := C.olm_account_generate_one_time_keys( (*C.OlmAccount)(a.int), C.size_t(num), unsafe.Pointer(&random[0]), C.size_t(len(random))) if r == errorVal() { panic(a.lastError()) } } // NewOutboundSession creates a new out-bound session for sending messages to a // given curve25519 identityKey and oneTimeKey. Returns error on failure. If the // keys couldn't be decoded as base64 then the error will be "INVALID_BASE64" func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (*Session, error) { if len(theirIdentityKey) == 0 || len(theirOneTimeKey) == 0 { return nil, EmptyInput } s := NewBlankSession() random := make([]byte, s.createOutboundRandomLen()+1) _, err := rand.Read(random) if err != nil { panic(NotEnoughGoRandom) } r := C.olm_create_outbound_session( (*C.OlmSession)(s.int), (*C.OlmAccount)(a.int), unsafe.Pointer(&([]byte(theirIdentityKey)[0])), C.size_t(len(theirIdentityKey)), unsafe.Pointer(&([]byte(theirOneTimeKey)[0])), C.size_t(len(theirOneTimeKey)), unsafe.Pointer(&random[0]), C.size_t(len(random))) if r == errorVal() { return nil, s.lastError() } return s, nil } // NewInboundSession creates a new in-bound session for sending/receiving // messages from an incoming PRE_KEY message. Returns error on failure. If // the base64 couldn't be decoded then the error will be "INVALID_BASE64". If // the message was for an unsupported protocol version then the error will be // "BAD_MESSAGE_VERSION". If the message couldn't be decoded then then the // error will be "BAD_MESSAGE_FORMAT". If the message refers to an unknown one // time key then the error will be "BAD_MESSAGE_KEY_ID". func (a *Account) NewInboundSession(oneTimeKeyMsg string) (*Session, error) { if len(oneTimeKeyMsg) == 0 { return nil, EmptyInput } s := NewBlankSession() r := C.olm_create_inbound_session( (*C.OlmSession)(s.int), (*C.OlmAccount)(a.int), unsafe.Pointer(&([]byte(oneTimeKeyMsg)[0])), C.size_t(len(oneTimeKeyMsg))) if r == errorVal() { return nil, s.lastError() } return s, nil } // NewInboundSessionFrom creates a new in-bound session for sending/receiving // messages from an incoming PRE_KEY message. Returns error on failure. If // the base64 couldn't be decoded then the error will be "INVALID_BASE64". If // the message was for an unsupported protocol version then the error will be // "BAD_MESSAGE_VERSION". If the message couldn't be decoded then then the // error will be "BAD_MESSAGE_FORMAT". If the message refers to an unknown one // time key then the error will be "BAD_MESSAGE_KEY_ID". func (a *Account) NewInboundSessionFrom(theirIdentityKey id.Curve25519, oneTimeKeyMsg string) (*Session, error) { if len(theirIdentityKey) == 0 || len(oneTimeKeyMsg) == 0 { return nil, EmptyInput } s := NewBlankSession() r := C.olm_create_inbound_session_from( (*C.OlmSession)(s.int), (*C.OlmAccount)(a.int), unsafe.Pointer(&([]byte(theirIdentityKey)[0])), C.size_t(len(theirIdentityKey)), unsafe.Pointer(&([]byte(oneTimeKeyMsg)[0])), C.size_t(len(oneTimeKeyMsg))) if r == errorVal() { return nil, s.lastError() } return s, nil } // RemoveOneTimeKeys removes the one time keys that the session used from the // Account. Returns error on failure. If the Account doesn't have any // matching one time keys then the error will be "BAD_MESSAGE_KEY_ID". func (a *Account) RemoveOneTimeKeys(s *Session) error { r := C.olm_remove_one_time_keys( (*C.OlmAccount)(a.int), (*C.OlmSession)(s.int)) if r == errorVal() { return a.lastError() } return nil } go-0.11.1/crypto/olm/error.go000066400000000000000000000051011436100171500157610ustar00rootroot00000000000000package olm import ( "errors" "fmt" ) // Error codes from go-olm var ( EmptyInput = errors.New("empty input") NoKeyProvided = errors.New("no pickle key provided") NotEnoughGoRandom = errors.New("couldn't get enough randomness from crypto/rand") SignatureNotFound = errors.New("input JSON doesn't contain signature from specified device") InputNotJSONString = errors.New("input doesn't look like a JSON string") ) // Error codes from olm code var ( NotEnoughRandom = errors.New("not enough entropy was supplied") OutputBufferTooSmall = errors.New("supplied output buffer is too small") BadMessageVersion = errors.New("the message version is unsupported") BadMessageFormat = errors.New("the message couldn't be decoded") BadMessageMAC = errors.New("the message couldn't be decrypted") BadMessageKeyID = errors.New("the message references an unknown key ID") InvalidBase64 = errors.New("the input base64 was invalid") BadAccountKey = errors.New("the supplied account key is invalid") UnknownPickleVersion = errors.New("the pickled object is too new") CorruptedPickle = errors.New("the pickled object couldn't be decoded") BadSessionKey = errors.New("attempt to initialise an inbound group session from an invalid session key") UnknownMessageIndex = errors.New("attempt to decode a message whose index is earlier than our earliest known session key") BadLegacyAccountPickle = errors.New("attempt to unpickle an account which uses pickle version 1") BadSignature = errors.New("received message had a bad signature") InputBufferTooSmall = errors.New("the input data was too small to be valid") ) var errorMap = map[string]error{ "NOT_ENOUGH_RANDOM": NotEnoughRandom, "OUTPUT_BUFFER_TOO_SMALL": OutputBufferTooSmall, "BAD_MESSAGE_VERSION": BadMessageVersion, "BAD_MESSAGE_FORMAT": BadMessageFormat, "BAD_MESSAGE_MAC": BadMessageMAC, "BAD_MESSAGE_KEY_ID": BadMessageKeyID, "INVALID_BASE64": InvalidBase64, "BAD_ACCOUNT_KEY": BadAccountKey, "UNKNOWN_PICKLE_VERSION": UnknownPickleVersion, "CORRUPTED_PICKLE": CorruptedPickle, "BAD_SESSION_KEY": BadSessionKey, "UNKNOWN_MESSAGE_INDEX": UnknownMessageIndex, "BAD_LEGACY_ACCOUNT_PICKLE": BadLegacyAccountPickle, "BAD_SIGNATURE": BadSignature, "INPUT_BUFFER_TOO_SMALL": InputBufferTooSmall, } func convertError(errCode string) error { err, ok := errorMap[errCode] if ok { return err } return fmt.Errorf("unknown error: %s", errCode) } go-0.11.1/crypto/olm/inboundgroupsession.go000066400000000000000000000235741436100171500207650ustar00rootroot00000000000000package olm // #cgo LDFLAGS: -lolm -lstdc++ // #include import "C" import ( "unsafe" "maunium.net/go/mautrix/id" ) // InboundGroupSession stores an inbound encrypted messaging session for a // group. type InboundGroupSession struct { int *C.OlmInboundGroupSession mem []byte } // InboundGroupSessionFromPickled loads an InboundGroupSession from a pickled // base64 string. Decrypts the InboundGroupSession using the supplied key. // Returns error on failure. If the key doesn't match the one used to encrypt // the InboundGroupSession then the error will be "BAD_SESSION_KEY". If the // base64 couldn't be decoded then the error will be "INVALID_BASE64". func InboundGroupSessionFromPickled(pickled, key []byte) (*InboundGroupSession, error) { if len(pickled) == 0 { return nil, EmptyInput } lenKey := len(key) if lenKey == 0 { key = []byte(" ") } s := NewBlankInboundGroupSession() return s, s.Unpickle(pickled, key) } // NewInboundGroupSession creates a new inbound group session from a key // exported from OutboundGroupSession.Key(). Returns error on failure. // If the sessionKey is not valid base64 the error will be // "OLM_INVALID_BASE64". If the session_key is invalid the error will be // "OLM_BAD_SESSION_KEY". func NewInboundGroupSession(sessionKey []byte) (*InboundGroupSession, error) { if len(sessionKey) == 0 { return nil, EmptyInput } s := NewBlankInboundGroupSession() r := C.olm_init_inbound_group_session( (*C.OlmInboundGroupSession)(s.int), (*C.uint8_t)(&sessionKey[0]), C.size_t(len(sessionKey))) if r == errorVal() { return nil, s.lastError() } return s, nil } // InboundGroupSessionImport imports an inbound group session from a previous // export. Returns error on failure. If the sessionKey is not valid base64 // the error will be "OLM_INVALID_BASE64". If the session_key is invalid the // error will be "OLM_BAD_SESSION_KEY". func InboundGroupSessionImport(sessionKey []byte) (*InboundGroupSession, error) { if len(sessionKey) == 0 { return nil, EmptyInput } s := NewBlankInboundGroupSession() r := C.olm_import_inbound_group_session( (*C.OlmInboundGroupSession)(s.int), (*C.uint8_t)(&sessionKey[0]), C.size_t(len(sessionKey))) if r == errorVal() { return nil, s.lastError() } return s, nil } // inboundGroupSessionSize is the size of an inbound group session object in // bytes. func inboundGroupSessionSize() uint { return uint(C.olm_inbound_group_session_size()) } // newInboundGroupSession initialises an empty InboundGroupSession. func NewBlankInboundGroupSession() *InboundGroupSession { memory := make([]byte, inboundGroupSessionSize()) return &InboundGroupSession{ int: C.olm_inbound_group_session(unsafe.Pointer(&memory[0])), mem: memory, } } // lastError returns an error describing the most recent error to happen to an // inbound group session. func (s *InboundGroupSession) lastError() error { return convertError(C.GoString(C.olm_inbound_group_session_last_error((*C.OlmInboundGroupSession)(s.int)))) } // Clear clears the memory used to back this InboundGroupSession. func (s *InboundGroupSession) Clear() error { r := C.olm_clear_inbound_group_session((*C.OlmInboundGroupSession)(s.int)) if r == errorVal() { return s.lastError() } return nil } // pickleLen returns the number of bytes needed to store an inbound group // session. func (s *InboundGroupSession) pickleLen() uint { return uint(C.olm_pickle_inbound_group_session_length((*C.OlmInboundGroupSession)(s.int))) } // Pickle returns an InboundGroupSession as a base64 string. Encrypts the // InboundGroupSession using the supplied key. func (s *InboundGroupSession) Pickle(key []byte) []byte { if len(key) == 0 { panic(NoKeyProvided) } pickled := make([]byte, s.pickleLen()) r := C.olm_pickle_inbound_group_session( (*C.OlmInboundGroupSession)(s.int), unsafe.Pointer(&key[0]), C.size_t(len(key)), unsafe.Pointer(&pickled[0]), C.size_t(len(pickled))) if r == errorVal() { panic(s.lastError()) } return pickled[:r] } func (s *InboundGroupSession) Unpickle(pickled, key []byte) error { if len(key) == 0 { return NoKeyProvided } r := C.olm_unpickle_inbound_group_session( (*C.OlmInboundGroupSession)(s.int), unsafe.Pointer(&key[0]), C.size_t(len(key)), unsafe.Pointer(&pickled[0]), C.size_t(len(pickled))) if r == errorVal() { return s.lastError() } return nil } func (s *InboundGroupSession) GobEncode() ([]byte, error) { pickled := s.Pickle(pickleKey) length := unpaddedBase64.DecodedLen(len(pickled)) rawPickled := make([]byte, length) _, err := unpaddedBase64.Decode(rawPickled, pickled) return rawPickled, err } func (s *InboundGroupSession) GobDecode(rawPickled []byte) error { if s == nil || s.int == nil { *s = *NewBlankInboundGroupSession() } length := unpaddedBase64.EncodedLen(len(rawPickled)) pickled := make([]byte, length) unpaddedBase64.Encode(pickled, rawPickled) return s.Unpickle(pickled, pickleKey) } func (s *InboundGroupSession) MarshalJSON() ([]byte, error) { pickled := s.Pickle(pickleKey) quotes := make([]byte, len(pickled)+2) quotes[0] = '"' quotes[len(quotes)-1] = '"' copy(quotes[1:len(quotes)-1], pickled) return quotes, nil } func (s *InboundGroupSession) UnmarshalJSON(data []byte) error { if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { return InputNotJSONString } if s == nil || s.int == nil { *s = *NewBlankInboundGroupSession() } return s.Unpickle(data[1:len(data)-1], pickleKey) } func clone(original []byte) []byte { clone := make([]byte, len(original)) copy(clone, original) return clone } // decryptMaxPlaintextLen returns the maximum number of bytes of plain-text a // given message could decode to. The actual size could be different due to // padding. Returns error on failure. If the message base64 couldn't be // decoded then the error will be "INVALID_BASE64". If the message is for an // unsupported version of the protocol then the error will be // "BAD_MESSAGE_VERSION". If the message couldn't be decoded then the error // will be "BAD_MESSAGE_FORMAT". func (s *InboundGroupSession) decryptMaxPlaintextLen(message []byte) (uint, error) { if len(message) == 0 { return 0, EmptyInput } // olm_group_decrypt_max_plaintext_length destroys the input, so we have to clone it message = clone(message) r := C.olm_group_decrypt_max_plaintext_length( (*C.OlmInboundGroupSession)(s.int), (*C.uint8_t)(&message[0]), C.size_t(len(message))) if r == errorVal() { return 0, s.lastError() } return uint(r), nil } // Decrypt decrypts a message using the InboundGroupSession. Returns the the // plain-text and message index on success. Returns error on failure. If the // base64 couldn't be decoded then the error will be "INVALID_BASE64". If the // message is for an unsupported version of the protocol then the error will be // "BAD_MESSAGE_VERSION". If the message couldn't be decoded then the error // will be BAD_MESSAGE_FORMAT". If the MAC on the message was invalid then the // error will be "BAD_MESSAGE_MAC". If we do not have a session key // corresponding to the message's index (ie, it was sent before the session key // was shared with us) the error will be "OLM_UNKNOWN_MESSAGE_INDEX". func (s *InboundGroupSession) Decrypt(message []byte) ([]byte, uint, error) { if len(message) == 0 { return nil, 0, EmptyInput } decryptMaxPlaintextLen, err := s.decryptMaxPlaintextLen(message) if err != nil { return nil, 0, err } plaintext := make([]byte, decryptMaxPlaintextLen) var messageIndex uint32 r := C.olm_group_decrypt( (*C.OlmInboundGroupSession)(s.int), (*C.uint8_t)(&message[0]), C.size_t(len(message)), (*C.uint8_t)(&plaintext[0]), C.size_t(len(plaintext)), (*C.uint32_t)(&messageIndex)) if r == errorVal() { return nil, 0, s.lastError() } return plaintext[:r], uint(messageIndex), nil } // sessionIdLen returns the number of bytes needed to store a session ID. func (s *InboundGroupSession) sessionIdLen() uint { return uint(C.olm_inbound_group_session_id_length((*C.OlmInboundGroupSession)(s.int))) } // ID returns a base64-encoded identifier for this session. func (s *InboundGroupSession) ID() id.SessionID { sessionID := make([]byte, s.sessionIdLen()) r := C.olm_inbound_group_session_id( (*C.OlmInboundGroupSession)(s.int), (*C.uint8_t)(&sessionID[0]), C.size_t(len(sessionID))) if r == errorVal() { panic(s.lastError()) } return id.SessionID(sessionID[:r]) } // FirstKnownIndex returns the first message index we know how to decrypt. func (s *InboundGroupSession) FirstKnownIndex() uint32 { return uint32(C.olm_inbound_group_session_first_known_index((*C.OlmInboundGroupSession)(s.int))) } // IsVerified check if the session has been verified as a valid session. (A // session is verified either because the original session share was signed, or // because we have subsequently successfully decrypted a message.) func (s *InboundGroupSession) IsVerified() uint { return uint(C.olm_inbound_group_session_is_verified((*C.OlmInboundGroupSession)(s.int))) } // exportLen returns the number of bytes needed to export an inbound group // session. func (s *InboundGroupSession) exportLen() uint { return uint(C.olm_export_inbound_group_session_length((*C.OlmInboundGroupSession)(s.int))) } // Export returns the base64-encoded ratchet key for this session, at the given // index, in a format which can be used by // InboundGroupSession.InboundGroupSessionImport(). Encrypts the // InboundGroupSession using the supplied key. Returns error on failure. // if we do not have a session key corresponding to the given index (ie, it was // sent before the session key was shared with us) the error will be // "OLM_UNKNOWN_MESSAGE_INDEX". func (s *InboundGroupSession) Export(messageIndex uint32) (string, error) { key := make([]byte, s.exportLen()) r := C.olm_export_inbound_group_session( (*C.OlmInboundGroupSession)(s.int), (*C.uint8_t)(&key[0]), C.size_t(len(key)), C.uint32_t(messageIndex)) if r == errorVal() { return "", s.lastError() } return string(key[:r]), nil } go-0.11.1/crypto/olm/olm.go000066400000000000000000000015551436100171500154300ustar00rootroot00000000000000package olm // #cgo LDFLAGS: -lolm -lstdc++ // #include import "C" import ( "encoding/base64" "maunium.net/go/mautrix/id" ) // Signatures is the data structure used to sign JSON objects. type Signatures map[id.UserID]map[id.DeviceKeyID]string // Version returns the version number of the olm library. func Version() (major, minor, patch uint8) { C.olm_get_library_version( (*C.uint8_t)(&major), (*C.uint8_t)(&minor), (*C.uint8_t)(&patch)) return } // errorVal returns the value that olm functions return if there was an error. func errorVal() C.size_t { return C.olm_error() } var unpaddedBase64 = base64.StdEncoding.WithPadding(base64.NoPadding) var pickleKey = []byte("maunium.net/go/mautrix/crypto/olm") // SetPickleKey sets the global pickle key used when encoding structs with Gob or JSON. func SetPickleKey(key []byte) { pickleKey = key } go-0.11.1/crypto/olm/outboundgroupsession.go000066400000000000000000000160651436100171500211630ustar00rootroot00000000000000package olm // #cgo LDFLAGS: -lolm -lstdc++ // #include import "C" import ( "crypto/rand" "unsafe" "maunium.net/go/mautrix/id" ) // OutboundGroupSession stores an outbound encrypted messaging session for a // group. type OutboundGroupSession struct { int *C.OlmOutboundGroupSession mem []byte } // OutboundGroupSessionFromPickled loads an OutboundGroupSession from a pickled // base64 string. Decrypts the OutboundGroupSession using the supplied key. // Returns error on failure. If the key doesn't match the one used to encrypt // the OutboundGroupSession then the error will be "BAD_SESSION_KEY". If the // base64 couldn't be decoded then the error will be "INVALID_BASE64". func OutboundGroupSessionFromPickled(pickled, key []byte) (*OutboundGroupSession, error) { if len(pickled) == 0 { return nil, EmptyInput } s := NewBlankOutboundGroupSession() return s, s.Unpickle(pickled, key) } // NewOutboundGroupSession creates a new outbound group session. func NewOutboundGroupSession() *OutboundGroupSession { s := NewBlankOutboundGroupSession() random := make([]byte, s.createRandomLen()+1) _, err := rand.Read(random) if err != nil { panic(NotEnoughGoRandom) } r := C.olm_init_outbound_group_session( (*C.OlmOutboundGroupSession)(s.int), (*C.uint8_t)(&random[0]), C.size_t(len(random))) if r == errorVal() { panic(s.lastError()) } return s } // outboundGroupSessionSize is the size of an outbound group session object in // bytes. func outboundGroupSessionSize() uint { return uint(C.olm_outbound_group_session_size()) } // newOutboundGroupSession initialises an empty OutboundGroupSession. func NewBlankOutboundGroupSession() *OutboundGroupSession { memory := make([]byte, outboundGroupSessionSize()) return &OutboundGroupSession{ int: C.olm_outbound_group_session(unsafe.Pointer(&memory[0])), mem: memory, } } // lastError returns an error describing the most recent error to happen to an // outbound group session. func (s *OutboundGroupSession) lastError() error { return convertError(C.GoString(C.olm_outbound_group_session_last_error((*C.OlmOutboundGroupSession)(s.int)))) } // Clear clears the memory used to back this OutboundGroupSession. func (s *OutboundGroupSession) Clear() error { r := C.olm_clear_outbound_group_session((*C.OlmOutboundGroupSession)(s.int)) if r == errorVal() { return s.lastError() } else { return nil } } // pickleLen returns the number of bytes needed to store an outbound group // session. func (s *OutboundGroupSession) pickleLen() uint { return uint(C.olm_pickle_outbound_group_session_length((*C.OlmOutboundGroupSession)(s.int))) } // Pickle returns an OutboundGroupSession as a base64 string. Encrypts the // OutboundGroupSession using the supplied key. func (s *OutboundGroupSession) Pickle(key []byte) []byte { if len(key) == 0 { panic(NoKeyProvided) } pickled := make([]byte, s.pickleLen()) r := C.olm_pickle_outbound_group_session( (*C.OlmOutboundGroupSession)(s.int), unsafe.Pointer(&key[0]), C.size_t(len(key)), unsafe.Pointer(&pickled[0]), C.size_t(len(pickled))) if r == errorVal() { panic(s.lastError()) } return pickled[:r] } func (s *OutboundGroupSession) Unpickle(pickled, key []byte) error { if len(key) == 0 { return NoKeyProvided } r := C.olm_unpickle_outbound_group_session( (*C.OlmOutboundGroupSession)(s.int), unsafe.Pointer(&key[0]), C.size_t(len(key)), unsafe.Pointer(&pickled[0]), C.size_t(len(pickled))) if r == errorVal() { return s.lastError() } return nil } func (s *OutboundGroupSession) GobEncode() ([]byte, error) { pickled := s.Pickle(pickleKey) length := unpaddedBase64.DecodedLen(len(pickled)) rawPickled := make([]byte, length) _, err := unpaddedBase64.Decode(rawPickled, pickled) return rawPickled, err } func (s *OutboundGroupSession) GobDecode(rawPickled []byte) error { if s == nil || s.int == nil { *s = *NewBlankOutboundGroupSession() } length := unpaddedBase64.EncodedLen(len(rawPickled)) pickled := make([]byte, length) unpaddedBase64.Encode(pickled, rawPickled) return s.Unpickle(pickled, pickleKey) } func (s *OutboundGroupSession) MarshalJSON() ([]byte, error) { pickled := s.Pickle(pickleKey) quotes := make([]byte, len(pickled)+2) quotes[0] = '"' quotes[len(quotes)-1] = '"' copy(quotes[1:len(quotes)-1], pickled) return quotes, nil } func (s *OutboundGroupSession) UnmarshalJSON(data []byte) error { if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { return InputNotJSONString } if s == nil || s.int == nil { *s = *NewBlankOutboundGroupSession() } return s.Unpickle(data[1:len(data)-1], pickleKey) } // createRandomLen returns the number of random bytes needed to create an // Account. func (s *OutboundGroupSession) createRandomLen() uint { return uint(C.olm_init_outbound_group_session_random_length((*C.OlmOutboundGroupSession)(s.int))) } // encryptMsgLen returns the size of the next message in bytes for the given // number of plain-text bytes. func (s *OutboundGroupSession) encryptMsgLen(plainTextLen int) uint { return uint(C.olm_group_encrypt_message_length((*C.OlmOutboundGroupSession)(s.int), C.size_t(plainTextLen))) } // Encrypt encrypts a message using the Session. Returns the encrypted message // as base64. func (s *OutboundGroupSession) Encrypt(plaintext []byte) []byte { if len(plaintext) == 0 { panic(EmptyInput) } message := make([]byte, s.encryptMsgLen(len(plaintext))) r := C.olm_group_encrypt( (*C.OlmOutboundGroupSession)(s.int), (*C.uint8_t)(&plaintext[0]), C.size_t(len(plaintext)), (*C.uint8_t)(&message[0]), C.size_t(len(message))) if r == errorVal() { panic(s.lastError()) } return message[:r] } // sessionIdLen returns the number of bytes needed to store a session ID. func (s *OutboundGroupSession) sessionIdLen() uint { return uint(C.olm_outbound_group_session_id_length((*C.OlmOutboundGroupSession)(s.int))) } // ID returns a base64-encoded identifier for this session. func (s *OutboundGroupSession) ID() id.SessionID { sessionID := make([]byte, s.sessionIdLen()) r := C.olm_outbound_group_session_id( (*C.OlmOutboundGroupSession)(s.int), (*C.uint8_t)(&sessionID[0]), C.size_t(len(sessionID))) if r == errorVal() { panic(s.lastError()) } return id.SessionID(sessionID[:r]) } // MessageIndex returns the message index for this session. Each message is // sent with an increasing index; this returns the index for the next message. func (s *OutboundGroupSession) MessageIndex() uint { return uint(C.olm_outbound_group_session_message_index((*C.OlmOutboundGroupSession)(s.int))) } // sessionKeyLen returns the number of bytes needed to store a session key. func (s *OutboundGroupSession) sessionKeyLen() uint { return uint(C.olm_outbound_group_session_key_length((*C.OlmOutboundGroupSession)(s.int))) } // Key returns the base64-encoded current ratchet key for this session. func (s *OutboundGroupSession) Key() string { sessionKey := make([]byte, s.sessionKeyLen()) r := C.olm_outbound_group_session_key( (*C.OlmOutboundGroupSession)(s.int), (*C.uint8_t)(&sessionKey[0]), C.size_t(len(sessionKey))) if r == errorVal() { panic(s.lastError()) } return string(sessionKey[:r]) } go-0.11.1/crypto/olm/pk.go000066400000000000000000000057461436100171500152610ustar00rootroot00000000000000package olm // #cgo LDFLAGS: -lolm -lstdc++ // #include // #include import "C" import ( "crypto/rand" "encoding/json" "unsafe" "github.com/tidwall/sjson" "maunium.net/go/mautrix/crypto/canonicaljson" "maunium.net/go/mautrix/id" ) // PkSigning stores a key pair for signing messages. type PkSigning struct { int *C.OlmPkSigning mem []byte PublicKey id.Ed25519 Seed []byte } func pkSigningSize() uint { return uint(C.olm_pk_signing_size()) } func pkSigningSeedLength() uint { return uint(C.olm_pk_signing_seed_length()) } func pkSigningPublicKeyLength() uint { return uint(C.olm_pk_signing_public_key_length()) } func pkSigningSignatureLength() uint { return uint(C.olm_pk_signature_length()) } func NewBlankPkSigning() *PkSigning { memory := make([]byte, pkSigningSize()) return &PkSigning{ int: C.olm_pk_signing(unsafe.Pointer(&memory[0])), mem: memory, } } // Clear clears the underlying memory of a PkSigning object. func (p *PkSigning) Clear() { C.olm_clear_pk_signing((*C.OlmPkSigning)(p.int)) } // NewPkSigningFromSeed creates a new PkSigning object using the given seed. func NewPkSigningFromSeed(seed []byte) (*PkSigning, error) { p := NewBlankPkSigning() p.Clear() pubKey := make([]byte, pkSigningPublicKeyLength()) if C.olm_pk_signing_key_from_seed((*C.OlmPkSigning)(p.int), unsafe.Pointer(&pubKey[0]), C.size_t(len(pubKey)), unsafe.Pointer(&seed[0]), C.size_t(len(seed))) == errorVal() { return nil, p.lastError() } p.PublicKey = id.Ed25519(pubKey) p.Seed = seed return p, nil } // NewPkSigning creates a new PkSigning object, containing a key pair for signing messages. func NewPkSigning() (*PkSigning, error) { // Generate the seed seed := make([]byte, pkSigningSeedLength()) _, err := rand.Read(seed) if err != nil { panic(NotEnoughGoRandom) } pk, err := NewPkSigningFromSeed(seed) return pk, err } // Sign creates a signature for the given message using this key. func (p *PkSigning) Sign(message []byte) ([]byte, error) { signature := make([]byte, pkSigningSignatureLength()) if C.olm_pk_sign((*C.OlmPkSigning)(p.int), (*C.uint8_t)(unsafe.Pointer(&message[0])), C.size_t(len(message)), (*C.uint8_t)(unsafe.Pointer(&signature[0])), C.size_t(len(signature))) == errorVal() { return nil, p.lastError() } return signature, nil } // SignJSON creates a signature for the given object after encoding it to canonical JSON. func (p *PkSigning) SignJSON(obj interface{}) (string, error) { objJSON, err := json.Marshal(obj) if err != nil { return "", err } objJSON, _ = sjson.DeleteBytes(objJSON, "unsigned") objJSON, _ = sjson.DeleteBytes(objJSON, "signatures") signature, err := p.Sign(canonicaljson.CanonicalJSONAssumeValid(objJSON)) if err != nil { return "", err } return string(signature), nil } // lastError returns the last error that happened in relation to this PkSigning object. func (p *PkSigning) lastError() error { return convertError(C.GoString(C.olm_pk_signing_last_error((*C.OlmPkSigning)(p.int)))) } go-0.11.1/crypto/olm/session.go000066400000000000000000000260101436100171500163150ustar00rootroot00000000000000package olm // #cgo LDFLAGS: -lolm -lstdc++ // #include // #include // #include // void olm_session_describe(OlmSession * session, char *buf, size_t buflen) __attribute__((weak)); // void meowlm_session_describe(OlmSession * session, char *buf, size_t buflen) { // if (olm_session_describe) { // olm_session_describe(session, buf, buflen); // } else { // sprintf(buf, "olm_session_describe not supported"); // } // } import "C" import ( "crypto/rand" "unsafe" "maunium.net/go/mautrix/id" ) // Session stores an end to end encrypted messaging session. type Session struct { int *C.OlmSession mem []byte } // sessionSize is the size of a session object in bytes. func sessionSize() uint { return uint(C.olm_session_size()) } // SessionFromPickled loads a Session from a pickled base64 string. Decrypts // the Session using the supplied key. Returns error on failure. If the key // doesn't match the one used to encrypt the Session then the error will be // "BAD_SESSION_KEY". If the base64 couldn't be decoded then the error will be // "INVALID_BASE64". func SessionFromPickled(pickled, key []byte) (*Session, error) { if len(pickled) == 0 { return nil, EmptyInput } s := NewBlankSession() return s, s.Unpickle(pickled, key) } func NewBlankSession() *Session { memory := make([]byte, sessionSize()) return &Session{ int: C.olm_session(unsafe.Pointer(&memory[0])), mem: memory, } } // lastError returns an error describing the most recent error to happen to a // session. func (s *Session) lastError() error { return convertError(C.GoString(C.olm_session_last_error((*C.OlmSession)(s.int)))) } // Clear clears the memory used to back this Session. func (s *Session) Clear() error { r := C.olm_clear_session((*C.OlmSession)(s.int)) if r == errorVal() { return s.lastError() } return nil } // pickleLen returns the number of bytes needed to store a session. func (s *Session) pickleLen() uint { return uint(C.olm_pickle_session_length((*C.OlmSession)(s.int))) } // createOutboundRandomLen returns the number of random bytes needed to create // an outbound session. func (s *Session) createOutboundRandomLen() uint { return uint(C.olm_create_outbound_session_random_length((*C.OlmSession)(s.int))) } // idLen returns the length of the buffer needed to return the id for this // session. func (s *Session) idLen() uint { return uint(C.olm_session_id_length((*C.OlmSession)(s.int))) } // encryptRandomLen returns the number of random bytes needed to encrypt the // next message. func (s *Session) encryptRandomLen() uint { return uint(C.olm_encrypt_random_length((*C.OlmSession)(s.int))) } // encryptMsgLen returns the size of the next message in bytes for the given // number of plain-text bytes. func (s *Session) encryptMsgLen(plainTextLen int) uint { return uint(C.olm_encrypt_message_length((*C.OlmSession)(s.int), C.size_t(plainTextLen))) } // decryptMaxPlaintextLen returns the maximum number of bytes of plain-text a // given message could decode to. The actual size could be different due to // padding. Returns error on failure. If the message base64 couldn't be // decoded then the error will be "INVALID_BASE64". If the message is for an // unsupported version of the protocol then the error will be // "BAD_MESSAGE_VERSION". If the message couldn't be decoded then the error // will be "BAD_MESSAGE_FORMAT". func (s *Session) decryptMaxPlaintextLen(message string, msgType id.OlmMsgType) (uint, error) { if len(message) == 0 { return 0, EmptyInput } r := C.olm_decrypt_max_plaintext_length( (*C.OlmSession)(s.int), C.size_t(msgType), unsafe.Pointer(C.CString(message)), C.size_t(len(message))) if r == errorVal() { return 0, s.lastError() } return uint(r), nil } // Pickle returns a Session as a base64 string. Encrypts the Session using the // supplied key. func (s *Session) Pickle(key []byte) []byte { if len(key) == 0 { panic(NoKeyProvided) } pickled := make([]byte, s.pickleLen()) r := C.olm_pickle_session( (*C.OlmSession)(s.int), unsafe.Pointer(&key[0]), C.size_t(len(key)), unsafe.Pointer(&pickled[0]), C.size_t(len(pickled))) if r == errorVal() { panic(s.lastError()) } return pickled[:r] } func (s *Session) Unpickle(pickled, key []byte) error { if len(key) == 0 { return NoKeyProvided } r := C.olm_unpickle_session( (*C.OlmSession)(s.int), unsafe.Pointer(&key[0]), C.size_t(len(key)), unsafe.Pointer(&pickled[0]), C.size_t(len(pickled))) if r == errorVal() { return s.lastError() } return nil } func (s *Session) GobEncode() ([]byte, error) { pickled := s.Pickle(pickleKey) length := unpaddedBase64.DecodedLen(len(pickled)) rawPickled := make([]byte, length) _, err := unpaddedBase64.Decode(rawPickled, pickled) return rawPickled, err } func (s *Session) GobDecode(rawPickled []byte) error { if s == nil || s.int == nil { *s = *NewBlankSession() } length := unpaddedBase64.EncodedLen(len(rawPickled)) pickled := make([]byte, length) unpaddedBase64.Encode(pickled, rawPickled) return s.Unpickle(pickled, pickleKey) } func (s *Session) MarshalJSON() ([]byte, error) { pickled := s.Pickle(pickleKey) quotes := make([]byte, len(pickled)+2) quotes[0] = '"' quotes[len(quotes)-1] = '"' copy(quotes[1:len(quotes)-1], pickled) return quotes, nil } func (s *Session) UnmarshalJSON(data []byte) error { if len(data) == 0 || len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' { return InputNotJSONString } if s == nil || s.int == nil { *s = *NewBlankSession() } return s.Unpickle(data[1:len(data)-1], pickleKey) } // Id returns an identifier for this Session. Will be the same for both ends // of the conversation. func (s *Session) ID() id.SessionID { sessionID := make([]byte, s.idLen()) r := C.olm_session_id( (*C.OlmSession)(s.int), unsafe.Pointer(&sessionID[0]), C.size_t(len(sessionID))) if r == errorVal() { panic(s.lastError()) } return id.SessionID(sessionID) } // HasReceivedMessage returns true if this session has received any message. func (s *Session) HasReceivedMessage() bool { switch C.olm_session_has_received_message((*C.OlmSession)(s.int)) { case 0: return false default: return true } } // MatchesInboundSession checks if the PRE_KEY message is for this in-bound // Session. This can happen if multiple messages are sent to this Account // before this Account sends a message in reply. Returns true if the session // matches. Returns false if the session does not match. Returns error on // failure. If the base64 couldn't be decoded then the error will be // "INVALID_BASE64". If the message was for an unsupported protocol version // then the error will be "BAD_MESSAGE_VERSION". If the message couldn't be // decoded then then the error will be "BAD_MESSAGE_FORMAT". func (s *Session) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) { if len(oneTimeKeyMsg) == 0 { return false, EmptyInput } r := C.olm_matches_inbound_session( (*C.OlmSession)(s.int), unsafe.Pointer(&([]byte(oneTimeKeyMsg))[0]), C.size_t(len(oneTimeKeyMsg))) if r == 1 { return true, nil } else if r == 0 { return false, nil } else { // if r == errorVal() return false, s.lastError() } } // MatchesInboundSessionFrom checks if the PRE_KEY message is for this in-bound // Session. This can happen if multiple messages are sent to this Account // before this Account sends a message in reply. Returns true if the session // matches. Returns false if the session does not match. Returns error on // failure. If the base64 couldn't be decoded then the error will be // "INVALID_BASE64". If the message was for an unsupported protocol version // then the error will be "BAD_MESSAGE_VERSION". If the message couldn't be // decoded then then the error will be "BAD_MESSAGE_FORMAT". func (s *Session) MatchesInboundSessionFrom(theirIdentityKey, oneTimeKeyMsg string) (bool, error) { if len(theirIdentityKey) == 0 || len(oneTimeKeyMsg) == 0 { return false, EmptyInput } r := C.olm_matches_inbound_session_from( (*C.OlmSession)(s.int), unsafe.Pointer(&([]byte(theirIdentityKey))[0]), C.size_t(len(theirIdentityKey)), unsafe.Pointer(&([]byte(oneTimeKeyMsg))[0]), C.size_t(len(oneTimeKeyMsg))) if r == 1 { return true, nil } else if r == 0 { return false, nil } else { // if r == errorVal() return false, s.lastError() } } // EncryptMsgType returns the type of the next message that Encrypt will // return. Returns MsgTypePreKey if the message will be a PRE_KEY message. // Returns MsgTypeMsg if the message will be a normal message. Returns error // on failure. func (s *Session) EncryptMsgType() id.OlmMsgType { switch C.olm_encrypt_message_type((*C.OlmSession)(s.int)) { case C.size_t(id.OlmMsgTypePreKey): return id.OlmMsgTypePreKey case C.size_t(id.OlmMsgTypeMsg): return id.OlmMsgTypeMsg default: panic("olm_encrypt_message_type returned invalid result") } } // Encrypt encrypts a message using the Session. Returns the encrypted message // as base64. func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte) { if len(plaintext) == 0 { panic(EmptyInput) } // Make the slice be at least length 1 random := make([]byte, s.encryptRandomLen()+1) _, err := rand.Read(random) if err != nil { panic(NotEnoughGoRandom) } messageType := s.EncryptMsgType() message := make([]byte, s.encryptMsgLen(len(plaintext))) r := C.olm_encrypt( (*C.OlmSession)(s.int), unsafe.Pointer(&plaintext[0]), C.size_t(len(plaintext)), unsafe.Pointer(&random[0]), C.size_t(len(random)), unsafe.Pointer(&message[0]), C.size_t(len(message))) if r == errorVal() { panic(s.lastError()) } return messageType, message[:r] } // Decrypt decrypts a message using the Session. Returns the the plain-text on // success. Returns error on failure. If the base64 couldn't be decoded then // the error will be "INVALID_BASE64". If the message is for an unsupported // version of the protocol then the error will be "BAD_MESSAGE_VERSION". If // the message couldn't be decoded then the error will be BAD_MESSAGE_FORMAT". // If the MAC on the message was invalid then the error will be // "BAD_MESSAGE_MAC". func (s *Session) Decrypt(message string, msgType id.OlmMsgType) ([]byte, error) { if len(message) == 0 { return nil, EmptyInput } decryptMaxPlaintextLen, err := s.decryptMaxPlaintextLen(message, msgType) if err != nil { return nil, err } plaintext := make([]byte, decryptMaxPlaintextLen) r := C.olm_decrypt( (*C.OlmSession)(s.int), C.size_t(msgType), unsafe.Pointer(&([]byte(message))[0]), C.size_t(len(message)), unsafe.Pointer(&plaintext[0]), C.size_t(len(plaintext))) if r == errorVal() { return nil, s.lastError() } return plaintext[:r], nil } // https://gitlab.matrix.org/matrix-org/olm/-/blob/3.2.8/include/olm/olm.h#L392-393 const maxDescribeSize = 600 // Describe generates a string describing the internal state of an olm session for debugging and logging purposes. func (s *Session) Describe() string { desc := (*C.char)(C.malloc(C.size_t(maxDescribeSize))) defer C.free(unsafe.Pointer(desc)) C.meowlm_session_describe( (*C.OlmSession)(s.int), desc, C.size_t(maxDescribeSize)) return C.GoString(desc) } go-0.11.1/crypto/olm/utility.go000066400000000000000000000105761436100171500163470ustar00rootroot00000000000000package olm // #cgo LDFLAGS: -lolm -lstdc++ // #include import "C" import ( "encoding/json" "fmt" "strings" "unsafe" "github.com/tidwall/gjson" "github.com/tidwall/sjson" "maunium.net/go/mautrix/crypto/canonicaljson" "maunium.net/go/mautrix/id" ) // Utility stores the necessary state to perform hash and signature // verification operations. type Utility struct { int *C.OlmUtility mem []byte } // utilitySize returns the size of a utility object in bytes. func utilitySize() uint { return uint(C.olm_utility_size()) } // sha256Len returns the length of the buffer needed to hold the SHA-256 hash. func (u *Utility) sha256Len() uint { return uint(C.olm_sha256_length((*C.OlmUtility)(u.int))) } // lastError returns an error describing the most recent error to happen to a // utility. func (u *Utility) lastError() error { return convertError(C.GoString(C.olm_utility_last_error((*C.OlmUtility)(u.int)))) } // Clear clears the memory used to back this utility. func (u *Utility) Clear() error { r := C.olm_clear_utility((*C.OlmUtility)(u.int)) if r == errorVal() { return u.lastError() } return nil } // NewUtility creates a new utility. func NewUtility() *Utility { memory := make([]byte, utilitySize()) return &Utility{ int: C.olm_utility(unsafe.Pointer(&memory[0])), mem: memory, } } // Sha256 calculates the SHA-256 hash of the input and encodes it as base64. func (u *Utility) Sha256(input string) string { if len(input) == 0 { panic(EmptyInput) } output := make([]byte, u.sha256Len()) r := C.olm_sha256( (*C.OlmUtility)(u.int), unsafe.Pointer(&([]byte(input)[0])), C.size_t(len(input)), unsafe.Pointer(&(output[0])), C.size_t(len(output))) if r == errorVal() { panic(u.lastError()) } return string(output) } // VerifySignature verifies an ed25519 signature. Returns true if the verification // suceeds or false otherwise. Returns error on failure. If the key was too // small then the error will be "INVALID_BASE64". func (u *Utility) VerifySignature(message string, key id.Ed25519, signature string) (ok bool, err error) { if len(message) == 0 || len(key) == 0 || len(signature) == 0 { return false, EmptyInput } r := C.olm_ed25519_verify( (*C.OlmUtility)(u.int), unsafe.Pointer(&([]byte(key)[0])), C.size_t(len(key)), unsafe.Pointer(&([]byte(message)[0])), C.size_t(len(message)), unsafe.Pointer(&([]byte(signature)[0])), C.size_t(len(signature))) if r == errorVal() { err = u.lastError() if err == BadMessageMAC { err = nil } } else { ok = true } return ok, err } var gjsonEscaper = strings.NewReplacer( `\`, `\\`, ".", `\.`, "|", `\|`, "#", `\#`, "@", `\@`, "*", `\*`, "?", `\?`) func gjsonPath(path ...string) string { var result strings.Builder for i, part := range path { _, _ = gjsonEscaper.WriteString(&result, part) if i < len(path)-1 { result.WriteRune('.') } } return result.String() } // VerifySignatureJSON verifies the signature in the JSON object _obj following // the Matrix specification: // https://matrix.org/speculator/spec/drafts%2Fe2e/appendices.html#signing-json // If the _obj is a struct, the `json` tags will be honored. func (u *Utility) VerifySignatureJSON(obj interface{}, userID id.UserID, keyName string, key id.Ed25519) (bool, error) { objJSON, err := json.Marshal(obj) if err != nil { return false, err } sig := gjson.GetBytes(objJSON, gjsonPath("signatures", string(userID), fmt.Sprintf("ed25519:%s", keyName))) if !sig.Exists() || sig.Type != gjson.String { return false, SignatureNotFound } objJSON, err = sjson.DeleteBytes(objJSON, "unsigned") if err != nil { return false, err } objJSON, err = sjson.DeleteBytes(objJSON, "signatures") if err != nil { return false, err } objJSONString := string(canonicaljson.CanonicalJSONAssumeValid(objJSON)) return u.VerifySignature(objJSONString, key, sig.Str) } // VerifySignatureJSON verifies the signature in the JSON object _obj following // the Matrix specification: // https://matrix.org/speculator/spec/drafts%2Fe2e/appendices.html#signing-json // This function is a wrapper over Utility.VerifySignatureJSON that creates and // destroys the Utility object transparently. // If the _obj is a struct, the `json` tags will be honored. func VerifySignatureJSON(obj interface{}, userID id.UserID, keyName string, key id.Ed25519) (bool, error) { u := NewUtility() defer u.Clear() return u.VerifySignatureJSON(obj, userID, keyName, key) } go-0.11.1/crypto/olm/verification.go000066400000000000000000000067741436100171500173330ustar00rootroot00000000000000//go:build !nosas // +build !nosas package olm // #cgo LDFLAGS: -lolm -lstdc++ // #include // #include import "C" import ( "crypto/rand" "unsafe" ) // SAS stores an Olm Short Authentication String (SAS) object. type SAS struct { int *C.OlmSAS mem []byte } // NewBlankSAS initializes an empty SAS object. func NewBlankSAS() *SAS { memory := make([]byte, sasSize()) return &SAS{ int: C.olm_sas(unsafe.Pointer(&memory[0])), mem: memory, } } // sasSize is the size of a SAS object in bytes. func sasSize() uint { return uint(C.olm_sas_size()) } // sasRandomLength is the number of random bytes needed to create an SAS object. func (sas *SAS) sasRandomLength() uint { return uint(C.olm_create_sas_random_length(sas.int)) } // NewSAS creates a new SAS object. func NewSAS() *SAS { sas := NewBlankSAS() random := make([]byte, sas.sasRandomLength()+1) _, err := rand.Read(random) if err != nil { panic(NotEnoughGoRandom) } r := C.olm_create_sas( (*C.OlmSAS)(sas.int), unsafe.Pointer(&random[0]), C.size_t(len(random))) if r == errorVal() { panic(sas.lastError()) } else { return sas } } // clear clears the memory used to back an SAS object. func (sas *SAS) clear() uint { return uint(C.olm_clear_sas(sas.int)) } // lastError returns the most recent error to happen to an SAS object. func (sas *SAS) lastError() error { return convertError(C.GoString(C.olm_sas_last_error(sas.int))) } // pubkeyLength is the size of a public key in bytes. func (sas *SAS) pubkeyLength() uint { return uint(C.olm_sas_pubkey_length((*C.OlmSAS)(sas.int))) } // GetPubkey gets the public key for the SAS object. func (sas *SAS) GetPubkey() []byte { pubkey := make([]byte, sas.pubkeyLength()) r := C.olm_sas_get_pubkey( (*C.OlmSAS)(sas.int), unsafe.Pointer(&pubkey[0]), C.size_t(len(pubkey))) if r == errorVal() { panic(sas.lastError()) } return pubkey } // SetTheirKey sets the public key of the other user. func (sas *SAS) SetTheirKey(theirKey []byte) error { theirKeyCopy := make([]byte, len(theirKey)) copy(theirKeyCopy, theirKey) r := C.olm_sas_set_their_key( (*C.OlmSAS)(sas.int), unsafe.Pointer(&theirKeyCopy[0]), C.size_t(len(theirKeyCopy))) if r == errorVal() { return sas.lastError() } return nil } // GenerateBytes generates bytes to use for the short authentication string. func (sas *SAS) GenerateBytes(info []byte, count uint) ([]byte, error) { infoCopy := make([]byte, len(info)) copy(infoCopy, info) output := make([]byte, count) r := C.olm_sas_generate_bytes( (*C.OlmSAS)(sas.int), unsafe.Pointer(&infoCopy[0]), C.size_t(len(infoCopy)), unsafe.Pointer(&output[0]), C.size_t(len(output))) if r == errorVal() { return nil, sas.lastError() } return output, nil } // macLength is the size of a message authentication code generated by olm_sas_calculate_mac. func (sas *SAS) macLength() uint { return uint(C.olm_sas_mac_length((*C.OlmSAS)(sas.int))) } // CalculateMAC generates a message authentication code (MAC) based on the shared secret. func (sas *SAS) CalculateMAC(input []byte, info []byte) ([]byte, error) { inputCopy := make([]byte, len(input)) copy(inputCopy, input) infoCopy := make([]byte, len(info)) copy(infoCopy, info) mac := make([]byte, sas.macLength()) r := C.olm_sas_calculate_mac( (*C.OlmSAS)(sas.int), unsafe.Pointer(&inputCopy[0]), C.size_t(len(inputCopy)), unsafe.Pointer(&infoCopy[0]), C.size_t(len(infoCopy)), unsafe.Pointer(&mac[0]), C.size_t(len(mac))) if r == errorVal() { return nil, sas.lastError() } return mac, nil } go-0.11.1/crypto/sessions.go000066400000000000000000000123361436100171500157170ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package crypto import ( "errors" "time" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) var ( SessionNotShared = errors.New("session has not been shared") SessionExpired = errors.New("session has expired") ) // OlmSessionList is a list of OlmSessions. // It implements sort.Interface so that the session with recent successful decryptions comes first. type OlmSessionList []*OlmSession func (o OlmSessionList) Len() int { return len(o) } func (o OlmSessionList) Less(i, j int) bool { return o[i].LastDecryptedTime.After(o[j].LastEncryptedTime) } func (o OlmSessionList) Swap(i, j int) { o[i], o[j] = o[j], o[i] } type OlmSession struct { Internal olm.Session ExpirationMixin id id.SessionID } func (session *OlmSession) ID() id.SessionID { if session.id == "" { session.id = session.Internal.ID() } return session.id } func (session *OlmSession) Describe() string { return session.Internal.Describe() } func wrapSession(session *olm.Session) *OlmSession { return &OlmSession{ Internal: *session, ExpirationMixin: ExpirationMixin{ TimeMixin: TimeMixin{ CreationTime: time.Now(), LastEncryptedTime: time.Now(), LastDecryptedTime: time.Now(), }, }, } } func (account *OlmAccount) NewInboundSessionFrom(senderKey id.Curve25519, ciphertext string) (*OlmSession, error) { session, err := account.Internal.NewInboundSessionFrom(senderKey, ciphertext) if err != nil { return nil, err } _ = account.Internal.RemoveOneTimeKeys(session) return wrapSession(session), nil } func (session *OlmSession) Encrypt(plaintext []byte) (id.OlmMsgType, []byte) { session.LastEncryptedTime = time.Now() return session.Internal.Encrypt(plaintext) } func (session *OlmSession) Decrypt(ciphertext string, msgType id.OlmMsgType) ([]byte, error) { msg, err := session.Internal.Decrypt(ciphertext, msgType) if err == nil { session.LastDecryptedTime = time.Now() } return msg, err } type InboundGroupSession struct { Internal olm.InboundGroupSession SigningKey id.Ed25519 SenderKey id.Curve25519 RoomID id.RoomID ForwardingChains []string id id.SessionID } func NewInboundGroupSession(senderKey id.SenderKey, signingKey id.Ed25519, roomID id.RoomID, sessionKey string) (*InboundGroupSession, error) { igs, err := olm.NewInboundGroupSession([]byte(sessionKey)) if err != nil { return nil, err } return &InboundGroupSession{ Internal: *igs, SigningKey: signingKey, SenderKey: senderKey, RoomID: roomID, ForwardingChains: nil, }, nil } func (igs *InboundGroupSession) ID() id.SessionID { if igs.id == "" { igs.id = igs.Internal.ID() } return igs.id } type OGSState int const ( OGSNotShared OGSState = iota OGSAlreadyShared OGSIgnored ) type UserDevice struct { UserID id.UserID DeviceID id.DeviceID } type OutboundGroupSession struct { Internal olm.OutboundGroupSession ExpirationMixin MaxMessages int MessageCount int Users map[UserDevice]OGSState RoomID id.RoomID Shared bool id id.SessionID content *event.RoomKeyEventContent } func NewOutboundGroupSession(roomID id.RoomID, encryptionContent *event.EncryptionEventContent) *OutboundGroupSession { ogs := &OutboundGroupSession{ Internal: *olm.NewOutboundGroupSession(), ExpirationMixin: ExpirationMixin{ TimeMixin: TimeMixin{ CreationTime: time.Now(), LastEncryptedTime: time.Now(), }, MaxAge: 7 * 24 * time.Hour, }, MaxMessages: 100, Shared: false, Users: make(map[UserDevice]OGSState), RoomID: roomID, } if encryptionContent != nil { if encryptionContent.RotationPeriodMillis != 0 { ogs.MaxAge = time.Duration(encryptionContent.RotationPeriodMillis) * time.Millisecond } if encryptionContent.RotationPeriodMessages != 0 { ogs.MaxMessages = encryptionContent.RotationPeriodMessages } } return ogs } func (ogs *OutboundGroupSession) ShareContent() event.Content { if ogs.content == nil { ogs.content = &event.RoomKeyEventContent{ Algorithm: id.AlgorithmMegolmV1, RoomID: ogs.RoomID, SessionID: ogs.ID(), SessionKey: ogs.Internal.Key(), } } return event.Content{Parsed: ogs.content} } func (ogs *OutboundGroupSession) ID() id.SessionID { if ogs.id == "" { ogs.id = ogs.Internal.ID() } return ogs.id } func (ogs *OutboundGroupSession) Expired() bool { return ogs.MessageCount >= ogs.MaxMessages || ogs.ExpirationMixin.Expired() } func (ogs *OutboundGroupSession) Encrypt(plaintext []byte) ([]byte, error) { if !ogs.Shared { return nil, SessionNotShared } else if ogs.Expired() { return nil, SessionExpired } ogs.MessageCount++ ogs.LastEncryptedTime = time.Now() return ogs.Internal.Encrypt(plaintext), nil } type TimeMixin struct { CreationTime time.Time LastEncryptedTime time.Time LastDecryptedTime time.Time } type ExpirationMixin struct { TimeMixin MaxAge time.Duration } func (exp *ExpirationMixin) Expired() bool { if exp.MaxAge == 0 { return false } return exp.CreationTime.Add(exp.MaxAge).Before(time.Now()) } go-0.11.1/crypto/sql_store.go000066400000000000000000000640161436100171500160660ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package crypto import ( "database/sql" "database/sql/driver" "fmt" "strings" "sync" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/crypto/sql_store_upgrade" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) var PostgresArrayWrapper func(interface{}) interface { driver.Valuer sql.Scanner } // SQLCryptoStore is an implementation of a crypto Store for a database backend. type SQLCryptoStore struct { DB *sql.DB Log Logger Dialect string AccountID string DeviceID id.DeviceID SyncToken string PickleKey []byte Account *OlmAccount olmSessionCache map[id.SenderKey]map[id.SessionID]*OlmSession olmSessionCacheLock sync.Mutex } var _ Store = (*SQLCryptoStore)(nil) // NewSQLCryptoStore initializes a new crypto Store using the given database, for a device's crypto material. // The stored material will be encrypted with the given key. func NewSQLCryptoStore(db *sql.DB, dialect string, accountID string, deviceID id.DeviceID, pickleKey []byte, log Logger) *SQLCryptoStore { return &SQLCryptoStore{ DB: db, Dialect: dialect, Log: log, PickleKey: pickleKey, AccountID: accountID, DeviceID: deviceID, olmSessionCache: make(map[id.SenderKey]map[id.SessionID]*OlmSession), } } // CreateTables applies all the pending database migrations. func (store *SQLCryptoStore) CreateTables() error { return sql_store_upgrade.Upgrade(store.DB, store.Dialect) } // Flush does nothing for this implementation as data is already persisted in the database. func (store *SQLCryptoStore) Flush() error { return nil } // PutNextBatch stores the next sync batch token for the current account. func (store *SQLCryptoStore) PutNextBatch(nextBatch string) { store.SyncToken = nextBatch _, err := store.DB.Exec(`UPDATE crypto_account SET sync_token=$1 WHERE account_id=$2`, store.SyncToken, store.AccountID) if err != nil { store.Log.Warn("Failed to store sync token: %v", err) } } // GetNextBatch retrieves the next sync batch token for the current account. func (store *SQLCryptoStore) GetNextBatch() string { if store.SyncToken == "" { err := store.DB. QueryRow("SELECT sync_token FROM crypto_account WHERE account_id=$1", store.AccountID). Scan(&store.SyncToken) if err != nil && err != sql.ErrNoRows { store.Log.Warn("Failed to scan sync token: %v", err) } } return store.SyncToken } // PutAccount stores an OlmAccount in the database. func (store *SQLCryptoStore) PutAccount(account *OlmAccount) error { store.Account = account bytes := account.Internal.Pickle(store.PickleKey) _, err := store.DB.Exec(` INSERT INTO crypto_account (device_id, shared, sync_token, account, account_id) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (account_id) DO UPDATE SET shared=excluded.shared, sync_token=excluded.sync_token, account=excluded.account, account_id=excluded.account_id `, store.DeviceID, account.Shared, store.SyncToken, bytes, store.AccountID) if err != nil { store.Log.Warn("Failed to store account: %v", err) } return nil } // GetAccount retrieves an OlmAccount from the database. func (store *SQLCryptoStore) GetAccount() (*OlmAccount, error) { if store.Account == nil { row := store.DB.QueryRow("SELECT shared, sync_token, account FROM crypto_account WHERE account_id=$1", store.AccountID) acc := &OlmAccount{Internal: *olm.NewBlankAccount()} var accountBytes []byte err := row.Scan(&acc.Shared, &store.SyncToken, &accountBytes) if err == sql.ErrNoRows { return nil, nil } else if err != nil { return nil, err } err = acc.Internal.Unpickle(accountBytes, store.PickleKey) if err != nil { return nil, err } store.Account = acc } return store.Account, nil } // HasSession returns whether there is an Olm session for the given sender key. func (store *SQLCryptoStore) HasSession(key id.SenderKey) bool { store.olmSessionCacheLock.Lock() cache, ok := store.olmSessionCache[key] store.olmSessionCacheLock.Unlock() if ok && len(cache) > 0 { return true } var sessionID id.SessionID err := store.DB.QueryRow("SELECT session_id FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 LIMIT 1", key, store.AccountID).Scan(&sessionID) if err == sql.ErrNoRows { return false } return len(sessionID) > 0 } // GetSessions returns all the known Olm sessions for a sender key. func (store *SQLCryptoStore) GetSessions(key id.SenderKey) (OlmSessionList, error) { rows, err := store.DB.Query("SELECT session_id, session, created_at, last_encrypted, last_decrypted FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY last_decrypted DESC", key, store.AccountID) if err != nil { return nil, err } list := OlmSessionList{} store.olmSessionCacheLock.Lock() defer store.olmSessionCacheLock.Unlock() cache := store.getOlmSessionCache(key) for rows.Next() { sess := OlmSession{Internal: *olm.NewBlankSession()} var sessionBytes []byte var sessionID id.SessionID err := rows.Scan(&sessionID, &sessionBytes, &sess.CreationTime, &sess.LastEncryptedTime, &sess.LastDecryptedTime) if err != nil { return nil, err } else if existing, ok := cache[sessionID]; ok { list = append(list, existing) } else { err = sess.Internal.Unpickle(sessionBytes, store.PickleKey) if err != nil { return nil, err } list = append(list, &sess) cache[sess.ID()] = &sess } } return list, nil } func (store *SQLCryptoStore) getOlmSessionCache(key id.SenderKey) map[id.SessionID]*OlmSession { data, ok := store.olmSessionCache[key] if !ok { data = make(map[id.SessionID]*OlmSession) store.olmSessionCache[key] = data } return data } // GetLatestSession retrieves the Olm session for a given sender key from the database that has the largest ID. func (store *SQLCryptoStore) GetLatestSession(key id.SenderKey) (*OlmSession, error) { store.olmSessionCacheLock.Lock() defer store.olmSessionCacheLock.Unlock() row := store.DB.QueryRow("SELECT session_id, session, created_at, last_encrypted, last_decrypted FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY last_decrypted DESC LIMIT 1", key, store.AccountID) sess := OlmSession{Internal: *olm.NewBlankSession()} var sessionBytes []byte var sessionID id.SessionID err := row.Scan(&sessionID, &sessionBytes, &sess.CreationTime, &sess.LastEncryptedTime, &sess.LastDecryptedTime) if err == sql.ErrNoRows { return nil, nil } else if err != nil { return nil, err } cache := store.getOlmSessionCache(key) if oldSess, ok := cache[sessionID]; ok { return oldSess, nil } else if err = sess.Internal.Unpickle(sessionBytes, store.PickleKey); err != nil { return nil, err } else { cache[sessionID] = &sess return &sess, nil } } // AddSession persists an Olm session for a sender in the database. func (store *SQLCryptoStore) AddSession(key id.SenderKey, session *OlmSession) error { store.olmSessionCacheLock.Lock() defer store.olmSessionCacheLock.Unlock() sessionBytes := session.Internal.Pickle(store.PickleKey) _, err := store.DB.Exec("INSERT INTO crypto_olm_session (session_id, sender_key, session, created_at, last_encrypted, last_decrypted, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7)", session.ID(), key, sessionBytes, session.CreationTime, session.LastEncryptedTime, session.LastDecryptedTime, store.AccountID) store.getOlmSessionCache(key)[session.ID()] = session return err } // UpdateSession replaces the Olm session for a sender in the database. func (store *SQLCryptoStore) UpdateSession(_ id.SenderKey, session *OlmSession) error { sessionBytes := session.Internal.Pickle(store.PickleKey) _, err := store.DB.Exec("UPDATE crypto_olm_session SET session=$1, last_encrypted=$2, last_decrypted=$3 WHERE session_id=$4 AND account_id=$5", sessionBytes, session.LastEncryptedTime, session.LastDecryptedTime, session.ID(), store.AccountID) return err } // PutGroupSession stores an inbound Megolm group session for a room, sender and session. func (store *SQLCryptoStore) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, session *InboundGroupSession) error { sessionBytes := session.Internal.Pickle(store.PickleKey) forwardingChains := strings.Join(session.ForwardingChains, ",") _, err := store.DB.Exec(` INSERT INTO crypto_megolm_inbound_session (session_id, sender_key, signing_key, room_id, session, forwarding_chains, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7) ON CONFLICT (session_id, account_id) DO UPDATE SET withheld_code=NULL, withheld_reason=NULL, sender_key=excluded.sender_key, signing_key=excluded.signing_key, room_id=excluded.room_id, session=excluded.session, forwarding_chains=excluded.forwarding_chains `, sessionID, senderKey, session.SigningKey, roomID, sessionBytes, forwardingChains, store.AccountID) return err } // GetGroupSession retrieves an inbound Megolm group session for a room, sender and session. func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*InboundGroupSession, error) { var signingKey, forwardingChains, withheldCode sql.NullString var sessionBytes []byte err := store.DB.QueryRow(` SELECT signing_key, session, forwarding_chains, withheld_code FROM crypto_megolm_inbound_session WHERE room_id=$1 AND sender_key=$2 AND session_id=$3 AND account_id=$4`, roomID, senderKey, sessionID, store.AccountID, ).Scan(&signingKey, &sessionBytes, &forwardingChains, &withheldCode) if err == sql.ErrNoRows { return nil, nil } else if err != nil { return nil, err } else if withheldCode.Valid { return nil, fmt.Errorf("%w (%s)", ErrGroupSessionWithheld, withheldCode.String) } igs := olm.NewBlankInboundGroupSession() err = igs.Unpickle(sessionBytes, store.PickleKey) if err != nil { return nil, err } return &InboundGroupSession{ Internal: *igs, SigningKey: id.Ed25519(signingKey.String), SenderKey: senderKey, RoomID: roomID, ForwardingChains: strings.Split(forwardingChains.String, ","), }, nil } func (store *SQLCryptoStore) PutWithheldGroupSession(content event.RoomKeyWithheldEventContent) error { _, err := store.DB.Exec("INSERT INTO crypto_megolm_inbound_session (session_id, sender_key, room_id, withheld_code, withheld_reason, account_id) VALUES ($1, $2, $3, $4, $5, $6)", content.SessionID, content.SenderKey, content.RoomID, content.Code, content.Reason, store.AccountID) return err } func (store *SQLCryptoStore) GetWithheldGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*event.RoomKeyWithheldEventContent, error) { var code, reason sql.NullString err := store.DB.QueryRow(` SELECT withheld_code, withheld_reason FROM crypto_megolm_inbound_session WHERE room_id=$1 AND sender_key=$2 AND session_id=$3 AND account_id=$4`, roomID, senderKey, sessionID, store.AccountID, ).Scan(&code, &reason) if err == sql.ErrNoRows { return nil, nil } else if err != nil || !code.Valid { return nil, err } return &event.RoomKeyWithheldEventContent{ RoomID: roomID, Algorithm: id.AlgorithmMegolmV1, SessionID: sessionID, SenderKey: senderKey, Code: event.RoomKeyWithheldCode(code.String), Reason: reason.String, }, nil } func (store *SQLCryptoStore) scanGroupSessionList(rows *sql.Rows) (result []*InboundGroupSession) { for rows.Next() { var roomID id.RoomID var signingKey, senderKey, forwardingChains sql.NullString var sessionBytes []byte err := rows.Scan(&roomID, &signingKey, &senderKey, &sessionBytes, &forwardingChains) if err != nil { store.Log.Warn("Failed to scan row: %v", err) continue } igs := olm.NewBlankInboundGroupSession() err = igs.Unpickle(sessionBytes, store.PickleKey) if err != nil { store.Log.Warn("Failed to unpickle session: %v", err) continue } result = append(result, &InboundGroupSession{ Internal: *igs, SigningKey: id.Ed25519(signingKey.String), SenderKey: id.Curve25519(senderKey.String), RoomID: roomID, ForwardingChains: strings.Split(forwardingChains.String, ","), }) } return } func (store *SQLCryptoStore) GetGroupSessionsForRoom(roomID id.RoomID) ([]*InboundGroupSession, error) { rows, err := store.DB.Query(` SELECT room_id, signing_key, sender_key, session, forwarding_chains FROM crypto_megolm_inbound_session WHERE room_id=$1 AND account_id=$2`, roomID, store.AccountID, ) if err == sql.ErrNoRows { return []*InboundGroupSession{}, nil } else if err != nil { return nil, err } return store.scanGroupSessionList(rows), nil } func (store *SQLCryptoStore) GetAllGroupSessions() ([]*InboundGroupSession, error) { rows, err := store.DB.Query(` SELECT room_id, signing_key, sender_key, session, forwarding_chains FROM crypto_megolm_inbound_session WHERE account_id=$2`, store.AccountID, ) if err == sql.ErrNoRows { return []*InboundGroupSession{}, nil } else if err != nil { return nil, err } return store.scanGroupSessionList(rows), nil } // AddOutboundGroupSession stores an outbound Megolm session, along with the information about the room and involved devices. func (store *SQLCryptoStore) AddOutboundGroupSession(session *OutboundGroupSession) error { sessionBytes := session.Internal.Pickle(store.PickleKey) _, err := store.DB.Exec(` INSERT INTO crypto_megolm_outbound_session (room_id, session_id, session, shared, max_messages, message_count, max_age, created_at, last_used, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) ON CONFLICT (account_id, room_id) DO UPDATE SET session_id=excluded.session_id, session=excluded.session, shared=excluded.shared, max_messages=excluded.max_messages, message_count=excluded.message_count, max_age=excluded.max_age, created_at=excluded.created_at, last_used=excluded.last_used, account_id=excluded.account_id `, session.RoomID, session.ID(), sessionBytes, session.Shared, session.MaxMessages, session.MessageCount, session.MaxAge, session.CreationTime, session.LastEncryptedTime, store.AccountID) return err } // UpdateOutboundGroupSession replaces an outbound Megolm session with for same room and session ID. func (store *SQLCryptoStore) UpdateOutboundGroupSession(session *OutboundGroupSession) error { sessionBytes := session.Internal.Pickle(store.PickleKey) _, err := store.DB.Exec("UPDATE crypto_megolm_outbound_session SET session=$1, message_count=$2, last_used=$3 WHERE room_id=$4 AND session_id=$5 AND account_id=$6", sessionBytes, session.MessageCount, session.LastEncryptedTime, session.RoomID, session.ID(), store.AccountID) return err } // GetOutboundGroupSession retrieves the outbound Megolm session for the given room ID. func (store *SQLCryptoStore) GetOutboundGroupSession(roomID id.RoomID) (*OutboundGroupSession, error) { var ogs OutboundGroupSession var sessionBytes []byte err := store.DB.QueryRow(` SELECT session, shared, max_messages, message_count, max_age, created_at, last_used FROM crypto_megolm_outbound_session WHERE room_id=$1 AND account_id=$2`, roomID, store.AccountID, ).Scan(&sessionBytes, &ogs.Shared, &ogs.MaxMessages, &ogs.MessageCount, &ogs.MaxAge, &ogs.CreationTime, &ogs.LastEncryptedTime) if err == sql.ErrNoRows { return nil, nil } else if err != nil { return nil, err } intOGS := olm.NewBlankOutboundGroupSession() err = intOGS.Unpickle(sessionBytes, store.PickleKey) if err != nil { return nil, err } ogs.Internal = *intOGS ogs.RoomID = roomID return &ogs, nil } // RemoveOutboundGroupSession removes the outbound Megolm session for the given room ID. func (store *SQLCryptoStore) RemoveOutboundGroupSession(roomID id.RoomID) error { _, err := store.DB.Exec("DELETE FROM crypto_megolm_outbound_session WHERE room_id=$1 AND account_id=$2", roomID, store.AccountID) return err } // ValidateMessageIndex returns whether the given event information match the ones stored in the database // for the given sender key, session ID and index. // If the event information was not yet stored, it's stored now. func (store *SQLCryptoStore) ValidateMessageIndex(senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) bool { var resultEventID id.EventID var resultTimestamp int64 err := store.DB.QueryRow( `SELECT event_id, timestamp FROM crypto_message_index WHERE sender_key=$1 AND session_id=$2 AND "index"=$3`, senderKey, sessionID, index, ).Scan(&resultEventID, &resultTimestamp) if err == sql.ErrNoRows { _, err := store.DB.Exec(`INSERT INTO crypto_message_index (sender_key, session_id, "index", event_id, timestamp) VALUES ($1, $2, $3, $4, $5)`, senderKey, sessionID, index, eventID, timestamp) if err != nil { store.Log.Warn("Failed to store message index: %v", err) } return true } else if err != nil { store.Log.Warn("Failed to scan message index: %v", err) return true } if resultEventID != eventID || resultTimestamp != timestamp { return false } return true } // GetDevices returns a map of device IDs to device identities, including the identity and signing keys, for a given user ID. func (store *SQLCryptoStore) GetDevices(userID id.UserID) (map[id.DeviceID]*DeviceIdentity, error) { var ignore id.UserID err := store.DB.QueryRow("SELECT user_id FROM crypto_tracked_user WHERE user_id=$1", userID).Scan(&ignore) if err == sql.ErrNoRows { return nil, nil } else if err != nil { return nil, err } rows, err := store.DB.Query("SELECT device_id, identity_key, signing_key, trust, deleted, name FROM crypto_device WHERE user_id=$1", userID) if err != nil { return nil, err } data := make(map[id.DeviceID]*DeviceIdentity) for rows.Next() { var identity DeviceIdentity err := rows.Scan(&identity.DeviceID, &identity.IdentityKey, &identity.SigningKey, &identity.Trust, &identity.Deleted, &identity.Name) if err != nil { return nil, err } identity.UserID = userID data[identity.DeviceID] = &identity } return data, nil } // GetDevice returns the device dentity for a given user and device ID. func (store *SQLCryptoStore) GetDevice(userID id.UserID, deviceID id.DeviceID) (*DeviceIdentity, error) { var identity DeviceIdentity err := store.DB.QueryRow(` SELECT identity_key, signing_key, trust, deleted, name FROM crypto_device WHERE user_id=$1 AND device_id=$2`, userID, deviceID, ).Scan(&identity.IdentityKey, &identity.SigningKey, &identity.Trust, &identity.Deleted, &identity.Name) if err != nil { if err == sql.ErrNoRows { return nil, nil } return nil, err } identity.UserID = userID identity.DeviceID = deviceID return &identity, nil } // FindDeviceByKey finds a specific device by its sender key. func (store *SQLCryptoStore) FindDeviceByKey(userID id.UserID, identityKey id.IdentityKey) (*DeviceIdentity, error) { var identity DeviceIdentity err := store.DB.QueryRow(` SELECT device_id, signing_key, trust, deleted, name FROM crypto_device WHERE user_id=$1 AND identity_key=$2`, userID, identityKey, ).Scan(&identity.DeviceID, &identity.SigningKey, &identity.Trust, &identity.Deleted, &identity.Name) if err != nil { if err == sql.ErrNoRows { return nil, nil } return nil, err } identity.UserID = userID identity.IdentityKey = identityKey return &identity, nil } // PutDevice stores a single device for a user, replacing it if it exists already. func (store *SQLCryptoStore) PutDevice(userID id.UserID, device *DeviceIdentity) error { _, err := store.DB.Exec(` INSERT INTO crypto_device (user_id, device_id, identity_key, signing_key, trust, deleted, name) VALUES ($1, $2, $3, $4, $5, $6, $7) ON CONFLICT (user_id, device_id) DO UPDATE SET identity_key=excluded.identity_key, signing_key=excluded.signing_key, trust=excluded.trust, deleted=excluded.deleted, name=excluded.name`, userID, device.DeviceID, device.IdentityKey, device.SigningKey, device.Trust, device.Deleted, device.Name) return err } // PutDevices stores the device identity information for the given user ID. func (store *SQLCryptoStore) PutDevices(userID id.UserID, devices map[id.DeviceID]*DeviceIdentity) error { tx, err := store.DB.Begin() if err != nil { return err } _, err = tx.Exec("INSERT INTO crypto_tracked_user (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID) if err != nil { return fmt.Errorf("failed to add user to tracked users list: %w", err) } _, err = tx.Exec("DELETE FROM crypto_device WHERE user_id=$1", userID) if err != nil { _ = tx.Rollback() return fmt.Errorf("failed to delete old devices: %w", err) } if len(devices) == 0 { err = tx.Commit() if err != nil { return fmt.Errorf("failed to commit changes (no devices added): %w", err) } return nil } deviceBatchLen := 5 // how many devices will be inserted per query deviceIDs := make([]id.DeviceID, 0, len(devices)) for deviceID := range devices { deviceIDs = append(deviceIDs, deviceID) } for batchDeviceIdx := 0; batchDeviceIdx < len(deviceIDs); batchDeviceIdx += deviceBatchLen { var batchDevices []id.DeviceID if batchDeviceIdx+deviceBatchLen < len(deviceIDs) { batchDevices = deviceIDs[batchDeviceIdx : batchDeviceIdx+deviceBatchLen] } else { batchDevices = deviceIDs[batchDeviceIdx:] } values := make([]interface{}, 1, len(devices)*6+1) values[0] = userID valueStrings := make([]string, 0, len(devices)) i := 2 for _, deviceID := range batchDevices { identity := devices[deviceID] values = append(values, deviceID, identity.IdentityKey, identity.SigningKey, identity.Trust, identity.Deleted, identity.Name) valueStrings = append(valueStrings, fmt.Sprintf("($1, $%d, $%d, $%d, $%d, $%d, $%d)", i, i+1, i+2, i+3, i+4, i+5)) i += 6 } valueString := strings.Join(valueStrings, ",") _, err = tx.Exec("INSERT INTO crypto_device (user_id, device_id, identity_key, signing_key, trust, deleted, name) VALUES "+valueString, values...) if err != nil { _ = tx.Rollback() return fmt.Errorf("failed to insert new devices: %w", err) } } err = tx.Commit() if err != nil { return fmt.Errorf("failed to commit changes: %w", err) } return nil } // FilterTrackedUsers finds all of the user IDs out of the given ones for which the database contains identity information. func (store *SQLCryptoStore) FilterTrackedUsers(users []id.UserID) []id.UserID { var rows *sql.Rows var err error if store.Dialect == "postgres" && PostgresArrayWrapper != nil { rows, err = store.DB.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id = ANY($1)", PostgresArrayWrapper(users)) } else { queryString := make([]string, len(users)) params := make([]interface{}, len(users)) for i, user := range users { queryString[i] = fmt.Sprintf("$%d", i+1) params[i] = user } rows, err = store.DB.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id IN ("+strings.Join(queryString, ",")+")", params...) } if err != nil { store.Log.Warn("Failed to filter tracked users: %v", err) return users } var ptr int for rows.Next() { err = rows.Scan(&users[ptr]) if err != nil { store.Log.Warn("Failed to scan tracked user ID: %v", err) } else { ptr++ } } return users[:ptr] } // PutCrossSigningKey stores a cross-signing key of some user along with its usage. func (store *SQLCryptoStore) PutCrossSigningKey(userID id.UserID, usage id.CrossSigningUsage, key id.Ed25519) error { _, err := store.DB.Exec(` INSERT INTO crypto_cross_signing_keys (user_id, usage, key) VALUES ($1, $2, $3) ON CONFLICT (user_id, usage) DO UPDATE SET key=excluded.key `, userID, usage, key) return err } // GetCrossSigningKeys retrieves a user's stored cross-signing keys. func (store *SQLCryptoStore) GetCrossSigningKeys(userID id.UserID) (map[id.CrossSigningUsage]id.Ed25519, error) { rows, err := store.DB.Query("SELECT usage, key FROM crypto_cross_signing_keys WHERE user_id=$1", userID) if err != nil { return nil, err } data := make(map[id.CrossSigningUsage]id.Ed25519) for rows.Next() { var usage id.CrossSigningUsage var key id.Ed25519 err := rows.Scan(&usage, &key) if err != nil { return nil, err } data[usage] = key } return data, nil } // PutSignature stores a signature of a cross-signing or device key along with the signer's user ID and key. func (store *SQLCryptoStore) PutSignature(signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519, signature string) error { _, err := store.DB.Exec(` INSERT INTO crypto_cross_signing_signatures (signed_user_id, signed_key, signer_user_id, signer_key, signature) VALUES ($1, $2, $3, $4, $5) ON CONFLICT (signed_user_id, signed_key, signer_user_id, signer_key) DO UPDATE SET signature=excluded.signature `, signedUserID, signedKey, signerUserID, signerKey, signature) return err } // GetSignaturesForKeyBy retrieves the stored signatures for a given cross-signing or device key, by the given signer. func (store *SQLCryptoStore) GetSignaturesForKeyBy(userID id.UserID, key id.Ed25519, signerID id.UserID) (map[id.Ed25519]string, error) { rows, err := store.DB.Query("SELECT signer_key, signature FROM crypto_cross_signing_signatures WHERE signed_user_id=$1 AND signed_key=$2 AND signer_user_id=$3", userID, key, signerID) if err != nil { return nil, err } data := make(map[id.Ed25519]string) for rows.Next() { var signerKey id.Ed25519 var signature string err := rows.Scan(&signerKey, &signature) if err != nil { return nil, err } data[signerKey] = signature } return data, nil } // IsKeySignedBy returns whether a cross-signing or device key is signed by the given signer. func (store *SQLCryptoStore) IsKeySignedBy(userID id.UserID, key id.Ed25519, signerID id.UserID, signerKey id.Ed25519) (bool, error) { sigs, err := store.GetSignaturesForKeyBy(userID, key, signerID) if err != nil { return false, err } _, ok := sigs[signerKey] return ok, nil } // DropSignaturesByKey deletes the signatures made by the given user and key from the store. It returns the number of signatures deleted. func (store *SQLCryptoStore) DropSignaturesByKey(userID id.UserID, key id.Ed25519) (int64, error) { res, err := store.DB.Exec("DELETE FROM crypto_cross_signing_signatures WHERE signer_user_id=$1 AND signer_key=$2", userID, key) if err != nil { return 0, err } count, err := res.RowsAffected() if err != nil { return 0, err } return count, nil } go-0.11.1/crypto/sql_store_upgrade/000077500000000000000000000000001436100171500172375ustar00rootroot00000000000000go-0.11.1/crypto/sql_store_upgrade/upgrade.go000066400000000000000000000265571436100171500212340ustar00rootroot00000000000000package sql_store_upgrade import ( "database/sql" "errors" "fmt" "strings" ) type upgradeFunc func(*sql.Tx, string) error var ErrUnknownDialect = errors.New("unknown dialect") var Upgrades = [...]upgradeFunc{ func(tx *sql.Tx, _ string) error { for _, query := range []string{ `CREATE TABLE IF NOT EXISTS crypto_account ( device_id VARCHAR(255) PRIMARY KEY, shared BOOLEAN NOT NULL, sync_token TEXT NOT NULL, account bytea NOT NULL )`, `CREATE TABLE IF NOT EXISTS crypto_message_index ( sender_key CHAR(43), session_id CHAR(43), "index" INTEGER, event_id VARCHAR(255) NOT NULL, timestamp BIGINT NOT NULL, PRIMARY KEY (sender_key, session_id, "index") )`, `CREATE TABLE IF NOT EXISTS crypto_tracked_user ( user_id VARCHAR(255) PRIMARY KEY )`, `CREATE TABLE IF NOT EXISTS crypto_device ( user_id VARCHAR(255), device_id VARCHAR(255), identity_key CHAR(43) NOT NULL, signing_key CHAR(43) NOT NULL, trust SMALLINT NOT NULL, deleted BOOLEAN NOT NULL, name VARCHAR(255) NOT NULL, PRIMARY KEY (user_id, device_id) )`, `CREATE TABLE IF NOT EXISTS crypto_olm_session ( session_id CHAR(43) PRIMARY KEY, sender_key CHAR(43) NOT NULL, session bytea NOT NULL, created_at timestamp NOT NULL, last_used timestamp NOT NULL )`, `CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session ( session_id CHAR(43) PRIMARY KEY, sender_key CHAR(43) NOT NULL, signing_key CHAR(43) NOT NULL, room_id VARCHAR(255) NOT NULL, session bytea NOT NULL, forwarding_chains bytea NOT NULL )`, `CREATE TABLE IF NOT EXISTS crypto_megolm_outbound_session ( room_id VARCHAR(255) PRIMARY KEY, session_id CHAR(43) NOT NULL UNIQUE, session bytea NOT NULL, shared BOOLEAN NOT NULL, max_messages INTEGER NOT NULL, message_count INTEGER NOT NULL, max_age BIGINT NOT NULL, created_at timestamp NOT NULL, last_used timestamp NOT NULL )`, } { if _, err := tx.Exec(query); err != nil { return err } } return nil }, func(tx *sql.Tx, dialect string) error { if dialect == "postgres" { tablesToPkeys := map[string][]string{ "crypto_account": {}, "crypto_olm_session": {"session_id"}, "crypto_megolm_inbound_session": {"session_id"}, "crypto_megolm_outbound_session": {"room_id"}, } for tableName, pkeys := range tablesToPkeys { // add account_id to primary key pkeyStr := strings.Join(append(pkeys, "account_id"), ", ") for _, query := range []string{ fmt.Sprintf("ALTER TABLE %s ADD COLUMN account_id VARCHAR(255)", tableName), fmt.Sprintf("UPDATE %s SET account_id=''", tableName), fmt.Sprintf("ALTER TABLE %s ALTER COLUMN account_id SET NOT NULL", tableName), fmt.Sprintf("ALTER TABLE %s DROP CONSTRAINT %s_pkey", tableName, tableName), fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s_pkey PRIMARY KEY (%s)", tableName, tableName, pkeyStr), } { if _, err := tx.Exec(query); err != nil { return err } } } } else if dialect == "sqlite3" { tableCols := map[string]string{ "crypto_account": ` account_id VARCHAR(255) NOT NULL, device_id VARCHAR(255) NOT NULL, shared BOOLEAN NOT NULL, sync_token TEXT NOT NULL, account BLOB NOT NULL, PRIMARY KEY (account_id) `, "crypto_olm_session": ` account_id VARCHAR(255) NOT NULL, session_id CHAR(43) NOT NULL, sender_key CHAR(43) NOT NULL, session BLOB NOT NULL, created_at timestamp NOT NULL, last_used timestamp NOT NULL, PRIMARY KEY (account_id, session_id) `, "crypto_megolm_inbound_session": ` account_id VARCHAR(255) NOT NULL, session_id CHAR(43) NOT NULL, sender_key CHAR(43) NOT NULL, signing_key CHAR(43) NOT NULL, room_id VARCHAR(255) NOT NULL, session BLOB NOT NULL, forwarding_chains BLOB NOT NULL, PRIMARY KEY (account_id, session_id) `, "crypto_megolm_outbound_session": ` account_id VARCHAR(255) NOT NULL, room_id VARCHAR(255) NOT NULL, session_id CHAR(43) NOT NULL UNIQUE, session BLOB NOT NULL, shared BOOLEAN NOT NULL, max_messages INTEGER NOT NULL, message_count INTEGER NOT NULL, max_age BIGINT NOT NULL, created_at timestamp NOT NULL, last_used timestamp NOT NULL, PRIMARY KEY (account_id, room_id) `, } for tableName, cols := range tableCols { // re-create tables with account_id column and new pkey and re-insert rows for _, query := range []string{ fmt.Sprintf("ALTER TABLE %s RENAME TO old_%s", tableName, tableName), fmt.Sprintf("CREATE TABLE %s (%s)", tableName, cols), fmt.Sprintf("INSERT INTO %s SELECT '', * FROM old_%s", tableName, tableName), fmt.Sprintf("DROP TABLE old_%s", tableName), } { if _, err := tx.Exec(query); err != nil { return err } } } } else { return fmt.Errorf("%w (%s)", ErrUnknownDialect, dialect) } return nil }, func(tx *sql.Tx, dialect string) error { if dialect == "postgres" { alters := [...]string{ "ADD COLUMN withheld_code VARCHAR(255)", "ADD COLUMN withheld_reason TEXT", "ALTER COLUMN signing_key DROP NOT NULL", "ALTER COLUMN session DROP NOT NULL", "ALTER COLUMN forwarding_chains DROP NOT NULL", } for _, alter := range alters { _, err := tx.Exec(fmt.Sprintf("ALTER TABLE crypto_megolm_inbound_session %s", alter)) if err != nil { return err } } } else if dialect == "sqlite3" { _, err := tx.Exec("ALTER TABLE crypto_megolm_inbound_session RENAME TO old_crypto_megolm_inbound_session") if err != nil { return err } _, err = tx.Exec(`CREATE TABLE crypto_megolm_inbound_session ( account_id VARCHAR(255) NOT NULL, session_id CHAR(43) NOT NULL, sender_key CHAR(43) NOT NULL, signing_key CHAR(43), room_id VARCHAR(255) NOT NULL, session BLOB, forwarding_chains BLOB, withheld_code VARCHAR(255), withheld_reason TEXT, PRIMARY KEY (account_id, session_id) )`) if err != nil { return err } _, err = tx.Exec(`INSERT INTO crypto_megolm_inbound_session (session_id, sender_key, signing_key, room_id, session, forwarding_chains, account_id) SELECT * FROM old_crypto_megolm_inbound_session`) if err != nil { return err } _, err = tx.Exec("DROP TABLE old_crypto_megolm_inbound_session") if err != nil { return err } } else { return fmt.Errorf("%w (%s)", ErrUnknownDialect, dialect) } return nil }, func(tx *sql.Tx, dialect string) error { if _, err := tx.Exec( `CREATE TABLE IF NOT EXISTS crypto_cross_signing_keys ( user_id VARCHAR(255) NOT NULL, usage VARCHAR(20) NOT NULL, key CHAR(43) NOT NULL, PRIMARY KEY (user_id, usage) )`, ); err != nil { return err } if _, err := tx.Exec( `CREATE TABLE IF NOT EXISTS crypto_cross_signing_signatures ( signed_user_id VARCHAR(255) NOT NULL, signed_key VARCHAR(255) NOT NULL, signer_user_id VARCHAR(255) NOT NULL, signer_key VARCHAR(255) NOT NULL, signature CHAR(88) NOT NULL, PRIMARY KEY (signed_user_id, signed_key, signer_user_id, signer_key) )`, ); err != nil { return err } return nil }, func(tx *sql.Tx, dialect string) error { if dialect == "sqlite3" { // SQLite doesn't enforce varchar sizes anyway return nil } alters := [...]string{ `ALTER TABLE crypto_account ALTER COLUMN device_id TYPE TEXT`, `ALTER TABLE crypto_account ALTER COLUMN account_id TYPE TEXT`, `ALTER TABLE crypto_device ALTER COLUMN user_id TYPE TEXT`, `ALTER TABLE crypto_device ALTER COLUMN device_id TYPE TEXT`, `ALTER TABLE crypto_device ALTER COLUMN name TYPE TEXT`, `ALTER TABLE crypto_megolm_inbound_session ALTER COLUMN room_id TYPE TEXT`, `ALTER TABLE crypto_megolm_inbound_session ALTER COLUMN account_id TYPE TEXT`, `ALTER TABLE crypto_megolm_inbound_session ALTER COLUMN withheld_code TYPE TEXT`, `ALTER TABLE crypto_megolm_outbound_session ALTER COLUMN room_id TYPE TEXT`, `ALTER TABLE crypto_megolm_outbound_session ALTER COLUMN account_id TYPE TEXT`, `ALTER TABLE crypto_message_index ALTER COLUMN event_id TYPE TEXT`, `ALTER TABLE crypto_olm_session ALTER COLUMN account_id TYPE TEXT`, `ALTER TABLE crypto_tracked_user ALTER COLUMN user_id TYPE TEXT`, `ALTER TABLE crypto_cross_signing_keys ALTER COLUMN user_id TYPE TEXT`, `ALTER TABLE crypto_cross_signing_keys ALTER COLUMN usage TYPE TEXT`, `ALTER TABLE crypto_cross_signing_signatures ALTER COLUMN signed_user_id TYPE TEXT`, `ALTER TABLE crypto_cross_signing_signatures ALTER COLUMN signed_key TYPE TEXT`, `ALTER TABLE crypto_cross_signing_signatures ALTER COLUMN signer_user_id TYPE TEXT`, `ALTER TABLE crypto_cross_signing_signatures ALTER COLUMN signer_key TYPE TEXT`, `ALTER TABLE crypto_cross_signing_signatures ALTER COLUMN signature TYPE TEXT`, } for _, alter := range alters { _, err := tx.Exec(alter) if err != nil { return err } } return nil }, func(tx *sql.Tx, dialect string) error { _, err := tx.Exec("ALTER TABLE crypto_olm_session RENAME COLUMN last_used TO last_decrypted") if err != nil { return err } _, err = tx.Exec("ALTER TABLE crypto_olm_session ADD COLUMN last_encrypted timestamp") if err != nil { return err } _, err = tx.Exec("UPDATE crypto_olm_session SET last_encrypted=last_decrypted") if err != nil { return err } if dialect == "postgres" { // This is too hard to do on sqlite, so let's just do it on postgres _, err = tx.Exec("ALTER TABLE crypto_olm_session ALTER COLUMN last_encrypted SET NOT NULL") if err != nil { return err } } return nil }, } // GetVersion returns the current version of the DB schema. func GetVersion(db *sql.DB) (int, error) { _, err := db.Exec("CREATE TABLE IF NOT EXISTS crypto_version (version INTEGER)") if err != nil { return -1, err } version := 0 row := db.QueryRow("SELECT version FROM crypto_version LIMIT 1") if row != nil { _ = row.Scan(&version) } return version, nil } // SetVersion sets the schema version in a running DB transaction. func SetVersion(tx *sql.Tx, version int) error { _, err := tx.Exec("DELETE FROM crypto_version") if err != nil { return err } _, err = tx.Exec("INSERT INTO crypto_version (version) VALUES ($1)", version) return err } // Upgrade upgrades the database from the current to the latest version available. func Upgrade(db *sql.DB, dialect string) error { version, err := GetVersion(db) if err != nil { return err } // perform migrations starting with #version for ; version < len(Upgrades); version++ { tx, err := db.Begin() if err != nil { return err } // run each migrate func migrateFunc := Upgrades[version] err = migrateFunc(tx, dialect) if err != nil { _ = tx.Rollback() return err } // also update the version in this tx if err = SetVersion(tx, version+1); err != nil { return err } if err = tx.Commit(); err != nil { return err } } return nil } go-0.11.1/crypto/ssss/000077500000000000000000000000001436100171500145105ustar00rootroot00000000000000go-0.11.1/crypto/ssss/client.go000066400000000000000000000072001436100171500163140ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package ssss import ( "fmt" "maunium.net/go/mautrix" "maunium.net/go/mautrix/event" ) // Machine contains utility methods for interacting with SSSS data on the server. type Machine struct { Client *mautrix.Client } func NewSSSSMachine(client *mautrix.Client) *Machine { return &Machine{ Client: client, } } type DefaultSecretStorageKeyContent struct { KeyID string `json:"key"` } // GetDefaultKeyID retrieves the default key ID for this account from SSSS. func (mach *Machine) GetDefaultKeyID() (string, error) { var data DefaultSecretStorageKeyContent err := mach.Client.GetAccountData(event.AccountDataSecretStorageDefaultKey.Type, &data) if err != nil { if httpErr, ok := err.(mautrix.HTTPError); ok && httpErr.RespError != nil && httpErr.RespError.ErrCode == "M_NOT_FOUND" { return "", ErrNoDefaultKeyAccountDataEvent } return "", fmt.Errorf("failed to get default key account data from server: %w", err) } if len(data.KeyID) == 0 { return "", ErrNoKeyFieldInAccountDataEvent } return data.KeyID, nil } // SetDefaultKeyID sets the default key ID for this account on the server. func (mach *Machine) SetDefaultKeyID(keyID string) error { return mach.Client.SetAccountData(event.AccountDataSecretStorageDefaultKey.Type, &DefaultSecretStorageKeyContent{keyID}) } // GetKeyData gets the details about the given key ID. func (mach *Machine) GetKeyData(keyID string) (keyData *KeyMetadata, err error) { keyData = &KeyMetadata{id: keyID} err = mach.Client.GetAccountData(fmt.Sprintf("%s.%s", event.AccountDataSecretStorageKey.Type, keyID), keyData) return } // SetKeyData stores SSSS key metadata on the server. func (mach *Machine) SetKeyData(keyID string, keyData *KeyMetadata) error { return mach.Client.SetAccountData(fmt.Sprintf("%s.%s", event.AccountDataSecretStorageKey.Type, keyID), keyData) } // GetDefaultKeyData gets the details about the default key ID (see GetDefaultKeyID). func (mach *Machine) GetDefaultKeyData() (keyID string, keyData *KeyMetadata, err error) { keyID, err = mach.GetDefaultKeyID() if err != nil { return } keyData, err = mach.GetKeyData(keyID) return } // GetDecryptedAccountData gets the account data event with the given event type and decrypts it using the given key. func (mach *Machine) GetDecryptedAccountData(eventType event.Type, key *Key) ([]byte, error) { var encData EncryptedAccountDataEventContent err := mach.Client.GetAccountData(eventType.Type, &encData) if err != nil { return nil, err } return encData.Decrypt(eventType.Type, key) } // SetEncryptedAccountData encrypts the given data with the given keys and stores it on the server. func (mach *Machine) SetEncryptedAccountData(eventType event.Type, data []byte, keys ...*Key) error { if len(keys) == 0 { return ErrNoKeyGiven } encrypted := make(map[string]EncryptedKeyData, len(keys)) for _, key := range keys { encrypted[key.ID] = key.Encrypt(eventType.Type, data) } return mach.Client.SetAccountData(eventType.Type, &EncryptedAccountDataEventContent{Encrypted: encrypted}) } // GenerateAndUploadKey generates a new SSSS key and stores the metadata on the server. func (mach *Machine) GenerateAndUploadKey(passphrase string) (key *Key, err error) { key, err = NewKey(passphrase) if err != nil { return nil, fmt.Errorf("failed to generate new key: %w", err) } err = mach.SetKeyData(key.ID, key.Metadata) if err != nil { err = fmt.Errorf("failed to upload key: %w", err) } return key, err } go-0.11.1/crypto/ssss/key.go000066400000000000000000000100761436100171500156330ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package ssss import ( "crypto/rand" "encoding/base64" "fmt" "strings" "maunium.net/go/mautrix/crypto/utils" ) // Key represents a SSSS private key and related metadata. type Key struct { ID string `json:"-"` Key []byte `json:"-"` Metadata *KeyMetadata `json:"-"` } // NewKey generates a new SSSS key, optionally based on the given passphrase. // // Errors are only returned if crypto/rand runs out of randomness. func NewKey(passphrase string) (*Key, error) { // We don't support any other algorithms currently. keyData := KeyMetadata{Algorithm: AlgorithmAESHMACSHA2} var ssssKey []byte if len(passphrase) > 0 { // There's a passphrase. We need to generate a salt for it, set the metadata // and then compute the key using the passphrase and the metadata. saltBytes := make([]byte, 24) if _, err := rand.Read(saltBytes); err != nil { return nil, fmt.Errorf("failed to get random bytes for salt: %w", err) } keyData.Passphrase = &PassphraseMetadata{ Algorithm: PassphraseAlgorithmPBKDF2, Iterations: 500000, Salt: base64.StdEncoding.EncodeToString(saltBytes), Bits: 256, } var err error ssssKey, err = keyData.Passphrase.GetKey(passphrase) if err != nil { return nil, fmt.Errorf("failed to get key from passphrase: %w", err) } } else { // No passphrase, just generate a random key ssssKey = make([]byte, 32) if _, err := rand.Read(ssssKey); err != nil { return nil, fmt.Errorf("failed to get random bytes for key: %w", err) } } // Generate a random ID for the key. It's what identifies the key in account data. keyIDBytes := make([]byte, 24) if _, err := rand.Read(keyIDBytes); err != nil { return nil, fmt.Errorf("failed to get random bytes for key ID: %w", err) } // We store a certain hash in the key metadata so that clients can check if the user entered the correct key. var ivBytes [utils.AESCTRIVLength]byte if _, err := rand.Read(ivBytes[:]); err != nil { return nil, fmt.Errorf("failed to get random bytes for IV: %w", err) } keyData.IV = base64.StdEncoding.EncodeToString(ivBytes[:]) keyData.MAC = keyData.calculateHash(ssssKey) return &Key{ Key: ssssKey, ID: base64.StdEncoding.EncodeToString(keyIDBytes), Metadata: &keyData, }, nil } // RecoveryKey gets the recovery key for this SSSS key. func (key *Key) RecoveryKey() string { return utils.EncodeBase58RecoveryKey(key.Key) } // Encrypt encrypts the given data with this key. func (key *Key) Encrypt(eventType string, data []byte) EncryptedKeyData { aesKey, hmacKey := utils.DeriveKeysSHA256(key.Key, eventType) iv := utils.GenA256CTRIV() payload := make([]byte, base64.StdEncoding.EncodedLen(len(data))) base64.StdEncoding.Encode(payload, data) utils.XorA256CTR(payload, aesKey, iv) return EncryptedKeyData{ Ciphertext: base64.StdEncoding.EncodeToString(payload), IV: base64.StdEncoding.EncodeToString(iv[:]), MAC: utils.HMACSHA256B64(payload, hmacKey), } } // Decrypt decrypts the given encrypted data with this key. func (key *Key) Decrypt(eventType string, data EncryptedKeyData) ([]byte, error) { var ivBytes [utils.AESCTRIVLength]byte decodedIV, _ := base64.StdEncoding.DecodeString(data.IV) copy(ivBytes[:], decodedIV) payload, err := base64.StdEncoding.DecodeString(data.Ciphertext) if err != nil { return nil, err } // derive the AES and HMAC keys for the requested event type using the SSSS key aesKey, hmacKey := utils.DeriveKeysSHA256(key.Key, eventType) // compare the stored MAC with the one we calculated from the ciphertext calcMac := utils.HMACSHA256B64(payload, hmacKey) if strings.ReplaceAll(data.MAC, "=", "") != strings.ReplaceAll(calcMac, "=", "") { return nil, ErrKeyDataMACMismatch } utils.XorA256CTR(payload, aesKey, ivBytes) decryptedDecoded, err := base64.StdEncoding.DecodeString(string(payload)) return decryptedDecoded, err } go-0.11.1/crypto/ssss/key_test.go000066400000000000000000000051171436100171500166720ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package ssss_test import ( "encoding/json" "errors" "testing" "github.com/stretchr/testify/assert" "maunium.net/go/mautrix/crypto/ssss" "maunium.net/go/mautrix/event" ) const key1CrossSigningMasterKey = ` { "encrypted": { "gEJqbfSEMnP5JXXcukpXEX1l0aI3MDs0": { "iv": "BpKP9nQJTE9jrsAssoxPqQ==", "ciphertext": "fNRiiiidezjerTgV+G6pUtmeF3izzj5re/mVvY0hO2kM6kYGrxLuIu2ej80=", "mac": "/gWGDGMyOLmbJp+aoSLh5JxCs0AdS6nAhjzpe+9G2Q0=" } } } ` var key1CrossSigningMasterKeyDecrypted = []byte{ 0x68, 0xf9, 0x7f, 0xd1, 0x92, 0x2e, 0xec, 0xf6, 0xb8, 0x2b, 0xb8, 0x90, 0xd2, 0x4d, 0x06, 0x52, 0x98, 0x4e, 0x7a, 0x1d, 0x70, 0x3b, 0x9e, 0x86, 0x7b, 0x7e, 0xba, 0xf7, 0xfe, 0xb9, 0x5b, 0x6f, } func getEncryptedMasterKey() *ssss.EncryptedAccountDataEventContent { var eadec ssss.EncryptedAccountDataEventContent err := json.Unmarshal([]byte(key1CrossSigningMasterKey), &eadec) if err != nil { panic(err) } return &eadec } func TestKey_Decrypt_Success(t *testing.T) { key := getKey1() emk := getEncryptedMasterKey() decrypted, err := emk.Decrypt(event.AccountDataCrossSigningMaster.Type, key) assert.NoError(t, err) assert.Equal(t, key1CrossSigningMasterKeyDecrypted, decrypted) } func TestKey_Decrypt_WrongKey(t *testing.T) { key := getKey2() emk := getEncryptedMasterKey() decrypted, err := emk.Decrypt(event.AccountDataCrossSigningMaster.Type, key) assert.True(t, errors.Is(err, ssss.ErrNotEncryptedForKey), "unexpected error %v", err) assert.Nil(t, decrypted) } func TestKey_Decrypt_FakeKey(t *testing.T) { key := getKey2() key.ID = key1ID emk := getEncryptedMasterKey() decrypted, err := emk.Decrypt(event.AccountDataCrossSigningMaster.Type, key) assert.True(t, errors.Is(err, ssss.ErrKeyDataMACMismatch), "unexpected error %v", err) assert.Nil(t, decrypted) } func TestKey_Decrypt_WrongType(t *testing.T) { key := getKey1() emk := getEncryptedMasterKey() decrypted, err := emk.Decrypt(event.AccountDataCrossSigningSelf.Type, key) assert.True(t, errors.Is(err, ssss.ErrKeyDataMACMismatch), "unexpected error %v", err) assert.Nil(t, decrypted) } func TestKey_Encrypt(t *testing.T) { key1 := getKey1() var evtType = "net.maunium.data" var data = []byte{0xde, 0xad, 0xbe, 0xef} encrypted := key1.Encrypt(evtType, data) decrypted, err := key1.Decrypt(evtType, encrypted) assert.NoError(t, err) assert.Equal(t, data, decrypted) } go-0.11.1/crypto/ssss/meta.go000066400000000000000000000061151436100171500157700ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package ssss import ( "encoding/base64" "fmt" "strings" "maunium.net/go/mautrix/crypto/utils" ) // KeyMetadata represents server-side metadata about a SSSS key. The metadata can be used to get // the actual SSSS key from a passphrase or recovery key. type KeyMetadata struct { id string Algorithm Algorithm `json:"algorithm"` IV string `json:"iv"` MAC string `json:"mac"` Passphrase *PassphraseMetadata `json:"passphrase,omitempty"` } // VerifyRecoveryKey verifies that the given passphrase is valid and returns the computed SSSS key. func (kd *KeyMetadata) VerifyPassphrase(passphrase string) (*Key, error) { ssssKey, err := kd.Passphrase.GetKey(passphrase) if err != nil { return nil, err } else if !kd.VerifyKey(ssssKey) { return nil, ErrIncorrectSSSSKey } return &Key{ ID: kd.id, Key: ssssKey, Metadata: kd, }, nil } // VerifyRecoveryKey verifies that the given recovery key is valid and returns the decoded SSSS key. func (kd *KeyMetadata) VerifyRecoveryKey(recoverKey string) (*Key, error) { ssssKey := utils.DecodeBase58RecoveryKey(recoverKey) if ssssKey == nil { return nil, ErrInvalidRecoveryKey } else if !kd.VerifyKey(ssssKey) { return nil, ErrIncorrectSSSSKey } return &Key{ ID: kd.id, Key: ssssKey, Metadata: kd, }, nil } // VerifyKey verifies the SSSS key is valid by calculating and comparing its MAC. func (kd *KeyMetadata) VerifyKey(key []byte) bool { return strings.ReplaceAll(kd.MAC, "=", "") == strings.ReplaceAll(kd.calculateHash(key), "=", "") } // calculateHash calculates the hash used for checking if the key is entered correctly as described // in the spec: https://matrix.org/docs/spec/client_server/unstable#m-secret-storage-v1-aes-hmac-sha2 func (kd *KeyMetadata) calculateHash(key []byte) string { aesKey, hmacKey := utils.DeriveKeysSHA256(key, "") var ivBytes [utils.AESCTRIVLength]byte _, _ = base64.StdEncoding.Decode(ivBytes[:], []byte(kd.IV)) cipher := utils.XorA256CTR(make([]byte, utils.AESCTRKeyLength), aesKey, ivBytes) return utils.HMACSHA256B64(cipher, hmacKey) } // PassphraseMetadata represents server-side metadata about a SSSS key passphrase. type PassphraseMetadata struct { Algorithm PassphraseAlgorithm `json:"algorithm"` Iterations int `json:"iterations"` Salt string `json:"salt"` Bits int `json:"bits"` } // GetKey gets the SSSS key from the passphrase. func (pd *PassphraseMetadata) GetKey(passphrase string) ([]byte, error) { if pd == nil { return nil, ErrNoPassphrase } if pd.Algorithm != PassphraseAlgorithmPBKDF2 { return nil, fmt.Errorf("%w: %s", ErrUnsupportedPassphraseAlgorithm, pd.Algorithm) } bits := 256 if pd.Bits != 0 { bits = pd.Bits } return utils.PBKDF2SHA512([]byte(passphrase), []byte(pd.Salt), pd.Iterations, bits), nil } go-0.11.1/crypto/ssss/meta_test.go000066400000000000000000000066641436100171500170400ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package ssss_test import ( "encoding/json" "errors" "testing" "github.com/stretchr/testify/assert" "maunium.net/go/mautrix/crypto/ssss" ) const key1Meta = ` { "algorithm": "m.secret_storage.v1.aes-hmac-sha2", "passphrase": { "algorithm": "m.pbkdf2", "iterations": 500000, "salt": "y863BOoqOadgDp8S3FtHXikDJEalsQ7d" }, "iv": "xxkTK0L4UzxgAFkQ6XPwsw==", "mac": "MEhooO0ZhFJNxUhvRMSxBnJfL20wkLgle3ocY0ee/eA=" } ` const key1ID = "gEJqbfSEMnP5JXXcukpXEX1l0aI3MDs0" const key1RecoveryKey = "EsTE s92N EtaX s2h6 VQYF 9Kao tHYL mkyL GKMh isZb KJ4E tvoC" const key1Passphrase = "correct horse battery staple" const key2Meta = ` { "algorithm": "m.secret_storage.v1.aes-hmac-sha2", "iv": "O0BOvTqiIAYjC+RMcyHfWw==", "mac": "7k6OruQlWg0UmQjxGZ0ad4Q6DdwkgnoI7G6X3IjBYtI=" } ` const key2ID = "NVe5vK6lZS9gEMQLJw0yqkzmE5Mr7dLv" const key2RecoveryKey = "EsUC xSxt XJgQ dz19 8WBZ rHdE GZo7 ybsn EFmG Y5HY MDAG GNWe" func getKey1Meta() *ssss.KeyMetadata { var km ssss.KeyMetadata err := json.Unmarshal([]byte(key1Meta), &km) if err != nil { panic(err) } return &km } func getKey1() *ssss.Key { km := getKey1Meta() key, err := km.VerifyRecoveryKey(key1RecoveryKey) if err != nil { panic(err) } key.ID = key1ID return key } func getKey2Meta() *ssss.KeyMetadata { var km ssss.KeyMetadata err := json.Unmarshal([]byte(key2Meta), &km) if err != nil { panic(err) } return &km } func getKey2() *ssss.Key { km := getKey2Meta() key, err := km.VerifyRecoveryKey(key2RecoveryKey) if err != nil { panic(err) } key.ID = key2ID return key } func TestKeyMetadata_VerifyRecoveryKey_Correct(t *testing.T) { km := getKey1Meta() key, err := km.VerifyRecoveryKey(key1RecoveryKey) assert.NoError(t, err) assert.NotNil(t, key) assert.Equal(t, key1RecoveryKey, key.RecoveryKey()) } func TestKeyMetadata_VerifyRecoveryKey_Correct2(t *testing.T) { km := getKey2Meta() key, err := km.VerifyRecoveryKey(key2RecoveryKey) assert.NoError(t, err) assert.NotNil(t, key) assert.Equal(t, key2RecoveryKey, key.RecoveryKey()) } func TestKeyMetadata_VerifyRecoveryKey_Invalid(t *testing.T) { km := getKey1Meta() key, err := km.VerifyRecoveryKey("foo") assert.True(t, errors.Is(err, ssss.ErrInvalidRecoveryKey), "unexpected error: %v", err) assert.Nil(t, key) } func TestKeyMetadata_VerifyRecoveryKey_Incorrect(t *testing.T) { km := getKey1Meta() key, err := km.VerifyRecoveryKey(key2RecoveryKey) assert.True(t, errors.Is(err, ssss.ErrIncorrectSSSSKey), "unexpected error: %v", err) assert.Nil(t, key) } func TestKeyMetadata_VerifyPassphrase_Correct(t *testing.T) { km := getKey1Meta() key, err := km.VerifyPassphrase(key1Passphrase) assert.NoError(t, err) assert.NotNil(t, key) assert.Equal(t, key1RecoveryKey, key.RecoveryKey()) } func TestKeyMetadata_VerifyPassphrase_Incorrect(t *testing.T) { km := getKey1Meta() key, err := km.VerifyPassphrase("incorrect horse battery staple") assert.True(t, errors.Is(err, ssss.ErrIncorrectSSSSKey), "unexpected error %v", err) assert.Nil(t, key) } func TestKeyMetadata_VerifyPassphrase_NotSet(t *testing.T) { km := getKey2Meta() key, err := km.VerifyPassphrase("hmm") assert.True(t, errors.Is(err, ssss.ErrNoPassphrase), "unexpected error %v", err) assert.Nil(t, key) } go-0.11.1/crypto/ssss/types.go000066400000000000000000000054601436100171500162100ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package ssss import ( "errors" "fmt" "reflect" "maunium.net/go/mautrix/event" ) var ( ErrNoDefaultKeyID = errors.New("could not find default key ID") ErrNoDefaultKeyAccountDataEvent = fmt.Errorf("%w: no %s event in account data", ErrNoDefaultKeyID, event.AccountDataSecretStorageDefaultKey.Type) ErrNoKeyFieldInAccountDataEvent = fmt.Errorf("%w: missing key field in account data event", ErrNoDefaultKeyID) ErrNoKeyGiven = errors.New("must provide at least one key to encrypt for") ErrNotEncryptedForKey = errors.New("data is not encrypted for given key ID") ErrKeyDataMACMismatch = errors.New("key data MAC mismatch") ErrNoPassphrase = errors.New("no passphrase data has been set for the default key") ErrUnsupportedPassphraseAlgorithm = errors.New("unsupported passphrase KDF algorithm") ErrIncorrectSSSSKey = errors.New("incorrect SSSS key") ErrInvalidRecoveryKey = errors.New("invalid recovery key") ) // Algorithm is the identifier for an SSSS encryption algorithm. type Algorithm string const ( // AlgorithmAESHMACSHA2 is the current main algorithm. AlgorithmAESHMACSHA2 Algorithm = "m.secret_storage.v1.aes-hmac-sha2" // AlgorithmCurve25519AESSHA2 is the old algorithm AlgorithmCurve25519AESSHA2 Algorithm = "m.secret_storage.v1.curve25519-aes-sha2" ) // PassphraseAlgorithm is the identifier for an algorithm used to derive a key from a passphrase for SSSS. type PassphraseAlgorithm string const ( // PassphraseAlgorithmPBKDF2 is the current main algorithm PassphraseAlgorithmPBKDF2 PassphraseAlgorithm = "m.pbkdf2" ) type EncryptedKeyData struct { Ciphertext string `json:"ciphertext"` IV string `json:"iv"` MAC string `json:"mac"` } type EncryptedAccountDataEventContent struct { Encrypted map[string]EncryptedKeyData `json:"encrypted"` } func (ed *EncryptedAccountDataEventContent) Decrypt(eventType string, key *Key) ([]byte, error) { keyEncData, ok := ed.Encrypted[key.ID] if !ok { return nil, ErrNotEncryptedForKey } return key.Decrypt(eventType, keyEncData) } func init() { encryptedContent := reflect.TypeOf(&EncryptedAccountDataEventContent{}) event.TypeMap[event.AccountDataCrossSigningMaster] = encryptedContent event.TypeMap[event.AccountDataCrossSigningSelf] = encryptedContent event.TypeMap[event.AccountDataCrossSigningUser] = encryptedContent event.TypeMap[event.AccountDataSecretStorageDefaultKey] = reflect.TypeOf(&DefaultSecretStorageKeyContent{}) event.TypeMap[event.AccountDataSecretStorageKey] = reflect.TypeOf(&KeyMetadata{}) } go-0.11.1/crypto/store.go000066400000000000000000000502161436100171500152040ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package crypto import ( "encoding/gob" "errors" "fmt" "os" "sort" "sync" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) // TrustState determines how trusted a device is. type TrustState int const ( TrustStateUnset TrustState = iota TrustStateVerified TrustStateBlacklisted TrustStateIgnored ) func (ts TrustState) String() string { switch ts { case TrustStateUnset: return "unverified" case TrustStateVerified: return "verified" case TrustStateBlacklisted: return "blacklisted" case TrustStateIgnored: return "ignored" default: return "" } } // DeviceIdentity contains the identity details of a device and some additional info. type DeviceIdentity struct { UserID id.UserID DeviceID id.DeviceID IdentityKey id.Curve25519 SigningKey id.Ed25519 Trust TrustState Deleted bool Name string } func (device *DeviceIdentity) Fingerprint() string { return Fingerprint(device.SigningKey) } var ErrGroupSessionWithheld = errors.New("group session has been withheld") // Store is used by OlmMachine to store Olm and Megolm sessions, user device lists and message indices. // // General implementation details: // * Get methods should not return errors if the requested data does not exist in the store, they should simply return nil. // * Update methods may assume that the pointer is the same as what has earlier been added to or fetched from the store. type Store interface { // Flush ensures that everything in the store is persisted to disk. // This doesn't have to do anything, e.g. for database-backed implementations that persist everything immediately. Flush() error // PutAccount updates the OlmAccount in the store. PutAccount(*OlmAccount) error // GetAccount returns the OlmAccount in the store that was previously inserted with PutAccount. GetAccount() (*OlmAccount, error) // AddSession inserts an Olm session into the store. AddSession(id.SenderKey, *OlmSession) error // HasSession returns whether or not the store has an Olm session with the given sender key. HasSession(id.SenderKey) bool // GetSessions returns all Olm sessions in the store with the given sender key. GetSessions(id.SenderKey) (OlmSessionList, error) // GetLatestSession returns the session with the highest session ID (lexiographically sorting). // It's usually safe to return the most recently added session if sorting by session ID is too difficult. GetLatestSession(id.SenderKey) (*OlmSession, error) // UpdateSession updates a session that has previously been inserted with AddSession. UpdateSession(id.SenderKey, *OlmSession) error // PutGroupSession inserts an inbound Megolm session into the store. If an earlier withhold event has been inserted // with PutWithheldGroupSession, this call should replace that. However, PutWithheldGroupSession must not replace // sessions inserted with this call. PutGroupSession(id.RoomID, id.SenderKey, id.SessionID, *InboundGroupSession) error // GetGroupSession gets an inbound Megolm session from the store. If the group session has been withheld // (i.e. a room key withheld event has been saved with PutWithheldGroupSession), this should return the // ErrGroupSessionWithheld error. The caller may use GetWithheldGroupSession to find more details. GetGroupSession(id.RoomID, id.SenderKey, id.SessionID) (*InboundGroupSession, error) // PutWithheldGroupSession tells the store that a specific Megolm session was withheld. PutWithheldGroupSession(event.RoomKeyWithheldEventContent) error // GetWithheldGroupSession gets the event content that was previously inserted with PutWithheldGroupSession. GetWithheldGroupSession(id.RoomID, id.SenderKey, id.SessionID) (*event.RoomKeyWithheldEventContent, error) // GetGroupSessionsForRoom gets all the inbound Megolm sessions for a specific room. This is used for creating key // export files. Unlike GetGroupSession, this should not return any errors about withheld keys. GetGroupSessionsForRoom(id.RoomID) ([]*InboundGroupSession, error) // GetGroupSessionsForRoom gets all the inbound Megolm sessions in the store. This is used for creating key export // files. Unlike GetGroupSession, this should not return any errors about withheld keys. GetAllGroupSessions() ([]*InboundGroupSession, error) // AddOutboundGroupSession inserts the given outbound Megolm session into the store. // // The store should index inserted sessions by the RoomID field to support getting and removing sessions. // There will only be one outbound session per room ID at a time. AddOutboundGroupSession(*OutboundGroupSession) error // UpdateOutboundGroupSession updates the given outbound Megolm session in the store. UpdateOutboundGroupSession(*OutboundGroupSession) error // GetOutboundGroupSession gets the stored outbound Megolm session for the given room ID from the store. GetOutboundGroupSession(id.RoomID) (*OutboundGroupSession, error) // RemoveOutboundGroupSession removes the stored outbound Megolm session for the given room ID. RemoveOutboundGroupSession(id.RoomID) error // ValidateMessageIndex validates that the given message details aren't from a replay attack. // // Implementations should store a map from (senderKey, sessionID, index) to (eventID, timestamp), then use that map // to check whether or not the message index is valid: // // * If the map key doesn't exist, the given values should be stored and this should return true. // * If the map key exists and the stored values match the given values, this should return true. // * If the map key exists, but the stored values do not match the given values, this should return false. ValidateMessageIndex(senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) bool // GetDevices returns a map from device ID to DeviceIdentity containing all devices of a given user. GetDevices(id.UserID) (map[id.DeviceID]*DeviceIdentity, error) // GetDevice returns a specific device of a given user. GetDevice(id.UserID, id.DeviceID) (*DeviceIdentity, error) // PutDevice stores a single device for a user, replacing it if it exists already. PutDevice(id.UserID, *DeviceIdentity) error // PutDevices overrides the stored device list for the given user with the given list. PutDevices(id.UserID, map[id.DeviceID]*DeviceIdentity) error // FindDeviceByKey finds a specific device by its identity key. FindDeviceByKey(id.UserID, id.IdentityKey) (*DeviceIdentity, error) // FilterTrackedUsers returns a filtered version of the given list that only includes user IDs whose device lists // have been stored with PutDevices. A user is considered tracked even if the PutDevices list was empty. FilterTrackedUsers([]id.UserID) []id.UserID // PutCrossSigningKey stores a cross-signing key of some user along with its usage. PutCrossSigningKey(id.UserID, id.CrossSigningUsage, id.Ed25519) error // GetCrossSigningKeys retrieves a user's stored cross-signing keys. GetCrossSigningKeys(id.UserID) (map[id.CrossSigningUsage]id.Ed25519, error) // PutSignature stores a signature of a cross-signing or device key along with the signer's user ID and key. PutSignature(id.UserID, id.Ed25519, id.UserID, id.Ed25519, string) error // GetSignaturesForKeyBy returns the signatures for a cross-signing or device key by the given signer. GetSignaturesForKeyBy(id.UserID, id.Ed25519, id.UserID) (map[id.Ed25519]string, error) // IsKeySignedBy returns whether a cross-signing or device key is signed by the given signer. IsKeySignedBy(id.UserID, id.Ed25519, id.UserID, id.Ed25519) (bool, error) // DropSignaturesByKey deletes the signatures made by the given user and key from the store. It returns the number of signatures deleted. DropSignaturesByKey(id.UserID, id.Ed25519) (int64, error) } type messageIndexKey struct { SenderKey id.SenderKey SessionID id.SessionID Index uint } type messageIndexValue struct { EventID id.EventID Timestamp int64 } // GobStore is a simple Store implementation that dumps everything into a .gob file. // // Deprecated: this is not atomic and can lose data. Using SQLCryptoStore or a custom implementation is recommended. type GobStore struct { lock sync.RWMutex path string Account *OlmAccount Sessions map[id.SenderKey]OlmSessionList GroupSessions map[id.RoomID]map[id.SenderKey]map[id.SessionID]*InboundGroupSession WithheldGroupSessions map[id.RoomID]map[id.SenderKey]map[id.SessionID]*event.RoomKeyWithheldEventContent OutGroupSessions map[id.RoomID]*OutboundGroupSession MessageIndices map[messageIndexKey]messageIndexValue Devices map[id.UserID]map[id.DeviceID]*DeviceIdentity CrossSigningKeys map[id.UserID]map[id.CrossSigningUsage]id.Ed25519 KeySignatures map[id.UserID]map[id.Ed25519]map[id.UserID]map[id.Ed25519]string } var _ Store = (*GobStore)(nil) // NewGobStore creates a new GobStore that saves everything to the given file. // // Deprecated: this is not atomic and can lose data. Using SQLCryptoStore or a custom implementation is recommended. func NewGobStore(path string) (*GobStore, error) { gs := &GobStore{ path: path, Sessions: make(map[id.SenderKey]OlmSessionList), GroupSessions: make(map[id.RoomID]map[id.SenderKey]map[id.SessionID]*InboundGroupSession), WithheldGroupSessions: make(map[id.RoomID]map[id.SenderKey]map[id.SessionID]*event.RoomKeyWithheldEventContent), OutGroupSessions: make(map[id.RoomID]*OutboundGroupSession), MessageIndices: make(map[messageIndexKey]messageIndexValue), Devices: make(map[id.UserID]map[id.DeviceID]*DeviceIdentity), CrossSigningKeys: make(map[id.UserID]map[id.CrossSigningUsage]id.Ed25519), KeySignatures: make(map[id.UserID]map[id.Ed25519]map[id.UserID]map[id.Ed25519]string), } return gs, gs.load() } func (gs *GobStore) save() error { file, err := os.OpenFile(gs.path, os.O_CREATE|os.O_WRONLY, 0600) if err != nil { return err } err = gob.NewEncoder(file).Encode(gs) _ = file.Close() return err } func (gs *GobStore) load() error { file, err := os.OpenFile(gs.path, os.O_RDONLY, 0600) if err != nil { if os.IsNotExist(err) { return nil } return err } err = gob.NewDecoder(file).Decode(gs) _ = file.Close() return err } func (gs *GobStore) Flush() error { gs.lock.Lock() err := gs.save() gs.lock.Unlock() return err } func (gs *GobStore) GetAccount() (*OlmAccount, error) { return gs.Account, nil } func (gs *GobStore) PutAccount(account *OlmAccount) error { gs.lock.Lock() gs.Account = account err := gs.save() gs.lock.Unlock() return err } func (gs *GobStore) GetSessions(senderKey id.SenderKey) (OlmSessionList, error) { gs.lock.Lock() sessions, ok := gs.Sessions[senderKey] if !ok { sessions = []*OlmSession{} gs.Sessions[senderKey] = sessions } gs.lock.Unlock() return sessions, nil } func (gs *GobStore) AddSession(senderKey id.SenderKey, session *OlmSession) error { gs.lock.Lock() sessions, _ := gs.Sessions[senderKey] gs.Sessions[senderKey] = append(sessions, session) sort.Sort(gs.Sessions[senderKey]) err := gs.save() gs.lock.Unlock() return err } func (gs *GobStore) UpdateSession(_ id.SenderKey, _ *OlmSession) error { // we don't need to do anything here because the session is a pointer and already stored in our map return gs.save() } func (gs *GobStore) HasSession(senderKey id.SenderKey) bool { gs.lock.RLock() sessions, ok := gs.Sessions[senderKey] gs.lock.RUnlock() return ok && len(sessions) > 0 && !sessions[0].Expired() } func (gs *GobStore) GetLatestSession(senderKey id.SenderKey) (*OlmSession, error) { gs.lock.RLock() sessions, ok := gs.Sessions[senderKey] gs.lock.RUnlock() if !ok || len(sessions) == 0 { return nil, nil } return sessions[0], nil } func (gs *GobStore) getGroupSessions(roomID id.RoomID, senderKey id.SenderKey) map[id.SessionID]*InboundGroupSession { room, ok := gs.GroupSessions[roomID] if !ok { room = make(map[id.SenderKey]map[id.SessionID]*InboundGroupSession) gs.GroupSessions[roomID] = room } sender, ok := room[senderKey] if !ok { sender = make(map[id.SessionID]*InboundGroupSession) room[senderKey] = sender } return sender } func (gs *GobStore) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, igs *InboundGroupSession) error { gs.lock.Lock() gs.getGroupSessions(roomID, senderKey)[sessionID] = igs err := gs.save() gs.lock.Unlock() return err } func (gs *GobStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*InboundGroupSession, error) { gs.lock.Lock() session, ok := gs.getGroupSessions(roomID, senderKey)[sessionID] if !ok { withheld, ok := gs.getWithheldGroupSessions(roomID, senderKey)[sessionID] gs.lock.Unlock() if ok { return nil, fmt.Errorf("%w (%s)", ErrGroupSessionWithheld, withheld.Code) } return nil, nil } gs.lock.Unlock() return session, nil } func (gs *GobStore) getWithheldGroupSessions(roomID id.RoomID, senderKey id.SenderKey) map[id.SessionID]*event.RoomKeyWithheldEventContent { room, ok := gs.WithheldGroupSessions[roomID] if !ok { room = make(map[id.SenderKey]map[id.SessionID]*event.RoomKeyWithheldEventContent) gs.WithheldGroupSessions[roomID] = room } sender, ok := room[senderKey] if !ok { sender = make(map[id.SessionID]*event.RoomKeyWithheldEventContent) room[senderKey] = sender } return sender } func (gs *GobStore) PutWithheldGroupSession(content event.RoomKeyWithheldEventContent) error { gs.lock.Lock() gs.getWithheldGroupSessions(content.RoomID, content.SenderKey)[content.SessionID] = &content err := gs.save() gs.lock.Unlock() return err } func (gs *GobStore) GetWithheldGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*event.RoomKeyWithheldEventContent, error) { gs.lock.Lock() session, ok := gs.getWithheldGroupSessions(roomID, senderKey)[sessionID] gs.lock.Unlock() if !ok { return nil, nil } return session, nil } func (gs *GobStore) GetGroupSessionsForRoom(roomID id.RoomID) ([]*InboundGroupSession, error) { gs.lock.Lock() defer gs.lock.Unlock() room, ok := gs.GroupSessions[roomID] if !ok { return []*InboundGroupSession{}, nil } var result []*InboundGroupSession for _, sessions := range room { for _, session := range sessions { result = append(result, session) } } return result, nil } func (gs *GobStore) GetAllGroupSessions() ([]*InboundGroupSession, error) { gs.lock.Lock() var result []*InboundGroupSession for _, room := range gs.GroupSessions { for _, sessions := range room { for _, session := range sessions { result = append(result, session) } } } gs.lock.Unlock() return result, nil } func (gs *GobStore) AddOutboundGroupSession(session *OutboundGroupSession) error { gs.lock.Lock() gs.OutGroupSessions[session.RoomID] = session err := gs.save() gs.lock.Unlock() return err } func (gs *GobStore) UpdateOutboundGroupSession(_ *OutboundGroupSession) error { // we don't need to do anything here because the session is a pointer and already stored in our map return gs.save() } func (gs *GobStore) GetOutboundGroupSession(roomID id.RoomID) (*OutboundGroupSession, error) { gs.lock.RLock() session, ok := gs.OutGroupSessions[roomID] gs.lock.RUnlock() if !ok { return nil, nil } return session, nil } func (gs *GobStore) RemoveOutboundGroupSession(roomID id.RoomID) error { gs.lock.Lock() session, ok := gs.OutGroupSessions[roomID] if !ok || session == nil { gs.lock.Unlock() return nil } delete(gs.OutGroupSessions, roomID) gs.lock.Unlock() return nil } func (gs *GobStore) ValidateMessageIndex(senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) bool { gs.lock.Lock() defer gs.lock.Unlock() key := messageIndexKey{ SenderKey: senderKey, SessionID: sessionID, Index: index, } val, ok := gs.MessageIndices[key] if !ok { gs.MessageIndices[key] = messageIndexValue{ EventID: eventID, Timestamp: timestamp, } _ = gs.save() return true } if val.EventID != eventID || val.Timestamp != timestamp { return false } return true } func (gs *GobStore) GetDevices(userID id.UserID) (map[id.DeviceID]*DeviceIdentity, error) { gs.lock.RLock() devices, ok := gs.Devices[userID] if !ok { devices = nil } gs.lock.RUnlock() return devices, nil } func (gs *GobStore) GetDevice(userID id.UserID, deviceID id.DeviceID) (*DeviceIdentity, error) { gs.lock.RLock() defer gs.lock.RUnlock() devices, ok := gs.Devices[userID] if !ok { return nil, nil } device, ok := devices[deviceID] if !ok { return nil, nil } return device, nil } func (gs *GobStore) FindDeviceByKey(userID id.UserID, identityKey id.IdentityKey) (*DeviceIdentity, error) { gs.lock.RLock() defer gs.lock.RUnlock() devices, ok := gs.Devices[userID] if !ok { return nil, nil } for _, device := range devices { if device.IdentityKey == identityKey { return device, nil } } return nil, nil } func (gs *GobStore) PutDevice(userID id.UserID, device *DeviceIdentity) error { gs.lock.Lock() devices, ok := gs.Devices[userID] if !ok { devices = make(map[id.DeviceID]*DeviceIdentity) gs.Devices[userID] = devices } devices[device.DeviceID] = device err := gs.save() gs.lock.Unlock() return err } func (gs *GobStore) PutDevices(userID id.UserID, devices map[id.DeviceID]*DeviceIdentity) error { gs.lock.Lock() gs.Devices[userID] = devices err := gs.save() gs.lock.Unlock() return err } func (gs *GobStore) FilterTrackedUsers(users []id.UserID) []id.UserID { gs.lock.RLock() var ptr int for _, userID := range users { _, ok := gs.Devices[userID] if ok { users[ptr] = userID ptr++ } } gs.lock.RUnlock() return users[:ptr] } func (gs *GobStore) PutCrossSigningKey(userID id.UserID, usage id.CrossSigningUsage, key id.Ed25519) error { gs.lock.RLock() userKeys, ok := gs.CrossSigningKeys[userID] if !ok { userKeys = make(map[id.CrossSigningUsage]id.Ed25519) gs.CrossSigningKeys[userID] = userKeys } userKeys[usage] = key err := gs.save() gs.lock.RUnlock() return err } func (gs *GobStore) GetCrossSigningKeys(userID id.UserID) (map[id.CrossSigningUsage]id.Ed25519, error) { gs.lock.RLock() defer gs.lock.RUnlock() keys, ok := gs.CrossSigningKeys[userID] if !ok { return map[id.CrossSigningUsage]id.Ed25519{}, nil } return keys, nil } func (gs *GobStore) PutSignature(signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519, signature string) error { gs.lock.RLock() signedUserSigs, ok := gs.KeySignatures[signedUserID] if !ok { signedUserSigs = make(map[id.Ed25519]map[id.UserID]map[id.Ed25519]string) gs.KeySignatures[signedUserID] = signedUserSigs } signaturesForKey, ok := signedUserSigs[signedKey] if !ok { signaturesForKey = make(map[id.UserID]map[id.Ed25519]string) signedUserSigs[signedKey] = signaturesForKey } signedByUser, ok := signaturesForKey[signerUserID] if !ok { signedByUser = make(map[id.Ed25519]string) signaturesForKey[signerUserID] = signedByUser } signedByUser[signerKey] = signature err := gs.save() gs.lock.RUnlock() return err } func (gs *GobStore) GetSignaturesForKeyBy(userID id.UserID, key id.Ed25519, signerID id.UserID) (map[id.Ed25519]string, error) { gs.lock.RLock() defer gs.lock.RUnlock() userKeys, ok := gs.KeySignatures[userID] if !ok { return map[id.Ed25519]string{}, nil } sigsForKey, ok := userKeys[key] if !ok { return map[id.Ed25519]string{}, nil } sigsBySigner, ok := sigsForKey[signerID] if !ok { return map[id.Ed25519]string{}, nil } return sigsBySigner, nil } func (gs *GobStore) IsKeySignedBy(userID id.UserID, key id.Ed25519, signerID id.UserID, signerKey id.Ed25519) (bool, error) { sigs, err := gs.GetSignaturesForKeyBy(userID, key, signerID) if err != nil { return false, err } _, ok := sigs[signerKey] return ok, nil } func (gs *GobStore) DropSignaturesByKey(userID id.UserID, key id.Ed25519) (int64, error) { var count int64 gs.lock.RLock() for _, userSigs := range gs.KeySignatures { for _, keySigs := range userSigs { if signedBySigner, ok := keySigs[userID]; ok { if _, ok := signedBySigner[key]; ok { count++ delete(signedBySigner, key) } } } } gs.lock.RUnlock() return count, nil } go-0.11.1/crypto/store_test.go000066400000000000000000000204221436100171500162370ustar00rootroot00000000000000// Copyright (c) 2020 Nikos Filippakis // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package crypto import ( "database/sql" "os" "strconv" "testing" _ "github.com/mattn/go-sqlite3" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/id" ) const olmSessID = "sJlikQQKXp7UQjmS9/lyZCNUVJ2AmKyHbufPBaC7tpk" const olmPickled = "L6cdv3JYO9OzhXbcjNSwl7ldN5bDvwmGyin+hISePETE6bO71DIlhqTC9YIhg21RDqRPH2HNl1MCyCw0hEXICWQyeJ9S7JLie" + "5PYxhqSSaTYaybvlvw34jvuSgEx0iotM6WNuWu5ocrsOo5Ye/3Nz7lBvxaw2rpS0jZnn7eV1n9GbINZk4YEVWrHOn7OxYfaGECJHDeAk/ameStiy" + "o1Gru0a/cmR0O3oKMyYnlXir0jS7oETMCsWk59GeVlz++j4aK0FK4g8/3fCMmLDXSatFjE9hoWDmeRwal58Y+XwX76Te/PiWtrFrinvCDEQJcZTa" + "qcCwp6sZrgLbmfBUBb0zJCogCmYw8m2" const groupSession = "9ZbsRqJuETbjnxPpKv29n3dubP/m5PSLbr9I9CIWS2O86F/Og1JZXhqT+4fA5tovoPfdpk5QLh7PfDyjmgOcO9sSA37maJyzCy6Ap+uBZLAXp6VLJ0mjSvxi+PAbzGKDMqpn+pa+oeEIH6SFPG/2GGDSRoXVi5fttAClCIoav5RflWiMypKqnQRfkZR2Gx8glOaBiTzAd7m0X6XGfYIPol41JUIHfBLuJBfXQ0Uu5GScV4eKUWdJP2J6zzC2Hx8cZAhiBBzAza0CbGcnUK+YJXMYaJg92HiIo++l317LlsYUJ/P+gKOLafYR9/l8bAzxH7j5s31PnRs7mD1Bl6G1LFM+dPsGXUOLx6PlvlTlYYM/opai0uKKzT0Wk6zPoq9fN/smlXEPBtKlw2fqcytL4gOF0MrBPEca" func getCryptoStores(t *testing.T) (map[string]Store, func()) { db, err := sql.Open("sqlite3", ":memory:?_busy_timeout=5000") if err != nil { t.Fatalf("Error opening db: %v", err) } sqlStore := NewSQLCryptoStore(db, "sqlite3", "accid", id.DeviceID("dev"), []byte("test"), emptyLogger{}) if err = sqlStore.CreateTables(); err != nil { t.Fatalf("Error creating tables: %v", err) } os.Remove("gob_store_test.gob") gobStore, err := NewGobStore("gob_store_test.gob") if err != nil { t.Fatalf("Error creating Gob store: %v", err) } return map[string]Store{ "sql": sqlStore, "gob": gobStore, }, func() { os.Remove("gob_store_test.gob") } } func TestPutNextBatch(t *testing.T) { stores, cleanup := getCryptoStores(t) defer cleanup() store := stores["sql"].(*SQLCryptoStore) store.PutNextBatch("batch1") if batch := store.GetNextBatch(); batch != "batch1" { t.Errorf("Expected batch1, got %v", batch) } } func TestPutAccount(t *testing.T) { stores, cleanup := getCryptoStores(t) defer cleanup() for storeName, store := range stores { t.Run(storeName, func(t *testing.T) { acc := NewOlmAccount() store.PutAccount(acc) retrieved, err := store.GetAccount() if err != nil { t.Fatalf("Error retrieving account: %v", err) } if acc.IdentityKey() != retrieved.IdentityKey() { t.Errorf("Stored identity key %v, got %v", acc.IdentityKey(), retrieved.IdentityKey()) } if acc.SigningKey() != retrieved.SigningKey() { t.Errorf("Stored signing key %v, got %v", acc.SigningKey(), retrieved.SigningKey()) } }) } } func TestValidateMessageIndex(t *testing.T) { stores, cleanup := getCryptoStores(t) defer cleanup() for storeName, store := range stores { t.Run(storeName, func(t *testing.T) { acc := NewOlmAccount() if !store.ValidateMessageIndex(acc.IdentityKey(), "sess1", "event1", 0, 1000) { t.Error("First message not validated successfully") } if store.ValidateMessageIndex(acc.IdentityKey(), "sess1", "event1", 0, 1001) { t.Error("First message validated successfully after changing timestamp") } if store.ValidateMessageIndex(acc.IdentityKey(), "sess1", "event2", 0, 1000) { t.Error("First message validated successfully after changing event ID") } if !store.ValidateMessageIndex(acc.IdentityKey(), "sess1", "event1", 0, 1000) { t.Error("First message not validated successfully for a second time") } }) } } func TestStoreOlmSession(t *testing.T) { stores, cleanup := getCryptoStores(t) defer cleanup() for storeName, store := range stores { t.Run(storeName, func(t *testing.T) { if store.HasSession(olmSessID) { t.Error("Found Olm session before inserting it") } olmInternal, err := olm.SessionFromPickled([]byte(olmPickled), []byte("test")) if err != nil { t.Fatalf("Error creating internal Olm session: %v", err) } olmSess := OlmSession{ id: olmSessID, Internal: *olmInternal, } err = store.AddSession(olmSessID, &olmSess) if err != nil { t.Errorf("Error storing Olm session: %v", err) } if !store.HasSession(olmSessID) { t.Error("Not found Olm session after inserting it") } retrieved, err := store.GetLatestSession(olmSessID) if err != nil { t.Errorf("Failed retrieving Olm session: %v", err) } if retrieved.ID() != olmSessID { t.Errorf("Expected session ID to be %v, got %v", olmSessID, retrieved.ID()) } if pickled := string(retrieved.Internal.Pickle([]byte("test"))); pickled != olmPickled { t.Error("Pickled Olm session does not match original") } }) } } func TestStoreMegolmSession(t *testing.T) { stores, cleanup := getCryptoStores(t) defer cleanup() for storeName, store := range stores { t.Run(storeName, func(t *testing.T) { acc := NewOlmAccount() internal, err := olm.InboundGroupSessionFromPickled([]byte(groupSession), []byte("test")) if err != nil { t.Fatalf("Error creating internal inbound group session: %v", err) } igs := &InboundGroupSession{ Internal: *internal, SigningKey: acc.SigningKey(), SenderKey: acc.IdentityKey(), RoomID: "room1", } err = store.PutGroupSession("room1", acc.IdentityKey(), igs.ID(), igs) if err != nil { t.Errorf("Error storing inbound group session: %v", err) } retrieved, err := store.GetGroupSession("room1", acc.IdentityKey(), igs.ID()) if err != nil { t.Errorf("Error retrieving inbound group session: %v", err) } if pickled := string(retrieved.Internal.Pickle([]byte("test"))); pickled != groupSession { t.Error("Pickled inbound group session does not match original") } }) } } func TestStoreOutboundMegolmSession(t *testing.T) { stores, cleanup := getCryptoStores(t) defer cleanup() for storeName, store := range stores { t.Run(storeName, func(t *testing.T) { sess, err := store.GetOutboundGroupSession("room1") if sess != nil { t.Error("Got outbound session before inserting") } if err != nil { t.Errorf("Error retrieving outbound session: %v", err) } outbound := NewOutboundGroupSession("room1", nil) err = store.AddOutboundGroupSession(outbound) if err != nil { t.Errorf("Error inserting outbound session: %v", err) } sess, err = store.GetOutboundGroupSession("room1") if sess == nil { t.Error("Did not get outbound session after inserting") } if err != nil { t.Errorf("Error retrieving outbound session: %v", err) } err = store.RemoveOutboundGroupSession("room1") if err != nil { t.Errorf("Error deleting outbound session: %v", err) } sess, err = store.GetOutboundGroupSession("room1") if sess != nil { t.Error("Got outbound session after deleting") } if err != nil { t.Errorf("Error retrieving outbound session: %v", err) } }) } } func TestStoreDevices(t *testing.T) { stores, cleanup := getCryptoStores(t) defer cleanup() for storeName, store := range stores { t.Run(storeName, func(t *testing.T) { deviceMap := make(map[id.DeviceID]*DeviceIdentity) for i := 0; i < 17; i++ { iStr := strconv.Itoa(i) acc := NewOlmAccount() deviceMap[id.DeviceID("dev"+iStr)] = &DeviceIdentity{ UserID: "user1", DeviceID: id.DeviceID("dev" + iStr), IdentityKey: acc.IdentityKey(), SigningKey: acc.SigningKey(), } } err := store.PutDevices("user1", deviceMap) if err != nil { t.Errorf("Error string devices: %v", err) } devs, err := store.GetDevices("user1") if err != nil { t.Errorf("Error getting devices: %v", err) } if len(devs) != 17 { t.Errorf("Stored 17 devices, got back %v", len(devs)) } if devs["dev0"].IdentityKey != deviceMap["dev0"].IdentityKey { t.Errorf("First device identity key does not match") } if devs["dev16"].IdentityKey != deviceMap["dev16"].IdentityKey { t.Errorf("Last device identity key does not match") } filtered := store.FilterTrackedUsers([]id.UserID{"user0", "user1", "user2"}) if len(filtered) != 1 || filtered[0] != "user1" { t.Errorf("Expected to get 'user1' from filter, got %v", filtered) } }) } } go-0.11.1/crypto/utils/000077500000000000000000000000001436100171500146555ustar00rootroot00000000000000go-0.11.1/crypto/utils/utils.go000066400000000000000000000073001436100171500163440ustar00rootroot00000000000000// Copyright (c) 2020 Nikos Filippakis // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package utils import ( "crypto/aes" "crypto/cipher" "crypto/hmac" "crypto/sha256" "crypto/sha512" "encoding/base64" "math/rand" "strings" "golang.org/x/crypto/hkdf" "golang.org/x/crypto/pbkdf2" "maunium.net/go/mautrix/util/base58" ) const ( // AESCTRKeyLength is the length of the AES256-CTR key used. AESCTRKeyLength = 32 // AESCTRIVLength is the length of the AES256-CTR IV used. AESCTRIVLength = 16 // HMACKeyLength is the length of the HMAC key used. HMACKeyLength = 32 // SHAHashLength is the length of the SHA hash used. SHAHashLength = 32 ) // XorA256CTR encrypts the input with the keystream generated by the AES256-CTR algorithm with the given arguments. func XorA256CTR(source []byte, key [AESCTRKeyLength]byte, iv [AESCTRIVLength]byte) []byte { block, _ := aes.NewCipher(key[:]) cipher.NewCTR(block, iv[:]).XORKeyStream(source, source) return source } // GenAttachmentA256CTR generates a new random AES256-CTR key and IV suitable for encrypting attachments. func GenAttachmentA256CTR() (key [AESCTRKeyLength]byte, iv [AESCTRIVLength]byte) { _, err := rand.Read(key[:]) if err != nil { panic(err) } // The last 8 bytes of the IV act as the counter in AES-CTR, which means they're left empty here _, err = rand.Read(iv[:8]) if err != nil { panic(err) } return } // GenA256CTRIV generates a random IV for AES256-CTR with the last bit set to zero. func GenA256CTRIV() (iv [AESCTRIVLength]byte) { _, err := rand.Read(iv[:]) if err != nil { panic(err) } iv[8] &= 0x7F return } // DeriveKeysSHA256 derives an AES and a HMAC key from the given recovery key. func DeriveKeysSHA256(key []byte, name string) ([AESCTRKeyLength]byte, [HMACKeyLength]byte) { var zeroBytes [32]byte derivedHkdf := hkdf.New(sha256.New, key[:], zeroBytes[:], []byte(name)) var aesKey [AESCTRKeyLength]byte var hmacKey [HMACKeyLength]byte derivedHkdf.Read(aesKey[:]) derivedHkdf.Read(hmacKey[:]) return aesKey, hmacKey } // PBKDF2SHA512 generates a key of the given bit-length using the given passphrase, salt and iteration count. func PBKDF2SHA512(password []byte, salt []byte, iters int, keyLenBits int) []byte { return pbkdf2.Key(password, salt, iters, keyLenBits/8, sha512.New) } // DecodeBase58RecoveryKey recovers the secret storage from a recovery key. func DecodeBase58RecoveryKey(recoveryKey string) []byte { noSpaces := strings.ReplaceAll(recoveryKey, " ", "") decoded := base58.Decode(noSpaces) if len(decoded) != AESCTRKeyLength+3 { // AESCTRKeyLength bytes key and 3 bytes prefix / parity return nil } var parity byte for _, b := range decoded[:34] { parity ^= b } if parity != decoded[34] || decoded[0] != 0x8B || decoded[1] != 1 { return nil } return decoded[2:34] } // EncodeBase58RecoveryKey recovers the secret storage from a recovery key. func EncodeBase58RecoveryKey(key []byte) string { var inputBytes [35]byte copy(inputBytes[2:34], key[:]) inputBytes[0] = 0x8B inputBytes[1] = 1 var parity byte for _, b := range inputBytes[:34] { parity ^= b } inputBytes[34] = parity recoveryKey := base58.Encode(inputBytes[:]) var spacedKey string for i, c := range recoveryKey { if i > 0 && i%4 == 0 { spacedKey += " " } spacedKey += string(c) } return spacedKey } // HMACSHA256B64 calculates the base64 of the SHA256 hmac of the input with the given key. func HMACSHA256B64(input []byte, hmacKey [HMACKeyLength]byte) string { h := hmac.New(sha256.New, hmacKey[:]) h.Write(input) return base64.StdEncoding.EncodeToString(h.Sum(nil)) } go-0.11.1/crypto/utils/utils_test.go000066400000000000000000000057201436100171500174070ustar00rootroot00000000000000// Copyright (c) 2020 Nikos Filippakis // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package utils import ( "encoding/base64" "testing" ) func TestAES256Ctr(t *testing.T) { expected := "Hello world" key, iv := GenAttachmentA256CTR() enc := XorA256CTR([]byte(expected), key, iv) dec := XorA256CTR(enc, key, iv) if string(dec) != expected { t.Errorf("Expected decrypted using generated key/iv to be `%v`, got %v", expected, string(dec)) } var key2 [AESCTRKeyLength]byte var iv2 [AESCTRIVLength]byte for i := 0; i < AESCTRKeyLength; i++ { key2[i] = byte(i) } for i := 0; i < AESCTRIVLength; i++ { iv2[i] = byte(i) + 32 } dec2 := XorA256CTR([]byte{0x29, 0xc3, 0xff, 0x02, 0x21, 0xaf, 0x67, 0x73, 0x6e, 0xad, 0x9d}, key2, iv2) if string(dec2) != expected { t.Errorf("Expected decrypted using constant key/iv to be `%v`, got %v", expected, string(dec2)) } } func TestPBKDF(t *testing.T) { salt := make([]byte, 16) for i := 0; i < 16; i++ { salt[i] = byte(i) } key := PBKDF2SHA512([]byte("Hello world"), salt, 1000, 256) expected := "ffk9YdbVE1cgqOWgDaec0lH+rJzO+MuCcxpIn3Z6D0E=" keyB64 := base64.StdEncoding.EncodeToString([]byte(key)) if keyB64 != expected { t.Errorf("Expected base64 of generated key to be `%v`, got `%v`", expected, keyB64) } } func TestDecodeSSSSKey(t *testing.T) { recoveryKey := "EsTL 2cTx 9Qy1 8TVd qGsn GDrD i5dT EEuX Qz8U P7hi Z7uu U8wZ" decoded := DecodeBase58RecoveryKey(recoveryKey) expected := "QCFDrXZYLEFnwf4NikVm62rYGJS2mNBEmAWLC3CgNPw=" decodedB64 := base64.StdEncoding.EncodeToString(decoded[:]) if expected != decodedB64 { t.Errorf("Expected decoded recovery key b64 to be `%v`, got `%v`", expected, decodedB64) } if encoded := EncodeBase58RecoveryKey(decoded); encoded != recoveryKey { t.Errorf("Expected recovery key to be `%v`, got `%v`", recoveryKey, encoded) } } func TestKeyDerivationAndHMAC(t *testing.T) { recoveryKey := "EsUG Ddi6 e1Cm F4um g38u JN72 d37v Q2ry qCf2 rKgL E2MQ ZQz6" decoded := DecodeBase58RecoveryKey(recoveryKey) aesKey, hmacKey := DeriveKeysSHA256(decoded[:], "m.cross_signing.master") ciphertextBytes, err := base64.StdEncoding.DecodeString("Fx16KlJ9vkd3Dd6CafIq5spaH5QmK5BALMzbtFbQznG2j1VARKK+klc4/Qo=") if err != nil { t.Error(err) } calcMac := HMACSHA256B64(ciphertextBytes, hmacKey) expectedMac := "0DABPNIZsP9iTOh1o6EM0s7BfHHXb96dN7Eca88jq2E=" if calcMac != expectedMac { t.Errorf("Expected MAC `%v`, got `%v`", expectedMac, calcMac) } var ivBytes [AESCTRIVLength]byte decodedIV, _ := base64.StdEncoding.DecodeString("zxT/W5LpZ0Q819pfju6hZw==") copy(ivBytes[:], decodedIV) decrypted := string(XorA256CTR(ciphertextBytes, aesKey, ivBytes)) expectedDec := "Ec8eZDyvVkO3EDsEG6ej5c0cCHnX7PINqFXZjnaTV2s=" if expectedDec != decrypted { t.Errorf("Expected decrypted text to be `%v`, got `%v`", expectedDec, decrypted) } } go-0.11.1/crypto/verification.go000066400000000000000000001047121436100171500165330ustar00rootroot00000000000000// Copyright (c) 2020 Nikos Filippakis // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. //go:build !nosas // +build !nosas package crypto import ( "context" "encoding/json" "errors" "fmt" "math/rand" "sort" "strconv" "strings" "sync" "time" "maunium.net/go/mautrix" "maunium.net/go/mautrix/crypto/canonicaljson" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) var ( ErrUnknownUserForTransaction = errors.New("unknown user for transaction") ErrTransactionAlreadyExists = errors.New("transaction already exists") // ErrUnknownTransaction is returned when a key verification message is received with an unknown transaction ID. ErrUnknownTransaction = errors.New("unknown transaction") // ErrUnknownVerificationMethod is returned when the verification method in a received m.key.verification.start is unknown. ErrUnknownVerificationMethod = errors.New("unknown verification method") ) type VerificationHooks interface { // VerifySASMatch receives the generated SAS and its method, as well as the device that is being verified. // It returns whether the given SAS match with the SAS displayed on other device. VerifySASMatch(otherDevice *DeviceIdentity, sas SASData) bool // VerificationMethods returns the list of supported verification methods in order of preference. // It must contain at least the decimal method. VerificationMethods() []VerificationMethod OnCancel(cancelledByUs bool, reason string, reasonCode event.VerificationCancelCode) OnSuccess() } type VerificationRequestResponse int const ( AcceptRequest VerificationRequestResponse = iota RejectRequest IgnoreRequest ) // sendToOneDevice sends a to-device event to a single device. func (mach *OlmMachine) sendToOneDevice(userID id.UserID, deviceID id.DeviceID, eventType event.Type, content interface{}) error { _, err := mach.Client.SendToDevice(eventType, &mautrix.ReqSendToDevice{ Messages: map[id.UserID]map[id.DeviceID]*event.Content{ userID: { deviceID: { Parsed: content, }, }, }, }) return err } func (mach *OlmMachine) getPKAndKeysMAC(sas *olm.SAS, sendingUser id.UserID, sendingDevice id.DeviceID, receivingUser id.UserID, receivingDevice id.DeviceID, transactionID string, signingKey id.SigningKey, mainKeyID id.KeyID, keys map[id.KeyID]string) (string, string, error) { sasInfo := "MATRIX_KEY_VERIFICATION_MAC" + sendingUser.String() + sendingDevice.String() + receivingUser.String() + receivingDevice.String() + transactionID // get key IDs from key map keyIDStrings := make([]string, len(keys)) i := 0 for keyID := range keys { keyIDStrings[i] = keyID.String() i++ } sort.Sort(sort.StringSlice(keyIDStrings)) keyIDString := strings.Join(keyIDStrings, ",") pubKeyMac, err := sas.CalculateMAC([]byte(signingKey), []byte(sasInfo+mainKeyID.String())) if err != nil { return "", "", err } mach.Log.Trace("sas.CalculateMAC(\"%s\", \"%s\") -> \"%s\"", signingKey, sasInfo+mainKeyID.String(), string(pubKeyMac)) keysMac, err := sas.CalculateMAC([]byte(keyIDString), []byte(sasInfo+"KEY_IDS")) if err != nil { return "", "", err } mach.Log.Trace("sas.CalculateMAC(\"%s\", \"%s\") -> \"%s\"", keyIDString, sasInfo+"KEY_IDS", string(keysMac)) return string(pubKeyMac), string(keysMac), nil } // verificationState holds all the information needed for the state of a SAS verification with another device. type verificationState struct { sas *olm.SAS otherDevice *DeviceIdentity initiatedByUs bool verificationStarted bool keyReceived bool sasMatched chan bool commitment string startEventCanonical string chosenSASMethod VerificationMethod hooks VerificationHooks extendTimeout context.CancelFunc inRoomID id.RoomID lock sync.Mutex } // getTransactionState retrieves the given transaction's state, or cancels the transaction if it cannot be found or there is a mismatch. func (mach *OlmMachine) getTransactionState(transactionID string, userID id.UserID) (*verificationState, error) { verStateInterface, ok := mach.keyVerificationTransactionState.Load(userID.String() + ":" + transactionID) if !ok { _ = mach.SendSASVerificationCancel(userID, id.DeviceID("*"), transactionID, "Unknown transaction: "+transactionID, event.VerificationCancelUnknownTransaction) return nil, ErrUnknownTransaction } verState := verStateInterface.(*verificationState) if verState.otherDevice.UserID != userID { reason := fmt.Sprintf("Unknown user for transaction %v: %v", transactionID, userID) if verState.inRoomID == "" { _ = mach.SendSASVerificationCancel(userID, id.DeviceID("*"), transactionID, reason, event.VerificationCancelUserMismatch) } else { _ = mach.SendInRoomSASVerificationCancel(verState.inRoomID, userID, transactionID, reason, event.VerificationCancelUserMismatch) } mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID) return nil, fmt.Errorf("%w %s: %s", ErrUnknownUserForTransaction, transactionID, userID) } return verState, nil } // handleVerificationStart handles an incoming m.key.verification.start message. // It initializes the state for this SAS verification process and stores it. func (mach *OlmMachine) handleVerificationStart(userID id.UserID, content *event.VerificationStartEventContent, transactionID string, timeout time.Duration, inRoomID id.RoomID) { mach.Log.Debug("Received verification start from %v", content.FromDevice) otherDevice, err := mach.GetOrFetchDevice(userID, content.FromDevice) if err != nil { mach.Log.Error("Could not find device %v of user %v", content.FromDevice, userID) return } warnAndCancel := func(logReason, cancelReason string) { mach.Log.Warn("Canceling verification transaction %v as it %s", transactionID, logReason) if inRoomID == "" { _ = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, cancelReason, event.VerificationCancelUnknownMethod) } else { _ = mach.SendInRoomSASVerificationCancel(inRoomID, otherDevice.UserID, transactionID, cancelReason, event.VerificationCancelUnknownMethod) } } switch { case content.Method != event.VerificationMethodSAS: warnAndCancel("is not SAS", "Only SAS method is supported") case !content.SupportsKeyAgreementProtocol(event.KeyAgreementCurve25519HKDFSHA256): warnAndCancel("does not support key agreement protocol curve25519-hkdf-sha256", "Only curve25519-hkdf-sha256 key agreement protocol is supported") case !content.SupportsHashMethod(event.VerificationHashSHA256): warnAndCancel("does not support SHA256 hashing", "Only SHA256 hashing is supported") case !content.SupportsMACMethod(event.HKDFHMACSHA256): warnAndCancel("does not support MAC method hkdf-hmac-sha256", "Only hkdf-hmac-sha256 MAC method is supported") case !content.SupportsSASMethod(event.SASDecimal): warnAndCancel("does not support decimal SAS", "Decimal SAS method must be supported") default: mach.actuallyStartVerification(userID, content, otherDevice, transactionID, timeout, inRoomID) } } func (mach *OlmMachine) actuallyStartVerification(userID id.UserID, content *event.VerificationStartEventContent, otherDevice *DeviceIdentity, transactionID string, timeout time.Duration, inRoomID id.RoomID) { if inRoomID != "" && transactionID != "" { verState, err := mach.getTransactionState(transactionID, userID) if err != nil { mach.Log.Error("Failed to get transaction state for in-room verification %s start: %v", transactionID, err) _ = mach.SendInRoomSASVerificationCancel(inRoomID, otherDevice.UserID, transactionID, "Internal state error in gomuks :(", "net.maunium.internal_error") return } mach.timeoutAfter(verState, transactionID, timeout) sasMethods := commonSASMethods(verState.hooks, content.ShortAuthenticationString) err = mach.SendInRoomSASVerificationAccept(inRoomID, userID, content, transactionID, verState.sas.GetPubkey(), sasMethods) if err != nil { mach.Log.Error("Error accepting in-room SAS verification: %v", err) } verState.chosenSASMethod = sasMethods[0] verState.verificationStarted = true return } resp, hooks := mach.AcceptVerificationFrom(transactionID, otherDevice, inRoomID) if resp == AcceptRequest { sasMethods := commonSASMethods(hooks, content.ShortAuthenticationString) if len(sasMethods) == 0 { mach.Log.Error("No common SAS methods: %v", content.ShortAuthenticationString) if inRoomID == "" { _ = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, "No common SAS methods", event.VerificationCancelUnknownMethod) } else { _ = mach.SendInRoomSASVerificationCancel(inRoomID, otherDevice.UserID, transactionID, "No common SAS methods", event.VerificationCancelUnknownMethod) } return } verState := &verificationState{ sas: olm.NewSAS(), otherDevice: otherDevice, initiatedByUs: false, verificationStarted: true, keyReceived: false, sasMatched: make(chan bool, 1), hooks: hooks, chosenSASMethod: sasMethods[0], inRoomID: inRoomID, } verState.lock.Lock() defer verState.lock.Unlock() _, loaded := mach.keyVerificationTransactionState.LoadOrStore(userID.String()+":"+transactionID, verState) if loaded { // transaction already exists mach.Log.Error("Transaction %v already exists, canceling", transactionID) if inRoomID == "" { _ = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, "Transaction already exists", event.VerificationCancelUnexpectedMessage) } else { _ = mach.SendInRoomSASVerificationCancel(inRoomID, otherDevice.UserID, transactionID, "Transaction already exists", event.VerificationCancelUnexpectedMessage) } return } mach.timeoutAfter(verState, transactionID, timeout) var err error if inRoomID == "" { err = mach.SendSASVerificationAccept(userID, content, verState.sas.GetPubkey(), sasMethods) } else { err = mach.SendInRoomSASVerificationAccept(inRoomID, userID, content, transactionID, verState.sas.GetPubkey(), sasMethods) } if err != nil { mach.Log.Error("Error accepting SAS verification: %v", err) } } else if resp == RejectRequest { mach.Log.Debug("Not accepting SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID) var err error if inRoomID == "" { err = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, "Not accepted by user", event.VerificationCancelByUser) } else { err = mach.SendInRoomSASVerificationCancel(inRoomID, otherDevice.UserID, transactionID, "Not accepted by user", event.VerificationCancelByUser) } if err != nil { mach.Log.Error("Error canceling SAS verification: %v", err) } } else { mach.Log.Debug("Ignoring SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID) } } func (mach *OlmMachine) timeoutAfter(verState *verificationState, transactionID string, timeout time.Duration) { timeoutCtx, timeoutCancel := context.WithTimeout(context.Background(), timeout) verState.extendTimeout = timeoutCancel go func() { mapKey := verState.otherDevice.UserID.String() + ":" + transactionID for { <-timeoutCtx.Done() // when timeout context is done verState.lock.Lock() // if transaction not active anymore, return if _, ok := mach.keyVerificationTransactionState.Load(mapKey); !ok { verState.lock.Unlock() return } if timeoutCtx.Err() == context.DeadlineExceeded { // if deadline exceeded cancel due to timeout mach.keyVerificationTransactionState.Delete(mapKey) _ = mach.callbackAndCancelSASVerification(verState, transactionID, "Timed out", event.VerificationCancelByTimeout) mach.Log.Warn("Verification transaction %v is canceled due to timing out", transactionID) verState.lock.Unlock() return } // otherwise the cancel func was called, so the timeout is reset mach.Log.Debug("Extending timeout for transaction %v", transactionID) timeoutCtx, timeoutCancel = context.WithTimeout(context.Background(), timeout) verState.extendTimeout = timeoutCancel verState.lock.Unlock() } }() } // handleVerificationAccept handles an incoming m.key.verification.accept message. // It continues the SAS verification process by sending the SAS key message to the other device. func (mach *OlmMachine) handleVerificationAccept(userID id.UserID, content *event.VerificationAcceptEventContent, transactionID string) { mach.Log.Debug("Received verification accept for transaction %v", transactionID) verState, err := mach.getTransactionState(transactionID, userID) if err != nil { mach.Log.Error("Error getting transaction state: %v", err) return } verState.lock.Lock() defer verState.lock.Unlock() verState.extendTimeout() if !verState.initiatedByUs || verState.verificationStarted { // unexpected accept at this point mach.Log.Warn("Unexpected verification accept message for transaction %v", transactionID) mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID) _ = mach.callbackAndCancelSASVerification(verState, transactionID, "Unexpected accept message", event.VerificationCancelUnexpectedMessage) return } sasMethods := commonSASMethods(verState.hooks, content.ShortAuthenticationString) if content.KeyAgreementProtocol != event.KeyAgreementCurve25519HKDFSHA256 || content.Hash != event.VerificationHashSHA256 || content.MessageAuthenticationCode != event.HKDFHMACSHA256 || len(sasMethods) == 0 { mach.Log.Warn("Canceling verification transaction %v due to unknown parameter", transactionID) mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID) _ = mach.callbackAndCancelSASVerification(verState, transactionID, "Verification uses unknown method", event.VerificationCancelUnknownMethod) return } key := verState.sas.GetPubkey() verState.commitment = content.Commitment verState.chosenSASMethod = sasMethods[0] verState.verificationStarted = true if verState.inRoomID == "" { err = mach.SendSASVerificationKey(userID, verState.otherDevice.DeviceID, transactionID, string(key)) } else { err = mach.SendInRoomSASVerificationKey(verState.inRoomID, userID, transactionID, string(key)) } if err != nil { mach.Log.Error("Error sending SAS key to other device: %v", err) return } } // handleVerificationKey handles an incoming m.key.verification.key message. // It stores the other device's public key in order to acquire the SAS shared secret. func (mach *OlmMachine) handleVerificationKey(userID id.UserID, content *event.VerificationKeyEventContent, transactionID string) { mach.Log.Debug("Got verification key for transaction %v: %v", transactionID, content.Key) verState, err := mach.getTransactionState(transactionID, userID) if err != nil { mach.Log.Error("Error getting transaction state: %v", err) return } verState.lock.Lock() defer verState.lock.Unlock() verState.extendTimeout() device := verState.otherDevice if !verState.verificationStarted || verState.keyReceived { // unexpected key at this point mach.Log.Warn("Unexpected verification key message for transaction %v", transactionID) mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID) _ = mach.callbackAndCancelSASVerification(verState, transactionID, "Unexpected key message", event.VerificationCancelUnexpectedMessage) return } if err := verState.sas.SetTheirKey([]byte(content.Key)); err != nil { mach.Log.Error("Error setting other device's key: %v", err) return } verState.keyReceived = true if verState.initiatedByUs { // verify commitment string from accept message now expectedCommitment := olm.NewUtility().Sha256(content.Key + verState.startEventCanonical) mach.Log.Debug("Received commitment: %v Expected: %v", verState.commitment, expectedCommitment) if expectedCommitment != verState.commitment { mach.Log.Warn("Canceling verification transaction %v due to commitment mismatch", transactionID) mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID) _ = mach.callbackAndCancelSASVerification(verState, transactionID, "Commitment mismatch", event.VerificationCancelCommitmentMismatch) return } } else { // if verification was initiated by other device, send out our key now key := verState.sas.GetPubkey() if verState.inRoomID == "" { err = mach.SendSASVerificationKey(userID, device.DeviceID, transactionID, string(key)) } else { err = mach.SendInRoomSASVerificationKey(verState.inRoomID, userID, transactionID, string(key)) } if err != nil { mach.Log.Error("Error sending SAS key to other device: %v", err) return } } // compare the SAS keys in a new goroutine and, when the verification is complete, send out the MAC var initUserID, acceptUserID id.UserID var initDeviceID, acceptDeviceID id.DeviceID var initKey, acceptKey string if verState.initiatedByUs { initUserID = mach.Client.UserID initDeviceID = mach.Client.DeviceID initKey = string(verState.sas.GetPubkey()) acceptUserID = device.UserID acceptDeviceID = device.DeviceID acceptKey = content.Key } else { initUserID = device.UserID initDeviceID = device.DeviceID initKey = content.Key acceptUserID = mach.Client.UserID acceptDeviceID = mach.Client.DeviceID acceptKey = string(verState.sas.GetPubkey()) } // use the prefered SAS method to generate a SAS sasMethod := verState.chosenSASMethod sas, err := sasMethod.GetVerificationSAS(initUserID, initDeviceID, initKey, acceptUserID, acceptDeviceID, acceptKey, transactionID, verState.sas) if err != nil { mach.Log.Error("Error generating SAS (method %v): %v", sasMethod.Type(), err) return } mach.Log.Debug("Generated SAS (%v): %v", sasMethod.Type(), sas) go func() { result := verState.hooks.VerifySASMatch(device, sas) mach.sasCompared(result, transactionID, verState) }() } // sasCompared is called asynchronously. It waits for the SAS to be compared for the verification to proceed. // If the SAS match, then our MAC is sent out. Otherwise the transaction is canceled. func (mach *OlmMachine) sasCompared(didMatch bool, transactionID string, verState *verificationState) { verState.lock.Lock() defer verState.lock.Unlock() verState.extendTimeout() if didMatch { verState.sasMatched <- true var err error if verState.inRoomID == "" { err = mach.SendSASVerificationMAC(verState.otherDevice.UserID, verState.otherDevice.DeviceID, transactionID, verState.sas) } else { err = mach.SendInRoomSASVerificationMAC(verState.inRoomID, verState.otherDevice.UserID, verState.otherDevice.DeviceID, transactionID, verState.sas) } if err != nil { mach.Log.Error("Error sending verification MAC to other device: %v", err) } } else { verState.sasMatched <- false } } // handleVerificationMAC handles an incoming m.key.verification.mac message. // It verifies the other device's MAC and if the MAC is valid it marks the device as trusted. func (mach *OlmMachine) handleVerificationMAC(userID id.UserID, content *event.VerificationMacEventContent, transactionID string) { mach.Log.Debug("Got MAC for verification %v: %v, MAC for keys: %v", transactionID, content.Mac, content.Keys) verState, err := mach.getTransactionState(transactionID, userID) if err != nil { mach.Log.Error("Error getting transaction state: %v", err) return } verState.lock.Lock() defer verState.lock.Unlock() verState.extendTimeout() device := verState.otherDevice // we are done with this SAS verification in all cases so we forget about it mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID) if !verState.verificationStarted || !verState.keyReceived { // unexpected MAC at this point mach.Log.Warn("Unexpected MAC message for transaction %v", transactionID) _ = mach.callbackAndCancelSASVerification(verState, transactionID, "Unexpected MAC message", event.VerificationCancelUnexpectedMessage) return } // do this in another goroutine as the match result might take a long time to arrive go func() { matched := <-verState.sasMatched verState.lock.Lock() defer verState.lock.Unlock() if !matched { mach.Log.Warn("SAS do not match! Canceling transaction %v", transactionID) _ = mach.callbackAndCancelSASVerification(verState, transactionID, "SAS do not match", event.VerificationCancelSASMismatch) return } keyID := id.NewKeyID(id.KeyAlgorithmEd25519, device.DeviceID.String()) expectedPKMAC, expectedKeysMAC, err := mach.getPKAndKeysMAC(verState.sas, device.UserID, device.DeviceID, mach.Client.UserID, mach.Client.DeviceID, transactionID, device.SigningKey, keyID, content.Mac) if err != nil { mach.Log.Error("Error generating MAC to match with received MAC: %v", err) return } mach.Log.Debug("Expected %s keys MAC, got %s", expectedKeysMAC, content.Keys) if content.Keys != expectedKeysMAC { mach.Log.Warn("Canceling verification transaction %v due to mismatched keys MAC", transactionID) _ = mach.callbackAndCancelSASVerification(verState, transactionID, "Mismatched keys MACs", event.VerificationCancelKeyMismatch) return } mach.Log.Debug("Expected %s PK MAC, got %s", expectedPKMAC, content.Mac[keyID]) if content.Mac[keyID] != expectedPKMAC { mach.Log.Warn("Canceling verification transaction %v due to mismatched PK MAC", transactionID) _ = mach.callbackAndCancelSASVerification(verState, transactionID, "Mismatched PK MACs", event.VerificationCancelKeyMismatch) return } // we can finally trust this device device.Trust = TrustStateVerified err = mach.CryptoStore.PutDevice(device.UserID, device) if err != nil { mach.Log.Warn("Failed to put device after verifying: %v", err) } if mach.CrossSigningKeys != nil { if device.UserID == mach.Client.UserID { err := mach.SignOwnDevice(device) if err != nil { mach.Log.Error("Failed to cross-sign own device %s: %v", device.DeviceID, err) } else { mach.Log.Debug("Cross-signed own device %v after SAS verification", device.DeviceID) } } else { masterKey, err := mach.fetchMasterKey(device, content, verState, transactionID) if err != nil { mach.Log.Warn("Failed to fetch %s's master key: %v", device.UserID, err) } else { if err := mach.SignUser(device.UserID, masterKey); err != nil { mach.Log.Error("Failed to cross-sign master key of %s: %v", device.UserID, err) } else { mach.Log.Debug("Cross-signed master key of %v after SAS verification", device.UserID) } } } } else { // TODO ask user to unlock cross-signing keys? mach.Log.Debug("Cross-signing keys not cached, not signing %s/%s", device.UserID, device.DeviceID) } mach.Log.Debug("Device %v of user %v verified successfully!", device.DeviceID, device.UserID) verState.hooks.OnSuccess() }() } // handleVerificationCancel handles an incoming m.key.verification.cancel message. // It cancels the verification process for the given reason. func (mach *OlmMachine) handleVerificationCancel(userID id.UserID, content *event.VerificationCancelEventContent, transactionID string) { // make sure to not reply with a cancel to not cause a loop of cancel messages // this verification will get canceled even if the senders do not match verStateInterface, ok := mach.keyVerificationTransactionState.Load(userID.String() + ":" + transactionID) if ok { go verStateInterface.(*verificationState).hooks.OnCancel(false, content.Reason, content.Code) } mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID) mach.Log.Warn("SAS verification %v was canceled by %v with reason: %v (%v)", transactionID, userID, content.Reason, content.Code) } // handleVerificationRequest handles an incoming m.key.verification.request message. func (mach *OlmMachine) handleVerificationRequest(userID id.UserID, content *event.VerificationRequestEventContent, transactionID string, inRoomID id.RoomID) { mach.Log.Debug("Received verification request from %v", content.FromDevice) otherDevice, err := mach.GetOrFetchDevice(userID, content.FromDevice) if err != nil { mach.Log.Error("Could not find device %v of user %v", content.FromDevice, userID) return } if !content.SupportsVerificationMethod(event.VerificationMethodSAS) { mach.Log.Warn("Canceling verification transaction %v as SAS is not supported", transactionID) if inRoomID == "" { _ = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, "Only SAS method is supported", event.VerificationCancelUnknownMethod) } else { _ = mach.SendInRoomSASVerificationCancel(inRoomID, otherDevice.UserID, transactionID, "Only SAS method is supported", event.VerificationCancelUnknownMethod) } return } resp, hooks := mach.AcceptVerificationFrom(transactionID, otherDevice, inRoomID) if resp == AcceptRequest { mach.Log.Debug("Accepting SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID) if inRoomID == "" { _, err = mach.NewSASVerificationWith(otherDevice, hooks, transactionID, mach.DefaultSASTimeout) } else { if err := mach.SendInRoomSASVerificationReady(inRoomID, transactionID); err != nil { mach.Log.Error("Error sending in-room SAS verification ready: %v", err) } if mach.Client.UserID < otherDevice.UserID { // up to us to send the start message _, err = mach.newInRoomSASVerificationWithInner(inRoomID, otherDevice, hooks, transactionID, mach.DefaultSASTimeout) } } if err != nil { mach.Log.Error("Error accepting SAS verification request: %v", err) } } else if resp == RejectRequest { mach.Log.Debug("Rejecting SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID) if inRoomID == "" { _ = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, "Not accepted by user", event.VerificationCancelByUser) } else { _ = mach.SendInRoomSASVerificationCancel(inRoomID, otherDevice.UserID, transactionID, "Not accepted by user", event.VerificationCancelByUser) } } else { mach.Log.Debug("Ignoring SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID) } } // NewSimpleSASVerificationWith starts the SAS verification process with another device with a default timeout, // a generated transaction ID and support for both emoji and decimal SAS methods. func (mach *OlmMachine) NewSimpleSASVerificationWith(device *DeviceIdentity, hooks VerificationHooks) (string, error) { return mach.NewSASVerificationWith(device, hooks, "", mach.DefaultSASTimeout) } // NewSASVerificationWith starts the SAS verification process with another device. // If the other device accepts the verification transaction, the methods in `hooks` will be used to verify the SAS match and to complete the transaction.. // If the transaction ID is empty, a new one is generated. func (mach *OlmMachine) NewSASVerificationWith(device *DeviceIdentity, hooks VerificationHooks, transactionID string, timeout time.Duration) (string, error) { if transactionID == "" { transactionID = strconv.Itoa(rand.Int()) } mach.Log.Debug("Starting new verification transaction %v with device %v of user %v", transactionID, device.DeviceID, device.UserID) verState := &verificationState{ sas: olm.NewSAS(), otherDevice: device, initiatedByUs: true, verificationStarted: false, keyReceived: false, sasMatched: make(chan bool, 1), hooks: hooks, } verState.lock.Lock() defer verState.lock.Unlock() startEvent, err := mach.SendSASVerificationStart(device.UserID, device.DeviceID, transactionID, hooks.VerificationMethods()) if err != nil { return "", err } payload, err := json.Marshal(startEvent) if err != nil { return "", err } canonical, err := canonicaljson.CanonicalJSON(payload) if err != nil { return "", err } verState.startEventCanonical = string(canonical) _, loaded := mach.keyVerificationTransactionState.LoadOrStore(device.UserID.String()+":"+transactionID, verState) if loaded { return "", ErrTransactionAlreadyExists } mach.timeoutAfter(verState, transactionID, timeout) return transactionID, nil } // CancelSASVerification is used by the user to cancel a SAS verification process with the given reason. func (mach *OlmMachine) CancelSASVerification(userID id.UserID, transactionID, reason string) error { mapKey := userID.String() + ":" + transactionID verStateInterface, ok := mach.keyVerificationTransactionState.Load(mapKey) if !ok { return ErrUnknownTransaction } verState := verStateInterface.(*verificationState) verState.lock.Lock() defer verState.lock.Unlock() mach.Log.Trace("User canceled verification transaction %v with reason: %v", transactionID, reason) mach.keyVerificationTransactionState.Delete(mapKey) return mach.callbackAndCancelSASVerification(verState, transactionID, reason, event.VerificationCancelByUser) } // SendSASVerificationCancel is used to manually send a SAS cancel message process with the given reason and cancellation code. func (mach *OlmMachine) SendSASVerificationCancel(userID id.UserID, deviceID id.DeviceID, transactionID string, reason string, code event.VerificationCancelCode) error { content := &event.VerificationCancelEventContent{ TransactionID: transactionID, Reason: reason, Code: code, } return mach.sendToOneDevice(userID, deviceID, event.ToDeviceVerificationCancel, content) } // SendSASVerificationStart is used to manually send the SAS verification start message to another device. func (mach *OlmMachine) SendSASVerificationStart(toUserID id.UserID, toDeviceID id.DeviceID, transactionID string, methods []VerificationMethod) (*event.VerificationStartEventContent, error) { sasMethods := make([]event.SASMethod, len(methods)) for i, method := range methods { sasMethods[i] = method.Type() } content := &event.VerificationStartEventContent{ FromDevice: mach.Client.DeviceID, TransactionID: transactionID, Method: event.VerificationMethodSAS, KeyAgreementProtocols: []event.KeyAgreementProtocol{event.KeyAgreementCurve25519HKDFSHA256}, Hashes: []event.VerificationHashMethod{event.VerificationHashSHA256}, MessageAuthenticationCodes: []event.MACMethod{event.HKDFHMACSHA256}, ShortAuthenticationString: sasMethods, } return content, mach.sendToOneDevice(toUserID, toDeviceID, event.ToDeviceVerificationStart, content) } // SendSASVerificationAccept is used to manually send an accept for a SAS verification process from a received m.key.verification.start event. func (mach *OlmMachine) SendSASVerificationAccept(fromUser id.UserID, startEvent *event.VerificationStartEventContent, publicKey []byte, methods []VerificationMethod) error { if startEvent.Method != event.VerificationMethodSAS { reason := "Unknown verification method: " + string(startEvent.Method) if err := mach.SendSASVerificationCancel(fromUser, startEvent.FromDevice, startEvent.TransactionID, reason, event.VerificationCancelUnknownMethod); err != nil { return err } return ErrUnknownVerificationMethod } payload, err := json.Marshal(startEvent) if err != nil { return err } canonical, err := canonicaljson.CanonicalJSON(payload) if err != nil { return err } hash := olm.NewUtility().Sha256(string(publicKey) + string(canonical)) sasMethods := make([]event.SASMethod, len(methods)) for i, method := range methods { sasMethods[i] = method.Type() } content := &event.VerificationAcceptEventContent{ TransactionID: startEvent.TransactionID, Method: event.VerificationMethodSAS, KeyAgreementProtocol: event.KeyAgreementCurve25519HKDFSHA256, Hash: event.VerificationHashSHA256, MessageAuthenticationCode: event.HKDFHMACSHA256, ShortAuthenticationString: sasMethods, Commitment: hash, } return mach.sendToOneDevice(fromUser, startEvent.FromDevice, event.ToDeviceVerificationAccept, content) } func (mach *OlmMachine) callbackAndCancelSASVerification(verState *verificationState, transactionID, reason string, code event.VerificationCancelCode) error { go verState.hooks.OnCancel(true, reason, code) return mach.SendSASVerificationCancel(verState.otherDevice.UserID, verState.otherDevice.DeviceID, transactionID, reason, code) } // SendSASVerificationKey sends the ephemeral public key for a device to the partner device. func (mach *OlmMachine) SendSASVerificationKey(userID id.UserID, deviceID id.DeviceID, transactionID string, key string) error { content := &event.VerificationKeyEventContent{ TransactionID: transactionID, Key: key, } return mach.sendToOneDevice(userID, deviceID, event.ToDeviceVerificationKey, content) } // SendSASVerificationMAC is use the MAC of a device's key to the partner device. func (mach *OlmMachine) SendSASVerificationMAC(userID id.UserID, deviceID id.DeviceID, transactionID string, sas *olm.SAS) error { keyID := id.NewKeyID(id.KeyAlgorithmEd25519, mach.Client.DeviceID.String()) signingKey := mach.account.SigningKey() keyIDsMap := map[id.KeyID]string{keyID: ""} macMap := make(map[id.KeyID]string) if mach.CrossSigningKeys != nil { masterKey := mach.CrossSigningKeys.MasterKey.PublicKey masterKeyID := id.NewKeyID(id.KeyAlgorithmEd25519, masterKey.String()) // add master key ID to key map keyIDsMap[masterKeyID] = "" masterKeyMAC, _, err := mach.getPKAndKeysMAC(sas, mach.Client.UserID, mach.Client.DeviceID, userID, deviceID, transactionID, masterKey, masterKeyID, keyIDsMap) if err != nil { mach.Log.Error("Error generating master key MAC: %v", err) } else { mach.Log.Debug("Generated master key `%v` MAC: %v", masterKey, masterKeyMAC) macMap[masterKeyID] = masterKeyMAC } } pubKeyMac, keysMac, err := mach.getPKAndKeysMAC(sas, mach.Client.UserID, mach.Client.DeviceID, userID, deviceID, transactionID, signingKey, keyID, keyIDsMap) if err != nil { return err } mach.Log.Debug("MAC of key %s is: %s", signingKey, pubKeyMac) mach.Log.Debug("MAC of key ID(s) %s is: %s", keyID, keysMac) macMap[keyID] = pubKeyMac content := &event.VerificationMacEventContent{ TransactionID: transactionID, Keys: keysMac, Mac: macMap, } return mach.sendToOneDevice(userID, deviceID, event.ToDeviceVerificationMAC, content) } func commonSASMethods(hooks VerificationHooks, otherDeviceMethods []event.SASMethod) []VerificationMethod { methods := make([]VerificationMethod, 0) for _, hookMethod := range hooks.VerificationMethods() { for _, otherMethod := range otherDeviceMethods { if hookMethod.Type() == otherMethod { methods = append(methods, hookMethod) break } } } return methods } go-0.11.1/crypto/verification_in_room.go000066400000000000000000000320631436100171500202540ustar00rootroot00000000000000// Copyright (c) 2020 Nikos Filippakis // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package crypto import ( "encoding/json" "errors" "time" "maunium.net/go/mautrix/crypto/canonicaljson" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) var ( ErrNoVerificationFromDevice = errors.New("from_device field is empty") ErrNoVerificationMethods = errors.New("verification method list is empty") ErrNoRelatesTo = errors.New("missing m.relates_to info") ) // ProcessInRoomVerification is a callback that is to be called when a client receives a message // related to in-room verification. // // Currently this is not automatically called, so you must add the listener yourself. // Note that in-room verification events are wrapped in m.room.encrypted, but this expects the decrypted events. func (mach *OlmMachine) ProcessInRoomVerification(evt *event.Event) error { if evt.Sender == mach.Client.UserID { // nothing to do if the message is our own return nil } if relatable, ok := evt.Content.Parsed.(event.Relatable); !ok || relatable.OptionalGetRelatesTo() == nil { return ErrNoRelatesTo } switch content := evt.Content.Parsed.(type) { case *event.MessageEventContent: if content.MsgType == event.MsgVerificationRequest { if content.FromDevice == "" { return ErrNoVerificationFromDevice } if content.Methods == nil { return ErrNoVerificationMethods } newContent := &event.VerificationRequestEventContent{ FromDevice: content.FromDevice, Methods: content.Methods, Timestamp: evt.Timestamp, TransactionID: evt.ID.String(), } mach.handleVerificationRequest(evt.Sender, newContent, evt.ID.String(), evt.RoomID) } case *event.VerificationStartEventContent: mach.handleVerificationStart(evt.Sender, content, content.RelatesTo.EventID.String(), 10*time.Minute, evt.RoomID) case *event.VerificationReadyEventContent: mach.handleInRoomVerificationReady(evt.Sender, evt.RoomID, content, content.RelatesTo.EventID.String()) case *event.VerificationAcceptEventContent: mach.handleVerificationAccept(evt.Sender, content, content.RelatesTo.EventID.String()) case *event.VerificationKeyEventContent: mach.handleVerificationKey(evt.Sender, content, content.RelatesTo.EventID.String()) case *event.VerificationMacEventContent: mach.handleVerificationMAC(evt.Sender, content, content.RelatesTo.EventID.String()) case *event.VerificationCancelEventContent: mach.handleVerificationCancel(evt.Sender, content, content.RelatesTo.EventID.String()) } return nil } // SendInRoomSASVerificationCancel is used to manually send an in-room SAS cancel message process with the given reason and cancellation code. func (mach *OlmMachine) SendInRoomSASVerificationCancel(roomID id.RoomID, userID id.UserID, transactionID string, reason string, code event.VerificationCancelCode) error { content := &event.VerificationCancelEventContent{ RelatesTo: &event.RelatesTo{Type: event.RelReference, EventID: id.EventID(transactionID)}, Reason: reason, Code: code, To: userID, } encrypted, err := mach.EncryptMegolmEvent(roomID, event.InRoomVerificationCancel, content) if err != nil { return err } _, err = mach.Client.SendMessageEvent(roomID, event.EventEncrypted, encrypted) return err } // SendInRoomSASVerificationRequest is used to manually send an in-room SAS verification request message to another user. func (mach *OlmMachine) SendInRoomSASVerificationRequest(roomID id.RoomID, toUserID id.UserID, methods []VerificationMethod) (string, error) { content := &event.MessageEventContent{ MsgType: event.MsgVerificationRequest, FromDevice: mach.Client.DeviceID, Methods: []event.VerificationMethod{event.VerificationMethodSAS}, To: toUserID, } encrypted, err := mach.EncryptMegolmEvent(roomID, event.EventMessage, content) if err != nil { return "", err } resp, err := mach.Client.SendMessageEvent(roomID, event.EventEncrypted, encrypted) if err != nil { return "", err } return resp.EventID.String(), nil } // SendInRoomSASVerificationReady is used to manually send an in-room SAS verification ready message to another user. func (mach *OlmMachine) SendInRoomSASVerificationReady(roomID id.RoomID, transactionID string) error { content := &event.VerificationReadyEventContent{ FromDevice: mach.Client.DeviceID, Methods: []event.VerificationMethod{event.VerificationMethodSAS}, RelatesTo: &event.RelatesTo{Type: event.RelReference, EventID: id.EventID(transactionID)}, } encrypted, err := mach.EncryptMegolmEvent(roomID, event.InRoomVerificationReady, content) if err != nil { return err } _, err = mach.Client.SendMessageEvent(roomID, event.EventEncrypted, encrypted) return err } // SendInRoomSASVerificationStart is used to manually send the in-room SAS verification start message to another user. func (mach *OlmMachine) SendInRoomSASVerificationStart(roomID id.RoomID, toUserID id.UserID, transactionID string, methods []VerificationMethod) (*event.VerificationStartEventContent, error) { sasMethods := make([]event.SASMethod, len(methods)) for i, method := range methods { sasMethods[i] = method.Type() } content := &event.VerificationStartEventContent{ FromDevice: mach.Client.DeviceID, RelatesTo: &event.RelatesTo{Type: event.RelReference, EventID: id.EventID(transactionID)}, Method: event.VerificationMethodSAS, KeyAgreementProtocols: []event.KeyAgreementProtocol{event.KeyAgreementCurve25519HKDFSHA256}, Hashes: []event.VerificationHashMethod{event.VerificationHashSHA256}, MessageAuthenticationCodes: []event.MACMethod{event.HKDFHMACSHA256}, ShortAuthenticationString: sasMethods, To: toUserID, } encrypted, err := mach.EncryptMegolmEvent(roomID, event.InRoomVerificationStart, content) if err != nil { return nil, err } _, err = mach.Client.SendMessageEvent(roomID, event.EventEncrypted, encrypted) return content, err } // SendInRoomSASVerificationAccept is used to manually send an accept for an in-room SAS verification process from a received m.key.verification.start event. func (mach *OlmMachine) SendInRoomSASVerificationAccept(roomID id.RoomID, fromUser id.UserID, startEvent *event.VerificationStartEventContent, transactionID string, publicKey []byte, methods []VerificationMethod) error { if startEvent.Method != event.VerificationMethodSAS { reason := "Unknown verification method: " + string(startEvent.Method) if err := mach.SendInRoomSASVerificationCancel(roomID, fromUser, transactionID, reason, event.VerificationCancelUnknownMethod); err != nil { return err } return ErrUnknownVerificationMethod } payload, err := json.Marshal(startEvent) if err != nil { return err } canonical, err := canonicaljson.CanonicalJSON(payload) if err != nil { return err } hash := olm.NewUtility().Sha256(string(publicKey) + string(canonical)) sasMethods := make([]event.SASMethod, len(methods)) for i, method := range methods { sasMethods[i] = method.Type() } content := &event.VerificationAcceptEventContent{ RelatesTo: &event.RelatesTo{Type: event.RelReference, EventID: id.EventID(transactionID)}, Method: event.VerificationMethodSAS, KeyAgreementProtocol: event.KeyAgreementCurve25519HKDFSHA256, Hash: event.VerificationHashSHA256, MessageAuthenticationCode: event.HKDFHMACSHA256, ShortAuthenticationString: sasMethods, Commitment: hash, To: fromUser, } encrypted, err := mach.EncryptMegolmEvent(roomID, event.InRoomVerificationAccept, content) if err != nil { return err } _, err = mach.Client.SendMessageEvent(roomID, event.EventEncrypted, encrypted) return err } // SendInRoomSASVerificationKey sends the ephemeral public key for a device to the partner device for an in-room verification. func (mach *OlmMachine) SendInRoomSASVerificationKey(roomID id.RoomID, userID id.UserID, transactionID string, key string) error { content := &event.VerificationKeyEventContent{ RelatesTo: &event.RelatesTo{Type: event.RelReference, EventID: id.EventID(transactionID)}, Key: key, To: userID, } encrypted, err := mach.EncryptMegolmEvent(roomID, event.InRoomVerificationKey, content) if err != nil { return err } _, err = mach.Client.SendMessageEvent(roomID, event.EventEncrypted, encrypted) return err } // SendInRoomSASVerificationMAC sends the MAC of a device's key to the partner device for an in-room verification. func (mach *OlmMachine) SendInRoomSASVerificationMAC(roomID id.RoomID, userID id.UserID, deviceID id.DeviceID, transactionID string, sas *olm.SAS) error { keyID := id.NewKeyID(id.KeyAlgorithmEd25519, mach.Client.DeviceID.String()) signingKey := mach.account.SigningKey() keyIDsMap := map[id.KeyID]string{keyID: ""} macMap := make(map[id.KeyID]string) if mach.CrossSigningKeys != nil { masterKey := mach.CrossSigningKeys.MasterKey.PublicKey masterKeyID := id.NewKeyID(id.KeyAlgorithmEd25519, masterKey.String()) // add master key ID to key map keyIDsMap[masterKeyID] = "" masterKeyMAC, _, err := mach.getPKAndKeysMAC(sas, mach.Client.UserID, mach.Client.DeviceID, userID, deviceID, transactionID, masterKey, masterKeyID, keyIDsMap) if err != nil { mach.Log.Error("Error generating master key MAC: %v", err) } else { mach.Log.Debug("Generated master key `%v` MAC: %v", masterKey, masterKeyMAC) macMap[masterKeyID] = masterKeyMAC } } pubKeyMac, keysMac, err := mach.getPKAndKeysMAC(sas, mach.Client.UserID, mach.Client.DeviceID, userID, deviceID, transactionID, signingKey, keyID, keyIDsMap) if err != nil { return err } mach.Log.Debug("MAC of key %s is: %s", signingKey, pubKeyMac) mach.Log.Debug("MAC of key ID(s) %s is: %s", keyID, keysMac) macMap[keyID] = pubKeyMac content := &event.VerificationMacEventContent{ RelatesTo: &event.RelatesTo{Type: event.RelReference, EventID: id.EventID(transactionID)}, Keys: keysMac, Mac: macMap, To: userID, } encrypted, err := mach.EncryptMegolmEvent(roomID, event.InRoomVerificationMAC, content) if err != nil { return err } _, err = mach.Client.SendMessageEvent(roomID, event.EventEncrypted, encrypted) return err } // NewInRoomSASVerificationWith starts the in-room SAS verification process with another user in the given room. // It returns the generated transaction ID. func (mach *OlmMachine) NewInRoomSASVerificationWith(inRoomID id.RoomID, userID id.UserID, hooks VerificationHooks, timeout time.Duration) (string, error) { return mach.newInRoomSASVerificationWithInner(inRoomID, &DeviceIdentity{UserID: userID}, hooks, "", timeout) } func (mach *OlmMachine) newInRoomSASVerificationWithInner(inRoomID id.RoomID, device *DeviceIdentity, hooks VerificationHooks, transactionID string, timeout time.Duration) (string, error) { mach.Log.Debug("Starting new in-room verification transaction user %v", device.UserID) request := transactionID == "" if request { var err error // get new transaction ID from the request message event ID transactionID, err = mach.SendInRoomSASVerificationRequest(inRoomID, device.UserID, hooks.VerificationMethods()) if err != nil { return "", err } } verState := &verificationState{ sas: olm.NewSAS(), otherDevice: device, initiatedByUs: true, verificationStarted: false, keyReceived: false, sasMatched: make(chan bool, 1), hooks: hooks, inRoomID: inRoomID, } verState.lock.Lock() defer verState.lock.Unlock() if !request { // start in-room verification startEvent, err := mach.SendInRoomSASVerificationStart(inRoomID, device.UserID, transactionID, hooks.VerificationMethods()) if err != nil { return "", err } payload, err := json.Marshal(startEvent) if err != nil { return "", err } canonical, err := canonicaljson.CanonicalJSON(payload) if err != nil { return "", err } verState.startEventCanonical = string(canonical) } mach.keyVerificationTransactionState.Store(device.UserID.String()+":"+transactionID, verState) mach.timeoutAfter(verState, transactionID, timeout) return transactionID, nil } func (mach *OlmMachine) handleInRoomVerificationReady(userID id.UserID, roomID id.RoomID, content *event.VerificationReadyEventContent, transactionID string) { device, err := mach.GetOrFetchDevice(userID, content.FromDevice) if err != nil { mach.Log.Error("Error fetching device %v of user %v: %v", content.FromDevice, userID, err) return } verState, err := mach.getTransactionState(transactionID, userID) if err != nil { mach.Log.Error("Error getting transaction state: %v", err) return } //mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID) if mach.Client.UserID < userID { // up to us to send the start message verState.lock.Lock() mach.newInRoomSASVerificationWithInner(roomID, device, verState.hooks, transactionID, 10*time.Minute) verState.lock.Unlock() } } go-0.11.1/crypto/verification_sas_methods.go000066400000000000000000000131241436100171500211200ustar00rootroot00000000000000// Copyright (c) 2020 Nikos Filippakis // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. //go:build !nosas // +build !nosas package crypto import ( "fmt" "maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) // SASData contains the data that users need to verify. type SASData interface { Type() event.SASMethod } // VerificationMethod describes a method for generating a SAS. type VerificationMethod interface { // GetVerificationSAS uses the user, device ID and key of the user who initiated the verification transaction, // the user, device ID and key of the user who accepted, the transaction ID and the SAS object to generate a SAS. // The SAS can be any type, such as an array of numbers or emojis. GetVerificationSAS(initUserID id.UserID, initDeviceID id.DeviceID, initKey string, acceptUserID id.UserID, acceptDeviceID id.DeviceID, acceptKey string, transactionID string, sas *olm.SAS) (SASData, error) // Type returns the type of this SAS method Type() event.SASMethod } const sasInfoFormat = "MATRIX_KEY_VERIFICATION_SAS|%s|%s|%s|%s|%s|%s|%s" // VerificationMethodDecimal describes the decimal SAS method. type VerificationMethodDecimal struct{} // DecimalSASData contains the verification numbers for the decimal SAS method. type DecimalSASData [3]uint // Type returns the decimal SAS method type. func (DecimalSASData) Type() event.SASMethod { return event.SASDecimal } // GetVerificationSAS generates the three numbers that need to match with the other device for a verification to be valid. func (VerificationMethodDecimal) GetVerificationSAS(initUserID id.UserID, initDeviceID id.DeviceID, initKey string, acceptUserID id.UserID, acceptDeviceID id.DeviceID, acceptKey string, transactionID string, sas *olm.SAS) (SASData, error) { sasInfo := fmt.Sprintf(sasInfoFormat, initUserID, initDeviceID, initKey, acceptUserID, acceptDeviceID, acceptKey, transactionID) sasBytes, err := sas.GenerateBytes([]byte(sasInfo), 5) if err != nil { return DecimalSASData{0, 0, 0}, err } numbers := DecimalSASData{ (uint(sasBytes[0])<<5 | uint(sasBytes[1])>>3) + 1000, (uint(sasBytes[1]&0x7)<<10 | uint(sasBytes[2])<<2 | uint(sasBytes[3]>>6)) + 1000, (uint(sasBytes[3]&0x3F)<<7 | uint(sasBytes[4])>>1) + 1000, } return numbers, nil } // Type returns the decimal SAS method type. func (VerificationMethodDecimal) Type() event.SASMethod { return event.SASDecimal } var allEmojis = [...]VerificationEmoji{ {'๐Ÿถ', "Dog"}, {'๐Ÿฑ', "Cat"}, {'๐Ÿฆ', "Lion"}, {'๐ŸŽ', "Horse"}, {'๐Ÿฆ„', "Unicorn"}, {'๐Ÿท', "Pig"}, {'๐Ÿ˜', "Elephant"}, {'๐Ÿฐ', "Rabbit"}, {'๐Ÿผ', "Panda"}, {'๐Ÿ“', "Rooster"}, {'๐Ÿง', "Penguin"}, {'๐Ÿข', "Turtle"}, {'๐ŸŸ', "Fish"}, {'๐Ÿ™', "Octopus"}, {'๐Ÿฆ‹', "Butterfly"}, {'๐ŸŒท', "Flower"}, {'๐ŸŒณ', "Tree"}, {'๐ŸŒต', "Cactus"}, {'๐Ÿ„', "Mushroom"}, {'๐ŸŒ', "Globe"}, {'๐ŸŒ™', "Moon"}, {'โ˜', "Cloud"}, {'๐Ÿ”ฅ', "Fire"}, {'๐ŸŒ', "Banana"}, {'๐ŸŽ', "Apple"}, {'๐Ÿ“', "Strawberry"}, {'๐ŸŒฝ', "Corn"}, {'๐Ÿ•', "Pizza"}, {'๐ŸŽ‚', "Cake"}, {'โค', "Heart"}, {'๐Ÿ˜€', "Smiley"}, {'๐Ÿค–', "Robot"}, {'๐ŸŽฉ', "Hat"}, {'๐Ÿ‘“', "Glasses"}, {'๐Ÿ”ง', "Spanner"}, {'๐ŸŽ…', "Santa"}, {'๐Ÿ‘', "Thumbs Up"}, {'โ˜‚', "Umbrella"}, {'โŒ›', "Hourglass"}, {'โฐ', "Clock"}, {'๐ŸŽ', "Gift"}, {'๐Ÿ’ก', "Light Bulb"}, {'๐Ÿ“•', "Book"}, {'โœ', "Pencil"}, {'๐Ÿ“Ž', "Paperclip"}, {'โœ‚', "Scissors"}, {'๐Ÿ”’', "Lock"}, {'๐Ÿ”‘', "Key"}, {'๐Ÿ”จ', "Hammer"}, {'โ˜Ž', "Telephone"}, {'๐Ÿ', "Flag"}, {'๐Ÿš‚', "Train"}, {'๐Ÿšฒ', "Bicycle"}, {'โœˆ', "Aeroplane"}, {'๐Ÿš€', "Rocket"}, {'๐Ÿ†', "Trophy"}, {'โšฝ', "Ball"}, {'๐ŸŽธ', "Guitar"}, {'๐ŸŽบ', "Trumpet"}, {'๐Ÿ””', "Bell"}, {'โš“', "Anchor"}, {'๐ŸŽง', "Headphones"}, {'๐Ÿ“', "Folder"}, {'๐Ÿ“Œ', "Pin"}, } // VerificationEmoji describes an emoji that might be sent for verifying devices. type VerificationEmoji struct { Emoji rune Description string } func (vm VerificationEmoji) GetEmoji() rune { return vm.Emoji } func (vm VerificationEmoji) GetDescription() string { return vm.Description } // EmojiSASData contains the verification emojis for the emoji SAS method. type EmojiSASData [7]VerificationEmoji // Type returns the emoji SAS method type. func (EmojiSASData) Type() event.SASMethod { return event.SASEmoji } // VerificationMethodEmoji describes the emoji SAS method. type VerificationMethodEmoji struct{} // GetVerificationSAS generates the three numbers that need to match with the other device for a verification to be valid. func (VerificationMethodEmoji) GetVerificationSAS(initUserID id.UserID, initDeviceID id.DeviceID, initKey string, acceptUserID id.UserID, acceptDeviceID id.DeviceID, acceptKey string, transactionID string, sas *olm.SAS) (SASData, error) { sasInfo := fmt.Sprintf(sasInfoFormat, initUserID, initDeviceID, initKey, acceptUserID, acceptDeviceID, acceptKey, transactionID) var emojis EmojiSASData sasBytes, err := sas.GenerateBytes([]byte(sasInfo), 6) if err != nil { return emojis, err } sasNum := uint64(sasBytes[0])<<40 | uint64(sasBytes[1])<<32 | uint64(sasBytes[2])<<24 | uint64(sasBytes[3])<<16 | uint64(sasBytes[4])<<8 | uint64(sasBytes[5]) for i := 0; i < len(emojis); i++ { // take nth group of 6 bits emojiIdx := (sasNum >> uint(48-(i+1)*6)) & 0x3F emoji := allEmojis[emojiIdx] emojis[i] = emoji } return emojis, nil } // Type returns the emoji SAS method type. func (VerificationMethodEmoji) Type() event.SASMethod { return event.SASEmoji } go-0.11.1/error.go000066400000000000000000000120511436100171500136540ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package mautrix import ( "encoding/json" "errors" "fmt" "net/http" ) // Common error codes from https://matrix.org/docs/spec/client_server/latest#api-standards // // Can be used with errors.Is() to check the response code without casting the error: // err := client.Sync() // if errors.Is(err, MUnknownToken) { // // logout // } var ( // Forbidden access, e.g. joining a room without permission, failed login. MForbidden = RespError{ErrCode: "M_FORBIDDEN"} // The access token specified was not recognised. MUnknownToken = RespError{ErrCode: "M_UNKNOWN_TOKEN"} // No access token was specified for the request. MMissingToken = RespError{ErrCode: "M_MISSING_TOKEN"} // Request contained valid JSON, but it was malformed in some way, e.g. missing required keys, invalid values for keys. MBadJSON = RespError{ErrCode: "M_BAD_JSON"} // Request did not contain valid JSON. MNotJSON = RespError{ErrCode: "M_NOT_JSON"} // No resource was found for this request. MNotFound = RespError{ErrCode: "M_NOT_FOUND"} // Too many requests have been sent in a short period of time. Wait a while then try again. MLimitExceeded = RespError{ErrCode: "M_LIMIT_EXCEEDED"} // The user ID associated with the request has been deactivated. // Typically for endpoints that prove authentication, such as /login. MUserDeactivated = RespError{ErrCode: "M_USER_DEACTIVATED"} // Encountered when trying to register a user ID which has been taken. MUserInUse = RespError{ErrCode: "M_USER_IN_USE"} // Encountered when trying to register a user ID which is not valid. MInvalidUsername = RespError{ErrCode: "M_INVALID_USERNAME"} // Sent when the room alias given to the createRoom API is already in use. MRoomInUse = RespError{ErrCode: "M_ROOM_IN_USE"} // The state change requested cannot be performed, such as attempting to unban a user who is not banned. MBadState = RespError{ErrCode: "M_BAD_STATE"} // The request or entity was too large. MTooLarge = RespError{ErrCode: "M_TOO_LARGE"} // The resource being requested is reserved by an application service, or the application service making the request has not created the resource. MExclusive = RespError{ErrCode: "M_EXCLUSIVE"} // The client's request to create a room used a room version that the server does not support. MUnsupportedRoomVersion = RespError{ErrCode: "M_UNSUPPORTED_ROOM_VERSION"} // The client attempted to join a room that has a version the server does not support. // Inspect the room_version property of the error response for the room's version. MIncompatibleRoomVersion = RespError{ErrCode: "M_INCOMPATIBLE_ROOM_VERSION"} ) // HTTPError An HTTP Error response, which may wrap an underlying native Go Error. type HTTPError struct { Request *http.Request Response *http.Response ResponseBody string WrappedError error RespError *RespError Message string } func (e HTTPError) Is(err error) bool { return (e.RespError != nil && errors.Is(e.RespError, err)) || (e.WrappedError != nil && errors.Is(e.WrappedError, err)) } func (e HTTPError) IsStatus(code int) bool { return e.Response != nil && e.Response.StatusCode == code } func (e HTTPError) Error() string { if e.WrappedError != nil { return fmt.Sprintf("%s: %v", e.Message, e.WrappedError) } else if e.RespError != nil { return fmt.Sprintf("failed to %s %s: %s (HTTP %d): %s", e.Request.Method, e.Request.URL.Path, e.RespError.ErrCode, e.Response.StatusCode, e.RespError.Err) } else { msg := fmt.Sprintf("failed to %s %s: %s", e.Request.Method, e.Request.URL.Path, e.Response.Status) if len(e.ResponseBody) > 0 { msg = fmt.Sprintf("%s\n%s", msg, e.ResponseBody) } return msg } } func (e HTTPError) Unwrap() error { if e.WrappedError != nil { return e.WrappedError } else if e.RespError != nil { return *e.RespError } return nil } // RespError is the standard JSON error response from Homeservers. It also implements the Golang "error" interface. // See https://spec.matrix.org/v1.2/client-server-api/#api-standards type RespError struct { ErrCode string Err string ExtraData map[string]interface{} } func (e *RespError) UnmarshalJSON(data []byte) error { err := json.Unmarshal(data, &e.ExtraData) if err != nil { return err } e.ErrCode, _ = e.ExtraData["errcode"].(string) e.Err, _ = e.ExtraData["error"].(string) return nil } func (e *RespError) MarshalJSON() ([]byte, error) { if e.ExtraData == nil { e.ExtraData = make(map[string]interface{}) } e.ExtraData["errcode"] = e.ErrCode e.ExtraData["error"] = e.Err return json.Marshal(&e.ExtraData) } // Error returns the errcode and error message. func (e RespError) Error() string { return e.ErrCode + ": " + e.Err } func (e RespError) Is(err error) bool { e2, ok := err.(RespError) if !ok { return false } if e.ErrCode == "M_UNKNOWN" && e2.ErrCode == "M_UNKNOWN" { return e.Err == e2.Err } return e2.ErrCode == e.ErrCode } go-0.11.1/event/000077500000000000000000000000001436100171500133165ustar00rootroot00000000000000go-0.11.1/event/accountdata.go000066400000000000000000000025141436100171500161350ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package event import ( "encoding/json" "maunium.net/go/mautrix/id" ) // TagEventContent represents the content of a m.tag room account data event. // https://spec.matrix.org/v1.2/client-server-api/#mtag type TagEventContent struct { Tags Tags `json:"tags"` } type Tags map[string]Tag type Tag struct { Order json.Number `json:"order,omitempty"` } // DirectChatsEventContent represents the content of a m.direct account data event. // https://spec.matrix.org/v1.2/client-server-api/#mdirect type DirectChatsEventContent map[id.UserID][]id.RoomID // FullyReadEventContent represents the content of a m.fully_read account data event. // https://spec.matrix.org/v1.2/client-server-api/#mfully_read type FullyReadEventContent struct { EventID id.EventID `json:"event_id"` } // IgnoredUserListEventContent represents the content of a m.ignored_user_list account data event. // https://spec.matrix.org/v1.2/client-server-api/#mignored_user_list type IgnoredUserListEventContent struct { IgnoredUsers map[id.UserID]IgnoredUser `json:"ignored_users"` } type IgnoredUser struct { // This is an empty object } go-0.11.1/event/content.go000066400000000000000000000376751436100171500153410ustar00rootroot00000000000000// Copyright (c) 2021 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package event import ( "encoding/gob" "encoding/json" "errors" "fmt" "reflect" ) // TypeMap is a mapping from event type to the content struct type. // This is used by Content.ParseRaw() for creating the correct type of struct. var TypeMap = map[Type]reflect.Type{ StateMember: reflect.TypeOf(MemberEventContent{}), StatePowerLevels: reflect.TypeOf(PowerLevelsEventContent{}), StateCanonicalAlias: reflect.TypeOf(CanonicalAliasEventContent{}), StateRoomName: reflect.TypeOf(RoomNameEventContent{}), StateRoomAvatar: reflect.TypeOf(RoomAvatarEventContent{}), StateServerACL: reflect.TypeOf(ServerACLEventContent{}), StateTopic: reflect.TypeOf(TopicEventContent{}), StateTombstone: reflect.TypeOf(TombstoneEventContent{}), StateCreate: reflect.TypeOf(CreateEventContent{}), StateJoinRules: reflect.TypeOf(JoinRulesEventContent{}), StateHistoryVisibility: reflect.TypeOf(HistoryVisibilityEventContent{}), StateGuestAccess: reflect.TypeOf(GuestAccessEventContent{}), StatePinnedEvents: reflect.TypeOf(PinnedEventsEventContent{}), StatePolicyRoom: reflect.TypeOf(ModPolicyContent{}), StatePolicyServer: reflect.TypeOf(ModPolicyContent{}), StatePolicyUser: reflect.TypeOf(ModPolicyContent{}), StateEncryption: reflect.TypeOf(EncryptionEventContent{}), StateBridge: reflect.TypeOf(BridgeEventContent{}), StateHalfShotBridge: reflect.TypeOf(BridgeEventContent{}), StateSpaceParent: reflect.TypeOf(SpaceParentEventContent{}), StateSpaceChild: reflect.TypeOf(SpaceChildEventContent{}), EventMessage: reflect.TypeOf(MessageEventContent{}), EventSticker: reflect.TypeOf(MessageEventContent{}), EventEncrypted: reflect.TypeOf(EncryptedEventContent{}), EventRedaction: reflect.TypeOf(RedactionEventContent{}), EventReaction: reflect.TypeOf(ReactionEventContent{}), AccountDataRoomTags: reflect.TypeOf(TagEventContent{}), AccountDataDirectChats: reflect.TypeOf(DirectChatsEventContent{}), AccountDataFullyRead: reflect.TypeOf(FullyReadEventContent{}), AccountDataIgnoredUserList: reflect.TypeOf(IgnoredUserListEventContent{}), EphemeralEventTyping: reflect.TypeOf(TypingEventContent{}), EphemeralEventReceipt: reflect.TypeOf(ReceiptEventContent{}), EphemeralEventPresence: reflect.TypeOf(PresenceEventContent{}), InRoomVerificationStart: reflect.TypeOf(VerificationStartEventContent{}), InRoomVerificationReady: reflect.TypeOf(VerificationReadyEventContent{}), InRoomVerificationAccept: reflect.TypeOf(VerificationAcceptEventContent{}), InRoomVerificationKey: reflect.TypeOf(VerificationKeyEventContent{}), InRoomVerificationMAC: reflect.TypeOf(VerificationMacEventContent{}), InRoomVerificationCancel: reflect.TypeOf(VerificationCancelEventContent{}), ToDeviceRoomKey: reflect.TypeOf(RoomKeyEventContent{}), ToDeviceForwardedRoomKey: reflect.TypeOf(ForwardedRoomKeyEventContent{}), ToDeviceRoomKeyRequest: reflect.TypeOf(RoomKeyRequestEventContent{}), ToDeviceEncrypted: reflect.TypeOf(EncryptedEventContent{}), ToDeviceRoomKeyWithheld: reflect.TypeOf(RoomKeyWithheldEventContent{}), ToDeviceDummy: reflect.TypeOf(DummyEventContent{}), ToDeviceVerificationStart: reflect.TypeOf(VerificationStartEventContent{}), ToDeviceVerificationAccept: reflect.TypeOf(VerificationAcceptEventContent{}), ToDeviceVerificationKey: reflect.TypeOf(VerificationKeyEventContent{}), ToDeviceVerificationMAC: reflect.TypeOf(VerificationMacEventContent{}), ToDeviceVerificationCancel: reflect.TypeOf(VerificationCancelEventContent{}), ToDeviceVerificationRequest: reflect.TypeOf(VerificationRequestEventContent{}), ToDeviceOrgMatrixRoomKeyWithheld: reflect.TypeOf(RoomKeyWithheldEventContent{}), CallInvite: reflect.TypeOf(CallInviteEventContent{}), CallCandidates: reflect.TypeOf(CallCandidatesEventContent{}), CallAnswer: reflect.TypeOf(CallAnswerEventContent{}), CallReject: reflect.TypeOf(CallRejectEventContent{}), CallSelectAnswer: reflect.TypeOf(CallSelectAnswerEventContent{}), CallNegotiate: reflect.TypeOf(CallNegotiateEventContent{}), CallHangup: reflect.TypeOf(CallHangupEventContent{}), } // Content stores the content of a Matrix event. // // By default, the raw JSON bytes are stored in VeryRaw and parsed into a map[string]interface{} in the Raw field. // Additionally, you can call ParseRaw with the correct event type to parse the (VeryRaw) content into a nicer struct, // which you can then access from Parsed or via the helper functions. // // When being marshaled into JSON, the data in Parsed will be marshaled first and then recursively merged // with the data in Raw. Values in Raw are preferred, but nested objects will be recursed into before merging, // rather than overriding the whole object with the one in Raw). // If one of them is nil, the only the other is used. If both (Parsed and Raw) are nil, VeryRaw is used instead. type Content struct { VeryRaw json.RawMessage Raw map[string]interface{} Parsed interface{} } type Relatable interface { GetRelatesTo() *RelatesTo OptionalGetRelatesTo() *RelatesTo SetRelatesTo(rel *RelatesTo) } func (content *Content) UnmarshalJSON(data []byte) error { content.VeryRaw = data err := json.Unmarshal(data, &content.Raw) return err } func (content *Content) MarshalJSON() ([]byte, error) { if content.Raw == nil { if content.Parsed == nil { if content.VeryRaw == nil { return []byte("{}"), nil } return content.VeryRaw, nil } return json.Marshal(content.Parsed) } else if content.Parsed != nil { // TODO this whole thing is incredibly hacky // It needs to produce JSON, where: // * content.Parsed is applied after content.Raw // * MarshalJSON() is respected inside content.Parsed // * Custom field inside nested objects of content.Raw are preserved, // even if content.Parsed contains the higher-level objects. // * content.Raw is not modified unparsed, err := json.Marshal(content.Parsed) if err != nil { return nil, err } var rawParsed map[string]interface{} err = json.Unmarshal(unparsed, &rawParsed) if err != nil { return nil, err } output := make(map[string]interface{}) for key, value := range content.Raw { output[key] = value } mergeMaps(output, rawParsed) return json.Marshal(output) } return json.Marshal(content.Raw) } func IsUnsupportedContentType(err error) bool { return errors.Is(err, ErrUnsupportedContentType) } var ErrContentAlreadyParsed = errors.New("content is already parsed") var ErrUnsupportedContentType = errors.New("unsupported event type") func (content *Content) ParseRaw(evtType Type) error { if content.Parsed != nil { return ErrContentAlreadyParsed } structType, ok := TypeMap[evtType] if !ok { return fmt.Errorf("%w %s", ErrUnsupportedContentType, evtType.Repr()) } content.Parsed = reflect.New(structType).Interface() return json.Unmarshal(content.VeryRaw, &content.Parsed) } func mergeMaps(into, from map[string]interface{}) { for key, newValue := range from { existingValue, ok := into[key] if !ok { into[key] = newValue continue } existingValueMap, okEx := existingValue.(map[string]interface{}) newValueMap, okNew := newValue.(map[string]interface{}) if okEx && okNew { mergeMaps(existingValueMap, newValueMap) } else { into[key] = newValue } } } func init() { gob.Register(&MemberEventContent{}) gob.Register(&PowerLevelsEventContent{}) gob.Register(&CanonicalAliasEventContent{}) gob.Register(&EncryptionEventContent{}) gob.Register(&BridgeEventContent{}) gob.Register(&SpaceChildEventContent{}) gob.Register(&SpaceParentEventContent{}) gob.Register(&RoomNameEventContent{}) gob.Register(&RoomAvatarEventContent{}) gob.Register(&TopicEventContent{}) gob.Register(&TombstoneEventContent{}) gob.Register(&CreateEventContent{}) gob.Register(&JoinRulesEventContent{}) gob.Register(&HistoryVisibilityEventContent{}) gob.Register(&GuestAccessEventContent{}) gob.Register(&PinnedEventsEventContent{}) gob.Register(&MessageEventContent{}) gob.Register(&MessageEventContent{}) gob.Register(&EncryptedEventContent{}) gob.Register(&RedactionEventContent{}) gob.Register(&ReactionEventContent{}) gob.Register(&TagEventContent{}) gob.Register(&DirectChatsEventContent{}) gob.Register(&FullyReadEventContent{}) gob.Register(&IgnoredUserListEventContent{}) gob.Register(&TypingEventContent{}) gob.Register(&ReceiptEventContent{}) gob.Register(&PresenceEventContent{}) gob.Register(&RoomKeyEventContent{}) gob.Register(&ForwardedRoomKeyEventContent{}) gob.Register(&RoomKeyRequestEventContent{}) gob.Register(&RoomKeyWithheldEventContent{}) } // Helper cast functions below func (content *Content) AsMember() *MemberEventContent { casted, ok := content.Parsed.(*MemberEventContent) if !ok { return &MemberEventContent{} } return casted } func (content *Content) AsPowerLevels() *PowerLevelsEventContent { casted, ok := content.Parsed.(*PowerLevelsEventContent) if !ok { return &PowerLevelsEventContent{} } return casted } func (content *Content) AsCanonicalAlias() *CanonicalAliasEventContent { casted, ok := content.Parsed.(*CanonicalAliasEventContent) if !ok { return &CanonicalAliasEventContent{} } return casted } func (content *Content) AsRoomName() *RoomNameEventContent { casted, ok := content.Parsed.(*RoomNameEventContent) if !ok { return &RoomNameEventContent{} } return casted } func (content *Content) AsRoomAvatar() *RoomAvatarEventContent { casted, ok := content.Parsed.(*RoomAvatarEventContent) if !ok { return &RoomAvatarEventContent{} } return casted } func (content *Content) AsTopic() *TopicEventContent { casted, ok := content.Parsed.(*TopicEventContent) if !ok { return &TopicEventContent{} } return casted } func (content *Content) AsTombstone() *TombstoneEventContent { casted, ok := content.Parsed.(*TombstoneEventContent) if !ok { return &TombstoneEventContent{} } return casted } func (content *Content) AsCreate() *CreateEventContent { casted, ok := content.Parsed.(*CreateEventContent) if !ok { return &CreateEventContent{} } return casted } func (content *Content) AsJoinRules() *JoinRulesEventContent { casted, ok := content.Parsed.(*JoinRulesEventContent) if !ok { return &JoinRulesEventContent{} } return casted } func (content *Content) AsHistoryVisibility() *HistoryVisibilityEventContent { casted, ok := content.Parsed.(*HistoryVisibilityEventContent) if !ok { return &HistoryVisibilityEventContent{} } return casted } func (content *Content) AsGuestAccess() *GuestAccessEventContent { casted, ok := content.Parsed.(*GuestAccessEventContent) if !ok { return &GuestAccessEventContent{} } return casted } func (content *Content) AsPinnedEvents() *PinnedEventsEventContent { casted, ok := content.Parsed.(*PinnedEventsEventContent) if !ok { return &PinnedEventsEventContent{} } return casted } func (content *Content) AsEncryption() *EncryptionEventContent { casted, ok := content.Parsed.(*EncryptionEventContent) if !ok { return &EncryptionEventContent{} } return casted } func (content *Content) AsBridge() *BridgeEventContent { casted, ok := content.Parsed.(*BridgeEventContent) if !ok { return &BridgeEventContent{} } return casted } func (content *Content) AsSpaceChild() *SpaceChildEventContent { casted, ok := content.Parsed.(*SpaceChildEventContent) if !ok { return &SpaceChildEventContent{} } return casted } func (content *Content) AsSpaceParent() *SpaceParentEventContent { casted, ok := content.Parsed.(*SpaceParentEventContent) if !ok { return &SpaceParentEventContent{} } return casted } func (content *Content) AsMessage() *MessageEventContent { casted, ok := content.Parsed.(*MessageEventContent) if !ok { return &MessageEventContent{} } return casted } func (content *Content) AsEncrypted() *EncryptedEventContent { casted, ok := content.Parsed.(*EncryptedEventContent) if !ok { return &EncryptedEventContent{} } return casted } func (content *Content) AsRedaction() *RedactionEventContent { casted, ok := content.Parsed.(*RedactionEventContent) if !ok { return &RedactionEventContent{} } return casted } func (content *Content) AsReaction() *ReactionEventContent { casted, ok := content.Parsed.(*ReactionEventContent) if !ok { return &ReactionEventContent{} } return casted } func (content *Content) AsTag() *TagEventContent { casted, ok := content.Parsed.(*TagEventContent) if !ok { return &TagEventContent{} } return casted } func (content *Content) AsDirectChats() *DirectChatsEventContent { casted, ok := content.Parsed.(*DirectChatsEventContent) if !ok { return &DirectChatsEventContent{} } return casted } func (content *Content) AsFullyRead() *FullyReadEventContent { casted, ok := content.Parsed.(*FullyReadEventContent) if !ok { return &FullyReadEventContent{} } return casted } func (content *Content) AsIgnoredUserList() *IgnoredUserListEventContent { casted, ok := content.Parsed.(*IgnoredUserListEventContent) if !ok { return &IgnoredUserListEventContent{} } return casted } func (content *Content) AsTyping() *TypingEventContent { casted, ok := content.Parsed.(*TypingEventContent) if !ok { return &TypingEventContent{} } return casted } func (content *Content) AsReceipt() *ReceiptEventContent { casted, ok := content.Parsed.(*ReceiptEventContent) if !ok { return &ReceiptEventContent{} } return casted } func (content *Content) AsPresence() *PresenceEventContent { casted, ok := content.Parsed.(*PresenceEventContent) if !ok { return &PresenceEventContent{} } return casted } func (content *Content) AsRoomKey() *RoomKeyEventContent { casted, ok := content.Parsed.(*RoomKeyEventContent) if !ok { return &RoomKeyEventContent{} } return casted } func (content *Content) AsForwardedRoomKey() *ForwardedRoomKeyEventContent { casted, ok := content.Parsed.(*ForwardedRoomKeyEventContent) if !ok { return &ForwardedRoomKeyEventContent{} } return casted } func (content *Content) AsRoomKeyRequest() *RoomKeyRequestEventContent { casted, ok := content.Parsed.(*RoomKeyRequestEventContent) if !ok { return &RoomKeyRequestEventContent{} } return casted } func (content *Content) AsRoomKeyWithheld() *RoomKeyWithheldEventContent { casted, ok := content.Parsed.(*RoomKeyWithheldEventContent) if !ok { return &RoomKeyWithheldEventContent{} } return casted } func (content *Content) AsCallInvite() *CallInviteEventContent { casted, ok := content.Parsed.(*CallInviteEventContent) if !ok { return &CallInviteEventContent{} } return casted } func (content *Content) AsCallCandidates() *CallCandidatesEventContent { casted, ok := content.Parsed.(*CallCandidatesEventContent) if !ok { return &CallCandidatesEventContent{} } return casted } func (content *Content) AsCallAnswer() *CallAnswerEventContent { casted, ok := content.Parsed.(*CallAnswerEventContent) if !ok { return &CallAnswerEventContent{} } return casted } func (content *Content) AsCallReject() *CallRejectEventContent { casted, ok := content.Parsed.(*CallRejectEventContent) if !ok { return &CallRejectEventContent{} } return casted } func (content *Content) AsCallSelectAnswer() *CallSelectAnswerEventContent { casted, ok := content.Parsed.(*CallSelectAnswerEventContent) if !ok { return &CallSelectAnswerEventContent{} } return casted } func (content *Content) AsCallNegotiate() *CallNegotiateEventContent { casted, ok := content.Parsed.(*CallNegotiateEventContent) if !ok { return &CallNegotiateEventContent{} } return casted } func (content *Content) AsCallHangup() *CallHangupEventContent { casted, ok := content.Parsed.(*CallHangupEventContent) if !ok { return &CallHangupEventContent{} } return casted } func (content *Content) AsModPolicy() *ModPolicyContent { casted, ok := content.Parsed.(*ModPolicyContent) if !ok { return &ModPolicyContent{} } return casted } go-0.11.1/event/encryption.go000066400000000000000000000126341436100171500160450ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package event import ( "encoding/json" "maunium.net/go/mautrix/id" ) // EncryptionEventContent represents the content of a m.room.encryption state event. // https://spec.matrix.org/v1.2/client-server-api/#mroomencryption type EncryptionEventContent struct { // The encryption algorithm to be used to encrypt messages sent in this room. Must be 'm.megolm.v1.aes-sha2'. Algorithm id.Algorithm `json:"algorithm"` // How long the session should be used before changing it. 604800000 (a week) is the recommended default. RotationPeriodMillis int64 `json:"rotation_period_ms,omitempty"` // How many messages should be sent before changing the session. 100 is the recommended default. RotationPeriodMessages int `json:"rotation_period_msgs,omitempty"` } // EncryptedEventContent represents the content of a m.room.encrypted message event. // https://spec.matrix.org/v1.2/client-server-api/#mroomencrypted // // Note that sender_key and device_id are deprecated in Megolm events as of https://github.com/matrix-org/matrix-spec-proposals/pull/3700 type EncryptedEventContent struct { Algorithm id.Algorithm `json:"algorithm"` SenderKey id.SenderKey `json:"sender_key,omitempty"` DeviceID id.DeviceID `json:"device_id,omitempty"` // Only present for Megolm events SessionID id.SessionID `json:"session_id,omitempty"` // Only present for Megolm events Ciphertext json.RawMessage `json:"ciphertext"` MegolmCiphertext []byte `json:"-"` OlmCiphertext OlmCiphertexts `json:"-"` RelatesTo *RelatesTo `json:"m.relates_to,omitempty"` } type OlmCiphertexts map[id.Curve25519]struct { Body string `json:"body"` Type id.OlmMsgType `json:"type"` } type serializableEncryptedEventContent EncryptedEventContent func (content *EncryptedEventContent) UnmarshalJSON(data []byte) error { err := json.Unmarshal(data, (*serializableEncryptedEventContent)(content)) if err != nil { return err } switch content.Algorithm { case id.AlgorithmOlmV1: content.OlmCiphertext = make(OlmCiphertexts) return json.Unmarshal(content.Ciphertext, &content.OlmCiphertext) case id.AlgorithmMegolmV1: if len(content.Ciphertext) == 0 || content.Ciphertext[0] != '"' || content.Ciphertext[len(content.Ciphertext)-1] != '"' { return id.InputNotJSONString } content.MegolmCiphertext = content.Ciphertext[1 : len(content.Ciphertext)-1] } return nil } func (content *EncryptedEventContent) MarshalJSON() ([]byte, error) { var err error switch content.Algorithm { case id.AlgorithmOlmV1: content.Ciphertext, err = json.Marshal(content.OlmCiphertext) case id.AlgorithmMegolmV1: content.Ciphertext = make([]byte, len(content.MegolmCiphertext)+2) content.Ciphertext[0] = '"' content.Ciphertext[len(content.Ciphertext)-1] = '"' copy(content.Ciphertext[1:len(content.Ciphertext)-1], content.MegolmCiphertext) } if err != nil { return nil, err } return json.Marshal((*serializableEncryptedEventContent)(content)) } // RoomKeyEventContent represents the content of a m.room_key to_device event. // https://spec.matrix.org/v1.2/client-server-api/#mroom_key type RoomKeyEventContent struct { Algorithm id.Algorithm `json:"algorithm"` RoomID id.RoomID `json:"room_id"` SessionID id.SessionID `json:"session_id"` SessionKey string `json:"session_key"` } // ForwardedRoomKeyEventContent represents the content of a m.forwarded_room_key to_device event. // https://spec.matrix.org/v1.2/client-server-api/#mforwarded_room_key type ForwardedRoomKeyEventContent struct { RoomKeyEventContent SenderKey id.SenderKey `json:"sender_key"` SenderClaimedKey id.Ed25519 `json:"sender_claimed_ed25519_key"` ForwardingKeyChain []string `json:"forwarding_curve25519_key_chain"` } type KeyRequestAction string const ( KeyRequestActionRequest = "request" KeyRequestActionCancel = "request_cancellation" ) // RoomKeyRequestEventContent represents the content of a m.room_key_request to_device event. // https://spec.matrix.org/v1.2/client-server-api/#mroom_key_request type RoomKeyRequestEventContent struct { Body RequestedKeyInfo `json:"body"` Action KeyRequestAction `json:"action"` RequestingDeviceID id.DeviceID `json:"requesting_device_id"` RequestID string `json:"request_id"` } type RequestedKeyInfo struct { Algorithm id.Algorithm `json:"algorithm"` RoomID id.RoomID `json:"room_id"` SenderKey id.SenderKey `json:"sender_key"` SessionID id.SessionID `json:"session_id"` } type RoomKeyWithheldCode string const ( RoomKeyWithheldBlacklisted RoomKeyWithheldCode = "m.blacklisted" RoomKeyWithheldUnverified RoomKeyWithheldCode = "m.unverified" RoomKeyWithheldUnauthorized RoomKeyWithheldCode = "m.unauthorized" RoomKeyWithheldUnavailable RoomKeyWithheldCode = "m.unavailable" RoomKeyWithheldNoOlmSession RoomKeyWithheldCode = "m.no_olm" ) type RoomKeyWithheldEventContent struct { RoomID id.RoomID `json:"room_id,omitempty"` Algorithm id.Algorithm `json:"algorithm"` SessionID id.SessionID `json:"session_id,omitempty"` SenderKey id.SenderKey `json:"sender_key"` Code RoomKeyWithheldCode `json:"code"` Reason string `json:"reason,omitempty"` } type DummyEventContent struct{} go-0.11.1/event/ephemeral.go000066400000000000000000000046141436100171500156140ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package event import ( "encoding/json" "maunium.net/go/mautrix/id" ) // TypingEventContent represents the content of a m.typing ephemeral event. // https://spec.matrix.org/v1.2/client-server-api/#mtyping type TypingEventContent struct { UserIDs []id.UserID `json:"user_ids"` } // ReceiptEventContent represents the content of a m.receipt ephemeral event. // https://spec.matrix.org/v1.2/client-server-api/#mreceipt type ReceiptEventContent map[id.EventID]Receipts type Receipts struct { Read map[id.UserID]ReadReceipt `json:"m.read"` } type ReadReceipt struct { Timestamp int64 `json:"ts"` // Extra contains any unknown fields in the read receipt event. // Most servers don't allow clients to set them, so this will be empty in most cases. Extra map[string]interface{} `json:"-"` } func (rr *ReadReceipt) UnmarshalJSON(data []byte) error { // Hacky compatibility hack against crappy clients that send double-encoded read receipts. // TODO is this actually needed? clients can't currently set custom content in receipts ๐Ÿค” if data[0] == '"' && data[len(data)-1] == '"' { var strData string err := json.Unmarshal(data, &strData) if err != nil { return err } data = []byte(strData) } var parsed map[string]interface{} err := json.Unmarshal(data, &parsed) if err != nil { return err } ts, _ := parsed["ts"].(float64) delete(parsed, "ts") *rr = ReadReceipt{ Timestamp: int64(ts), Extra: parsed, } return nil } type Presence string const ( PresenceOnline Presence = "online" PresenceOffline Presence = "offline" PresenceUnavailable Presence = "unavailable" ) // PresenceEventContent represents the content of a m.presence ephemeral event. // https://spec.matrix.org/v1.2/client-server-api/#mpresence type PresenceEventContent struct { Presence Presence `json:"presence"` Displayname string `json:"displayname,omitempty"` AvatarURL id.ContentURIString `json:"avatar_url,omitempty"` LastActiveAgo int64 `json:"last_active_ago,omitempty"` CurrentlyActive bool `json:"currently_active,omitempty"` StatusMessage string `json:"status_msg,omitempty"` } go-0.11.1/event/events.go000066400000000000000000000122221436100171500151500ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package event import ( "encoding/json" "maunium.net/go/mautrix/id" ) // Event represents a single Matrix event. type Event struct { StateKey *string `json:"state_key,omitempty"` // The state key for the event. Only present on State Events. Sender id.UserID `json:"sender,omitempty"` // The user ID of the sender of the event Type Type `json:"type"` // The event type Timestamp int64 `json:"origin_server_ts,omitempty"` // The unix timestamp when this message was sent by the origin server ID id.EventID `json:"event_id,omitempty"` // The unique ID of this event RoomID id.RoomID `json:"room_id,omitempty"` // The room the event was sent to. May be nil (e.g. for presence) Content Content `json:"content"` // The JSON content of the event. Redacts id.EventID `json:"redacts,omitempty"` // The event ID that was redacted if a m.room.redaction event Unsigned Unsigned `json:"unsigned,omitempty"` // Unsigned content set by own homeserver. Mautrix MautrixInfo `json:"-"` ToUserID id.UserID `json:"to_user_id,omitempty"` // The user ID that the to-device event was sent to. Only present in MSC2409 appservice transactions. ToDeviceID id.DeviceID `json:"to_device_id,omitempty"` // The device ID that the to-device event was sent to. Only present in MSC2409 appservice transactions. } type eventForMarshaling struct { StateKey *string `json:"state_key,omitempty"` Sender id.UserID `json:"sender,omitempty"` Type Type `json:"type"` Timestamp int64 `json:"origin_server_ts,omitempty"` ID id.EventID `json:"event_id,omitempty"` RoomID id.RoomID `json:"room_id,omitempty"` Content Content `json:"content"` Redacts id.EventID `json:"redacts,omitempty"` Unsigned *Unsigned `json:"unsigned,omitempty"` PrevContent *Content `json:"prev_content,omitempty"` ReplacesState *id.EventID `json:"replaces_state,omitempty"` ToUserID id.UserID `json:"to_user_id,omitempty"` ToDeviceID id.DeviceID `json:"to_device_id,omitempty"` } // UnmarshalJSON unmarshals the event, including moving prev_content from the top level to inside unsigned. func (evt *Event) UnmarshalJSON(data []byte) error { var efm eventForMarshaling err := json.Unmarshal(data, &efm) if err != nil { return err } evt.StateKey = efm.StateKey evt.Sender = efm.Sender evt.Type = efm.Type evt.Timestamp = efm.Timestamp evt.ID = efm.ID evt.RoomID = efm.RoomID evt.Content = efm.Content evt.Redacts = efm.Redacts if efm.Unsigned != nil { evt.Unsigned = *efm.Unsigned } if efm.PrevContent != nil && evt.Unsigned.PrevContent == nil { evt.Unsigned.PrevContent = efm.PrevContent } if efm.ReplacesState != nil && *efm.ReplacesState != "" && evt.Unsigned.ReplacesState == "" { evt.Unsigned.ReplacesState = *efm.ReplacesState } evt.ToUserID = efm.ToUserID evt.ToDeviceID = efm.ToDeviceID return nil } // MarshalJSON marshals the event, including omitting the unsigned field if it's empty. // // This is necessary because Unsigned is not a pointer (for convenience reasons), // and encoding/json doesn't know how to check if a non-pointer struct is empty. // // TODO(tulir): maybe it makes more sense to make Unsigned a pointer and make an easy and safe way to access it? func (evt *Event) MarshalJSON() ([]byte, error) { unsigned := &evt.Unsigned if unsigned.IsEmpty() { unsigned = nil } return json.Marshal(&eventForMarshaling{ StateKey: evt.StateKey, Sender: evt.Sender, Type: evt.Type, Timestamp: evt.Timestamp, ID: evt.ID, RoomID: evt.RoomID, Content: evt.Content, Redacts: evt.Redacts, Unsigned: unsigned, ToUserID: evt.ToUserID, ToDeviceID: evt.ToDeviceID, }) } type MautrixInfo struct { Verified bool } func (evt *Event) GetStateKey() string { if evt.StateKey != nil { return *evt.StateKey } return "" } type StrippedState struct { Content Content `json:"content"` Type Type `json:"type"` StateKey string `json:"state_key"` } type Unsigned struct { PrevContent *Content `json:"prev_content,omitempty"` PrevSender id.UserID `json:"prev_sender,omitempty"` ReplacesState id.EventID `json:"replaces_state,omitempty"` Age int64 `json:"age,omitempty"` TransactionID string `json:"transaction_id,omitempty"` Relations Relations `json:"m.relations,omitempty"` RedactedBecause *Event `json:"redacted_because,omitempty"` InviteRoomState []StrippedState `json:"invite_room_state,omitempty"` } func (us *Unsigned) IsEmpty() bool { return us.PrevContent == nil && us.PrevSender == "" && us.ReplacesState == "" && us.Age == 0 && us.TransactionID == "" && us.RedactedBecause == nil && us.InviteRoomState == nil && us.Relations.Raw == nil && us.Relations.Annotations.Map == nil && us.Relations.References.List == nil && us.Relations.Replaces.List == nil } go-0.11.1/event/member.go000066400000000000000000000032501436100171500151140ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package event import ( "encoding/json" "maunium.net/go/mautrix/id" ) // Membership is an enum specifying the membership state of a room member. type Membership string func (ms Membership) IsInviteOrJoin() bool { return ms == MembershipJoin || ms == MembershipInvite } func (ms Membership) IsLeaveOrBan() bool { return ms == MembershipLeave || ms == MembershipBan } // The allowed membership states as specified in spec section 10.5.5. const ( MembershipJoin Membership = "join" MembershipLeave Membership = "leave" MembershipInvite Membership = "invite" MembershipBan Membership = "ban" MembershipKnock Membership = "knock" ) // MemberEventContent represents the content of a m.room.member state event. // https://spec.matrix.org/v1.2/client-server-api/#mroommember type MemberEventContent struct { Membership Membership `json:"membership"` AvatarURL id.ContentURIString `json:"avatar_url,omitempty"` Displayname string `json:"displayname,omitempty"` IsDirect bool `json:"is_direct,omitempty"` ThirdPartyInvite *ThirdPartyInvite `json:"third_party_invite,omitempty"` Reason string `json:"reason,omitempty"` } type ThirdPartyInvite struct { DisplayName string `json:"display_name"` Signed struct { Token string `json:"token"` Signatures json.RawMessage `json:"signatures"` MXID string `json:"mxid"` } } go-0.11.1/event/message.go000066400000000000000000000167241436100171500153030ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package event import ( "encoding/json" "strconv" "strings" "golang.org/x/net/html" "maunium.net/go/mautrix/crypto/attachment" "maunium.net/go/mautrix/id" ) // MessageType is the sub-type of a m.room.message event. // https://spec.matrix.org/v1.2/client-server-api/#mroommessage-msgtypes type MessageType string // Msgtypes const ( MsgText MessageType = "m.text" MsgEmote MessageType = "m.emote" MsgNotice MessageType = "m.notice" MsgImage MessageType = "m.image" MsgLocation MessageType = "m.location" MsgVideo MessageType = "m.video" MsgAudio MessageType = "m.audio" MsgFile MessageType = "m.file" MsgVerificationRequest MessageType = "m.key.verification.request" ) // Format specifies the format of the formatted_body in m.room.message events. // https://spec.matrix.org/v1.2/client-server-api/#mroommessage-msgtypes type Format string // Message formats const ( FormatHTML Format = "org.matrix.custom.html" ) // RedactionEventContent represents the content of a m.room.redaction message event. // // The redacted event ID is still at the top level, but will move in a future room version. // See https://github.com/matrix-org/matrix-doc/pull/2244 and https://github.com/matrix-org/matrix-doc/pull/2174 // // https://spec.matrix.org/v1.2/client-server-api/#mroomredaction type RedactionEventContent struct { Reason string `json:"reason,omitempty"` } // ReactionEventContent represents the content of a m.reaction message event. // This is not yet in a spec release, see https://github.com/matrix-org/matrix-doc/pull/1849 type ReactionEventContent struct { RelatesTo RelatesTo `json:"m.relates_to"` } func (content *ReactionEventContent) GetRelatesTo() *RelatesTo { return &content.RelatesTo } func (content *ReactionEventContent) OptionalGetRelatesTo() *RelatesTo { return &content.RelatesTo } func (content *ReactionEventContent) SetRelatesTo(rel *RelatesTo) { content.RelatesTo = *rel } // MessageEventContent represents the content of a m.room.message event. // // It is also used to represent m.sticker events, as they are equivalent to m.room.message // with the exception of the msgtype field. // // https://spec.matrix.org/v1.2/client-server-api/#mroommessage type MessageEventContent struct { // Base m.room.message fields MsgType MessageType `json:"msgtype,omitempty"` Body string `json:"body"` // Extra fields for text types Format Format `json:"format,omitempty"` FormattedBody string `json:"formatted_body,omitempty"` // Extra field for m.location GeoURI string `json:"geo_uri,omitempty"` // Extra fields for media types URL id.ContentURIString `json:"url,omitempty"` Info *FileInfo `json:"info,omitempty"` File *EncryptedFileInfo `json:"file,omitempty"` // Edits and relations NewContent *MessageEventContent `json:"m.new_content,omitempty"` RelatesTo *RelatesTo `json:"m.relates_to,omitempty"` // In-room verification To id.UserID `json:"to,omitempty"` FromDevice id.DeviceID `json:"from_device,omitempty"` Methods []VerificationMethod `json:"methods,omitempty"` replyFallbackRemoved bool } func (content *MessageEventContent) GetRelatesTo() *RelatesTo { if content.RelatesTo == nil { content.RelatesTo = &RelatesTo{} } return content.RelatesTo } func (content *MessageEventContent) OptionalGetRelatesTo() *RelatesTo { return content.RelatesTo } func (content *MessageEventContent) SetRelatesTo(rel *RelatesTo) { content.RelatesTo = rel } func (content *MessageEventContent) SetEdit(original id.EventID) { newContent := *content content.NewContent = &newContent content.RelatesTo = &RelatesTo{ Type: RelReplace, EventID: original, } if content.MsgType == MsgText || content.MsgType == MsgNotice { content.Body = "* " + content.Body if content.Format == FormatHTML && len(content.FormattedBody) > 0 { content.FormattedBody = "* " + content.FormattedBody } } } func (content *MessageEventContent) EnsureHasHTML() { if len(content.FormattedBody) == 0 || content.Format != FormatHTML { content.FormattedBody = strings.ReplaceAll(html.EscapeString(content.Body), "\n", "
") content.Format = FormatHTML } } func (content *MessageEventContent) GetFile() *EncryptedFileInfo { if content.File == nil { content.File = &EncryptedFileInfo{} } return content.File } func (content *MessageEventContent) GetInfo() *FileInfo { if content.Info == nil { content.Info = &FileInfo{} } return content.Info } type EncryptedFileInfo struct { attachment.EncryptedFile URL id.ContentURIString `json:"url"` } type FileInfo struct { MimeType string `json:"mimetype,omitempty"` ThumbnailInfo *FileInfo `json:"thumbnail_info,omitempty"` ThumbnailURL id.ContentURIString `json:"thumbnail_url,omitempty"` ThumbnailFile *EncryptedFileInfo `json:"thumbnail_file,omitempty"` Width int `json:"-"` Height int `json:"-"` Duration int `json:"-"` Size int `json:"-"` } type serializableFileInfo struct { MimeType string `json:"mimetype,omitempty"` ThumbnailInfo *serializableFileInfo `json:"thumbnail_info,omitempty"` ThumbnailURL id.ContentURIString `json:"thumbnail_url,omitempty"` ThumbnailFile *EncryptedFileInfo `json:"thumbnail_file,omitempty"` Width json.Number `json:"w,omitempty"` Height json.Number `json:"h,omitempty"` Duration json.Number `json:"duration,omitempty"` Size json.Number `json:"size,omitempty"` } func (sfi *serializableFileInfo) CopyFrom(fileInfo *FileInfo) *serializableFileInfo { if fileInfo == nil { return nil } *sfi = serializableFileInfo{ MimeType: fileInfo.MimeType, ThumbnailURL: fileInfo.ThumbnailURL, ThumbnailInfo: (&serializableFileInfo{}).CopyFrom(fileInfo.ThumbnailInfo), ThumbnailFile: fileInfo.ThumbnailFile, } if fileInfo.Width > 0 { sfi.Width = json.Number(strconv.Itoa(fileInfo.Width)) } if fileInfo.Height > 0 { sfi.Height = json.Number(strconv.Itoa(fileInfo.Height)) } if fileInfo.Size > 0 { sfi.Size = json.Number(strconv.Itoa(fileInfo.Size)) } if fileInfo.Duration > 0 { sfi.Duration = json.Number(strconv.Itoa(int(fileInfo.Duration))) } return sfi } func (sfi *serializableFileInfo) CopyTo(fileInfo *FileInfo) { *fileInfo = FileInfo{ Width: numberToInt(sfi.Width), Height: numberToInt(sfi.Height), Size: numberToInt(sfi.Size), Duration: numberToInt(sfi.Duration), MimeType: sfi.MimeType, ThumbnailURL: sfi.ThumbnailURL, ThumbnailFile: sfi.ThumbnailFile, } if sfi.ThumbnailInfo != nil { fileInfo.ThumbnailInfo = &FileInfo{} sfi.ThumbnailInfo.CopyTo(fileInfo.ThumbnailInfo) } } func (fileInfo *FileInfo) UnmarshalJSON(data []byte) error { sfi := &serializableFileInfo{} if err := json.Unmarshal(data, sfi); err != nil { return err } sfi.CopyTo(fileInfo) return nil } func (fileInfo *FileInfo) MarshalJSON() ([]byte, error) { return json.Marshal((&serializableFileInfo{}).CopyFrom(fileInfo)) } func numberToInt(val json.Number) int { f64, _ := val.Float64() if f64 > 0 { return int(f64) } return 0 } func (fileInfo *FileInfo) GetThumbnailInfo() *FileInfo { if fileInfo.ThumbnailInfo == nil { fileInfo.ThumbnailInfo = &FileInfo{} } return fileInfo.ThumbnailInfo } go-0.11.1/event/message_test.go000066400000000000000000000113521436100171500163320ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package event_test import ( "encoding/json" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) const invalidMessageEvent = `{ "sender": "@tulir:maunium.net", "type": "m.room.message", "origin_server_ts": 1587252684192, "event_id": "$foo", "room_id": "!bar", "content": { "body": { "hmm": false } } }` func TestMessageEventContent__ParseInvalid(t *testing.T) { var evt *event.Event err := json.Unmarshal([]byte(invalidMessageEvent), &evt) assert.Nil(t, err) assert.Equal(t, id.UserID("@tulir:maunium.net"), evt.Sender) assert.Equal(t, event.EventMessage, evt.Type) assert.Equal(t, int64(1587252684192), evt.Timestamp) assert.Equal(t, id.EventID("$foo"), evt.ID) assert.Equal(t, id.RoomID("!bar"), evt.RoomID) err = evt.Content.ParseRaw(evt.Type) assert.NotNil(t, err) } const messageEvent = `{ "sender": "@tulir:maunium.net", "type": "m.room.message", "origin_server_ts": 1587252684192, "event_id": "$foo", "room_id": "!bar", "content": { "msgtype": "m.text", "body": "* **Hello**, World!", "format": "org.matrix.custom.html", "formatted_body": "* Hello, World!", "m.new_content": { "msgtype": "m.text", "body": "**Hello**, World!", "format": "org.matrix.custom.html", "formatted_body": "Hello, World!" } } }` func TestMessageEventContent__ParseEdit(t *testing.T) { var evt *event.Event err := json.Unmarshal([]byte(messageEvent), &evt) assert.Nil(t, err) assert.Equal(t, id.UserID("@tulir:maunium.net"), evt.Sender) assert.Equal(t, event.EventMessage, evt.Type) assert.Equal(t, int64(1587252684192), evt.Timestamp) assert.Equal(t, id.EventID("$foo"), evt.ID) assert.Equal(t, id.RoomID("!bar"), evt.RoomID) err = evt.Content.ParseRaw(evt.Type) require.NoError(t, err) assert.IsType(t, &event.MessageEventContent{}, evt.Content.Parsed) content := evt.Content.Parsed.(*event.MessageEventContent) assert.Equal(t, event.MsgText, content.MsgType) assert.Equal(t, event.MsgText, content.NewContent.MsgType) assert.Equal(t, "**Hello**, World!", content.NewContent.Body) assert.Equal(t, "Hello, World!", content.NewContent.FormattedBody) } const imageMessageEvent = `{ "sender": "@tulir:maunium.net", "type": "m.room.message", "origin_server_ts": 1587252684192, "event_id": "$foo", "room_id": "!bar", "content": { "msgtype": "m.image", "body": "image.png", "url": "mxc://example.com/image", "info": { "mimetype": "image/png", "w": 64, "h": 64, "size": 12345, "thumbnail_url": "mxc://example.com/image_thumb" } } }` func TestMessageEventContent__ParseMedia(t *testing.T) { var evt *event.Event err := json.Unmarshal([]byte(imageMessageEvent), &evt) assert.Nil(t, err) assert.Equal(t, id.UserID("@tulir:maunium.net"), evt.Sender) assert.Equal(t, event.EventMessage, evt.Type) assert.Equal(t, int64(1587252684192), evt.Timestamp) assert.Equal(t, id.EventID("$foo"), evt.ID) assert.Equal(t, id.RoomID("!bar"), evt.RoomID) err = evt.Content.ParseRaw(evt.Type) require.NoError(t, err) assert.IsType(t, &event.MessageEventContent{}, evt.Content.Parsed) content := evt.Content.Parsed.(*event.MessageEventContent) assert.Equal(t, event.MsgImage, content.MsgType) parsedURL, err := content.URL.Parse() assert.Nil(t, err) assert.Equal(t, id.ContentURI{Homeserver: "example.com", FileID: "image"}, parsedURL) assert.Nil(t, content.NewContent) assert.Equal(t, "image/png", content.GetInfo().MimeType) assert.EqualValues(t, 64, content.GetInfo().Width) assert.EqualValues(t, 64, content.GetInfo().Height) assert.EqualValues(t, 12345, content.GetInfo().Size) } var parsedMessage = &event.Content{ Parsed: &event.MessageEventContent{ MsgType: event.MsgText, Body: "test", }, } const expectedMarshalResult = `{"msgtype":"m.text","body":"test"}` func TestMessageEventContent__Marshal(t *testing.T) { data, err := json.Marshal(parsedMessage) assert.Nil(t, err) assert.Equal(t, expectedMarshalResult, string(data)) } var customParsedMessage = &event.Content{ Raw: map[string]interface{}{ "net.maunium.custom": "hello world", }, Parsed: &event.MessageEventContent{ MsgType: event.MsgText, Body: "test", }, } const expectedCustomMarshalResult = `{"body":"test","msgtype":"m.text","net.maunium.custom":"hello world"}` func TestMessageEventContent__Marshal_Custom(t *testing.T) { data, err := json.Marshal(customParsedMessage) assert.Nil(t, err) assert.Equal(t, expectedCustomMarshalResult, string(data)) } go-0.11.1/event/powerlevels.go000066400000000000000000000063771436100171500162310ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package event import ( "sync" "maunium.net/go/mautrix/id" ) // PowerLevelsEventContent represents the content of a m.room.power_levels state event content. // https://spec.matrix.org/v1.2/client-server-api/#mroompower_levels type PowerLevelsEventContent struct { usersLock sync.RWMutex Users map[id.UserID]int `json:"users,omitempty"` UsersDefault int `json:"users_default,omitempty"` eventsLock sync.RWMutex Events map[string]int `json:"events,omitempty"` EventsDefault int `json:"events_default,omitempty"` StateDefaultPtr *int `json:"state_default,omitempty"` InvitePtr *int `json:"invite,omitempty"` KickPtr *int `json:"kick,omitempty"` BanPtr *int `json:"ban,omitempty"` RedactPtr *int `json:"redact,omitempty"` HistoricalPtr *int `json:"historical,omitempty"` } func (pl *PowerLevelsEventContent) Invite() int { if pl.InvitePtr != nil { return *pl.InvitePtr } return 50 } func (pl *PowerLevelsEventContent) Kick() int { if pl.KickPtr != nil { return *pl.KickPtr } return 50 } func (pl *PowerLevelsEventContent) Ban() int { if pl.BanPtr != nil { return *pl.BanPtr } return 50 } func (pl *PowerLevelsEventContent) Redact() int { if pl.RedactPtr != nil { return *pl.RedactPtr } return 50 } func (pl *PowerLevelsEventContent) Historical() int { if pl.HistoricalPtr != nil { return *pl.HistoricalPtr } return 100 } func (pl *PowerLevelsEventContent) StateDefault() int { if pl.StateDefaultPtr != nil { return *pl.StateDefaultPtr } return 50 } func (pl *PowerLevelsEventContent) GetUserLevel(userID id.UserID) int { pl.usersLock.RLock() defer pl.usersLock.RUnlock() level, ok := pl.Users[userID] if !ok { return pl.UsersDefault } return level } func (pl *PowerLevelsEventContent) SetUserLevel(userID id.UserID, level int) { pl.usersLock.Lock() defer pl.usersLock.Unlock() if level == pl.UsersDefault { delete(pl.Users, userID) } else { pl.Users[userID] = level } } func (pl *PowerLevelsEventContent) EnsureUserLevel(userID id.UserID, level int) bool { existingLevel := pl.GetUserLevel(userID) if existingLevel != level { pl.SetUserLevel(userID, level) return true } return false } func (pl *PowerLevelsEventContent) GetEventLevel(eventType Type) int { pl.eventsLock.RLock() defer pl.eventsLock.RUnlock() level, ok := pl.Events[eventType.String()] if !ok { if eventType.IsState() { return pl.StateDefault() } return pl.EventsDefault } return level } func (pl *PowerLevelsEventContent) SetEventLevel(eventType Type, level int) { pl.eventsLock.Lock() defer pl.eventsLock.Unlock() if (eventType.IsState() && level == pl.StateDefault()) || (!eventType.IsState() && level == pl.EventsDefault) { delete(pl.Events, eventType.String()) } else { pl.Events[eventType.String()] = level } } func (pl *PowerLevelsEventContent) EnsureEventLevel(eventType Type, level int) bool { existingLevel := pl.GetEventLevel(eventType) if existingLevel != level { pl.SetEventLevel(eventType, level) return true } return false } go-0.11.1/event/relations.go000066400000000000000000000112401436100171500156430ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package event import ( "encoding/json" "maunium.net/go/mautrix/id" ) type RelationType string const ( RelReplace RelationType = "m.replace" RelReference RelationType = "m.reference" RelAnnotation RelationType = "m.annotation" RelReply RelationType = "net.maunium.reply" ) type RelatesTo struct { Type RelationType EventID id.EventID Key string } type serializableInReplyTo struct { EventID id.EventID `json:"event_id,omitempty"` } type serializableRelatesTo struct { InReplyTo *serializableInReplyTo `json:"m.in_reply_to,omitempty"` Type RelationType `json:"rel_type,omitempty"` EventID id.EventID `json:"event_id,omitempty"` Key string `json:"key,omitempty"` } func (rel *RelatesTo) GetReplaceID() id.EventID { if rel.Type == RelReplace { return rel.EventID } return "" } func (rel *RelatesTo) GetReferenceID() id.EventID { if rel.Type == RelReference { return rel.EventID } return "" } func (rel *RelatesTo) GetReplyID() id.EventID { if rel.Type == RelReply { return rel.EventID } return "" } func (rel *RelatesTo) GetAnnotationID() id.EventID { if rel.Type == RelAnnotation { return rel.EventID } return "" } func (rel *RelatesTo) GetAnnotationKey() string { if rel.Type == RelAnnotation { return rel.Key } return "" } func (rel *RelatesTo) UnmarshalJSON(data []byte) error { var srel serializableRelatesTo if err := json.Unmarshal(data, &srel); err != nil { return err } if len(srel.Type) > 0 { rel.Type = srel.Type rel.EventID = srel.EventID rel.Key = srel.Key } else if srel.InReplyTo != nil && len(srel.InReplyTo.EventID) > 0 { rel.Type = RelReply rel.EventID = srel.InReplyTo.EventID rel.Key = "" } return nil } func (rel *RelatesTo) MarshalJSON() ([]byte, error) { srel := serializableRelatesTo{Type: rel.Type, EventID: rel.EventID, Key: rel.Key} if rel.Type == RelReply { srel.InReplyTo = &serializableInReplyTo{rel.EventID} } return json.Marshal(&srel) } type RelationChunkItem struct { Type RelationType `json:"type"` EventID string `json:"event_id,omitempty"` Key string `json:"key,omitempty"` Count int `json:"count,omitempty"` } type RelationChunk struct { Chunk []RelationChunkItem `json:"chunk"` Limited bool `json:"limited"` Count int `json:"count"` } type AnnotationChunk struct { RelationChunk Map map[string]int `json:"-"` } type serializableAnnotationChunk AnnotationChunk func (ac *AnnotationChunk) UnmarshalJSON(data []byte) error { if err := json.Unmarshal(data, (*serializableAnnotationChunk)(ac)); err != nil { return err } ac.Map = make(map[string]int) for _, item := range ac.Chunk { ac.Map[item.Key] += item.Count } return nil } func (ac *AnnotationChunk) Serialize() RelationChunk { ac.Chunk = make([]RelationChunkItem, len(ac.Map)) i := 0 for key, count := range ac.Map { ac.Chunk[i] = RelationChunkItem{ Type: RelAnnotation, Key: key, Count: count, } } return ac.RelationChunk } type EventIDChunk struct { RelationChunk List []string `json:"-"` } type serializableEventIDChunk EventIDChunk func (ec *EventIDChunk) UnmarshalJSON(data []byte) error { if err := json.Unmarshal(data, (*serializableEventIDChunk)(ec)); err != nil { return err } for _, item := range ec.Chunk { ec.List = append(ec.List, item.EventID) } return nil } func (ec *EventIDChunk) Serialize(typ RelationType) RelationChunk { ec.Chunk = make([]RelationChunkItem, len(ec.List)) for i, eventID := range ec.List { ec.Chunk[i] = RelationChunkItem{ Type: typ, EventID: eventID, } } return ec.RelationChunk } type Relations struct { Raw map[RelationType]RelationChunk `json:"-"` Annotations AnnotationChunk `json:"m.annotation,omitempty"` References EventIDChunk `json:"m.reference,omitempty"` Replaces EventIDChunk `json:"m.replace,omitempty"` } type serializableRelations Relations func (relations *Relations) UnmarshalJSON(data []byte) error { if err := json.Unmarshal(data, &relations.Raw); err != nil { return err } return json.Unmarshal(data, (*serializableRelations)(relations)) } func (relations *Relations) MarshalJSON() ([]byte, error) { if relations.Raw == nil { relations.Raw = make(map[RelationType]RelationChunk) } relations.Raw[RelAnnotation] = relations.Annotations.Serialize() relations.Raw[RelReference] = relations.References.Serialize(RelReference) relations.Raw[RelReplace] = relations.Replaces.Serialize(RelReplace) return json.Marshal(relations.Raw) } go-0.11.1/event/reply.go000066400000000000000000000057411436100171500150070ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package event import ( "fmt" "regexp" "strings" "golang.org/x/net/html" "maunium.net/go/mautrix/id" ) var HTMLReplyFallbackRegex = regexp.MustCompile(`^[\s\S]+?`) func TrimReplyFallbackHTML(html string) string { return HTMLReplyFallbackRegex.ReplaceAllString(html, "") } func TrimReplyFallbackText(text string) string { if !strings.HasPrefix(text, "> ") || !strings.Contains(text, "\n") { return text } lines := strings.Split(text, "\n") for len(lines) > 0 && strings.HasPrefix(lines[0], "> ") { lines = lines[1:] } return strings.TrimSpace(strings.Join(lines, "\n")) } func (content *MessageEventContent) RemoveReplyFallback() { if len(content.GetReplyTo()) > 0 && !content.replyFallbackRemoved { if content.Format == FormatHTML { content.FormattedBody = TrimReplyFallbackHTML(content.FormattedBody) } content.Body = TrimReplyFallbackText(content.Body) content.replyFallbackRemoved = true } } func (content *MessageEventContent) GetReplyTo() id.EventID { if content.RelatesTo != nil && content.RelatesTo.Type == RelReply { return content.RelatesTo.EventID } return "" } const ReplyFormat = `
In reply to %s
%s
` func (evt *Event) GenerateReplyFallbackHTML() string { parsedContent, ok := evt.Content.Parsed.(*MessageEventContent) if !ok { return "" } parsedContent.RemoveReplyFallback() body := parsedContent.FormattedBody if len(body) == 0 { body = strings.ReplaceAll(html.EscapeString(parsedContent.Body), "\n", "
") } senderDisplayName := evt.Sender return fmt.Sprintf(ReplyFormat, evt.RoomID, evt.ID, evt.Sender, senderDisplayName, body) } func (evt *Event) GenerateReplyFallbackText() string { parsedContent, ok := evt.Content.Parsed.(*MessageEventContent) if !ok { return "" } parsedContent.RemoveReplyFallback() body := parsedContent.Body lines := strings.Split(strings.TrimSpace(body), "\n") firstLine, lines := lines[0], lines[1:] senderDisplayName := evt.Sender var fallbackText strings.Builder _, _ = fmt.Fprintf(&fallbackText, "> <%s> %s", senderDisplayName, firstLine) for _, line := range lines { _, _ = fmt.Fprintf(&fallbackText, "\n> %s", line) } fallbackText.WriteString("\n\n") return fallbackText.String() } func (content *MessageEventContent) SetReply(inReplyTo *Event) { content.RelatesTo = &RelatesTo{ EventID: inReplyTo.ID, Type: RelReply, } if content.MsgType == MsgText || content.MsgType == MsgNotice { content.EnsureHasHTML() content.FormattedBody = inReplyTo.GenerateReplyFallbackHTML() + content.FormattedBody content.Body = inReplyTo.GenerateReplyFallbackText() + content.Body content.replyFallbackRemoved = false } } go-0.11.1/event/state.go000066400000000000000000000136511436100171500147730ustar00rootroot00000000000000// Copyright (c) 2021 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package event import ( "maunium.net/go/mautrix/id" ) // CanonicalAliasEventContent represents the content of a m.room.canonical_alias state event. // https://spec.matrix.org/v1.2/client-server-api/#mroomcanonical_alias type CanonicalAliasEventContent struct { Alias id.RoomAlias `json:"alias"` AltAliases []id.RoomAlias `json:"alt_aliases,omitempty"` } // RoomNameEventContent represents the content of a m.room.name state event. // https://spec.matrix.org/v1.2/client-server-api/#mroomname type RoomNameEventContent struct { Name string `json:"name"` } // RoomAvatarEventContent represents the content of a m.room.avatar state event. // https://spec.matrix.org/v1.2/client-server-api/#mroomavatar type RoomAvatarEventContent struct { URL id.ContentURI `json:"url"` Info *FileInfo `json:"info,omitempty"` } // ServerACLEventContent represents the content of a m.room.server_acl state event. // https://spec.matrix.org/v1.2/client-server-api/#server-access-control-lists-acls-for-rooms type ServerACLEventContent struct { Allow []string `json:"allow,omitempty"` AllowIPLiterals bool `json:"allow_ip_literals"` Deny []string `json:"deny,omitempty"` } // TopicEventContent represents the content of a m.room.topic state event. // https://spec.matrix.org/v1.2/client-server-api/#mroomtopic type TopicEventContent struct { Topic string `json:"topic"` } // TombstoneEventContent represents the content of a m.room.tombstone state event. // https://spec.matrix.org/v1.2/client-server-api/#mroomtombstone type TombstoneEventContent struct { Body string `json:"body"` ReplacementRoom id.RoomID `json:"replacement_room"` } // CreateEventContent represents the content of a m.room.create state event. // https://spec.matrix.org/v1.2/client-server-api/#mroomcreate type CreateEventContent struct { Type RoomType `json:"type,omitempty"` Creator id.UserID `json:"creator,omitempty"` Federate bool `json:"m.federate,omitempty"` RoomVersion string `json:"version,omitempty"` Predecessor struct { RoomID id.RoomID `json:"room_id"` EventID id.EventID `json:"event_id"` } `json:"predecessor"` } // JoinRule specifies how open a room is to new members. // https://spec.matrix.org/v1.2/client-server-api/#mroomjoin_rules type JoinRule string const ( JoinRulePublic JoinRule = "public" JoinRuleKnock JoinRule = "knock" JoinRuleInvite JoinRule = "invite" JoinRuleRestricted JoinRule = "restricted" JoinRulePrivate JoinRule = "private" ) // JoinRulesEventContent represents the content of a m.room.join_rules state event. // https://spec.matrix.org/v1.2/client-server-api/#mroomjoin_rules type JoinRulesEventContent struct { JoinRule JoinRule `json:"join_rule"` Allow []JoinRuleAllow `json:"allow,omitempty"` } type JoinRuleAllowType string const ( JoinRuleAllowRoomMembership JoinRuleAllowType = "m.room_membership" ) type JoinRuleAllow struct { RoomID id.RoomID `json:"room_id"` Type JoinRuleAllowType `json:"type"` } // PinnedEventsEventContent represents the content of a m.room.pinned_events state event. // https://spec.matrix.org/v1.2/client-server-api/#mroompinned_events type PinnedEventsEventContent struct { Pinned []id.EventID `json:"pinned"` } // HistoryVisibility specifies who can see new messages. // https://spec.matrix.org/v1.2/client-server-api/#mroomhistory_visibility type HistoryVisibility string const ( HistoryVisibilityInvited HistoryVisibility = "invited" HistoryVisibilityJoined HistoryVisibility = "joined" HistoryVisibilityShared HistoryVisibility = "shared" HistoryVisibilityWorldReadable HistoryVisibility = "world_readable" ) // HistoryVisibilityEventContent represents the content of a m.room.history_visibility state event. // https://spec.matrix.org/v1.2/client-server-api/#mroomhistory_visibility type HistoryVisibilityEventContent struct { HistoryVisibility HistoryVisibility `json:"history_visibility"` } // GuestAccess specifies whether or not guest accounts can join. // https://spec.matrix.org/v1.2/client-server-api/#mroomguest_access type GuestAccess string const ( GuestAccessCanJoin GuestAccess = "can_join" GuestAccessForbidden GuestAccess = "forbidden" ) // GuestAccessEventContent represents the content of a m.room.guest_access state event. // https://spec.matrix.org/v1.2/client-server-api/#mroomguest_access type GuestAccessEventContent struct { GuestAccess GuestAccess `json:"guest_access"` } type BridgeInfoSection struct { ID string `json:"id"` DisplayName string `json:"displayname,omitempty"` AvatarURL id.ContentURIString `json:"avatar_url,omitempty"` ExternalURL string `json:"external_url,omitempty"` } // BridgeEventContent represents the content of a m.bridge state event. // https://github.com/matrix-org/matrix-doc/pull/2346 type BridgeEventContent struct { BridgeBot id.UserID `json:"bridgebot"` Creator id.UserID `json:"creator,omitempty"` Protocol BridgeInfoSection `json:"protocol"` Network *BridgeInfoSection `json:"network,omitempty"` Channel BridgeInfoSection `json:"channel"` } type SpaceChildEventContent struct { Via []string `json:"via,omitempty"` Order string `json:"order,omitempty"` } type SpaceParentEventContent struct { Via []string `json:"via,omitempty"` Canonical bool `json:"canonical,omitempty"` } // ModPolicyContent represents the content of a m.room.rule.user, m.room.rule.room, and m.room.rule.server state event. // https://spec.matrix.org/v1.2/client-server-api/#moderation-policy-lists type ModPolicyContent struct { Entity string `json:"entity"` Reason string `json:"reason"` Recommendation string `json:"recommendation"` } go-0.11.1/event/type.go000066400000000000000000000217131436100171500146320ustar00rootroot00000000000000// Copyright (c) 2021 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package event import ( "encoding/json" "fmt" "strings" ) type RoomType string const ( RoomTypeDefault RoomType = "" RoomTypeSpace RoomType = "m.space" ) type TypeClass int func (tc TypeClass) Name() string { switch tc { case MessageEventType: return "message" case StateEventType: return "state" case EphemeralEventType: return "ephemeral" case AccountDataEventType: return "account data" case ToDeviceEventType: return "to-device" default: return "unknown" } } const ( // Unknown events UnknownEventType TypeClass = iota // Normal message events MessageEventType // State events StateEventType // Ephemeral events EphemeralEventType // Account data events AccountDataEventType // Device-to-device events ToDeviceEventType ) type Type struct { Type string Class TypeClass } func NewEventType(name string) Type { evtType := Type{Type: name} evtType.Class = evtType.GuessClass() return evtType } func (et *Type) IsState() bool { return et.Class == StateEventType } func (et *Type) IsEphemeral() bool { return et.Class == EphemeralEventType } func (et *Type) IsAccountData() bool { return et.Class == AccountDataEventType } func (et *Type) IsToDevice() bool { return et.Class == ToDeviceEventType } func (et *Type) IsInRoomVerification() bool { switch et.Type { case InRoomVerificationStart.Type, InRoomVerificationReady.Type, InRoomVerificationAccept.Type, InRoomVerificationKey.Type, InRoomVerificationMAC.Type, InRoomVerificationCancel.Type: return true default: return false } } func (et *Type) IsCall() bool { switch et.Type { case CallInvite.Type, CallCandidates.Type, CallAnswer.Type, CallReject.Type, CallSelectAnswer.Type, CallNegotiate.Type, CallHangup.Type: return true default: return false } } func (et *Type) IsCustom() bool { return !strings.HasPrefix(et.Type, "m.") } func (et *Type) GuessClass() TypeClass { switch et.Type { case StateAliases.Type, StateCanonicalAlias.Type, StateCreate.Type, StateJoinRules.Type, StateMember.Type, StatePowerLevels.Type, StateRoomName.Type, StateRoomAvatar.Type, StateServerACL.Type, StateTopic.Type, StatePinnedEvents.Type, StateTombstone.Type, StateEncryption.Type, StateBridge.Type, StateHalfShotBridge.Type, StateSpaceParent.Type, StateSpaceChild.Type, StatePolicyRoom.Type, StatePolicyServer.Type, StatePolicyUser.Type: return StateEventType case EphemeralEventReceipt.Type, EphemeralEventTyping.Type, EphemeralEventPresence.Type: return EphemeralEventType case AccountDataDirectChats.Type, AccountDataPushRules.Type, AccountDataRoomTags.Type, AccountDataSecretStorageKey.Type, AccountDataSecretStorageDefaultKey.Type, AccountDataCrossSigningMaster.Type, AccountDataCrossSigningSelf.Type, AccountDataCrossSigningUser.Type: return AccountDataEventType case EventRedaction.Type, EventMessage.Type, EventEncrypted.Type, EventReaction.Type, EventSticker.Type, InRoomVerificationStart.Type, InRoomVerificationReady.Type, InRoomVerificationAccept.Type, InRoomVerificationKey.Type, InRoomVerificationMAC.Type, InRoomVerificationCancel.Type, CallInvite.Type, CallCandidates.Type, CallAnswer.Type, CallReject.Type, CallSelectAnswer.Type, CallNegotiate.Type, CallHangup.Type: return MessageEventType case ToDeviceRoomKey.Type, ToDeviceRoomKeyRequest.Type, ToDeviceForwardedRoomKey.Type, ToDeviceRoomKeyWithheld.Type: return ToDeviceEventType default: return UnknownEventType } } func (et *Type) UnmarshalJSON(data []byte) error { err := json.Unmarshal(data, &et.Type) if err != nil { return err } et.Class = et.GuessClass() return nil } func (et *Type) MarshalJSON() ([]byte, error) { return json.Marshal(&et.Type) } func (et Type) UnmarshalText(data []byte) error { et.Type = string(data) et.Class = et.GuessClass() return nil } func (et Type) MarshalText() ([]byte, error) { return []byte(et.Type), nil } func (et *Type) String() string { return et.Type } func (et *Type) Repr() string { return fmt.Sprintf("%s (%s)", et.Type, et.Class.Name()) } // State events var ( StateAliases = Type{"m.room.aliases", StateEventType} StateCanonicalAlias = Type{"m.room.canonical_alias", StateEventType} StateCreate = Type{"m.room.create", StateEventType} StateJoinRules = Type{"m.room.join_rules", StateEventType} StateHistoryVisibility = Type{"m.room.history_visibility", StateEventType} StateGuestAccess = Type{"m.room.guest_access", StateEventType} StateMember = Type{"m.room.member", StateEventType} StatePowerLevels = Type{"m.room.power_levels", StateEventType} StateRoomName = Type{"m.room.name", StateEventType} StateTopic = Type{"m.room.topic", StateEventType} StateRoomAvatar = Type{"m.room.avatar", StateEventType} StatePinnedEvents = Type{"m.room.pinned_events", StateEventType} StateServerACL = Type{"m.room.server_acl", StateEventType} StateTombstone = Type{"m.room.tombstone", StateEventType} StatePolicyRoom = Type{"m.policy.rule.room", StateEventType} StatePolicyServer = Type{"m.policy.rule.server", StateEventType} StatePolicyUser = Type{"m.policy.rule.user", StateEventType} StateEncryption = Type{"m.room.encryption", StateEventType} StateBridge = Type{"m.bridge", StateEventType} StateHalfShotBridge = Type{"uk.half-shot.bridge", StateEventType} StateSpaceChild = Type{"m.space.child", StateEventType} StateSpaceParent = Type{"m.space.parent", StateEventType} ) // Message events var ( EventRedaction = Type{"m.room.redaction", MessageEventType} EventMessage = Type{"m.room.message", MessageEventType} EventEncrypted = Type{"m.room.encrypted", MessageEventType} EventReaction = Type{"m.reaction", MessageEventType} EventSticker = Type{"m.sticker", MessageEventType} InRoomVerificationStart = Type{"m.key.verification.start", MessageEventType} InRoomVerificationReady = Type{"m.key.verification.ready", MessageEventType} InRoomVerificationAccept = Type{"m.key.verification.accept", MessageEventType} InRoomVerificationKey = Type{"m.key.verification.key", MessageEventType} InRoomVerificationMAC = Type{"m.key.verification.mac", MessageEventType} InRoomVerificationCancel = Type{"m.key.verification.cancel", MessageEventType} CallInvite = Type{"m.call.invite", MessageEventType} CallCandidates = Type{"m.call.candidates", MessageEventType} CallAnswer = Type{"m.call.answer", MessageEventType} CallReject = Type{"m.call.reject", MessageEventType} CallSelectAnswer = Type{"m.call.select_answer", MessageEventType} CallNegotiate = Type{"m.call.negotiate", MessageEventType} CallHangup = Type{"m.call.hangup", MessageEventType} ) // Ephemeral events var ( EphemeralEventReceipt = Type{"m.receipt", EphemeralEventType} EphemeralEventTyping = Type{"m.typing", EphemeralEventType} EphemeralEventPresence = Type{"m.presence", EphemeralEventType} ) // Account data events var ( AccountDataDirectChats = Type{"m.direct", AccountDataEventType} AccountDataPushRules = Type{"m.push_rules", AccountDataEventType} AccountDataRoomTags = Type{"m.tag", AccountDataEventType} AccountDataFullyRead = Type{"m.fully_read", AccountDataEventType} AccountDataIgnoredUserList = Type{"m.ignored_user_list", AccountDataEventType} AccountDataSecretStorageDefaultKey = Type{"m.secret_storage.default_key", AccountDataEventType} AccountDataSecretStorageKey = Type{"m.secret_storage.key", AccountDataEventType} AccountDataCrossSigningMaster = Type{"m.cross_signing.master", AccountDataEventType} AccountDataCrossSigningUser = Type{"m.cross_signing.user_signing", AccountDataEventType} AccountDataCrossSigningSelf = Type{"m.cross_signing.self_signing", AccountDataEventType} ) // Device-to-device events var ( ToDeviceRoomKey = Type{"m.room_key", ToDeviceEventType} ToDeviceRoomKeyRequest = Type{"m.room_key_request", ToDeviceEventType} ToDeviceForwardedRoomKey = Type{"m.forwarded_room_key", ToDeviceEventType} ToDeviceEncrypted = Type{"m.room.encrypted", ToDeviceEventType} ToDeviceRoomKeyWithheld = Type{"m.room_key.withheld", ToDeviceEventType} ToDeviceDummy = Type{"m.dummy", ToDeviceEventType} ToDeviceVerificationRequest = Type{"m.key.verification.request", ToDeviceEventType} ToDeviceVerificationStart = Type{"m.key.verification.start", ToDeviceEventType} ToDeviceVerificationAccept = Type{"m.key.verification.accept", ToDeviceEventType} ToDeviceVerificationKey = Type{"m.key.verification.key", ToDeviceEventType} ToDeviceVerificationMAC = Type{"m.key.verification.mac", ToDeviceEventType} ToDeviceVerificationCancel = Type{"m.key.verification.cancel", ToDeviceEventType} ToDeviceOrgMatrixRoomKeyWithheld = Type{"org.matrix.room_key.withheld", ToDeviceEventType} ) go-0.11.1/event/verification.go000066400000000000000000000274171436100171500163420ustar00rootroot00000000000000// Copyright (c) 2020 Nikos Filippakis // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package event import ( "maunium.net/go/mautrix/id" ) type VerificationMethod string const VerificationMethodSAS VerificationMethod = "m.sas.v1" // VerificationRequestEventContent represents the content of a m.key.verification.request to_device event. // https://spec.matrix.org/v1.2/client-server-api/#mkeyverificationrequest type VerificationRequestEventContent struct { // The device ID which is initiating the request. FromDevice id.DeviceID `json:"from_device"` // An opaque identifier for the verification request. Must be unique with respect to the devices involved. TransactionID string `json:"transaction_id,omitempty"` // The verification methods supported by the sender. Methods []VerificationMethod `json:"methods"` // The POSIX timestamp in milliseconds for when the request was made. Timestamp int64 `json:"timestamp,omitempty"` // The user that the event is sent to for in-room verification. To id.UserID `json:"to,omitempty"` // Original event ID for in-room verification. RelatesTo *RelatesTo `json:"m.relates_to,omitempty"` } func (vrec *VerificationRequestEventContent) SupportsVerificationMethod(meth VerificationMethod) bool { for _, supportedMeth := range vrec.Methods { if supportedMeth == meth { return true } } return false } type KeyAgreementProtocol string const ( KeyAgreementCurve25519 KeyAgreementProtocol = "curve25519" KeyAgreementCurve25519HKDFSHA256 KeyAgreementProtocol = "curve25519-hkdf-sha256" ) type VerificationHashMethod string const VerificationHashSHA256 VerificationHashMethod = "sha256" type MACMethod string const HKDFHMACSHA256 MACMethod = "hkdf-hmac-sha256" type SASMethod string const ( SASDecimal SASMethod = "decimal" SASEmoji SASMethod = "emoji" ) // VerificationStartEventContent represents the content of a m.key.verification.start to_device event. // https://spec.matrix.org/v1.2/client-server-api/#mkeyverificationstartmsasv1 type VerificationStartEventContent struct { // The device ID which is initiating the process. FromDevice id.DeviceID `json:"from_device"` // An opaque identifier for the verification process. Must be unique with respect to the devices involved. TransactionID string `json:"transaction_id,omitempty"` // The verification method to use. Method VerificationMethod `json:"method"` // The key agreement protocols the sending device understands. KeyAgreementProtocols []KeyAgreementProtocol `json:"key_agreement_protocols"` // The hash methods the sending device understands. Hashes []VerificationHashMethod `json:"hashes"` // The message authentication codes that the sending device understands. MessageAuthenticationCodes []MACMethod `json:"message_authentication_codes"` // The SAS methods the sending device (and the sending device's user) understands. ShortAuthenticationString []SASMethod `json:"short_authentication_string"` // The user that the event is sent to for in-room verification. To id.UserID `json:"to,omitempty"` // Original event ID for in-room verification. RelatesTo *RelatesTo `json:"m.relates_to,omitempty"` } func (vsec *VerificationStartEventContent) SupportsKeyAgreementProtocol(proto KeyAgreementProtocol) bool { for _, supportedProto := range vsec.KeyAgreementProtocols { if supportedProto == proto { return true } } return false } func (vsec *VerificationStartEventContent) SupportsHashMethod(alg VerificationHashMethod) bool { for _, supportedAlg := range vsec.Hashes { if supportedAlg == alg { return true } } return false } func (vsec *VerificationStartEventContent) SupportsMACMethod(meth MACMethod) bool { for _, supportedMeth := range vsec.MessageAuthenticationCodes { if supportedMeth == meth { return true } } return false } func (vsec *VerificationStartEventContent) SupportsSASMethod(meth SASMethod) bool { for _, supportedMeth := range vsec.ShortAuthenticationString { if supportedMeth == meth { return true } } return false } func (vsec *VerificationStartEventContent) GetRelatesTo() *RelatesTo { if vsec.RelatesTo == nil { vsec.RelatesTo = &RelatesTo{} } return vsec.RelatesTo } func (vsec *VerificationStartEventContent) OptionalGetRelatesTo() *RelatesTo { return vsec.RelatesTo } func (vsec *VerificationStartEventContent) SetRelatesTo(rel *RelatesTo) { vsec.RelatesTo = rel } // VerificationReadyEventContent represents the content of a m.key.verification.ready event. // https://spec.matrix.org/v1.2/client-server-api/#mkeyverificationready type VerificationReadyEventContent struct { // The device ID which accepted the process. FromDevice id.DeviceID `json:"from_device"` // The verification methods supported by the sender. Methods []VerificationMethod `json:"methods"` // Original event ID for in-room verification. RelatesTo *RelatesTo `json:"m.relates_to,omitempty"` } var _ Relatable = (*VerificationReadyEventContent)(nil) func (vrec *VerificationReadyEventContent) GetRelatesTo() *RelatesTo { if vrec.RelatesTo == nil { vrec.RelatesTo = &RelatesTo{} } return vrec.RelatesTo } func (vrec *VerificationReadyEventContent) OptionalGetRelatesTo() *RelatesTo { return vrec.RelatesTo } func (vrec *VerificationReadyEventContent) SetRelatesTo(rel *RelatesTo) { vrec.RelatesTo = rel } // VerificationAcceptEventContent represents the content of a m.key.verification.accept to_device event. // https://spec.matrix.org/v1.2/client-server-api/#mkeyverificationaccept type VerificationAcceptEventContent struct { // An opaque identifier for the verification process. Must be the same as the one used for the m.key.verification.start message. TransactionID string `json:"transaction_id,omitempty"` // The verification method to use. Method VerificationMethod `json:"method"` // The key agreement protocol the device is choosing to use, out of the options in the m.key.verification.start message. KeyAgreementProtocol KeyAgreementProtocol `json:"key_agreement_protocol"` // The hash method the device is choosing to use, out of the options in the m.key.verification.start message. Hash VerificationHashMethod `json:"hash"` // The message authentication code the device is choosing to use, out of the options in the m.key.verification.start message. MessageAuthenticationCode MACMethod `json:"message_authentication_code"` // The SAS methods both devices involved in the verification process understand. Must be a subset of the options in the m.key.verification.start message. ShortAuthenticationString []SASMethod `json:"short_authentication_string"` // The hash (encoded as unpadded base64) of the concatenation of the device's ephemeral public key (encoded as unpadded base64) and the canonical JSON representation of the m.key.verification.start message. Commitment string `json:"commitment"` // The user that the event is sent to for in-room verification. To id.UserID `json:"to,omitempty"` // Original event ID for in-room verification. RelatesTo *RelatesTo `json:"m.relates_to,omitempty"` } func (vaec *VerificationAcceptEventContent) GetRelatesTo() *RelatesTo { if vaec.RelatesTo == nil { vaec.RelatesTo = &RelatesTo{} } return vaec.RelatesTo } func (vaec *VerificationAcceptEventContent) OptionalGetRelatesTo() *RelatesTo { return vaec.RelatesTo } func (vaec *VerificationAcceptEventContent) SetRelatesTo(rel *RelatesTo) { vaec.RelatesTo = rel } // VerificationKeyEventContent represents the content of a m.key.verification.key to_device event. // https://spec.matrix.org/v1.2/client-server-api/#mkeyverificationkey type VerificationKeyEventContent struct { // An opaque identifier for the verification process. Must be the same as the one used for the m.key.verification.start message. TransactionID string `json:"transaction_id,omitempty"` // The device's ephemeral public key, encoded as unpadded base64. Key string `json:"key"` // The user that the event is sent to for in-room verification. To id.UserID `json:"to,omitempty"` // Original event ID for in-room verification. RelatesTo *RelatesTo `json:"m.relates_to,omitempty"` } func (vkec *VerificationKeyEventContent) GetRelatesTo() *RelatesTo { if vkec.RelatesTo == nil { vkec.RelatesTo = &RelatesTo{} } return vkec.RelatesTo } func (vkec *VerificationKeyEventContent) OptionalGetRelatesTo() *RelatesTo { return vkec.RelatesTo } func (vkec *VerificationKeyEventContent) SetRelatesTo(rel *RelatesTo) { vkec.RelatesTo = rel } // VerificationMacEventContent represents the content of a m.key.verification.mac to_device event. // https://spec.matrix.org/v1.2/client-server-api/#mkeyverificationmac type VerificationMacEventContent struct { // An opaque identifier for the verification process. Must be the same as the one used for the m.key.verification.start message. TransactionID string `json:"transaction_id,omitempty"` // A map of the key ID to the MAC of the key, using the algorithm in the verification process. The MAC is encoded as unpadded base64. Mac map[id.KeyID]string `json:"mac"` // The MAC of the comma-separated, sorted, list of key IDs given in the mac property, encoded as unpadded base64. Keys string `json:"keys"` // The user that the event is sent to for in-room verification. To id.UserID `json:"to,omitempty"` // Original event ID for in-room verification. RelatesTo *RelatesTo `json:"m.relates_to,omitempty"` } func (vmec *VerificationMacEventContent) GetRelatesTo() *RelatesTo { if vmec.RelatesTo == nil { vmec.RelatesTo = &RelatesTo{} } return vmec.RelatesTo } func (vmec *VerificationMacEventContent) OptionalGetRelatesTo() *RelatesTo { return vmec.RelatesTo } func (vmec *VerificationMacEventContent) SetRelatesTo(rel *RelatesTo) { vmec.RelatesTo = rel } type VerificationCancelCode string const ( VerificationCancelByUser VerificationCancelCode = "m.user" VerificationCancelByTimeout VerificationCancelCode = "m.timeout" VerificationCancelUnknownTransaction VerificationCancelCode = "m.unknown_transaction" VerificationCancelUnknownMethod VerificationCancelCode = "m.unknown_method" VerificationCancelUnexpectedMessage VerificationCancelCode = "m.unexpected_message" VerificationCancelKeyMismatch VerificationCancelCode = "m.key_mismatch" VerificationCancelUserMismatch VerificationCancelCode = "m.user_mismatch" VerificationCancelInvalidMessage VerificationCancelCode = "m.invalid_message" VerificationCancelAccepted VerificationCancelCode = "m.accepted" VerificationCancelSASMismatch VerificationCancelCode = "m.mismatched_sas" VerificationCancelCommitmentMismatch VerificationCancelCode = "m.mismatched_commitment" ) // VerificationCancelEventContent represents the content of a m.key.verification.cancel to_device event. // https://spec.matrix.org/v1.2/client-server-api/#mkeyverificationcancel type VerificationCancelEventContent struct { // The opaque identifier for the verification process/request. TransactionID string `json:"transaction_id,omitempty"` // A human readable description of the code. The client should only rely on this string if it does not understand the code. Reason string `json:"reason"` // The error code for why the process/request was cancelled by the user. Code VerificationCancelCode `json:"code"` // The user that the event is sent to for in-room verification. To id.UserID `json:"to,omitempty"` // Original event ID for in-room verification. RelatesTo *RelatesTo `json:"m.relates_to,omitempty"` } func (vcec *VerificationCancelEventContent) GetRelatesTo() *RelatesTo { if vcec.RelatesTo == nil { vcec.RelatesTo = &RelatesTo{} } return vcec.RelatesTo } func (vcec *VerificationCancelEventContent) OptionalGetRelatesTo() *RelatesTo { return vcec.RelatesTo } func (vcec *VerificationCancelEventContent) SetRelatesTo(rel *RelatesTo) { vcec.RelatesTo = rel } go-0.11.1/event/voip.go000066400000000000000000000053351436100171500146300ustar00rootroot00000000000000// Copyright (c) 2021 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package event import ( "encoding/json" "fmt" "strconv" ) type CallHangupReason string const ( CallHangupICEFailed CallHangupReason = "ice_failed" CallHangupInviteTimeout CallHangupReason = "invite_timeout" CallHangupUserHangup CallHangupReason = "user_hangup" CallHangupUserMediaFailed CallHangupReason = "user_media_failed" CallHangupUnknownError CallHangupReason = "unknown_error" ) type CallDataType string const ( CallDataTypeOffer CallDataType = "offer" CallDataTypeAnswer CallDataType = "answer" ) type CallData struct { SDP string `json:"sdp"` Type CallDataType `json:"type"` } type CallCandidate struct { Candidate string `json:"candidate"` SDPMLineIndex int `json:"sdpMLineIndex"` SDPMID string `json:"sdpMid"` } type CallVersion string func (cv *CallVersion) UnmarshalJSON(raw []byte) error { var numberVersion int err := json.Unmarshal(raw, &numberVersion) if err != nil { var stringVersion string err = json.Unmarshal(raw, &stringVersion) if err != nil { return fmt.Errorf("failed to parse CallVersion: %w", err) } *cv = CallVersion(stringVersion) } else { *cv = CallVersion(strconv.Itoa(numberVersion)) } return nil } func (cv *CallVersion) MarshalJSON() ([]byte, error) { for _, char := range *cv { if char < '0' || char > '9' { // The version contains weird characters, return as string. return json.Marshal(string(*cv)) } } // The version consists of only ASCII digits, return as an integer. return []byte(*cv), nil } func (cv *CallVersion) Int() (int, error) { return strconv.Atoi(string(*cv)) } type BaseCallEventContent struct { CallID string `json:"call_id"` PartyID string `json:"party_id"` Version CallVersion `json:"version"` } type CallInviteEventContent struct { BaseCallEventContent Lifetime int `json:"lifetime"` Offer CallData `json:"offer"` } type CallCandidatesEventContent struct { BaseCallEventContent Candidates []CallCandidate `json:"candidates"` } type CallRejectEventContent struct { BaseCallEventContent } type CallAnswerEventContent struct { BaseCallEventContent Answer CallData `json:"answer"` } type CallSelectAnswerEventContent struct { BaseCallEventContent SelectedPartyID string `json:"selected_party_id"` } type CallNegotiateEventContent struct { BaseCallEventContent Lifetime int `json:"lifetime"` Description CallData `json:"description"` } type CallHangupEventContent struct { BaseCallEventContent Reason CallHangupReason `json:"reason"` } go-0.11.1/event/voip_test.go000066400000000000000000000110531436100171500156610ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package event_test import ( "encoding/json" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "maunium.net/go/mautrix/event" ) const callCandidates = `{ "type": "m.call.candidates", "event_id": "$143273582443PhrSn:example.org", "origin_server_ts": 1432735824653, "room_id": "!jEsUZKDJdhlrceRyVU:example.org", "sender": "@example:example.org", "content": { "call_id": "12345", "candidates": [ { "candidate": "candidate:863018703 1 udp 2122260223 10.9.64.156 43670 typ host generation 0", "sdpMLineIndex": 0, "sdpMid": "audio" } ], "version": 0 }, "unsigned": { "age": 1234 } }` const callSelectAnswer = `{ "type": "m.call.select_answer", "event_id": "$143273582443PhrSn:example.org", "origin_server_ts": 1432735824653, "room_id": "!jEsUZKDJdhlrceRyVU:example.org", "sender": "@example:example.org", "content": { "version": 1, "call_id": "12345", "party_id": "67890", "selected_party_id": "111213" }, "unsigned": { "age": 1234 } }` const callAnswerStringVersion = `{ "type": "m.call.answer", "event_id": "$143273582443PhrSn:example.org", "origin_server_ts": 1432735824653, "room_id": "!jEsUZKDJdhlrceRyVU:example.org", "sender": "@example:example.org", "content": { "answer": { "sdp": "v=0\r\no=- 6584580628695956864 2 IN IP4 127.0.0.1[...]", "type": "answer" }, "call_id": "12345", "lifetime": 60000, "version": "com.example.call.version" }, "unsigned": { "age": 1234 } }` func TestCallCandidatesEventContent_Parse(t *testing.T) { var evt *event.Event err := json.Unmarshal([]byte(callCandidates), &evt) require.NoError(t, err) require.Equal(t, evt.Type, event.CallCandidates) err = evt.Content.ParseRaw(evt.Type) require.NoError(t, err) content := evt.Content.AsCallCandidates() require.NotNil(t, content) assert.Equal(t, event.CallVersion("0"), content.Version) } func TestCallSelectAnswerEventContent_Parse(t *testing.T) { var evt *event.Event err := json.Unmarshal([]byte(callSelectAnswer), &evt) require.NoError(t, err) require.Equal(t, evt.Type, event.CallSelectAnswer) err = evt.Content.ParseRaw(evt.Type) require.NoError(t, err) content := evt.Content.AsCallSelectAnswer() require.NotNil(t, content) assert.Equal(t, event.CallVersion("1"), content.Version) } func TestCallAnswerContent_Parse(t *testing.T) { var evt *event.Event err := json.Unmarshal([]byte(callAnswerStringVersion), &evt) require.NoError(t, err) require.Equal(t, evt.Type, event.CallAnswer) err = evt.Content.ParseRaw(evt.Type) require.NoError(t, err) content := evt.Content.AsCallAnswer() require.NotNil(t, content) assert.Equal(t, event.CallVersion("com.example.call.version"), content.Version) } func TestCallVersion_MarshalJSON(t *testing.T) { var version event.CallVersion var data []byte var err error version = "1" data, err = json.Marshal(&version) assert.NoError(t, err) assert.Equal(t, []byte("1"), data) version = "0" data, err = json.Marshal(&version) assert.NoError(t, err) assert.Equal(t, []byte("0"), data) version = "1234" data, err = json.Marshal(&version) assert.NoError(t, err) assert.Equal(t, []byte("1234"), data) version = "com.example.call.version" data, err = json.Marshal(&version) assert.NoError(t, err) assert.Equal(t, []byte(`"com.example.call.version"`), data) } func TestCallVersion_UnmarshalJSON(t *testing.T) { var version event.CallVersion var err error err = json.Unmarshal([]byte(`1`), &version) assert.NoError(t, err) assert.Equal(t, event.CallVersion("1"), version) err = json.Unmarshal([]byte(`0`), &version) assert.NoError(t, err) assert.Equal(t, event.CallVersion("0"), version) err = json.Unmarshal([]byte(`1234`), &version) assert.NoError(t, err) assert.Equal(t, event.CallVersion("1234"), version) err = json.Unmarshal([]byte(`"1234"`), &version) assert.NoError(t, err) assert.Equal(t, event.CallVersion("1234"), version) err = json.Unmarshal([]byte(`"com.example.call.version"`), &version) assert.NoError(t, err) assert.Equal(t, event.CallVersion("com.example.call.version"), version) err = json.Unmarshal([]byte(`1.234`), &version) assert.Error(t, err) err = json.Unmarshal([]byte(`false`), &version) assert.Error(t, err) err = json.Unmarshal([]byte(`["hmm"]`), &version) assert.Error(t, err) err = json.Unmarshal([]byte(`{"hmm": true}`), &version) assert.Error(t, err) } go-0.11.1/example/000077500000000000000000000000001436100171500136305ustar00rootroot00000000000000go-0.11.1/example/go.mod000066400000000000000000000001071436100171500147340ustar00rootroot00000000000000module mautrix-example go 1.15 require maunium.net/go/mautrix v0.7.6 go-0.11.1/example/go.sum000066400000000000000000000067121436100171500147710ustar00rootroot00000000000000github.com/PuerkitoBio/goquery v1.5.1/go.mod h1:GsLWisAFVj4WgDibEWF4pvYnkVQBpKBKeU+7zCJoLcc= github.com/andybalholm/cascadia v1.1.0/go.mod h1:GsXiBklL0woXo1j/WYWtSYYC4ouU9PqHO0sqidkEA4Y= github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/gorilla/mux v1.7.4/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/lib/pq v1.7.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mattn/go-sqlite3 v1.14.0/go.mod h1:JIl7NbARA7phWnGvh0LKTyg7S9BA+6gx71ShQilpsus= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/tidwall/gjson v1.6.0/go.mod h1:P256ACg0Mn+j1RXIDXoss50DeIABTYK1PULOJHhxOls= github.com/tidwall/match v1.0.1/go.mod h1:LujAq0jyVjBy028G1WhWfIzbpQfMO8bBZ6Tyb0+pL9E= github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= github.com/tidwall/pretty v1.0.1/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= github.com/tidwall/sjson v1.1.1/go.mod h1:yvVuSnpEQv5cYIrO+AT6kw4QVfd5SDZoGIS7/5+fZFs= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/net v0.0.0-20180218175443-cbe0f9307d01/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200324143707-d3edc9973b7e/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20200602114024-627f9648deb9 h1:pNX+40auqi2JqRfOP1akLGtYcn15TUbkhwuCO3foqqM= golang.org/x/net v0.0.0-20200602114024-627f9648deb9/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= maunium.net/go/maulogger/v2 v2.1.1/go.mod h1:TYWy7wKwz/tIXTpsx8G3mZseIRiC5DoMxSZazOHy68A= maunium.net/go/mautrix v0.7.6 h1:jB9oCimPq0mVyolwQBC/9N1fu21AU+Ryq837cLf4gOo= maunium.net/go/mautrix v0.7.6/go.mod h1:Va/74MijqaS0DQ3aUqxmFO54/PMfr1LVsCOcGRHbYmo= go-0.11.1/example/main.go000066400000000000000000000037131436100171500151070ustar00rootroot00000000000000// Copyright (C) 2017 Tulir Asokan // Copyright (C) 2018-2020 Luca Weiss // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU General Public License for more details. // // You should have received a copy of the GNU General Public License // along with this program. If not, see . package main import ( "flag" "fmt" "os" "maunium.net/go/mautrix" "maunium.net/go/mautrix/event" ) var homeserver = flag.String("homeserver", "", "Matrix homeserver") var username = flag.String("username", "", "Matrix username localpart") var password = flag.String("password", "", "Matrix password") func main() { flag.Parse() if *username == "" || *password == "" || *homeserver == "" { _, _ = fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0]) flag.PrintDefaults() os.Exit(1) } fmt.Println("Logging into", *homeserver, "as", *username) client, err := mautrix.NewClient(*homeserver, "", "") if err != nil { panic(err) } _, err = client.Login(&mautrix.ReqLogin{ Type: "m.login.password", Identifier: mautrix.UserIdentifier{Type: mautrix.IdentifierTypeUser, User: *username}, Password: *password, StoreCredentials: true, }) if err != nil { panic(err) } fmt.Println("Login successful") syncer := client.Syncer.(*mautrix.DefaultSyncer) syncer.OnEventType(event.EventMessage, func(source mautrix.EventSource, evt *event.Event) { fmt.Printf("<%[1]s> %[4]s (%[2]s/%[3]s)\n", evt.Sender, evt.Type.String(), evt.ID, evt.Content.AsMessage().Body) }) err = client.Sync() if err != nil { panic(err) } } go-0.11.1/filter.go000066400000000000000000000057551436100171500140250ustar00rootroot00000000000000// Copyright 2017 Jan Christian Grรผnhage package mautrix import ( "errors" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) type EventFormat string const ( EventFormatClient EventFormat = "client" EventFormatFederation EventFormat = "federation" ) // Filter is used by clients to specify how the server should filter responses to e.g. sync requests // Specified by: https://spec.matrix.org/v1.2/client-server-api/#filtering type Filter struct { AccountData FilterPart `json:"account_data,omitempty"` EventFields []string `json:"event_fields,omitempty"` EventFormat EventFormat `json:"event_format,omitempty"` Presence FilterPart `json:"presence,omitempty"` Room RoomFilter `json:"room,omitempty"` } // RoomFilter is used to define filtering rules for room events type RoomFilter struct { AccountData FilterPart `json:"account_data,omitempty"` Ephemeral FilterPart `json:"ephemeral,omitempty"` IncludeLeave bool `json:"include_leave,omitempty"` NotRooms []id.RoomID `json:"not_rooms,omitempty"` Rooms []id.RoomID `json:"rooms,omitempty"` State FilterPart `json:"state,omitempty"` Timeline FilterPart `json:"timeline,omitempty"` } // FilterPart is used to define filtering rules for specific categories of events type FilterPart struct { NotRooms []id.RoomID `json:"not_rooms,omitempty"` Rooms []id.RoomID `json:"rooms,omitempty"` Limit int `json:"limit,omitempty"` NotSenders []id.UserID `json:"not_senders,omitempty"` NotTypes []event.Type `json:"not_types,omitempty"` Senders []id.UserID `json:"senders,omitempty"` Types []event.Type `json:"types,omitempty"` ContainsURL *bool `json:"contains_url,omitempty"` LazyLoadMembers bool `json:"lazy_load_members,omitempty"` IncludeRedundantMembers bool `json:"include_redundant_members,omitempty"` } // Validate checks if the filter contains valid property values func (filter *Filter) Validate() error { if filter.EventFormat != EventFormatClient && filter.EventFormat != EventFormatFederation { return errors.New("Bad event_format value. Must be one of [\"client\", \"federation\"]") } return nil } // DefaultFilter returns the default filter used by the Matrix server if no filter is provided in the request func DefaultFilter() Filter { return Filter{ AccountData: DefaultFilterPart(), EventFields: nil, EventFormat: "client", Presence: DefaultFilterPart(), Room: RoomFilter{ AccountData: DefaultFilterPart(), Ephemeral: DefaultFilterPart(), IncludeLeave: false, NotRooms: nil, Rooms: nil, State: DefaultFilterPart(), Timeline: DefaultFilterPart(), }, } } // DefaultFilterPart returns the default filter part used by the Matrix server if no filter is provided in the request func DefaultFilterPart() FilterPart { return FilterPart{ NotRooms: nil, Rooms: nil, Limit: 20, NotSenders: nil, NotTypes: nil, Senders: nil, Types: nil, } } go-0.11.1/format/000077500000000000000000000000001436100171500134655ustar00rootroot00000000000000go-0.11.1/format/doc.go000066400000000000000000000004021436100171500145550ustar00rootroot00000000000000// Package format contains utilities for working with Matrix HTML, specifically // methods to parse Markdown into HTML and to parse Matrix HTML into text or markdown. // // https://spec.matrix.org/v1.2/client-server-api/#mroommessage-msgtypes package format go-0.11.1/format/htmlparser.go000066400000000000000000000230361436100171500162010ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package format import ( "fmt" "math" "strconv" "strings" "golang.org/x/net/html" "maunium.net/go/mautrix/id" ) type Context map[string]interface{} type TextConverter func(string, Context) string type CodeBlockConverter func(code, language string, ctx Context) string type PillConverter func(displayname, mxid, eventID string, ctx Context) string func DefaultPillConverter(displayname, mxid, eventID string, _ Context) string { switch { case len(mxid) == 0, mxid[0] == '@': // User link, always just show the displayname return displayname case len(eventID) > 0: // Event ID link, always just show the link return fmt.Sprintf("https://matrix.to/#/%s/%s", mxid, eventID) case mxid[0] == '!' && displayname == mxid: // Room ID link with no separate display text, just show the link return fmt.Sprintf("https://matrix.to/#/%s", mxid) case mxid[0] == '#': // Room alias link, just show the alias return mxid default: // Other link (e.g. room ID link with display text), show text and link return fmt.Sprintf("%s (https://matrix.to/#/%s)", displayname, mxid) } } // HTMLParser is a somewhat customizable Matrix HTML parser. type HTMLParser struct { PillConverter PillConverter TabsToSpaces int Newline string HorizontalLine string BoldConverter TextConverter ItalicConverter TextConverter StrikethroughConverter TextConverter UnderlineConverter TextConverter MonospaceBlockConverter CodeBlockConverter MonospaceConverter TextConverter } // TaggedString is a string that also contains a HTML tag. type TaggedString struct { string tag string } func (parser *HTMLParser) getAttribute(node *html.Node, attribute string) string { for _, attr := range node.Attr { if attr.Key == attribute { return attr.Val } } return "" } // Digits counts the number of digits (and the sign, if negative) in an integer. func Digits(num int) int { if num == 0 { return 1 } else if num < 0 { return Digits(-num) + 1 } return int(math.Floor(math.Log10(float64(num))) + 1) } func (parser *HTMLParser) listToString(node *html.Node, stripLinebreak bool, ctx Context) string { ordered := node.Data == "ol" taggedChildren := parser.nodeToTaggedStrings(node.FirstChild, stripLinebreak, ctx) counter := 1 indentLength := 0 if ordered { start := parser.getAttribute(node, "start") if len(start) > 0 { counter, _ = strconv.Atoi(start) } longestIndex := (counter - 1) + len(taggedChildren) indentLength = Digits(longestIndex) } indent := strings.Repeat(" ", indentLength+2) var children []string for _, child := range taggedChildren { if child.tag != "li" { continue } var prefix string // TODO make bullets and numbering configurable if ordered { indexPadding := indentLength - Digits(counter) prefix = fmt.Sprintf("%d. %s", counter, strings.Repeat(" ", indexPadding)) } else { prefix = "* " } str := prefix + child.string counter++ parts := strings.Split(str, "\n") for i, part := range parts[1:] { parts[i+1] = indent + part } str = strings.Join(parts, "\n") children = append(children, str) } return strings.Join(children, "\n") } func (parser *HTMLParser) basicFormatToString(node *html.Node, stripLinebreak bool, ctx Context) string { str := parser.nodeToTagAwareString(node.FirstChild, stripLinebreak, ctx) switch node.Data { case "b", "strong": if parser.BoldConverter != nil { return parser.BoldConverter(str, ctx) } return fmt.Sprintf("**%s**", str) case "i", "em": if parser.ItalicConverter != nil { return parser.ItalicConverter(str, ctx) } return fmt.Sprintf("_%s_", str) case "s", "del", "strike": if parser.StrikethroughConverter != nil { return parser.StrikethroughConverter(str, ctx) } return fmt.Sprintf("~~%s~~", str) case "u", "ins": if parser.UnderlineConverter != nil { return parser.UnderlineConverter(str, ctx) } case "tt", "code": if parser.MonospaceConverter != nil { return parser.MonospaceConverter(str, ctx) } return fmt.Sprintf("`%s`", str) } return str } func (parser *HTMLParser) headerToString(node *html.Node, stripLinebreak bool, ctx Context) string { children := parser.nodeToStrings(node.FirstChild, stripLinebreak, ctx) length := int(node.Data[1] - '0') prefix := strings.Repeat("#", length) + " " return prefix + strings.Join(children, "") } func (parser *HTMLParser) blockquoteToString(node *html.Node, stripLinebreak bool, ctx Context) string { str := parser.nodeToTagAwareString(node.FirstChild, stripLinebreak, ctx) childrenArr := strings.Split(strings.TrimSpace(str), "\n") // TODO make blockquote prefix configurable for index, child := range childrenArr { childrenArr[index] = "> " + child } return strings.Join(childrenArr, "\n") } func (parser *HTMLParser) linkToString(node *html.Node, stripLinebreak bool, ctx Context) string { str := parser.nodeToTagAwareString(node.FirstChild, stripLinebreak, ctx) href := parser.getAttribute(node, "href") if len(href) == 0 { return str } if parser.PillConverter != nil { parsedMatrix, err := id.ParseMatrixURIOrMatrixToURL(href) if err == nil && parsedMatrix != nil { return parser.PillConverter(str, parsedMatrix.PrimaryIdentifier(), parsedMatrix.SecondaryIdentifier(), ctx) } } if str == href { return str } return fmt.Sprintf("%s (%s)", str, href) } func (parser *HTMLParser) tagToString(node *html.Node, stripLinebreak bool, ctx Context) string { switch node.Data { case "blockquote": return parser.blockquoteToString(node, stripLinebreak, ctx) case "ol", "ul": return parser.listToString(node, stripLinebreak, ctx) case "h1", "h2", "h3", "h4", "h5", "h6": return parser.headerToString(node, stripLinebreak, ctx) case "br": return parser.Newline case "b", "strong", "i", "em", "s", "strike", "del", "u", "ins", "tt", "code": return parser.basicFormatToString(node, stripLinebreak, ctx) case "a": return parser.linkToString(node, stripLinebreak, ctx) case "p": return parser.nodeToTagAwareString(node.FirstChild, stripLinebreak, ctx) case "hr": return parser.HorizontalLine case "pre": var preStr, language string if node.FirstChild != nil && node.FirstChild.Type == html.ElementNode && node.FirstChild.Data == "code" { class := parser.getAttribute(node.FirstChild, "class") if strings.HasPrefix(class, "language-") { language = class[len("language-"):] } preStr = parser.nodeToString(node.FirstChild.FirstChild, false, ctx) } else { preStr = parser.nodeToString(node.FirstChild, false, ctx) } if parser.MonospaceBlockConverter != nil { return parser.MonospaceBlockConverter(preStr, language, ctx) } if len(preStr) == 0 || preStr[len(preStr)-1] != '\n' { preStr += "\n" } return fmt.Sprintf("```%s\n%s```", language, preStr) default: return parser.nodeToTagAwareString(node.FirstChild, stripLinebreak, ctx) } } func (parser *HTMLParser) singleNodeToString(node *html.Node, stripLinebreak bool, ctx Context) TaggedString { switch node.Type { case html.TextNode: if stripLinebreak { node.Data = strings.Replace(node.Data, "\n", "", -1) } return TaggedString{node.Data, "text"} case html.ElementNode: return TaggedString{parser.tagToString(node, stripLinebreak, ctx), node.Data} case html.DocumentNode: return TaggedString{parser.nodeToTagAwareString(node.FirstChild, stripLinebreak, ctx), "html"} default: return TaggedString{"", "unknown"} } } func (parser *HTMLParser) nodeToTaggedStrings(node *html.Node, stripLinebreak bool, ctx Context) (strs []TaggedString) { for ; node != nil; node = node.NextSibling { strs = append(strs, parser.singleNodeToString(node, stripLinebreak, ctx)) } return } var BlockTags = []string{"p", "h1", "h2", "h3", "h4", "h5", "h6", "ol", "ul", "pre", "blockquote", "div", "hr", "table"} func (parser *HTMLParser) isBlockTag(tag string) bool { for _, blockTag := range BlockTags { if tag == blockTag { return true } } return false } func (parser *HTMLParser) nodeToTagAwareString(node *html.Node, stripLinebreak bool, ctx Context) string { strs := parser.nodeToTaggedStrings(node, stripLinebreak, ctx) var output strings.Builder for _, str := range strs { tstr := str.string if parser.isBlockTag(str.tag) { tstr = fmt.Sprintf("\n%s\n", tstr) } output.WriteString(tstr) } return strings.TrimSpace(output.String()) } func (parser *HTMLParser) nodeToStrings(node *html.Node, stripLinebreak bool, ctx Context) (strs []string) { for ; node != nil; node = node.NextSibling { strs = append(strs, parser.singleNodeToString(node, stripLinebreak, ctx).string) } return } func (parser *HTMLParser) nodeToString(node *html.Node, stripLinebreak bool, ctx Context) string { return strings.Join(parser.nodeToStrings(node, stripLinebreak, ctx), "") } // Parse converts Matrix HTML into text using the settings in this parser. func (parser *HTMLParser) Parse(htmlData string, ctx Context) string { if parser.TabsToSpaces >= 0 { htmlData = strings.Replace(htmlData, "\t", strings.Repeat(" ", parser.TabsToSpaces), -1) } node, _ := html.Parse(strings.NewReader(htmlData)) return parser.nodeToTagAwareString(node, true, ctx) } // HTMLToText converts Matrix HTML into text with the default settings. func HTMLToText(html string) string { return (&HTMLParser{ TabsToSpaces: 4, Newline: "\n", HorizontalLine: "\n---\n", PillConverter: DefaultPillConverter, }).Parse(html, make(Context)) } go-0.11.1/format/markdown.go000066400000000000000000000033061436100171500156400ustar00rootroot00000000000000// Copyright (c) 2022 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package format import ( "fmt" "regexp" "strings" "github.com/yuin/goldmark" "github.com/yuin/goldmark/extension" "github.com/yuin/goldmark/renderer/html" "maunium.net/go/mautrix/event" ) var AntiParagraphRegex = regexp.MustCompile("^

(.+?)

$") var Extensions = goldmark.WithExtensions(extension.Strikethrough, extension.Table, ExtensionSpoiler) var HTMLOptions = goldmark.WithRendererOptions(html.WithHardWraps(), html.WithUnsafe()) var withHTML = goldmark.New(Extensions, HTMLOptions) var noHTML = goldmark.New(Extensions, HTMLOptions, goldmark.WithExtensions(ExtensionEscapeHTML)) func RenderMarkdown(text string, allowMarkdown, allowHTML bool) event.MessageEventContent { var htmlBody string if allowMarkdown { rndr := withHTML if !allowHTML { rndr = noHTML } var buf strings.Builder err := rndr.Convert([]byte(text), &buf) if err != nil { panic(fmt.Errorf("markdown parser errored: %w", err)) } htmlBody = strings.TrimRight(buf.String(), "\n") htmlBody = AntiParagraphRegex.ReplaceAllString(htmlBody, "$1") } else { htmlBody = strings.Replace(text, "\n", "
", -1) } if len(htmlBody) > 0 && (allowMarkdown || allowHTML) { text = HTMLToText(htmlBody) if htmlBody != text { return event.MessageEventContent{ FormattedBody: htmlBody, Format: event.FormatHTML, MsgType: event.MsgText, Body: text, } } } return event.MessageEventContent{ MsgType: event.MsgText, Body: text, } } go-0.11.1/format/markdown_test.go000066400000000000000000000035661436100171500167070ustar00rootroot00000000000000// Copyright (c) 2022 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package format_test import ( "strings" "testing" "github.com/stretchr/testify/assert" "maunium.net/go/mautrix/format" ) var spoilerTests = map[string]string{ "test ||bar||": "test bar", "test ||reason|**bar**||": `test bar`, "test ||reason|[bar](https://example.com)||": `test bar`, "test [||reason|foo||](https://example.com)": `test foo`, "test [||foo||](https://example.com)": `test foo`, "test [||*foo*||](https://example.com)": `test foo`, // FIXME wrapping spoilers in italic/bold/strikethrough doesn't work for some reason //"test **[||foo||](https://example.com)**": `test foo`, //"test **||foo||**": `test foo`, "* ||foo||": `
  • foo
`, "> ||foo||": "

foo

", } func TestRenderMarkdown_Spoiler(t *testing.T) { for markdown, html := range spoilerTests { rendered := format.RenderMarkdown(markdown, true, false) // FIXME the HTML parser doesn't do spoilers yet //assert.Equal(t, plaintext, rendered.Body) assert.Equal(t, html, strings.ReplaceAll(rendered.FormattedBody, "\n", "")) } } go-0.11.1/format/mdnohtml.go000066400000000000000000000034031436100171500156360ustar00rootroot00000000000000// Copyright (c) 2022 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package format import ( "github.com/yuin/goldmark" "github.com/yuin/goldmark/ast" "github.com/yuin/goldmark/renderer" "github.com/yuin/goldmark/renderer/html" "github.com/yuin/goldmark/util" ) type extEscapeHTML struct{} type escapingHTMLRenderer struct{} var ExtensionEscapeHTML = &extEscapeHTML{} var defaultEHR = &escapingHTMLRenderer{} func (eeh *extEscapeHTML) Extend(m goldmark.Markdown) { m.Renderer().AddOptions(renderer.WithNodeRenderers(util.Prioritized(defaultEHR, 0))) } func (ehr *escapingHTMLRenderer) RegisterFuncs(reg renderer.NodeRendererFuncRegisterer) { reg.Register(ast.KindHTMLBlock, ehr.renderHTMLBlock) reg.Register(ast.KindRawHTML, ehr.renderRawHTML) } func (ehr *escapingHTMLRenderer) renderRawHTML(w util.BufWriter, source []byte, node ast.Node, entering bool) (ast.WalkStatus, error) { if !entering { return ast.WalkSkipChildren, nil } n := node.(*ast.RawHTML) l := n.Segments.Len() for i := 0; i < l; i++ { segment := n.Segments.At(i) html.DefaultWriter.RawWrite(w, segment.Value(source)) } return ast.WalkSkipChildren, nil } func (ehr *escapingHTMLRenderer) renderHTMLBlock(w util.BufWriter, source []byte, node ast.Node, entering bool) (ast.WalkStatus, error) { n := node.(*ast.HTMLBlock) if entering { l := n.Lines().Len() for i := 0; i < l; i++ { line := n.Lines().At(i) html.DefaultWriter.RawWrite(w, line.Value(source)) } } else { if n.HasClosure() { closure := n.ClosureLine html.DefaultWriter.RawWrite(w, closure.Value(source)) } } return ast.WalkContinue, nil } go-0.11.1/format/mdspoiler.go000066400000000000000000000077041436100171500160220ustar00rootroot00000000000000// Copyright (c) 2022 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package format import ( "bytes" "fmt" stdhtml "html" "regexp" "github.com/yuin/goldmark" "github.com/yuin/goldmark/ast" "github.com/yuin/goldmark/parser" "github.com/yuin/goldmark/renderer" "github.com/yuin/goldmark/renderer/html" "github.com/yuin/goldmark/text" "github.com/yuin/goldmark/util" ) var astKindSpoiler = ast.NewNodeKind("Spoiler") type astSpoiler struct { ast.BaseInline reason string } func (n *astSpoiler) Dump(source []byte, level int) { ast.DumpHelper(n, source, level, nil, nil) } func (n *astSpoiler) Kind() ast.NodeKind { return astKindSpoiler } type spoilerDelimiterProcessor struct{} var defaultSpoilerDelimiterProcessor = &spoilerDelimiterProcessor{} func (p *spoilerDelimiterProcessor) IsDelimiter(b byte) bool { return b == '|' } func (p *spoilerDelimiterProcessor) CanOpenCloser(opener, closer *parser.Delimiter) bool { return opener.Char == closer.Char } func (p *spoilerDelimiterProcessor) OnMatch(consumes int) ast.Node { return &astSpoiler{} } type spoilerParser struct{} var defaultSpoilerParser = &spoilerParser{} func newSpoilerParser() parser.InlineParser { return defaultSpoilerParser } func (s *spoilerParser) Trigger() []byte { return []byte{'|'} } var spoilerRegex = regexp.MustCompile(`^\|\|(?:([^|]+?)\|[^|])?`) var spoilerContextKey = parser.NewContextKey() type spoilerContext struct { reason string segment text.Segment bottom *parser.Delimiter } func (s *spoilerParser) Parse(parent ast.Node, block text.Reader, pc parser.Context) ast.Node { line, segment := block.PeekLine() if spoiler, ok := pc.Get(spoilerContextKey).(spoilerContext); ok { if !bytes.HasPrefix(line, []byte("||")) { return nil } block.Advance(2) pc.Set(spoilerContextKey, nil) n := &astSpoiler{ BaseInline: ast.BaseInline{}, reason: spoiler.reason, } parser.ProcessDelimiters(spoiler.bottom, pc) var c ast.Node = spoiler.bottom for c != nil { next := c.NextSibling() parent.RemoveChild(parent, c) n.AppendChild(n, c) c = next } return n } match := spoilerRegex.FindSubmatch(line) if match == nil { return nil } length := 2 reason := string(match[1]) if len(reason) > 0 { length += len(match[1]) + 1 } block.Advance(length) delim := parser.NewDelimiter(true, false, length, '|', defaultSpoilerDelimiterProcessor) pc.Set(spoilerContextKey, spoilerContext{ reason: reason, segment: segment, bottom: delim, }) return delim } func (s *spoilerParser) CloseBlock(parent ast.Node, pc parser.Context) { // nothing to do } type spoilerHTMLRenderer struct { html.Config } func newSpoilerHTMLRenderer(opts ...html.Option) renderer.NodeRenderer { r := &spoilerHTMLRenderer{ Config: html.NewConfig(), } for _, opt := range opts { opt.SetHTMLOption(&r.Config) } return r } func (r *spoilerHTMLRenderer) RegisterFuncs(reg renderer.NodeRendererFuncRegisterer) { reg.Register(astKindSpoiler, r.renderSpoiler) } func (r *spoilerHTMLRenderer) renderSpoiler(w util.BufWriter, source []byte, n ast.Node, entering bool) (ast.WalkStatus, error) { if entering { node := n.(*astSpoiler) if len(node.reason) == 0 { _, _ = w.WriteString("") } else { _, _ = fmt.Fprintf(w, ``, stdhtml.EscapeString(node.reason)) } } else { _, _ = w.WriteString("") } return ast.WalkContinue, nil } type extSpoiler struct{} // ExtensionSpoiler is an extension that allow you to use spoiler expression like '~~text~~' . var ExtensionSpoiler = &extSpoiler{} func (e *extSpoiler) Extend(m goldmark.Markdown) { m.Parser().AddOptions(parser.WithInlineParsers( util.Prioritized(newSpoilerParser(), 500), )) m.Renderer().AddOptions(renderer.WithNodeRenderers( util.Prioritized(newSpoilerHTMLRenderer(), 500), )) } go-0.11.1/go.mod000066400000000000000000000013331436100171500133030ustar00rootroot00000000000000module maunium.net/go/mautrix go 1.17 require ( github.com/gorilla/mux v1.8.0 github.com/gorilla/websocket v1.5.0 github.com/mattn/go-sqlite3 v1.14.13 github.com/stretchr/testify v1.7.1 github.com/tidwall/gjson v1.14.1 github.com/tidwall/sjson v1.2.4 github.com/yuin/goldmark v1.4.12 golang.org/x/crypto v0.0.0-20220513210258-46612604a0f9 golang.org/x/net v0.0.0-20220513224357-95641704303c gopkg.in/yaml.v2 v2.4.0 maunium.net/go/maulogger/v2 v2.3.2 ) require ( github.com/davecgh/go-spew v1.1.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect ) go-0.11.1/go.sum000066400000000000000000000102101436100171500133220ustar00rootroot00000000000000github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc= github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/mattn/go-sqlite3 v1.14.13 h1:1tj15ngiFfcZzii7yd82foL+ks+ouQcj8j/TPq3fk1I= github.com/mattn/go-sqlite3 v1.14.13/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/tidwall/gjson v1.12.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.14.1 h1:iymTbGkQBhveq21bEvAQ81I0LEBork8BFe1CUZXdyuo= github.com/tidwall/gjson v1.14.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.4 h1:cuiLzLnaMeBhRmEv00Lpk3tkYrcxpmbU81tAY4Dw0tc= github.com/tidwall/sjson v1.2.4/go.mod h1:098SZ494YoMWPmMO6ct4dcFnqxwj9r/gF0Etp19pSNM= github.com/yuin/goldmark v1.4.12 h1:6hffw6vALvEDqJ19dOJvJKOoAOKe4NDaTqvd2sktGN0= github.com/yuin/goldmark v1.4.12/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20220513210258-46612604a0f9 h1:NUzdAbFtCJSXU20AOXgeqaUwg8Ypg4MPYmL+d+rsB5c= golang.org/x/crypto v0.0.0-20220513210258-46612604a0f9/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220513224357-95641704303c h1:nF9mHSvoKBLkQNQhJZNsc66z2UzAMUbLGjC95CF3pU0= golang.org/x/net v0.0.0-20220513224357-95641704303c/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= maunium.net/go/maulogger/v2 v2.3.2 h1:1XmIYmMd3PoQfp9J+PaHhpt80zpfmMqaShzUTC7FwY0= maunium.net/go/maulogger/v2 v2.3.2/go.mod h1:TYWy7wKwz/tIXTpsx8G3mZseIRiC5DoMxSZazOHy68A= go-0.11.1/id/000077500000000000000000000000001436100171500125715ustar00rootroot00000000000000go-0.11.1/id/contenturi.go000066400000000000000000000063611436100171500153200ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package id import ( "bytes" "encoding/json" "errors" "fmt" "strings" ) var ( InvalidContentURI = errors.New("invalid Matrix content URI") InputNotJSONString = errors.New("input doesn't look like a JSON string") ) // ContentURIString is a string that's expected to be a Matrix content URI. // It's useful for delaying the parsing of the content URI to move errors from the event content // JSON parsing step to a later step where more appropriate errors can be produced. type ContentURIString string func (uriString ContentURIString) Parse() (ContentURI, error) { return ParseContentURI(string(uriString)) } func (uriString ContentURIString) ParseOrIgnore() ContentURI { parsed, _ := ParseContentURI(string(uriString)) return parsed } // ContentURI represents a Matrix content URI. // https://spec.matrix.org/v1.2/client-server-api/#matrix-content-mxc-uris type ContentURI struct { Homeserver string FileID string } func MustParseContentURI(uri string) ContentURI { parsed, err := ParseContentURI(uri) if err != nil { panic(err) } return parsed } // ParseContentURI parses a Matrix content URI. func ParseContentURI(uri string) (parsed ContentURI, err error) { if len(uri) == 0 { return } else if !strings.HasPrefix(uri, "mxc://") { err = InvalidContentURI } else if index := strings.IndexRune(uri[6:], '/'); index == -1 || index == len(uri)-7 { err = InvalidContentURI } else { parsed.Homeserver = uri[6 : 6+index] parsed.FileID = uri[6+index+1:] } return } var mxcBytes = []byte("mxc://") func ParseContentURIBytes(uri []byte) (parsed ContentURI, err error) { if len(uri) == 0 { return } else if !bytes.HasPrefix(uri, mxcBytes) { err = InvalidContentURI } else if index := bytes.IndexRune(uri[6:], '/'); index == -1 || index == len(uri)-7 { err = InvalidContentURI } else { parsed.Homeserver = string(uri[6 : 6+index]) parsed.FileID = string(uri[6+index+1:]) } return } func (uri *ContentURI) UnmarshalJSON(raw []byte) (err error) { if string(raw) == "null" { *uri = ContentURI{} return nil } else if len(raw) < 2 || raw[0] != '"' || raw[len(raw)-1] != '"' { return InputNotJSONString } parsed, err := ParseContentURIBytes(raw[1 : len(raw)-1]) if err != nil { return err } *uri = parsed return nil } func (uri *ContentURI) MarshalJSON() ([]byte, error) { if uri.IsEmpty() { return []byte("null"), nil } return json.Marshal(uri.String()) } func (uri *ContentURI) UnmarshalText(raw []byte) (err error) { parsed, err := ParseContentURIBytes(raw) if err != nil { return err } *uri = parsed return nil } func (uri ContentURI) MarshalText() ([]byte, error) { if uri.IsEmpty() { return []byte(""), nil } return []byte(uri.String()), nil } func (uri *ContentURI) String() string { if uri.IsEmpty() { return "" } return fmt.Sprintf("mxc://%s/%s", uri.Homeserver, uri.FileID) } func (uri *ContentURI) CUString() ContentURIString { return ContentURIString(uri.String()) } func (uri *ContentURI) IsEmpty() bool { return len(uri.Homeserver) == 0 || len(uri.FileID) == 0 } go-0.11.1/id/crypto.go000066400000000000000000000057121436100171500144450ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package id import ( "fmt" "strings" ) // OlmMsgType is an Olm message type type OlmMsgType int const ( OlmMsgTypePreKey OlmMsgType = 0 OlmMsgTypeMsg OlmMsgType = 1 ) // Algorithm is a Matrix message encryption algorithm. // https://spec.matrix.org/v1.2/client-server-api/#messaging-algorithm-names type Algorithm string const ( AlgorithmOlmV1 Algorithm = "m.olm.v1.curve25519-aes-sha2" AlgorithmMegolmV1 Algorithm = "m.megolm.v1.aes-sha2" ) type KeyAlgorithm string const ( KeyAlgorithmCurve25519 KeyAlgorithm = "curve25519" KeyAlgorithmEd25519 KeyAlgorithm = "ed25519" KeyAlgorithmSignedCurve25519 KeyAlgorithm = "signed_curve25519" ) type CrossSigningUsage string const ( XSUsageMaster CrossSigningUsage = "master" XSUsageSelfSigning CrossSigningUsage = "self_signing" XSUsageUserSigning CrossSigningUsage = "user_signing" ) // A SessionID is an arbitrary string that identifies an Olm or Megolm session. type SessionID string func (sessionID SessionID) String() string { return string(sessionID) } // Ed25519 is the base64 representation of an Ed25519 public key type Ed25519 string type SigningKey = Ed25519 func (ed25519 Ed25519) String() string { return string(ed25519) } // Curve25519 is the base64 representation of an Curve25519 public key type Curve25519 string type SenderKey = Curve25519 type IdentityKey = Curve25519 func (curve25519 Curve25519) String() string { return string(curve25519) } // A DeviceID is an arbitrary string that references a specific device. type DeviceID string func (deviceID DeviceID) String() string { return string(deviceID) } // A DeviceKeyID is a string formatted as : that is used as the key in deviceid-key mappings. type DeviceKeyID string func NewDeviceKeyID(algorithm KeyAlgorithm, deviceID DeviceID) DeviceKeyID { return DeviceKeyID(fmt.Sprintf("%s:%s", algorithm, deviceID)) } func (deviceKeyID DeviceKeyID) String() string { return string(deviceKeyID) } func (deviceKeyID DeviceKeyID) Parse() (Algorithm, DeviceID) { index := strings.IndexRune(string(deviceKeyID), ':') if index < 0 || len(deviceKeyID) <= index+1 { return "", "" } return Algorithm(deviceKeyID[:index]), DeviceID(deviceKeyID[index+1:]) } // A KeyID a string formatted as : that is used as the key in one-time-key mappings. type KeyID string func NewKeyID(algorithm KeyAlgorithm, keyID string) KeyID { return KeyID(fmt.Sprintf("%s:%s", algorithm, keyID)) } func (keyID KeyID) String() string { return string(keyID) } func (keyID KeyID) Parse() (KeyAlgorithm, string) { index := strings.IndexRune(string(keyID), ':') if index < 0 || len(keyID) <= index+1 { return "", "" } return KeyAlgorithm(keyID[:index]), string(keyID[index+1:]) } go-0.11.1/id/matrixuri.go000066400000000000000000000215321436100171500151470ustar00rootroot00000000000000// Copyright (c) 2021 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package id import ( "errors" "fmt" "net/url" "strings" ) // Errors that can happen when parsing matrix: URIs var ( ErrInvalidScheme = errors.New("matrix URI scheme must be exactly 'matrix'") ErrInvalidPartCount = errors.New("matrix URIs must have exactly 2 or 4 segments") ErrInvalidFirstSegment = errors.New("invalid identifier in first segment of matrix URI") ErrEmptySecondSegment = errors.New("the second segment of the matrix URI must not be empty") ErrInvalidThirdSegment = errors.New("invalid identifier in third segment of matrix URI") ErrEmptyFourthSegment = errors.New("the fourth segment of the matrix URI must not be empty when the third segment is present") ) // Errors that can happen when parsing matrix.to URLs var ( ErrNotMatrixTo = errors.New("that URL is not a matrix.to URL") ErrInvalidMatrixToPartCount = errors.New("matrix.to URLs must have exactly 1 or 2 segments") ErrEmptyMatrixToPrimaryIdentifier = errors.New("the primary identifier in the matrix.to URL is empty") ErrInvalidMatrixToPrimaryIdentifier = errors.New("the primary identifier in the matrix.to URL has an invalid sigil") ErrInvalidMatrixToSecondaryIdentifier = errors.New("the secondary identifier in the matrix.to URL has an invalid sigil") ) var ErrNotMatrixToOrMatrixURI = errors.New("that URL is not a matrix.to URL nor matrix: URI") // MatrixURI contains the result of parsing a matrix: URI using ParseMatrixURI type MatrixURI struct { Sigil1 rune Sigil2 rune MXID1 string MXID2 string Via []string Action string } // SigilToPathSegment contains a mapping from Matrix identifier sigils to matrix: URI path segments. var SigilToPathSegment = map[rune]string{ '$': "e", '#': "r", '!': "roomid", '@': "u", } func (uri *MatrixURI) getQuery() url.Values { q := make(url.Values) if uri.Via != nil && len(uri.Via) > 0 { q["via"] = uri.Via } if len(uri.Action) > 0 { q.Set("action", uri.Action) } return q } // String converts the parsed matrix: URI back into the string representation. func (uri *MatrixURI) String() string { parts := []string{ SigilToPathSegment[uri.Sigil1], uri.MXID1, } if uri.Sigil2 != 0 { parts = append(parts, SigilToPathSegment[uri.Sigil2], uri.MXID2) } return (&url.URL{ Scheme: "matrix", Opaque: strings.Join(parts, "/"), RawQuery: uri.getQuery().Encode(), }).String() } // MatrixToURL converts to parsed matrix: URI into a matrix.to URL func (uri *MatrixURI) MatrixToURL() string { fragment := fmt.Sprintf("#/%s", url.QueryEscape(uri.PrimaryIdentifier())) if uri.Sigil2 != 0 { fragment = fmt.Sprintf("%s/%s", fragment, url.QueryEscape(uri.SecondaryIdentifier())) } query := uri.getQuery().Encode() if len(query) > 0 { fragment = fmt.Sprintf("%s?%s", fragment, query) } // It would be nice to use URL{...}.String() here, but figuring out the Fragment vs RawFragment stuff is a pain return fmt.Sprintf("https://matrix.to/%s", fragment) } // PrimaryIdentifier returns the first Matrix identifier in the URI. // Currently room IDs, room aliases and user IDs can be in the primary identifier slot. func (uri *MatrixURI) PrimaryIdentifier() string { return fmt.Sprintf("%c%s", uri.Sigil1, uri.MXID1) } // SecondaryIdentifier returns the second Matrix identifier in the URI. // Currently only event IDs can be in the secondary identifier slot. func (uri *MatrixURI) SecondaryIdentifier() string { if uri.Sigil2 == 0 { return "" } return fmt.Sprintf("%c%s", uri.Sigil2, uri.MXID2) } // UserID returns the user ID from the URI if the primary identifier is a user ID. func (uri *MatrixURI) UserID() UserID { if uri.Sigil1 == '@' { return UserID(uri.PrimaryIdentifier()) } return "" } // RoomID returns the room ID from the URI if the primary identifier is a room ID. func (uri *MatrixURI) RoomID() RoomID { if uri.Sigil1 == '!' { return RoomID(uri.PrimaryIdentifier()) } return "" } // RoomAlias returns the room alias from the URI if the primary identifier is a room alias. func (uri *MatrixURI) RoomAlias() RoomAlias { if uri.Sigil1 == '#' { return RoomAlias(uri.PrimaryIdentifier()) } return "" } // EventID returns the event ID from the URI if the primary identifier is a room ID or alias and the secondary identifier is an event ID. func (uri *MatrixURI) EventID() EventID { if (uri.Sigil1 == '!' || uri.Sigil1 == '#') && uri.Sigil2 == '$' { return EventID(uri.SecondaryIdentifier()) } return "" } // ParseMatrixURIOrMatrixToURL parses the given matrix.to URL or matrix: URI into a unified representation. func ParseMatrixURIOrMatrixToURL(uri string) (*MatrixURI, error) { parsed, err := url.Parse(uri) if err != nil { return nil, fmt.Errorf("failed to parse URI: %w", err) } if parsed.Scheme == "matrix" { return ProcessMatrixURI(parsed) } else if strings.HasSuffix(parsed.Hostname(), "matrix.to") { return ProcessMatrixToURL(parsed) } else { return nil, ErrNotMatrixToOrMatrixURI } } // ParseMatrixURI implements the matrix: URI parsing algorithm. // // Currently specified in https://github.com/matrix-org/matrix-doc/blob/master/proposals/2312-matrix-uri.md#uri-parsing-algorithm func ParseMatrixURI(uri string) (*MatrixURI, error) { // Step 1: parse the URI according to RFC 3986 parsed, err := url.Parse(uri) if err != nil { return nil, fmt.Errorf("failed to parse URI: %w", err) } return ProcessMatrixURI(parsed) } // ProcessMatrixURI implements steps 2-7 of the matrix: URI parsing algorithm // (i.e. everything except parsing the URI itself, which is done with url.Parse or ParseMatrixURI) func ProcessMatrixURI(uri *url.URL) (*MatrixURI, error) { // Step 2: check that scheme is exactly `matrix` if uri.Scheme != "matrix" { return nil, ErrInvalidScheme } // Step 3: split the path into segments separated by / parts := strings.Split(uri.Opaque, "/") // Step 4: Check that the URI contains either 2 or 4 segments if len(parts) != 2 && len(parts) != 4 { return nil, ErrInvalidPartCount } var parsed MatrixURI // Step 5: Construct the top-level Matrix identifier // a: find the sigil from the first segment switch parts[0] { case "u", "user": parsed.Sigil1 = '@' case "r", "room": parsed.Sigil1 = '#' case "roomid": parsed.Sigil1 = '!' default: return nil, fmt.Errorf("%w: '%s'", ErrInvalidFirstSegment, parts[0]) } // b: find the identifier from the second segment if len(parts[1]) == 0 { return nil, ErrEmptySecondSegment } parsed.MXID1 = parts[1] // Step 6: if the first part is a room and the URI has 4 segments, construct a second level identifier if (parsed.Sigil1 == '!' || parsed.Sigil1 == '#') && len(parts) == 4 { // a: find the sigil from the third segment switch parts[2] { case "e", "event": parsed.Sigil2 = '$' default: return nil, fmt.Errorf("%w: '%s'", ErrInvalidThirdSegment, parts[0]) } // b: find the identifier from the fourth segment if len(parts[3]) == 0 { return nil, ErrEmptyFourthSegment } parsed.MXID2 = parts[3] } // Step 7: parse the query and extract via and action items via, ok := uri.Query()["via"] if ok && len(via) > 0 { parsed.Via = via } action, ok := uri.Query()["action"] if ok && len(action) > 0 { parsed.Action = action[len(action)-1] } return &parsed, nil } // ParseMatrixToURL parses a matrix.to URL into the same container as ParseMatrixURI parses matrix: URIs. func ParseMatrixToURL(uri string) (*MatrixURI, error) { parsed, err := url.Parse(uri) if err != nil { return nil, fmt.Errorf("failed to parse URL: %w", err) } return ProcessMatrixToURL(parsed) } // ProcessMatrixToURL is the equivalent of ProcessMatrixURI for matrix.to URLs. func ProcessMatrixToURL(uri *url.URL) (*MatrixURI, error) { if !strings.HasSuffix(uri.Hostname(), "matrix.to") { return nil, ErrNotMatrixTo } initialSplit := strings.SplitN(uri.Fragment, "?", 2) parts := strings.Split(initialSplit[0], "/") if len(initialSplit) > 1 { uri.RawQuery = initialSplit[1] } if len(parts) < 2 || len(parts) > 3 { return nil, ErrInvalidMatrixToPartCount } if len(parts[1]) == 0 { return nil, ErrEmptyMatrixToPrimaryIdentifier } var parsed MatrixURI parsed.Sigil1 = rune(parts[1][0]) parsed.MXID1 = parts[1][1:] _, isKnown := SigilToPathSegment[parsed.Sigil1] if !isKnown { return nil, ErrInvalidMatrixToPrimaryIdentifier } if len(parts) == 3 && len(parts[2]) > 0 { parsed.Sigil2 = rune(parts[2][0]) parsed.MXID2 = parts[2][1:] _, isKnown = SigilToPathSegment[parsed.Sigil2] if !isKnown { return nil, ErrInvalidMatrixToSecondaryIdentifier } } via, ok := uri.Query()["via"] if ok && len(via) > 0 { parsed.Via = via } action, ok := uri.Query()["action"] if ok && len(action) > 0 { parsed.Action = action[len(action)-1] } return &parsed, nil } go-0.11.1/id/matrixuri_test.go000066400000000000000000000170601436100171500162070ustar00rootroot00000000000000// Copyright (c) 2021 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package id_test import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "maunium.net/go/mautrix/id" ) var ( roomIDLink = id.MatrixURI{Sigil1: '!', MXID1: "7NdBVvkd4aLSbgKt9RXl:example.org"} roomIDViaLink = id.MatrixURI{Sigil1: '!', MXID1: "7NdBVvkd4aLSbgKt9RXl:example.org", Via: []string{"maunium.net", "matrix.org"}} roomAliasLink = id.MatrixURI{Sigil1: '#', MXID1: "someroom:example.org"} roomIDEventLink = id.MatrixURI{Sigil1: '!', MXID1: "7NdBVvkd4aLSbgKt9RXl:example.org", Sigil2: '$', MXID2: "uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s"} roomAliasEventLink = id.MatrixURI{Sigil1: '#', MXID1: "someroom:example.org", Sigil2: '$', MXID2: "uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s"} userLink = id.MatrixURI{Sigil1: '@', MXID1: "user:example.org"} ) func TestMatrixURI_MatrixToURL(t *testing.T) { assert.Equal(t, "https://matrix.to/#/%217NdBVvkd4aLSbgKt9RXl%3Aexample.org", roomIDLink.MatrixToURL()) assert.Equal(t, "https://matrix.to/#/%217NdBVvkd4aLSbgKt9RXl%3Aexample.org?via=maunium.net&via=matrix.org", roomIDViaLink.MatrixToURL()) assert.Equal(t, "https://matrix.to/#/%23someroom%3Aexample.org", roomAliasLink.MatrixToURL()) assert.Equal(t, "https://matrix.to/#/%217NdBVvkd4aLSbgKt9RXl%3Aexample.org/%24uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s", roomIDEventLink.MatrixToURL()) assert.Equal(t, "https://matrix.to/#/%23someroom%3Aexample.org/%24uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s", roomAliasEventLink.MatrixToURL()) assert.Equal(t, "https://matrix.to/#/%40user%3Aexample.org", userLink.MatrixToURL()) } func TestMatrixURI_String(t *testing.T) { assert.Equal(t, "matrix:roomid/7NdBVvkd4aLSbgKt9RXl:example.org", roomIDLink.String()) assert.Equal(t, "matrix:roomid/7NdBVvkd4aLSbgKt9RXl:example.org?via=maunium.net&via=matrix.org", roomIDViaLink.String()) assert.Equal(t, "matrix:r/someroom:example.org", roomAliasLink.String()) assert.Equal(t, "matrix:roomid/7NdBVvkd4aLSbgKt9RXl:example.org/e/uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s", roomIDEventLink.String()) assert.Equal(t, "matrix:r/someroom:example.org/e/uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s", roomAliasEventLink.String()) assert.Equal(t, "matrix:u/user:example.org", userLink.String()) } func TestParseMatrixURIOrMatrixToURL(t *testing.T) { const inputURI = "matrix:u/user:example.org" const inputMatrixToURL = "https://matrix.to/#/%40user%3Aexample.org" parsed1, err := id.ParseMatrixURIOrMatrixToURL(inputURI) require.NoError(t, err) require.NotNil(t, parsed1) parsed2, err := id.ParseMatrixURIOrMatrixToURL(inputMatrixToURL) require.NoError(t, err) require.NotNil(t, parsed2) assert.Equal(t, parsed1, parsed2) assert.Equal(t, inputURI, parsed2.String()) assert.Equal(t, inputMatrixToURL, parsed1.MatrixToURL()) } func TestParseMatrixURI_RoomAlias(t *testing.T) { parsed1, err := id.ParseMatrixURI("matrix:r/someroom:example.org") require.NoError(t, err) require.NotNil(t, parsed1) parsed2, err := id.ParseMatrixURI("matrix:room/someroom:example.org") require.NoError(t, err) require.NotNil(t, parsed2) assert.Equal(t, roomAliasLink, *parsed1) assert.Equal(t, roomAliasLink, *parsed2) } func TestParseMatrixURI_RoomID(t *testing.T) { parsed, err := id.ParseMatrixURI("matrix:roomid/7NdBVvkd4aLSbgKt9RXl:example.org") require.NoError(t, err) require.NotNil(t, parsed) parsedVia, err := id.ParseMatrixURI("matrix:roomid/7NdBVvkd4aLSbgKt9RXl:example.org?via=maunium.net&via=matrix.org") require.NoError(t, err) require.NotNil(t, parsedVia) assert.Equal(t, roomIDLink, *parsed) assert.Equal(t, roomIDViaLink, *parsedVia) } func TestParseMatrixURI_UserID(t *testing.T) { parsed1, err := id.ParseMatrixURI("matrix:u/user:example.org") require.NoError(t, err) require.NotNil(t, parsed1) parsed2, err := id.ParseMatrixURI("matrix:user/user:example.org") require.NoError(t, err) require.NotNil(t, parsed2) assert.Equal(t, userLink, *parsed1) assert.Equal(t, userLink, *parsed2) } func TestParseMatrixURI_EventID(t *testing.T) { parsed1, err := id.ParseMatrixURI("matrix:r/someroom:example.org/e/uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s") require.NoError(t, err) require.NotNil(t, parsed1) parsed2, err := id.ParseMatrixURI("matrix:room/someroom:example.org/e/uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s") require.NoError(t, err) require.NotNil(t, parsed2) parsed3, err := id.ParseMatrixURI("matrix:roomid/7NdBVvkd4aLSbgKt9RXl:example.org/e/uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s") require.NoError(t, err) require.NotNil(t, parsed3) assert.Equal(t, roomAliasEventLink, *parsed1) assert.Equal(t, roomAliasEventLink, *parsed2) assert.Equal(t, roomIDEventLink, *parsed3) } func TestParseMatrixToURL_RoomAlias(t *testing.T) { parsed, err := id.ParseMatrixToURL("https://matrix.to/#/#someroom:example.org") require.NoError(t, err) require.NotNil(t, parsed) parsedEncoded, err := id.ParseMatrixToURL("https://matrix.to/#/%23someroom%3Aexample.org") require.NoError(t, err) require.NotNil(t, parsedEncoded) assert.Equal(t, roomAliasLink, *parsed) assert.Equal(t, roomAliasLink, *parsedEncoded) } func TestParseMatrixToURL_RoomID(t *testing.T) { parsed, err := id.ParseMatrixToURL("https://matrix.to/#/!7NdBVvkd4aLSbgKt9RXl:example.org") require.NoError(t, err) require.NotNil(t, parsed) parsedEncoded, err := id.ParseMatrixToURL("https://matrix.to/#/%217NdBVvkd4aLSbgKt9RXl%3Aexample.org") require.NoError(t, err) require.NotNil(t, parsedEncoded) parsedVia, err := id.ParseMatrixToURL("https://matrix.to/#/!7NdBVvkd4aLSbgKt9RXl:example.org?via=maunium.net&via=matrix.org") require.NoError(t, err) require.NotNil(t, parsedVia) parsedViaEncoded, err := id.ParseMatrixToURL("https://matrix.to/#/%217NdBVvkd4aLSbgKt9RXl%3Aexample.org?via=maunium.net&via=matrix.org") require.NoError(t, err) require.NotNil(t, parsedViaEncoded) assert.Equal(t, roomIDLink, *parsed) assert.Equal(t, roomIDLink, *parsedEncoded) assert.Equal(t, roomIDViaLink, *parsedVia) assert.Equal(t, roomIDViaLink, *parsedViaEncoded) } func TestParseMatrixToURL_UserID(t *testing.T) { parsed, err := id.ParseMatrixToURL("https://matrix.to/#/@user:example.org") require.NoError(t, err) require.NotNil(t, parsed) parsedEncoded, err := id.ParseMatrixToURL("https://matrix.to/#/%40user%3Aexample.org") require.NoError(t, err) require.NotNil(t, parsedEncoded) assert.Equal(t, userLink, *parsed) assert.Equal(t, userLink, *parsedEncoded) } func TestParseMatrixToURL_EventID(t *testing.T) { parsed1, err := id.ParseMatrixToURL("https://matrix.to/#/#someroom:example.org/$uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s") require.NoError(t, err) require.NotNil(t, parsed1) parsed2, err := id.ParseMatrixToURL("https://matrix.to/#/!7NdBVvkd4aLSbgKt9RXl:example.org/$uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s") require.NoError(t, err) require.NotNil(t, parsed2) parsed1Encoded, err := id.ParseMatrixToURL("https://matrix.to/#/%23someroom:example.org/%24uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s") require.NoError(t, err) require.NotNil(t, parsed1) parsed2Encoded, err := id.ParseMatrixToURL("https://matrix.to/#/%217NdBVvkd4aLSbgKt9RXl:example.org/%24uOH4C9cK4HhMeFWkUXMbdF_dtndJ0j9je-kIK3XpV1s") require.NoError(t, err) require.NotNil(t, parsed2) assert.Equal(t, roomAliasEventLink, *parsed1) assert.Equal(t, roomAliasEventLink, *parsed1Encoded) assert.Equal(t, roomIDEventLink, *parsed2) assert.Equal(t, roomIDEventLink, *parsed2Encoded) } go-0.11.1/id/opaque.go000066400000000000000000000037201436100171500144140ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package id import ( "fmt" ) // A RoomID is a string starting with ! that references a specific room. // https://matrix.org/docs/spec/appendices#room-ids-and-event-ids type RoomID string // A RoomAlias is a string starting with # that can be resolved into. // https://matrix.org/docs/spec/appendices#room-aliases type RoomAlias string func NewRoomAlias(localpart, server string) RoomAlias { return RoomAlias(fmt.Sprintf("#%s:%s", localpart, server)) } // An EventID is a string starting with $ that references a specific event. // // https://matrix.org/docs/spec/appendices#room-ids-and-event-ids // https://matrix.org/docs/spec/rooms/v4#event-ids type EventID string // A BatchID is a string identifying a batch of events being backfilled to a room. // https://github.com/matrix-org/matrix-doc/pull/2716 type BatchID string func (roomID RoomID) String() string { return string(roomID) } func (roomID RoomID) URI(via ...string) *MatrixURI { return &MatrixURI{ Sigil1: '!', MXID1: string(roomID)[1:], Via: via, } } func (roomID RoomID) EventURI(eventID EventID, via ...string) *MatrixURI { return &MatrixURI{ Sigil1: '!', MXID1: string(roomID)[1:], Sigil2: '$', MXID2: string(eventID)[1:], Via: via, } } func (roomAlias RoomAlias) String() string { return string(roomAlias) } func (roomAlias RoomAlias) URI() *MatrixURI { return &MatrixURI{ Sigil1: '#', MXID1: string(roomAlias)[1:], } } func (roomAlias RoomAlias) EventURI(eventID EventID) *MatrixURI { return &MatrixURI{ Sigil1: '#', MXID1: string(roomAlias)[1:], Sigil2: '$', MXID2: string(eventID)[1:], } } func (eventID EventID) String() string { return string(eventID) } func (batchID BatchID) String() string { return string(batchID) } go-0.11.1/id/userid.go000066400000000000000000000160761436100171500144250ustar00rootroot00000000000000// Copyright (c) 2021 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package id import ( "bytes" "encoding/hex" "errors" "fmt" "regexp" "strings" ) // UserID represents a Matrix user ID. // https://matrix.org/docs/spec/appendices#user-identifiers type UserID string const UserIDMaxLength = 255 func NewUserID(localpart, homeserver string) UserID { return UserID(fmt.Sprintf("@%s:%s", localpart, homeserver)) } func NewEncodedUserID(localpart, homeserver string) UserID { return NewUserID(EncodeUserLocalpart(localpart), homeserver) } var ( ErrInvalidUserID = errors.New("is not a valid user ID") ErrNoncompliantLocalpart = errors.New("contains characters that are not allowed") ErrUserIDTooLong = errors.New("the given user ID is longer than 255 characters") ErrEmptyLocalpart = errors.New("empty localparts are not allowed") ) // Parse parses the user ID into the localpart and server name. // // Note that this only enforces very basic user ID formatting requirements: user IDs start with // a @, and contain a : after the @. If you want to enforce localpart validity, see the // ParseAndValidate and ValidateUserLocalpart functions. func (userID UserID) Parse() (localpart, homeserver string, err error) { if len(userID) == 0 || userID[0] != '@' || !strings.ContainsRune(string(userID), ':') { // This error wrapping lets you use errors.Is() nicely even though the message contains the user ID err = fmt.Errorf("'%s' %w", userID, ErrInvalidUserID) return } parts := strings.SplitN(string(userID), ":", 2) localpart, homeserver = strings.TrimPrefix(parts[0], "@"), parts[1] return } // URI returns the user ID as a MatrixURI struct, which can then be stringified into a matrix: URI or a matrix.to URL. // // This does not parse or validate the user ID. Use the ParseAndValidate method if you want to ensure the user ID is valid first. func (userID UserID) URI() *MatrixURI { return &MatrixURI{ Sigil1: '@', MXID1: string(userID)[1:], } } var ValidLocalpartRegex = regexp.MustCompile("^[0-9a-z-.=_/]+$") // ValidateUserLocalpart validates a Matrix user ID localpart using the grammar // in https://matrix.org/docs/spec/appendices#user-identifier func ValidateUserLocalpart(localpart string) error { if len(localpart) == 0 { return ErrEmptyLocalpart } else if !ValidLocalpartRegex.MatchString(localpart) { return fmt.Errorf("'%s' %w", localpart, ErrNoncompliantLocalpart) } return nil } // ParseAndValidate parses the user ID into the localpart and server name like Parse, // and also validates that the localpart is allowed according to the user identifiers spec. func (userID UserID) ParseAndValidate() (localpart, homeserver string, err error) { localpart, homeserver, err = userID.Parse() if err == nil { err = ValidateUserLocalpart(localpart) } if err == nil && len(userID) > UserIDMaxLength { err = ErrUserIDTooLong } return } func (userID UserID) ParseAndDecode() (localpart, homeserver string, err error) { localpart, homeserver, err = userID.ParseAndValidate() if err == nil { localpart, err = DecodeUserLocalpart(localpart) } return } func (userID UserID) String() string { return string(userID) } const lowerhex = "0123456789abcdef" // encode the given byte using quoted-printable encoding (e.g "=2f") // and writes it to the buffer // See https://golang.org/src/mime/quotedprintable/writer.go func encode(buf *bytes.Buffer, b byte) { buf.WriteByte('=') buf.WriteByte(lowerhex[b>>4]) buf.WriteByte(lowerhex[b&0x0f]) } // escape the given alpha character and writes it to the buffer func escape(buf *bytes.Buffer, b byte) { buf.WriteByte('_') if b == '_' { buf.WriteByte('_') // another _ } else { buf.WriteByte(b + 0x20) // ASCII shift A-Z to a-z } } func shouldEncode(b byte) bool { return b != '-' && b != '.' && b != '_' && !(b >= '0' && b <= '9') && !(b >= 'a' && b <= 'z') && !(b >= 'A' && b <= 'Z') } func shouldEscape(b byte) bool { return (b >= 'A' && b <= 'Z') || b == '_' } func isValidByte(b byte) bool { return isValidEscapedChar(b) || (b >= '0' && b <= '9') || b == '.' || b == '=' || b == '-' } func isValidEscapedChar(b byte) bool { return b == '_' || (b >= 'a' && b <= 'z') } // EncodeUserLocalpart encodes the given string into Matrix-compliant user ID localpart form. // See https://spec.matrix.org/v1.2/appendices/#mapping-from-other-character-sets // // This returns a string with only the characters "a-z0-9._=-". The uppercase range A-Z // are encoded using leading underscores ("_"). Characters outside the aforementioned ranges // (including literal underscores ("_") and equals ("=")) are encoded as UTF8 code points (NOT NCRs) // and converted to lower-case hex with a leading "=". For example: // Alph@Bet_50up => _alph=40_bet=5f50up func EncodeUserLocalpart(str string) string { strBytes := []byte(str) var outputBuffer bytes.Buffer for _, b := range strBytes { if shouldEncode(b) { encode(&outputBuffer, b) } else if shouldEscape(b) { escape(&outputBuffer, b) } else { outputBuffer.WriteByte(b) } } return outputBuffer.String() } // DecodeUserLocalpart decodes the given string back into the original input string. // Returns an error if the given string is not a valid user ID localpart encoding. // See https://spec.matrix.org/v1.2/appendices/#mapping-from-other-character-sets // // This decodes quoted-printable bytes back into UTF8, and unescapes casing. For // example: // _alph=40_bet=5f50up => Alph@Bet_50up // Returns an error if the input string contains characters outside the // range "a-z0-9._=-", has an invalid quote-printable byte (e.g. not hex), or has // an invalid _ escaped byte (e.g. "_5"). func DecodeUserLocalpart(str string) (string, error) { strBytes := []byte(str) var outputBuffer bytes.Buffer for i := 0; i < len(strBytes); i++ { b := strBytes[i] if !isValidByte(b) { return "", fmt.Errorf("Byte pos %d: Invalid byte", i) } if b == '_' { // next byte is a-z and should be upper-case or is another _ and should be a literal _ if i+1 >= len(strBytes) { return "", fmt.Errorf("Byte pos %d: expected _[a-z_] encoding but ran out of string", i) } if !isValidEscapedChar(strBytes[i+1]) { // invalid escaping return "", fmt.Errorf("Byte pos %d: expected _[a-z_] encoding", i) } if strBytes[i+1] == '_' { outputBuffer.WriteByte('_') } else { outputBuffer.WriteByte(strBytes[i+1] - 0x20) // ASCII shift a-z to A-Z } i++ // skip next byte since we just handled it } else if b == '=' { // next 2 bytes are hex and should be buffered ready to be read as utf8 if i+2 >= len(strBytes) { return "", fmt.Errorf("Byte pos: %d: expected quote-printable encoding but ran out of string", i) } dst := make([]byte, 1) _, err := hex.Decode(dst, strBytes[i+1:i+3]) if err != nil { return "", err } outputBuffer.WriteByte(dst[0]) i += 2 // skip next 2 bytes since we just handled it } else { // pass through outputBuffer.WriteByte(b) } } return outputBuffer.String(), nil } go-0.11.1/id/userid_test.go000066400000000000000000000064201436100171500154540ustar00rootroot00000000000000// Copyright (c) 2021 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package id_test import ( "errors" "testing" "github.com/stretchr/testify/assert" "maunium.net/go/mautrix/id" ) func TestUserID_Parse(t *testing.T) { const inputUserID = "@s p a c e:maunium.net" parsedLocalpart, parsedServerName, err := id.UserID(inputUserID).Parse() assert.NoError(t, err) assert.Equal(t, "s p a c e", parsedLocalpart) assert.Equal(t, "maunium.net", parsedServerName) } func TestUserID_Parse_Empty(t *testing.T) { const inputUserID = "@:ponies.im" parsedLocalpart, parsedServerName, err := id.UserID(inputUserID).Parse() assert.NoError(t, err) assert.Equal(t, "", parsedLocalpart) assert.Equal(t, "ponies.im", parsedServerName) } func TestUserID_Parse_Invalid(t *testing.T) { const inputUserID = "hello world" _, _, err := id.UserID(inputUserID).Parse() assert.Error(t, err) assert.True(t, errors.Is(err, id.ErrInvalidUserID)) } func TestUserID_ParseAndValidate_Invalid(t *testing.T) { const inputUserID = "@s p a c e:maunium.net" _, _, err := id.UserID(inputUserID).ParseAndValidate() assert.Error(t, err) assert.True(t, errors.Is(err, id.ErrNoncompliantLocalpart)) } func TestUserID_ParseAndValidate_Empty(t *testing.T) { const inputUserID = "@:ponies.im" _, _, err := id.UserID(inputUserID).ParseAndValidate() assert.Error(t, err) assert.True(t, errors.Is(err, id.ErrEmptyLocalpart)) } func TestUserID_ParseAndValidate_Long(t *testing.T) { const inputUserID = "@aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa:example.com" _, _, err := id.UserID(inputUserID).ParseAndValidate() assert.Error(t, err) assert.True(t, errors.Is(err, id.ErrUserIDTooLong)) } func TestUserID_ParseAndValidate_NotLong(t *testing.T) { const inputUserID = "@aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa:example.com" _, _, err := id.UserID(inputUserID).ParseAndValidate() assert.NoError(t, err) } func TestUserIDEncoding(t *testing.T) { const inputLocalpart = "This localpart contains IlLeGaL chรคracters ๐Ÿšจ" const encodedLocalpart = "_this=20localpart=20contains=20_il_le_ga_l=20ch=c3=a4racters=20=f0=9f=9a=a8" const inputServerName = "example.com" userID := id.NewEncodedUserID(inputLocalpart, inputServerName) parsedLocalpart, parsedServerName, err := userID.ParseAndValidate() assert.NoError(t, err) assert.Equal(t, encodedLocalpart, parsedLocalpart) assert.Equal(t, inputServerName, parsedServerName) decodedLocalpart, decodedServerName, err := userID.ParseAndDecode() assert.NoError(t, err) assert.Equal(t, inputLocalpart, decodedLocalpart) assert.Equal(t, inputServerName, decodedServerName) } func TestUserID_URI(t *testing.T) { userID := id.NewUserID("hello", "example.com") assert.Equal(t, userID.URI().String(), "matrix:u/hello:example.com") } go-0.11.1/pushrules/000077500000000000000000000000001436100171500142275ustar00rootroot00000000000000go-0.11.1/pushrules/action.go000066400000000000000000000072331436100171500160400ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package pushrules import "encoding/json" // PushActionType is the type of a PushAction type PushActionType string // The allowed push action types as specified in spec section 11.12.1.4.1. const ( ActionNotify PushActionType = "notify" ActionDontNotify PushActionType = "dont_notify" ActionCoalesce PushActionType = "coalesce" ActionSetTweak PushActionType = "set_tweak" ) // PushActionTweak is the type of the tweak in SetTweak push actions. type PushActionTweak string // The allowed tweak types as specified in spec section 11.12.1.4.1.1. const ( TweakSound PushActionTweak = "sound" TweakHighlight PushActionTweak = "highlight" ) // PushActionArray is an array of PushActions. type PushActionArray []*PushAction // PushActionArrayShould contains the important information parsed from a PushActionArray. type PushActionArrayShould struct { // Whether or not the array contained a Notify, DontNotify or Coalesce action type. NotifySpecified bool // Whether or not the event in question should trigger a notification. Notify bool // Whether or not the event in question should be highlighted. Highlight bool // Whether or not the event in question should trigger a sound alert. PlaySound bool // The name of the sound to play if PlaySound is true. SoundName string } // Should parses this push action array and returns the relevant details wrapped in a PushActionArrayShould struct. func (actions PushActionArray) Should() (should PushActionArrayShould) { for _, action := range actions { switch action.Action { case ActionNotify, ActionCoalesce: should.Notify = true should.NotifySpecified = true case ActionDontNotify: should.Notify = false should.NotifySpecified = true case ActionSetTweak: switch action.Tweak { case TweakHighlight: var ok bool should.Highlight, ok = action.Value.(bool) if !ok { // Highlight value not specified, so assume true since the tweak is set. should.Highlight = true } case TweakSound: should.SoundName = action.Value.(string) should.PlaySound = len(should.SoundName) > 0 } } } return } // PushAction is a single action that should be triggered when receiving a message. type PushAction struct { Action PushActionType Tweak PushActionTweak Value interface{} } // UnmarshalJSON parses JSON into this PushAction. // // * If the JSON is a single string, the value is stored in the Action field. // * If the JSON is an object with the set_tweak field, Action will be set to // "set_tweak", Tweak will be set to the value of the set_tweak field and // and Value will be set to the value of the value field. // * In any other case, the function does nothing. func (action *PushAction) UnmarshalJSON(raw []byte) error { var data interface{} err := json.Unmarshal(raw, &data) if err != nil { return err } switch val := data.(type) { case string: action.Action = PushActionType(val) case map[string]interface{}: tweak, ok := val["set_tweak"].(string) if ok { action.Action = ActionSetTweak action.Tweak = PushActionTweak(tweak) action.Value, _ = val["value"] } } return nil } // MarshalJSON is the reverse of UnmarshalJSON() func (action *PushAction) MarshalJSON() (raw []byte, err error) { if action.Action == ActionSetTweak { data := map[string]interface{}{ "set_tweak": action.Tweak, "value": action.Value, } return json.Marshal(&data) } data := string(action.Action) return json.Marshal(&data) } go-0.11.1/pushrules/action_test.go000066400000000000000000000143671436100171500171050ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package pushrules_test import ( "testing" "github.com/stretchr/testify/assert" "maunium.net/go/mautrix/pushrules" ) func TestPushActionArray_Should_EmptyArrayReturnsDefaults(t *testing.T) { should := pushrules.PushActionArray{}.Should() assert.False(t, should.NotifySpecified) assert.False(t, should.Notify) assert.False(t, should.Highlight) assert.False(t, should.PlaySound) assert.Empty(t, should.SoundName) } func TestPushActionArray_Should_MixedArrayReturnsExpected1(t *testing.T) { should := pushrules.PushActionArray{ {Action: pushrules.ActionNotify}, {Action: pushrules.ActionSetTweak, Tweak: pushrules.TweakHighlight}, {Action: pushrules.ActionSetTweak, Tweak: pushrules.TweakSound, Value: "ping"}, }.Should() assert.True(t, should.NotifySpecified) assert.True(t, should.Notify) assert.True(t, should.Highlight) assert.True(t, should.PlaySound) assert.Equal(t, "ping", should.SoundName) } func TestPushActionArray_Should_MixedArrayReturnsExpected2(t *testing.T) { should := pushrules.PushActionArray{ {Action: pushrules.ActionDontNotify}, {Action: pushrules.ActionSetTweak, Tweak: pushrules.TweakHighlight, Value: false}, {Action: pushrules.ActionSetTweak, Tweak: pushrules.TweakSound, Value: ""}, }.Should() assert.True(t, should.NotifySpecified) assert.False(t, should.Notify) assert.False(t, should.Highlight) assert.False(t, should.PlaySound) assert.Empty(t, should.SoundName) } func TestPushActionArray_Should_NotifySet(t *testing.T) { should := pushrules.PushActionArray{ {Action: pushrules.ActionNotify}, }.Should() assert.True(t, should.NotifySpecified) assert.True(t, should.Notify) assert.False(t, should.Highlight) assert.False(t, should.PlaySound) assert.Empty(t, should.SoundName) } func TestPushActionArray_Should_NotifyAndCoalesceDoTheSameThing(t *testing.T) { should1 := pushrules.PushActionArray{ {Action: pushrules.ActionNotify}, }.Should() should2 := pushrules.PushActionArray{ {Action: pushrules.ActionCoalesce}, }.Should() assert.Equal(t, should1, should2) } func TestPushActionArray_Should_DontNotify(t *testing.T) { should := pushrules.PushActionArray{ {Action: pushrules.ActionDontNotify}, }.Should() assert.True(t, should.NotifySpecified) assert.False(t, should.Notify) assert.False(t, should.Highlight) assert.False(t, should.PlaySound) assert.Empty(t, should.SoundName) } func TestPushActionArray_Should_HighlightBlank(t *testing.T) { should := pushrules.PushActionArray{ {Action: pushrules.ActionSetTweak, Tweak: pushrules.TweakHighlight}, }.Should() assert.False(t, should.NotifySpecified) assert.False(t, should.Notify) assert.True(t, should.Highlight) assert.False(t, should.PlaySound) assert.Empty(t, should.SoundName) } func TestPushActionArray_Should_HighlightFalse(t *testing.T) { should := pushrules.PushActionArray{ {Action: pushrules.ActionSetTweak, Tweak: pushrules.TweakHighlight, Value: false}, }.Should() assert.False(t, should.NotifySpecified) assert.False(t, should.Notify) assert.False(t, should.Highlight) assert.False(t, should.PlaySound) assert.Empty(t, should.SoundName) } func TestPushActionArray_Should_SoundName(t *testing.T) { should := pushrules.PushActionArray{ {Action: pushrules.ActionSetTweak, Tweak: pushrules.TweakSound, Value: "ping"}, }.Should() assert.False(t, should.NotifySpecified) assert.False(t, should.Notify) assert.False(t, should.Highlight) assert.True(t, should.PlaySound) assert.Equal(t, "ping", should.SoundName) } func TestPushActionArray_Should_SoundNameEmpty(t *testing.T) { should := pushrules.PushActionArray{ {Action: pushrules.ActionSetTweak, Tweak: pushrules.TweakSound, Value: ""}, }.Should() assert.False(t, should.NotifySpecified) assert.False(t, should.Notify) assert.False(t, should.Highlight) assert.False(t, should.PlaySound) assert.Empty(t, should.SoundName) } func TestPushAction_UnmarshalJSON_InvalidJSONFails(t *testing.T) { pa := &pushrules.PushAction{} err := pa.UnmarshalJSON([]byte("Not JSON")) assert.NotNil(t, err) } func TestPushAction_UnmarshalJSON_InvalidTypeDoesNothing(t *testing.T) { pa := &pushrules.PushAction{ Action: pushrules.PushActionType("unchanged"), Tweak: pushrules.PushActionTweak("unchanged"), Value: "unchanged", } err := pa.UnmarshalJSON([]byte(`{"foo": "bar"}`)) assert.Nil(t, err) err = pa.UnmarshalJSON([]byte(`9001`)) assert.Nil(t, err) assert.Equal(t, pushrules.PushActionType("unchanged"), pa.Action) assert.Equal(t, pushrules.PushActionTweak("unchanged"), pa.Tweak) assert.Equal(t, "unchanged", pa.Value) } func TestPushAction_UnmarshalJSON_StringChangesActionType(t *testing.T) { pa := &pushrules.PushAction{ Action: pushrules.PushActionType("unchanged"), Tweak: pushrules.PushActionTweak("unchanged"), Value: "unchanged", } err := pa.UnmarshalJSON([]byte(`"foo"`)) assert.Nil(t, err) assert.Equal(t, pushrules.PushActionType("foo"), pa.Action) assert.Equal(t, pushrules.PushActionTweak("unchanged"), pa.Tweak) assert.Equal(t, "unchanged", pa.Value) } func TestPushAction_UnmarshalJSON_SetTweakChangesTweak(t *testing.T) { pa := &pushrules.PushAction{ Action: pushrules.PushActionType("unchanged"), Tweak: pushrules.PushActionTweak("unchanged"), Value: "unchanged", } err := pa.UnmarshalJSON([]byte(`{"set_tweak": "foo", "value": 123.0}`)) assert.Nil(t, err) assert.Equal(t, pushrules.ActionSetTweak, pa.Action) assert.Equal(t, pushrules.PushActionTweak("foo"), pa.Tweak) assert.Equal(t, 123.0, pa.Value) } func TestPushAction_MarshalJSON_TweakOutputWorks(t *testing.T) { pa := &pushrules.PushAction{ Action: pushrules.ActionSetTweak, Tweak: pushrules.PushActionTweak("foo"), Value: "bar", } data, err := pa.MarshalJSON() assert.Nil(t, err) assert.Equal(t, []byte(`{"set_tweak":"foo","value":"bar"}`), data) } func TestPushAction_MarshalJSON_OtherOutputWorks(t *testing.T) { pa := &pushrules.PushAction{ Action: pushrules.PushActionType("something else"), Tweak: pushrules.PushActionTweak("foo"), Value: "bar", } data, err := pa.MarshalJSON() assert.Nil(t, err) assert.Equal(t, []byte(`"something else"`), data) } go-0.11.1/pushrules/condition.go000066400000000000000000000102701436100171500165440ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package pushrules import ( "regexp" "strconv" "strings" "unicode" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/pushrules/glob" ) // Room is an interface with the functions that are needed for processing room-specific push conditions type Room interface { GetOwnDisplayname() string GetMemberCount() int } // PushCondKind is the type of a push condition. type PushCondKind string // The allowed push condition kinds as specified in https://spec.matrix.org/v1.2/client-server-api/#conditions-1 const ( KindEventMatch PushCondKind = "event_match" KindContainsDisplayName PushCondKind = "contains_display_name" KindRoomMemberCount PushCondKind = "room_member_count" ) // PushCondition wraps a condition that is required for a specific PushRule to be used. type PushCondition struct { // The type of the condition. Kind PushCondKind `json:"kind"` // The dot-separated field of the event to match. Only applicable if kind is EventMatch. Key string `json:"key,omitempty"` // The glob-style pattern to match the field against. Only applicable if kind is EventMatch. Pattern string `json:"pattern,omitempty"` // The condition that needs to be fulfilled for RoomMemberCount-type conditions. // A decimal integer optionally prefixed by ==, <, >, >= or <=. Prefix "==" is assumed if no prefix found. MemberCountCondition string `json:"is,omitempty"` } // MemberCountFilterRegex is the regular expression to parse the MemberCountCondition of PushConditions. var MemberCountFilterRegex = regexp.MustCompile("^(==|[<>]=?)?([0-9]+)$") // Match checks if this condition is fulfilled for the given event in the given room. func (cond *PushCondition) Match(room Room, evt *event.Event) bool { switch cond.Kind { case KindEventMatch: return cond.matchValue(room, evt) case KindContainsDisplayName: return cond.matchDisplayName(room, evt) case KindRoomMemberCount: return cond.matchMemberCount(room) default: return false } } func (cond *PushCondition) matchValue(room Room, evt *event.Event) bool { index := strings.IndexRune(cond.Key, '.') key := cond.Key subkey := "" if index > 0 { subkey = key[index+1:] key = key[0:index] } pattern, err := glob.Compile(cond.Pattern) if err != nil { return false } switch key { case "type": return pattern.MatchString(evt.Type.String()) case "sender": return pattern.MatchString(string(evt.Sender)) case "room_id": return pattern.MatchString(string(evt.RoomID)) case "state_key": if evt.StateKey == nil { return cond.Pattern == "" } return pattern.MatchString(*evt.StateKey) case "content": val, _ := evt.Content.Raw[subkey].(string) return pattern.MatchString(val) default: return false } } func (cond *PushCondition) matchDisplayName(room Room, evt *event.Event) bool { displayname := room.GetOwnDisplayname() if len(displayname) == 0 { return false } msg, ok := evt.Content.Raw["body"].(string) if !ok { return false } isAcceptable := func(r uint8) bool { return unicode.IsSpace(rune(r)) || unicode.IsPunct(rune(r)) } length := len(displayname) for index := strings.Index(msg, displayname); index != -1; index = strings.Index(msg, displayname) { if (index <= 0 || isAcceptable(msg[index-1])) && (index+length >= len(msg) || isAcceptable(msg[index+length])) { return true } msg = msg[index+len(displayname):] } return false } func (cond *PushCondition) matchMemberCount(room Room) bool { group := MemberCountFilterRegex.FindStringSubmatch(cond.MemberCountCondition) if len(group) != 3 { return false } operator := group[1] wantedMemberCount, _ := strconv.Atoi(group[2]) memberCount := room.GetMemberCount() switch operator { case "==", "": return memberCount == wantedMemberCount case ">": return memberCount > wantedMemberCount case ">=": return memberCount >= wantedMemberCount case "<": return memberCount < wantedMemberCount case "<=": return memberCount <= wantedMemberCount default: // Should be impossible due to regex. return false } } go-0.11.1/pushrules/condition_displayname_test.go000066400000000000000000000025331436100171500221740ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package pushrules_test import ( "maunium.net/go/mautrix/event" "testing" "github.com/stretchr/testify/assert" ) func TestPushCondition_Match_DisplayName(t *testing.T) { evt := newFakeEvent(event.EventMessage, &event.MessageEventContent{ MsgType: event.MsgText, Body: "tulir: test mention", }) evt.Sender = "@someone_else:matrix.org" assert.True(t, displaynamePushCondition.Match(displaynameTestRoom, evt)) } func TestPushCondition_Match_DisplayName_Fail(t *testing.T) { evt := newFakeEvent(event.EventMessage, &event.MessageEventContent{ MsgType: event.MsgText, Body: "not a mention", }) evt.Sender = "@someone_else:matrix.org" assert.False(t, displaynamePushCondition.Match(displaynameTestRoom, evt)) } func TestPushCondition_Match_DisplayName_FailsOnEmptyRoom(t *testing.T) { emptyRoom := newFakeRoom(0) evt := newFakeEvent(event.EventMessage, &event.MessageEventContent{ MsgType: event.MsgText, Body: "tulir: this room doesn't have the owner Member available, so it fails.", }) evt.Sender = "@someone_else:matrix.org" assert.False(t, displaynamePushCondition.Match(emptyRoom, evt)) } go-0.11.1/pushrules/condition_eventmatch_test.go000066400000000000000000000061031436100171500220210ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package pushrules_test import ( "maunium.net/go/mautrix/event" "testing" "github.com/stretchr/testify/assert" ) func TestPushCondition_Match_KindEvent_MsgType(t *testing.T) { condition := newMatchPushCondition("content.msgtype", "m.emote") evt := newFakeEvent(event.EventMessage, &event.MessageEventContent{ MsgType: event.MsgEmote, Body: "tests gomuks pushconditions", }) assert.True(t, condition.Match(blankTestRoom, evt)) } func TestPushCondition_Match_KindEvent_MsgType_Fail(t *testing.T) { condition := newMatchPushCondition("content.msgtype", "m.emote") evt := newFakeEvent(event.EventMessage, &event.MessageEventContent{ MsgType: event.MsgText, Body: "I'm testing gomuks pushconditions", }) assert.False(t, condition.Match(blankTestRoom, evt)) } func TestPushCondition_Match_KindEvent_EventType(t *testing.T) { condition := newMatchPushCondition("type", "m.room.foo") evt := newFakeEvent(event.NewEventType("m.room.foo"), &struct{}{}) assert.True(t, condition.Match(blankTestRoom, evt)) } func TestPushCondition_Match_KindEvent_EventType_IllegalGlob(t *testing.T) { condition := newMatchPushCondition("type", "m.room.invalid_glo[b") evt := newFakeEvent(event.NewEventType("m.room.invalid_glob"), &struct{}{}) assert.False(t, condition.Match(blankTestRoom, evt)) } func TestPushCondition_Match_KindEvent_Sender_Fail(t *testing.T) { condition := newMatchPushCondition("sender", "@foo:maunium.net") evt := newFakeEvent(event.NewEventType("m.room.foo"), &struct{}{}) assert.False(t, condition.Match(blankTestRoom, evt)) } func TestPushCondition_Match_KindEvent_RoomID(t *testing.T) { condition := newMatchPushCondition("room_id", "!fakeroom:maunium.net") evt := newFakeEvent(event.NewEventType(""), &struct{}{}) assert.True(t, condition.Match(blankTestRoom, evt)) } func TestPushCondition_Match_KindEvent_BlankStateKey(t *testing.T) { condition := newMatchPushCondition("state_key", "") evt := newFakeEvent(event.NewEventType("m.room.foo"), &struct{}{}) assert.True(t, condition.Match(blankTestRoom, evt)) } func TestPushCondition_Match_KindEvent_BlankStateKey_Fail(t *testing.T) { condition := newMatchPushCondition("state_key", "not blank") evt := newFakeEvent(event.NewEventType("m.room.foo"), &struct{}{}) assert.False(t, condition.Match(blankTestRoom, evt)) } func TestPushCondition_Match_KindEvent_NonBlankStateKey(t *testing.T) { condition := newMatchPushCondition("state_key", "*:maunium.net") evt := newFakeEvent(event.NewEventType("m.room.foo"), &struct{}{}) evt.StateKey = (*string)(&evt.Sender) assert.True(t, condition.Match(blankTestRoom, evt)) } func TestPushCondition_Match_KindEvent_UnknownKey(t *testing.T) { condition := newMatchPushCondition("non-existent key", "doesn't affect anything") evt := newFakeEvent(event.NewEventType("m.room.foo"), &struct{}{}) assert.False(t, condition.Match(blankTestRoom, evt)) } go-0.11.1/pushrules/condition_membercount_test.go000066400000000000000000000037531436100171500222130ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package pushrules_test import ( "testing" "github.com/stretchr/testify/assert" ) func TestPushCondition_Match_KindMemberCount_OneToOne_ImplicitPrefix(t *testing.T) { condition := newCountPushCondition("2") room := newFakeRoom(2) assert.True(t, condition.Match(room, countConditionTestEvent)) } func TestPushCondition_Match_KindMemberCount_OneToOne_ExplicitPrefix(t *testing.T) { condition := newCountPushCondition("==2") room := newFakeRoom(2) assert.True(t, condition.Match(room, countConditionTestEvent)) } func TestPushCondition_Match_KindMemberCount_BigRoom(t *testing.T) { condition := newCountPushCondition(">200") room := newFakeRoom(201) assert.True(t, condition.Match(room, countConditionTestEvent)) } func TestPushCondition_Match_KindMemberCount_BigRoom_Fail(t *testing.T) { condition := newCountPushCondition(">=200") room := newFakeRoom(199) assert.False(t, condition.Match(room, countConditionTestEvent)) } func TestPushCondition_Match_KindMemberCount_SmallRoom(t *testing.T) { condition := newCountPushCondition("<10") room := newFakeRoom(9) assert.True(t, condition.Match(room, countConditionTestEvent)) } func TestPushCondition_Match_KindMemberCount_SmallRoom_Fail(t *testing.T) { condition := newCountPushCondition("<=10") room := newFakeRoom(11) assert.False(t, condition.Match(room, countConditionTestEvent)) } func TestPushCondition_Match_KindMemberCount_InvalidPrefix(t *testing.T) { condition := newCountPushCondition("??10") room := newFakeRoom(11) assert.False(t, condition.Match(room, countConditionTestEvent)) } func TestPushCondition_Match_KindMemberCount_InvalidCondition(t *testing.T) { condition := newCountPushCondition("foobar") room := newFakeRoom(1) assert.False(t, condition.Match(room, countConditionTestEvent)) } go-0.11.1/pushrules/condition_test.go000066400000000000000000000062621436100171500176110ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package pushrules_test import ( "encoding/json" "fmt" "testing" "github.com/stretchr/testify/assert" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/pushrules" ) var ( blankTestRoom pushrules.Room displaynameTestRoom pushrules.Room countConditionTestEvent *event.Event displaynamePushCondition *pushrules.PushCondition ) func init() { blankTestRoom = newFakeRoom(1) countConditionTestEvent = &event.Event{ Sender: "@tulir:maunium.net", Type: event.EventMessage, Timestamp: 1523791120, ID: "$123:maunium.net", RoomID: "!fakeroom:maunium.net", Content: event.Content{ Raw: map[string]interface{}{ "msgtype": "m.text", "body": "test", }, Parsed: &event.MessageEventContent{ MsgType: event.MsgText, Body: "test", }, }, } displaynameTestRoom = newFakeRoom(4) displaynamePushCondition = &pushrules.PushCondition{ Kind: pushrules.KindContainsDisplayName, } } func newFakeEvent(evtType event.Type, parsed interface{}) *event.Event { data, err := json.Marshal(parsed) if err != nil { panic(err) } var raw map[string]interface{} err = json.Unmarshal(data, &raw) if err != nil { panic(err) } content := event.Content{ VeryRaw: data, Raw: raw, Parsed: parsed, } return &event.Event{ Sender: "@tulir:maunium.net", Type: evtType, Timestamp: 1523791120, ID: "$123:maunium.net", RoomID: "!fakeroom:maunium.net", Content: content, } } func newCountPushCondition(condition string) *pushrules.PushCondition { return &pushrules.PushCondition{ Kind: pushrules.KindRoomMemberCount, MemberCountCondition: condition, } } func newMatchPushCondition(key, pattern string) *pushrules.PushCondition { return &pushrules.PushCondition{ Kind: pushrules.KindEventMatch, Key: key, Pattern: pattern, } } func TestPushCondition_Match_InvalidKind(t *testing.T) { condition := &pushrules.PushCondition{ Kind: pushrules.PushCondKind("invalid"), } evt := newFakeEvent(event.Type{Type: "m.room.foobar"}, &struct{}{}) assert.False(t, condition.Match(blankTestRoom, evt)) } type FakeRoom struct { members map[string]*event.MemberEventContent owner string } func newFakeRoom(memberCount int) *FakeRoom { room := &FakeRoom{ owner: "@tulir:maunium.net", members: make(map[string]*event.MemberEventContent), } if memberCount >= 1 { room.members["@tulir:maunium.net"] = &event.MemberEventContent{ Membership: event.MembershipJoin, Displayname: "tulir", } } for i := 0; i < memberCount-1; i++ { mxid := fmt.Sprintf("@extrauser_%d:matrix.org", i) room.members[mxid] = &event.MemberEventContent{ Membership: event.MembershipJoin, Displayname: fmt.Sprintf("Extra User %d", i), } } return room } func (fr *FakeRoom) GetMemberCount() int { return len(fr.members) } func (fr *FakeRoom) GetOwnDisplayname() string { member, ok := fr.members[fr.owner] if ok { return member.Displayname } return "" } go-0.11.1/pushrules/doc.go000066400000000000000000000001341436100171500153210ustar00rootroot00000000000000// Package pushrules contains utilities to parse push notification rules. package pushrules go-0.11.1/pushrules/glob/000077500000000000000000000000001436100171500151525ustar00rootroot00000000000000go-0.11.1/pushrules/glob/LICENSE000066400000000000000000000021261436100171500161600ustar00rootroot00000000000000Glob is licensed under the MIT "Expat" License: Copyright (c) 2016: Zachary Yedidia. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. go-0.11.1/pushrules/glob/README.md000066400000000000000000000013431436100171500164320ustar00rootroot00000000000000# String globbing in Go [![GoDoc](https://godoc.org/github.com/zyedidia/glob?status.svg)](http://godoc.org/github.com/zyedidia/glob) This package adds support for globs in Go. It simply converts glob expressions to regexps. I try to follow the standard defined [here](http://pubs.opengroup.org/onlinepubs/009695399/utilities/xcu_chap02.html#tag_02_13). # Example ```go package main import "github.com/zyedidia/glob" func main() { glob, err := glob.Compile("{*.go,*.c}") if err != nil { // Error } glob.Match([]byte("test.c")) // true glob.Match([]byte("hello.go")) // true glob.Match([]byte("test.d")) // false } ``` You can call all the same functions on a glob that you can call on a regexp. go-0.11.1/pushrules/glob/glob.go000066400000000000000000000040171436100171500164260ustar00rootroot00000000000000// Package glob provides objects for matching strings with globs package glob import "regexp" // Glob is a wrapper of *regexp.Regexp. // It should contain a glob expression compiled into a regular expression. type Glob struct { *regexp.Regexp } // Compile a takes a glob expression as a string and transforms it // into a *Glob object (which is really just a regular expression) // Compile also returns a possible error. func Compile(pattern string) (*Glob, error) { r, err := globToRegex(pattern) return &Glob{r}, err } func globToRegex(glob string) (*regexp.Regexp, error) { regex := "" inGroup := 0 inClass := 0 firstIndexInClass := -1 arr := []byte(glob) hasGlobCharacters := false for i := 0; i < len(arr); i++ { ch := arr[i] switch ch { case '\\': i++ if i >= len(arr) { regex += "\\" } else { next := arr[i] switch next { case ',': // Nothing case 'Q', 'E': regex += "\\\\" default: regex += "\\" } regex += string(next) } case '*': if inClass == 0 { regex += ".*" } else { regex += "*" } hasGlobCharacters = true case '?': if inClass == 0 { regex += "." } else { regex += "?" } hasGlobCharacters = true case '[': inClass++ firstIndexInClass = i + 1 regex += "[" hasGlobCharacters = true case ']': inClass-- regex += "]" case '.', '(', ')', '+', '|', '^', '$', '@', '%': if inClass == 0 || (firstIndexInClass == i && ch == '^') { regex += "\\" } regex += string(ch) hasGlobCharacters = true case '!': if firstIndexInClass == i { regex += "^" } else { regex += "!" } hasGlobCharacters = true case '{': inGroup++ regex += "(" hasGlobCharacters = true case '}': inGroup-- regex += ")" case ',': if inGroup > 0 { regex += "|" hasGlobCharacters = true } else { regex += "," } default: regex += string(ch) } } if hasGlobCharacters { return regexp.Compile("^" + regex + "$") } else { return regexp.Compile(regex) } } go-0.11.1/pushrules/pushrules.go000066400000000000000000000017451436100171500166170ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package pushrules import ( "encoding/gob" "encoding/json" "reflect" "maunium.net/go/mautrix/event" ) // EventContent represents the content of a m.push_rules account data event. // https://spec.matrix.org/v1.2/client-server-api/#mpush_rules type EventContent struct { Ruleset *PushRuleset `json:"global"` } func init() { event.TypeMap[event.AccountDataPushRules] = reflect.TypeOf(EventContent{}) gob.Register(&EventContent{}) } // EventToPushRules converts a m.push_rules event to a PushRuleset by passing the data through JSON. func EventToPushRules(evt *event.Event) (*PushRuleset, error) { content := &EventContent{} err := json.Unmarshal(evt.Content.VeryRaw, content) if err != nil { return nil, err } return content.Ruleset, nil } go-0.11.1/pushrules/pushrules_test.go000066400000000000000000000131361436100171500176530ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package pushrules_test import ( "encoding/json" "testing" "github.com/stretchr/testify/assert" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/pushrules" ) func TestEventToPushRules(t *testing.T) { evt := &event.Event{ Type: event.AccountDataPushRules, Timestamp: 1523380910, Content: event.Content{ VeryRaw: json.RawMessage(JSONExamplePushRules), }, } pushRuleset, err := pushrules.EventToPushRules(evt) assert.Nil(t, err) assert.NotNil(t, pushRuleset) assert.IsType(t, pushRuleset.Override, pushrules.PushRuleArray{}) assert.IsType(t, pushRuleset.Content, pushrules.PushRuleArray{}) assert.IsType(t, pushRuleset.Room, pushrules.PushRuleMap{}) assert.IsType(t, pushRuleset.Sender, pushrules.PushRuleMap{}) assert.IsType(t, pushRuleset.Underride, pushrules.PushRuleArray{}) assert.Len(t, pushRuleset.Override, 2) assert.Len(t, pushRuleset.Content, 1) assert.Empty(t, pushRuleset.Room.Map) assert.Empty(t, pushRuleset.Sender.Map) assert.Len(t, pushRuleset.Underride, 6) assert.Len(t, pushRuleset.Content[0].Actions, 3) assert.True(t, pushRuleset.Content[0].Default) assert.True(t, pushRuleset.Content[0].Enabled) assert.Empty(t, pushRuleset.Content[0].Conditions) assert.Equal(t, "alice", pushRuleset.Content[0].Pattern) assert.Equal(t, ".m.rule.contains_user_name", pushRuleset.Content[0].RuleID) assert.False(t, pushRuleset.Override[0].Actions.Should().Notify) assert.True(t, pushRuleset.Override[0].Actions.Should().NotifySpecified) } const JSONExamplePushRules = `{ "global": { "content": [ { "actions": [ "notify", { "set_tweak": "sound", "value": "default" }, { "set_tweak": "highlight" } ], "default": true, "enabled": true, "pattern": "alice", "rule_id": ".m.rule.contains_user_name" } ], "override": [ { "actions": [ "dont_notify" ], "conditions": [], "default": true, "enabled": false, "rule_id": ".m.rule.master" }, { "actions": [ "dont_notify" ], "conditions": [ { "key": "content.msgtype", "kind": "event_match", "pattern": "m.notice" } ], "default": true, "enabled": true, "rule_id": ".m.rule.suppress_notices" } ], "room": [], "sender": [], "underride": [ { "actions": [ "notify", { "set_tweak": "sound", "value": "ring" }, { "set_tweak": "highlight", "value": false } ], "conditions": [ { "key": "type", "kind": "event_match", "pattern": "m.call.invite" } ], "default": true, "enabled": true, "rule_id": ".m.rule.call" }, { "actions": [ "notify", { "set_tweak": "sound", "value": "default" }, { "set_tweak": "highlight" } ], "conditions": [ { "kind": "contains_display_name" } ], "default": true, "enabled": true, "rule_id": ".m.rule.contains_display_name" }, { "actions": [ "notify", { "set_tweak": "sound", "value": "default" }, { "set_tweak": "highlight", "value": false } ], "conditions": [ { "is": "2", "kind": "room_member_count" } ], "default": true, "enabled": true, "rule_id": ".m.rule.room_one_to_one" }, { "actions": [ "notify", { "set_tweak": "sound", "value": "default" }, { "set_tweak": "highlight", "value": false } ], "conditions": [ { "key": "type", "kind": "event_match", "pattern": "m.room.member" }, { "key": "content.membership", "kind": "event_match", "pattern": "invite" }, { "key": "state_key", "kind": "event_match", "pattern": "@alice:example.com" } ], "default": true, "enabled": true, "rule_id": ".m.rule.invite_for_me" }, { "actions": [ "notify", { "set_tweak": "highlight", "value": false } ], "conditions": [ { "key": "type", "kind": "event_match", "pattern": "m.room.member" } ], "default": true, "enabled": true, "rule_id": ".m.rule.member_event" }, { "actions": [ "notify", { "set_tweak": "highlight", "value": false } ], "conditions": [ { "key": "type", "kind": "event_match", "pattern": "m.room.message" } ], "default": true, "enabled": true, "rule_id": ".m.rule.message" } ] } }` go-0.11.1/pushrules/rule.go000066400000000000000000000072131436100171500155300ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package pushrules import ( "encoding/gob" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" "maunium.net/go/mautrix/pushrules/glob" ) func init() { gob.Register(PushRuleArray{}) gob.Register(PushRuleMap{}) } type PushRuleCollection interface { GetActions(room Room, evt *event.Event) PushActionArray } type PushRuleArray []*PushRule func (rules PushRuleArray) SetType(typ PushRuleType) PushRuleArray { for _, rule := range rules { rule.Type = typ } return rules } func (rules PushRuleArray) GetActions(room Room, evt *event.Event) PushActionArray { for _, rule := range rules { if !rule.Match(room, evt) { continue } return rule.Actions } return nil } type PushRuleMap struct { Map map[string]*PushRule Type PushRuleType } func (rules PushRuleArray) SetTypeAndMap(typ PushRuleType) PushRuleMap { data := PushRuleMap{ Map: make(map[string]*PushRule), Type: typ, } for _, rule := range rules { rule.Type = typ data.Map[rule.RuleID] = rule } return data } func (ruleMap PushRuleMap) GetActions(room Room, evt *event.Event) PushActionArray { var rule *PushRule var found bool switch ruleMap.Type { case RoomRule: rule, found = ruleMap.Map[string(evt.RoomID)] case SenderRule: rule, found = ruleMap.Map[string(evt.Sender)] } if found && rule.Match(room, evt) { return rule.Actions } return nil } func (ruleMap PushRuleMap) Unmap() PushRuleArray { array := make(PushRuleArray, len(ruleMap.Map)) index := 0 for _, rule := range ruleMap.Map { array[index] = rule index++ } return array } type PushRuleType string const ( OverrideRule PushRuleType = "override" ContentRule PushRuleType = "content" RoomRule PushRuleType = "room" SenderRule PushRuleType = "sender" UnderrideRule PushRuleType = "underride" ) type PushRule struct { // The type of this rule. Type PushRuleType `json:"-"` // The ID of this rule. // For room-specific rules and user-specific rules, this is the room or user ID (respectively) // For other types of rules, this doesn't affect anything. RuleID string `json:"rule_id"` // The actions this rule should trigger when matched. Actions PushActionArray `json:"actions"` // Whether this is a default rule, or has been set explicitly. Default bool `json:"default"` // Whether or not this push rule is enabled. Enabled bool `json:"enabled"` // The conditions to match in order to trigger this rule. // Only applicable to generic underride/override rules. Conditions []*PushCondition `json:"conditions,omitempty"` // Pattern for content-specific push rules Pattern string `json:"pattern,omitempty"` } func (rule *PushRule) Match(room Room, evt *event.Event) bool { if !rule.Enabled { return false } switch rule.Type { case OverrideRule, UnderrideRule: return rule.matchConditions(room, evt) case ContentRule: return rule.matchPattern(room, evt) case RoomRule: return id.RoomID(rule.RuleID) == evt.RoomID case SenderRule: return id.UserID(rule.RuleID) == evt.Sender default: return false } } func (rule *PushRule) matchConditions(room Room, evt *event.Event) bool { for _, cond := range rule.Conditions { if !cond.Match(room, evt) { return false } } return true } func (rule *PushRule) matchPattern(room Room, evt *event.Event) bool { pattern, err := glob.Compile(rule.Pattern) if err != nil { return false } msg, ok := evt.Content.Raw["body"].(string) if !ok { return false } return pattern.MatchString(msg) } go-0.11.1/pushrules/rule_array_test.go000066400000000000000000000175701436100171500177740ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package pushrules_test import ( "github.com/stretchr/testify/assert" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/pushrules" "testing" ) func TestPushRuleArray_GetActions_FirstMatchReturns(t *testing.T) { cond1 := newMatchPushCondition("content.msgtype", "m.emote") cond2 := newMatchPushCondition("content.body", "no match") actions1 := pushrules.PushActionArray{ {Action: pushrules.ActionNotify}, {Action: pushrules.ActionSetTweak, Tweak: pushrules.TweakHighlight}, {Action: pushrules.ActionSetTweak, Tweak: pushrules.TweakSound, Value: "ping"}, } rule1 := &pushrules.PushRule{ Type: pushrules.OverrideRule, Enabled: true, Conditions: []*pushrules.PushCondition{cond1, cond2}, Actions: actions1, } actions2 := pushrules.PushActionArray{ {Action: pushrules.ActionDontNotify}, {Action: pushrules.ActionSetTweak, Tweak: pushrules.TweakHighlight, Value: false}, {Action: pushrules.ActionSetTweak, Tweak: pushrules.TweakSound, Value: "pong"}, } rule2 := &pushrules.PushRule{ Type: pushrules.RoomRule, Enabled: true, RuleID: "!fakeroom:maunium.net", Actions: actions2, } actions3 := pushrules.PushActionArray{ {Action: pushrules.ActionCoalesce}, } rule3 := &pushrules.PushRule{ Type: pushrules.SenderRule, Enabled: true, RuleID: "@tulir:maunium.net", Actions: actions3, } rules := pushrules.PushRuleArray{rule1, rule2, rule3} evt := newFakeEvent(event.EventMessage, &event.MessageEventContent{ MsgType: event.MsgEmote, Body: "is testing pushrules", }) assert.Equal(t, rules.GetActions(blankTestRoom, evt), actions2) } func TestPushRuleArray_GetActions_NoMatchesIsNil(t *testing.T) { cond1 := newMatchPushCondition("content.msgtype", "m.emote") cond2 := newMatchPushCondition("content.body", "no match") actions1 := pushrules.PushActionArray{ {Action: pushrules.ActionNotify}, {Action: pushrules.ActionSetTweak, Tweak: pushrules.TweakHighlight}, {Action: pushrules.ActionSetTweak, Tweak: pushrules.TweakSound, Value: "ping"}, } rule1 := &pushrules.PushRule{ Type: pushrules.OverrideRule, Enabled: true, Conditions: []*pushrules.PushCondition{cond1, cond2}, Actions: actions1, } actions2 := pushrules.PushActionArray{ {Action: pushrules.ActionDontNotify}, {Action: pushrules.ActionSetTweak, Tweak: pushrules.TweakHighlight, Value: false}, {Action: pushrules.ActionSetTweak, Tweak: pushrules.TweakSound, Value: "pong"}, } rule2 := &pushrules.PushRule{ Type: pushrules.RoomRule, Enabled: true, RuleID: "!realroom:maunium.net", Actions: actions2, } actions3 := pushrules.PushActionArray{ {Action: pushrules.ActionCoalesce}, } rule3 := &pushrules.PushRule{ Type: pushrules.SenderRule, Enabled: true, RuleID: "@otheruser:maunium.net", Actions: actions3, } rules := pushrules.PushRuleArray{rule1, rule2, rule3} evt := newFakeEvent(event.EventMessage, &event.MessageEventContent{ MsgType: event.MsgEmote, Body: "is testing pushrules", }) assert.Nil(t, rules.GetActions(blankTestRoom, evt)) } func TestPushRuleMap_GetActions_RoomRuleExists(t *testing.T) { actions1 := pushrules.PushActionArray{ {Action: pushrules.ActionDontNotify}, {Action: pushrules.ActionSetTweak, Tweak: pushrules.TweakHighlight, Value: false}, {Action: pushrules.ActionSetTweak, Tweak: pushrules.TweakSound, Value: "pong"}, } rule1 := &pushrules.PushRule{ Type: pushrules.RoomRule, Enabled: true, RuleID: "!realroom:maunium.net", Actions: actions1, } actions2 := pushrules.PushActionArray{ {Action: pushrules.ActionNotify}, } rule2 := &pushrules.PushRule{ Type: pushrules.RoomRule, Enabled: true, RuleID: "!thirdroom:maunium.net", Actions: actions2, } actions3 := pushrules.PushActionArray{ {Action: pushrules.ActionCoalesce}, } rule3 := &pushrules.PushRule{ Type: pushrules.RoomRule, Enabled: true, RuleID: "!fakeroom:maunium.net", Actions: actions3, } rules := pushrules.PushRuleMap{ Map: map[string]*pushrules.PushRule{ rule1.RuleID: rule1, rule2.RuleID: rule2, rule3.RuleID: rule3, }, Type: pushrules.RoomRule, } evt := newFakeEvent(event.EventMessage, &event.MessageEventContent{ MsgType: event.MsgEmote, Body: "is testing pushrules", }) assert.Equal(t, rules.GetActions(blankTestRoom, evt), actions3) } func TestPushRuleMap_GetActions_RoomRuleDoesntExist(t *testing.T) { actions1 := pushrules.PushActionArray{ {Action: pushrules.ActionDontNotify}, {Action: pushrules.ActionSetTweak, Tweak: pushrules.TweakHighlight, Value: false}, {Action: pushrules.ActionSetTweak, Tweak: pushrules.TweakSound, Value: "pong"}, } rule1 := &pushrules.PushRule{ Type: pushrules.RoomRule, Enabled: true, RuleID: "!realroom:maunium.net", Actions: actions1, } actions2 := pushrules.PushActionArray{ {Action: pushrules.ActionNotify}, } rule2 := &pushrules.PushRule{ Type: pushrules.RoomRule, Enabled: true, RuleID: "!thirdroom:maunium.net", Actions: actions2, } rules := pushrules.PushRuleMap{ Map: map[string]*pushrules.PushRule{ rule1.RuleID: rule1, rule2.RuleID: rule2, }, Type: pushrules.RoomRule, } evt := newFakeEvent(event.EventMessage, &event.MessageEventContent{ MsgType: event.MsgEmote, Body: "is testing pushrules", }) assert.Nil(t, rules.GetActions(blankTestRoom, evt)) } func TestPushRuleMap_GetActions_SenderRuleExists(t *testing.T) { actions1 := pushrules.PushActionArray{ {Action: pushrules.ActionDontNotify}, {Action: pushrules.ActionSetTweak, Tweak: pushrules.TweakHighlight, Value: false}, {Action: pushrules.ActionSetTweak, Tweak: pushrules.TweakSound, Value: "pong"}, } rule1 := &pushrules.PushRule{ Type: pushrules.SenderRule, Enabled: true, RuleID: "@tulir:maunium.net", Actions: actions1, } actions2 := pushrules.PushActionArray{ {Action: pushrules.ActionNotify}, } rule2 := &pushrules.PushRule{ Type: pushrules.SenderRule, Enabled: true, RuleID: "@someone:maunium.net", Actions: actions2, } actions3 := pushrules.PushActionArray{ {Action: pushrules.ActionCoalesce}, } rule3 := &pushrules.PushRule{ Type: pushrules.SenderRule, Enabled: true, RuleID: "@otheruser:matrix.org", Actions: actions3, } rules := pushrules.PushRuleMap{ Map: map[string]*pushrules.PushRule{ rule1.RuleID: rule1, rule2.RuleID: rule2, rule3.RuleID: rule3, }, Type: pushrules.SenderRule, } evt := newFakeEvent(event.EventMessage, &event.MessageEventContent{ MsgType: event.MsgEmote, Body: "is testing pushrules", }) assert.Equal(t, rules.GetActions(blankTestRoom, evt), actions1) } func TestPushRuleArray_SetTypeAndMap(t *testing.T) { actions1 := pushrules.PushActionArray{ {Action: pushrules.ActionDontNotify}, {Action: pushrules.ActionSetTweak, Tweak: pushrules.TweakHighlight, Value: false}, {Action: pushrules.ActionSetTweak, Tweak: pushrules.TweakSound, Value: "pong"}, } rule1 := &pushrules.PushRule{ Enabled: true, RuleID: "@tulir:maunium.net", Actions: actions1, } actions2 := pushrules.PushActionArray{ {Action: pushrules.ActionNotify}, } rule2 := &pushrules.PushRule{ Enabled: true, RuleID: "@someone:maunium.net", Actions: actions2, } actions3 := pushrules.PushActionArray{ {Action: pushrules.ActionCoalesce}, } rule3 := &pushrules.PushRule{ Enabled: true, RuleID: "@otheruser:matrix.org", Actions: actions3, } ruleArray := pushrules.PushRuleArray{rule1, rule2, rule3} ruleMap := ruleArray.SetTypeAndMap(pushrules.SenderRule) assert.Equal(t, pushrules.SenderRule, ruleMap.Type) for _, rule := range ruleArray { assert.Equal(t, rule, ruleMap.Map[rule.RuleID]) } newRuleArray := ruleMap.Unmap() for _, rule := range ruleArray { assert.Contains(t, newRuleArray, rule) } } go-0.11.1/pushrules/rule_test.go000066400000000000000000000112371436100171500165700ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package pushrules_test import ( "github.com/stretchr/testify/assert" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/pushrules" "testing" ) func TestPushRule_Match_Conditions(t *testing.T) { cond1 := newMatchPushCondition("content.msgtype", "m.emote") cond2 := newMatchPushCondition("content.body", "*pushrules") rule := &pushrules.PushRule{ Type: pushrules.OverrideRule, Enabled: true, Conditions: []*pushrules.PushCondition{cond1, cond2}, } evt := newFakeEvent(event.EventMessage, &event.MessageEventContent{ MsgType: event.MsgEmote, Body: "is testing pushrules", }) assert.True(t, rule.Match(blankTestRoom, evt)) } func TestPushRule_Match_Conditions_Disabled(t *testing.T) { cond1 := newMatchPushCondition("content.msgtype", "m.emote") cond2 := newMatchPushCondition("content.body", "*pushrules") rule := &pushrules.PushRule{ Type: pushrules.OverrideRule, Enabled: false, Conditions: []*pushrules.PushCondition{cond1, cond2}, } evt := newFakeEvent(event.EventMessage, &event.MessageEventContent{ MsgType: event.MsgEmote, Body: "is testing pushrules", }) assert.False(t, rule.Match(blankTestRoom, evt)) } func TestPushRule_Match_Conditions_FailIfOneFails(t *testing.T) { cond1 := newMatchPushCondition("content.msgtype", "m.emote") cond2 := newMatchPushCondition("content.body", "*pushrules") rule := &pushrules.PushRule{ Type: pushrules.OverrideRule, Enabled: true, Conditions: []*pushrules.PushCondition{cond1, cond2}, } evt := newFakeEvent(event.EventMessage, &event.MessageEventContent{ MsgType: event.MsgText, Body: "I'm testing pushrules", }) assert.False(t, rule.Match(blankTestRoom, evt)) } func TestPushRule_Match_Content(t *testing.T) { rule := &pushrules.PushRule{ Type: pushrules.ContentRule, Enabled: true, Pattern: "is testing*", } evt := newFakeEvent(event.EventMessage, &event.MessageEventContent{ MsgType: event.MsgEmote, Body: "is testing pushrules", }) assert.True(t, rule.Match(blankTestRoom, evt)) } func TestPushRule_Match_Content_Fail(t *testing.T) { rule := &pushrules.PushRule{ Type: pushrules.ContentRule, Enabled: true, Pattern: "is testing*", } evt := newFakeEvent(event.EventMessage, &event.MessageEventContent{ MsgType: event.MsgEmote, Body: "is not testing pushrules", }) assert.False(t, rule.Match(blankTestRoom, evt)) } func TestPushRule_Match_Content_ImplicitGlob(t *testing.T) { rule := &pushrules.PushRule{ Type: pushrules.ContentRule, Enabled: true, Pattern: "testing", } evt := newFakeEvent(event.EventMessage, &event.MessageEventContent{ MsgType: event.MsgEmote, Body: "is not testing pushrules", }) assert.True(t, rule.Match(blankTestRoom, evt)) } func TestPushRule_Match_Content_IllegalGlob(t *testing.T) { rule := &pushrules.PushRule{ Type: pushrules.ContentRule, Enabled: true, Pattern: "this is not a valid glo[b", } evt := newFakeEvent(event.EventMessage, &event.MessageEventContent{ MsgType: event.MsgEmote, Body: "this is not a valid glob", }) assert.False(t, rule.Match(blankTestRoom, evt)) } func TestPushRule_Match_Room(t *testing.T) { rule := &pushrules.PushRule{ Type: pushrules.RoomRule, Enabled: true, RuleID: "!fakeroom:maunium.net", } evt := newFakeEvent(event.EventMessage, &struct{}{}) assert.True(t, rule.Match(blankTestRoom, evt)) } func TestPushRule_Match_Room_Fail(t *testing.T) { rule := &pushrules.PushRule{ Type: pushrules.RoomRule, Enabled: true, RuleID: "!otherroom:maunium.net", } evt := newFakeEvent(event.EventMessage, &struct{}{}) assert.False(t, rule.Match(blankTestRoom, evt)) } func TestPushRule_Match_Sender(t *testing.T) { rule := &pushrules.PushRule{ Type: pushrules.SenderRule, Enabled: true, RuleID: "@tulir:maunium.net", } evt := newFakeEvent(event.EventMessage, &struct{}{}) assert.True(t, rule.Match(blankTestRoom, evt)) } func TestPushRule_Match_Sender_Fail(t *testing.T) { rule := &pushrules.PushRule{ Type: pushrules.RoomRule, Enabled: true, RuleID: "@someone:matrix.org", } evt := newFakeEvent(event.EventMessage, &struct{}{}) assert.False(t, rule.Match(blankTestRoom, evt)) } func TestPushRule_Match_UnknownTypeAlwaysFail(t *testing.T) { rule := &pushrules.PushRule{ Type: pushrules.PushRuleType("foobar"), Enabled: true, RuleID: "@someone:matrix.org", } evt := newFakeEvent(event.EventMessage, &struct{}{}) assert.False(t, rule.Match(blankTestRoom, evt)) } go-0.11.1/pushrules/ruleset.go000066400000000000000000000052431436100171500162450ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package pushrules import ( "encoding/json" "maunium.net/go/mautrix/event" ) type PushRuleset struct { Override PushRuleArray Content PushRuleArray Room PushRuleMap Sender PushRuleMap Underride PushRuleArray } type rawPushRuleset struct { Override PushRuleArray `json:"override"` Content PushRuleArray `json:"content"` Room PushRuleArray `json:"room"` Sender PushRuleArray `json:"sender"` Underride PushRuleArray `json:"underride"` } // UnmarshalJSON parses JSON into this PushRuleset. // // For override, sender and underride push rule arrays, the type is added // to each PushRule and the array is used as-is. // // For room and sender push rule arrays, the type is added to each PushRule // and the array is converted to a map with the rule ID as the key and the // PushRule as the value. func (rs *PushRuleset) UnmarshalJSON(raw []byte) (err error) { data := rawPushRuleset{} err = json.Unmarshal(raw, &data) if err != nil { return } rs.Override = data.Override.SetType(OverrideRule) rs.Content = data.Content.SetType(ContentRule) rs.Room = data.Room.SetTypeAndMap(RoomRule) rs.Sender = data.Sender.SetTypeAndMap(SenderRule) rs.Underride = data.Underride.SetType(UnderrideRule) return } // MarshalJSON is the reverse of UnmarshalJSON() func (rs *PushRuleset) MarshalJSON() ([]byte, error) { data := rawPushRuleset{ Override: rs.Override, Content: rs.Content, Room: rs.Room.Unmap(), Sender: rs.Sender.Unmap(), Underride: rs.Underride, } return json.Marshal(&data) } // DefaultPushActions is the value returned if none of the rule // collections in a Ruleset match the event given to GetActions() var DefaultPushActions = PushActionArray{&PushAction{Action: ActionDontNotify}} // GetActions matches the given event against all of the push rule // collections in this push ruleset in the order of priority as // specified in spec section 11.12.1.4. func (rs *PushRuleset) GetActions(room Room, evt *event.Event) (match PushActionArray) { // Add push rule collections to array in priority order arrays := []PushRuleCollection{rs.Override, rs.Content, rs.Room, rs.Sender, rs.Underride} // Loop until one of the push rule collections matches the room/event combo. for _, pra := range arrays { if pra == nil { continue } if match = pra.GetActions(room, evt); match != nil { // Match found, return it. return } } // No match found, return default actions. return DefaultPushActions } go-0.11.1/requests.go000066400000000000000000000247521436100171500144110ustar00rootroot00000000000000package mautrix import ( "encoding/json" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" "maunium.net/go/mautrix/pushrules" ) type AuthType string const ( AuthTypePassword AuthType = "m.login.password" AuthTypeReCAPTCHA AuthType = "m.login.recaptcha" AuthTypeOAuth2 AuthType = "m.login.oauth2" AuthTypeSSO AuthType = "m.login.sso" AuthTypeEmail AuthType = "m.login.email.identity" AuthTypeMSISDN AuthType = "m.login.msisdn" AuthTypeToken AuthType = "m.login.token" AuthTypeDummy AuthType = "m.login.dummy" AuthTypeAppservice AuthType = "m.login.application_service" ) type IdentifierType string const ( IdentifierTypeUser = "m.id.user" IdentifierTypeThirdParty = "m.id.thirdparty" IdentifierTypePhone = "m.id.phone" ) // ReqRegister is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3register type ReqRegister struct { Username string `json:"username,omitempty"` Password string `json:"password,omitempty"` DeviceID id.DeviceID `json:"device_id,omitempty"` InitialDeviceDisplayName string `json:"initial_device_display_name,omitempty"` InhibitLogin bool `json:"inhibit_login,omitempty"` Auth interface{} `json:"auth,omitempty"` // Type for registration, only used for appservice user registrations // https://spec.matrix.org/v1.2/application-service-api/#server-admin-style-permissions Type AuthType `json:"type,omitempty"` } type BaseAuthData struct { Type AuthType `json:"type"` Session string `json:"session,omitempty"` } type UserIdentifier struct { Type IdentifierType `json:"type"` User string `json:"user,omitempty"` Medium string `json:"medium,omitempty"` Address string `json:"address,omitempty"` Country string `json:"country,omitempty"` Phone string `json:"phone,omitempty"` } // ReqLogin is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3login type ReqLogin struct { Type AuthType `json:"type"` Identifier UserIdentifier `json:"identifier"` Password string `json:"password,omitempty"` Token string `json:"token,omitempty"` DeviceID id.DeviceID `json:"device_id,omitempty"` InitialDeviceDisplayName string `json:"initial_device_display_name,omitempty"` // Whether or not the returned credentials should be stored in the Client StoreCredentials bool `json:"-"` // Whether or not the returned .well-known data should update the homeserver URL in the Client StoreHomeserverURL bool `json:"-"` } type ReqUIAuthFallback struct { Session string `json:"session"` User string `json:"user"` } type ReqUIAuthLogin struct { BaseAuthData User string `json:"user"` Password string `json:"password"` } // ReqCreateRoom is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3createroom type ReqCreateRoom struct { Visibility string `json:"visibility,omitempty"` RoomAliasName string `json:"room_alias_name,omitempty"` Name string `json:"name,omitempty"` Topic string `json:"topic,omitempty"` Invite []id.UserID `json:"invite,omitempty"` Invite3PID []ReqInvite3PID `json:"invite_3pid,omitempty"` CreationContent map[string]interface{} `json:"creation_content,omitempty"` InitialState []*event.Event `json:"initial_state,omitempty"` Preset string `json:"preset,omitempty"` IsDirect bool `json:"is_direct,omitempty"` PowerLevelOverride *event.PowerLevelsEventContent `json:"power_level_content_override,omitempty"` } // ReqRedact is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidredacteventidtxnid type ReqRedact struct { Reason string TxnID string Extra map[string]interface{} } type ReqMembers struct { At string `json:"at"` Membership event.Membership `json:"membership,omitempty"` NotMembership event.Membership `json:"not_membership,omitempty"` } // ReqInvite3PID is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidinvite-1 // It is also a JSON object used in https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3createroom type ReqInvite3PID struct { IDServer string `json:"id_server"` Medium string `json:"medium"` Address string `json:"address"` } type ReqLeave struct { Reason string `json:"reason,omitempty"` } // ReqInviteUser is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidinvite type ReqInviteUser struct { Reason string `json:"reason,omitempty"` UserID id.UserID `json:"user_id"` } // ReqKickUser is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidkick type ReqKickUser struct { Reason string `json:"reason,omitempty"` UserID id.UserID `json:"user_id"` } // ReqBanUser is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidban type ReqBanUser struct { Reason string `json:"reason,omitempty"` UserID id.UserID `json:"user_id"` } // ReqUnbanUser is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidunban type ReqUnbanUser struct { Reason string `json:"reason,omitempty"` UserID id.UserID `json:"user_id"` } // ReqTyping is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidtypinguserid type ReqTyping struct { Typing bool `json:"typing"` Timeout int64 `json:"timeout,omitempty"` } type ReqPresence struct { Presence event.Presence `json:"presence"` } type ReqAliasCreate struct { RoomID id.RoomID `json:"room_id"` } type OneTimeKey struct { Key id.Curve25519 `json:"key"` IsSigned bool `json:"-"` Signatures Signatures `json:"signatures,omitempty"` Unsigned map[string]interface{} `json:"unsigned,omitempty"` } type serializableOTK OneTimeKey func (otk *OneTimeKey) UnmarshalJSON(data []byte) (err error) { if len(data) > 0 && data[0] == '"' && data[len(data)-1] == '"' { err = json.Unmarshal(data, &otk.Key) otk.Signatures = nil otk.Unsigned = nil otk.IsSigned = false } else { err = json.Unmarshal(data, (*serializableOTK)(otk)) otk.IsSigned = true } return err } func (otk *OneTimeKey) MarshalJSON() ([]byte, error) { if !otk.IsSigned { return json.Marshal(otk.Key) } else { return json.Marshal((*serializableOTK)(otk)) } } type ReqUploadKeys struct { DeviceKeys *DeviceKeys `json:"device_keys,omitempty"` OneTimeKeys map[id.KeyID]OneTimeKey `json:"one_time_keys"` } type ReqKeysSignatures struct { UserID id.UserID `json:"user_id"` DeviceID id.DeviceID `json:"device_id,omitempty"` Algorithms []id.Algorithm `json:"algorithms,omitempty"` Usage []id.CrossSigningUsage `json:"usage,omitempty"` Keys map[id.KeyID]string `json:"keys"` Signatures Signatures `json:"signatures"` } type ReqUploadSignatures map[id.UserID]map[string]ReqKeysSignatures type DeviceKeys struct { UserID id.UserID `json:"user_id"` DeviceID id.DeviceID `json:"device_id"` Algorithms []id.Algorithm `json:"algorithms"` Keys KeyMap `json:"keys"` Signatures Signatures `json:"signatures"` Unsigned map[string]interface{} `json:"unsigned,omitempty"` } type CrossSigningKeys struct { UserID id.UserID `json:"user_id"` Usage []id.CrossSigningUsage `json:"usage"` Keys map[id.KeyID]id.Ed25519 `json:"keys"` Signatures map[id.UserID]map[id.KeyID]string `json:"signatures,omitempty"` } func (csk *CrossSigningKeys) FirstKey() id.Ed25519 { for _, key := range csk.Keys { return key } return "" } type UploadCrossSigningKeysReq struct { Master CrossSigningKeys `json:"master_key"` SelfSigning CrossSigningKeys `json:"self_signing_key"` UserSigning CrossSigningKeys `json:"user_signing_key"` Auth interface{} `json:"auth,omitempty"` } type KeyMap map[id.DeviceKeyID]string func (km KeyMap) GetEd25519(deviceID id.DeviceID) id.Ed25519 { val, ok := km[id.NewDeviceKeyID(id.KeyAlgorithmEd25519, deviceID)] if !ok { return "" } return id.Ed25519(val) } func (km KeyMap) GetCurve25519(deviceID id.DeviceID) id.Curve25519 { val, ok := km[id.NewDeviceKeyID(id.KeyAlgorithmCurve25519, deviceID)] if !ok { return "" } return id.Curve25519(val) } type Signatures map[id.UserID]map[id.KeyID]string type ReqQueryKeys struct { DeviceKeys DeviceKeysRequest `json:"device_keys"` Timeout int64 `json:"timeout,omitempty"` Token string `json:"token,omitempty"` } type DeviceKeysRequest map[id.UserID]DeviceIDList type DeviceIDList []id.DeviceID type ReqClaimKeys struct { OneTimeKeys OneTimeKeysRequest `json:"one_time_keys"` Timeout int64 `json:"timeout,omitempty"` } type OneTimeKeysRequest map[id.UserID]map[id.DeviceID]id.KeyAlgorithm type ReqSendToDevice struct { Messages map[id.UserID]map[id.DeviceID]*event.Content `json:"messages"` } // ReqDeviceInfo is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3devicesdeviceid type ReqDeviceInfo struct { DisplayName string `json:"display_name,omitempty"` } // ReqDeleteDevice is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#delete_matrixclientv3devicesdeviceid type ReqDeleteDevice struct { Auth interface{} `json:"auth,omitempty"` } // ReqDeleteDevices is the JSON request for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3delete_devices type ReqDeleteDevices struct { Devices []id.DeviceID `json:"devices"` Auth interface{} `json:"auth,omitempty"` } type ReqPutPushRule struct { Before string `json:"-"` After string `json:"-"` Actions []pushrules.PushActionType `json:"actions"` Conditions []pushrules.PushCondition `json:"conditions"` Pattern string `json:"pattern"` } type ReqBatchSend struct { PrevEventID id.EventID `json:"-"` BatchID id.BatchID `json:"-"` StateEventsAtStart []*event.Event `json:"state_events_at_start"` Events []*event.Event `json:"events"` } type ReqSetReadMarkers struct { Read id.EventID `json:"m.read"` FullyRead id.EventID `json:"m.fully_read"` } go-0.11.1/responses.go000066400000000000000000000272761436100171500145630ustar00rootroot00000000000000package mautrix import ( "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) // RespWhoami is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3accountwhoami type RespWhoami struct { UserID id.UserID `json:"user_id"` DeviceID id.DeviceID `json:"device_id"` } // RespCreateFilter is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3useruseridfilter type RespCreateFilter struct { FilterID string `json:"filter_id"` } // RespJoinRoom is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidjoin type RespJoinRoom struct { RoomID id.RoomID `json:"room_id"` } // RespLeaveRoom is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidleave type RespLeaveRoom struct{} // RespForgetRoom is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidforget type RespForgetRoom struct{} // RespInviteUser is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidinvite type RespInviteUser struct{} // RespKickUser is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidkick type RespKickUser struct{} // RespBanUser is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidban type RespBanUser struct{} // RespUnbanUser is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidunban type RespUnbanUser struct{} // RespTyping is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidtypinguserid type RespTyping struct{} // RespPresence is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3presenceuseridstatus type RespPresence struct { Presence event.Presence `json:"presence"` LastActiveAgo int `json:"last_active_ago"` StatusMsg string `json:"status_msg"` CurrentlyActive bool `json:"currently_active"` } // RespJoinedRooms is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3joined_rooms type RespJoinedRooms struct { JoinedRooms []id.RoomID `json:"joined_rooms"` } // RespJoinedMembers is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3roomsroomidjoined_members type RespJoinedMembers struct { Joined map[id.UserID]struct { DisplayName *string `json:"display_name"` AvatarURL *string `json:"avatar_url"` } `json:"joined"` } // RespMessages is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3roomsroomidmessages type RespMessages struct { Start string `json:"start"` Chunk []*event.Event `json:"chunk"` State []*event.Event `json:"state"` End string `json:"end"` } // RespContext is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3roomsroomidcontexteventid type RespContext struct { End string `json:"end"` Event *event.Event `json:"event"` EventsAfter []*event.Event `json:"events_after"` EventsBefore []*event.Event `json:"events_before"` Start string `json:"start"` State []*event.Event `json:"state"` } // RespSendEvent is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3roomsroomidsendeventtypetxnid type RespSendEvent struct { EventID id.EventID `json:"event_id"` } // RespMediaUpload is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#post_matrixmediav3upload type RespMediaUpload struct { ContentURI id.ContentURI `json:"content_uri"` } // RespCreateMXC is the JSON response for /_matrix/media/v3/create as specified in https://github.com/matrix-org/matrix-spec-proposals/pull/2246 type RespCreateMXC struct { ContentURI id.ContentURI `json:"content_uri"` UnusedExpiresAt int `json:"unused_expires_at,omitempty"` } // RespPreviewURL is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#get_matrixmediav3preview_url type RespPreviewURL struct { CanonicalURL string `json:"og:url,omitempty"` Title string `json:"og:title,omitempty"` Type string `json:"og:type,omitempty"` Description string `json:"og:description,omitempty"` ImageURL id.ContentURIString `json:"og:image,omitempty"` ImageSize int `json:"matrix:image:size,omitempty"` ImageWidth int `json:"og:image:width,omitempty"` ImageHeight int `json:"og:image:height,omitempty"` ImageType string `json:"og:image:type,omitempty"` } // RespUserInteractive is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#user-interactive-authentication-api type RespUserInteractive struct { Flows []struct { Stages []AuthType `json:"stages"` } `json:"flows"` Params map[AuthType]interface{} `json:"params"` Session string `json:"session"` Completed []string `json:"completed"` ErrCode string `json:"errcode"` Error string `json:"error"` } // HasSingleStageFlow returns true if there exists at least 1 Flow with a single stage of stageName. func (r RespUserInteractive) HasSingleStageFlow(stageName AuthType) bool { for _, f := range r.Flows { if len(f.Stages) == 1 && f.Stages[0] == stageName { return true } } return false } // RespUserDisplayName is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3profileuseriddisplayname type RespUserDisplayName struct { DisplayName string `json:"displayname"` } // RespRegister is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3register type RespRegister struct { AccessToken string `json:"access_token"` DeviceID id.DeviceID `json:"device_id"` HomeServer string `json:"home_server"` RefreshToken string `json:"refresh_token"` UserID id.UserID `json:"user_id"` } type LoginFlow struct { Type AuthType `json:"type"` } // RespLoginFlows is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3login type RespLoginFlows struct { Flows []LoginFlow `json:"flows"` } func (rlf *RespLoginFlows) FirstFlowOfType(flowTypes ...AuthType) *LoginFlow { for _, flow := range rlf.Flows { for _, flowType := range flowTypes { if flow.Type == flowType { return &flow } } } return nil } func (rlf *RespLoginFlows) HasFlow(flowType ...AuthType) bool { return rlf.FirstFlowOfType(flowType...) != nil } // RespLogin is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3login type RespLogin struct { AccessToken string `json:"access_token"` DeviceID id.DeviceID `json:"device_id"` UserID id.UserID `json:"user_id"` WellKnown *ClientWellKnown `json:"well_known"` } // RespLogout is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3logout type RespLogout struct{} // RespCreateRoom is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3createroom type RespCreateRoom struct { RoomID id.RoomID `json:"room_id"` } type RespMembers struct { Chunk []*event.Event `json:"chunk"` } type LazyLoadSummary struct { Heroes []id.UserID `json:"m.heroes,omitempty"` JoinedMemberCount *int `json:"m.joined_member_count,omitempty"` InvitedMemberCount *int `json:"m.invited_member_count,omitempty"` } // RespSync is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3sync type RespSync struct { NextBatch string `json:"next_batch"` AccountData struct { Events []*event.Event `json:"events"` } `json:"account_data"` Presence struct { Events []*event.Event `json:"events"` } `json:"presence"` ToDevice struct { Events []*event.Event `json:"events"` } `json:"to_device"` DeviceLists DeviceLists `json:"device_lists"` DeviceOTKCount OTKCount `json:"device_one_time_keys_count"` Rooms struct { Leave map[id.RoomID]SyncLeftRoom `json:"leave"` Join map[id.RoomID]SyncJoinedRoom `json:"join"` Invite map[id.RoomID]SyncInvitedRoom `json:"invite"` } `json:"rooms"` } type DeviceLists struct { Changed []id.UserID `json:"changed"` Left []id.UserID `json:"left"` } type OTKCount struct { Curve25519 int `json:"curve25519"` SignedCurve25519 int `json:"signed_curve25519"` // For appservice OTK counts only: the user ID in question UserID id.UserID `json:"-"` DeviceID id.DeviceID `json:"-"` } type SyncLeftRoom struct { Summary LazyLoadSummary `json:"summary"` State struct { Events []*event.Event `json:"events"` } `json:"state"` Timeline struct { Events []*event.Event `json:"events"` Limited bool `json:"limited"` PrevBatch string `json:"prev_batch"` } `json:"timeline"` } type SyncJoinedRoom struct { Summary LazyLoadSummary `json:"summary"` State struct { Events []*event.Event `json:"events"` } `json:"state"` Timeline struct { Events []*event.Event `json:"events"` Limited bool `json:"limited"` PrevBatch string `json:"prev_batch"` } `json:"timeline"` Ephemeral struct { Events []*event.Event `json:"events"` } `json:"ephemeral"` AccountData struct { Events []*event.Event `json:"events"` } `json:"account_data"` } type SyncInvitedRoom struct { Summary LazyLoadSummary `json:"summary"` State struct { Events []*event.Event `json:"events"` } `json:"invite_state"` } type RespTurnServer struct { Username string `json:"username"` Password string `json:"password"` TTL int `json:"ttl"` URIs []string `json:"uris"` } type RespAliasCreate struct{} type RespAliasDelete struct{} type RespAliasResolve struct { RoomID id.RoomID `json:"room_id"` Servers []string `json:"servers"` } type RespAliasList struct { Aliases []id.RoomAlias `json:"aliases"` } type RespUploadKeys struct { OneTimeKeyCounts OTKCount `json:"one_time_key_counts"` } type RespQueryKeys struct { Failures map[string]interface{} `json:"failures"` DeviceKeys map[id.UserID]map[id.DeviceID]DeviceKeys `json:"device_keys"` MasterKeys map[id.UserID]CrossSigningKeys `json:"master_keys"` SelfSigningKeys map[id.UserID]CrossSigningKeys `json:"self_signing_keys"` UserSigningKeys map[id.UserID]CrossSigningKeys `json:"user_signing_keys"` } type RespClaimKeys struct { Failures map[string]interface{} `json:"failures"` OneTimeKeys map[id.UserID]map[id.DeviceID]map[id.KeyID]OneTimeKey `json:"one_time_keys"` } type RespUploadSignatures struct { Failures map[string]interface{} `json:"failures"` } type RespKeyChanges struct { Changed []id.UserID `json:"changed"` Left []id.UserID `json:"left"` } type RespSendToDevice struct{} // RespDevicesInfo is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3devices type RespDevicesInfo struct { Devices []RespDeviceInfo `json:"devices"` } // RespDeviceInfo is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3devicesdeviceid type RespDeviceInfo struct { DeviceID id.DeviceID `json:"device_id"` DisplayName string `json:"display_name"` LastSeenIP string `json:"last_seen_ip"` LastSeenTS int64 `json:"last_seen_ts"` } type RespBatchSend struct { StateEventIDs []id.EventID `json:"state_event_ids"` EventIDs []id.EventID `json:"event_ids"` InsertionEventID id.EventID `json:"insertion_event_id"` BatchEventID id.EventID `json:"batch_event_id"` BaseInsertionEventID id.EventID `json:"base_insertion_event_id"` NextBatchID id.BatchID `json:"next_batch_id"` } go-0.11.1/room.go000066400000000000000000000030301436100171500134740ustar00rootroot00000000000000package mautrix import ( "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) type RoomStateMap = map[event.Type]map[string]*event.Event // Room represents a single Matrix room. type Room struct { ID id.RoomID State RoomStateMap } // UpdateState updates the room's current state with the given Event. This will clobber events based // on the type/state_key combination. func (room Room) UpdateState(evt *event.Event) { _, exists := room.State[evt.Type] if !exists { room.State[evt.Type] = make(map[string]*event.Event) } room.State[evt.Type][*evt.StateKey] = evt } // GetStateEvent returns the state event for the given type/state_key combo, or nil. func (room Room) GetStateEvent(eventType event.Type, stateKey string) *event.Event { stateEventMap, _ := room.State[eventType] evt, _ := stateEventMap[stateKey] return evt } // GetMembershipState returns the membership state of the given user ID in this room. If there is // no entry for this member, 'leave' is returned for consistency with left users. func (room Room) GetMembershipState(userID id.UserID) event.Membership { state := event.MembershipLeave evt := room.GetStateEvent(event.StateMember, string(userID)) if evt != nil { membership, ok := evt.Content.Raw["membership"].(string) if ok { state = event.Membership(membership) } } return state } // NewRoom creates a new Room with the given ID func NewRoom(roomID id.RoomID) *Room { // Init the State map and return a pointer to the Room return &Room{ ID: roomID, State: make(RoomStateMap), } } go-0.11.1/store.go000066400000000000000000000106431436100171500136640ustar00rootroot00000000000000package mautrix import ( "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) // Storer is an interface which must be satisfied to store client data. // // You can either write a struct which persists this data to disk, or you can use the // provided "InMemoryStore" which just keeps data around in-memory which is lost on // restarts. type Storer interface { SaveFilterID(userID id.UserID, filterID string) LoadFilterID(userID id.UserID) string SaveNextBatch(userID id.UserID, nextBatchToken string) LoadNextBatch(userID id.UserID) string SaveRoom(room *Room) LoadRoom(roomID id.RoomID) *Room } // InMemoryStore implements the Storer interface. // // Everything is persisted in-memory as maps. It is not safe to load/save filter IDs // or next batch tokens on any goroutine other than the syncing goroutine: the one // which called Client.Sync(). type InMemoryStore struct { Filters map[id.UserID]string NextBatch map[id.UserID]string Rooms map[id.RoomID]*Room } // SaveFilterID to memory. func (s *InMemoryStore) SaveFilterID(userID id.UserID, filterID string) { s.Filters[userID] = filterID } // LoadFilterID from memory. func (s *InMemoryStore) LoadFilterID(userID id.UserID) string { return s.Filters[userID] } // SaveNextBatch to memory. func (s *InMemoryStore) SaveNextBatch(userID id.UserID, nextBatchToken string) { s.NextBatch[userID] = nextBatchToken } // LoadNextBatch from memory. func (s *InMemoryStore) LoadNextBatch(userID id.UserID) string { return s.NextBatch[userID] } // SaveRoom to memory. func (s *InMemoryStore) SaveRoom(room *Room) { s.Rooms[room.ID] = room } // LoadRoom from memory. func (s *InMemoryStore) LoadRoom(roomID id.RoomID) *Room { return s.Rooms[roomID] } // UpdateState stores a state event. This can be passed to DefaultSyncer.OnEvent to keep all room state cached. func (s *InMemoryStore) UpdateState(_ EventSource, evt *event.Event) { if !evt.Type.IsState() { return } room := s.LoadRoom(evt.RoomID) if room == nil { room = NewRoom(evt.RoomID) s.SaveRoom(room) } room.UpdateState(evt) } // NewInMemoryStore constructs a new InMemoryStore. func NewInMemoryStore() *InMemoryStore { return &InMemoryStore{ Filters: make(map[id.UserID]string), NextBatch: make(map[id.UserID]string), Rooms: make(map[id.RoomID]*Room), } } // AccountDataStore uses account data to store the next batch token, and // reuses the InMemoryStore for all other operations. type AccountDataStore struct { *InMemoryStore eventType string client *Client } type accountData struct { NextBatch string `json:"next_batch"` } // SaveNextBatch to account data. func (s *AccountDataStore) SaveNextBatch(userID id.UserID, nextBatchToken string) { if userID.String() != s.client.UserID.String() { panic("AccountDataStore must only be used with bots") } data := accountData{ NextBatch: nextBatchToken, } err := s.client.SetAccountData(s.eventType, data) if err != nil { if s.client.Logger != nil { s.client.Logger.Debugfln("failed to save next batch token to account data: %s", err.Error()) } } } // LoadNextBatch from account data. func (s *AccountDataStore) LoadNextBatch(userID id.UserID) string { if userID.String() != s.client.UserID.String() { panic("AccountDataStore must only be used with bots") } data := &accountData{} err := s.client.GetAccountData(s.eventType, data) if err != nil { if s.client.Logger != nil { s.client.Logger.Debugfln("failed to load next batch token to account data: %s", err.Error()) } return "" } return data.NextBatch } // NewAccountDataStore returns a new AccountDataStore, which stores // the next_batch token as a custom event in account data in the // homeserver. // // AccountDataStore is only appropriate for bots, not appservices. // // eventType should be a reversed DNS name like tld.domain.sub.internal and // must be unique for a client. The data stored in it is considered internal // and must not be modified through outside means. You should also add a filter // for account data changes of this event type, to avoid ending up in a sync // loop: // // mautrix.Filter{ // AccountData: mautrix.FilterPart{ // Limit: 20, // NotTypes: []event.Type{ // event.NewEventType(eventType), // }, // }, // } // mautrix.Client.CreateFilter(...) // func NewAccountDataStore(eventType string, client *Client) *AccountDataStore { return &AccountDataStore{ InMemoryStore: NewInMemoryStore(), eventType: eventType, client: client, } } go-0.11.1/sync.go000066400000000000000000000227631436100171500135120ustar00rootroot00000000000000// Copyright (c) 2020 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package mautrix import ( "fmt" "runtime/debug" "time" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) // EventSource represents the part of the sync response that an event came from. type EventSource int const ( EventSourcePresence EventSource = 1 << iota EventSourceJoin EventSourceInvite EventSourceLeave EventSourceAccountData EventSourceTimeline EventSourceState EventSourceEphemeral EventSourceToDevice ) func (es EventSource) String() string { switch { case es == EventSourcePresence: return "presence" case es == EventSourceAccountData: return "user account data" case es == EventSourceToDevice: return "to-device" case es&EventSourceJoin != 0: es -= EventSourceJoin switch es { case EventSourceState: return "joined state" case EventSourceTimeline: return "joined timeline" case EventSourceEphemeral: return "room ephemeral (joined)" case EventSourceAccountData: return "room account data (joined)" } case es&EventSourceInvite != 0: es -= EventSourceInvite switch es { case EventSourceState: return "invited state" } case es&EventSourceLeave != 0: es -= EventSourceLeave switch es { case EventSourceState: return "left state" case EventSourceTimeline: return "left timeline" } } return fmt.Sprintf("unknown (%d)", es) } // EventHandler handles a single event from a sync response. type EventHandler func(source EventSource, evt *event.Event) // SyncHandler handles a whole sync response. If the return value is false, handling will be stopped completely. type SyncHandler func(resp *RespSync, since string) bool // Syncer is an interface that must be satisfied in order to do /sync requests on a client. type Syncer interface { // Process the /sync response. The since parameter is the since= value that was used to produce the response. // This is useful for detecting the very first sync (since=""). If an error is return, Syncing will be stopped // permanently. ProcessResponse(resp *RespSync, since string) error // OnFailedSync returns either the time to wait before retrying or an error to stop syncing permanently. OnFailedSync(res *RespSync, err error) (time.Duration, error) // GetFilterJSON for the given user ID. NOT the filter ID. GetFilterJSON(userID id.UserID) *Filter } type ExtensibleSyncer interface { OnSync(callback SyncHandler) OnEvent(callback EventHandler) OnEventType(eventType event.Type, callback EventHandler) } // DefaultSyncer is the default syncing implementation. You can either write your own syncer, or selectively // replace parts of this default syncer (e.g. the ProcessResponse method). The default syncer uses the observer // pattern to notify callers about incoming events. See DefaultSyncer.OnEventType for more information. type DefaultSyncer struct { // syncListeners want the whole sync response, e.g. the crypto machine syncListeners []SyncHandler // globalListeners want all events globalListeners []EventHandler // listeners want a specific event type listeners map[event.Type][]EventHandler // ParseEventContent determines whether or not event content should be parsed before passing to handlers. ParseEventContent bool // ParseErrorHandler is called when event.Content.ParseRaw returns an error. // If it returns false, the event will not be forwarded to listeners. ParseErrorHandler func(evt *event.Event, err error) bool } var _ Syncer = (*DefaultSyncer)(nil) var _ ExtensibleSyncer = (*DefaultSyncer)(nil) // NewDefaultSyncer returns an instantiated DefaultSyncer func NewDefaultSyncer() *DefaultSyncer { return &DefaultSyncer{ listeners: make(map[event.Type][]EventHandler), syncListeners: []SyncHandler{}, globalListeners: []EventHandler{}, ParseEventContent: true, ParseErrorHandler: func(evt *event.Event, err error) bool { return false }, } } // ProcessResponse processes the /sync response in a way suitable for bots. "Suitable for bots" means a stream of // unrepeating events. Returns a fatal error if a listener panics. func (s *DefaultSyncer) ProcessResponse(res *RespSync, since string) (err error) { defer func() { if r := recover(); r != nil { err = fmt.Errorf("ProcessResponse panicked! since=%s panic=%s\n%s", since, r, debug.Stack()) } }() for _, listener := range s.syncListeners { if !listener(res, since) { return } } s.processSyncEvents("", res.Presence.Events, EventSourcePresence) s.processSyncEvents("", res.AccountData.Events, EventSourceAccountData) for roomID, roomData := range res.Rooms.Join { s.processSyncEvents(roomID, roomData.State.Events, EventSourceJoin|EventSourceState) s.processSyncEvents(roomID, roomData.Timeline.Events, EventSourceJoin|EventSourceTimeline) s.processSyncEvents(roomID, roomData.Ephemeral.Events, EventSourceJoin|EventSourceEphemeral) s.processSyncEvents(roomID, roomData.AccountData.Events, EventSourceJoin|EventSourceAccountData) } for roomID, roomData := range res.Rooms.Invite { s.processSyncEvents(roomID, roomData.State.Events, EventSourceInvite|EventSourceState) } for roomID, roomData := range res.Rooms.Leave { s.processSyncEvents(roomID, roomData.State.Events, EventSourceLeave|EventSourceState) s.processSyncEvents(roomID, roomData.Timeline.Events, EventSourceLeave|EventSourceTimeline) } return } func (s *DefaultSyncer) processSyncEvents(roomID id.RoomID, events []*event.Event, source EventSource) { for _, evt := range events { s.processSyncEvent(roomID, evt, source) } } func (s *DefaultSyncer) processSyncEvent(roomID id.RoomID, evt *event.Event, source EventSource) { evt.RoomID = roomID // Ensure the type class is correct. It's safe to mutate the class since the event type is not a pointer. // Listeners are keyed by type structs, which means only the correct class will pass. switch { case evt.StateKey != nil: evt.Type.Class = event.StateEventType case source == EventSourcePresence, source&EventSourceEphemeral != 0: evt.Type.Class = event.EphemeralEventType case source&EventSourceAccountData != 0: evt.Type.Class = event.AccountDataEventType case source == EventSourceToDevice: evt.Type.Class = event.ToDeviceEventType default: evt.Type.Class = event.MessageEventType } if s.ParseEventContent { err := evt.Content.ParseRaw(evt.Type) if err != nil && !s.ParseErrorHandler(evt, err) { return } } s.notifyListeners(source, evt) } func (s *DefaultSyncer) notifyListeners(source EventSource, evt *event.Event) { for _, fn := range s.globalListeners { fn(source, evt) } listeners, exists := s.listeners[evt.Type] if exists { for _, fn := range listeners { fn(source, evt) } } } // OnEventType allows callers to be notified when there are new events for the given event type. // There are no duplicate checks. func (s *DefaultSyncer) OnEventType(eventType event.Type, callback EventHandler) { _, exists := s.listeners[eventType] if !exists { s.listeners[eventType] = []EventHandler{} } s.listeners[eventType] = append(s.listeners[eventType], callback) } func (s *DefaultSyncer) OnSync(callback SyncHandler) { s.syncListeners = append(s.syncListeners, callback) } func (s *DefaultSyncer) OnEvent(callback EventHandler) { s.globalListeners = append(s.globalListeners, callback) } // OnFailedSync always returns a 10 second wait period between failed /syncs, never a fatal error. func (s *DefaultSyncer) OnFailedSync(res *RespSync, err error) (time.Duration, error) { return 10 * time.Second, nil } // GetFilterJSON returns a filter with a timeline limit of 50. func (s *DefaultSyncer) GetFilterJSON(userID id.UserID) *Filter { return &Filter{ Room: RoomFilter{ Timeline: FilterPart{ Limit: 50, }, }, } } // OldEventIgnorer is an utility struct for bots to ignore events from before the bot joined the room. // Create a struct and call Register with your DefaultSyncer to register the sync handler. type OldEventIgnorer struct { UserID id.UserID } func (oei *OldEventIgnorer) Register(syncer ExtensibleSyncer) { syncer.OnSync(oei.DontProcessOldEvents) } // DontProcessOldEvents returns true if a sync response should be processed. May modify the response to remove // stuff that shouldn't be processed. func (oei *OldEventIgnorer) DontProcessOldEvents(resp *RespSync, since string) bool { if since == "" { return false } // This is a horrible hack because /sync will return the most recent messages for a room // as soon as you /join it. We do NOT want to process those events in that particular room // because they may have already been processed (if you toggle the bot in/out of the room). // // Work around this by inspecting each room's timeline and seeing if an m.room.member event for us // exists and is "join" and then discard processing that room entirely if so. // TODO: We probably want to process messages from after the last join event in the timeline. for roomID, roomData := range resp.Rooms.Join { for i := len(roomData.Timeline.Events) - 1; i >= 0; i-- { evt := roomData.Timeline.Events[i] if evt.Type == event.StateMember && evt.GetStateKey() == string(oei.UserID) { membership, _ := evt.Content.Raw["membership"].(string) if membership == "join" { _, ok := resp.Rooms.Join[roomID] if !ok { continue } delete(resp.Rooms.Join, roomID) // don't re-process messages delete(resp.Rooms.Invite, roomID) // don't re-process invites break } } } } return true } go-0.11.1/url.go000066400000000000000000000057221436100171500133340ustar00rootroot00000000000000// Copyright (c) 2022 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package mautrix import ( "fmt" "net/url" "strconv" "strings" ) func parseAndNormalizeBaseURL(homeserverURL string) (*url.URL, error) { hsURL, err := url.Parse(homeserverURL) if err != nil { return nil, err } if hsURL.Scheme == "" { hsURL.Scheme = "https" fixedURL := hsURL.String() hsURL, err = url.Parse(fixedURL) if err != nil { return nil, fmt.Errorf("failed to parse fixed URL '%s': %v", fixedURL, err) } } hsURL.RawPath = hsURL.EscapedPath() return hsURL, nil } // BuildURL builds a URL with the given path parts func BuildURL(baseURL *url.URL, path ...interface{}) *url.URL { createdURL := *baseURL rawParts := make([]string, len(path)+1) rawParts[0] = strings.TrimSuffix(createdURL.RawPath, "/") parts := make([]string, len(path)+1) parts[0] = strings.TrimSuffix(createdURL.Path, "/") for i, part := range path { switch casted := part.(type) { case string: parts[i+1] = casted case int: parts[i+1] = strconv.Itoa(casted) case Stringifiable: parts[i+1] = casted.String() default: parts[i+1] = fmt.Sprint(casted) } rawParts[i+1] = url.PathEscape(parts[i+1]) } createdURL.Path = strings.Join(parts, "/") createdURL.RawPath = strings.Join(rawParts, "/") return &createdURL } // BuildURL builds a URL with the Client's homeserver and appservice user ID set already. func (cli *Client) BuildURL(urlPath PrefixableURLPath) string { return cli.BuildURLWithQuery(urlPath, nil) } // BuildClientURL builds a URL with the Client's homeserver and appservice user ID set already. // This method also automatically prepends the client API prefix (/_matrix/client). func (cli *Client) BuildClientURL(urlPath ...interface{}) string { return cli.BuildURLWithQuery(ClientURLPath(urlPath), nil) } type PrefixableURLPath interface { FullPath() []interface{} } type BaseURLPath []interface{} func (bup BaseURLPath) FullPath() []interface{} { return bup } type ClientURLPath []interface{} func (cup ClientURLPath) FullPath() []interface{} { return append([]interface{}{"_matrix", "client"}, []interface{}(cup)...) } type MediaURLPath []interface{} func (mup MediaURLPath) FullPath() []interface{} { return append([]interface{}{"_matrix", "media"}, []interface{}(mup)...) } // BuildURLWithQuery builds a URL with query parameters in addition to the Client's homeserver // and appservice user ID set already. func (cli *Client) BuildURLWithQuery(urlPath PrefixableURLPath, urlQuery map[string]string) string { hsURL := *BuildURL(cli.HomeserverURL, urlPath.FullPath()...) query := hsURL.Query() if cli.AppServiceUserID != "" { query.Set("user_id", string(cli.AppServiceUserID)) } if urlQuery != nil { for k, v := range urlQuery { query.Set(k, v) } } hsURL.RawQuery = query.Encode() return hsURL.String() } go-0.11.1/url_test.go000066400000000000000000000052541436100171500143730ustar00rootroot00000000000000// Copyright (c) 2022 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package mautrix_test import ( "testing" "github.com/stretchr/testify/assert" "maunium.net/go/mautrix" ) func TestClient_BuildURL(t *testing.T) { cli, err := mautrix.NewClient("https://example.com", "", "") assert.NoError(t, err) assert.Equal(t, cli.HomeserverURL.Scheme, "https") assert.Equal(t, cli.HomeserverURL.Host, "example.com") assert.Equal(t, cli.HomeserverURL.Path, "") built := cli.BuildClientURL("v3", "foo/bar%2F๐Ÿˆ 1", "hello", "world") assert.Equal(t, "https://example.com/_matrix/client/v3/foo%2Fbar%252F%F0%9F%90%88%201/hello/world", built) } func TestClient_BuildURL_HTTP(t *testing.T) { cli, err := mautrix.NewClient("http://example.com", "", "") assert.NoError(t, err) assert.Equal(t, cli.HomeserverURL.Scheme, "http") assert.Equal(t, cli.HomeserverURL.Host, "example.com") assert.Equal(t, cli.HomeserverURL.Path, "") built := cli.BuildClientURL("v3", "foo/bar%2F๐Ÿˆ 1", "hello", "world") assert.Equal(t, "http://example.com/_matrix/client/v3/foo%2Fbar%252F%F0%9F%90%88%201/hello/world", built) } func TestClient_BuildURL_MissingScheme(t *testing.T) { cli, err := mautrix.NewClient("example.com", "", "") assert.NoError(t, err) assert.Equal(t, cli.HomeserverURL.Scheme, "https") assert.Equal(t, cli.HomeserverURL.Host, "example.com") assert.Equal(t, cli.HomeserverURL.Path, "") built := cli.BuildClientURL("v3", "foo/bar%2F๐Ÿˆ 1", "hello", "world") assert.Equal(t, "https://example.com/_matrix/client/v3/foo%2Fbar%252F%F0%9F%90%88%201/hello/world", built) } func TestClient_BuildURL_WithPath(t *testing.T) { cli, err := mautrix.NewClient("https://example.com/base", "", "") assert.NoError(t, err) assert.Equal(t, cli.HomeserverURL.Scheme, "https") assert.Equal(t, cli.HomeserverURL.Host, "example.com") assert.Equal(t, cli.HomeserverURL.Path, "/base") built := cli.BuildClientURL("v3", "foo/bar%2F๐Ÿˆ 1", "hello", "world") assert.Equal(t, "https://example.com/base/_matrix/client/v3/foo%2Fbar%252F%F0%9F%90%88%201/hello/world", built) } func TestClient_BuildURL_MissingSchemeWithPath(t *testing.T) { cli, err := mautrix.NewClient("example.com/base", "", "") assert.NoError(t, err) assert.Equal(t, cli.HomeserverURL.Scheme, "https") assert.Equal(t, cli.HomeserverURL.Host, "example.com") assert.Equal(t, cli.HomeserverURL.Path, "/base") built := cli.BuildClientURL("v3", "foo/bar%2F๐Ÿˆ 1", "hello", "world") assert.Equal(t, "https://example.com/base/_matrix/client/v3/foo%2Fbar%252F%F0%9F%90%88%201/hello/world", built) } go-0.11.1/util/000077500000000000000000000000001436100171500131525ustar00rootroot00000000000000go-0.11.1/util/base58/000077500000000000000000000000001436100171500142415ustar00rootroot00000000000000go-0.11.1/util/base58/README.md000066400000000000000000000003031436100171500155140ustar00rootroot00000000000000base58 ========== This is a copy of . ## License Package base58 is licensed under the [copyfree](http://copyfree.org) ISC License. go-0.11.1/util/base58/alphabet.go000066400000000000000000000031561436100171500163550ustar00rootroot00000000000000// Copyright (c) 2015 The btcsuite developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. // AUTOGENERATED by genalphabet.go; do not edit. package base58 const ( // alphabet is the modified base58 alphabet used by Bitcoin. alphabet = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz" alphabetIdx0 = '1' ) var b58 = [256]byte{ 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 1, 2, 3, 4, 5, 6, 7, 8, 255, 255, 255, 255, 255, 255, 255, 9, 10, 11, 12, 13, 14, 15, 16, 255, 17, 18, 19, 20, 21, 255, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 255, 255, 255, 255, 255, 255, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 255, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, } go-0.11.1/util/base58/base58.go000066400000000000000000000063771436100171500156740ustar00rootroot00000000000000// Copyright (c) 2013-2015 The btcsuite developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. package base58 import ( "math/big" ) //go:generate go run genalphabet.go var bigRadix = [...]*big.Int{ big.NewInt(0), big.NewInt(58), big.NewInt(58 * 58), big.NewInt(58 * 58 * 58), big.NewInt(58 * 58 * 58 * 58), big.NewInt(58 * 58 * 58 * 58 * 58), big.NewInt(58 * 58 * 58 * 58 * 58 * 58), big.NewInt(58 * 58 * 58 * 58 * 58 * 58 * 58), big.NewInt(58 * 58 * 58 * 58 * 58 * 58 * 58 * 58), big.NewInt(58 * 58 * 58 * 58 * 58 * 58 * 58 * 58 * 58), bigRadix10, } var bigRadix10 = big.NewInt(58 * 58 * 58 * 58 * 58 * 58 * 58 * 58 * 58 * 58) // 58^10 // Decode decodes a modified base58 string to a byte slice. func Decode(b string) []byte { answer := big.NewInt(0) scratch := new(big.Int) // Calculating with big.Int is slow for each iteration. // x += b58[b[i]] * j // j *= 58 // // Instead we can try to do as much calculations on int64. // We can represent a 10 digit base58 number using an int64. // // Hence we'll try to convert 10, base58 digits at a time. // The rough idea is to calculate `t`, such that: // // t := b58[b[i+9]] * 58^9 ... + b58[b[i+1]] * 58^1 + b58[b[i]] * 58^0 // x *= 58^10 // x += t // // Of course, in addition, we'll need to handle boundary condition when `b` is not multiple of 58^10. // In that case we'll use the bigRadix[n] lookup for the appropriate power. for t := b; len(t) > 0; { n := len(t) if n > 10 { n = 10 } total := uint64(0) for _, v := range t[:n] { tmp := b58[v] if tmp == 255 { return []byte("") } total = total*58 + uint64(tmp) } answer.Mul(answer, bigRadix[n]) scratch.SetUint64(total) answer.Add(answer, scratch) t = t[n:] } tmpval := answer.Bytes() var numZeros int for numZeros = 0; numZeros < len(b); numZeros++ { if b[numZeros] != alphabetIdx0 { break } } flen := numZeros + len(tmpval) val := make([]byte, flen) copy(val[numZeros:], tmpval) return val } // Encode encodes a byte slice to a modified base58 string. func Encode(b []byte) string { x := new(big.Int) x.SetBytes(b) // maximum length of output is log58(2^(8*len(b))) == len(b) * 8 / log(58) maxlen := int(float64(len(b))*1.365658237309761) + 1 answer := make([]byte, 0, maxlen) mod := new(big.Int) for x.Sign() > 0 { // Calculating with big.Int is slow for each iteration. // x, mod = x / 58, x % 58 // // Instead we can try to do as much calculations on int64. // x, mod = x / 58^10, x % 58^10 // // Which will give us mod, which is 10 digit base58 number. // We'll loop that 10 times to convert to the answer. x.DivMod(x, bigRadix10, mod) if x.Sign() == 0 { // When x = 0, we need to ensure we don't add any extra zeros. m := mod.Int64() for m > 0 { answer = append(answer, alphabet[m%58]) m /= 58 } } else { m := mod.Int64() for i := 0; i < 10; i++ { answer = append(answer, alphabet[m%58]) m /= 58 } } } // leading zero bytes for _, i := range b { if i != 0 { break } answer = append(answer, alphabetIdx0) } // reverse alen := len(answer) for i := 0; i < alen/2; i++ { answer[i], answer[alen-1-i] = answer[alen-1-i], answer[i] } return string(answer) } go-0.11.1/util/base58/base58_test.go000066400000000000000000000063441436100171500167250ustar00rootroot00000000000000// Copyright (c) 2013-2017 The btcsuite developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. package base58_test import ( "bytes" "encoding/hex" "testing" "maunium.net/go/mautrix/util/base58" ) var stringTests = []struct { in string out string }{ {"", ""}, {" ", "Z"}, {"-", "n"}, {"0", "q"}, {"1", "r"}, {"-1", "4SU"}, {"11", "4k8"}, {"abc", "ZiCa"}, {"1234598760", "3mJr7AoUXx2Wqd"}, {"abcdefghijklmnopqrstuvwxyz", "3yxU3u1igY8WkgtjK92fbJQCd4BZiiT1v25f"}, {"00000000000000000000000000000000000000000000000000000000000000", "3sN2THZeE9Eh9eYrwkvZqNstbHGvrxSAM7gXUXvyFQP8XvQLUqNCS27icwUeDT7ckHm4FUHM2mTVh1vbLmk7y"}, } var invalidStringTests = []struct { in string out string }{ {"0", ""}, {"O", ""}, {"I", ""}, {"l", ""}, {"3mJr0", ""}, {"O3yxU", ""}, {"3sNI", ""}, {"4kl8", ""}, {"0OIl", ""}, {"!@#$%^&*()-_=+~`", ""}, } var hexTests = []struct { in string out string }{ {"", ""}, {"61", "2g"}, {"626262", "a3gV"}, {"636363", "aPEr"}, {"73696d706c792061206c6f6e6720737472696e67", "2cFupjhnEsSn59qHXstmK2ffpLv2"}, {"00eb15231dfceb60925886b67d065299925915aeb172c06647", "1NS17iag9jJgTHD1VXjvLCEnZuQ3rJDE9L"}, {"516b6fcd0f", "ABnLTmg"}, {"bf4f89001e670274dd", "3SEo3LWLoPntC"}, {"572e4794", "3EFU7m"}, {"ecac89cad93923c02321", "EJDM8drfXA6uyA"}, {"10c8511e", "Rt5zm"}, {"00000000000000000000", "1111111111"}, {"000111d38e5fc9071ffcd20b4a763cc9ae4f252bb4e48fd66a835e252ada93ff480d6dd43dc62a641155a5", "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz"}, {"000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfeff", "1cWB5HCBdLjAuqGGReWE3R3CguuwSjw6RHn39s2yuDRTS5NsBgNiFpWgAnEx6VQi8csexkgYw3mdYrMHr8x9i7aEwP8kZ7vccXWqKDvGv3u1GxFKPuAkn8JCPPGDMf3vMMnbzm6Nh9zh1gcNsMvH3ZNLmP5fSG6DGbbi2tuwMWPthr4boWwCxf7ewSgNQeacyozhKDDQQ1qL5fQFUW52QKUZDZ5fw3KXNQJMcNTcaB723LchjeKun7MuGW5qyCBZYzA1KjofN1gYBV3NqyhQJ3Ns746GNuf9N2pQPmHz4xpnSrrfCvy6TVVz5d4PdrjeshsWQwpZsZGzvbdAdN8MKV5QsBDY"}, } func TestBase58(t *testing.T) { // Encode tests for x, test := range stringTests { tmp := []byte(test.in) if res := base58.Encode(tmp); res != test.out { t.Errorf("Encode test #%d failed: got: %s want: %s", x, res, test.out) continue } } // Decode tests for x, test := range hexTests { b, err := hex.DecodeString(test.in) if err != nil { t.Errorf("hex.DecodeString failed failed #%d: got: %s", x, test.in) continue } if res := base58.Decode(test.out); !bytes.Equal(res, b) { t.Errorf("Decode test #%d failed: got: %q want: %q", x, res, test.in) continue } } // Decode with invalid input for x, test := range invalidStringTests { if res := base58.Decode(test.in); string(res) != test.out { t.Errorf("Decode invalidString test #%d failed: got: %q want: %q", x, res, test.out) continue } } } go-0.11.1/util/base58/base58bench_test.go000066400000000000000000000017351436100171500177240ustar00rootroot00000000000000// Copyright (c) 2013-2014 The btcsuite developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. package base58_test import ( "bytes" "testing" "maunium.net/go/mautrix/util/base58" ) var ( raw5k = bytes.Repeat([]byte{0xff}, 5000) raw100k = bytes.Repeat([]byte{0xff}, 100*1000) encoded5k = base58.Encode(raw5k) encoded100k = base58.Encode(raw100k) ) func BenchmarkBase58Encode_5K(b *testing.B) { b.SetBytes(int64(len(raw5k))) for i := 0; i < b.N; i++ { base58.Encode(raw5k) } } func BenchmarkBase58Encode_100K(b *testing.B) { b.SetBytes(int64(len(raw100k))) for i := 0; i < b.N; i++ { base58.Encode(raw100k) } } func BenchmarkBase58Decode_5K(b *testing.B) { b.SetBytes(int64(len(encoded5k))) for i := 0; i < b.N; i++ { base58.Decode(encoded5k) } } func BenchmarkBase58Decode_100K(b *testing.B) { b.SetBytes(int64(len(encoded100k))) for i := 0; i < b.N; i++ { base58.Decode(encoded100k) } } go-0.11.1/util/base58/base58check.go000066400000000000000000000027421436100171500166620ustar00rootroot00000000000000// Copyright (c) 2013-2014 The btcsuite developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. package base58 import ( "crypto/sha256" "errors" ) // ErrChecksum indicates that the checksum of a check-encoded string does not verify against // the checksum. var ErrChecksum = errors.New("checksum error") // ErrInvalidFormat indicates that the check-encoded string has an invalid format. var ErrInvalidFormat = errors.New("invalid format: version and/or checksum bytes missing") // checksum: first four bytes of sha256^2 func checksum(input []byte) (cksum [4]byte) { h := sha256.Sum256(input) h2 := sha256.Sum256(h[:]) copy(cksum[:], h2[:4]) return } // CheckEncode prepends a version byte and appends a four byte checksum. func CheckEncode(input []byte, version byte) string { b := make([]byte, 0, 1+len(input)+4) b = append(b, version) b = append(b, input...) cksum := checksum(b) b = append(b, cksum[:]...) return Encode(b) } // CheckDecode decodes a string that was encoded with CheckEncode and verifies the checksum. func CheckDecode(input string) (result []byte, version byte, err error) { decoded := Decode(input) if len(decoded) < 5 { return nil, 0, ErrInvalidFormat } version = decoded[0] var cksum [4]byte copy(cksum[:], decoded[len(decoded)-4:]) if checksum(decoded[:len(decoded)-4]) != cksum { return nil, 0, ErrChecksum } payload := decoded[1 : len(decoded)-4] result = append(result, payload...) return } go-0.11.1/util/base58/base58check_test.go000066400000000000000000000040071436100171500177150ustar00rootroot00000000000000// Copyright (c) 2013-2014 The btcsuite developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. package base58_test import ( "testing" "maunium.net/go/mautrix/util/base58" ) var checkEncodingStringTests = []struct { version byte in string out string }{ {20, "", "3MNQE1X"}, {20, " ", "B2Kr6dBE"}, {20, "-", "B3jv1Aft"}, {20, "0", "B482yuaX"}, {20, "1", "B4CmeGAC"}, {20, "-1", "mM7eUf6kB"}, {20, "11", "mP7BMTDVH"}, {20, "abc", "4QiVtDjUdeq"}, {20, "1234598760", "ZmNb8uQn5zvnUohNCEPP"}, {20, "abcdefghijklmnopqrstuvwxyz", "K2RYDcKfupxwXdWhSAxQPCeiULntKm63UXyx5MvEH2"}, {20, "00000000000000000000000000000000000000000000000000000000000000", "bi1EWXwJay2udZVxLJozuTb8Meg4W9c6xnmJaRDjg6pri5MBAxb9XwrpQXbtnqEoRV5U2pixnFfwyXC8tRAVC8XxnjK"}, } func TestBase58Check(t *testing.T) { for x, test := range checkEncodingStringTests { // test encoding if res := base58.CheckEncode([]byte(test.in), test.version); res != test.out { t.Errorf("CheckEncode test #%d failed: got %s, want: %s", x, res, test.out) } // test decoding res, version, err := base58.CheckDecode(test.out) switch { case err != nil: t.Errorf("CheckDecode test #%d failed with err: %v", x, err) case version != test.version: t.Errorf("CheckDecode test #%d failed: got version: %d want: %d", x, version, test.version) case string(res) != test.in: t.Errorf("CheckDecode test #%d failed: got: %s want: %s", x, res, test.in) } } // test the two decoding failure cases // case 1: checksum error _, _, err := base58.CheckDecode("3MNQE1Y") if err != base58.ErrChecksum { t.Error("Checkdecode test failed, expected ErrChecksum") } // case 2: invalid formats (string lengths below 5 mean the version byte and/or the checksum // bytes are missing). testString := "" for len := 0; len < 4; len++ { testString += "x" _, _, err = base58.CheckDecode(testString) if err != base58.ErrInvalidFormat { t.Error("Checkdecode test failed, expected ErrInvalidFormat") } } } go-0.11.1/util/base58/doc.go000066400000000000000000000023541436100171500153410ustar00rootroot00000000000000// Copyright (c) 2014 The btcsuite developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. /* Package base58 provides an API for working with modified base58 and Base58Check encodings. Modified Base58 Encoding Standard base58 encoding is similar to standard base64 encoding except, as the name implies, it uses a 58 character alphabet which results in an alphanumeric string and allows some characters which are problematic for humans to be excluded. Due to this, there can be various base58 alphabets. The modified base58 alphabet used by Bitcoin, and hence this package, omits the 0, O, I, and l characters that look the same in many fonts and are therefore hard to humans to distinguish. Base58Check Encoding Scheme The Base58Check encoding scheme is primarily used for Bitcoin addresses at the time of this writing, however it can be used to generically encode arbitrary byte arrays into human-readable strings along with a version byte that can be used to differentiate the same payload. For Bitcoin addresses, the extra version is used to differentiate the network of otherwise identical public keys which helps prevent using an address intended for one network on another. */ package base58 go-0.11.1/util/base58/example_test.go000066400000000000000000000033651436100171500172710ustar00rootroot00000000000000// Copyright (c) 2014 The btcsuite developers // Use of this source code is governed by an ISC // license that can be found in the LICENSE file. package base58_test import ( "fmt" "maunium.net/go/mautrix/util/base58" ) // This example demonstrates how to decode modified base58 encoded data. func ExampleDecode() { // Decode example modified base58 encoded data. encoded := "25JnwSn7XKfNQ" decoded := base58.Decode(encoded) // Show the decoded data. fmt.Println("Decoded Data:", string(decoded)) // Output: // Decoded Data: Test data } // This example demonstrates how to encode data using the modified base58 // encoding scheme. func ExampleEncode() { // Encode example data with the modified base58 encoding scheme. data := []byte("Test data") encoded := base58.Encode(data) // Show the encoded data. fmt.Println("Encoded Data:", encoded) // Output: // Encoded Data: 25JnwSn7XKfNQ } // This example demonstrates how to decode Base58Check encoded data. func ExampleCheckDecode() { // Decode an example Base58Check encoded data. encoded := "1A1zP1eP5QGefi2DMPTfTL5SLmv7DivfNa" decoded, version, err := base58.CheckDecode(encoded) if err != nil { fmt.Println(err) return } // Show the decoded data. fmt.Printf("Decoded data: %x\n", decoded) fmt.Println("Version Byte:", version) // Output: // Decoded data: 62e907b15cbf27d5425399ebf6f0fb50ebb88f18 // Version Byte: 0 } // This example demonstrates how to encode data using the Base58Check encoding // scheme. func ExampleCheckEncode() { // Encode example data with the Base58Check encoding scheme. data := []byte("Test data") encoded := base58.CheckEncode(data, 0) // Show the encoded data. fmt.Println("Encoded Data:", encoded) // Output: // Encoded Data: 182iP79GRURMp7oMHDU } go-0.11.1/util/ffmpeg/000077500000000000000000000000001436100171500144165ustar00rootroot00000000000000go-0.11.1/util/ffmpeg/convert.go000066400000000000000000000055551436100171500164370ustar00rootroot00000000000000// Copyright (c) 2022 Sumner Evans // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package ffmpeg import ( "fmt" "io/ioutil" "os" "os/exec" "path/filepath" "strings" log "maunium.net/go/maulogger/v2" "maunium.net/go/mautrix/util" ) var ffmpegDefaultParams = []string{"-hide_banner", "-loglevel", "warning"} // Convert a media file on the disk using ffmpeg. // // Args: // * inputFile: The full path to the file. // * outputExtension: The extension that the output file should be. // * inputArgs: Arguments to tell ffmpeg how to parse the input file. // * outputArgs: Arguments to tell ffmpeg how to convert the file to reach the wanted output. // * removeInput: Whether the input file should be removed after converting. // // Returns: the path to the converted file. func ConvertPath(inputFile string, outputExtension string, inputArgs []string, outputArgs []string, removeInput bool) (string, error) { outputFilename := strings.TrimSuffix(inputFile, filepath.Ext(inputFile)) + outputExtension args := []string{} args = append(args, ffmpegDefaultParams...) args = append(args, inputArgs...) args = append(args, "-i", inputFile) args = append(args, outputArgs...) args = append(args, outputFilename) cmd := exec.Command("ffmpeg", args...) vcLog := log.Sub("ffmpeg").Writer(log.LevelWarn) cmd.Stdout = vcLog cmd.Stderr = vcLog err := cmd.Run() if err != nil { return "", fmt.Errorf("ffmpeg error: %+v", err) } if removeInput { os.Remove(inputFile) } return outputFilename, nil } // Convert media data using ffmpeg. // // Args: // * data: The media data to convert // * outputExtension: The extension that the output file should be. // * inputArgs: Arguments to tell ffmpeg how to parse the input file. // * outputArgs: Arguments to tell ffmpeg how to convert the file to reach the wanted output. // * inputMime: The mimetype of the input data. // // Returns: the converted data func ConvertBytes(data []byte, outputExtension string, inputArgs []string, outputArgs []string, inputMime string) ([]byte, error) { tempdir, err := ioutil.TempDir("", "mautrix_ffmpeg_*") if err != nil { return nil, err } defer os.RemoveAll(tempdir) inputFileName := fmt.Sprintf("%s/input%s", tempdir, util.ExtensionFromMimetype(inputMime)) inputFile, err := os.OpenFile(inputFileName, os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0600) if err != nil { return nil, fmt.Errorf("failed to open input file: %w", err) } _, err = inputFile.Write(data) if err != nil { inputFile.Close() return nil, fmt.Errorf("failed to write data to input file: %w", err) } inputFile.Close() outputPath, err := ConvertPath(inputFileName, outputExtension, inputArgs, outputArgs, false) if err != nil { return nil, err } return ioutil.ReadFile(outputPath) } go-0.11.1/util/mimetypes.go000066400000000000000000000024111436100171500155130ustar00rootroot00000000000000// Copyright (c) 2022 Sumner Evans // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package util import ( "mime" "strings" ) // MimeExtensionSanityOverrides includes extensions for various common mimetypes. // // This is necessary because sometimes the OS mimetype database and Go interact in weird ways, // which causes very obscure extensions to be first in the array for common mimetypes // (e.g. image/jpeg -> .jpe, text/plain -> ,v). var MimeExtensionSanityOverrides = map[string]string{ "image/png": ".png", "image/webp": ".webp", "image/jpeg": ".jpg", "image/tiff": ".tiff", "image/heif": ".heic", "image/heic": ".heic", "audio/mpeg": ".mp3", "audio/ogg": ".ogg", "audio/webm": ".webm", "audio/x-caf": ".caf", "video/mp4": ".mp4", "video/mpeg": ".mpeg", "video/webm": ".webm", "text/plain": ".txt", "text/html": ".html", "application/xml": ".xml", } func ExtensionFromMimetype(mimetype string) string { ext, ok := MimeExtensionSanityOverrides[strings.Split(mimetype, ";")[0]] if !ok { exts, _ := mime.ExtensionsByType(mimetype) if len(exts) > 0 { ext = exts[0] } } return ext } go-0.11.1/util/variationselector/000077500000000000000000000000001436100171500167075ustar00rootroot00000000000000go-0.11.1/util/variationselector/emojis-with-variations.json000066400000000000000000000043551436100171500242250ustar00rootroot00000000000000["*","0","1","2","3","4","5","6","7","8","9","ยฉ","ยฎ","โ€ผ","โ‰","โ„ข","โ„น","โ†”","โ†•","โ†–","โ†—","โ†˜","โ†™","โ†ฉ","โ†ช","โŒš","โŒ›","โŒจ","โ","โฉ","โช","โญ","โฎ","โฏ","โฑ","โฒ","โณ","โธ","โน","โบ","โ“‚","โ–ช","โ–ซ","โ–ถ","โ—€","โ—ป","โ—ผ","โ—ฝ","โ—พ","โ˜€","โ˜","โ˜‚","โ˜ƒ","โ˜„","โ˜Ž","โ˜‘","โ˜”","โ˜•","โ˜˜","โ˜","โ˜ ","โ˜ข","โ˜ฃ","โ˜ฆ","โ˜ช","โ˜ฎ","โ˜ฏ","โ˜ธ","โ˜น","โ˜บ","โ™€","โ™‚","โ™ˆ","โ™‰","โ™Š","โ™‹","โ™Œ","โ™","โ™Ž","โ™","โ™","โ™‘","โ™’","โ™“","โ™Ÿ","โ™ ","โ™ฃ","โ™ฅ","โ™ฆ","โ™จ","โ™ป","โ™พ","โ™ฟ","โš’","โš“","โš”","โš•","โš–","โš—","โš™","โš›","โšœ","โš ","โšก","โšง","โšช","โšซ","โšฐ","โšฑ","โšฝ","โšพ","โ›„","โ›…","โ›ˆ","โ›","โ›‘","โ›“","โ›”","โ›ฉ","โ›ช","โ›ฐ","โ›ฑ","โ›ฒ","โ›ณ","โ›ด","โ›ต","โ›ท","โ›ธ","โ›น","โ›บ","โ›ฝ","โœ‚","โœˆ","โœ‰","โœŒ","โœ","โœ","โœ’","โœ”","โœ–","โœ","โœก","โœณ","โœด","โ„","โ‡","โ“","โ—","โฃ","โค","โžก","โคด","โคต","โฌ…","โฌ†","โฌ‡","โฌ›","โฌœ","โญ","โญ•","ใ€ฐ","ใ€ฝ","ใŠ—","ใŠ™","๐Ÿ€„","๐Ÿ…ฐ","๐Ÿ…ฑ","๐Ÿ…พ","๐Ÿ…ฟ","๐Ÿˆ‚","๐Ÿˆš","๐Ÿˆฏ","๐Ÿˆท","๐ŸŒ","๐ŸŒŽ","๐ŸŒ","๐ŸŒ•","๐ŸŒœ","๐ŸŒก","๐ŸŒค","๐ŸŒฅ","๐ŸŒฆ","๐ŸŒง","๐ŸŒจ","๐ŸŒฉ","๐ŸŒช","๐ŸŒซ","๐ŸŒฌ","๐ŸŒถ","๐Ÿธ","๐Ÿฝ","๐ŸŽ“","๐ŸŽ–","๐ŸŽ—","๐ŸŽ™","๐ŸŽš","๐ŸŽ›","๐ŸŽž","๐ŸŽŸ","๐ŸŽง","๐ŸŽฌ","๐ŸŽญ","๐ŸŽฎ","๐Ÿ‚","๐Ÿ„","๐Ÿ†","๐ŸŠ","๐Ÿ‹","๐ŸŒ","๐Ÿ","๐ŸŽ","๐Ÿ”","๐Ÿ•","๐Ÿ–","๐Ÿ—","๐Ÿ˜","๐Ÿ™","๐Ÿš","๐Ÿ›","๐Ÿœ","๐Ÿ","๐Ÿž","๐ŸŸ","๐Ÿ ","๐Ÿญ","๐Ÿณ","๐Ÿต","๐Ÿท","๐Ÿˆ","๐Ÿ•","๐ŸŸ","๐Ÿฆ","๐Ÿฟ","๐Ÿ‘","๐Ÿ‘‚","๐Ÿ‘†","๐Ÿ‘‡","๐Ÿ‘ˆ","๐Ÿ‘‰","๐Ÿ‘","๐Ÿ‘Ž","๐Ÿ‘“","๐Ÿ‘ช","๐Ÿ‘ฝ","๐Ÿ’ฃ","๐Ÿ’ฐ","๐Ÿ’ณ","๐Ÿ’ป","๐Ÿ’ฟ","๐Ÿ“‹","๐Ÿ“š","๐Ÿ“Ÿ","๐Ÿ“ค","๐Ÿ“ฅ","๐Ÿ“ฆ","๐Ÿ“ช","๐Ÿ“ซ","๐Ÿ“ฌ","๐Ÿ“ญ","๐Ÿ“ท","๐Ÿ“น","๐Ÿ“บ","๐Ÿ“ป","๐Ÿ“ฝ","๐Ÿ”ˆ","๐Ÿ”","๐Ÿ”’","๐Ÿ”“","๐Ÿ•‰","๐Ÿ•Š","๐Ÿ•","๐Ÿ•‘","๐Ÿ•’","๐Ÿ•“","๐Ÿ•”","๐Ÿ••","๐Ÿ•–","๐Ÿ•—","๐Ÿ•˜","๐Ÿ•™","๐Ÿ•š","๐Ÿ•›","๐Ÿ•œ","๐Ÿ•","๐Ÿ•ž","๐Ÿ•Ÿ","๐Ÿ• ","๐Ÿ•ก","๐Ÿ•ข","๐Ÿ•ฃ","๐Ÿ•ค","๐Ÿ•ฅ","๐Ÿ•ฆ","๐Ÿ•ง","๐Ÿ•ฏ","๐Ÿ•ฐ","๐Ÿ•ณ","๐Ÿ•ด","๐Ÿ•ต","๐Ÿ•ถ","๐Ÿ•ท","๐Ÿ•ธ","๐Ÿ•น","๐Ÿ–‡","๐Ÿ–Š","๐Ÿ–‹","๐Ÿ–Œ","๐Ÿ–","๐Ÿ–","๐Ÿ–ฅ","๐Ÿ–จ","๐Ÿ–ฑ","๐Ÿ–ฒ","๐Ÿ–ผ","๐Ÿ—‚","๐Ÿ—ƒ","๐Ÿ—„","๐Ÿ—‘","๐Ÿ—’","๐Ÿ—“","๐Ÿ—œ","๐Ÿ—","๐Ÿ—ž","๐Ÿ—ก","๐Ÿ—ฃ","๐Ÿ—จ","๐Ÿ—ฏ","๐Ÿ—ณ","๐Ÿ—บ","๐Ÿ˜","๐Ÿš‡","๐Ÿš","๐Ÿš‘","๐Ÿš”","๐Ÿš˜","๐Ÿšญ","๐Ÿšฒ","๐Ÿšน","๐Ÿšบ","๐Ÿšผ","๐Ÿ›‹","๐Ÿ›","๐Ÿ›Ž","๐Ÿ›","๐Ÿ› ","๐Ÿ›ก","๐Ÿ›ข","๐Ÿ›ฃ","๐Ÿ›ค","๐Ÿ›ฅ","๐Ÿ›ฉ","๐Ÿ›ฐ","๐Ÿ›ณ"] go-0.11.1/util/variationselector/generate.sh000077500000000000000000000003551436100171500210430ustar00rootroot00000000000000#!/bin/bash echo -e "$( curl -s https://www.unicode.org/Public/14.0.0/ucd/emoji/emoji-variation-sequences.txt \ | grep FE0F \ | awk '{ printf("\\U%8s\n", $1) }' \ | sed 's/ /0/g' )" | jq -RcM '[inputs]' > emojis-with-variations.json go-0.11.1/util/variationselector/variationselector.go000066400000000000000000000034231436100171500227750ustar00rootroot00000000000000// Copyright (c) 2022 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. // Package variationselector provides utility functions for adding and removing emoji variation selectors (16) // that matches the suggestions in the Matrix spec. package variationselector import ( _ "embed" "encoding/json" "strings" ) //go:generate ./generate.sh //go:embed emojis-with-variations.json var emojisWithVariationsJSON []byte var variationReplacer *strings.Replacer // The variation replacer will add incorrect variation selectors before skin tones, this removes those. var skinToneReplacer = strings.NewReplacer( "\ufe0f\U0001F3FB", "\U0001F3FB", "\ufe0f\U0001F3FC", "\U0001F3FC", "\ufe0f\U0001F3FD", "\U0001F3FD", "\ufe0f\U0001F3FE", "\U0001F3FE", "\ufe0f\U0001F3FF", "\U0001F3FF", ) func init() { var emojisWithVariations []string err := json.Unmarshal(emojisWithVariationsJSON, &emojisWithVariations) if err != nil { panic(err) } replaceInput := make([]string, 2*len(emojisWithVariations)) for i, emoji := range emojisWithVariations { replaceInput[i*2] = emoji replaceInput[(i*2)+1] = emoji + VS16 } variationReplacer = strings.NewReplacer(replaceInput...) } const VS16 = "\ufe0f" // Add adds emoji variation selectors to all emojis that have multiple forms in the given string. // // This will remove all variation selectors first to make sure it doesn't add duplicates. func Add(val string) string { return skinToneReplacer.Replace(variationReplacer.Replace(Remove(val))) } // Remove removes all emoji variation selectors in the given string. func Remove(val string) string { return strings.ReplaceAll(val, VS16, "") } go-0.11.1/util/variationselector/variationselector_test.go000066400000000000000000000041251436100171500240340ustar00rootroot00000000000000// Copyright (c) 2022 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package variationselector_test import ( "fmt" "strconv" "testing" "github.com/stretchr/testify/assert" "maunium.net/go/mautrix/util/variationselector" ) func TestAdd(t *testing.T) { assert.Equal(t, "\U0001f44d\U0001F3FD", variationselector.Add("\U0001f44d\U0001F3FD")) assert.Equal(t, "\U0001f44d\ufe0f", variationselector.Add("\U0001f44d")) assert.Equal(t, "\U0001f44d\ufe0f", variationselector.Add("\U0001f44d\ufe0f")) assert.Equal(t, "4\ufe0f\u20e3", variationselector.Add("4\u20e3")) assert.Equal(t, "4\ufe0f\u20e3", variationselector.Add("4\ufe0f\u20e3")) assert.Equal(t, "\U0001f914", variationselector.Add("\U0001f914")) } func TestRemove(t *testing.T) { assert.Equal(t, "\U0001f44d", variationselector.Remove("\U0001f44d")) assert.Equal(t, "\U0001f44d", variationselector.Remove("\U0001f44d\ufe0f")) assert.Equal(t, "4\u20e3", variationselector.Remove("4\u20e3")) assert.Equal(t, "4\u20e3", variationselector.Remove("4\ufe0f\u20e3")) assert.Equal(t, "\U0001f914", variationselector.Remove("\U0001f914")) } func ExampleAdd() { fmt.Println(strconv.QuoteToASCII(variationselector.Add("\U0001f44d"))) // thumbs up (needs selector) fmt.Println(strconv.QuoteToASCII(variationselector.Add("\U0001f44d\ufe0f"))) // thumbs up with variation selector (stays as-is) fmt.Println(strconv.QuoteToASCII(variationselector.Add("\U0001f44d\U0001F3FD"))) // thumbs up with skin tone (shouldn't get selector) fmt.Println(strconv.QuoteToASCII(variationselector.Add("\U0001f914"))) // thinking face (shouldn't get selector) // Output: // "\U0001f44d\ufe0f" // "\U0001f44d\ufe0f" // "\U0001f44d\U0001f3fd" // "\U0001f914" } func ExampleRemove() { fmt.Println(strconv.QuoteToASCII(variationselector.Remove("\U0001f44d"))) fmt.Println(strconv.QuoteToASCII(variationselector.Remove("\U0001f44d\ufe0f"))) // Output: // "\U0001f44d" // "\U0001f44d" } go-0.11.1/version.go000066400000000000000000000001331436100171500142060ustar00rootroot00000000000000package mautrix const Version = "v0.11.1" var DefaultUserAgent = "mautrix-go/" + Version go-0.11.1/versions.go000066400000000000000000000072021436100171500143750ustar00rootroot00000000000000// Copyright (c) 2022 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package mautrix import ( "fmt" "regexp" "strconv" ) // RespVersions is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientversions type RespVersions struct { Versions []SpecVersion `json:"versions"` UnstableFeatures map[string]bool `json:"unstable_features"` } func (versions *RespVersions) ContainsFunc(match func(found SpecVersion) bool) bool { for _, found := range versions.Versions { if match(found) { return true } } return false } func (versions *RespVersions) Contains(version SpecVersion) bool { return versions.ContainsFunc(func(found SpecVersion) bool { return found == version }) } func (versions *RespVersions) ContainsGreaterOrEqual(version SpecVersion) bool { return versions.ContainsFunc(func(found SpecVersion) bool { return found.GreaterThan(version) || found == version }) } type SpecVersionFormat int const ( SpecVersionFormatUnknown SpecVersionFormat = iota SpecVersionFormatR SpecVersionFormatV ) var ( SpecR010 = MustParseSpecVersion("r0.1.0") SpecR020 = MustParseSpecVersion("r0.2.0") SpecR030 = MustParseSpecVersion("r0.3.0") SpecR040 = MustParseSpecVersion("r0.4.0") SpecR050 = MustParseSpecVersion("r0.5.0") SpecR060 = MustParseSpecVersion("r0.6.0") SpecR061 = MustParseSpecVersion("r0.6.1") SpecV11 = MustParseSpecVersion("v1.1") SpecV12 = MustParseSpecVersion("v1.2") ) func (svf SpecVersionFormat) String() string { switch svf { case SpecVersionFormatR: return "r" case SpecVersionFormatV: return "v" default: return "" } } type SpecVersion struct { Format SpecVersionFormat Major int Minor int Patch int Raw string } var legacyVersionRegex = regexp.MustCompile(`^r(\d+)\.(\d+)\.(\d+)$`) var modernVersionRegex = regexp.MustCompile(`^v(\d+)\.(\d+)$`) func MustParseSpecVersion(version string) SpecVersion { sv, err := ParseSpecVersion(version) if err != nil { panic(err) } return sv } func ParseSpecVersion(version string) (sv SpecVersion, err error) { sv.Raw = version if parts := modernVersionRegex.FindStringSubmatch(version); parts != nil { sv.Major, _ = strconv.Atoi(parts[1]) sv.Minor, _ = strconv.Atoi(parts[2]) sv.Format = SpecVersionFormatV } else if parts = legacyVersionRegex.FindStringSubmatch(version); parts != nil { sv.Major, _ = strconv.Atoi(parts[1]) sv.Minor, _ = strconv.Atoi(parts[2]) sv.Patch, _ = strconv.Atoi(parts[3]) sv.Format = SpecVersionFormatR } else { err = fmt.Errorf("version '%s' doesn't match either known syntax", version) } return } func (sv *SpecVersion) UnmarshalText(version []byte) error { *sv, _ = ParseSpecVersion(string(version)) return nil } func (sv *SpecVersion) MarshalText() ([]byte, error) { return []byte(sv.String()), nil } func (sv *SpecVersion) String() string { switch sv.Format { case SpecVersionFormatR: return fmt.Sprintf("r%d.%d.%d", sv.Major, sv.Minor, sv.Patch) case SpecVersionFormatV: return fmt.Sprintf("v%d.%d", sv.Major, sv.Minor) default: return sv.Raw } } func (sv SpecVersion) LessThan(other SpecVersion) bool { return sv != other && !sv.GreaterThan(other) } func (sv SpecVersion) GreaterThan(other SpecVersion) bool { return sv.Format > other.Format || (sv.Format == other.Format && sv.Major > other.Major) || (sv.Format == other.Format && sv.Major == other.Major && sv.Minor > other.Minor) || (sv.Format == other.Format && sv.Major == other.Major && sv.Minor == other.Minor && sv.Patch > other.Patch) } go-0.11.1/versions_test.go000066400000000000000000000077701436100171500154460ustar00rootroot00000000000000// Copyright (c) 2022 Tulir Asokan // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. package mautrix_test import ( "encoding/json" "testing" "github.com/stretchr/testify/assert" "maunium.net/go/mautrix" ) const sampleVersions = `{ "versions": [ "r0.0.1", "r0.1.0", "r0.2.0", "r0.3.0", "r0.4.0", "r0.5.0", "r0.6.0", "r0.6.1", "v1.1", "v1.2" ], "unstable_features": { "org.matrix.label_based_filtering": true, "org.matrix.e2e_cross_signing": true, "org.matrix.msc2432": true, "uk.half-shot.msc2666.mutual_rooms": true, "io.element.e2ee_forced.public": false, "io.element.e2ee_forced.private": false, "io.element.e2ee_forced.trusted_private": false, "org.matrix.msc3026.busy_presence": false, "org.matrix.msc2285": true, "org.matrix.msc2716": false, "org.matrix.msc3030": false, "org.matrix.msc3440.stable": true, "fi.mau.msc2815": false } }` func TestRespVersions_UnmarshalJSON(t *testing.T) { var resp mautrix.RespVersions err := json.Unmarshal([]byte(sampleVersions), &resp) assert.NoError(t, err) assert.True(t, resp.ContainsGreaterOrEqual(mautrix.SpecV11)) assert.True(t, resp.Contains(mautrix.SpecV12)) assert.True(t, resp.Contains(mautrix.SpecR061)) assert.True(t, resp.ContainsGreaterOrEqual(mautrix.MustParseSpecVersion("r0.0.0"))) assert.True(t, !resp.ContainsGreaterOrEqual(mautrix.MustParseSpecVersion("v123.456"))) } func TestParseSpecVersion(t *testing.T) { assert.Equal(t, mautrix.SpecVersion{mautrix.SpecVersionFormatR, 0, 1, 0, "r0.1.0"}, mautrix.MustParseSpecVersion("r0.1.0")) assert.Equal(t, mautrix.SpecVersion{mautrix.SpecVersionFormatV, 1, 1, 0, "v1.1"}, mautrix.MustParseSpecVersion("v1.1")) assert.Equal(t, mautrix.SpecVersion{mautrix.SpecVersionFormatV, 123, 456, 0, "v123.456"}, mautrix.MustParseSpecVersion("v123.456")) invalidVer, err := mautrix.ParseSpecVersion("not a version") assert.Error(t, err) assert.Equal(t, mautrix.SpecVersion{Raw: "not a version"}, invalidVer) // v syntax doesn't allow patch versions invalidVer, err = mautrix.ParseSpecVersion("v1.2.3") assert.Error(t, err) assert.Equal(t, mautrix.SpecVersion{Raw: "v1.2.3"}, invalidVer) invalidVer, err = mautrix.ParseSpecVersion("r0.6") assert.Error(t, err) assert.Equal(t, mautrix.SpecVersion{Raw: "r0.6"}, invalidVer) } func TestSpecVersion_String(t *testing.T) { assert.Equal(t, "r0.1.0", (&mautrix.SpecVersion{mautrix.SpecVersionFormatR, 0, 1, 0, ""}).String()) assert.Equal(t, "v1.2", (&mautrix.SpecVersion{mautrix.SpecVersionFormatV, 1, 2, 0, ""}).String()) assert.Equal(t, "v567.890", (&mautrix.SpecVersion{mautrix.SpecVersionFormatV, 567, 890, 0, ""}).String()) assert.Equal(t, "invalid version", (&mautrix.SpecVersion{Raw: "invalid version"}).String()) } func TestSpecVersion_GreaterThan(t *testing.T) { assert.True(t, mautrix.MustParseSpecVersion("r0.1.0").GreaterThan(mautrix.MustParseSpecVersion("r0.0.0"))) assert.True(t, mautrix.MustParseSpecVersion("r0.6.0").GreaterThan(mautrix.MustParseSpecVersion("r0.1.0"))) assert.True(t, mautrix.MustParseSpecVersion("r0.6.1").GreaterThan(mautrix.MustParseSpecVersion("r0.1.0"))) assert.True(t, mautrix.MustParseSpecVersion("v1.1").GreaterThan(mautrix.MustParseSpecVersion("r0.6.1"))) assert.True(t, mautrix.MustParseSpecVersion("v11.11").GreaterThan(mautrix.MustParseSpecVersion("v1.23"))) assert.True(t, mautrix.MustParseSpecVersion("v1.123").GreaterThan(mautrix.MustParseSpecVersion("v1.1"))) assert.True(t, !mautrix.MustParseSpecVersion("v1.23").GreaterThan(mautrix.MustParseSpecVersion("v2.31"))) assert.True(t, !mautrix.MustParseSpecVersion("r0.6.0").GreaterThan(mautrix.MustParseSpecVersion("r0.6.1"))) assert.True(t, !mautrix.MustParseSpecVersion("r0.6.0").GreaterThan(mautrix.MustParseSpecVersion("r0.6.0"))) assert.True(t, !mautrix.MustParseSpecVersion("r0.6.0").LessThan(mautrix.MustParseSpecVersion("r0.6.0"))) }