mautrix-whatsapp/database/portal.go

271 lines
8.2 KiB
Go

// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
// Copyright (C) 2021 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package database
import (
"database/sql"
"fmt"
"time"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/util/dbutil"
"go.mau.fi/whatsmeow/types"
)
type PortalKey struct {
JID types.JID
Receiver types.JID
}
func NewPortalKey(jid, receiver types.JID) PortalKey {
if jid.Server == types.GroupServer {
receiver = jid
} else if jid.Server == types.LegacyUserServer {
jid.Server = types.DefaultUserServer
}
return PortalKey{
JID: jid.ToNonAD(),
Receiver: receiver.ToNonAD(),
}
}
func (key PortalKey) String() string {
if key.Receiver == key.JID {
return key.JID.String()
}
return key.JID.String() + "-" + key.Receiver.String()
}
type PortalQuery struct {
db *Database
log log.Logger
}
func (pq *PortalQuery) New() *Portal {
return &Portal{
db: pq.db,
log: pq.log,
}
}
const portalColumns = "jid, receiver, mxid, name, name_set, topic, topic_set, avatar, avatar_url, avatar_set, encrypted, last_sync, is_parent, parent_group, in_space, first_event_id, next_batch_id, relay_user_id, expiration_time"
func (pq *PortalQuery) GetAll() []*Portal {
return pq.getAll(fmt.Sprintf("SELECT %s FROM portal", portalColumns))
}
func (pq *PortalQuery) GetByJID(key PortalKey) *Portal {
return pq.get(fmt.Sprintf("SELECT %s FROM portal WHERE jid=$1 AND receiver=$2", portalColumns), key.JID, key.Receiver)
}
func (pq *PortalQuery) GetByMXID(mxid id.RoomID) *Portal {
return pq.get(fmt.Sprintf("SELECT %s FROM portal WHERE mxid=$1", portalColumns), mxid)
}
func (pq *PortalQuery) GetAllByJID(jid types.JID) []*Portal {
return pq.getAll(fmt.Sprintf("SELECT %s FROM portal WHERE jid=$1", portalColumns), jid.ToNonAD())
}
func (pq *PortalQuery) FindPrivateChats(receiver types.JID) []*Portal {
return pq.getAll(fmt.Sprintf("SELECT %s FROM portal WHERE receiver=$1 AND jid LIKE '%%@s.whatsapp.net'", portalColumns), receiver.ToNonAD())
}
func (pq *PortalQuery) FindPrivateChatsNotInSpace(receiver types.JID) (keys []PortalKey) {
receiver = receiver.ToNonAD()
rows, err := pq.db.Query(`
SELECT jid FROM portal
LEFT JOIN user_portal ON portal.jid=user_portal.portal_jid AND portal.receiver=user_portal.portal_receiver
WHERE mxid<>'' AND receiver=$1 AND (in_space=false OR in_space IS NULL)
`, receiver)
if err != nil || rows == nil {
return
}
for rows.Next() {
var key PortalKey
key.Receiver = receiver
err = rows.Scan(&key.JID)
if err == nil {
keys = append(keys, key)
}
}
return
}
func (pq *PortalQuery) getAll(query string, args ...interface{}) (portals []*Portal) {
rows, err := pq.db.Query(query, args...)
if err != nil || rows == nil {
return nil
}
defer rows.Close()
for rows.Next() {
portals = append(portals, pq.New().Scan(rows))
}
return
}
func (pq *PortalQuery) get(query string, args ...interface{}) *Portal {
row := pq.db.QueryRow(query, args...)
if row == nil {
return nil
}
return pq.New().Scan(row)
}
type Portal struct {
db *Database
log log.Logger
Key PortalKey
MXID id.RoomID
Name string
NameSet bool
Topic string
TopicSet bool
Avatar string
AvatarURL id.ContentURI
AvatarSet bool
Encrypted bool
LastSync time.Time
IsParent bool
ParentGroup types.JID
InSpace bool
FirstEventID id.EventID
NextBatchID id.BatchID
RelayUserID id.UserID
ExpirationTime uint32
}
func (portal *Portal) Scan(row dbutil.Scannable) *Portal {
var mxid, avatarURL, firstEventID, nextBatchID, relayUserID, parentGroupJID sql.NullString
var lastSyncTs int64
err := row.Scan(&portal.Key.JID, &portal.Key.Receiver, &mxid, &portal.Name, &portal.NameSet, &portal.Topic, &portal.TopicSet, &portal.Avatar, &avatarURL, &portal.AvatarSet, &portal.Encrypted, &lastSyncTs, &portal.IsParent, &parentGroupJID, &portal.InSpace, &firstEventID, &nextBatchID, &relayUserID, &portal.ExpirationTime)
if err != nil {
if err != sql.ErrNoRows {
portal.log.Errorln("Database scan failed:", err)
}
return nil
}
if lastSyncTs > 0 {
portal.LastSync = time.Unix(lastSyncTs, 0)
}
portal.MXID = id.RoomID(mxid.String)
portal.AvatarURL, _ = id.ParseContentURI(avatarURL.String)
if parentGroupJID.Valid {
portal.ParentGroup, _ = types.ParseJID(parentGroupJID.String)
}
portal.FirstEventID = id.EventID(firstEventID.String)
portal.NextBatchID = id.BatchID(nextBatchID.String)
portal.RelayUserID = id.UserID(relayUserID.String)
return portal
}
func (portal *Portal) mxidPtr() *id.RoomID {
if len(portal.MXID) > 0 {
return &portal.MXID
}
return nil
}
func (portal *Portal) relayUserPtr() *id.UserID {
if len(portal.RelayUserID) > 0 {
return &portal.RelayUserID
}
return nil
}
func (portal *Portal) parentGroupPtr() *string {
if !portal.ParentGroup.IsEmpty() {
val := portal.ParentGroup.String()
return &val
}
return nil
}
func (portal *Portal) lastSyncTs() int64 {
if portal.LastSync.IsZero() {
return 0
}
return portal.LastSync.Unix()
}
func (portal *Portal) Insert() {
_, err := portal.db.Exec(`
INSERT INTO portal (jid, receiver, mxid, name, name_set, topic, topic_set, avatar, avatar_url, avatar_set,
encrypted, last_sync, is_parent, parent_group, in_space, first_event_id, next_batch_id,
relay_user_id, expiration_time)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19)
`,
portal.Key.JID, portal.Key.Receiver, portal.mxidPtr(), portal.Name, portal.NameSet, portal.Topic, portal.TopicSet,
portal.Avatar, portal.AvatarURL.String(), portal.AvatarSet, portal.Encrypted, portal.lastSyncTs(),
portal.IsParent, portal.parentGroupPtr(), portal.InSpace, portal.FirstEventID.String(), portal.NextBatchID.String(),
portal.relayUserPtr(), portal.ExpirationTime)
if err != nil {
portal.log.Warnfln("Failed to insert %s: %v", portal.Key, err)
}
}
func (portal *Portal) Update(txn dbutil.Execable) {
if txn == nil {
txn = portal.db
}
_, err := txn.Exec(`
UPDATE portal
SET mxid=$1, name=$2, name_set=$3, topic=$4, topic_set=$5, avatar=$6, avatar_url=$7, avatar_set=$8,
encrypted=$9, last_sync=$10, is_parent=$11, parent_group=$12, in_space=$13,
first_event_id=$14, next_batch_id=$15, relay_user_id=$16, expiration_time=$17
WHERE jid=$18 AND receiver=$19
`, portal.mxidPtr(), portal.Name, portal.NameSet, portal.Topic, portal.TopicSet, portal.Avatar, portal.AvatarURL.String(),
portal.AvatarSet, portal.Encrypted, portal.lastSyncTs(), portal.IsParent, portal.parentGroupPtr(), portal.InSpace,
portal.FirstEventID.String(), portal.NextBatchID.String(), portal.relayUserPtr(), portal.ExpirationTime,
portal.Key.JID, portal.Key.Receiver)
if err != nil {
portal.log.Warnfln("Failed to update %s: %v", portal.Key, err)
}
}
func (portal *Portal) Delete() {
txn, err := portal.db.Begin()
if err != nil {
portal.log.Errorfln("Failed to begin transaction to delete portal %v: %v", portal.Key, err)
return
}
defer func() {
if err != nil {
err = txn.Rollback()
if err != nil {
portal.log.Warnfln("Failed to rollback failed portal delete transaction: %v", err)
}
} else if err = txn.Commit(); err != nil {
portal.log.Warnfln("Failed to commit portal delete transaction: %v", err)
}
}()
_, err = portal.db.Exec("UPDATE portal SET in_space=false WHERE parent_group=$1", portal.Key.JID)
if err != nil {
portal.log.Warnfln("Failed to mark child groups of %v as not in space: %v", portal.Key.JID, err)
return
}
_, err = portal.db.Exec("DELETE FROM portal WHERE jid=$1 AND receiver=$2", portal.Key.JID, portal.Key.Receiver)
if err != nil {
portal.log.Warnfln("Failed to delete %v: %v", portal.Key, err)
}
}