diff --git a/database/database.go b/database/database.go index 5e9b995..3062341 100644 --- a/database/database.go +++ b/database/database.go @@ -23,6 +23,8 @@ import ( _ "github.com/mattn/go-sqlite3" log "maunium.net/go/maulogger/v2" + + "maunium.net/go/mautrix-whatsapp/database/upgrades" ) type Database struct { @@ -64,24 +66,8 @@ func New(dbType string, uri string) (*Database, error) { return db, nil } -func (db *Database) CreateTables(dbType string) error { - err := db.User.CreateTable(dbType) - if err != nil { - return err - } - err = db.Portal.CreateTable(dbType) - if err != nil { - return err - } - err = db.Puppet.CreateTable(dbType) - if err != nil { - return err - } - err = db.Message.CreateTable(dbType) - if err != nil { - return err - } - return nil +func (db *Database) Init(dialectName string) error { + return upgrades.Run(db.log.Sub("Upgrade"), dialectName, db.DB) } type Scannable interface { diff --git a/database/message.go b/database/message.go index b7910a2..2004dc4 100644 --- a/database/message.go +++ b/database/message.go @@ -18,7 +18,6 @@ package database import ( "bytes" - "strings" "database/sql" "encoding/json" @@ -34,36 +33,6 @@ type MessageQuery struct { log log.Logger } -func (mq *MessageQuery) CreateTable(dbType string) error { - if strings.ToLower(dbType) == "postgres" { - _, err := mq.db.Exec(`CREATE TABLE IF NOT EXISTS message ( - chat_jid VARCHAR(255), - chat_receiver VARCHAR(255), - jid VARCHAR(255), - mxid VARCHAR(255) NOT NULL UNIQUE, - sender VARCHAR(255) NOT NULL, - content bytea NOT NULL, - - PRIMARY KEY (chat_jid, chat_receiver, jid), - FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver) - )`) - return err - } else { - _, err := mq.db.Exec(`CREATE TABLE IF NOT EXISTS message ( - chat_jid VARCHAR(255), - chat_receiver VARCHAR(255), - jid VARCHAR(255), - mxid VARCHAR(255) NOT NULL UNIQUE, - sender VARCHAR(255) NOT NULL, - content BLOB NOT NULL, - - PRIMARY KEY (chat_jid, chat_receiver, jid), - FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver) - )`) - return err - } -} - func (mq *MessageQuery) New() *Message { return &Message{ db: mq.db, diff --git a/database/portal.go b/database/portal.go index 008bf6f..25e1dc9 100644 --- a/database/portal.go +++ b/database/portal.go @@ -59,21 +59,6 @@ type PortalQuery struct { log log.Logger } -func (pq *PortalQuery) CreateTable(dbType string) error { - _, err := pq.db.Exec(`CREATE TABLE IF NOT EXISTS portal ( - jid VARCHAR(255), - receiver VARCHAR(255), - mxid VARCHAR(255) UNIQUE, - - name VARCHAR(255) NOT NULL, - topic VARCHAR(255) NOT NULL, - avatar VARCHAR(255) NOT NULL, - - PRIMARY KEY (jid, receiver) - )`) - return err -} - func (pq *PortalQuery) New() *Portal { return &Portal{ db: pq.db, @@ -160,3 +145,10 @@ func (portal *Portal) Update() { portal.log.Warnfln("Failed to update %s: %v", portal.Key, err) } } + +func (portal *Portal) Delete() { + _, err := portal.db.Exec("DELETE FROM portal WHERE jid=$1 AND receiver=$2", portal.Key.JID, portal.Key.Receiver) + if err != nil { + portal.log.Warnfln("Failed to delete %s: %v", portal.Key, err) + } +} diff --git a/database/puppet.go b/database/puppet.go index 439be4a..c411420 100644 --- a/database/puppet.go +++ b/database/puppet.go @@ -29,16 +29,6 @@ type PuppetQuery struct { log log.Logger } -func (pq *PuppetQuery) CreateTable(dbType string) error { - _, err := pq.db.Exec(`CREATE TABLE IF NOT EXISTS puppet ( - jid VARCHAR(255) PRIMARY KEY, - avatar VARCHAR(255), - displayname VARCHAR(255), - name_quality SMALLINT - )`) - return err -} - func (pq *PuppetQuery) New() *Puppet { return &Puppet{ db: pq.db, diff --git a/database/upgrades/2018-09-01-initial-schema.go b/database/upgrades/2018-09-01-initial-schema.go new file mode 100644 index 0000000..ba3ef01 --- /dev/null +++ b/database/upgrades/2018-09-01-initial-schema.go @@ -0,0 +1,74 @@ +package upgrades + +import ( + "database/sql" + "fmt" +) + +func init() { + upgrades[0] = upgrade{"Initial schema", func(dialect Dialect, tx *sql.Tx) error { + var byteType string + if dialect == SQLite { + byteType = "BLOB" + } else { + byteType = "bytea" + } + _, err := tx.Exec(`CREATE TABLE IF NOT EXISTS portal ( + jid VARCHAR(255), + receiver VARCHAR(255), + mxid VARCHAR(255) UNIQUE, + + name VARCHAR(255) NOT NULL, + topic VARCHAR(255) NOT NULL, + avatar VARCHAR(255) NOT NULL, + + PRIMARY KEY (jid, receiver) + )`) + if err != nil { + return err + } + + _, err = tx.Exec(`CREATE TABLE IF NOT EXISTS puppet ( + jid VARCHAR(255) PRIMARY KEY, + avatar VARCHAR(255), + displayname VARCHAR(255), + name_quality SMALLINT + )`) + if err != nil { + return err + } + + _, err = tx.Exec(fmt.Sprintf(`CREATE TABLE IF NOT EXISTS "user" ( + mxid VARCHAR(255) PRIMARY KEY, + jid VARCHAR(255) UNIQUE, + + management_room VARCHAR(255), + + client_id VARCHAR(255), + client_token VARCHAR(255), + server_token VARCHAR(255), + enc_key %[1]s, + mac_key %[1]s + )`, byteType)) + if err != nil { + return err + } + + _, err = tx.Exec(fmt.Sprintf(`CREATE TABLE IF NOT EXISTS message ( + chat_jid VARCHAR(255), + chat_receiver VARCHAR(255), + jid VARCHAR(255), + mxid VARCHAR(255) NOT NULL UNIQUE, + sender VARCHAR(255) NOT NULL, + content %[1]s NOT NULL, + + PRIMARY KEY (chat_jid, chat_receiver, jid), + FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver) + )`, byteType)) + if err != nil { + return err + } + + return nil + }} +} diff --git a/database/upgrades/2019-05-16-message-delete-cascade.go b/database/upgrades/2019-05-16-message-delete-cascade.go new file mode 100644 index 0000000..cb31cd4 --- /dev/null +++ b/database/upgrades/2019-05-16-message-delete-cascade.go @@ -0,0 +1,25 @@ +package upgrades + +import ( + "database/sql" +) + +func init() { + upgrades[1] = upgrade{"Add ON DELETE CASCADE to message table", func(dialect Dialect, tx *sql.Tx) error { + if dialect == SQLite { + // SQLite doesn't support constraint updates, but it isn't that careful with constraints anyway. + return nil + } + _, err := tx.Exec("ALTER TABLE message DROP CONSTRAINT message_chat_jid_fkey") + if err != nil { + return err + } + _, err = tx.Exec(`ALTER TABLE message ADD CONSTRAINT message_chat_jid_fkey + FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver) + ON DELETE CASCADE`) + if err != nil { + return err + } + return nil + }} +} diff --git a/database/upgrades/upgrades.go b/database/upgrades/upgrades.go new file mode 100644 index 0000000..ade1565 --- /dev/null +++ b/database/upgrades/upgrades.go @@ -0,0 +1,87 @@ +package upgrades + +import ( + "database/sql" + "fmt" + "strings" + + log "maunium.net/go/maulogger/v2" +) + +type Dialect int + +const ( + Postgres Dialect = iota + SQLite +) + +type upgradeFunc func(Dialect, *sql.Tx) error + +type upgrade struct { + message string + fn upgradeFunc +} + +var upgrades [2]upgrade + +func getVersion(dialect Dialect, db *sql.DB) (int, error) { + _, err := db.Exec("CREATE TABLE IF NOT EXISTS version (version INTEGER)") + if err != nil { + return -1, err + } + + version := 0 + row := db.QueryRow("SELECT version FROM version LIMIT 1") + if row != nil { + _ = row.Scan(&version) + } + return version, nil +} + +func setVersion(dialect Dialect, tx *sql.Tx, version int) error { + _, err := tx.Exec("DELETE FROM version") + if err != nil { + return err + } + _, err = tx.Exec("INSERT INTO version (version) VALUES ($1)", version) + return err +} + +func Run(log log.Logger, dialectName string, db *sql.DB) error { + var dialect Dialect + switch strings.ToLower(dialectName) { + case "postgres": + dialect = Postgres + case "sqlite3": + dialect = SQLite + default: + return fmt.Errorf("unknown dialect %s", dialectName) + } + + version, err := getVersion(dialect, db) + if err != nil { + return err + } + + log.Infofln("Database currently on v%d, latest: v%d", version, len(upgrades)) + for i, upgrade := range upgrades[version:] { + log.Infofln("Upgrading database to v%d: %s", i+1, upgrade.message) + tx, err := db.Begin() + if err != nil { + return err + } + err = upgrade.fn(dialect, tx) + if err != nil { + return err + } + err = setVersion(dialect, tx, i+1) + if err != nil { + return err + } + err = tx.Commit() + if err != nil { + return err + } + } + return nil +} diff --git a/database/user.go b/database/user.go index a9c0dfd..e831280 100644 --- a/database/user.go +++ b/database/user.go @@ -33,38 +33,6 @@ type UserQuery struct { log log.Logger } -func (uq *UserQuery) CreateTable(dbType string) error { - if strings.ToLower(dbType) == "postgres" { - _, err := uq.db.Exec(`CREATE TABLE IF NOT EXISTS "user" ( - mxid VARCHAR(255) PRIMARY KEY, - jid VARCHAR(255) UNIQUE, - - management_room VARCHAR(255), - - client_id VARCHAR(255), - client_token VARCHAR(255), - server_token VARCHAR(255), - enc_key bytea, - mac_key bytea - )`) - return err - } else { - _, err := uq.db.Exec(`CREATE TABLE IF NOT EXISTS "user" ( - mxid VARCHAR(255) PRIMARY KEY, - jid VARCHAR(255) UNIQUE, - - management_room VARCHAR(255), - - client_id VARCHAR(255), - client_token VARCHAR(255), - server_token VARCHAR(255), - enc_key BLOB, - mac_key BLOB - )`) - return err - } -} - func (uq *UserQuery) New() *User { return &User{ db: uq.db, diff --git a/go.mod b/go.mod index e0526e5..6455139 100644 --- a/go.mod +++ b/go.mod @@ -19,3 +19,9 @@ require ( ) replace gopkg.in/russross/blackfriday.v2 => github.com/russross/blackfriday/v2 v2.0.1 + +replace maunium.net/go/mautrix-appservice => ../mautrix-appservice-go + +replace maunium.net/go/mautrix => ../mautrix-go + +replace github.com/Rhymen/go-whatsapp => ../../Go/go-whatsapp diff --git a/main.go b/main.go index 2a75259..7779101 100644 --- a/main.go +++ b/main.go @@ -147,9 +147,9 @@ func (bridge *Bridge) Init() { } func (bridge *Bridge) Start() { - err := bridge.DB.CreateTables(bridge.Config.AppService.Database.Type) + err := bridge.DB.Init(bridge.Config.AppService.Database.Type) if err != nil { - bridge.Log.Fatalln("Failed to create database tables:", err) + bridge.Log.Fatalln("Failed to initialize database:", err) os.Exit(15) } bridge.Log.Debugln("Starting application service HTTP server") diff --git a/matrix.go b/matrix.go index 38f1a72..c2c5c8d 100644 --- a/matrix.go +++ b/matrix.go @@ -111,6 +111,26 @@ func (mx *MatrixHandler) HandleMembership(evt *mautrix.Event) { if evt.Content.Membership == "invite" && evt.GetStateKey() == mx.as.BotMXID() { mx.HandleBotInvite(evt) } + + portal := mx.bridge.GetPortalByMXID(evt.RoomID) + if portal == nil { + return + } + + user := mx.bridge.GetUserByMXID(types.MatrixUserID(evt.Sender)) + if user == nil || !user.Whitelisted || !user.IsLoggedIn() { + return + } + + if evt.Content.Membership == "leave" { + if evt.GetStateKey() == evt.Sender { + if portal.IsPrivateChat() || evt.Unsigned.PrevContent.Membership == "join" { + portal.HandleMatrixLeave(user) + } + } else { + portal.HandleMatrixKick(user, evt) + } + } } func (mx *MatrixHandler) HandleRoomMetadata(evt *mautrix.Event) { diff --git a/portal.go b/portal.go index a9ecb9b..0e60525 100644 --- a/portal.go +++ b/portal.go @@ -991,3 +991,53 @@ func (portal *Portal) HandleMatrixRedaction(sender *User, evt *mautrix.Event) { portal.log.Debugln("Handled Matrix redaction:", evt) } } + +func (portal *Portal) Delete() { + portal.Portal.Delete() + delete(portal.bridge.portalsByJID, portal.Key) + if len(portal.MXID) > 0 { + delete(portal.bridge.portalsByMXID, portal.MXID) + } +} + +func (portal *Portal) Cleanup(puppetsOnly bool) { + if len(portal.MXID) == 0 { + return + } + if portal.IsPrivateChat() { + _, err := portal.MainIntent().LeaveRoom(portal.MXID) + if err != nil { + portal.log.Warnln("Failed to leave private chat portal with main intent:", err) + } + return + } + intent := portal.MainIntent() + members, err := intent.JoinedMembers(portal.MXID) + if err != nil { + portal.log.Errorln("Failed to get portal members for cleanup:", err) + return + } + for member, _ := range members.Joined { + puppet := portal.bridge.GetPuppetByMXID(member) + if puppet != nil { + _, err = puppet.Intent().LeaveRoom(portal.MXID) + portal.log.Errorln("Error leaving as puppet while cleaning up portal:", err) + } else if !puppetsOnly { + _, err = intent.KickUser(portal.MXID, &mautrix.ReqKickUser{UserID: member, Reason: "Deleting portal"}) + portal.log.Errorln("Error kicking user while cleaning up portal:", err) + } + } +} + +func (portal *Portal) HandleMatrixLeave(sender *User) { + if portal.IsPrivateChat() { + portal.log.Debugln("User left private chat portal, cleaning up and deleting...") + portal.Delete() + portal.Cleanup(false) + return + } +} + +func (portal *Portal) HandleMatrixKick(sender *User, event *mautrix.Event) { + // TODO +} diff --git a/user.go b/user.go index 01e7891..fd14117 100644 --- a/user.go +++ b/user.go @@ -47,6 +47,10 @@ type User struct { } func (bridge *Bridge) GetUserByMXID(userID types.MatrixUserID) *User { + _, isPuppet := bridge.ParsePuppetMXID(userID) + if isPuppet || userID == bridge.Bot.UserID { + return nil + } bridge.usersLock.Lock() defer bridge.usersLock.Unlock() user, ok := bridge.usersByMXID[userID]