2018-08-13 22:24:44 +02:00
// mautrix-whatsapp - A Matrix-WhatsApp puppeting bridge.
2021-10-22 19:14:34 +02:00
// Copyright (C) 2021 Tulir Asokan
2018-08-13 22:24:44 +02:00
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
package database
import (
2018-08-26 21:53:13 +02:00
"database/sql"
2019-01-11 20:17:31 +01:00
log "maunium.net/go/maulogger/v2"
2021-10-26 16:01:10 +02:00
2021-02-17 00:21:30 +01:00
"maunium.net/go/mautrix/id"
2021-10-22 19:14:34 +02:00
"go.mau.fi/whatsmeow/types"
2018-08-13 22:24:44 +02:00
)
2018-08-28 23:40:54 +02:00
type PortalKey struct {
2021-10-22 19:14:34 +02:00
JID types . JID
Receiver types . JID
2018-08-28 23:40:54 +02:00
}
2021-10-22 19:14:34 +02:00
func GroupPortalKey ( jid types . JID ) PortalKey {
2021-10-26 16:01:10 +02:00
return NewPortalKey ( jid , jid )
2018-08-28 23:40:54 +02:00
}
2021-10-22 19:14:34 +02:00
func NewPortalKey ( jid , receiver types . JID ) PortalKey {
if jid . Server == types . GroupServer {
2018-08-28 23:40:54 +02:00
receiver = jid
2021-10-26 16:01:10 +02:00
} else if jid . Server == types . LegacyUserServer {
jid . Server = types . DefaultUserServer
2018-08-28 23:40:54 +02:00
}
return PortalKey {
2021-10-22 19:14:34 +02:00
JID : jid . ToNonAD ( ) ,
Receiver : receiver . ToNonAD ( ) ,
2018-08-28 23:40:54 +02:00
}
}
func ( key PortalKey ) String ( ) string {
if key . Receiver == key . JID {
2021-10-22 19:14:34 +02:00
return key . JID . String ( )
2018-08-28 23:40:54 +02:00
}
2021-10-22 19:14:34 +02:00
return key . JID . String ( ) + "-" + key . Receiver . String ( )
2018-08-28 23:40:54 +02:00
}
2018-08-13 22:24:44 +02:00
type PortalQuery struct {
db * Database
2018-08-16 18:20:07 +02:00
log log . Logger
2018-08-13 22:24:44 +02:00
}
func ( pq * PortalQuery ) New ( ) * Portal {
return & Portal {
db : pq . db ,
log : pq . log ,
}
}
2019-06-01 19:03:29 +02:00
func ( pq * PortalQuery ) GetAll ( ) [ ] * Portal {
return pq . getAll ( "SELECT * FROM portal" )
2018-08-13 22:24:44 +02:00
}
2018-08-28 23:40:54 +02:00
func ( pq * PortalQuery ) GetByJID ( key PortalKey ) * Portal {
2019-03-06 16:33:42 +01:00
return pq . get ( "SELECT * FROM portal WHERE jid=$1 AND receiver=$2" , key . JID , key . Receiver )
2018-08-13 22:24:44 +02:00
}
2020-05-08 21:32:22 +02:00
func ( pq * PortalQuery ) GetByMXID ( mxid id . RoomID ) * Portal {
2019-03-06 16:33:42 +01:00
return pq . get ( "SELECT * FROM portal WHERE mxid=$1" , mxid )
2018-08-13 22:24:44 +02:00
}
2021-10-22 19:14:34 +02:00
func ( pq * PortalQuery ) GetAllByJID ( jid types . JID ) [ ] * Portal {
2021-10-26 16:01:10 +02:00
return pq . getAll ( "SELECT * FROM portal WHERE jid=$1" , jid . ToNonAD ( ) )
2019-06-01 19:03:29 +02:00
}
2021-10-22 19:14:34 +02:00
func ( pq * PortalQuery ) FindPrivateChats ( receiver types . JID ) [ ] * Portal {
2021-10-26 16:01:10 +02:00
return pq . getAll ( "SELECT * FROM portal WHERE receiver=$1 AND jid LIKE '%@s.whatsapp.net'" , receiver . ToNonAD ( ) )
2020-08-22 12:07:55 +02:00
}
2019-06-01 19:03:29 +02:00
func ( pq * PortalQuery ) getAll ( query string , args ... interface { } ) ( portals [ ] * Portal ) {
rows , err := pq . db . Query ( query , args ... )
if err != nil || rows == nil {
return nil
}
defer rows . Close ( )
for rows . Next ( ) {
portals = append ( portals , pq . New ( ) . Scan ( rows ) )
}
return
}
2018-08-13 22:24:44 +02:00
func ( pq * PortalQuery ) get ( query string , args ... interface { } ) * Portal {
row := pq . db . QueryRow ( query , args ... )
if row == nil {
return nil
}
return pq . New ( ) . Scan ( row )
}
type Portal struct {
db * Database
2018-08-16 18:20:07 +02:00
log log . Logger
2018-08-13 22:24:44 +02:00
2018-08-28 23:40:54 +02:00
Key PortalKey
2020-05-08 21:32:22 +02:00
MXID id . RoomID
2018-08-18 21:57:08 +02:00
2019-06-01 19:03:29 +02:00
Name string
Topic string
Avatar string
2020-05-08 21:32:22 +02:00
AvatarURL id . ContentURI
2020-05-09 01:03:59 +02:00
Encrypted bool
2021-10-26 16:01:10 +02:00
FirstEventID id . EventID
NextBatchID id . BatchID
2018-08-13 22:24:44 +02:00
}
func ( portal * Portal ) Scan ( row Scannable ) * Portal {
2021-10-26 16:01:10 +02:00
var mxid , avatarURL , firstEventID , nextBatchID sql . NullString
err := row . Scan ( & portal . Key . JID , & portal . Key . Receiver , & mxid , & portal . Name , & portal . Topic , & portal . Avatar , & avatarURL , & portal . Encrypted , & firstEventID , & nextBatchID )
2018-08-13 22:24:44 +02:00
if err != nil {
2018-08-18 21:57:08 +02:00
if err != sql . ErrNoRows {
2018-08-26 00:55:21 +02:00
portal . log . Errorln ( "Database scan failed:" , err )
2018-08-18 21:57:08 +02:00
}
return nil
2018-08-13 22:24:44 +02:00
}
2020-05-08 21:32:22 +02:00
portal . MXID = id . RoomID ( mxid . String )
portal . AvatarURL , _ = id . ParseContentURI ( avatarURL . String )
2021-10-26 16:01:10 +02:00
portal . FirstEventID = id . EventID ( firstEventID . String )
portal . NextBatchID = id . BatchID ( nextBatchID . String )
2018-08-13 22:24:44 +02:00
return portal
}
2020-05-08 21:32:22 +02:00
func ( portal * Portal ) mxidPtr ( ) * id . RoomID {
2018-08-23 00:12:26 +02:00
if len ( portal . MXID ) > 0 {
2018-08-28 23:40:54 +02:00
return & portal . MXID
2018-08-23 00:12:26 +02:00
}
2018-08-28 23:40:54 +02:00
return nil
}
2019-01-21 22:55:16 +01:00
func ( portal * Portal ) Insert ( ) {
2021-10-26 16:01:10 +02:00
_ , err := portal . db . Exec ( "INSERT INTO portal (jid, receiver, mxid, name, topic, avatar, avatar_url, encrypted, first_event_id, next_batch_id) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)" ,
portal . Key . JID , portal . Key . Receiver , portal . mxidPtr ( ) , portal . Name , portal . Topic , portal . Avatar , portal . AvatarURL . String ( ) , portal . Encrypted , portal . FirstEventID . String ( ) , portal . NextBatchID . String ( ) )
2018-08-18 21:57:08 +02:00
if err != nil {
2018-08-28 23:40:54 +02:00
portal . log . Warnfln ( "Failed to insert %s: %v" , portal . Key , err )
2018-08-18 21:57:08 +02:00
}
2018-08-13 22:24:44 +02:00
}
2019-01-21 22:55:16 +01:00
func ( portal * Portal ) Update ( ) {
2021-10-26 16:01:10 +02:00
_ , err := portal . db . Exec ( "UPDATE portal SET mxid=$1, name=$2, topic=$3, avatar=$4, avatar_url=$5, encrypted=$6, first_event_id=$7, next_batch_id=$8 WHERE jid=$9 AND receiver=$10" ,
portal . mxidPtr ( ) , portal . Name , portal . Topic , portal . Avatar , portal . AvatarURL . String ( ) , portal . Encrypted , portal . FirstEventID . String ( ) , portal . NextBatchID . String ( ) , portal . Key . JID , portal . Key . Receiver )
2018-08-18 21:57:08 +02:00
if err != nil {
2018-08-28 23:40:54 +02:00
portal . log . Warnfln ( "Failed to update %s: %v" , portal . Key , err )
2018-08-18 21:57:08 +02:00
}
2018-08-13 22:24:44 +02:00
}
2019-05-16 19:14:32 +02:00
func ( portal * Portal ) Delete ( ) {
_ , err := portal . db . Exec ( "DELETE FROM portal WHERE jid=$1 AND receiver=$2" , portal . Key . JID , portal . Key . Receiver )
if err != nil {
portal . log . Warnfln ( "Failed to delete %s: %v" , portal . Key , err )
}
}
2019-05-28 20:31:25 +02:00
2021-10-22 19:14:34 +02:00
//func (portal *Portal) GetUserIDs() []id.UserID {
// rows, err := portal.db.Query(`SELECT "user".mxid FROM "user", user_portal
// WHERE "user".jid=user_portal.user_jid
// AND user_portal.portal_jid=$1
// AND user_portal.portal_receiver=$2`,
// portal.Key.JID, portal.Key.Receiver)
// if err != nil {
// portal.log.Debugln("Failed to get portal user ids:", err)
// return nil
// }
// var userIDs []id.UserID
// for rows.Next() {
// var userID id.UserID
// err = rows.Scan(&userID)
// if err != nil {
// portal.log.Warnln("Failed to scan row:", err)
// continue
// }
// userIDs = append(userIDs, userID)
// }
// return userIDs
//}