From ba536f17e7545fe54270fecc5e28ff1c847744c2 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 29 Dec 2021 21:40:58 +0200 Subject: [PATCH] Add command to sync DM rooms into space --- commands.go | 19 +++++++++++++++++-- database/portal.go | 21 +++++++++++++++++++++ 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/commands.go b/commands.go index 0f41049..7995773 100644 --- a/commands.go +++ b/commands.go @@ -1072,13 +1072,14 @@ const cmdSyncHelp = `sync [--create-portals] - Synchr func (handler *CommandHandler) CommandSync(ce *CommandEvent) { if len(ce.Args) == 0 { - ce.Reply("**Usage:** `sync [--create-portals]`") + ce.Reply("**Usage:** `sync [--create-portals]`") return } args := strings.ToLower(strings.Join(ce.Args, " ")) contacts := strings.Contains(args, "contacts") appState := strings.Contains(args, "appstate") - groups := strings.Contains(args, "groups") + space := strings.Contains(args, "space") + groups := strings.Contains(args, "groups") || space createPortals := strings.Contains(args, "--create-portals") if appState { @@ -1100,6 +1101,20 @@ func (handler *CommandHandler) CommandSync(ce *CommandEvent) { ce.Reply("Resynced contacts") } } + if space { + keys := ce.Bridge.DB.Portal.FindPrivateChatsNotInSpace(ce.User.JID) + count := 0 + for _, key := range keys { + portal := ce.Bridge.GetPortalByJID(key) + portal.addToSpace(ce.User) + count++ + } + plural := "s" + if count == 1 { + plural = "" + } + ce.Reply("Added %d DM room%s to space", count, plural) + } if groups { err := ce.User.ResyncGroups(createPortals) if err != nil { diff --git a/database/portal.go b/database/portal.go index 7fc765e..9bec4ed 100644 --- a/database/portal.go +++ b/database/portal.go @@ -82,6 +82,27 @@ func (pq *PortalQuery) FindPrivateChats(receiver types.JID) []*Portal { return pq.getAll("SELECT * FROM portal WHERE receiver=$1 AND jid LIKE '%@s.whatsapp.net'", receiver.ToNonAD()) } +func (pq *PortalQuery) FindPrivateChatsNotInSpace(receiver types.JID) (keys []PortalKey) { + receiver = receiver.ToNonAD() + rows, err := pq.db.Query(` + SELECT jid FROM portal + LEFT JOIN user_portal ON portal.jid=user_portal.portal_jid AND portal.receiver=user_portal.portal_receiver + WHERE mxid<>'' AND receiver=$1 AND (in_space=false OR in_space IS NULL) + `, receiver) + if err != nil || rows == nil { + return + } + for rows.Next() { + var key PortalKey + key.Receiver = receiver + err = rows.Scan(&key.JID) + if err == nil { + keys = append(keys, key) + } + } + return +} + func (pq *PortalQuery) getAll(query string, args ...interface{}) (portals []*Portal) { rows, err := pq.db.Query(query, args...) if err != nil || rows == nil {