Merge branch 'e2be'

This commit is contained in:
Tulir Asokan 2020-05-21 20:58:57 +03:00
commit 8bb5407f98
37 changed files with 1677 additions and 533 deletions

View file

@ -10,3 +10,6 @@ insert_final_newline = true
[*.{yaml,yml}] [*.{yaml,yml}]
indent_style = space indent_style = space
[.gitlab-ci.yml]
indent_size = 2

View file

@ -3,16 +3,15 @@ stages:
- build docker - build docker
- manifest - manifest
build: .build: &build
image: golang:1-alpine image: golang:1-alpine
stage: build stage: build
tags:
- amd64
cache: cache:
paths: paths:
- .cache - .cache
before_script: before_script:
- apk add git build-base - echo "@edge_community http://dl-cdn.alpinelinux.org/alpine/edge/community" >> /etc/apk/repositories
- apk add build-base olm-dev@edge_community
- mkdir -p .cache - mkdir -p .cache
- export GOPATH="$CI_PROJECT_DIR/.cache" - export GOPATH="$CI_PROJECT_DIR/.cache"
script: script:
@ -22,31 +21,62 @@ build:
- mautrix-whatsapp - mautrix-whatsapp
- example-config.yaml - example-config.yaml
build docker amd64: .build-docker: &build-docker
image: docker:stable image: docker:stable
stage: build docker stage: build docker
before_script:
- docker login -u $CI_REGISTRY_USER -p $CI_REGISTRY_PASSWORD $CI_REGISTRY
script:
- docker pull $CI_REGISTRY_IMAGE:latest || true
- docker build --pull --cache-from $CI_REGISTRY_IMAGE:latest --tag $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-$DOCKER_ARCH . --file Dockerfile.ci
- docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-$DOCKER_ARCH
- docker rmi $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-$DOCKER_ARCH
build static amd64:
image: golang:1-alpine
stage: build
tags: tags:
- amd64 - amd64
cache:
paths:
- .cache
before_script: before_script:
- docker login -u $CI_REGISTRY_USER -p $CI_REGISTRY_PASSWORD $CI_REGISTRY - mkdir -p .cache
- export GOPATH="$CI_PROJECT_DIR/.cache"
script: script:
- docker pull $CI_REGISTRY_IMAGE:latest || true - CGO_ENABLED=0 go build -o mautrix-whatsapp
- docker build --pull --cache-from $CI_REGISTRY_IMAGE:latest --tag $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-amd64 . --file Dockerfile.ci artifacts:
- docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-amd64 paths:
- docker rmi $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-amd64 - mautrix-whatsapp
- example-config.yaml
build docker arm64: build amd64:
image: docker:stable <<: *build
stage: build docker tags:
- amd64
build arm64:
<<: *build
tags: tags:
- arm64 - arm64
before_script:
- docker login -u $CI_REGISTRY_USER -p $CI_REGISTRY_PASSWORD $CI_REGISTRY build docker amd64:
script: <<: *build-docker
- docker pull $CI_REGISTRY_IMAGE:latest || true tags:
- docker build --pull --cache-from $CI_REGISTRY_IMAGE:latest --tag $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-arm64 . - amd64
- docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-arm64 dependencies:
- docker rmi $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-arm64 - build amd64
variables:
DOCKER_ARCH: amd64
build docker arm64:
<<: *build-docker
tags:
- arm64
dependencies:
- build arm64
variables:
DOCKER_ARCH: arm64
manifest: manifest:
stage: manifest stage: manifest

View file

@ -1,6 +1,7 @@
FROM golang:1.12-alpine AS builder FROM golang:1-alpine AS builder
RUN apk add --no-cache git ca-certificates build-base su-exec RUN echo "@edge_community http://dl-cdn.alpinelinux.org/alpine/edge/community" >> /etc/apk/repositories
RUN apk add --no-cache git ca-certificates build-base su-exec olm-dev@edge_community
WORKDIR /build WORKDIR /build
COPY go.mod go.sum /build/ COPY go.mod go.sum /build/
@ -14,7 +15,8 @@ FROM alpine:latest
ENV UID=1337 \ ENV UID=1337 \
GID=1337 GID=1337
RUN apk add --no-cache su-exec ca-certificates RUN echo "@edge_community http://dl-cdn.alpinelinux.org/alpine/edge/community" >> /etc/apk/repositories
RUN apk add --no-cache su-exec ca-certificates olm@edge_community
COPY --from=builder /usr/bin/mautrix-whatsapp /usr/bin/mautrix-whatsapp COPY --from=builder /usr/bin/mautrix-whatsapp /usr/bin/mautrix-whatsapp
COPY --from=builder /build/example-config.yaml /opt/mautrix-whatsapp/example-config.yaml COPY --from=builder /build/example-config.yaml /opt/mautrix-whatsapp/example-config.yaml

View file

@ -3,7 +3,8 @@ FROM alpine:latest
ENV UID=1337 \ ENV UID=1337 \
GID=1337 GID=1337
RUN apk add --no-cache su-exec ca-certificates RUN echo "@edge_community http://dl-cdn.alpinelinux.org/alpine/edge/community" >> /etc/apk/repositories
RUN apk add --no-cache su-exec ca-certificates olm@edge_community
ARG EXECUTABLE=./mautrix-whatsapp ARG EXECUTABLE=./mautrix-whatsapp
COPY $EXECUTABLE /usr/bin/mautrix-whatsapp COPY $EXECUTABLE /usr/bin/mautrix-whatsapp

View file

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2019 Tulir Asokan // Copyright (C) 2020 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by
@ -18,6 +18,7 @@ package main
import ( import (
"fmt" "fmt"
"strconv"
"strings" "strings"
"github.com/Rhymen/go-whatsapp" "github.com/Rhymen/go-whatsapp"
@ -25,11 +26,12 @@ import (
"maunium.net/go/maulogger/v2" "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix" "maunium.net/go/mautrix"
"maunium.net/go/mautrix-appservice" "maunium.net/go/mautrix/appservice"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/format" "maunium.net/go/mautrix/format"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix-whatsapp/database" "maunium.net/go/mautrix-whatsapp/database"
"maunium.net/go/mautrix-whatsapp/types"
"maunium.net/go/mautrix-whatsapp/whatsapp-ext" "maunium.net/go/mautrix-whatsapp/whatsapp-ext"
) )
@ -51,7 +53,7 @@ type CommandEvent struct {
Bot *appservice.IntentAPI Bot *appservice.IntentAPI
Bridge *Bridge Bridge *Bridge
Handler *CommandHandler Handler *CommandHandler
RoomID types.MatrixRoomID RoomID id.RoomID
User *User User *User
Command string Command string
Args []string Args []string
@ -59,20 +61,20 @@ type CommandEvent struct {
// Reply sends a reply to command as notice // Reply sends a reply to command as notice
func (ce *CommandEvent) Reply(msg string, args ...interface{}) { func (ce *CommandEvent) Reply(msg string, args ...interface{}) {
content := format.RenderMarkdown(fmt.Sprintf(msg, args...)) content := format.RenderMarkdown(fmt.Sprintf(msg, args...), true, false)
content.MsgType = mautrix.MsgNotice content.MsgType = event.MsgNotice
room := ce.User.ManagementRoom room := ce.User.ManagementRoom
if len(room) == 0 { if len(room) == 0 {
room = ce.RoomID room = ce.RoomID
} }
_, err := ce.Bot.SendMessageEvent(room, mautrix.EventMessage, content) _, err := ce.Bot.SendMessageEvent(room, event.EventMessage, content)
if err != nil { if err != nil {
ce.Handler.log.Warnfln("Failed to reply to command from %s: %v", ce.User.MXID, err) ce.Handler.log.Warnfln("Failed to reply to command from %s: %v", ce.User.MXID, err)
} }
} }
// Handle handles messages to the bridge // Handle handles messages to the bridge
func (handler *CommandHandler) Handle(roomID types.MatrixRoomID, user *User, message string) { func (handler *CommandHandler) Handle(roomID id.RoomID, user *User, message string) {
args := strings.Split(message, " ") args := strings.Split(message, " ")
ce := &CommandEvent{ ce := &CommandEvent{
Bot: handler.bridge.Bot, Bot: handler.bridge.Bot,
@ -117,7 +119,11 @@ func (handler *CommandHandler) CommandMux(ce *CommandEvent) {
handler.CommandDeleteAllPortals(ce) handler.CommandDeleteAllPortals(ce)
case "dev-test": case "dev-test":
handler.CommandDevTest(ce) handler.CommandDevTest(ce)
case "login-matrix", "logout", "sync", "list", "open", "pm": case "set-pl":
handler.CommandSetPowerLevel(ce)
case "logout":
handler.CommandLogout(ce)
case "login-matrix", "sync", "list", "open", "pm":
if !ce.User.HasSession() { if !ce.User.HasSession() {
ce.Reply("You are not logged in. Use the `login` command to log into WhatsApp.") ce.Reply("You are not logged in. Use the `login` command to log into WhatsApp.")
return return
@ -129,8 +135,6 @@ func (handler *CommandHandler) CommandMux(ce *CommandEvent) {
switch ce.Command { switch ce.Command {
case "login-matrix": case "login-matrix":
handler.CommandLoginMatrix(ce) handler.CommandLoginMatrix(ce)
case "logout":
handler.CommandLogout(ce)
case "sync": case "sync":
handler.CommandSync(ce) handler.CommandSync(ce)
case "list": case "list":
@ -168,6 +172,45 @@ func (handler *CommandHandler) CommandDevTest(ce *CommandEvent) {
} }
func (handler *CommandHandler) CommandSetPowerLevel(ce *CommandEvent) {
portal := ce.Bridge.GetPortalByMXID(ce.RoomID)
if portal == nil {
ce.Reply("Not a portal room")
return
}
var level int
var userID id.UserID
var err error
if len(ce.Args) == 1 {
level, err = strconv.Atoi(ce.Args[0])
if err != nil {
ce.Reply("Invalid power level \"%s\"", ce.Args[0])
return
}
userID = ce.User.MXID
} else if len(ce.Args) == 2 {
userID = id.UserID(ce.Args[0])
_, _, err := userID.Parse()
if err != nil {
ce.Reply("Invalid user ID \"%s\"", ce.Args[0])
return
}
level, err = strconv.Atoi(ce.Args[1])
if err != nil {
ce.Reply("Invalid power level \"%s\"", ce.Args[1])
return
}
} else {
ce.Reply("**Usage:** `set-pl [user] <level>`")
return
}
intent := portal.MainIntent()
_, err = intent.SetPowerLevel(ce.RoomID, userID, level)
if err != nil {
ce.Reply("Failed to set power levels: %v", err)
}
}
const cmdLoginHelp = `login - Authenticate this Bridge as WhatsApp Web Client` const cmdLoginHelp = `login - Authenticate this Bridge as WhatsApp Web Client`
// CommandLogin handles login command // CommandLogin handles login command
@ -186,6 +229,16 @@ func (handler *CommandHandler) CommandLogout(ce *CommandEvent) {
if ce.User.Session == nil { if ce.User.Session == nil {
ce.Reply("You're not logged in.") ce.Reply("You're not logged in.")
return return
} else if !ce.User.IsConnected() {
ce.Reply("You are not connected to WhatsApp. Use the `reconnect` command to reconnect, or `delete-session` to forget all login information.")
return
}
puppet := handler.bridge.GetPuppetByJID(ce.User.JID)
if puppet.CustomMXID != "" {
err := puppet.SwitchCustomMXID("", "")
if err != nil {
ce.User.log.Warnln("Failed to logout-matrix while logging out of WhatsApp:", err)
}
} }
err := ce.User.Conn.Logout() err := ce.User.Conn.Logout()
if err != nil { if err != nil {
@ -199,6 +252,9 @@ func (handler *CommandHandler) CommandLogout(ce *CommandEvent) {
} }
ce.User.Conn.RemoveHandlers() ce.User.Conn.RemoveHandlers()
ce.User.Conn = nil ce.User.Conn = nil
ce.User.removeFromJIDMap()
// TODO this causes a foreign key violation, which should be fixed
//ce.User.JID = ""
ce.User.SetSession(nil) ce.User.SetSession(nil)
ce.Reply("Logged out successfully.") ce.Reply("Logged out successfully.")
} }
@ -516,6 +572,7 @@ func (handler *CommandHandler) CommandOpen(ce *CommandEvent) {
portal.Sync(user, contact) portal.Sync(user, contact)
ce.Reply("Portal room created.") ce.Reply("Portal room created.")
} }
_, _ = portal.MainIntent().InviteUser(portal.MXID, &mautrix.ReqInviteUser{UserID: user.MXID})
} }
const cmdPMHelp = `pm [--force] <_international phone number_> - Open a private chat with the given phone number.` const cmdPMHelp = `pm [--force] <_international phone number_> - Open a private chat with the given phone number.`

View file

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2019 Tulir Asokan // Copyright (C) 2020 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by
@ -21,7 +21,6 @@ import (
"net/http" "net/http"
"maunium.net/go/mautrix" "maunium.net/go/mautrix"
appservice "maunium.net/go/mautrix-appservice"
) )
func (user *User) inviteToCommunity() { func (user *User) inviteToCommunity() {
@ -51,7 +50,7 @@ func (user *User) createCommunity() {
return return
} }
localpart, server := appservice.ParseUserID(user.MXID) localpart, server, _ := user.MXID.Parse()
community := user.bridge.Config.Bridge.FormatCommunity(localpart, server) community := user.bridge.Config.Bridge.FormatCommunity(localpart, server)
user.log.Debugln("Creating personal filtering community", community) user.log.Debugln("Creating personal filtering community", community)
bot := user.bridge.Bot bot := user.bridge.Bot
@ -100,8 +99,8 @@ func (user *User) addPuppetToCommunity(puppet *Puppet) bool {
"type": "private", "type": "private",
}, },
} }
url = bot.BuildURLWithQuery([]string{"groups", user.CommunityID, "self", "accept_invite"}, map[string]string{ url = bot.BuildURLWithQuery(mautrix.URLPath{"groups", user.CommunityID, "self", "accept_invite"}, map[string]string{
"user_id": puppet.MXID, "user_id": puppet.MXID.String(),
}) })
_, err = bot.MakeRequest(http.MethodPut, url, &reqBody, nil) _, err = bot.MakeRequest(http.MethodPut, url, &reqBody, nil)
if err != nil { if err != nil {

View file

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2019 Tulir Asokan // Copyright (C) 2020 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by
@ -24,8 +24,8 @@ import (
"github.com/Rhymen/go-whatsapp" "github.com/Rhymen/go-whatsapp"
"maunium.net/go/mautrix" "maunium.net/go/mautrix/event"
"maunium.net/go/mautrix-appservice" "maunium.net/go/mautrix/id"
"maunium.net/go/mautrix-whatsapp/types" "maunium.net/go/mautrix-whatsapp/types"
) )
@ -54,8 +54,8 @@ type BridgeConfig struct {
RecoverHistory bool `yaml:"recovery_history_backfill"` RecoverHistory bool `yaml:"recovery_history_backfill"`
SyncChatMaxAge uint64 `yaml:"sync_max_chat_age"` SyncChatMaxAge uint64 `yaml:"sync_max_chat_age"`
SyncWithCustomPuppets bool `yaml:"sync_with_custom_puppets"` SyncWithCustomPuppets bool `yaml:"sync_with_custom_puppets"`
LoginSharedSecret string `yaml:"login_shared_secret"` LoginSharedSecret string `yaml:"login_shared_secret"`
InviteOwnPuppetForBackfilling bool `yaml:"invite_own_puppet_for_backfilling"` InviteOwnPuppetForBackfilling bool `yaml:"invite_own_puppet_for_backfilling"`
PrivateChatPortalMeta bool `yaml:"private_chat_portal_meta"` PrivateChatPortalMeta bool `yaml:"private_chat_portal_meta"`
@ -64,6 +64,11 @@ type BridgeConfig struct {
CommandPrefix string `yaml:"command_prefix"` CommandPrefix string `yaml:"command_prefix"`
Encryption struct {
Allow bool `yaml:"allow"`
Default bool `yaml:"default"`
} `yaml:"encryption"`
Permissions PermissionConfig `yaml:"permissions"` Permissions PermissionConfig `yaml:"permissions"`
Relaybot RelaybotConfig `yaml:"relaybot"` Relaybot RelaybotConfig `yaml:"relaybot"`
@ -127,7 +132,7 @@ func (bc *BridgeConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
} }
type UsernameTemplateArgs struct { type UsernameTemplateArgs struct {
UserID string UserID id.UserID
} }
func (bc BridgeConfig) FormatDisplayname(contact whatsapp.Contact) (string, int8) { func (bc BridgeConfig) FormatDisplayname(contact whatsapp.Contact) (string, int8) {
@ -232,25 +237,25 @@ func (pc *PermissionConfig) MarshalYAML() (interface{}, error) {
return rawPC, nil return rawPC, nil
} }
func (pc PermissionConfig) IsRelaybotWhitelisted(userID string) bool { func (pc PermissionConfig) IsRelaybotWhitelisted(userID id.UserID) bool {
return pc.GetPermissionLevel(userID) >= PermissionLevelRelaybot return pc.GetPermissionLevel(userID) >= PermissionLevelRelaybot
} }
func (pc PermissionConfig) IsWhitelisted(userID string) bool { func (pc PermissionConfig) IsWhitelisted(userID id.UserID) bool {
return pc.GetPermissionLevel(userID) >= PermissionLevelUser return pc.GetPermissionLevel(userID) >= PermissionLevelUser
} }
func (pc PermissionConfig) IsAdmin(userID string) bool { func (pc PermissionConfig) IsAdmin(userID id.UserID) bool {
return pc.GetPermissionLevel(userID) >= PermissionLevelAdmin return pc.GetPermissionLevel(userID) >= PermissionLevelAdmin
} }
func (pc PermissionConfig) GetPermissionLevel(userID string) PermissionLevel { func (pc PermissionConfig) GetPermissionLevel(userID id.UserID) PermissionLevel {
permissions, ok := pc[userID] permissions, ok := pc[string(userID)]
if ok { if ok {
return permissions return permissions
} }
_, homeserver := appservice.ParseUserID(userID) _, homeserver, _ := userID.Parse()
permissions, ok = pc[homeserver] permissions, ok = pc[homeserver]
if len(homeserver) > 0 && ok { if len(homeserver) > 0 && ok {
return permissions return permissions
@ -265,12 +270,12 @@ func (pc PermissionConfig) GetPermissionLevel(userID string) PermissionLevel {
} }
type RelaybotConfig struct { type RelaybotConfig struct {
Enabled bool `yaml:"enabled"` Enabled bool `yaml:"enabled"`
ManagementRoom string `yaml:"management"` ManagementRoom id.RoomID `yaml:"management"`
InviteUsers []types.MatrixUserID `yaml:"invites"` InviteUsers []id.UserID `yaml:"invites"`
MessageFormats map[mautrix.MessageType]string `yaml:"message_formats"` MessageFormats map[event.MessageType]string `yaml:"message_formats"`
messageTemplates *template.Template `yaml:"-"` messageTemplates *template.Template `yaml:"-"`
} }
type umRelaybotConfig RelaybotConfig type umRelaybotConfig RelaybotConfig
@ -293,25 +298,25 @@ func (rc *RelaybotConfig) UnmarshalYAML(unmarshal func(interface{}) error) error
} }
type Sender struct { type Sender struct {
UserID types.MatrixUserID UserID id.UserID
mautrix.Member *event.MemberEventContent
} }
type formatData struct { type formatData struct {
Sender Sender Sender Sender
Message string Message string
Content mautrix.Content Content *event.MessageEventContent
} }
func (rc *RelaybotConfig) FormatMessage(evt *mautrix.Event, member mautrix.Member) (string, error) { func (rc *RelaybotConfig) FormatMessage(content *event.MessageEventContent, sender id.UserID, member *event.MemberEventContent) (string, error) {
var output strings.Builder var output strings.Builder
err := rc.messageTemplates.ExecuteTemplate(&output, string(evt.Content.MsgType), formatData{ err := rc.messageTemplates.ExecuteTemplate(&output, string(content.MsgType), formatData{
Sender: Sender{ Sender: Sender{
UserID: evt.Sender, UserID: sender,
Member: member, MemberEventContent: member,
}, },
Content: evt.Content, Content: content,
Message: evt.Content.FormattedBody, Message: content.FormattedBody,
}) })
return output.String(), err return output.String(), err
} }

View file

@ -21,7 +21,7 @@ import (
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
"maunium.net/go/mautrix-appservice" "maunium.net/go/mautrix/appservice"
) )
type Config struct { type Config struct {

View file

@ -20,7 +20,7 @@ import (
"fmt" "fmt"
"regexp" "regexp"
"maunium.net/go/mautrix-appservice" "maunium.net/go/mautrix/appservice"
) )
func (config *Config) NewRegistration() (*appservice.Registration, error) { func (config *Config) NewRegistration() (*appservice.Registration, error) {

242
crypto.go Normal file
View file

@ -0,0 +1,242 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2020 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
// +build cgo
package main
import (
"crypto/hmac"
"crypto/sha512"
"encoding/hex"
"fmt"
"time"
"github.com/pkg/errors"
"maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix-whatsapp/database"
"maunium.net/go/mautrix/crypto"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
var levelTrace = maulogger.Level{
Name: "Trace",
Severity: -10,
Color: -1,
}
type CryptoHelper struct {
bridge *Bridge
client *mautrix.Client
mach *crypto.OlmMachine
store *database.SQLCryptoStore
log maulogger.Logger
baseLog maulogger.Logger
}
func NewCryptoHelper(bridge *Bridge) Crypto {
if !bridge.Config.Bridge.Encryption.Allow {
bridge.Log.Debugln("Bridge built with end-to-bridge encryption, but disabled in config")
return nil
} else if bridge.Config.Bridge.LoginSharedSecret == "" {
bridge.Log.Warnln("End-to-bridge encryption enabled, but login_shared_secret not set")
return nil
}
baseLog := bridge.Log.Sub("Crypto")
return &CryptoHelper{
bridge: bridge,
log: baseLog.Sub("Helper"),
baseLog: baseLog,
}
}
func (helper *CryptoHelper) Init() error {
helper.log.Debugln("Initializing end-to-bridge encryption...")
var err error
helper.client, err = helper.loginBot()
if err != nil {
return err
}
helper.log.Debugln("Logged in as bridge bot with device ID", helper.client.DeviceID)
logger := &cryptoLogger{helper.baseLog}
stateStore := &cryptoStateStore{helper.bridge}
helper.store = database.NewSQLCryptoStore(helper.bridge.DB, helper.client.DeviceID)
helper.store.UserID = helper.client.UserID
helper.store.GhostIDFormat = fmt.Sprintf("@%s:%s", helper.bridge.Config.Bridge.FormatUsername("%"), helper.bridge.AS.HomeserverDomain)
helper.mach = crypto.NewOlmMachine(helper.client, logger, helper.store, stateStore)
helper.client.Logger = logger.int.Sub("Bot")
helper.client.Syncer = &cryptoSyncer{helper.mach}
helper.client.Store = &cryptoClientStore{helper.store}
return helper.mach.Load()
}
func (helper *CryptoHelper) loginBot() (*mautrix.Client, error) {
deviceID := helper.bridge.DB.FindDeviceID()
if len(deviceID) > 0 {
helper.log.Debugln("Found existing device ID for bot in database:", deviceID)
}
mac := hmac.New(sha512.New, []byte(helper.bridge.Config.Bridge.LoginSharedSecret))
mac.Write([]byte(helper.bridge.AS.BotMXID()))
resp, err := helper.bridge.AS.BotClient().Login(&mautrix.ReqLogin{
Type: "m.login.password",
Identifier: mautrix.UserIdentifier{Type: "m.id.user", User: string(helper.bridge.AS.BotMXID())},
Password: hex.EncodeToString(mac.Sum(nil)),
DeviceID: deviceID,
InitialDeviceDisplayName: "WhatsApp Bridge",
})
if err != nil {
return nil, err
}
client, err := mautrix.NewClient(helper.bridge.AS.HomeserverURL, helper.bridge.AS.BotMXID(), resp.AccessToken)
if err != nil {
return nil, err
}
client.DeviceID = resp.DeviceID
return client, nil
}
func (helper *CryptoHelper) Start() {
helper.log.Debugln("Starting syncer for receiving to-device messages")
err := helper.client.Sync()
if err != nil {
helper.log.Errorln("Fatal error syncing:", err)
}
}
func (helper *CryptoHelper) Stop() {
helper.client.StopSync()
}
func (helper *CryptoHelper) Decrypt(evt *event.Event) (*event.Event, error) {
return helper.mach.DecryptMegolmEvent(evt)
}
func (helper *CryptoHelper) Encrypt(roomID id.RoomID, evtType event.Type, content event.Content) (*event.EncryptedEventContent, error) {
encrypted, err := helper.mach.EncryptMegolmEvent(roomID, evtType, content)
if err != nil {
if err != crypto.SessionExpired && err != crypto.SessionNotShared && err != crypto.NoGroupSession {
return nil, err
}
helper.log.Debugfln("Got %v while encrypting event for %s, sharing group session and trying again...", err, roomID)
users, err := helper.store.GetRoomMembers(roomID)
if err != nil {
return nil, errors.Wrap(err, "failed to get room member list")
}
err = helper.mach.ShareGroupSession(roomID, users)
if err != nil {
return nil, errors.Wrap(err, "failed to share group session")
}
encrypted, err = helper.mach.EncryptMegolmEvent(roomID, evtType, content)
if err != nil {
return nil, errors.Wrap(err, "failed to encrypt event after re-sharing group session")
}
}
return encrypted, nil
}
func (helper *CryptoHelper) HandleMemberEvent(evt *event.Event) {
helper.mach.HandleMemberEvent(evt)
}
type cryptoSyncer struct {
*crypto.OlmMachine
}
func (syncer *cryptoSyncer) ProcessResponse(resp *mautrix.RespSync, since string) error {
syncer.ProcessSyncResponse(resp, since)
return nil
}
func (syncer *cryptoSyncer) OnFailedSync(_ *mautrix.RespSync, err error) (time.Duration, error) {
syncer.Log.Error("Error /syncing, waiting 10 seconds: %v", err)
return 10 * time.Second, nil
}
func (syncer *cryptoSyncer) GetFilterJSON(_ id.UserID) *mautrix.Filter {
everything := []event.Type{{Type: "*"}}
return &mautrix.Filter{
Presence: mautrix.FilterPart{NotTypes: everything},
AccountData: mautrix.FilterPart{NotTypes: everything},
Room: mautrix.RoomFilter{
IncludeLeave: false,
Ephemeral: mautrix.FilterPart{NotTypes: everything},
AccountData: mautrix.FilterPart{NotTypes: everything},
State: mautrix.FilterPart{NotTypes: everything},
Timeline: mautrix.FilterPart{NotTypes: everything},
},
}
}
type cryptoLogger struct {
int maulogger.Logger
}
func (c *cryptoLogger) Error(message string, args ...interface{}) {
c.int.Errorfln(message, args...)
}
func (c *cryptoLogger) Warn(message string, args ...interface{}) {
c.int.Warnfln(message, args...)
}
func (c *cryptoLogger) Debug(message string, args ...interface{}) {
c.int.Debugfln(message, args...)
}
func (c *cryptoLogger) Trace(message string, args ...interface{}) {
c.int.Logfln(levelTrace, message, args...)
}
type cryptoClientStore struct {
int *database.SQLCryptoStore
}
func (c cryptoClientStore) SaveFilterID(_ id.UserID, _ string) {}
func (c cryptoClientStore) LoadFilterID(_ id.UserID) string { return "" }
func (c cryptoClientStore) SaveRoom(_ *mautrix.Room) {}
func (c cryptoClientStore) LoadRoom(_ id.RoomID) *mautrix.Room { return nil }
func (c cryptoClientStore) SaveNextBatch(_ id.UserID, nextBatchToken string) {
c.int.PutNextBatch(nextBatchToken)
}
func (c cryptoClientStore) LoadNextBatch(_ id.UserID) string {
return c.int.GetNextBatch()
}
var _ mautrix.Storer = (*cryptoClientStore)(nil)
type cryptoStateStore struct {
bridge *Bridge
}
func (c *cryptoStateStore) IsEncrypted(id id.RoomID) bool {
portal := c.bridge.GetPortalByMXID(id)
if portal != nil {
return portal.Encrypted
}
return false
}
func (c *cryptoStateStore) FindSharedRooms(id id.UserID) []id.RoomID {
return c.bridge.StateStore.FindSharedRooms(id)
}

View file

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2019 Tulir Asokan // Copyright (C) 2020 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by
@ -20,17 +20,16 @@ import (
"crypto/hmac" "crypto/hmac"
"crypto/sha512" "crypto/sha512"
"encoding/hex" "encoding/hex"
"encoding/json"
"fmt"
"os"
"strings"
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/Rhymen/go-whatsapp" "github.com/Rhymen/go-whatsapp"
"maunium.net/go/mautrix" "maunium.net/go/mautrix"
appservice "maunium.net/go/mautrix-appservice" "maunium.net/go/mautrix/appservice"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
) )
var ( var (
@ -38,7 +37,7 @@ var (
ErrMismatchingMXID = errors.New("whoami result does not match custom mxid") ErrMismatchingMXID = errors.New("whoami result does not match custom mxid")
) )
func (puppet *Puppet) SwitchCustomMXID(accessToken string, mxid string) error { func (puppet *Puppet) SwitchCustomMXID(accessToken string, mxid id.UserID) error {
prevCustomMXID := puppet.CustomMXID prevCustomMXID := puppet.CustomMXID
if puppet.customIntent != nil { if puppet.customIntent != nil {
puppet.stopSyncing() puppet.stopSyncing()
@ -63,12 +62,12 @@ func (puppet *Puppet) SwitchCustomMXID(accessToken string, mxid string) error {
return nil return nil
} }
func (puppet *Puppet) loginWithSharedSecret(mxid string) (string, error) { func (puppet *Puppet) loginWithSharedSecret(mxid id.UserID) (string, error) {
mac := hmac.New(sha512.New, []byte(puppet.bridge.Config.Bridge.LoginSharedSecret)) mac := hmac.New(sha512.New, []byte(puppet.bridge.Config.Bridge.LoginSharedSecret))
mac.Write([]byte(mxid)) mac.Write([]byte(mxid))
resp, err := puppet.bridge.AS.BotClient().Login(&mautrix.ReqLogin{ resp, err := puppet.bridge.AS.BotClient().Login(&mautrix.ReqLogin{
Type: "m.login.password", Type: "m.login.password",
Identifier: mautrix.UserIdentifier{Type: "m.id.user", User: mxid}, Identifier: mautrix.UserIdentifier{Type: "m.id.user", User: string(mxid)},
Password: hex.EncodeToString(mac.Sum(nil)), Password: hex.EncodeToString(mac.Sum(nil)),
DeviceID: "WhatsApp Bridge", DeviceID: "WhatsApp Bridge",
InitialDeviceDisplayName: "WhatsApp Bridge", InitialDeviceDisplayName: "WhatsApp Bridge",
@ -87,13 +86,13 @@ func (puppet *Puppet) newCustomIntent() (*appservice.IntentAPI, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
client.Logger = puppet.bridge.AS.Log.Sub(puppet.CustomMXID) client.Logger = puppet.bridge.AS.Log.Sub(string(puppet.CustomMXID))
client.Syncer = puppet client.Syncer = puppet
client.Store = puppet client.Store = puppet
ia := puppet.bridge.AS.NewIntentAPI("custom") ia := puppet.bridge.AS.NewIntentAPI("custom")
ia.Client = client ia.Client = client
ia.Localpart = puppet.CustomMXID[1:strings.IndexRune(puppet.CustomMXID, ':')] ia.Localpart, _, _ = puppet.CustomMXID.Parse()
ia.UserID = puppet.CustomMXID ia.UserID = puppet.CustomMXID
ia.IsCustomPuppet = true ia.IsCustomPuppet = true
return ia, nil return ia, nil
@ -117,11 +116,7 @@ func (puppet *Puppet) StartCustomMXID() error {
puppet.clearCustomMXID() puppet.clearCustomMXID()
return err return err
} }
urlPath := intent.BuildURL("account", "whoami") resp, err := intent.Whoami()
var resp struct {
UserID string `json:"user_id"`
}
_, err = intent.MakeRequest("GET", urlPath, nil, &resp)
if err != nil { if err != nil {
puppet.clearCustomMXID() puppet.clearCustomMXID()
return err return err
@ -131,7 +126,7 @@ func (puppet *Puppet) StartCustomMXID() error {
return ErrMismatchingMXID return ErrMismatchingMXID
} }
puppet.customIntent = intent puppet.customIntent = intent
puppet.customTypingIn = make(map[string]bool) puppet.customTypingIn = make(map[id.RoomID]bool)
puppet.customUser = puppet.bridge.GetUserByMXID(puppet.CustomMXID) puppet.customUser = puppet.bridge.GetUserByMXID(puppet.CustomMXID)
puppet.startSyncing() puppet.startSyncing()
return nil return nil
@ -158,28 +153,6 @@ func (puppet *Puppet) stopSyncing() {
puppet.customIntent.StopSync() puppet.customIntent.StopSync()
} }
func parseEvent(roomID string, data json.RawMessage) *mautrix.Event {
event := &mautrix.Event{}
err := json.Unmarshal(data, event)
if err != nil {
// TODO add separate handler for these
_, _ = fmt.Fprintf(os.Stderr, "Failed to unmarshal event: %v\n%s\n", err, string(data))
return nil
}
return event
}
func parsePresenceEvent(data json.RawMessage) *mautrix.Event {
event := &mautrix.Event{}
err := json.Unmarshal(data, event)
if err != nil {
// TODO add separate handler for these
_, _ = fmt.Fprintf(os.Stderr, "Failed to unmarshal event: %v\n%s\n", err, string(data))
return nil
}
return event
}
func (puppet *Puppet) ProcessResponse(resp *mautrix.RespSync, since string) error { func (puppet *Puppet) ProcessResponse(resp *mautrix.RespSync, since string) error {
if !puppet.customUser.IsConnected() { if !puppet.customUser.IsConnected() {
puppet.log.Debugln("Skipping sync processing: custom user not connected to whatsapp") puppet.log.Debugln("Skipping sync processing: custom user not connected to whatsapp")
@ -190,31 +163,33 @@ func (puppet *Puppet) ProcessResponse(resp *mautrix.RespSync, since string) erro
if portal == nil { if portal == nil {
continue continue
} }
for _, data := range events.Ephemeral.Events { for _, evt := range events.Ephemeral.Events {
event := parseEvent(roomID, data) err := evt.Content.ParseRaw(evt.Type)
if event != nil { if err != nil {
switch event.Type { continue
case mautrix.EphemeralEventReceipt: }
go puppet.handleReceiptEvent(portal, event) switch evt.Type {
case mautrix.EphemeralEventTyping: case event.EphemeralEventReceipt:
go puppet.handleTypingEvent(portal, event) go puppet.handleReceiptEvent(portal, evt)
} case event.EphemeralEventTyping:
go puppet.handleTypingEvent(portal, evt)
} }
} }
} }
for _, data := range resp.Presence.Events { for _, evt := range resp.Presence.Events {
event := parsePresenceEvent(data) if evt.Sender != puppet.CustomMXID {
if event != nil { continue
if event.Sender != puppet.CustomMXID {
continue
}
go puppet.handlePresenceEvent(event)
} }
err := evt.Content.ParseRaw(evt.Type)
if err != nil {
continue
}
go puppet.handlePresenceEvent(evt)
} }
return nil return nil
} }
func (puppet *Puppet) handlePresenceEvent(event *mautrix.Event) { func (puppet *Puppet) handlePresenceEvent(event *event.Event) {
presence := whatsapp.PresenceAvailable presence := whatsapp.PresenceAvailable
if event.Content.Raw["presence"].(string) != "online" { if event.Content.Raw["presence"].(string) != "online" {
presence = whatsapp.PresenceUnavailable presence = whatsapp.PresenceUnavailable
@ -228,13 +203,9 @@ func (puppet *Puppet) handlePresenceEvent(event *mautrix.Event) {
} }
} }
func (puppet *Puppet) handleReceiptEvent(portal *Portal, event *mautrix.Event) { func (puppet *Puppet) handleReceiptEvent(portal *Portal, event *event.Event) {
for eventID, rawReceipts := range event.Content.Raw { for eventID, receipts := range *event.Content.AsReceipt() {
if receipts, ok := rawReceipts.(map[string]interface{}); !ok { if _, ok := receipts.Read[puppet.CustomMXID]; !ok {
continue
} else if readReceipt, ok := receipts["m.read"].(map[string]interface{}); !ok {
continue
} else if _, ok = readReceipt[puppet.CustomMXID].(map[string]interface{}); !ok {
continue continue
} }
message := puppet.bridge.DB.Message.GetByMXID(eventID) message := puppet.bridge.DB.Message.GetByMXID(eventID)
@ -249,16 +220,16 @@ func (puppet *Puppet) handleReceiptEvent(portal *Portal, event *mautrix.Event) {
} }
} }
func (puppet *Puppet) handleTypingEvent(portal *Portal, event *mautrix.Event) { func (puppet *Puppet) handleTypingEvent(portal *Portal, evt *event.Event) {
isTyping := false isTyping := false
for _, userID := range event.Content.TypingUserIDs { for _, userID := range evt.Content.AsTyping().UserIDs {
if userID == puppet.CustomMXID { if userID == puppet.CustomMXID {
isTyping = true isTyping = true
break break
} }
} }
if puppet.customTypingIn[event.RoomID] != isTyping { if puppet.customTypingIn[evt.RoomID] != isTyping {
puppet.customTypingIn[event.RoomID] = isTyping puppet.customTypingIn[evt.RoomID] = isTyping
presence := whatsapp.PresenceComposing presence := whatsapp.PresenceComposing
if !isTyping { if !isTyping {
puppet.customUser.log.Infofln("Marking not typing in %s/%s", portal.Key.JID, portal.MXID) puppet.customUser.log.Infofln("Marking not typing in %s/%s", portal.Key.JID, portal.MXID)
@ -278,36 +249,27 @@ func (puppet *Puppet) OnFailedSync(res *mautrix.RespSync, err error) (time.Durat
return 10 * time.Second, nil return 10 * time.Second, nil
} }
func (puppet *Puppet) GetFilterJSON(_ string) json.RawMessage { func (puppet *Puppet) GetFilterJSON(_ id.UserID) *mautrix.Filter {
mxid, _ := json.Marshal(puppet.CustomMXID) everything := []event.Type{{Type: "*"}}
return json.RawMessage(fmt.Sprintf(`{ return &mautrix.Filter{
"account_data": { "types": [] }, Presence: mautrix.FilterPart{
"presence": { Senders: []id.UserID{puppet.CustomMXID},
"senders": [ Types: []event.Type{event.EphemeralEventPresence},
%s },
], AccountData: mautrix.FilterPart{NotTypes: everything},
"types": [ Room: mautrix.RoomFilter{
"m.presence" Ephemeral: mautrix.FilterPart{Types: []event.Type{event.EphemeralEventTyping, event.EphemeralEventReceipt}},
] IncludeLeave: false,
}, AccountData: mautrix.FilterPart{NotTypes: everything},
"room": { State: mautrix.FilterPart{NotTypes: everything},
"ephemeral": { Timeline: mautrix.FilterPart{NotTypes: everything},
"types": [ },
"m.typing", }
"m.receipt"
]
},
"include_leave": false,
"account_data": { "types": [] },
"state": { "types": [] },
"timeline": { "types": [] }
}
}`, mxid))
} }
func (puppet *Puppet) SaveFilterID(_, _ string) {} func (puppet *Puppet) SaveFilterID(_ id.UserID, _ string) {}
func (puppet *Puppet) SaveNextBatch(_, nbt string) { puppet.NextBatch = nbt; puppet.Update() } func (puppet *Puppet) SaveNextBatch(_ id.UserID, nbt string) { puppet.NextBatch = nbt; puppet.Update() }
func (puppet *Puppet) SaveRoom(room *mautrix.Room) {} func (puppet *Puppet) SaveRoom(room *mautrix.Room) {}
func (puppet *Puppet) LoadFilterID(_ string) string { return "" } func (puppet *Puppet) LoadFilterID(_ id.UserID) string { return "" }
func (puppet *Puppet) LoadNextBatch(_ string) string { return puppet.NextBatch } func (puppet *Puppet) LoadNextBatch(_ id.UserID) string { return puppet.NextBatch }
func (puppet *Puppet) LoadRoom(roomID string) *mautrix.Room { return nil } func (puppet *Puppet) LoadRoom(roomID id.RoomID) *mautrix.Room { return nil }

438
database/cryptostore.go Normal file
View file

@ -0,0 +1,438 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2020 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
// +build cgo
package database
import (
"database/sql"
"fmt"
"strings"
"github.com/lib/pq"
"github.com/pkg/errors"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/crypto"
"maunium.net/go/mautrix/crypto/olm"
"maunium.net/go/mautrix/id"
)
type SQLCryptoStore struct {
db *Database
log log.Logger
UserID id.UserID
DeviceID id.DeviceID
SyncToken string
PickleKey []byte
Account *crypto.OlmAccount
GhostIDFormat string
}
var _ crypto.Store = (*SQLCryptoStore)(nil)
func NewSQLCryptoStore(db *Database, deviceID id.DeviceID) *SQLCryptoStore {
return &SQLCryptoStore{
db: db,
log: db.log.Sub("CryptoStore"),
PickleKey: []byte("maunium.net/go/mautrix-whatsapp"),
DeviceID: deviceID,
}
}
func (db *Database) FindDeviceID() (deviceID id.DeviceID) {
err := db.QueryRow("SELECT device_id FROM crypto_account LIMIT 1").Scan(&deviceID)
if err != nil && err != sql.ErrNoRows {
db.log.Warnln("Failed to scan device ID:", err)
}
return
}
func (store *SQLCryptoStore) GetRoomMembers(roomID id.RoomID) (members []id.UserID, err error) {
var rows *sql.Rows
rows, err = store.db.Query(`
SELECT user_id FROM mx_user_profile
WHERE room_id=$1
AND (membership='join' OR membership='invite')
AND user_id<>$2
AND user_id NOT LIKE $3
`, roomID, store.UserID, store.GhostIDFormat)
if err != nil {
return
}
for rows.Next() {
var userID id.UserID
err := rows.Scan(&userID)
if err != nil {
store.log.Warnfln("Failed to scan member in %s: %v", roomID, err)
} else {
members = append(members, userID)
}
}
return
}
func (store *SQLCryptoStore) Flush() error {
return nil
}
func (store *SQLCryptoStore) PutNextBatch(nextBatch string) {
store.SyncToken = nextBatch
_, err := store.db.Exec(`UPDATE crypto_account SET sync_token=$1 WHERE device_id=$2`, store.SyncToken, store.DeviceID)
if err != nil {
store.log.Warnln("Failed to store sync token:", err)
}
}
func (store *SQLCryptoStore) GetNextBatch() string {
if store.SyncToken == "" {
err := store.db.
QueryRow("SELECT sync_token FROM crypto_account WHERE device_id=$1", store.DeviceID).
Scan(&store.SyncToken)
if err != nil && err != sql.ErrNoRows {
store.log.Warnln("Failed to scan sync token:", err)
}
}
return store.SyncToken
}
func (store *SQLCryptoStore) PutAccount(account *crypto.OlmAccount) error {
store.Account = account
bytes := account.Internal.Pickle(store.PickleKey)
var err error
if store.db.dialect == "postgres" {
_, err = store.db.Exec(`
INSERT INTO crypto_account (device_id, shared, sync_token, account) VALUES ($1, $2, $3, $4)
ON CONFLICT (device_id) DO UPDATE SET shared=$2, sync_token=$3, account=$4`,
store.DeviceID, account.Shared, store.SyncToken, bytes)
} else if store.db.dialect == "sqlite3" {
_, err = store.db.Exec("INSERT OR REPLACE INTO crypto_account (device_id, shared, sync_token, account) VALUES ($1, $2, $3, $4)",
store.DeviceID, account.Shared, store.SyncToken, bytes)
} else {
err = fmt.Errorf("unsupported dialect %s", store.db.dialect)
}
if err != nil {
store.log.Warnln("Failed to store account:", err)
}
return nil
}
func (store *SQLCryptoStore) GetAccount() (*crypto.OlmAccount, error) {
if store.Account == nil {
row := store.db.QueryRow("SELECT shared, sync_token, account FROM crypto_account WHERE device_id=$1", store.DeviceID)
acc := &crypto.OlmAccount{Internal: *olm.NewBlankAccount()}
var accountBytes []byte
err := row.Scan(&acc.Shared, &store.SyncToken, &accountBytes)
if err == sql.ErrNoRows {
return nil, nil
} else if err != nil {
return nil, err
}
err = acc.Internal.Unpickle(accountBytes, store.PickleKey)
if err != nil {
return nil, err
}
store.Account = acc
}
return store.Account, nil
}
func (store *SQLCryptoStore) HasSession(key id.SenderKey) bool {
// TODO this may need to be changed if olm sessions start expiring
var sessionID id.SessionID
err := store.db.QueryRow("SELECT session_id FROM crypto_olm_session WHERE sender_key=$1 LIMIT 1", key).Scan(&sessionID)
if err == sql.ErrNoRows {
return false
}
return len(sessionID) > 0
}
func (store *SQLCryptoStore) GetSessions(key id.SenderKey) (crypto.OlmSessionList, error) {
rows, err := store.db.Query("SELECT session, created_at, last_used FROM crypto_olm_session WHERE sender_key=$1 ORDER BY session_id", key)
if err != nil {
return nil, err
}
list := crypto.OlmSessionList{}
for rows.Next() {
sess := crypto.OlmSession{Internal: *olm.NewBlankSession()}
var sessionBytes []byte
err := rows.Scan(&sessionBytes, &sess.CreationTime, &sess.UseTime)
if err != nil {
return nil, err
}
err = sess.Internal.Unpickle(sessionBytes, store.PickleKey)
if err != nil {
return nil, err
}
list = append(list, &sess)
}
return list, nil
}
func (store *SQLCryptoStore) GetLatestSession(key id.SenderKey) (*crypto.OlmSession, error) {
row := store.db.QueryRow("SELECT session, created_at, last_used FROM crypto_olm_session WHERE sender_key=$1 ORDER BY session_id DESC LIMIT 1", key)
sess := crypto.OlmSession{Internal: *olm.NewBlankSession()}
var sessionBytes []byte
err := row.Scan(&sessionBytes, &sess.CreationTime, &sess.UseTime)
if err == sql.ErrNoRows {
return nil, nil
} else if err != nil {
return nil, err
}
return &sess, sess.Internal.Unpickle(sessionBytes, store.PickleKey)
}
func (store *SQLCryptoStore) AddSession(key id.SenderKey, session *crypto.OlmSession) error {
sessionBytes := session.Internal.Pickle(store.PickleKey)
_, err := store.db.Exec("INSERT INTO crypto_olm_session (session_id, sender_key, session, created_at, last_used) VALUES ($1, $2, $3, $4, $5)",
session.ID(), key, sessionBytes, session.CreationTime, session.UseTime)
return err
}
func (store *SQLCryptoStore) UpdateSession(key id.SenderKey, session *crypto.OlmSession) error {
sessionBytes := session.Internal.Pickle(store.PickleKey)
_, err := store.db.Exec("UPDATE crypto_olm_session SET session=$1, last_used=$2 WHERE session_id=$3",
sessionBytes, session.UseTime, session.ID())
return err
}
func (store *SQLCryptoStore) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, session *crypto.InboundGroupSession) error {
sessionBytes := session.Internal.Pickle(store.PickleKey)
forwardingChains := strings.Join(session.ForwardingChains, ",")
_, err := store.db.Exec("INSERT INTO crypto_megolm_inbound_session (session_id, sender_key, signing_key, room_id, session, forwarding_chains) VALUES ($1, $2, $3, $4, $5, $6)",
sessionID, senderKey, session.SigningKey, roomID, sessionBytes, forwardingChains)
return err
}
func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*crypto.InboundGroupSession, error) {
var signingKey id.Ed25519
var sessionBytes []byte
var forwardingChains string
err := store.db.QueryRow(`
SELECT signing_key, session, forwarding_chains
FROM crypto_megolm_inbound_session
WHERE room_id=$1 AND sender_key=$2 AND session_id=$3`,
roomID, senderKey, sessionID,
).Scan(&signingKey, &sessionBytes, &forwardingChains)
if err == sql.ErrNoRows {
return nil, nil
} else if err != nil {
return nil, err
}
igs := olm.NewBlankInboundGroupSession()
err = igs.Unpickle(sessionBytes, store.PickleKey)
if err != nil {
return nil, err
}
return &crypto.InboundGroupSession{
Internal: *igs,
SigningKey: signingKey,
SenderKey: senderKey,
RoomID: roomID,
ForwardingChains: strings.Split(forwardingChains, ","),
}, nil
}
func (store *SQLCryptoStore) AddOutboundGroupSession(session *crypto.OutboundGroupSession) (err error) {
sessionBytes := session.Internal.Pickle(store.PickleKey)
if store.db.dialect == "postgres" {
_, err = store.db.Exec(`
INSERT INTO crypto_megolm_outbound_session (
room_id, session_id, session, shared, max_messages, message_count, max_age, created_at, last_used
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
ON CONFLICT (room_id) DO UPDATE SET session_id=$2, session=$3, shared=$4, max_messages=$5, message_count=$6, max_age=$7, created_at=$8, last_used=$9`,
session.RoomID, session.ID(), sessionBytes, session.Shared, session.MaxMessages, session.MessageCount, session.MaxAge, session.CreationTime, session.UseTime)
} else if store.db.dialect == "sqlite3" {
_, err = store.db.Exec(`
INSERT OR REPLACE INTO crypto_megolm_outbound_session (
room_id, session_id, session, shared, max_messages, message_count, max_age, created_at, last_used
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)`,
session.RoomID, session.ID(), sessionBytes, session.Shared, session.MaxMessages, session.MessageCount, session.MaxAge, session.CreationTime, session.UseTime)
} else {
err = fmt.Errorf("unsupported dialect %s", store.db.dialect)
}
return
}
func (store *SQLCryptoStore) UpdateOutboundGroupSession(session *crypto.OutboundGroupSession) error {
sessionBytes := session.Internal.Pickle(store.PickleKey)
_, err := store.db.Exec("UPDATE crypto_megolm_outbound_session SET session=$1, message_count=$2, last_used=$3 WHERE room_id=$4 AND session_id=$5",
sessionBytes, session.MessageCount, session.UseTime, session.RoomID, session.ID())
return err
}
func (store *SQLCryptoStore) GetOutboundGroupSession(roomID id.RoomID) (*crypto.OutboundGroupSession, error) {
var ogs crypto.OutboundGroupSession
var sessionBytes []byte
err := store.db.QueryRow(`
SELECT session, shared, max_messages, message_count, max_age, created_at, last_used
FROM crypto_megolm_outbound_session WHERE room_id=$1`,
roomID,
).Scan(&sessionBytes, &ogs.Shared, &ogs.MaxMessages, &ogs.MessageCount, &ogs.MaxAge, &ogs.CreationTime, &ogs.UseTime)
if err == sql.ErrNoRows {
return nil, nil
} else if err != nil {
return nil, err
}
intOGS := olm.NewBlankOutboundGroupSession()
err = intOGS.Unpickle(sessionBytes, store.PickleKey)
if err != nil {
return nil, err
}
ogs.Internal = *intOGS
ogs.RoomID = roomID
return &ogs, nil
}
func (store *SQLCryptoStore) RemoveOutboundGroupSession(roomID id.RoomID) error {
_, err := store.db.Exec("DELETE FROM crypto_megolm_outbound_session WHERE room_id=$1", roomID)
return err
}
func (store *SQLCryptoStore) ValidateMessageIndex(senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) bool {
var resultEventID id.EventID
var resultTimestamp int64
err := store.db.QueryRow(
`SELECT event_id, timestamp FROM crypto_message_index WHERE sender_key=$1 AND session_id=$2 AND "index"=$3`,
senderKey, sessionID, index,
).Scan(&resultEventID, &resultTimestamp)
if err == sql.ErrNoRows {
_, err := store.db.Exec(`INSERT INTO crypto_message_index (sender_key, session_id, "index", event_id, timestamp) VALUES ($1, $2, $3, $4, $5)`,
senderKey, sessionID, index, eventID, timestamp)
if err != nil {
store.log.Warnln("Failed to store message index:", err)
}
return true
} else if err != nil {
store.log.Warnln("Failed to scan message index:", err)
return true
}
if resultEventID != eventID || resultTimestamp != timestamp {
return false
}
return true
}
func (store *SQLCryptoStore) GetDevices(userID id.UserID) (map[id.DeviceID]*crypto.DeviceIdentity, error) {
var ignore id.UserID
err := store.db.QueryRow("SELECT user_id FROM crypto_tracked_user WHERE user_id=$1", userID).Scan(&ignore)
if err == sql.ErrNoRows {
return nil, nil
} else if err != nil {
return nil, err
}
rows, err := store.db.Query("SELECT device_id, identity_key, signing_key, trust, deleted, name FROM crypto_device WHERE user_id=$1", userID)
if err != nil {
return nil, err
}
data := make(map[id.DeviceID]*crypto.DeviceIdentity)
for rows.Next() {
var identity crypto.DeviceIdentity
err := rows.Scan(&identity.DeviceID, &identity.IdentityKey, &identity.SigningKey, &identity.Trust, &identity.Deleted, &identity.Name)
if err != nil {
return nil, err
}
identity.UserID = userID
data[identity.DeviceID] = &identity
}
return data, nil
}
func (store *SQLCryptoStore) PutDevices(userID id.UserID, devices map[id.DeviceID]*crypto.DeviceIdentity) error {
tx, err := store.db.Begin()
if err != nil {
return err
}
if store.db.dialect == "postgres" {
_, err = tx.Exec("INSERT INTO crypto_tracked_user (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID)
} else if store.db.dialect == "sqlite3" {
_, err = tx.Exec("INSERT OR IGNORE INTO crypto_tracked_user (user_id) VALUES ($1)", userID)
} else {
err = fmt.Errorf("unsupported dialect %s", store.db.dialect)
}
if err != nil {
return errors.Wrap(err, "failed to add user to tracked users list")
}
_, err = tx.Exec("DELETE FROM crypto_device WHERE user_id=$1", userID)
if err != nil {
_ = tx.Rollback()
return errors.Wrap(err, "failed to delete old devices")
}
if len(devices) == 0 {
err = tx.Commit()
if err != nil {
return errors.Wrap(err, "failed to commit changes (no devices added)")
}
return nil
}
// TODO do this in batches to avoid too large db queries
values := make([]interface{}, 1, len(devices)*6+1)
values[0] = userID
valueStrings := make([]string, 0, len(devices))
i := 2
for deviceID, identity := range devices {
values = append(values, deviceID, identity.IdentityKey, identity.SigningKey, identity.Trust, identity.Deleted, identity.Name)
valueStrings = append(valueStrings, fmt.Sprintf("($1, $%d, $%d, $%d, $%d, $%d, $%d)", i, i+1, i+2, i+3, i+4, i+5))
i += 6
}
valueString := strings.Join(valueStrings, ",")
_, err = tx.Exec("INSERT INTO crypto_device (user_id, device_id, identity_key, signing_key, trust, deleted, name) VALUES "+valueString, values...)
if err != nil {
_ = tx.Rollback()
return errors.Wrap(err, "failed to insert new devices")
}
err = tx.Commit()
if err != nil {
return errors.Wrap(err, "failed to commit changes")
}
return nil
}
func (store *SQLCryptoStore) FilterTrackedUsers(users []id.UserID) []id.UserID {
var rows *sql.Rows
var err error
if store.db.dialect == "postgres" {
rows, err = store.db.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id = ANY($1)", pq.Array(users))
} else {
queryString := make([]string, len(users))
params := make([]interface{}, len(users))
for i, user := range users {
queryString[i] = fmt.Sprintf("$%d", i+1)
params[i] = user
}
rows, err = store.db.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id IN ("+strings.Join(queryString, ",")+")", params...)
}
if err != nil {
store.log.Warnln("Failed to filter tracked users:", err)
return users
}
var ptr int
for rows.Next() {
err = rows.Scan(&users[ptr])
if err != nil {
store.log.Warnln("Failed to tracked user ID:", err)
} else {
ptr++
}
}
return users[:ptr]
}

View file

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2019 Tulir Asokan // Copyright (C) 2020 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by
@ -26,6 +26,7 @@ import (
log "maunium.net/go/maulogger/v2" log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix-whatsapp/types" "maunium.net/go/mautrix-whatsapp/types"
"maunium.net/go/mautrix/id"
) )
type MessageQuery struct { type MessageQuery struct {
@ -57,7 +58,7 @@ func (mq *MessageQuery) GetByJID(chat PortalKey, jid types.WhatsAppMessageID) *M
"FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3", chat.JID, chat.Receiver, jid) "FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3", chat.JID, chat.Receiver, jid)
} }
func (mq *MessageQuery) GetByMXID(mxid types.MatrixEventID) *Message { func (mq *MessageQuery) GetByMXID(mxid id.EventID) *Message {
return mq.get("SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, content " + return mq.get("SELECT chat_jid, chat_receiver, jid, mxid, sender, timestamp, content " +
"FROM message WHERE mxid=$1", mxid) "FROM message WHERE mxid=$1", mxid)
} }
@ -86,7 +87,7 @@ type Message struct {
Chat PortalKey Chat PortalKey
JID types.WhatsAppMessageID JID types.WhatsAppMessageID
MXID types.MatrixEventID MXID id.EventID
Sender types.WhatsAppID Sender types.WhatsAppID
Timestamp uint64 Timestamp uint64
Content *waProto.Message Content *waProto.Message

View file

@ -89,7 +89,7 @@ func migrateTable(old *Database, new *Database, table string, columns ...string)
} }
func Migrate(old *Database, new *Database) { func Migrate(old *Database, new *Database) {
err := migrateTable(old, new, "portal", "jid", "receiver", "mxid", "name", "topic", "avatar", "avatar_url") err := migrateTable(old, new, "portal", "jid", "receiver", "mxid", "name", "topic", "avatar", "avatar_url", "encrypted")
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -121,4 +121,32 @@ func Migrate(old *Database, new *Database) {
if err != nil { if err != nil {
panic(err) panic(err)
} }
err = migrateTable(old, new, "crypto_account", "device_id", "shared", "sync_token", "account")
if err != nil {
panic(err)
}
err = migrateTable(old, new, "crypto_message_index", "sender_key", "session_id", `"index"`, "event_id", "timestamp")
if err != nil {
panic(err)
}
err = migrateTable(old, new, "crypto_tracked_user", "user_id")
if err != nil {
panic(err)
}
err = migrateTable(old, new, "crypto_device", "user_id", "device_id", "identity_key", "signing_key", "trust", "deleted", "name")
if err != nil {
panic(err)
}
err = migrateTable(old, new, "crypto_olm_session", "session_id", "sender_key", "session", "created_at", "last_used")
if err != nil {
panic(err)
}
err = migrateTable(old, new, "crypto_megolm_inbound_session", "session_id", "sender_key", "signing_key", "room_id", "session", "forwarding_chains")
if err != nil {
panic(err)
}
err = migrateTable(old, new, "crypto_megolm_outbound_session", "room_id", "session_id", "session", "shared", "max_messages", "message_count", "max_age", "created_at", "last_used")
if err != nil {
panic(err)
}
} }

View file

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2019 Tulir Asokan // Copyright (C) 2020 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by
@ -22,6 +22,8 @@ import (
log "maunium.net/go/maulogger/v2" log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix-whatsapp/types" "maunium.net/go/mautrix-whatsapp/types"
) )
@ -74,7 +76,7 @@ func (pq *PortalQuery) GetByJID(key PortalKey) *Portal {
return pq.get("SELECT * FROM portal WHERE jid=$1 AND receiver=$2", key.JID, key.Receiver) return pq.get("SELECT * FROM portal WHERE jid=$1 AND receiver=$2", key.JID, key.Receiver)
} }
func (pq *PortalQuery) GetByMXID(mxid types.MatrixRoomID) *Portal { func (pq *PortalQuery) GetByMXID(mxid id.RoomID) *Portal {
return pq.get("SELECT * FROM portal WHERE mxid=$1", mxid) return pq.get("SELECT * FROM portal WHERE mxid=$1", mxid)
} }
@ -107,29 +109,30 @@ type Portal struct {
log log.Logger log log.Logger
Key PortalKey Key PortalKey
MXID types.MatrixRoomID MXID id.RoomID
Name string Name string
Topic string Topic string
Avatar string Avatar string
AvatarURL string AvatarURL id.ContentURI
Encrypted bool
} }
func (portal *Portal) Scan(row Scannable) *Portal { func (portal *Portal) Scan(row Scannable) *Portal {
var mxid, avatarURL sql.NullString var mxid, avatarURL sql.NullString
err := row.Scan(&portal.Key.JID, &portal.Key.Receiver, &mxid, &portal.Name, &portal.Topic, &portal.Avatar, &avatarURL) err := row.Scan(&portal.Key.JID, &portal.Key.Receiver, &mxid, &portal.Name, &portal.Topic, &portal.Avatar, &avatarURL, &portal.Encrypted)
if err != nil { if err != nil {
if err != sql.ErrNoRows { if err != sql.ErrNoRows {
portal.log.Errorln("Database scan failed:", err) portal.log.Errorln("Database scan failed:", err)
} }
return nil return nil
} }
portal.MXID = mxid.String portal.MXID = id.RoomID(mxid.String)
portal.AvatarURL = avatarURL.String portal.AvatarURL, _ = id.ParseContentURI(avatarURL.String)
return portal return portal
} }
func (portal *Portal) mxidPtr() *string { func (portal *Portal) mxidPtr() *id.RoomID {
if len(portal.MXID) > 0 { if len(portal.MXID) > 0 {
return &portal.MXID return &portal.MXID
} }
@ -137,20 +140,20 @@ func (portal *Portal) mxidPtr() *string {
} }
func (portal *Portal) Insert() { func (portal *Portal) Insert() {
_, err := portal.db.Exec("INSERT INTO portal VALUES ($1, $2, $3, $4, $5, $6, $7)", _, err := portal.db.Exec("INSERT INTO portal (jid, receiver, mxid, name, topic, avatar, avatar_url, encrypted) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)",
portal.Key.JID, portal.Key.Receiver, portal.mxidPtr(), portal.Name, portal.Topic, portal.Avatar, portal.AvatarURL) portal.Key.JID, portal.Key.Receiver, portal.mxidPtr(), portal.Name, portal.Topic, portal.Avatar, portal.AvatarURL.String(), portal.Encrypted)
if err != nil { if err != nil {
portal.log.Warnfln("Failed to insert %s: %v", portal.Key, err) portal.log.Warnfln("Failed to insert %s: %v", portal.Key, err)
} }
} }
func (portal *Portal) Update() { func (portal *Portal) Update() {
var mxid *string var mxid *id.RoomID
if len(portal.MXID) > 0 { if len(portal.MXID) > 0 {
mxid = &portal.MXID mxid = &portal.MXID
} }
_, err := portal.db.Exec("UPDATE portal SET mxid=$1, name=$2, topic=$3, avatar=$4, avatar_url=$5 WHERE jid=$6 AND receiver=$7", _, err := portal.db.Exec("UPDATE portal SET mxid=$1, name=$2, topic=$3, avatar=$4, avatar_url=$5, encrypted=$6 WHERE jid=$7 AND receiver=$8",
mxid, portal.Name, portal.Topic, portal.Avatar, portal.AvatarURL, portal.Key.JID, portal.Key.Receiver) mxid, portal.Name, portal.Topic, portal.Avatar, portal.AvatarURL.String(), portal.Encrypted, portal.Key.JID, portal.Key.Receiver)
if err != nil { if err != nil {
portal.log.Warnfln("Failed to update %s: %v", portal.Key, err) portal.log.Warnfln("Failed to update %s: %v", portal.Key, err)
} }
@ -163,7 +166,7 @@ func (portal *Portal) Delete() {
} }
} }
func (portal *Portal) GetUserIDs() []types.MatrixUserID { func (portal *Portal) GetUserIDs() []id.UserID {
rows, err := portal.db.Query(`SELECT "user".mxid FROM "user", user_portal rows, err := portal.db.Query(`SELECT "user".mxid FROM "user", user_portal
WHERE "user".jid=user_portal.user_jid WHERE "user".jid=user_portal.user_jid
AND user_portal.portal_jid=$1 AND user_portal.portal_jid=$1
@ -173,9 +176,9 @@ func (portal *Portal) GetUserIDs() []types.MatrixUserID {
portal.log.Debugln("Failed to get portal user ids:", err) portal.log.Debugln("Failed to get portal user ids:", err)
return nil return nil
} }
var userIDs []types.MatrixUserID var userIDs []id.UserID
for rows.Next() { for rows.Next() {
var userID types.MatrixUserID var userID id.UserID
err = rows.Scan(&userID) err = rows.Scan(&userID)
if err != nil { if err != nil {
portal.log.Warnln("Failed to scan row:", err) portal.log.Warnln("Failed to scan row:", err)

View file

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2019 Tulir Asokan // Copyright (C) 2020 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by
@ -22,6 +22,7 @@ import (
log "maunium.net/go/maulogger/v2" log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix-whatsapp/types" "maunium.net/go/mautrix-whatsapp/types"
"maunium.net/go/mautrix/id"
) )
type PuppetQuery struct { type PuppetQuery struct {
@ -56,7 +57,7 @@ func (pq *PuppetQuery) Get(jid types.WhatsAppID) *Puppet {
return pq.New().Scan(row) return pq.New().Scan(row)
} }
func (pq *PuppetQuery) GetByCustomMXID(mxid types.MatrixUserID) *Puppet { func (pq *PuppetQuery) GetByCustomMXID(mxid id.UserID) *Puppet {
row := pq.db.QueryRow("SELECT jid, avatar, avatar_url, displayname, name_quality, custom_mxid, access_token, next_batch FROM puppet WHERE custom_mxid=$1", mxid) row := pq.db.QueryRow("SELECT jid, avatar, avatar_url, displayname, name_quality, custom_mxid, access_token, next_batch FROM puppet WHERE custom_mxid=$1", mxid)
if row == nil { if row == nil {
return nil return nil
@ -82,11 +83,11 @@ type Puppet struct {
JID types.WhatsAppID JID types.WhatsAppID
Avatar string Avatar string
AvatarURL string AvatarURL id.ContentURI
Displayname string Displayname string
NameQuality int8 NameQuality int8
CustomMXID string CustomMXID id.UserID
AccessToken string AccessToken string
NextBatch string NextBatch string
} }
@ -103,9 +104,9 @@ func (puppet *Puppet) Scan(row Scannable) *Puppet {
} }
puppet.Displayname = displayname.String puppet.Displayname = displayname.String
puppet.Avatar = avatar.String puppet.Avatar = avatar.String
puppet.AvatarURL = avatarURL.String puppet.AvatarURL, _ = id.ParseContentURI(avatarURL.String)
puppet.NameQuality = int8(quality.Int64) puppet.NameQuality = int8(quality.Int64)
puppet.CustomMXID = customMXID.String puppet.CustomMXID = id.UserID(customMXID.String)
puppet.AccessToken = accessToken.String puppet.AccessToken = accessToken.String
puppet.NextBatch = nextBatch.String puppet.NextBatch = nextBatch.String
return puppet return puppet
@ -113,7 +114,7 @@ func (puppet *Puppet) Scan(row Scannable) *Puppet {
func (puppet *Puppet) Insert() { func (puppet *Puppet) Insert() {
_, err := puppet.db.Exec("INSERT INTO puppet (jid, avatar, avatar_url, displayname, name_quality, custom_mxid, access_token, next_batch) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)", _, err := puppet.db.Exec("INSERT INTO puppet (jid, avatar, avatar_url, displayname, name_quality, custom_mxid, access_token, next_batch) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)",
puppet.JID, puppet.Avatar, puppet.AvatarURL, puppet.Displayname, puppet.NameQuality, puppet.CustomMXID, puppet.AccessToken, puppet.NextBatch) puppet.JID, puppet.Avatar, puppet.AvatarURL.String(), puppet.Displayname, puppet.NameQuality, puppet.CustomMXID, puppet.AccessToken, puppet.NextBatch)
if err != nil { if err != nil {
puppet.log.Warnfln("Failed to insert %s: %v", puppet.JID, err) puppet.log.Warnfln("Failed to insert %s: %v", puppet.JID, err)
} }
@ -121,7 +122,7 @@ func (puppet *Puppet) Insert() {
func (puppet *Puppet) Update() { func (puppet *Puppet) Update() {
_, err := puppet.db.Exec("UPDATE puppet SET displayname=$1, name_quality=$2, avatar=$3, avatar_url=$4, custom_mxid=$5, access_token=$6, next_batch=$7 WHERE jid=$8", _, err := puppet.db.Exec("UPDATE puppet SET displayname=$1, name_quality=$2, avatar=$3, avatar_url=$4, custom_mxid=$5, access_token=$6, next_batch=$7 WHERE jid=$8",
puppet.Displayname, puppet.NameQuality, puppet.Avatar, puppet.AvatarURL, puppet.CustomMXID, puppet.AccessToken, puppet.NextBatch, puppet.JID) puppet.Displayname, puppet.NameQuality, puppet.Avatar, puppet.AvatarURL.String(), puppet.CustomMXID, puppet.AccessToken, puppet.NextBatch, puppet.JID)
if err != nil { if err != nil {
puppet.log.Warnfln("Failed to update %s->%s: %v", puppet.JID, err) puppet.log.Warnfln("Failed to update %s->%s: %v", puppet.JID, err)
} }

View file

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2019 Tulir Asokan // Copyright (C) 2020 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by
@ -24,8 +24,9 @@ import (
log "maunium.net/go/maulogger/v2" log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix" "maunium.net/go/mautrix/appservice"
"maunium.net/go/mautrix-appservice" "maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
) )
type SQLStateStore struct { type SQLStateStore struct {
@ -34,19 +35,21 @@ type SQLStateStore struct {
db *Database db *Database
log log.Logger log log.Logger
Typing map[string]map[string]int64 Typing map[id.RoomID]map[id.UserID]int64
typingLock sync.RWMutex typingLock sync.RWMutex
} }
var _ appservice.StateStore = (*SQLStateStore)(nil)
func NewSQLStateStore(db *Database) *SQLStateStore { func NewSQLStateStore(db *Database) *SQLStateStore {
return &SQLStateStore{ return &SQLStateStore{
TypingStateStore: appservice.NewTypingStateStore(), TypingStateStore: appservice.NewTypingStateStore(),
db: db, db: db,
log: log.Sub("StateStore"), log: db.log.Sub("StateStore"),
} }
} }
func (store *SQLStateStore) IsRegistered(userID string) bool { func (store *SQLStateStore) IsRegistered(userID id.UserID) bool {
row := store.db.QueryRow("SELECT EXISTS(SELECT 1 FROM mx_registrations WHERE user_id=$1)", userID) row := store.db.QueryRow("SELECT EXISTS(SELECT 1 FROM mx_registrations WHERE user_id=$1)", userID)
var isRegistered bool var isRegistered bool
err := row.Scan(&isRegistered) err := row.Scan(&isRegistered)
@ -56,7 +59,7 @@ func (store *SQLStateStore) IsRegistered(userID string) bool {
return isRegistered return isRegistered
} }
func (store *SQLStateStore) MarkRegistered(userID string) { func (store *SQLStateStore) MarkRegistered(userID id.UserID) {
var err error var err error
if store.db.dialect == "postgres" { if store.db.dialect == "postgres" {
_, err = store.db.Exec("INSERT INTO mx_registrations (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID) _, err = store.db.Exec("INSERT INTO mx_registrations (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID)
@ -70,28 +73,28 @@ func (store *SQLStateStore) MarkRegistered(userID string) {
} }
} }
func (store *SQLStateStore) GetRoomMembers(roomID string) map[string]mautrix.Member { func (store *SQLStateStore) GetRoomMembers(roomID id.RoomID) map[id.UserID]*event.MemberEventContent {
members := make(map[string]mautrix.Member) members := make(map[id.UserID]*event.MemberEventContent)
rows, err := store.db.Query("SELECT user_id, membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1", roomID) rows, err := store.db.Query("SELECT user_id, membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1", roomID)
if err != nil { if err != nil {
return members return members
} }
var userID string var userID id.UserID
var member mautrix.Member var member event.MemberEventContent
for rows.Next() { for rows.Next() {
err := rows.Scan(&userID, &member.Membership, &member.Displayname, &member.AvatarURL) err := rows.Scan(&userID, &member.Membership, &member.Displayname, &member.AvatarURL)
if err != nil { if err != nil {
store.log.Warnfln("Failed to scan member in %s: %v", roomID, err) store.log.Warnfln("Failed to scan member in %s: %v", roomID, err)
} else { } else {
members[userID] = member members[userID] = &member
} }
} }
return members return members
} }
func (store *SQLStateStore) GetMembership(roomID, userID string) mautrix.Membership { func (store *SQLStateStore) GetMembership(roomID id.RoomID, userID id.UserID) event.Membership {
row := store.db.QueryRow("SELECT membership FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID) row := store.db.QueryRow("SELECT membership FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID)
membership := mautrix.MembershipLeave membership := event.MembershipLeave
err := row.Scan(&membership) err := row.Scan(&membership)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
store.log.Warnfln("Failed to scan membership of %s in %s: %v", userID, roomID, err) store.log.Warnfln("Failed to scan membership of %s in %s: %v", userID, roomID, err)
@ -99,33 +102,55 @@ func (store *SQLStateStore) GetMembership(roomID, userID string) mautrix.Members
return membership return membership
} }
func (store *SQLStateStore) GetMember(roomID, userID string) mautrix.Member { func (store *SQLStateStore) GetMember(roomID id.RoomID, userID id.UserID) *event.MemberEventContent {
member, ok := store.TryGetMember(roomID, userID) member, ok := store.TryGetMember(roomID, userID)
if !ok { if !ok {
member.Membership = mautrix.MembershipLeave member.Membership = event.MembershipLeave
} }
return member return member
} }
func (store *SQLStateStore) TryGetMember(roomID, userID string) (mautrix.Member, bool) { func (store *SQLStateStore) TryGetMember(roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, bool) {
row := store.db.QueryRow("SELECT membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID) row := store.db.QueryRow("SELECT membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID)
var member mautrix.Member var member event.MemberEventContent
err := row.Scan(&member.Membership, &member.Displayname, &member.AvatarURL) err := row.Scan(&member.Membership, &member.Displayname, &member.AvatarURL)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
store.log.Warnfln("Failed to scan member info of %s in %s: %v", userID, roomID, err) store.log.Warnfln("Failed to scan member info of %s in %s: %v", userID, roomID, err)
} }
return member, err == nil return &member, err == nil
} }
func (store *SQLStateStore) IsInRoom(roomID, userID string) bool { func (store *SQLStateStore) FindSharedRooms(userID id.UserID) (rooms []id.RoomID) {
rows, err := store.db.Query(`
SELECT room_id FROM mx_user_profile
LEFT JOIN portal ON portal.mxid=mx_user_profile.room_id
WHERE user_id=$1 AND portal.encrypted=true
`, userID)
if err != nil {
store.log.Warnfln("Failed to query shared rooms with %s: %v", userID, err)
return
}
for rows.Next() {
var roomID id.RoomID
err := rows.Scan(&roomID)
if err != nil {
store.log.Warnfln("Failed to scan room ID: %v", err)
} else {
rooms = append(rooms, roomID)
}
}
return
}
func (store *SQLStateStore) IsInRoom(roomID id.RoomID, userID id.UserID) bool {
return store.IsMembership(roomID, userID, "join") return store.IsMembership(roomID, userID, "join")
} }
func (store *SQLStateStore) IsInvited(roomID, userID string) bool { func (store *SQLStateStore) IsInvited(roomID id.RoomID, userID id.UserID) bool {
return store.IsMembership(roomID, userID, "join", "invite") return store.IsMembership(roomID, userID, "join", "invite")
} }
func (store *SQLStateStore) IsMembership(roomID, userID string, allowedMemberships ...mautrix.Membership) bool { func (store *SQLStateStore) IsMembership(roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool {
membership := store.GetMembership(roomID, userID) membership := store.GetMembership(roomID, userID)
for _, allowedMembership := range allowedMemberships { for _, allowedMembership := range allowedMemberships {
if allowedMembership == membership { if allowedMembership == membership {
@ -135,7 +160,7 @@ func (store *SQLStateStore) IsMembership(roomID, userID string, allowedMembershi
return false return false
} }
func (store *SQLStateStore) SetMembership(roomID, userID string, membership mautrix.Membership) { func (store *SQLStateStore) SetMembership(roomID id.RoomID, userID id.UserID, membership event.Membership) {
var err error var err error
if store.db.dialect == "postgres" { if store.db.dialect == "postgres" {
_, err = store.db.Exec(`INSERT INTO mx_user_profile (room_id, user_id, membership) VALUES ($1, $2, $3) _, err = store.db.Exec(`INSERT INTO mx_user_profile (room_id, user_id, membership) VALUES ($1, $2, $3)
@ -150,7 +175,7 @@ func (store *SQLStateStore) SetMembership(roomID, userID string, membership maut
} }
} }
func (store *SQLStateStore) SetMember(roomID, userID string, member mautrix.Member) { func (store *SQLStateStore) SetMember(roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) {
var err error var err error
if store.db.dialect == "postgres" { if store.db.dialect == "postgres" {
_, err = store.db.Exec(`INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url) VALUES ($1, $2, $3, $4, $5) _, err = store.db.Exec(`INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url) VALUES ($1, $2, $3, $4, $5)
@ -166,7 +191,7 @@ func (store *SQLStateStore) SetMember(roomID, userID string, member mautrix.Memb
} }
} }
func (store *SQLStateStore) SetPowerLevels(roomID string, levels *mautrix.PowerLevels) { func (store *SQLStateStore) SetPowerLevels(roomID id.RoomID, levels *event.PowerLevelsEventContent) {
levelsBytes, err := json.Marshal(levels) levelsBytes, err := json.Marshal(levels)
if err != nil { if err != nil {
store.log.Errorfln("Failed to marshal power levels of %s: %v", roomID, err) store.log.Errorfln("Failed to marshal power levels of %s: %v", roomID, err)
@ -185,7 +210,7 @@ func (store *SQLStateStore) SetPowerLevels(roomID string, levels *mautrix.PowerL
} }
} }
func (store *SQLStateStore) GetPowerLevels(roomID string) (levels *mautrix.PowerLevels) { func (store *SQLStateStore) GetPowerLevels(roomID id.RoomID) (levels *event.PowerLevelsEventContent) {
row := store.db.QueryRow("SELECT power_levels FROM mx_room_state WHERE room_id=$1", roomID) row := store.db.QueryRow("SELECT power_levels FROM mx_room_state WHERE room_id=$1", roomID)
if row == nil { if row == nil {
return return
@ -196,7 +221,7 @@ func (store *SQLStateStore) GetPowerLevels(roomID string) (levels *mautrix.Power
store.log.Errorln("Failed to scan power levels of %s: %v", roomID, err) store.log.Errorln("Failed to scan power levels of %s: %v", roomID, err)
return return
} }
levels = &mautrix.PowerLevels{} levels = &event.PowerLevelsEventContent{}
err = json.Unmarshal(data, levels) err = json.Unmarshal(data, levels)
if err != nil { if err != nil {
store.log.Errorln("Failed to parse power levels of %s: %v", roomID, err) store.log.Errorln("Failed to parse power levels of %s: %v", roomID, err)
@ -205,7 +230,7 @@ func (store *SQLStateStore) GetPowerLevels(roomID string) (levels *mautrix.Power
return return
} }
func (store *SQLStateStore) GetPowerLevel(roomID, userID string) int { func (store *SQLStateStore) GetPowerLevel(roomID id.RoomID, userID id.UserID) int {
if store.db.dialect == "postgres" { if store.db.dialect == "postgres" {
row := store.db.QueryRow(`SELECT row := store.db.QueryRow(`SELECT
COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0) COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0)
@ -224,7 +249,7 @@ func (store *SQLStateStore) GetPowerLevel(roomID, userID string) int {
return store.GetPowerLevels(roomID).GetUserLevel(userID) return store.GetPowerLevels(roomID).GetUserLevel(userID)
} }
func (store *SQLStateStore) GetPowerLevelRequirement(roomID string, eventType mautrix.EventType) int { func (store *SQLStateStore) GetPowerLevelRequirement(roomID id.RoomID, eventType event.Type) int {
if store.db.dialect == "postgres" { if store.db.dialect == "postgres" {
defaultType := "events_default" defaultType := "events_default"
defaultValue := 0 defaultValue := 0
@ -249,7 +274,7 @@ func (store *SQLStateStore) GetPowerLevelRequirement(roomID string, eventType ma
return store.GetPowerLevels(roomID).GetEventLevel(eventType) return store.GetPowerLevels(roomID).GetEventLevel(eventType)
} }
func (store *SQLStateStore) HasPowerLevel(roomID, userID string, eventType mautrix.EventType) bool { func (store *SQLStateStore) HasPowerLevel(roomID id.RoomID, userID id.UserID, eventType event.Type) bool {
if store.db.dialect == "postgres" { if store.db.dialect == "postgres" {
defaultType := "events_default" defaultType := "events_default"
defaultValue := 0 defaultValue := 0

View file

@ -2,17 +2,10 @@ package upgrades
import ( import (
"database/sql" "database/sql"
"fmt"
) )
func init() { func init() {
upgrades[0] = upgrade{"Initial schema", func(tx *sql.Tx, ctx context) error { upgrades[0] = upgrade{"Initial schema", func(tx *sql.Tx, ctx context) error {
var byteType string
if ctx.dialect == SQLite {
byteType = "BLOB"
} else {
byteType = "bytea"
}
_, err := tx.Exec(`CREATE TABLE IF NOT EXISTS portal ( _, err := tx.Exec(`CREATE TABLE IF NOT EXISTS portal (
jid VARCHAR(255), jid VARCHAR(255),
receiver VARCHAR(255), receiver VARCHAR(255),
@ -38,7 +31,7 @@ func init() {
return err return err
} }
_, err = tx.Exec(fmt.Sprintf(`CREATE TABLE IF NOT EXISTS "user" ( _, err = tx.Exec(`CREATE TABLE IF NOT EXISTS "user" (
mxid VARCHAR(255) PRIMARY KEY, mxid VARCHAR(255) PRIMARY KEY,
jid VARCHAR(255) UNIQUE, jid VARCHAR(255) UNIQUE,
@ -47,24 +40,24 @@ func init() {
client_id VARCHAR(255), client_id VARCHAR(255),
client_token VARCHAR(255), client_token VARCHAR(255),
server_token VARCHAR(255), server_token VARCHAR(255),
enc_key %[1]s, enc_key bytea,
mac_key %[1]s mac_key bytea
)`, byteType)) )`)
if err != nil { if err != nil {
return err return err
} }
_, err = tx.Exec(fmt.Sprintf(`CREATE TABLE IF NOT EXISTS message ( _, err = tx.Exec(`CREATE TABLE IF NOT EXISTS message (
chat_jid VARCHAR(255), chat_jid VARCHAR(255),
chat_receiver VARCHAR(255), chat_receiver VARCHAR(255),
jid VARCHAR(255), jid VARCHAR(255),
mxid VARCHAR(255) NOT NULL UNIQUE, mxid VARCHAR(255) NOT NULL UNIQUE,
sender VARCHAR(255) NOT NULL, sender VARCHAR(255) NOT NULL,
content %[1]s NOT NULL, content bytea NOT NULL,
PRIMARY KEY (chat_jid, chat_receiver, jid), PRIMARY KEY (chat_jid, chat_receiver, jid),
FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver) ON DELETE CASCADE
)`, byteType)) )`)
if err != nil { if err != nil {
return err return err
} }

View file

@ -8,7 +8,7 @@ import (
"os" "os"
"strings" "strings"
"maunium.net/go/mautrix" "maunium.net/go/mautrix/event"
) )
func init() { func init() {
@ -46,7 +46,7 @@ func init() {
return executeBatch(tx, valueStrings, values...) return executeBatch(tx, valueStrings, values...)
} }
migrateMemberships := func(tx *sql.Tx, rooms map[string]map[string]mautrix.Membership) error { migrateMemberships := func(tx *sql.Tx, rooms map[string]map[string]event.Membership) error {
for roomID, members := range rooms { for roomID, members := range rooms {
if len(members) == 0 { if len(members) == 0 {
continue continue
@ -68,7 +68,7 @@ func init() {
return nil return nil
} }
migratePowerLevels := func(tx *sql.Tx, rooms map[string]*mautrix.PowerLevels) error { migratePowerLevels := func(tx *sql.Tx, rooms map[string]*event.PowerLevelsEventContent) error {
if len(rooms) == 0 { if len(rooms) == 0 {
return nil return nil
} }
@ -106,9 +106,9 @@ func init() {
)` )`
type TempStateStore struct { type TempStateStore struct {
Registrations map[string]bool `json:"registrations"` Registrations map[string]bool `json:"registrations"`
Members map[string]map[string]mautrix.Membership `json:"memberships"` Members map[string]map[string]event.Membership `json:"memberships"`
PowerLevels map[string]*mautrix.PowerLevels `json:"power_levels"` PowerLevels map[string]*event.PowerLevelsEventContent `json:"power_levels"`
} }
upgrades[9] = upgrade{"Move state store to main DB", func(tx *sql.Tx, ctx context) error { upgrades[9] = upgrade{"Move state store to main DB", func(tx *sql.Tx, ctx context) error {

View file

@ -0,0 +1,12 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[12] = upgrade{"Add encryption status to portal table", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`ALTER TABLE portal ADD COLUMN encrypted BOOLEAN NOT NULL DEFAULT false`)
return err
}}
}

View file

@ -0,0 +1,73 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[13] = upgrade{"Add crypto store to database", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`CREATE TABLE crypto_account (
device_id VARCHAR(255) PRIMARY KEY,
shared BOOLEAN NOT NULL,
sync_token TEXT NOT NULL,
account bytea NOT NULL
)`)
if err != nil {
return err
}
_, err = tx.Exec(`CREATE TABLE crypto_message_index (
sender_key CHAR(43),
session_id CHAR(43),
"index" INTEGER,
event_id VARCHAR(255) NOT NULL,
timestamp BIGINT NOT NULL,
PRIMARY KEY (sender_key, session_id, "index")
)`)
if err != nil {
return err
}
_, err = tx.Exec(`CREATE TABLE crypto_tracked_user (
user_id VARCHAR(255) PRIMARY KEY
)`)
if err != nil {
return err
}
_, err = tx.Exec(`CREATE TABLE crypto_device (
user_id VARCHAR(255),
device_id VARCHAR(255),
identity_key CHAR(43) NOT NULL,
signing_key CHAR(43) NOT NULL,
trust SMALLINT NOT NULL,
deleted BOOLEAN NOT NULL,
name VARCHAR(255) NOT NULL,
PRIMARY KEY (user_id, device_id)
)`)
if err != nil {
return err
}
_, err = tx.Exec(`CREATE TABLE crypto_olm_session (
session_id CHAR(43) PRIMARY KEY,
sender_key CHAR(43) NOT NULL,
session bytea NOT NULL,
created_at timestamp NOT NULL,
last_used timestamp NOT NULL
)`)
if err != nil {
return err
}
_, err = tx.Exec(`CREATE TABLE crypto_megolm_inbound_session (
session_id CHAR(43) PRIMARY KEY,
sender_key CHAR(43) NOT NULL,
signing_key CHAR(43) NOT NULL,
room_id VARCHAR(255) NOT NULL,
session bytea NOT NULL,
forwarding_chains bytea NOT NULL
)`)
if err != nil {
return err
}
return nil
}}
}

View file

@ -0,0 +1,25 @@
package upgrades
import (
"database/sql"
)
func init() {
upgrades[14] = upgrade{"Add outbound group sessions to database", func(tx *sql.Tx, ctx context) error {
_, err := tx.Exec(`CREATE TABLE crypto_megolm_outbound_session (
room_id VARCHAR(255) PRIMARY KEY,
session_id CHAR(43) NOT NULL UNIQUE,
session bytea NOT NULL,
shared BOOLEAN NOT NULL,
max_messages INTEGER NOT NULL,
message_count INTEGER NOT NULL,
max_age BIGINT NOT NULL,
created_at timestamp NOT NULL,
last_used timestamp NOT NULL
)`)
if err != nil {
return err
}
return nil
}}
}

View file

@ -28,7 +28,7 @@ type upgrade struct {
fn upgradeFunc fn upgradeFunc
} }
const NumberOfUpgrades = 12 const NumberOfUpgrades = 15
var upgrades [NumberOfUpgrades]upgrade var upgrades [NumberOfUpgrades]upgrade

View file

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2019 Tulir Asokan // Copyright (C) 2020 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by
@ -28,6 +28,7 @@ import (
"maunium.net/go/mautrix-whatsapp/types" "maunium.net/go/mautrix-whatsapp/types"
"maunium.net/go/mautrix-whatsapp/whatsapp-ext" "maunium.net/go/mautrix-whatsapp/whatsapp-ext"
"maunium.net/go/mautrix/id"
) )
type UserQuery struct { type UserQuery struct {
@ -54,7 +55,7 @@ func (uq *UserQuery) GetAll() (users []*User) {
return return
} }
func (uq *UserQuery) GetByMXID(userID types.MatrixUserID) *User { func (uq *UserQuery) GetByMXID(userID id.UserID) *User {
row := uq.db.QueryRow(`SELECT mxid, jid, management_room, last_connection, client_id, client_token, server_token, enc_key, mac_key FROM "user" WHERE mxid=$1`, userID) row := uq.db.QueryRow(`SELECT mxid, jid, management_room, last_connection, client_id, client_token, server_token, enc_key, mac_key FROM "user" WHERE mxid=$1`, userID)
if row == nil { if row == nil {
return nil return nil
@ -74,9 +75,9 @@ type User struct {
db *Database db *Database
log log.Logger log log.Logger
MXID types.MatrixUserID MXID id.UserID
JID types.WhatsAppID JID types.WhatsAppID
ManagementRoom types.MatrixRoomID ManagementRoom id.RoomID
Session *whatsapp.Session Session *whatsapp.Session
LastConnection uint64 LastConnection uint64
} }

View file

@ -138,6 +138,19 @@ bridge:
# The prefix for commands. Only required in non-management rooms. # The prefix for commands. Only required in non-management rooms.
command_prefix: "!wa" command_prefix: "!wa"
# End-to-bridge encryption support options. This requires login_shared_secret to be configured
# in order to get a device for the bridge bot.
#
# Additionally, https://github.com/matrix-org/synapse/pull/5758 is required if using a normal
# application service.
encryption:
# Allow encryption, work in group chat rooms with e2ee enabled
allow: false
# Default to encryption, force-enable encryption in all portals the bridge creates
# This will cause the bridge bot to be in private chats for the encryption to work properly.
# It is recommended to also set private_chat_portal_meta to true when using this.
default: false
# Permissions for using the bridge. # Permissions for using the bridge.
# Permitted values: # Permitted values:
# relaybot - Talk through the relaybot (if enabled), no access otherwise # relaybot - Talk through the relaybot (if enabled), no access otherwise

View file

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2019 Tulir Asokan // Copyright (C) 2020 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by
@ -22,8 +22,9 @@ import (
"regexp" "regexp"
"strings" "strings"
"maunium.net/go/mautrix" "maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/format" "maunium.net/go/mautrix/format"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix-whatsapp/types" "maunium.net/go/mautrix-whatsapp/types"
"maunium.net/go/mautrix-whatsapp/whatsapp-ext" "maunium.net/go/mautrix-whatsapp/whatsapp-ext"
@ -54,8 +55,7 @@ func NewFormatter(bridge *Bridge) *Formatter {
PillConverter: func(mxid, eventID string) string { PillConverter: func(mxid, eventID string) string {
if mxid[0] == '@' { if mxid[0] == '@' {
puppet := bridge.GetPuppetByMXID(mxid) puppet := bridge.GetPuppetByMXID(id.UserID(mxid))
fmt.Println(mxid, puppet)
if puppet != nil { if puppet != nil {
return "@" + puppet.PhoneNumber() return "@" + puppet.PhoneNumber()
} }
@ -106,10 +106,10 @@ func NewFormatter(bridge *Bridge) *Formatter {
return formatter return formatter
} }
func (formatter *Formatter) getMatrixInfoByJID(jid types.WhatsAppID) (mxid, displayname string) { func (formatter *Formatter) getMatrixInfoByJID(jid types.WhatsAppID) (mxid id.UserID, displayname string) {
if user := formatter.bridge.GetUserByJID(jid); user != nil { if user := formatter.bridge.GetUserByJID(jid); user != nil {
mxid = user.MXID mxid = user.MXID
displayname = user.MXID displayname = string(user.MXID)
} else if puppet := formatter.bridge.GetPuppetByJID(jid); puppet != nil { } else if puppet := formatter.bridge.GetPuppetByJID(jid); puppet != nil {
mxid = puppet.MXID mxid = puppet.MXID
displayname = puppet.Displayname displayname = puppet.Displayname
@ -117,7 +117,7 @@ func (formatter *Formatter) getMatrixInfoByJID(jid types.WhatsAppID) (mxid, disp
return return
} }
func (formatter *Formatter) ParseWhatsApp(content *mautrix.Content) { func (formatter *Formatter) ParseWhatsApp(content *event.MessageEventContent) {
output := html.EscapeString(content.Body) output := html.EscapeString(content.Body)
for regex, replacement := range formatter.waReplString { for regex, replacement := range formatter.waReplString {
output = regex.ReplaceAllString(output, replacement) output = regex.ReplaceAllString(output, replacement)
@ -128,7 +128,7 @@ func (formatter *Formatter) ParseWhatsApp(content *mautrix.Content) {
if output != content.Body { if output != content.Body {
output = strings.Replace(output, "\n", "<br/>", -1) output = strings.Replace(output, "\n", "<br/>", -1)
content.FormattedBody = output content.FormattedBody = output
content.Format = mautrix.FormatHTML content.Format = event.FormatHTML
for regex, replacer := range formatter.waReplFuncText { for regex, replacer := range formatter.waReplFuncText {
content.Body = regex.ReplaceAllStringFunc(content.Body, replacer) content.Body = regex.ReplaceAllStringFunc(content.Body, replacer)
} }

8
go.mod
View file

@ -5,17 +5,17 @@ go 1.14
require ( require (
github.com/Rhymen/go-whatsapp v0.1.0 github.com/Rhymen/go-whatsapp v0.1.0
github.com/chai2010/webp v1.1.0 github.com/chai2010/webp v1.1.0
github.com/gorilla/websocket v1.4.1 github.com/gorilla/websocket v1.4.2
github.com/lib/pq v1.3.0 github.com/lib/pq v1.5.2
github.com/mattn/go-sqlite3 v2.0.3+incompatible github.com/mattn/go-sqlite3 v2.0.3+incompatible
github.com/pkg/errors v0.9.1 github.com/pkg/errors v0.9.1
github.com/shurcooL/sanitized_anchor_name v1.0.0 // indirect github.com/shurcooL/sanitized_anchor_name v1.0.0 // indirect
github.com/skip2/go-qrcode v0.0.0-20191027152451-9434209cb086 github.com/skip2/go-qrcode v0.0.0-20191027152451-9434209cb086
golang.org/x/image v0.0.0-20200430140353-33d19683fad8
gopkg.in/yaml.v2 v2.2.8 gopkg.in/yaml.v2 v2.2.8
maunium.net/go/mauflag v1.0.0 maunium.net/go/mauflag v1.0.0
maunium.net/go/maulogger/v2 v2.1.1 maunium.net/go/maulogger/v2 v2.1.1
maunium.net/go/mautrix v0.1.0-beta.2 maunium.net/go/mautrix v0.4.5
maunium.net/go/mautrix-appservice v0.1.0-alpha.6
) )
replace github.com/Rhymen/go-whatsapp => github.com/tulir/go-whatsapp v0.2.6 replace github.com/Rhymen/go-whatsapp => github.com/tulir/go-whatsapp v0.2.6

38
go.sum
View file

@ -1,7 +1,10 @@
github.com/chai2010/webp v1.1.0 h1:4Ei0/BRroMF9FaXDG2e4OxwFcuW2vcXd+A6tyqTJUQQ= github.com/chai2010/webp v1.1.0 h1:4Ei0/BRroMF9FaXDG2e4OxwFcuW2vcXd+A6tyqTJUQQ=
github.com/chai2010/webp v1.1.0/go.mod h1:LP12PG5IFmLGHUU26tBiCBKnghxx3toZFwDjOYvd3Ow= github.com/chai2010/webp v1.1.0/go.mod h1:LP12PG5IFmLGHUU26tBiCBKnghxx3toZFwDjOYvd3Ow=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/fatih/color v1.9.0 h1:8xPHl4/q1VyqGIPif1F+1V3Y3lSmrq01EabUW3CoW5s= github.com/fatih/color v1.9.0 h1:8xPHl4/q1VyqGIPif1F+1V3Y3lSmrq01EabUW3CoW5s=
github.com/fatih/color v1.9.0/go.mod h1:eQcE1qtQxscV5RaZvpXrrb8Drkc3/DdQ+uUYCNjL+zU= github.com/fatih/color v1.9.0/go.mod h1:eQcE1qtQxscV5RaZvpXrrb8Drkc3/DdQ+uUYCNjL+zU=
github.com/fatih/structs v1.1.0 h1:Q7juDM0QtcnhCpeyLGQKyg4TOIghuNXrkL32pHAUMxo=
github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M=
github.com/golang/protobuf v1.3.1 h1:YF8+flBXS5eO826T4nzqPrxfhQThhXl0YzfuUPu4SBg= github.com/golang/protobuf v1.3.1 h1:YF8+flBXS5eO826T4nzqPrxfhQThhXl0YzfuUPu4SBg=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/gorilla/mux v1.7.4 h1:VuZ8uybHlWmqV03+zRzdwKL4tUnIp1MAQtp1mIFE1bc= github.com/gorilla/mux v1.7.4 h1:VuZ8uybHlWmqV03+zRzdwKL4tUnIp1MAQtp1mIFE1bc=
@ -12,6 +15,8 @@ github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0U
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/lib/pq v1.3.0 h1:/qkRGz8zljWiDcFvgpwUpwIAPu3r07TDvs3Rws+o/pU= github.com/lib/pq v1.3.0 h1:/qkRGz8zljWiDcFvgpwUpwIAPu3r07TDvs3Rws+o/pU=
github.com/lib/pq v1.3.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.3.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/lib/pq v1.5.2 h1:yTSXVswvWUOQ3k1sd7vJfDrbSl8lKuscqFJRqjC0ifw=
github.com/lib/pq v1.5.2/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/mattn/go-colorable v0.1.4 h1:snbPLB8fVfU9iwbbo30TPtbLRzwWu6aJS6Xh4eaaviA= github.com/mattn/go-colorable v0.1.4 h1:snbPLB8fVfU9iwbbo30TPtbLRzwWu6aJS6Xh4eaaviA=
github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
@ -22,12 +27,24 @@ github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/russross/blackfriday/v2 v2.0.1 h1:lPqVAte+HuHNfhJ/0LC98ESWRz8afy9tM/0RK8m9o+Q= github.com/russross/blackfriday/v2 v2.0.1 h1:lPqVAte+HuHNfhJ/0LC98ESWRz8afy9tM/0RK8m9o+Q=
github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/shurcooL/sanitized_anchor_name v1.0.0 h1:PdmoCO6wvbs+7yrJyMORt4/BmY5IYyJwS/kOiWx8mHo= github.com/shurcooL/sanitized_anchor_name v1.0.0 h1:PdmoCO6wvbs+7yrJyMORt4/BmY5IYyJwS/kOiWx8mHo=
github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc=
github.com/skip2/go-qrcode v0.0.0-20191027152451-9434209cb086 h1:RYiqpb2ii2Z6J4x0wxK46kvPBbFuZcdhS+CIztmYgZs= github.com/skip2/go-qrcode v0.0.0-20191027152451-9434209cb086 h1:RYiqpb2ii2Z6J4x0wxK46kvPBbFuZcdhS+CIztmYgZs=
github.com/skip2/go-qrcode v0.0.0-20191027152451-9434209cb086/go.mod h1:PLPIyL7ikehBD1OAjmKKiOEhbvWyHGaNDjquXMcYABo= github.com/skip2/go-qrcode v0.0.0-20191027152451-9434209cb086/go.mod h1:PLPIyL7ikehBD1OAjmKKiOEhbvWyHGaNDjquXMcYABo=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/tidwall/gjson v1.6.0 h1:9VEQWz6LLMUsUl6PueE49ir4Ka6CzLymOAZDxpFsTDc=
github.com/tidwall/gjson v1.6.0/go.mod h1:P256ACg0Mn+j1RXIDXoss50DeIABTYK1PULOJHhxOls=
github.com/tidwall/match v1.0.1 h1:PnKP62LPNxHKTwvHHZZzdOAOCtsJTjo6dZLCwpKm5xc=
github.com/tidwall/match v1.0.1/go.mod h1:LujAq0jyVjBy028G1WhWfIzbpQfMO8bBZ6Tyb0+pL9E=
github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk=
github.com/tidwall/pretty v1.0.1 h1:WE4RBSZ1x6McVVC8S/Md+Qse8YUv6HRObAx6ke00NY8=
github.com/tidwall/pretty v1.0.1/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk=
github.com/tidwall/sjson v1.1.1 h1:7h1vk049Jnd5EH9NyzNiEuwYW4b5qgreBbqRC19AS3U=
github.com/tidwall/sjson v1.1.1/go.mod h1:yvVuSnpEQv5cYIrO+AT6kw4QVfd5SDZoGIS7/5+fZFs=
github.com/tulir/go-whatsapp v0.2.0 h1:JWK/Xxrc1qsZsVz6gYVX5AtvzYmqaHNjt34Ipnrgz88= github.com/tulir/go-whatsapp v0.2.0 h1:JWK/Xxrc1qsZsVz6gYVX5AtvzYmqaHNjt34Ipnrgz88=
github.com/tulir/go-whatsapp v0.2.0/go.mod h1:gyw9zGup1/Y3ZQUueZaqz3iR/WX9a2Lth4aqEbXjkok= github.com/tulir/go-whatsapp v0.2.0/go.mod h1:gyw9zGup1/Y3ZQUueZaqz3iR/WX9a2Lth4aqEbXjkok=
github.com/tulir/go-whatsapp v0.2.1 h1:Owoss2AbvZMgt3nxoFlsG+bqLHDnO+PhXNhhoCmb/3M= github.com/tulir/go-whatsapp v0.2.1 h1:Owoss2AbvZMgt3nxoFlsG+bqLHDnO+PhXNhhoCmb/3M=
@ -42,6 +59,8 @@ github.com/tulir/go-whatsapp v0.2.6 h1:d58cqz/iqcCDeT+uFjLso8oSgMTYqoxGhGhGOyyHB
github.com/tulir/go-whatsapp v0.2.6/go.mod h1:gyw9zGup1/Y3ZQUueZaqz3iR/WX9a2Lth4aqEbXjkok= github.com/tulir/go-whatsapp v0.2.6/go.mod h1:gyw9zGup1/Y3ZQUueZaqz3iR/WX9a2Lth4aqEbXjkok=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90PveolxSbWFaJdECFbxSq0Mqo2M=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/image v0.0.0-20200430140353-33d19683fad8 h1:6WW6V3x1P/jokJBpRQYUJnMHRP6isStQwCozxnU7XQw=
golang.org/x/image v0.0.0-20200430140353-33d19683fad8/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20200301022130-244492dfa37a h1:GuSPYbZzB5/dcLNCwLQLsg3obCJtX9IJhpXkvY7kzk0= golang.org/x/net v0.0.0-20200301022130-244492dfa37a h1:GuSPYbZzB5/dcLNCwLQLsg3obCJtX9IJhpXkvY7kzk0=
golang.org/x/net v0.0.0-20200301022130-244492dfa37a/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200301022130-244492dfa37a/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
@ -52,6 +71,7 @@ golang.org/x/sys v0.0.0-20191026070338-33540a1f6037 h1:YyJpGZS1sBuBCzLAR1VEpK193
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M= maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M=
@ -60,5 +80,19 @@ maunium.net/go/maulogger/v2 v2.1.1 h1:NAZNc6XUFJzgzfewCzVoGkxNAsblLCSSEdtDuIjP0X
maunium.net/go/maulogger/v2 v2.1.1/go.mod h1:TYWy7wKwz/tIXTpsx8G3mZseIRiC5DoMxSZazOHy68A= maunium.net/go/maulogger/v2 v2.1.1/go.mod h1:TYWy7wKwz/tIXTpsx8G3mZseIRiC5DoMxSZazOHy68A=
maunium.net/go/mautrix v0.1.0-beta.2 h1:RxYTqTzW6iXu83gf8ucqGwYx8JLa+a17LWjiPkVV/fU= maunium.net/go/mautrix v0.1.0-beta.2 h1:RxYTqTzW6iXu83gf8ucqGwYx8JLa+a17LWjiPkVV/fU=
maunium.net/go/mautrix v0.1.0-beta.2/go.mod h1:YFMU9DBeXH7cqx7sJLg0DkVxwNPbih8QbpUTYf/IjMM= maunium.net/go/mautrix v0.1.0-beta.2/go.mod h1:YFMU9DBeXH7cqx7sJLg0DkVxwNPbih8QbpUTYf/IjMM=
maunium.net/go/mautrix-appservice v0.1.0-alpha.6 h1:dNE+RykOC0UhSyRNbMHXEk3BzSOp3dj8aQwKuNMELWM= maunium.net/go/mautrix v0.3.6 h1:bXUo8WFdv7sUpvr7jgJ6TVMEQgVHtw1z1T3eUcLpPCA=
maunium.net/go/mautrix-appservice v0.1.0-alpha.6/go.mod h1:Dfiwiuicvn8s2VKrBDrZ9eCjlKUMbuCi91TE6xeEHRM= maunium.net/go/mautrix v0.3.6/go.mod h1:SkGZzch8CvU2qKtNpYxtzZ0sQxfVEJ3IsVVLSUBUx9Y=
maunium.net/go/mautrix v0.3.7 h1:N0czrZeAwjvBrw2a/B2G6U3EwIYaWpt7OuSslGp8DRc=
maunium.net/go/mautrix v0.3.7/go.mod h1:SkGZzch8CvU2qKtNpYxtzZ0sQxfVEJ3IsVVLSUBUx9Y=
maunium.net/go/mautrix v0.4.0 h1:IYfmxCoxR/6UMi92IncsSZeKQbZm8Xa35XIRX814KJ4=
maunium.net/go/mautrix v0.4.0/go.mod h1:8Y+NqmROJyWYvvP4yPfX9tLM59VCfgE/kcQ0SeX68ho=
maunium.net/go/mautrix v0.4.1 h1:i2lJNT+TE4AAL3cVKUN4jKVRkujCE/oS8aIsj8+7iNE=
maunium.net/go/mautrix v0.4.1/go.mod h1:8Y+NqmROJyWYvvP4yPfX9tLM59VCfgE/kcQ0SeX68ho=
maunium.net/go/mautrix v0.4.2 h1:GBU++Z7o/fLPcEsNMkNOUsnDknwV/MGPQ0BN4ikK6tw=
maunium.net/go/mautrix v0.4.2/go.mod h1:8Y+NqmROJyWYvvP4yPfX9tLM59VCfgE/kcQ0SeX68ho=
maunium.net/go/mautrix v0.4.3 h1:fVoJy992TjBEvuK5NeO9fpBh+9JuSFsxaEdGjFp/7h4=
maunium.net/go/mautrix v0.4.3/go.mod h1:8Y+NqmROJyWYvvP4yPfX9tLM59VCfgE/kcQ0SeX68ho=
maunium.net/go/mautrix v0.4.4 h1:C5yYDzUdRtJj/9Vot5YBPQUsWmn19sTySew7f4ACLhM=
maunium.net/go/mautrix v0.4.4/go.mod h1:8Y+NqmROJyWYvvP4yPfX9tLM59VCfgE/kcQ0SeX68ho=
maunium.net/go/mautrix v0.4.5 h1:cQhlPURW0TGjlqEoac+4+J/aS5/Rg8x1b+fiFZZz6LI=
maunium.net/go/mautrix v0.4.5/go.mod h1:8Y+NqmROJyWYvvP4yPfX9tLM59VCfgE/kcQ0SeX68ho=

57
main.go
View file

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2019 Tulir Asokan // Copyright (C) 2020 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by
@ -18,7 +18,6 @@ package main
import ( import (
"fmt" "fmt"
"net/http"
"os" "os"
"os/signal" "os/signal"
"sync" "sync"
@ -29,7 +28,9 @@ import (
log "maunium.net/go/maulogger/v2" log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix" "maunium.net/go/mautrix"
"maunium.net/go/mautrix-appservice" "maunium.net/go/mautrix/appservice"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix-whatsapp/config" "maunium.net/go/mautrix-whatsapp/config"
"maunium.net/go/mautrix-whatsapp/database" "maunium.net/go/mautrix-whatsapp/database"
@ -106,29 +107,39 @@ type Bridge struct {
Bot *appservice.IntentAPI Bot *appservice.IntentAPI
Formatter *Formatter Formatter *Formatter
Relaybot *User Relaybot *User
Crypto Crypto
usersByMXID map[types.MatrixUserID]*User usersByMXID map[id.UserID]*User
usersByJID map[types.WhatsAppID]*User usersByJID map[types.WhatsAppID]*User
usersLock sync.Mutex usersLock sync.Mutex
managementRooms map[types.MatrixRoomID]*User managementRooms map[id.RoomID]*User
managementRoomsLock sync.Mutex managementRoomsLock sync.Mutex
portalsByMXID map[types.MatrixRoomID]*Portal portalsByMXID map[id.RoomID]*Portal
portalsByJID map[database.PortalKey]*Portal portalsByJID map[database.PortalKey]*Portal
portalsLock sync.Mutex portalsLock sync.Mutex
puppets map[types.WhatsAppID]*Puppet puppets map[types.WhatsAppID]*Puppet
puppetsByCustomMXID map[types.MatrixUserID]*Puppet puppetsByCustomMXID map[id.UserID]*Puppet
puppetsLock sync.Mutex puppetsLock sync.Mutex
} }
type Crypto interface {
HandleMemberEvent(*event.Event)
Decrypt(*event.Event) (*event.Event, error)
Encrypt(id.RoomID, event.Type, event.Content) (*event.EncryptedEventContent, error)
Init() error
Start()
Stop()
}
func NewBridge() *Bridge { func NewBridge() *Bridge {
bridge := &Bridge{ bridge := &Bridge{
usersByMXID: make(map[types.MatrixUserID]*User), usersByMXID: make(map[id.UserID]*User),
usersByJID: make(map[types.WhatsAppID]*User), usersByJID: make(map[types.WhatsAppID]*User),
managementRooms: make(map[types.MatrixRoomID]*User), managementRooms: make(map[id.RoomID]*User),
portalsByMXID: make(map[types.MatrixRoomID]*Portal), portalsByMXID: make(map[id.RoomID]*Portal),
portalsByJID: make(map[database.PortalKey]*Portal), portalsByJID: make(map[database.PortalKey]*Portal),
puppets: make(map[types.WhatsAppID]*Puppet), puppets: make(map[types.WhatsAppID]*Puppet),
puppetsByCustomMXID: make(map[types.MatrixUserID]*Puppet), puppetsByCustomMXID: make(map[id.UserID]*Puppet),
} }
var err error var err error
@ -141,12 +152,8 @@ func NewBridge() *Bridge {
} }
func (bridge *Bridge) ensureConnection() { func (bridge *Bridge) ensureConnection() {
url := bridge.Bot.BuildURL("account", "whoami")
resp := struct {
UserID string `json:"user_id"`
}{}
for { for {
_, err := bridge.Bot.MakeRequest(http.MethodGet, url, nil, &resp) resp, err := bridge.Bot.Whoami()
if err != nil { if err != nil {
if httpErr, ok := err.(mautrix.HTTPError); ok && httpErr.RespError != nil && httpErr.RespError.ErrCode == "M_UNKNOWN_ACCESS_TOKEN" { if httpErr, ok := err.(mautrix.HTTPError); ok && httpErr.RespError != nil && httpErr.RespError.ErrCode == "M_UNKNOWN_ACCESS_TOKEN" {
bridge.Log.Fatalln("Access token invalid. Is the registration installed in your homeserver correctly?") bridge.Log.Fatalln("Access token invalid. Is the registration installed in your homeserver correctly?")
@ -219,6 +226,7 @@ func (bridge *Bridge) Init() {
bridge.Log.Debugln("Initializing Matrix event handler") bridge.Log.Debugln("Initializing Matrix event handler")
bridge.MatrixHandler = NewMatrixHandler(bridge) bridge.MatrixHandler = NewMatrixHandler(bridge)
bridge.Formatter = NewFormatter(bridge) bridge.Formatter = NewFormatter(bridge)
bridge.Crypto = NewCryptoHelper(bridge)
} }
func (bridge *Bridge) Start() { func (bridge *Bridge) Start() {
@ -227,6 +235,13 @@ func (bridge *Bridge) Start() {
bridge.Log.Fatalln("Failed to initialize database:", err) bridge.Log.Fatalln("Failed to initialize database:", err)
os.Exit(15) os.Exit(15)
} }
if bridge.Crypto != nil {
err := bridge.Crypto.Init()
if err != nil {
bridge.Log.Fatalln("Error initializing end-to-bridge encryption:", err)
os.Exit(19)
}
}
if bridge.Provisioning != nil { if bridge.Provisioning != nil {
bridge.Log.Debugln("Initializing provisioning API") bridge.Log.Debugln("Initializing provisioning API")
bridge.Provisioning.Init() bridge.Provisioning.Init()
@ -239,6 +254,7 @@ func (bridge *Bridge) Start() {
bridge.Log.Debugln("Starting event processor") bridge.Log.Debugln("Starting event processor")
go bridge.EventProcessor.Start() go bridge.EventProcessor.Start()
go bridge.UpdateBotProfile() go bridge.UpdateBotProfile()
go bridge.Crypto.Start()
go bridge.StartUsers() go bridge.StartUsers()
} }
@ -262,10 +278,14 @@ func (bridge *Bridge) UpdateBotProfile() {
botConfig := bridge.Config.AppService.Bot botConfig := bridge.Config.AppService.Bot
var err error var err error
var mxc id.ContentURI
if botConfig.Avatar == "remove" { if botConfig.Avatar == "remove" {
err = bridge.Bot.SetAvatarURL("") err = bridge.Bot.SetAvatarURL(mxc)
} else if len(botConfig.Avatar) > 0 { } else if len(botConfig.Avatar) > 0 {
err = bridge.Bot.SetAvatarURL(botConfig.Avatar) mxc, err = id.ParseContentURI(botConfig.Avatar)
if err == nil {
err = bridge.Bot.SetAvatarURL(mxc)
}
} }
if err != nil { if err != nil {
bridge.Log.Warnln("Failed to update bot avatar:", err) bridge.Log.Warnln("Failed to update bot avatar:", err)
@ -299,6 +319,7 @@ func (bridge *Bridge) StartUsers() {
} }
func (bridge *Bridge) Stop() { func (bridge *Bridge) Stop() {
bridge.Crypto.Stop()
bridge.AS.Stop() bridge.AS.Stop()
bridge.EventProcessor.Stop() bridge.EventProcessor.Stop()
for _, user := range bridge.usersByJID { for _, user := range bridge.usersByJID {

153
matrix.go
View file

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2019 Tulir Asokan // Copyright (C) 2020 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by
@ -22,11 +22,10 @@ import (
"maunium.net/go/maulogger/v2" "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix" "maunium.net/go/mautrix/appservice"
"maunium.net/go/mautrix-appservice" "maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/format" "maunium.net/go/mautrix/format"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix-whatsapp/types"
) )
type MatrixHandler struct { type MatrixHandler struct {
@ -43,17 +42,32 @@ func NewMatrixHandler(bridge *Bridge) *MatrixHandler {
log: bridge.Log.Sub("Matrix"), log: bridge.Log.Sub("Matrix"),
cmd: NewCommandHandler(bridge), cmd: NewCommandHandler(bridge),
} }
bridge.EventProcessor.On(mautrix.EventMessage, handler.HandleMessage) bridge.EventProcessor.On(event.EventMessage, handler.HandleMessage)
bridge.EventProcessor.On(mautrix.EventSticker, handler.HandleMessage) bridge.EventProcessor.On(event.EventEncrypted, handler.HandleEncrypted)
bridge.EventProcessor.On(mautrix.EventRedaction, handler.HandleRedaction) bridge.EventProcessor.On(event.EventSticker, handler.HandleMessage)
bridge.EventProcessor.On(mautrix.StateMember, handler.HandleMembership) bridge.EventProcessor.On(event.EventRedaction, handler.HandleRedaction)
bridge.EventProcessor.On(mautrix.StateRoomName, handler.HandleRoomMetadata) bridge.EventProcessor.On(event.StateMember, handler.HandleMembership)
bridge.EventProcessor.On(mautrix.StateRoomAvatar, handler.HandleRoomMetadata) bridge.EventProcessor.On(event.StateRoomName, handler.HandleRoomMetadata)
bridge.EventProcessor.On(mautrix.StateTopic, handler.HandleRoomMetadata) bridge.EventProcessor.On(event.StateRoomAvatar, handler.HandleRoomMetadata)
bridge.EventProcessor.On(event.StateTopic, handler.HandleRoomMetadata)
bridge.EventProcessor.On(event.StateEncryption, handler.HandleEncryption)
return handler return handler
} }
func (mx *MatrixHandler) HandleBotInvite(evt *mautrix.Event) { func (mx *MatrixHandler) HandleEncryption(evt *event.Event) {
if evt.Content.AsEncryption().Algorithm != id.AlgorithmMegolmV1 {
return
}
portal := mx.bridge.GetPortalByMXID(evt.RoomID)
mx.log.Debugln(portal)
if portal != nil && !portal.Encrypted {
mx.log.Debugfln("%s enabled encryption in %s", evt.Sender, evt.RoomID)
portal.Encrypted = true
portal.Update()
}
}
func (mx *MatrixHandler) HandleBotInvite(evt *event.Event) {
intent := mx.as.BotIntent() intent := mx.as.BotIntent()
user := mx.bridge.GetUserByMXID(evt.Sender) user := mx.bridge.GetUserByMXID(evt.Sender)
@ -61,7 +75,7 @@ func (mx *MatrixHandler) HandleBotInvite(evt *mautrix.Event) {
return return
} }
resp, err := intent.JoinRoom(evt.RoomID, "", nil) resp, err := intent.JoinRoomByID(evt.RoomID)
if err != nil { if err != nil {
mx.log.Debugln("Failed to join room", evt.RoomID, "with invite from", evt.Sender) mx.log.Debugln("Failed to join room", evt.RoomID, "with invite from", evt.Sender)
return return
@ -97,7 +111,7 @@ func (mx *MatrixHandler) HandleBotInvite(evt *mautrix.Event) {
for mxid, _ := range members.Joined { for mxid, _ := range members.Joined {
if mxid == intent.UserID || mxid == evt.Sender { if mxid == intent.UserID || mxid == evt.Sender {
continue continue
} else if _, ok := mx.bridge.ParsePuppetMXID(types.MatrixUserID(mxid)); ok { } else if _, ok := mx.bridge.ParsePuppetMXID(mxid); ok {
hasPuppets = true hasPuppets = true
continue continue
} }
@ -108,15 +122,24 @@ func (mx *MatrixHandler) HandleBotInvite(evt *mautrix.Event) {
} }
if !hasPuppets { if !hasPuppets {
user := mx.bridge.GetUserByMXID(types.MatrixUserID(evt.Sender)) user := mx.bridge.GetUserByMXID(evt.Sender)
user.SetManagementRoom(types.MatrixRoomID(resp.RoomID)) user.SetManagementRoom(resp.RoomID)
intent.SendNotice(string(user.ManagementRoom), "This room has been registered as your bridge management/status room. Send `help` to get a list of commands.") intent.SendNotice(user.ManagementRoom, "This room has been registered as your bridge management/status room. Send `help` to get a list of commands.")
mx.log.Debugln(resp.RoomID, "registered as a management room with", evt.Sender) mx.log.Debugln(resp.RoomID, "registered as a management room with", evt.Sender)
} }
} }
func (mx *MatrixHandler) HandleMembership(evt *mautrix.Event) { func (mx *MatrixHandler) HandleMembership(evt *event.Event) {
if evt.Content.Membership == "invite" && evt.GetStateKey() == mx.as.BotMXID() { if _, isPuppet := mx.bridge.ParsePuppetMXID(evt.Sender); evt.Sender == mx.bridge.Bot.UserID || isPuppet {
return
}
if mx.bridge.Crypto != nil {
mx.bridge.Crypto.HandleMemberEvent(evt)
}
content := evt.Content.AsMember()
if content.Membership == event.MembershipInvite && id.UserID(evt.GetStateKey()) == mx.as.BotMXID() {
mx.HandleBotInvite(evt) mx.HandleBotInvite(evt)
} }
@ -125,15 +148,21 @@ func (mx *MatrixHandler) HandleMembership(evt *mautrix.Event) {
return return
} }
user := mx.bridge.GetUserByMXID(types.MatrixUserID(evt.Sender)) user := mx.bridge.GetUserByMXID(evt.Sender)
if user == nil || !user.Whitelisted || !user.IsConnected() { if user == nil || !user.Whitelisted || !user.IsConnected() {
return return
} }
if evt.Content.Membership == "leave" { if content.Membership == event.MembershipLeave {
if evt.GetStateKey() == evt.Sender { if id.UserID(evt.GetStateKey()) == evt.Sender {
if portal.IsPrivateChat() || evt.Unsigned.PrevContent.Membership == "join" { if evt.Unsigned.PrevContent != nil {
portal.HandleMatrixLeave(user) _ = evt.Unsigned.PrevContent.ParseRaw(evt.Type)
prevContent, ok := evt.Unsigned.PrevContent.Parsed.(*event.MemberEventContent)
if ok {
if portal.IsPrivateChat() || prevContent.Membership == "join" {
portal.HandleMatrixLeave(user)
}
}
} }
} else { } else {
portal.HandleMatrixKick(user, evt) portal.HandleMatrixKick(user, evt)
@ -141,8 +170,8 @@ func (mx *MatrixHandler) HandleMembership(evt *mautrix.Event) {
} }
} }
func (mx *MatrixHandler) HandleRoomMetadata(evt *mautrix.Event) { func (mx *MatrixHandler) HandleRoomMetadata(evt *event.Event) {
user := mx.bridge.GetUserByMXID(types.MatrixUserID(evt.Sender)) user := mx.bridge.GetUserByMXID(evt.Sender)
if user == nil || !user.Whitelisted || !user.IsConnected() { if user == nil || !user.Whitelisted || !user.IsConnected() {
return return
} }
@ -154,12 +183,12 @@ func (mx *MatrixHandler) HandleRoomMetadata(evt *mautrix.Event) {
var resp <-chan string var resp <-chan string
var err error var err error
switch evt.Type { switch content := evt.Content.Parsed.(type) {
case mautrix.StateRoomName: case *event.RoomNameEventContent:
resp, err = user.Conn.UpdateGroupSubject(evt.Content.Name, portal.Key.JID) resp, err = user.Conn.UpdateGroupSubject(content.Name, portal.Key.JID)
case mautrix.StateTopic: case *event.TopicEventContent:
resp, err = user.Conn.UpdateGroupDescription(portal.Key.JID, evt.Content.Topic) resp, err = user.Conn.UpdateGroupDescription(portal.Key.JID, content.Topic)
case mautrix.StateRoomAvatar: case *event.RoomAvatarEventContent:
return return
} }
if err != nil { if err != nil {
@ -170,47 +199,65 @@ func (mx *MatrixHandler) HandleRoomMetadata(evt *mautrix.Event) {
} }
} }
func (mx *MatrixHandler) HandleMessage(evt *mautrix.Event) { func (mx *MatrixHandler) shouldIgnoreEvent(evt *event.Event) bool {
if _, isPuppet := mx.bridge.ParsePuppetMXID(evt.Sender); evt.Sender == mx.bridge.Bot.UserID || isPuppet { if _, isPuppet := mx.bridge.ParsePuppetMXID(evt.Sender); evt.Sender == mx.bridge.Bot.UserID || isPuppet {
return return true
} }
isCustomPuppet, ok := evt.Content.Raw["net.maunium.whatsapp.puppet"].(bool) isCustomPuppet, ok := evt.Content.Raw["net.maunium.whatsapp.puppet"].(bool)
if ok && isCustomPuppet && mx.bridge.GetPuppetByCustomMXID(evt.Sender) != nil { if ok && isCustomPuppet && mx.bridge.GetPuppetByCustomMXID(evt.Sender) != nil {
return return true
} }
user := mx.bridge.GetUserByMXID(evt.Sender)
roomID := types.MatrixRoomID(evt.RoomID)
user := mx.bridge.GetUserByMXID(types.MatrixUserID(evt.Sender))
if !user.RelaybotWhitelisted { if !user.RelaybotWhitelisted {
return true
}
return false
}
func (mx *MatrixHandler) HandleEncrypted(evt *event.Event) {
if mx.shouldIgnoreEvent(evt) || mx.bridge.Crypto == nil {
return return
} }
if user.Whitelisted && evt.Content.MsgType == mautrix.MsgText { decrypted, err := mx.bridge.Crypto.Decrypt(evt)
if err != nil {
mx.log.Warnfln("Failed to decrypt %s: %v", evt.ID, err)
return
}
mx.bridge.EventProcessor.Dispatch(decrypted)
}
func (mx *MatrixHandler) HandleMessage(evt *event.Event) {
if mx.shouldIgnoreEvent(evt) {
return
}
user := mx.bridge.GetUserByMXID(evt.Sender)
content := evt.Content.AsMessage()
if user.Whitelisted && content.MsgType == event.MsgText {
commandPrefix := mx.bridge.Config.Bridge.CommandPrefix commandPrefix := mx.bridge.Config.Bridge.CommandPrefix
hasCommandPrefix := strings.HasPrefix(evt.Content.Body, commandPrefix) hasCommandPrefix := strings.HasPrefix(content.Body, commandPrefix)
if hasCommandPrefix { if hasCommandPrefix {
evt.Content.Body = strings.TrimLeft(evt.Content.Body[len(commandPrefix):], " ") content.Body = strings.TrimLeft(content.Body[len(commandPrefix):], " ")
} }
if hasCommandPrefix || roomID == user.ManagementRoom { if hasCommandPrefix || evt.RoomID == user.ManagementRoom {
mx.cmd.Handle(roomID, user, evt.Content.Body) mx.cmd.Handle(evt.RoomID, user, content.Body)
return return
} }
} }
portal := mx.bridge.GetPortalByMXID(roomID) portal := mx.bridge.GetPortalByMXID(evt.RoomID)
if portal != nil && (user.Whitelisted || portal.HasRelaybot()) { if portal != nil && (user.Whitelisted || portal.HasRelaybot()) {
portal.HandleMatrixMessage(user, evt) portal.HandleMatrixMessage(user, evt)
} }
} }
func (mx *MatrixHandler) HandleRedaction(evt *mautrix.Event) { func (mx *MatrixHandler) HandleRedaction(evt *event.Event) {
if _, isPuppet := mx.bridge.ParsePuppetMXID(evt.Sender); evt.Sender == mx.bridge.Bot.UserID || isPuppet { if _, isPuppet := mx.bridge.ParsePuppetMXID(evt.Sender); evt.Sender == mx.bridge.Bot.UserID || isPuppet {
return return
} }
roomID := types.MatrixRoomID(evt.RoomID) user := mx.bridge.GetUserByMXID(evt.Sender)
user := mx.bridge.GetUserByMXID(types.MatrixUserID(evt.Sender))
if !user.Whitelisted { if !user.Whitelisted {
return return
@ -221,13 +268,13 @@ func (mx *MatrixHandler) HandleRedaction(evt *mautrix.Event) {
} else if !user.IsConnected() { } else if !user.IsConnected() {
msg := format.RenderMarkdown(fmt.Sprintf("[%[1]s](https://matrix.to/#/%[1]s): \u26a0 "+ msg := format.RenderMarkdown(fmt.Sprintf("[%[1]s](https://matrix.to/#/%[1]s): \u26a0 "+
"You are not connected to WhatsApp, so your redaction was not bridged. "+ "You are not connected to WhatsApp, so your redaction was not bridged. "+
"Use `%[2]s reconnect` to reconnect.", user.MXID, mx.bridge.Config.Bridge.CommandPrefix)) "Use `%[2]s reconnect` to reconnect.", user.MXID, mx.bridge.Config.Bridge.CommandPrefix), true, false)
msg.MsgType = mautrix.MsgNotice msg.MsgType = event.MsgNotice
_, _ = mx.bridge.Bot.SendMessageEvent(roomID, mautrix.EventMessage, msg) _, _ = mx.bridge.Bot.SendMessageEvent(evt.RoomID, event.EventMessage, msg)
return return
} }
portal := mx.bridge.GetPortalByMXID(roomID) portal := mx.bridge.GetPortalByMXID(evt.RoomID)
if portal != nil { if portal != nil {
portal.HandleMatrixRedaction(user, evt) portal.HandleMatrixRedaction(user, evt)
} }

38
no-cgo.go Normal file
View file

@ -0,0 +1,38 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2020 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
// +build !cgo
package main
import (
"image"
"io"
"golang.org/x/image/webp"
)
func NewCryptoHelper(bridge *Bridge) Crypto {
if !bridge.Config.Bridge.Encryption.Allow {
bridge.Log.Warnln("Bridge built without end-to-bridge encryption, but encryption is enabled in config")
}
bridge.Log.Debugln("Bridge built without end-to-bridge encryption")
return nil
}
func decodeWebp(r io.Reader) (image.Image, error) {
return webp.Decode(r)
}

327
portal.go
View file

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2019 Tulir Asokan // Copyright (C) 2020 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by
@ -20,7 +20,6 @@ import (
"bytes" "bytes"
"encoding/gob" "encoding/gob"
"encoding/hex" "encoding/hex"
"encoding/json"
"fmt" "fmt"
"html" "html"
"image" "image"
@ -35,22 +34,24 @@ import (
"sync" "sync"
"time" "time"
"github.com/chai2010/webp" "github.com/pkg/errors"
log "maunium.net/go/maulogger/v2" log "maunium.net/go/maulogger/v2"
"github.com/Rhymen/go-whatsapp" "github.com/Rhymen/go-whatsapp"
waProto "github.com/Rhymen/go-whatsapp/binary/proto" waProto "github.com/Rhymen/go-whatsapp/binary/proto"
"maunium.net/go/mautrix" "maunium.net/go/mautrix"
"maunium.net/go/mautrix-appservice" "maunium.net/go/mautrix/appservice"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/format" "maunium.net/go/mautrix/format"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix-whatsapp/database" "maunium.net/go/mautrix-whatsapp/database"
"maunium.net/go/mautrix-whatsapp/types" "maunium.net/go/mautrix-whatsapp/types"
"maunium.net/go/mautrix-whatsapp/whatsapp-ext" "maunium.net/go/mautrix-whatsapp/whatsapp-ext"
) )
func (bridge *Bridge) GetPortalByMXID(mxid types.MatrixRoomID) *Portal { func (bridge *Bridge) GetPortalByMXID(mxid id.RoomID) *Portal {
bridge.portalsLock.Lock() bridge.portalsLock.Lock()
defer bridge.portalsLock.Unlock() defer bridge.portalsLock.Unlock()
portal, ok := bridge.portalsByMXID[mxid] portal, ok := bridge.portalsByMXID[mxid]
@ -233,7 +234,7 @@ func init() {
gob.Register(&waProto.Message{}) gob.Register(&waProto.Message{})
} }
func (portal *Portal) markHandled(source *User, message *waProto.WebMessageInfo, mxid types.MatrixEventID) { func (portal *Portal) markHandled(source *User, message *waProto.WebMessageInfo, mxid id.EventID) {
msg := portal.bridge.DB.Message.New() msg := portal.bridge.DB.Message.New()
msg.Chat = portal.Key msg.Chat = portal.Key
msg.JID = message.GetKey().GetId() msg.JID = message.GetKey().GetId()
@ -269,7 +270,7 @@ func (portal *Portal) startHandling(info whatsapp.MessageInfo) bool {
return true return true
} }
func (portal *Portal) finishHandling(source *User, message *waProto.WebMessageInfo, mxid types.MatrixEventID) { func (portal *Portal) finishHandling(source *User, message *waProto.WebMessageInfo, mxid id.EventID) {
portal.markHandled(source, message, mxid) portal.markHandled(source, message, mxid)
portal.log.Debugln("Handled message", message.GetKey().GetId(), "->", mxid) portal.log.Debugln("Handled message", message.GetKey().GetId(), "->", mxid)
} }
@ -416,7 +417,7 @@ func (portal *Portal) UpdateMetadata(user *User) bool {
return update return update
} }
func (portal *Portal) userMXIDAction(user *User, fn func(mxid types.MatrixUserID)) { func (portal *Portal) userMXIDAction(user *User, fn func(mxid id.UserID)) {
if user == nil { if user == nil {
return return
} }
@ -430,7 +431,7 @@ func (portal *Portal) userMXIDAction(user *User, fn func(mxid types.MatrixUserID
} }
} }
func (portal *Portal) ensureMXIDInvited(mxid types.MatrixUserID) { func (portal *Portal) ensureMXIDInvited(mxid id.UserID) {
err := portal.MainIntent().EnsureInvited(portal.MXID, mxid) err := portal.MainIntent().EnsureInvited(portal.MXID, mxid)
if err != nil { if err != nil {
portal.log.Warnfln("Failed to ensure %s is invited to %s: %v", mxid, portal.MXID, err) portal.log.Warnfln("Failed to ensure %s is invited to %s: %v", mxid, portal.MXID, err)
@ -481,27 +482,27 @@ func (portal *Portal) Sync(user *User, contact whatsapp.Contact) {
} }
} }
func (portal *Portal) GetBasePowerLevels() *mautrix.PowerLevels { func (portal *Portal) GetBasePowerLevels() *event.PowerLevelsEventContent {
anyone := 0 anyone := 0
nope := 99 nope := 99
invite := 99 invite := 99
if portal.bridge.Config.Bridge.AllowUserInvite { if portal.bridge.Config.Bridge.AllowUserInvite {
invite = 0 invite = 0
} }
return &mautrix.PowerLevels{ return &event.PowerLevelsEventContent{
UsersDefault: anyone, UsersDefault: anyone,
EventsDefault: anyone, EventsDefault: anyone,
RedactPtr: &anyone, RedactPtr: &anyone,
StateDefaultPtr: &nope, StateDefaultPtr: &nope,
BanPtr: &nope, BanPtr: &nope,
InvitePtr: &invite, InvitePtr: &invite,
Users: map[string]int{ Users: map[id.UserID]int{
portal.MainIntent().UserID: 100, portal.MainIntent().UserID: 100,
}, },
Events: map[string]int{ Events: map[string]int{
mautrix.StateRoomName.Type: anyone, event.StateRoomName.Type: anyone,
mautrix.StateRoomAvatar.Type: anyone, event.StateRoomAvatar.Type: anyone,
mautrix.StateTopic.Type: anyone, event.StateTopic.Type: anyone,
}, },
} }
} }
@ -559,9 +560,9 @@ func (portal *Portal) RestrictMetadataChanges(restrict bool) {
newLevel = 50 newLevel = 50
} }
changed := false changed := false
changed = levels.EnsureEventLevel(mautrix.StateRoomName, newLevel) || changed changed = levels.EnsureEventLevel(event.StateRoomName, newLevel) || changed
changed = levels.EnsureEventLevel(mautrix.StateRoomAvatar, newLevel) || changed changed = levels.EnsureEventLevel(event.StateRoomAvatar, newLevel) || changed
changed = levels.EnsureEventLevel(mautrix.StateTopic, newLevel) || changed changed = levels.EnsureEventLevel(event.StateTopic, newLevel) || changed
if changed { if changed {
_, err = portal.MainIntent().SetPowerLevels(portal.MXID, levels) _, err = portal.MainIntent().SetPowerLevels(portal.MXID, levels)
if err != nil { if err != nil {
@ -724,7 +725,6 @@ func (portal *Portal) CreateMatrixRoom(user *User) error {
portal.log.Infoln("Creating Matrix room. Info source:", user.MXID) portal.log.Infoln("Creating Matrix room. Info source:", user.MXID)
var metadata *whatsappExt.GroupInfo var metadata *whatsappExt.GroupInfo
isPrivateChat := false
if portal.IsPrivateChat() { if portal.IsPrivateChat() {
puppet := portal.bridge.GetPuppetByJID(portal.Key.JID) puppet := portal.bridge.GetPuppetByJID(portal.Key.JID)
if portal.bridge.Config.Bridge.PrivateChatPortalMeta { if portal.bridge.Config.Bridge.PrivateChatPortalMeta {
@ -735,7 +735,6 @@ func (portal *Portal) CreateMatrixRoom(user *User) error {
portal.Name = "" portal.Name = ""
} }
portal.Topic = "WhatsApp private chat" portal.Topic = "WhatsApp private chat"
isPrivateChat = true
} else if portal.IsStatusBroadcastRoom() { } else if portal.IsStatusBroadcastRoom() {
portal.Name = "WhatsApp Status Broadcast" portal.Name = "WhatsApp Status Broadcast"
portal.Topic = "WhatsApp status updates from your contacts" portal.Topic = "WhatsApp status updates from your contacts"
@ -749,33 +748,46 @@ func (portal *Portal) CreateMatrixRoom(user *User) error {
portal.UpdateAvatar(user, nil) portal.UpdateAvatar(user, nil)
} }
initialState := []*mautrix.Event{{ initialState := []*event.Event{{
Type: mautrix.StatePowerLevels, Type: event.StatePowerLevels,
Content: mautrix.Content{ Content: event.Content{
PowerLevels: portal.GetBasePowerLevels(), Parsed: portal.GetBasePowerLevels(),
}, },
}} }}
if len(portal.AvatarURL) > 0 { if !portal.AvatarURL.IsEmpty() {
initialState = append(initialState, &mautrix.Event{ initialState = append(initialState, &event.Event{
Type: mautrix.StateRoomAvatar, Type: event.StateRoomAvatar,
Content: mautrix.Content{ Content: event.Content{
URL: portal.AvatarURL, Parsed: event.RoomAvatarEventContent{URL: portal.AvatarURL},
}, },
}) })
} }
invite := []string{user.MXID} invite := []id.UserID{user.MXID}
if user.IsRelaybot { if user.IsRelaybot {
invite = portal.bridge.Config.Bridge.Relaybot.InviteUsers invite = portal.bridge.Config.Bridge.Relaybot.InviteUsers
} }
if portal.bridge.Config.Bridge.Encryption.Default {
initialState = append(initialState, &event.Event{
Type: event.StateEncryption,
Content: event.Content{
Parsed: event.EncryptionEventContent{Algorithm: id.AlgorithmMegolmV1},
},
})
portal.Encrypted = true
if portal.IsPrivateChat() {
invite = append(invite, portal.bridge.Bot.UserID)
}
}
resp, err := intent.CreateRoom(&mautrix.ReqCreateRoom{ resp, err := intent.CreateRoom(&mautrix.ReqCreateRoom{
Visibility: "private", Visibility: "private",
Name: portal.Name, Name: portal.Name,
Topic: portal.Topic, Topic: portal.Topic,
Invite: invite, Invite: invite,
Preset: "private_chat", Preset: "private_chat",
IsDirect: isPrivateChat, IsDirect: portal.IsPrivateChat(),
InitialState: initialState, InitialState: initialState,
}) })
if err != nil { if err != nil {
@ -783,6 +795,12 @@ func (portal *Portal) CreateMatrixRoom(user *User) error {
} }
portal.MXID = resp.RoomID portal.MXID = resp.RoomID
portal.Update() portal.Update()
// We set the memberships beforehand to make sure the encryption key exchange in initial backfill knows the users are here.
for _, user := range invite {
portal.bridge.StateStore.SetMembership(portal.MXID, user, event.MembershipInvite)
}
if metadata != nil { if metadata != nil {
portal.SyncParticipants(metadata) portal.SyncParticipants(metadata)
} else { } else {
@ -795,6 +813,13 @@ func (portal *Portal) CreateMatrixRoom(user *User) error {
if portal.IsPrivateChat() { if portal.IsPrivateChat() {
puppet := user.bridge.GetPuppetByJID(portal.Key.JID) puppet := user.bridge.GetPuppetByJID(portal.Key.JID)
user.addPuppetToCommunity(puppet) user.addPuppetToCommunity(puppet)
if portal.bridge.Config.Bridge.Encryption.Default {
err = portal.bridge.Bot.EnsureJoined(portal.MXID)
if err != nil {
portal.log.Errorln("Failed to join created portal with bridge bot for e2be:", err)
}
}
} }
err = portal.FillInitialHistory(user) err = portal.FillInitialHistory(user)
if err != nil { if err != nil {
@ -847,19 +872,18 @@ func (portal *Portal) GetMessageIntent(user *User, info whatsapp.MessageInfo) *a
return portal.bridge.GetPuppetByJID(info.SenderJid).IntentFor(portal) return portal.bridge.GetPuppetByJID(info.SenderJid).IntentFor(portal)
} }
func (portal *Portal) SetReply(content *mautrix.Content, info whatsapp.ContextInfo) { func (portal *Portal) SetReply(content *event.MessageEventContent, info whatsapp.ContextInfo) {
if len(info.QuotedMessageID) == 0 { if len(info.QuotedMessageID) == 0 {
return return
} }
message := portal.bridge.DB.Message.GetByJID(portal.Key, info.QuotedMessageID) message := portal.bridge.DB.Message.GetByJID(portal.Key, info.QuotedMessageID)
if message != nil { if message != nil {
event, err := portal.MainIntent().GetEvent(portal.MXID, message.MXID) evt, err := portal.MainIntent().GetEvent(portal.MXID, message.MXID)
if err != nil { if err != nil {
portal.log.Warnln("Failed to get reply target:", err) portal.log.Warnln("Failed to get reply target:", err)
return return
} }
event.Content.RemoveReplyFallback() content.SetReply(evt)
content.SetReply(event)
} }
return return
} }
@ -895,7 +919,14 @@ func (portal *Portal) HandleFakeMessage(source *User, message FakeMessage) {
return return
} }
_, err := portal.MainIntent().SendNotice(portal.MXID, message.Text) content := event.MessageEventContent{
MsgType: event.MsgNotice,
Body: message.Text,
}
if message.Alert {
content.MsgType = event.MsgText
}
_, err := portal.sendMainIntentMessage(content)
if err != nil { if err != nil {
portal.log.Errorfln("Failed to handle fake message %s: %v", message.ID, err) portal.log.Errorfln("Failed to handle fake message %s: %v", message.ID, err)
return return
@ -908,30 +939,30 @@ func (portal *Portal) HandleFakeMessage(source *User, message FakeMessage) {
portal.recentlyHandled[index] = message.ID portal.recentlyHandled[index] = message.ID
} }
type MessageContent struct { func (portal *Portal) sendMainIntentMessage(content interface{}) (*mautrix.RespSendEvent, error) {
*mautrix.Content return portal.sendMessage(portal.MainIntent(), event.EventMessage, content, 0)
IsCustomPuppet bool `json:"net.maunium.whatsapp.puppet,omitempty"`
} }
type serializableContent mautrix.Content func (portal *Portal) sendMessage(intent *appservice.IntentAPI, eventType event.Type, content interface{}, timestamp int64) (*mautrix.RespSendEvent, error) {
wrappedContent := event.Content{Parsed: content}
type serializableMessageContent struct { if timestamp != 0 && intent.IsCustomPuppet {
*serializableContent wrappedContent.Raw = map[string]interface{}{
IsCustomPuppet bool `json:"net.maunium.whatsapp.puppet,omitempty"` "net.maunium.whatsapp.puppet": intent.IsCustomPuppet,
}
// Hacky bypass for mautrix.Content's MarshalSJSON
func (content *MessageContent) MarshalJSON() ([]byte, error) {
if mautrix.DisableFancyEventParsing {
if content.IsCustomPuppet {
content.Raw["net.maunium.whatsapp.puppet"] = content.IsCustomPuppet
} }
return json.Marshal(content.Raw)
} }
return json.Marshal(&serializableMessageContent{ if portal.Encrypted && portal.bridge.Crypto != nil {
serializableContent: (*serializableContent)(content.Content), encrypted, err := portal.bridge.Crypto.Encrypt(portal.MXID, eventType, wrappedContent)
IsCustomPuppet: content.IsCustomPuppet, if err != nil {
}) return nil, errors.Wrap(err, "failed to encrypt event")
}
eventType = event.EventEncrypted
wrappedContent.Parsed = encrypted
}
if timestamp == 0 {
return intent.SendMessageEvent(portal.MXID, eventType, &wrappedContent)
} else {
return intent.SendMassagedMessageEvent(portal.MXID, eventType, &wrappedContent, timestamp)
}
} }
func (portal *Portal) HandleTextMessage(source *User, message whatsapp.TextMessage) { func (portal *Portal) HandleTextMessage(source *User, message whatsapp.TextMessage) {
@ -944,16 +975,16 @@ func (portal *Portal) HandleTextMessage(source *User, message whatsapp.TextMessa
return return
} }
content := &mautrix.Content{ content := &event.MessageEventContent{
Body: message.Text, Body: message.Text,
MsgType: mautrix.MsgText, MsgType: event.MsgText,
} }
portal.bridge.Formatter.ParseWhatsApp(content) portal.bridge.Formatter.ParseWhatsApp(content)
portal.SetReply(content, message.ContextInfo) portal.SetReply(content, message.ContextInfo)
_, _ = intent.UserTyping(portal.MXID, false, 0) _, _ = intent.UserTyping(portal.MXID, false, 0)
resp, err := intent.SendMassagedMessageEvent(portal.MXID, mautrix.EventMessage, &MessageContent{content, intent.IsCustomPuppet}, int64(message.Info.Timestamp*1000)) resp, err := portal.sendMessage(intent, event.EventMessage, content, int64(message.Info.Timestamp*1000))
if err != nil { if err != nil {
portal.log.Errorfln("Failed to handle message %s: %v", message.Info.Id, err) portal.log.Errorfln("Failed to handle message %s: %v", message.Info.Id, err)
return return
@ -977,7 +1008,10 @@ func (portal *Portal) HandleMediaMessage(source *User, download func() ([]byte,
return return
} else if err != nil { } else if err != nil {
portal.log.Errorfln("Failed to download media for %s: %v", info.Id, err) portal.log.Errorfln("Failed to download media for %s: %v", info.Id, err)
resp, err := portal.MainIntent().SendNotice(portal.MXID, "Failed to bridge media") resp, err := portal.sendMainIntentMessage(event.MessageEventContent{
MsgType: event.MsgNotice,
Body: "Failed to bridge media",
})
if err != nil { if err != nil {
portal.log.Errorfln("Failed to send media download error message for %s: %v", info.Id, err) portal.log.Errorfln("Failed to send media download error message for %s: %v", info.Id, err)
} else { } else {
@ -988,7 +1022,7 @@ func (portal *Portal) HandleMediaMessage(source *User, download func() ([]byte,
// synapse doesn't handle webp well, so we convert it. This can be dropped once https://github.com/matrix-org/synapse/issues/4382 is fixed // synapse doesn't handle webp well, so we convert it. This can be dropped once https://github.com/matrix-org/synapse/issues/4382 is fixed
if mimeType == "image/webp" { if mimeType == "image/webp" {
img, err := webp.Decode(bytes.NewReader(data)) img, err := decodeWebp(bytes.NewReader(data))
if err != nil { if err != nil {
portal.log.Errorfln("Failed to decode media for %s: %v", err) portal.log.Errorfln("Failed to decode media for %s: %v", err)
return return
@ -1016,10 +1050,10 @@ func (portal *Portal) HandleMediaMessage(source *User, download func() ([]byte,
fileName += exts[0] fileName += exts[0]
} }
content := &mautrix.Content{ content := &event.MessageEventContent{
Body: fileName, Body: fileName,
URL: uploaded.ContentURI, URL: uploaded.ContentURI.CUString(),
Info: &mautrix.FileInfo{ Info: &event.FileInfo{
Size: len(data), Size: len(data),
MimeType: mimeType, MimeType: mimeType,
}, },
@ -1030,9 +1064,9 @@ func (portal *Portal) HandleMediaMessage(source *User, download func() ([]byte,
thumbnailMime := http.DetectContentType(thumbnail) thumbnailMime := http.DetectContentType(thumbnail)
uploadedThumbnail, _ := intent.UploadBytes(thumbnail, thumbnailMime) uploadedThumbnail, _ := intent.UploadBytes(thumbnail, thumbnailMime)
if uploadedThumbnail != nil { if uploadedThumbnail != nil {
content.Info.ThumbnailURL = uploadedThumbnail.ContentURI content.Info.ThumbnailURL = uploadedThumbnail.ContentURI.CUString()
cfg, _, _ := image.DecodeConfig(bytes.NewReader(data)) cfg, _, _ := image.DecodeConfig(bytes.NewReader(data))
content.Info.ThumbnailInfo = &mautrix.FileInfo{ content.Info.ThumbnailInfo = &event.FileInfo{
Size: len(thumbnail), Size: len(thumbnail),
Width: cfg.Width, Width: cfg.Width,
Height: cfg.Height, Height: cfg.Height,
@ -1044,40 +1078,40 @@ func (portal *Portal) HandleMediaMessage(source *User, download func() ([]byte,
switch strings.ToLower(strings.Split(mimeType, "/")[0]) { switch strings.ToLower(strings.Split(mimeType, "/")[0]) {
case "image": case "image":
if !sendAsSticker { if !sendAsSticker {
content.MsgType = mautrix.MsgImage content.MsgType = event.MsgImage
} }
cfg, _, _ := image.DecodeConfig(bytes.NewReader(data)) cfg, _, _ := image.DecodeConfig(bytes.NewReader(data))
content.Info.Width = cfg.Width content.Info.Width = cfg.Width
content.Info.Height = cfg.Height content.Info.Height = cfg.Height
case "video": case "video":
content.MsgType = mautrix.MsgVideo content.MsgType = event.MsgVideo
case "audio": case "audio":
content.MsgType = mautrix.MsgAudio content.MsgType = event.MsgAudio
default: default:
content.MsgType = mautrix.MsgFile content.MsgType = event.MsgFile
} }
_, _ = intent.UserTyping(portal.MXID, false, 0) _, _ = intent.UserTyping(portal.MXID, false, 0)
ts := int64(info.Timestamp * 1000) ts := int64(info.Timestamp * 1000)
eventType := mautrix.EventMessage eventType := event.EventMessage
if sendAsSticker { if sendAsSticker {
eventType = mautrix.EventSticker eventType = event.EventSticker
} }
resp, err := intent.SendMassagedMessageEvent(portal.MXID, eventType, &MessageContent{content, intent.IsCustomPuppet}, ts) resp, err := portal.sendMessage(intent, eventType, content, ts)
if err != nil { if err != nil {
portal.log.Errorfln("Failed to handle message %s: %v", info.Id, err) portal.log.Errorfln("Failed to handle message %s: %v", info.Id, err)
return return
} }
if len(caption) > 0 { if len(caption) > 0 {
captionContent := &mautrix.Content{ captionContent := &event.MessageEventContent{
Body: caption, Body: caption,
MsgType: mautrix.MsgNotice, MsgType: event.MsgNotice,
} }
portal.bridge.Formatter.ParseWhatsApp(captionContent) portal.bridge.Formatter.ParseWhatsApp(captionContent)
_, err := intent.SendMassagedMessageEvent(portal.MXID, mautrix.EventMessage, &MessageContent{captionContent, intent.IsCustomPuppet}, ts) _, err := portal.sendMessage(intent, event.EventMessage, content, ts)
if err != nil { if err != nil {
portal.log.Warnfln("Failed to handle caption of message %s: %v", info.Id, err) portal.log.Warnfln("Failed to handle caption of message %s: %v", info.Id, err)
} }
@ -1094,14 +1128,17 @@ func makeMessageID() *string {
return &str return &str
} }
func (portal *Portal) downloadThumbnail(evt *mautrix.Event) []byte { func (portal *Portal) downloadThumbnail(content *event.MessageEventContent, id id.EventID) []byte {
if evt.Content.Info == nil || len(evt.Content.Info.ThumbnailURL) == 0 { if len(content.GetInfo().ThumbnailURL) == 0 {
return nil return nil
} }
mxc, err := content.GetInfo().ThumbnailURL.Parse()
thumbnail, err := portal.MainIntent().DownloadBytes(evt.Content.Info.ThumbnailURL)
if err != nil { if err != nil {
portal.log.Errorln("Failed to download thumbnail in %s: %v", evt.ID, err) portal.log.Errorln("Malformed thumbnail URL in %s: %v", id, err)
}
thumbnail, err := portal.MainIntent().DownloadBytes(mxc)
if err != nil {
portal.log.Errorln("Failed to download thumbnail in %s: %v", id, err)
return nil return nil
} }
thumbnailType := http.DetectContentType(thumbnail) thumbnailType := http.DetectContentType(thumbnail)
@ -1121,30 +1158,44 @@ func (portal *Portal) downloadThumbnail(evt *mautrix.Event) []byte {
Quality: jpeg.DefaultQuality, Quality: jpeg.DefaultQuality,
}) })
if err != nil { if err != nil {
portal.log.Errorln("Failed to re-encode thumbnail in %s: %v", evt.ID, err) portal.log.Errorln("Failed to re-encode thumbnail in %s: %v", id, err)
return nil return nil
} }
return buf.Bytes() return buf.Bytes()
} }
func (portal *Portal) preprocessMatrixMedia(sender *User, relaybotFormatted bool, evt *mautrix.Event, mediaType whatsapp.MediaType) *MediaUpload { func (portal *Portal) preprocessMatrixMedia(sender *User, relaybotFormatted bool, content *event.MessageEventContent, eventID id.EventID, mediaType whatsapp.MediaType) *MediaUpload {
if evt.Content.Info == nil {
evt.Content.Info = &mautrix.FileInfo{}
}
var caption string var caption string
if relaybotFormatted { if relaybotFormatted {
caption = portal.bridge.Formatter.ParseMatrix(evt.Content.FormattedBody) caption = portal.bridge.Formatter.ParseMatrix(content.FormattedBody)
} }
content, err := portal.MainIntent().DownloadBytes(evt.Content.URL) var file *event.EncryptedFileInfo
rawMXC := content.URL
if content.File != nil {
file = content.File
rawMXC = file.URL
}
mxc, err := rawMXC.Parse()
if err != nil { if err != nil {
portal.log.Errorfln("Failed to download media in %s: %v", evt.ID, err) portal.log.Errorln("Malformed content URL in %s: %v", eventID, err)
}
data, err := portal.MainIntent().DownloadBytes(mxc)
if err != nil {
portal.log.Errorfln("Failed to download media in %s: %v", eventID, err)
return nil return nil
} }
if file != nil {
data, err = file.Decrypt(data)
if err != nil {
portal.log.Errorfln("Failed to decrypt media in %s: %v", eventID, err)
return nil
}
}
url, mediaKey, fileEncSHA256, fileSHA256, fileLength, err := sender.Conn.Upload(bytes.NewReader(content), mediaType) url, mediaKey, fileEncSHA256, fileSHA256, fileLength, err := sender.Conn.Upload(bytes.NewReader(data), mediaType)
if err != nil { if err != nil {
portal.log.Errorfln("Failed to upload media in %s: %v", evt.ID, err) portal.log.Errorfln("Failed to upload media in %s: %v", eventID, err)
return nil return nil
} }
@ -1155,7 +1206,7 @@ func (portal *Portal) preprocessMatrixMedia(sender *User, relaybotFormatted bool
FileEncSHA256: fileEncSHA256, FileEncSHA256: fileEncSHA256,
FileSHA256: fileSHA256, FileSHA256: fileSHA256,
FileLength: fileLength, FileLength: fileLength,
Thumbnail: portal.downloadThumbnail(evt), Thumbnail: portal.downloadThumbnail(content, eventID),
} }
} }
@ -1169,7 +1220,7 @@ type MediaUpload struct {
Thumbnail []byte Thumbnail []byte
} }
func (portal *Portal) sendMatrixConnectionError(sender *User, eventID string) bool { func (portal *Portal) sendMatrixConnectionError(sender *User, eventID id.EventID) bool {
if !sender.HasSession() { if !sender.HasSession() {
portal.log.Debugln("Ignoring event", eventID, "from", sender.MXID, "as user has no session") portal.log.Debugln("Ignoring event", eventID, "from", sender.MXID, "as user has no session")
return true return true
@ -1183,9 +1234,9 @@ func (portal *Portal) sendMatrixConnectionError(sender *User, eventID string) bo
if sender.IsLoginInProgress() { if sender.IsLoginInProgress() {
reconnect = "You have a login attempt in progress, please wait." reconnect = "You have a login attempt in progress, please wait."
} }
msg := format.RenderMarkdown("\u26a0 You are not connected to WhatsApp, so your message was not bridged. " + reconnect) msg := format.RenderMarkdown("\u26a0 You are not connected to WhatsApp, so your message was not bridged. "+reconnect, true, false)
msg.MsgType = mautrix.MsgNotice msg.MsgType = event.MsgNotice
_, err := portal.MainIntent().SendMessageEvent(portal.MXID, mautrix.EventMessage, msg) _, err := portal.sendMainIntentMessage(msg)
if err != nil { if err != nil {
portal.log.Errorln("Failed to send bridging failure message:", err) portal.log.Errorln("Failed to send bridging failure message:", err)
} }
@ -1194,30 +1245,34 @@ func (portal *Portal) sendMatrixConnectionError(sender *User, eventID string) bo
return false return false
} }
func (portal *Portal) addRelaybotFormat(user *User, evt *mautrix.Event) bool { func (portal *Portal) addRelaybotFormat(sender *User, content *event.MessageEventContent) bool {
member := portal.MainIntent().Member(portal.MXID, evt.Sender) member := portal.MainIntent().Member(portal.MXID, sender.MXID)
if len(member.Displayname) == 0 { if len(member.Displayname) == 0 {
member.Displayname = evt.Sender member.Displayname = string(sender.MXID)
} }
if evt.Content.Format != mautrix.FormatHTML { if content.Format != event.FormatHTML {
evt.Content.FormattedBody = strings.Replace(html.EscapeString(evt.Content.Body), "\n", "<br/>", -1) content.FormattedBody = strings.Replace(html.EscapeString(content.Body), "\n", "<br/>", -1)
evt.Content.Format = mautrix.FormatHTML content.Format = event.FormatHTML
} }
data, err := portal.bridge.Config.Bridge.Relaybot.FormatMessage(evt, member) data, err := portal.bridge.Config.Bridge.Relaybot.FormatMessage(content, sender.MXID, member)
if err != nil { if err != nil {
portal.log.Errorln("Failed to apply relaybot format:", err) portal.log.Errorln("Failed to apply relaybot format:", err)
} }
evt.Content.FormattedBody = data content.FormattedBody = data
return true return true
} }
func (portal *Portal) HandleMatrixMessage(sender *User, evt *mautrix.Event) { func (portal *Portal) HandleMatrixMessage(sender *User, evt *event.Event) {
if !portal.HasRelaybot() && ( if !portal.HasRelaybot() && (
(portal.IsPrivateChat() && sender.JID != portal.Key.Receiver) || (portal.IsPrivateChat() && sender.JID != portal.Key.Receiver) ||
portal.sendMatrixConnectionError(sender, evt.ID)) { portal.sendMatrixConnectionError(sender, evt.ID)) {
return return
} }
content := evt.Content.AsMessage()
if content == nil {
return
}
portal.log.Debugfln("Received event %s", evt.ID) portal.log.Debugfln("Received event %s", evt.ID)
ts := uint64(evt.Timestamp / 1000) ts := uint64(evt.Timestamp / 1000)
@ -1234,9 +1289,9 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *mautrix.Event) {
Status: &status, Status: &status,
} }
ctxInfo := &waProto.ContextInfo{} ctxInfo := &waProto.ContextInfo{}
replyToID := evt.Content.GetReplyTo() replyToID := content.GetReplyTo()
if len(replyToID) > 0 { if len(replyToID) > 0 {
evt.Content.RemoveReplyFallback() content.RemoveReplyFallback()
msg := portal.bridge.DB.Message.GetByMXID(replyToID) msg := portal.bridge.DB.Message.GetByMXID(replyToID)
if msg != nil && msg.Content != nil { if msg != nil && msg.Content != nil {
ctxInfo.StanzaId = &msg.JID ctxInfo.StanzaId = &msg.JID
@ -1254,21 +1309,21 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *mautrix.Event) {
return return
} }
} else { } else {
relaybotFormatted = portal.addRelaybotFormat(sender, evt) relaybotFormatted = portal.addRelaybotFormat(sender, content)
sender = portal.bridge.Relaybot sender = portal.bridge.Relaybot
} }
} }
if evt.Type == mautrix.EventSticker { if evt.Type == event.EventSticker {
evt.Content.MsgType = mautrix.MsgImage content.MsgType = event.MsgImage
} }
var err error var err error
switch evt.Content.MsgType { switch content.MsgType {
case mautrix.MsgText, mautrix.MsgEmote, mautrix.MsgNotice: case event.MsgText, event.MsgEmote, event.MsgNotice:
text := evt.Content.Body text := content.Body
if evt.Content.Format == mautrix.FormatHTML { if content.Format == event.FormatHTML {
text = portal.bridge.Formatter.ParseMatrix(evt.Content.FormattedBody) text = portal.bridge.Formatter.ParseMatrix(content.FormattedBody)
} }
if evt.Content.MsgType == mautrix.MsgEmote && !relaybotFormatted { if content.MsgType == event.MsgEmote && !relaybotFormatted {
text = "/me " + text text = "/me " + text
} }
ctxInfo.MentionedJid = mentionRegex.FindAllString(text, -1) ctxInfo.MentionedJid = mentionRegex.FindAllString(text, -1)
@ -1283,8 +1338,8 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *mautrix.Event) {
} else { } else {
info.Message.Conversation = &text info.Message.Conversation = &text
} }
case mautrix.MsgImage: case event.MsgImage:
media := portal.preprocessMatrixMedia(sender, relaybotFormatted, evt, whatsapp.MediaImage) media := portal.preprocessMatrixMedia(sender, relaybotFormatted, content, evt.ID, whatsapp.MediaImage)
if media == nil { if media == nil {
return return
} }
@ -1293,53 +1348,53 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *mautrix.Event) {
JpegThumbnail: media.Thumbnail, JpegThumbnail: media.Thumbnail,
Url: &media.URL, Url: &media.URL,
MediaKey: media.MediaKey, MediaKey: media.MediaKey,
Mimetype: &evt.Content.GetInfo().MimeType, Mimetype: &content.GetInfo().MimeType,
FileEncSha256: media.FileEncSHA256, FileEncSha256: media.FileEncSHA256,
FileSha256: media.FileSHA256, FileSha256: media.FileSHA256,
FileLength: &media.FileLength, FileLength: &media.FileLength,
} }
case mautrix.MsgVideo: case event.MsgVideo:
media := portal.preprocessMatrixMedia(sender, relaybotFormatted, evt, whatsapp.MediaVideo) media := portal.preprocessMatrixMedia(sender, relaybotFormatted, content, evt.ID, whatsapp.MediaVideo)
if media == nil { if media == nil {
return return
} }
duration := uint32(evt.Content.GetInfo().Duration) duration := uint32(content.GetInfo().Duration)
info.Message.VideoMessage = &waProto.VideoMessage{ info.Message.VideoMessage = &waProto.VideoMessage{
Caption: &media.Caption, Caption: &media.Caption,
JpegThumbnail: media.Thumbnail, JpegThumbnail: media.Thumbnail,
Url: &media.URL, Url: &media.URL,
MediaKey: media.MediaKey, MediaKey: media.MediaKey,
Mimetype: &evt.Content.GetInfo().MimeType, Mimetype: &content.GetInfo().MimeType,
Seconds: &duration, Seconds: &duration,
FileEncSha256: media.FileEncSHA256, FileEncSha256: media.FileEncSHA256,
FileSha256: media.FileSHA256, FileSha256: media.FileSHA256,
FileLength: &media.FileLength, FileLength: &media.FileLength,
} }
case mautrix.MsgAudio: case event.MsgAudio:
media := portal.preprocessMatrixMedia(sender, relaybotFormatted, evt, whatsapp.MediaAudio) media := portal.preprocessMatrixMedia(sender, relaybotFormatted, content, evt.ID, whatsapp.MediaAudio)
if media == nil { if media == nil {
return return
} }
duration := uint32(evt.Content.GetInfo().Duration) duration := uint32(content.GetInfo().Duration)
info.Message.AudioMessage = &waProto.AudioMessage{ info.Message.AudioMessage = &waProto.AudioMessage{
Url: &media.URL, Url: &media.URL,
MediaKey: media.MediaKey, MediaKey: media.MediaKey,
Mimetype: &evt.Content.GetInfo().MimeType, Mimetype: &content.GetInfo().MimeType,
Seconds: &duration, Seconds: &duration,
FileEncSha256: media.FileEncSHA256, FileEncSha256: media.FileEncSHA256,
FileSha256: media.FileSHA256, FileSha256: media.FileSHA256,
FileLength: &media.FileLength, FileLength: &media.FileLength,
} }
case mautrix.MsgFile: case event.MsgFile:
media := portal.preprocessMatrixMedia(sender, relaybotFormatted, evt, whatsapp.MediaDocument) media := portal.preprocessMatrixMedia(sender, relaybotFormatted, content, evt.ID, whatsapp.MediaDocument)
if media == nil { if media == nil {
return return
} }
info.Message.DocumentMessage = &waProto.DocumentMessage{ info.Message.DocumentMessage = &waProto.DocumentMessage{
Url: &media.URL, Url: &media.URL,
FileName: &evt.Content.Body, FileName: &content.Body,
MediaKey: media.MediaKey, MediaKey: media.MediaKey,
Mimetype: &evt.Content.GetInfo().MimeType, Mimetype: &content.GetInfo().MimeType,
FileEncSha256: media.FileEncSHA256, FileEncSha256: media.FileEncSHA256,
FileSha256: media.FileSHA256, FileSha256: media.FileSHA256,
FileLength: &media.FileLength, FileLength: &media.FileLength,
@ -1353,9 +1408,9 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *mautrix.Event) {
_, err = sender.Conn.Send(info) _, err = sender.Conn.Send(info)
if err != nil { if err != nil {
portal.log.Errorfln("Error handling Matrix event %s: %v", evt.ID, err) portal.log.Errorfln("Error handling Matrix event %s: %v", evt.ID, err)
msg := format.RenderMarkdown(fmt.Sprintf("\u26a0 Your message may not have been bridged: %v", err)) msg := format.RenderMarkdown(fmt.Sprintf("\u26a0 Your message may not have been bridged: %v", err), false, false)
msg.MsgType = mautrix.MsgNotice msg.MsgType = event.MsgNotice
_, err := portal.MainIntent().SendMessageEvent(portal.MXID, mautrix.EventMessage, msg) _, err := portal.sendMainIntentMessage(msg)
if err != nil { if err != nil {
portal.log.Errorln("Failed to send bridging failure message:", err) portal.log.Errorln("Failed to send bridging failure message:", err)
} }
@ -1364,7 +1419,7 @@ func (portal *Portal) HandleMatrixMessage(sender *User, evt *mautrix.Event) {
} }
} }
func (portal *Portal) HandleMatrixRedaction(sender *User, evt *mautrix.Event) { func (portal *Portal) HandleMatrixRedaction(sender *User, evt *event.Event) {
if portal.IsPrivateChat() && sender.JID != portal.Key.Receiver { if portal.IsPrivateChat() && sender.JID != portal.Key.Receiver {
return return
} }
@ -1462,6 +1517,6 @@ func (portal *Portal) HandleMatrixLeave(sender *User) {
} }
} }
func (portal *Portal) HandleMatrixKick(sender *User, event *mautrix.Event) { func (portal *Portal) HandleMatrixKick(sender *User, event *event.Event) {
// TODO // TODO
} }

View file

@ -26,8 +26,8 @@ import (
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
log "maunium.net/go/maulogger/v2" log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix-whatsapp/types"
whatsappExt "maunium.net/go/mautrix-whatsapp/whatsapp-ext" whatsappExt "maunium.net/go/mautrix-whatsapp/whatsapp-ext"
"maunium.net/go/mautrix/id"
) )
type ProvisioningAPI struct { type ProvisioningAPI struct {
@ -61,7 +61,7 @@ func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
return return
} }
userID := r.URL.Query().Get("user_id") userID := r.URL.Query().Get("user_id")
user := prov.bridge.GetUserByMXID(types.MatrixUserID(userID)) user := prov.bridge.GetUserByMXID(id.UserID(userID))
h.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), "user", user))) h.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), "user", user)))
}) })
} }
@ -292,6 +292,9 @@ func (prov *ProvisioningAPI) Logout(w http.ResponseWriter, r *http.Request) {
} }
user.Conn.RemoveHandlers() user.Conn.RemoveHandlers()
user.Conn = nil user.Conn = nil
user.removeFromJIDMap()
// TODO this causes a foreign key violation, which should be fixed
//ce.User.JID = ""
user.SetSession(nil) user.SetSession(nil)
jsonResponse(w, http.StatusOK, Response{true, "Logged out successfully."}) jsonResponse(w, http.StatusOK, Response{true, "Logged out successfully."})
} }
@ -300,7 +303,7 @@ var upgrader = websocket.Upgrader{}
func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) { func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
userID := r.URL.Query().Get("user_id") userID := r.URL.Query().Get("user_id")
user := prov.bridge.GetUserByMXID(types.MatrixUserID(userID)) user := prov.bridge.GetUserByMXID(id.UserID(userID))
c, err := upgrader.Upgrade(w, r, nil) c, err := upgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {
@ -351,6 +354,7 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
} }
user.ConnectionErrors = 0 user.ConnectionErrors = 0
user.JID = strings.Replace(user.Conn.Info.Wid, whatsappExt.OldUserSuffix, whatsappExt.NewUserSuffix, 1) user.JID = strings.Replace(user.Conn.Info.Wid, whatsappExt.OldUserSuffix, whatsappExt.NewUserSuffix, 1)
user.addToJIDMap()
user.SetSession(&session) user.SetSession(&session)
_ = c.WriteJSON(map[string]interface{}{ _ = c.WriteJSON(map[string]interface{}{
"success": true, "success": true,

View file

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2019 Tulir Asokan // Copyright (C) 2020 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by
@ -25,14 +25,16 @@ import (
"github.com/Rhymen/go-whatsapp" "github.com/Rhymen/go-whatsapp"
log "maunium.net/go/maulogger/v2" log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix-appservice"
"maunium.net/go/mautrix/appservice"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix-whatsapp/database" "maunium.net/go/mautrix-whatsapp/database"
"maunium.net/go/mautrix-whatsapp/types" "maunium.net/go/mautrix-whatsapp/types"
"maunium.net/go/mautrix-whatsapp/whatsapp-ext" "maunium.net/go/mautrix-whatsapp/whatsapp-ext"
) )
func (bridge *Bridge) ParsePuppetMXID(mxid types.MatrixUserID) (types.WhatsAppID, bool) { func (bridge *Bridge) ParsePuppetMXID(mxid id.UserID) (types.WhatsAppID, bool) {
userIDRegex, err := regexp.Compile(fmt.Sprintf("^@%s:%s$", userIDRegex, err := regexp.Compile(fmt.Sprintf("^@%s:%s$",
bridge.Config.Bridge.FormatUsername("([0-9]+)"), bridge.Config.Bridge.FormatUsername("([0-9]+)"),
bridge.Config.Homeserver.Domain)) bridge.Config.Homeserver.Domain))
@ -49,7 +51,7 @@ func (bridge *Bridge) ParsePuppetMXID(mxid types.MatrixUserID) (types.WhatsAppID
return jid, true return jid, true
} }
func (bridge *Bridge) GetPuppetByMXID(mxid types.MatrixUserID) *Puppet { func (bridge *Bridge) GetPuppetByMXID(mxid id.UserID) *Puppet {
jid, ok := bridge.ParsePuppetMXID(mxid) jid, ok := bridge.ParsePuppetMXID(mxid)
if !ok { if !ok {
return nil return nil
@ -78,7 +80,7 @@ func (bridge *Bridge) GetPuppetByJID(jid types.WhatsAppID) *Puppet {
return puppet return puppet
} }
func (bridge *Bridge) GetPuppetByCustomMXID(mxid types.MatrixUserID) *Puppet { func (bridge *Bridge) GetPuppetByCustomMXID(mxid id.UserID) *Puppet {
bridge.puppetsLock.Lock() bridge.puppetsLock.Lock()
defer bridge.puppetsLock.Unlock() defer bridge.puppetsLock.Unlock()
puppet, ok := bridge.puppetsByCustomMXID[mxid] puppet, ok := bridge.puppetsByCustomMXID[mxid]
@ -129,7 +131,7 @@ func (bridge *Bridge) NewPuppet(dbPuppet *database.Puppet) *Puppet {
bridge: bridge, bridge: bridge,
log: bridge.Log.Sub(fmt.Sprintf("Puppet/%s", dbPuppet.JID)), log: bridge.Log.Sub(fmt.Sprintf("Puppet/%s", dbPuppet.JID)),
MXID: fmt.Sprintf("@%s:%s", MXID: id.NewUserID(
bridge.Config.Bridge.FormatUsername( bridge.Config.Bridge.FormatUsername(
strings.Replace( strings.Replace(
dbPuppet.JID, dbPuppet.JID,
@ -144,13 +146,13 @@ type Puppet struct {
bridge *Bridge bridge *Bridge
log log.Logger log log.Logger
typingIn types.MatrixRoomID typingIn id.RoomID
typingAt int64 typingAt int64
MXID types.MatrixUserID MXID id.UserID
customIntent *appservice.IntentAPI customIntent *appservice.IntentAPI
customTypingIn map[string]bool customTypingIn map[id.RoomID]bool
customUser *User customUser *User
} }
@ -159,7 +161,9 @@ func (puppet *Puppet) PhoneNumber() string {
} }
func (puppet *Puppet) IntentFor(portal *Portal) *appservice.IntentAPI { func (puppet *Puppet) IntentFor(portal *Portal) *appservice.IntentAPI {
if (!portal.IsPrivateChat() && puppet.customIntent == nil) || portal.backfilling || portal.Key.JID == puppet.JID { if (!portal.IsPrivateChat() && puppet.customIntent == nil) ||
(portal.backfilling && portal.bridge.Config.Bridge.InviteOwnPuppetForBackfilling) ||
portal.Key.JID == puppet.JID {
return puppet.DefaultIntent() return puppet.DefaultIntent()
} }
return puppet.customIntent return puppet.customIntent
@ -192,11 +196,11 @@ func (puppet *Puppet) UpdateAvatar(source *User, avatar *whatsappExt.ProfilePicI
} }
if len(avatar.URL) == 0 { if len(avatar.URL) == 0 {
err := puppet.DefaultIntent().SetAvatarURL("") err := puppet.DefaultIntent().SetAvatarURL(id.ContentURI{})
if err != nil { if err != nil {
puppet.log.Warnln("Failed to remove avatar:", err) puppet.log.Warnln("Failed to remove avatar:", err)
} }
puppet.AvatarURL = "" puppet.AvatarURL = id.ContentURI{}
puppet.Avatar = avatar.Tag puppet.Avatar = avatar.Tag
go puppet.updatePortalAvatar() go puppet.updatePortalAvatar()
return true return true

View file

@ -21,12 +21,3 @@ type WhatsAppID = string
// WhatsAppMessageID is the internal ID of a WhatsApp message. // WhatsAppMessageID is the internal ID of a WhatsApp message.
type WhatsAppMessageID = string type WhatsAppMessageID = string
// MatrixUserID is the ID of a Matrix user.
type MatrixUserID = string
// MatrixRoomID is the internal room ID of a Matrix room.
type MatrixRoomID = string
// MatrixEventID is the internal ID of a Matrix event.
type MatrixEventID = string

96
user.go
View file

@ -1,5 +1,5 @@
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge. // mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2019 Tulir Asokan // Copyright (C) 2020 Tulir Asokan
// //
// This program is free software: you can redistribute it and/or modify // This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by // it under the terms of the GNU Affero General Public License as published by
@ -32,8 +32,9 @@ import (
"github.com/Rhymen/go-whatsapp" "github.com/Rhymen/go-whatsapp"
waProto "github.com/Rhymen/go-whatsapp/binary/proto" waProto "github.com/Rhymen/go-whatsapp/binary/proto"
"maunium.net/go/mautrix" "maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/format" "maunium.net/go/mautrix/format"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix-whatsapp/database" "maunium.net/go/mautrix-whatsapp/database"
"maunium.net/go/mautrix-whatsapp/types" "maunium.net/go/mautrix-whatsapp/types"
@ -65,7 +66,7 @@ type User struct {
syncLock sync.Mutex syncLock sync.Mutex
} }
func (bridge *Bridge) GetUserByMXID(userID types.MatrixUserID) *User { func (bridge *Bridge) GetUserByMXID(userID id.UserID) *User {
_, isPuppet := bridge.ParsePuppetMXID(userID) _, isPuppet := bridge.ParsePuppetMXID(userID)
if isPuppet || userID == bridge.Bot.UserID { if isPuppet || userID == bridge.Bot.UserID {
return nil return nil
@ -89,6 +90,18 @@ func (bridge *Bridge) GetUserByJID(userID types.WhatsAppID) *User {
return user return user
} }
func (user *User) addToJIDMap() {
user.bridge.usersLock.Lock()
user.bridge.usersByJID[user.JID] = user
user.bridge.usersLock.Unlock()
}
func (user *User) removeFromJIDMap() {
user.bridge.usersLock.Lock()
delete(user.bridge.usersByJID, user.JID)
user.bridge.usersLock.Unlock()
}
func (bridge *Bridge) GetAllUsers() []*User { func (bridge *Bridge) GetAllUsers() []*User {
bridge.usersLock.Lock() bridge.usersLock.Lock()
defer bridge.usersLock.Unlock() defer bridge.usersLock.Unlock()
@ -104,7 +117,7 @@ func (bridge *Bridge) GetAllUsers() []*User {
return output return output
} }
func (bridge *Bridge) loadDBUser(dbUser *database.User, mxid *types.MatrixUserID) *User { func (bridge *Bridge) loadDBUser(dbUser *database.User, mxid *id.UserID) *User {
if dbUser == nil { if dbUser == nil {
if mxid == nil { if mxid == nil {
return nil return nil
@ -160,7 +173,7 @@ func (bridge *Bridge) NewUser(dbUser *database.User) *User {
return user return user
} }
func (user *User) SetManagementRoom(roomID types.MatrixRoomID) { func (user *User) SetManagementRoom(roomID id.RoomID) {
existingUser, ok := user.bridge.managementRooms[roomID] existingUser, ok := user.bridge.managementRooms[roomID]
if ok { if ok {
existingUser.ManagementRoom = "" existingUser.ManagementRoom = ""
@ -194,9 +207,9 @@ func (user *User) Connect(evenIfNoSession bool) bool {
conn, err := whatsapp.NewConn(timeout * time.Second) conn, err := whatsapp.NewConn(timeout * time.Second)
if err != nil { if err != nil {
user.log.Errorln("Failed to connect to WhatsApp:", err) user.log.Errorln("Failed to connect to WhatsApp:", err)
msg := format.RenderMarkdown("\u26a0 Failed to connect to WhatsApp server. " + msg := format.RenderMarkdown("\u26a0 Failed to connect to WhatsApp server. "+
"This indicates a network problem on the bridge server. See bridge logs for more info.") "This indicates a network problem on the bridge server. See bridge logs for more info.", true, false)
_, _ = user.bridge.Bot.SendMessageEvent(user.ManagementRoom, mautrix.EventMessage, msg) _, _ = user.bridge.Bot.SendMessageEvent(user.ManagementRoom, event.EventMessage, msg)
return false return false
} }
user.Conn = whatsappExt.ExtendConn(conn) user.Conn = whatsappExt.ExtendConn(conn)
@ -213,9 +226,9 @@ func (user *User) RestoreSession() bool {
return true return true
} else if err != nil { } else if err != nil {
user.log.Errorln("Failed to restore session:", err) user.log.Errorln("Failed to restore session:", err)
msg := format.RenderMarkdown("\u26a0 Failed to connect to WhatsApp. Make sure WhatsApp " + msg := format.RenderMarkdown("\u26a0 Failed to connect to WhatsApp. Make sure WhatsApp "+
"on your phone is reachable and use `reconnect` to try connecting again.") "on your phone is reachable and use `reconnect` to try connecting again.", true, false)
_, _ = user.bridge.Bot.SendMessageEvent(user.ManagementRoom, mautrix.EventMessage, msg) _, _ = user.bridge.Bot.SendMessageEvent(user.ManagementRoom, event.EventMessage, msg)
user.log.Debugln("Disconnecting due to failed session restore...") user.log.Debugln("Disconnecting due to failed session restore...")
_, err := user.Conn.Disconnect() _, err := user.Conn.Disconnect()
if err != nil { if err != nil {
@ -243,8 +256,8 @@ func (user *User) IsLoginInProgress() bool {
return user.Conn != nil && user.Conn.IsLoginInProgress() return user.Conn != nil && user.Conn.IsLoginInProgress()
} }
func (user *User) loginQrChannel(ce *CommandEvent, qrChan <-chan string, eventIDChan chan<- string) { func (user *User) loginQrChannel(ce *CommandEvent, qrChan <-chan string, eventIDChan chan<- id.EventID) {
var qrEventID string var qrEventID id.EventID
for code := range qrChan { for code := range qrChan {
if code == "stop" { if code == "stop" {
return return
@ -274,17 +287,17 @@ func (user *User) loginQrChannel(ce *CommandEvent, qrChan <-chan string, eventID
qrEventID = sendResp.EventID qrEventID = sendResp.EventID
eventIDChan <- qrEventID eventIDChan <- qrEventID
} else { } else {
_, err = bot.SendMessageEvent(ce.RoomID, mautrix.EventMessage, &mautrix.Content{ _, err = bot.SendMessageEvent(ce.RoomID, event.EventMessage, &event.MessageEventContent{
MsgType: mautrix.MsgImage, MsgType: event.MsgImage,
Body: code, Body: code,
URL: resp.ContentURI, URL: resp.ContentURI.CUString(),
NewContent: &mautrix.Content{ NewContent: &event.MessageEventContent{
MsgType: mautrix.MsgImage, MsgType: event.MsgImage,
Body: code, Body: code,
URL: resp.ContentURI, URL: resp.ContentURI.CUString(),
}, },
RelatesTo: &mautrix.RelatesTo{ RelatesTo: &event.RelatesTo{
Type: mautrix.RelReplace, Type: event.RelReplace,
EventID: qrEventID, EventID: qrEventID,
}, },
}) })
@ -297,18 +310,18 @@ func (user *User) loginQrChannel(ce *CommandEvent, qrChan <-chan string, eventID
func (user *User) Login(ce *CommandEvent) { func (user *User) Login(ce *CommandEvent) {
qrChan := make(chan string, 3) qrChan := make(chan string, 3)
eventIDChan := make(chan string, 1) eventIDChan := make(chan id.EventID, 1)
go user.loginQrChannel(ce, qrChan, eventIDChan) go user.loginQrChannel(ce, qrChan, eventIDChan)
session, err := user.Conn.LoginWithRetry(qrChan, user.bridge.Config.Bridge.LoginQRRegenCount) session, err := user.Conn.LoginWithRetry(qrChan, user.bridge.Config.Bridge.LoginQRRegenCount)
qrChan <- "stop" qrChan <- "stop"
if err != nil { if err != nil {
var eventID string var eventID id.EventID
select { select {
case eventID = <-eventIDChan: case eventID = <-eventIDChan:
default: default:
} }
reply := mautrix.Content{ reply := event.MessageEventContent{
MsgType: mautrix.MsgText, MsgType: event.MsgText,
} }
if err == whatsapp.ErrAlreadyLoggedIn { if err == whatsapp.ErrAlreadyLoggedIn {
reply.Body = "You're already logged in" reply.Body = "You're already logged in"
@ -323,16 +336,19 @@ func (user *User) Login(ce *CommandEvent) {
msg := reply msg := reply
if eventID != "" { if eventID != "" {
msg.NewContent = &reply msg.NewContent = &reply
msg.RelatesTo = &mautrix.RelatesTo{ msg.RelatesTo = &event.RelatesTo{
Type: mautrix.RelReplace, Type: event.RelReplace,
EventID: eventID, EventID: eventID,
} }
} }
_, _ = ce.Bot.SendMessageEvent(ce.RoomID, mautrix.EventMessage, &msg) _, _ = ce.Bot.SendMessageEvent(ce.RoomID, event.EventMessage, &msg)
return return
} }
// TODO there's a bit of duplication between this and the provisioning API login method
// Also between the two logout methods (commands.go and provisioning.go)
user.ConnectionErrors = 0 user.ConnectionErrors = 0
user.JID = strings.Replace(user.Conn.Info.Wid, whatsappExt.OldUserSuffix, whatsappExt.NewUserSuffix, 1) user.JID = strings.Replace(user.Conn.Info.Wid, whatsappExt.OldUserSuffix, whatsappExt.NewUserSuffix, 1)
user.addToJIDMap()
user.SetSession(&session) user.SetSession(&session)
ce.Reply("Successfully logged in, synchronizing chats...") ce.Reply("Successfully logged in, synchronizing chats...")
user.PostLogin() user.PostLogin()
@ -365,8 +381,11 @@ func (user *User) PostLogin() {
} }
func (user *User) tryAutomaticDoublePuppeting() { func (user *User) tryAutomaticDoublePuppeting() {
if len(user.bridge.Config.Bridge.LoginSharedSecret) == 0 || !strings.HasSuffix(user.MXID, user.bridge.Config.Homeserver.Domain) { if len(user.bridge.Config.Bridge.LoginSharedSecret) == 0 {
// Automatic login not enabled or user is on another homeserver // Automatic login not enabled
return
} else if _, homeserver, _ := user.MXID.Parse(); homeserver != user.bridge.Config.Homeserver.Domain {
// user is on another homeserver
return return
} }
@ -535,8 +554,8 @@ func (user *User) HandleError(err error) {
func (user *User) tryReconnect(msg string) { func (user *User) tryReconnect(msg string) {
if user.ConnectionErrors > user.bridge.Config.Bridge.MaxConnectionAttempts { if user.ConnectionErrors > user.bridge.Config.Bridge.MaxConnectionAttempts {
content := format.RenderMarkdown(fmt.Sprintf("%s. Use the `reconnect` command to reconnect.", msg)) content := format.RenderMarkdown(fmt.Sprintf("%s. Use the `reconnect` command to reconnect.", msg), true, false)
_, _ = user.bridge.Bot.SendMessageEvent(user.ManagementRoom, mautrix.EventMessage, content) _, _ = user.bridge.Bot.SendMessageEvent(user.ManagementRoom, event.EventMessage, content)
return return
} }
if user.bridge.Config.Bridge.ReportConnectionRetry { if user.bridge.Config.Bridge.ReportConnectionRetry {
@ -591,8 +610,8 @@ func (user *User) tryReconnect(msg string) {
"Use the `reconnect` command to try to reconnect.", msg, tries) "Use the `reconnect` command to try to reconnect.", msg, tries)
} }
content := format.RenderMarkdown(msg) content := format.RenderMarkdown(msg, true, false)
_, _ = user.bridge.Bot.SendMessageEvent(user.ManagementRoom, mautrix.EventMessage, content) _, _ = user.bridge.Bot.SendMessageEvent(user.ManagementRoom, event.EventMessage, content)
} }
func (user *User) ShouldCallSynchronously() bool { func (user *User) ShouldCallSynchronously() bool {
@ -656,8 +675,9 @@ func (user *User) HandleMessageRevoke(message whatsappExt.MessageRevocation) {
} }
type FakeMessage struct { type FakeMessage struct {
Text string Text string
ID string ID string
Alert bool
} }
func (user *User) HandleCallInfo(info whatsappExt.CallInfo) { func (user *User) HandleCallInfo(info whatsappExt.CallInfo) {
@ -673,11 +693,13 @@ func (user *User) HandleCallInfo(info whatsappExt.CallInfo) {
return return
} }
data.Text = "Incoming call" data.Text = "Incoming call"
data.Alert = true
case whatsappExt.CallOfferVideo: case whatsappExt.CallOfferVideo:
if !user.bridge.Config.Bridge.CallNotices.Start { if !user.bridge.Config.Bridge.CallNotices.Start {
return return
} }
data.Text = "Incoming video call" data.Text = "Incoming video call"
data.Alert = true
case whatsappExt.CallTerminate: case whatsappExt.CallTerminate:
if !user.bridge.Config.Bridge.CallNotices.End { if !user.bridge.Config.Bridge.CallNotices.End {
return return
@ -766,7 +788,7 @@ func (user *User) HandleCommand(cmd whatsappExt.Command) {
"Use the `reconnect` command to reconnect.", cmd.Kind) "Use the `reconnect` command to reconnect.", cmd.Kind)
} }
user.cleanDisconnection = true user.cleanDisconnection = true
go user.bridge.Bot.SendMessageEvent(user.ManagementRoom, mautrix.EventMessage, format.RenderMarkdown(msg)) go user.bridge.Bot.SendMessageEvent(user.ManagementRoom, event.EventMessage, format.RenderMarkdown(msg, true, false))
} }
} }

14
webp.go Normal file
View file

@ -0,0 +1,14 @@
// +build cgo
package main
import (
"image"
"io"
"github.com/chai2010/webp"
)
func decodeWebp(r io.Reader) (image.Image, error) {
return webp.Decode(r)
}