Merge pull request #43 from RennerDev/master

Implemented postgres
This commit is contained in:
Tulir Asokan 2019-03-14 00:37:00 +02:00 committed by GitHub
commit 67a041c06d
7 changed files with 90 additions and 56 deletions

View file

@ -19,6 +19,7 @@ package database
import ( import (
"database/sql" "database/sql"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
log "maunium.net/go/maulogger/v2" log "maunium.net/go/maulogger/v2"
@ -34,8 +35,8 @@ type Database struct {
Message *MessageQuery Message *MessageQuery
} }
func New(file string) (*Database, error) { func New(dbType string, uri string) (*Database, error) {
conn, err := sql.Open("sqlite3", file) conn, err := sql.Open(dbType, uri)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -63,20 +64,20 @@ func New(file string) (*Database, error) {
return db, nil return db, nil
} }
func (db *Database) CreateTables() error { func (db *Database) CreateTables(dbType string) error {
err := db.User.CreateTable() err := db.User.CreateTable(dbType)
if err != nil { if err != nil {
return err return err
} }
err = db.Portal.CreateTable() err = db.Portal.CreateTable(dbType)
if err != nil { if err != nil {
return err return err
} }
err = db.Puppet.CreateTable() err = db.Puppet.CreateTable(dbType)
if err != nil { if err != nil {
return err return err
} }
err = db.Message.CreateTable() err = db.Message.CreateTable(dbType)
if err != nil { if err != nil {
return err return err
} }

View file

@ -18,6 +18,7 @@ package database
import ( import (
"bytes" "bytes"
"strings"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
@ -33,19 +34,34 @@ type MessageQuery struct {
log log.Logger log log.Logger
} }
func (mq *MessageQuery) CreateTable() error { func (mq *MessageQuery) CreateTable(dbType string) error {
_, err := mq.db.Exec(`CREATE TABLE IF NOT EXISTS message ( if strings.ToLower(dbType) == "postgres" {
chat_jid VARCHAR(25), _, err := mq.db.Exec(`CREATE TABLE IF NOT EXISTS message (
chat_receiver VARCHAR(25), chat_jid VARCHAR(255),
jid VARCHAR(255), chat_receiver VARCHAR(255),
mxid VARCHAR(255) NOT NULL UNIQUE, jid VARCHAR(255),
sender VARCHAR(25) NOT NULL, mxid VARCHAR(255) NOT NULL UNIQUE,
content BLOB NOT NULL, sender VARCHAR(255) NOT NULL,
content bytea NOT NULL,
PRIMARY KEY (chat_jid, chat_receiver, jid), PRIMARY KEY (chat_jid, chat_receiver, jid),
FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver) FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver)
)`) )`)
return err
} else {
_, err := mq.db.Exec(`CREATE TABLE IF NOT EXISTS message (
chat_jid VARCHAR(255),
chat_receiver VARCHAR(255),
jid VARCHAR(255),
mxid VARCHAR(255) NOT NULL UNIQUE,
sender VARCHAR(255) NOT NULL,
content BLOB NOT NULL,
PRIMARY KEY (chat_jid, chat_receiver, jid),
FOREIGN KEY (chat_jid, chat_receiver) REFERENCES portal(jid, receiver)
)`)
return err return err
}
} }
func (mq *MessageQuery) New() *Message { func (mq *MessageQuery) New() *Message {
@ -56,7 +72,7 @@ func (mq *MessageQuery) New() *Message {
} }
func (mq *MessageQuery) GetAll(chat PortalKey) (messages []*Message) { func (mq *MessageQuery) GetAll(chat PortalKey) (messages []*Message) {
rows, err := mq.db.Query("SELECT * FROM message WHERE chat_jid=? AND chat_receiver=?", chat.JID, chat.Receiver) rows, err := mq.db.Query("SELECT * FROM message WHERE chat_jid=$1 AND chat_receiver=$2", chat.JID, chat.Receiver)
if err != nil || rows == nil { if err != nil || rows == nil {
return nil return nil
} }
@ -68,11 +84,11 @@ func (mq *MessageQuery) GetAll(chat PortalKey) (messages []*Message) {
} }
func (mq *MessageQuery) GetByJID(chat PortalKey, jid types.WhatsAppMessageID) *Message { func (mq *MessageQuery) GetByJID(chat PortalKey, jid types.WhatsAppMessageID) *Message {
return mq.get("SELECT * FROM message WHERE chat_jid=? AND chat_receiver=? AND jid=?", chat.JID, chat.Receiver, jid) return mq.get("SELECT * FROM message WHERE chat_jid=$1 AND chat_receiver=$2 AND jid=$3", chat.JID, chat.Receiver, jid)
} }
func (mq *MessageQuery) GetByMXID(mxid types.MatrixEventID) *Message { func (mq *MessageQuery) GetByMXID(mxid types.MatrixEventID) *Message {
return mq.get("SELECT * FROM message WHERE mxid=?", mxid) return mq.get("SELECT * FROM message WHERE mxid=$1", mxid)
} }
func (mq *MessageQuery) get(query string, args ...interface{}) *Message { func (mq *MessageQuery) get(query string, args ...interface{}) *Message {
@ -130,7 +146,7 @@ func (msg *Message) encodeBinaryContent() []byte {
} }
func (msg *Message) Insert() { func (msg *Message) Insert() {
_, err := msg.db.Exec("INSERT INTO message VALUES (?, ?, ?, ?, ?, ?)", msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID, msg.Sender, msg.encodeBinaryContent()) _, err := msg.db.Exec("INSERT INTO message VALUES ($1, $2, $3, $4, $5, $6)", msg.Chat.JID, msg.Chat.Receiver, msg.JID, msg.MXID, msg.Sender, msg.encodeBinaryContent())
if err != nil { if err != nil {
msg.log.Warnfln("Failed to insert %s@%s: %v", msg.Chat, msg.JID, err) msg.log.Warnfln("Failed to insert %s@%s: %v", msg.Chat, msg.JID, err)
} }

View file

@ -59,18 +59,17 @@ type PortalQuery struct {
log log.Logger log log.Logger
} }
func (pq *PortalQuery) CreateTable() error { func (pq *PortalQuery) CreateTable(dbType string) error {
_, err := pq.db.Exec(`CREATE TABLE IF NOT EXISTS portal ( _, err := pq.db.Exec(`CREATE TABLE IF NOT EXISTS portal (
jid VARCHAR(25), jid VARCHAR(255),
receiver VARCHAR(25), receiver VARCHAR(255),
mxid VARCHAR(255) UNIQUE, mxid VARCHAR(255) UNIQUE,
name VARCHAR(255) NOT NULL, name VARCHAR(255) NOT NULL,
topic VARCHAR(255) NOT NULL, topic VARCHAR(255) NOT NULL,
avatar VARCHAR(255) NOT NULL, avatar VARCHAR(255) NOT NULL,
PRIMARY KEY (jid, receiver), PRIMARY KEY (jid, receiver)
FOREIGN KEY (receiver) REFERENCES user(mxid)
)`) )`)
return err return err
} }
@ -95,11 +94,11 @@ func (pq *PortalQuery) GetAll() (portals []*Portal) {
} }
func (pq *PortalQuery) GetByJID(key PortalKey) *Portal { func (pq *PortalQuery) GetByJID(key PortalKey) *Portal {
return pq.get("SELECT * FROM portal WHERE jid=? AND receiver=?", key.JID, key.Receiver) return pq.get("SELECT * FROM portal WHERE jid=$1 AND receiver=$2", key.JID, key.Receiver)
} }
func (pq *PortalQuery) GetByMXID(mxid types.MatrixRoomID) *Portal { func (pq *PortalQuery) GetByMXID(mxid types.MatrixRoomID) *Portal {
return pq.get("SELECT * FROM portal WHERE mxid=?", mxid) return pq.get("SELECT * FROM portal WHERE mxid=$1", mxid)
} }
func (pq *PortalQuery) get(query string, args ...interface{}) *Portal { func (pq *PortalQuery) get(query string, args ...interface{}) *Portal {
@ -143,7 +142,7 @@ func (portal *Portal) mxidPtr() *string {
} }
func (portal *Portal) Insert() { func (portal *Portal) Insert() {
_, err := portal.db.Exec("INSERT INTO portal VALUES (?, ?, ?, ?, ?, ?)", _, err := portal.db.Exec("INSERT INTO portal VALUES ($1, $2, $3, $4, $5, $6)",
portal.Key.JID, portal.Key.Receiver, portal.mxidPtr(), portal.Name, portal.Topic, portal.Avatar) portal.Key.JID, portal.Key.Receiver, portal.mxidPtr(), portal.Name, portal.Topic, portal.Avatar)
if err != nil { if err != nil {
portal.log.Warnfln("Failed to insert %s: %v", portal.Key, err) portal.log.Warnfln("Failed to insert %s: %v", portal.Key, err)
@ -155,7 +154,7 @@ func (portal *Portal) Update() {
if len(portal.MXID) > 0 { if len(portal.MXID) > 0 {
mxid = &portal.MXID mxid = &portal.MXID
} }
_, err := portal.db.Exec("UPDATE portal SET mxid=?, name=?, topic=?, avatar=? WHERE jid=? AND receiver=?", _, err := portal.db.Exec("UPDATE portal SET mxid=$1, name=$2, topic=$3, avatar=$4 WHERE jid=$5 AND receiver=$6",
mxid, portal.Name, portal.Topic, portal.Avatar, portal.Key.JID, portal.Key.Receiver) mxid, portal.Name, portal.Topic, portal.Avatar, portal.Key.JID, portal.Key.Receiver)
if err != nil { if err != nil {
portal.log.Warnfln("Failed to update %s: %v", portal.Key, err) portal.log.Warnfln("Failed to update %s: %v", portal.Key, err)

View file

@ -29,12 +29,12 @@ type PuppetQuery struct {
log log.Logger log log.Logger
} }
func (pq *PuppetQuery) CreateTable() error { func (pq *PuppetQuery) CreateTable(dbType string) error {
_, err := pq.db.Exec(`CREATE TABLE IF NOT EXISTS puppet ( _, err := pq.db.Exec(`CREATE TABLE IF NOT EXISTS puppet (
jid VARCHAR(25) PRIMARY KEY, jid VARCHAR(255) PRIMARY KEY,
avatar VARCHAR(255), avatar VARCHAR(255),
displayname VARCHAR(255), displayname VARCHAR(255),
name_quality TINYINT name_quality SMALLINT
)`) )`)
return err return err
} }
@ -59,7 +59,7 @@ func (pq *PuppetQuery) GetAll() (puppets []*Puppet) {
} }
func (pq *PuppetQuery) Get(jid types.WhatsAppID) *Puppet { func (pq *PuppetQuery) Get(jid types.WhatsAppID) *Puppet {
row := pq.db.QueryRow("SELECT * FROM puppet WHERE jid=?", jid) row := pq.db.QueryRow("SELECT * FROM puppet WHERE jid=$1", jid)
if row == nil { if row == nil {
return nil return nil
} }
@ -93,7 +93,7 @@ func (puppet *Puppet) Scan(row Scannable) *Puppet {
} }
func (puppet *Puppet) Insert() { func (puppet *Puppet) Insert() {
_, err := puppet.db.Exec("INSERT INTO puppet VALUES (?, ?, ?, ?)", _, err := puppet.db.Exec("INSERT INTO puppet VALUES ($1, $2, $3, $4)",
puppet.JID, puppet.Avatar, puppet.Displayname, puppet.NameQuality) puppet.JID, puppet.Avatar, puppet.Displayname, puppet.NameQuality)
if err != nil { if err != nil {
puppet.log.Warnfln("Failed to insert %s: %v", puppet.JID, err) puppet.log.Warnfln("Failed to insert %s: %v", puppet.JID, err)
@ -101,7 +101,7 @@ func (puppet *Puppet) Insert() {
} }
func (puppet *Puppet) Update() { func (puppet *Puppet) Update() {
_, err := puppet.db.Exec("UPDATE puppet SET displayname=?, name_quality=?, avatar=? WHERE jid=?", _, err := puppet.db.Exec("UPDATE puppet SET displayname=$1, name_quality=$2, avatar=$3 WHERE jid=$4",
puppet.Displayname, puppet.NameQuality, puppet.Avatar, puppet.JID) puppet.Displayname, puppet.NameQuality, puppet.Avatar, puppet.JID)
if err != nil { if err != nil {
puppet.log.Warnfln("Failed to update %s->%s: %v", puppet.JID, err) puppet.log.Warnfln("Failed to update %s->%s: %v", puppet.JID, err)

View file

@ -33,20 +33,36 @@ type UserQuery struct {
log log.Logger log log.Logger
} }
func (uq *UserQuery) CreateTable() error { func (uq *UserQuery) CreateTable(dbType string) error {
_, err := uq.db.Exec(`CREATE TABLE IF NOT EXISTS user ( if strings.ToLower(dbType) == "postgres" {
mxid VARCHAR(255) PRIMARY KEY, _, err := uq.db.Exec(`CREATE TABLE IF NOT EXISTS "user" (
jid VARCHAR(25) UNIQUE, mxid VARCHAR(255) PRIMARY KEY,
jid VARCHAR(255) UNIQUE,
management_room VARCHAR(255), management_room VARCHAR(255),
client_id VARCHAR(255), client_id VARCHAR(255),
client_token VARCHAR(255), client_token VARCHAR(255),
server_token VARCHAR(255), server_token VARCHAR(255),
enc_key BLOB, enc_key bytea,
mac_key BLOB mac_key bytea
)`) )`)
return err return err
} else {
_, err := uq.db.Exec(`CREATE TABLE IF NOT EXISTS "user" (
mxid VARCHAR(255) PRIMARY KEY,
jid VARCHAR(255) UNIQUE,
management_room VARCHAR(255),
client_id VARCHAR(255),
client_token VARCHAR(255),
server_token VARCHAR(255),
enc_key BLOB,
mac_key BLOB
)`)
return err
}
} }
func (uq *UserQuery) New() *User { func (uq *UserQuery) New() *User {
@ -57,7 +73,7 @@ func (uq *UserQuery) New() *User {
} }
func (uq *UserQuery) GetAll() (users []*User) { func (uq *UserQuery) GetAll() (users []*User) {
rows, err := uq.db.Query("SELECT * FROM user") rows, err := uq.db.Query(`SELECT * FROM "user"`)
if err != nil || rows == nil { if err != nil || rows == nil {
return nil return nil
} }
@ -69,7 +85,7 @@ func (uq *UserQuery) GetAll() (users []*User) {
} }
func (uq *UserQuery) GetByMXID(userID types.MatrixUserID) *User { func (uq *UserQuery) GetByMXID(userID types.MatrixUserID) *User {
row := uq.db.QueryRow("SELECT * FROM user WHERE mxid=?", userID) row := uq.db.QueryRow(`SELECT * FROM "user" WHERE mxid=$1`, userID)
if row == nil { if row == nil {
return nil return nil
} }
@ -77,7 +93,7 @@ func (uq *UserQuery) GetByMXID(userID types.MatrixUserID) *User {
} }
func (uq *UserQuery) GetByJID(userID types.WhatsAppID) *User { func (uq *UserQuery) GetByJID(userID types.WhatsAppID) *User {
row := uq.db.QueryRow("SELECT * FROM user WHERE jid=?", stripSuffix(userID)) row := uq.db.QueryRow(`SELECT * FROM "user" WHERE jid=$1`, stripSuffix(userID))
if row == nil { if row == nil {
return nil return nil
} }
@ -150,7 +166,7 @@ func (user *User) sessionUnptr() (sess whatsapp.Session) {
func (user *User) Insert() { func (user *User) Insert() {
sess := user.sessionUnptr() sess := user.sessionUnptr()
_, err := user.db.Exec("INSERT INTO user VALUES (?, ?, ?, ?, ?, ?, ?, ?)", user.MXID, user.jidPtr(), _, err := user.db.Exec(`INSERT INTO "user" VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`, user.MXID, user.jidPtr(),
user.ManagementRoom, user.ManagementRoom,
sess.ClientId, sess.ClientToken, sess.ServerToken, sess.EncKey, sess.MacKey) sess.ClientId, sess.ClientToken, sess.ServerToken, sess.EncKey, sess.MacKey)
if err != nil { if err != nil {
@ -160,7 +176,7 @@ func (user *User) Insert() {
func (user *User) Update() { func (user *User) Update() {
sess := user.sessionUnptr() sess := user.sessionUnptr()
_, err := user.db.Exec("UPDATE user SET jid=?, management_room=?, client_id=?, client_token=?, server_token=?, enc_key=?, mac_key=? WHERE mxid=?", _, err := user.db.Exec(`UPDATE "user" SET jid=$1, management_room=$2, client_id=$3, client_token=$4, server_token=$5, enc_key=$6, mac_key=$7 WHERE mxid=$8`,
user.jidPtr(), user.ManagementRoom, user.jidPtr(), user.ManagementRoom,
sess.ClientId, sess.ClientToken, sess.ServerToken, sess.EncKey, sess.MacKey, sess.ClientId, sess.ClientToken, sess.ServerToken, sess.EncKey, sess.MacKey,
user.MXID) user.MXID)

View file

@ -20,6 +20,7 @@ appservice:
# The database type. Only "sqlite3" is supported. # The database type. Only "sqlite3" is supported.
type: sqlite3 type: sqlite3
# The database URI. Usually file name. https://github.com/mattn/go-sqlite3#connection-string # The database URI. Usually file name. https://github.com/mattn/go-sqlite3#connection-string
# postres example: postgres://synapse:changeme@db/whatsapp?sslmode=disable
uri: mautrix-whatsapp.db uri: mautrix-whatsapp.db
# Path to the Matrix room state store. # Path to the Matrix room state store.
state_store_path: ./mx-state.json state_store_path: ./mx-state.json

View file

@ -133,7 +133,7 @@ func (bridge *Bridge) Init() {
bridge.AS.StateStore = bridge.StateStore bridge.AS.StateStore = bridge.StateStore
bridge.Log.Debugln("Initializing database") bridge.Log.Debugln("Initializing database")
bridge.DB, err = database.New(bridge.Config.AppService.Database.URI) bridge.DB, err = database.New(bridge.Config.AppService.Database.Type, bridge.Config.AppService.Database.URI)
if err != nil { if err != nil {
bridge.Log.Fatalln("Failed to initialize database:", err) bridge.Log.Fatalln("Failed to initialize database:", err)
os.Exit(14) os.Exit(14)
@ -147,7 +147,7 @@ func (bridge *Bridge) Init() {
} }
func (bridge *Bridge) Start() { func (bridge *Bridge) Start() {
err := bridge.DB.CreateTables() err := bridge.DB.CreateTables(bridge.Config.AppService.Database.Type)
if err != nil { if err != nil {
bridge.Log.Fatalln("Failed to create database tables:", err) bridge.Log.Fatalln("Failed to create database tables:", err)
os.Exit(15) os.Exit(15)
@ -185,6 +185,7 @@ func (bridge *Bridge) UpdateBotProfile() {
} }
func (bridge *Bridge) StartUsers() { func (bridge *Bridge) StartUsers() {
bridge.Log.Debugln("Starting users")
for _, user := range bridge.GetAllUsers() { for _, user := range bridge.GetAllUsers() {
go user.Connect(false) go user.Connect(false)
} }