diff --git a/CHANGELOG.md b/CHANGELOG.md index 607403c..cc68164 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ [@abmantis] in [#452]). This can be used to enable incoming typing notifications without enabling Matrix presence (WhatsApp only sends typing notifications if you're online). +* Added checks to prevent sharing the database with unrelated software. * Exposed maximum database connection idle time and lifetime options. * Fixed syncing group topics. To get topics into existing portals on Matrix, you can use `!wa sync groups`. diff --git a/database/database.go b/database/database.go index 67c0912..7640871 100644 --- a/database/database.go +++ b/database/database.go @@ -23,10 +23,9 @@ import ( "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" - + "go.mau.fi/whatsmeow/store/sqlstore" log "maunium.net/go/maulogger/v2" - "go.mau.fi/whatsmeow/store/sqlstore" "maunium.net/go/mautrix-whatsapp/config" "maunium.net/go/mautrix-whatsapp/database/upgrades" ) diff --git a/database/upgrades/upgrades.go b/database/upgrades/upgrades.go index d32639c..f4b1ded 100644 --- a/database/upgrades/upgrades.go +++ b/database/upgrades/upgrades.go @@ -44,7 +44,12 @@ const NumberOfUpgrades = 39 var upgrades [NumberOfUpgrades]upgrade -var UnsupportedDatabaseVersion = fmt.Errorf("unsupported database schema version") +var ErrUnsupportedDatabaseVersion = fmt.Errorf("unsupported database schema version") +var ErrForeignTables = fmt.Errorf("the database contains foreign tables") +var ErrNotOwned = fmt.Errorf("the database is owned by") +var IgnoreForeignTables = false + +const databaseOwner = "mautrix-whatsapp" func GetVersion(db *sql.DB) (int, error) { _, err := db.Exec("CREATE TABLE IF NOT EXISTS version (version INTEGER)") @@ -60,6 +65,49 @@ func GetVersion(db *sql.DB) (int, error) { return version, nil } +const tableExistsPostgres = "SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_name=$1)" +const tableExistsSQLite = "SELECT EXISTS(SELECT 1 FROM sqlite_master WHERE type='table' AND table_name=$1)" + +func tableExists(dialect Dialect, db *sql.DB, table string) (exists bool) { + if dialect == SQLite { + _ = db.QueryRow(tableExistsSQLite, table).Scan(&exists) + } else if dialect == Postgres { + _ = db.QueryRow(tableExistsPostgres, table).Scan(&exists) + } + return +} + +const createOwnerTable = ` +CREATE TABLE IF NOT EXISTS database_owner ( + key INTEGER PRIMARY KEY DEFAULT 0, + owner TEXT NOT NULL +) +` + +func CheckDatabaseOwner(dialect Dialect, db *sql.DB) error { + var owner string + if !IgnoreForeignTables { + if tableExists(dialect, db, "state_groups_state") { + return fmt.Errorf("%w (found state_groups_state, likely belonging to Synapse)", ErrForeignTables) + } else if tableExists(dialect, db, "goose_db_version") { + return fmt.Errorf("%w (found goose_db_version, possibly belonging to Dendrite)", ErrForeignTables) + } + } + if _, err := db.Exec(createOwnerTable); err != nil { + return fmt.Errorf("failed to ensure database owner table exists: %w", err) + } else if err = db.QueryRow("SELECT owner FROM database_owner WHERE key=0").Scan(&owner); errors.Is(err, sql.ErrNoRows) { + _, err = db.Exec("INSERT INTO database_owner (owner) VALUES ($1)", databaseOwner) + if err != nil { + return fmt.Errorf("failed to insert database owner: %w", err) + } + } else if err != nil { + return fmt.Errorf("failed to check database owner: %w", err) + } else if owner != databaseOwner { + return fmt.Errorf("%w %s", ErrNotOwned, owner) + } + return nil +} + func SetVersion(tx *sql.Tx, version int) error { _, err := tx.Exec("DELETE FROM version") if err != nil { @@ -90,13 +138,18 @@ func Run(log log.Logger, dialectName string, db *sql.DB) error { return fmt.Errorf("unknown dialect %s", dialectName) } + err := CheckDatabaseOwner(dialect, db) + if err != nil { + return err + } + version, err := GetVersion(db) if err != nil { return err } if version > NumberOfUpgrades { - return fmt.Errorf("%w: currently on v%d, latest known: v%d", UnsupportedDatabaseVersion, version, NumberOfUpgrades) + return fmt.Errorf("%w: currently on v%d, latest known: v%d", ErrUnsupportedDatabaseVersion, version, NumberOfUpgrades) } log.Infofln("Database currently on v%d, latest: v%d", version, NumberOfUpgrades) diff --git a/main.go b/main.go index 6a68b01..981e5e5 100644 --- a/main.go +++ b/main.go @@ -103,7 +103,8 @@ var dontSaveConfig = flag.MakeFull("n", "no-update", "Don't save updated config var registrationPath = flag.MakeFull("r", "registration", "The path where to save the appservice registration.", "registration.yaml").String() var generateRegistration = flag.MakeFull("g", "generate-registration", "Generate registration and quit.", "false").Bool() var version = flag.MakeFull("v", "version", "View bridge version and quit.", "false").Bool() -var ignoreUnsupportedDatabase = flag.Make().LongKey("ignore-unsupported-database").Usage("Run even if database is too new").Default("false").Bool() +var ignoreUnsupportedDatabase = flag.Make().LongKey("ignore-unsupported-database").Usage("Run even if the database schema is too new").Default("false").Bool() +var ignoreForeignTables = flag.Make().LongKey("ignore-foreign-tables").Usage("Run even if the database contains tables from other programs (like Synapse)").Default("false").Bool() var migrateFrom = flag.Make().LongKey("migrate-db").Usage("Source database type and URI to migrate from.").Bool() var wantHelp, _ = flag.MakeHelpFlag() @@ -299,8 +300,15 @@ func (bridge *Bridge) Init() { func (bridge *Bridge) Start() { bridge.Log.Debugln("Running database upgrades") err := bridge.DB.Init() - if err != nil && (err != upgrades.UnsupportedDatabaseVersion || !*ignoreUnsupportedDatabase) { + if err != nil && (!errors.Is(err, upgrades.ErrUnsupportedDatabaseVersion) || !*ignoreUnsupportedDatabase) { bridge.Log.Fatalln("Failed to initialize database:", err) + if errors.Is(err, upgrades.ErrForeignTables) { + bridge.Log.Infoln("You can use --ignore-foreign-tables to ignore this error") + } else if errors.Is(err, upgrades.ErrNotOwned) { + bridge.Log.Infoln("Sharing the same database with different programs is not supported") + } else if errors.Is(err, upgrades.ErrUnsupportedDatabaseVersion) { + bridge.Log.Infoln("Downgrading the bridge is not supported") + } os.Exit(15) } bridge.Log.Debugln("Checking connection to homeserver") @@ -517,6 +525,7 @@ func main() { fmt.Println(VersionString) return } + upgrades.IgnoreForeignTables = *ignoreForeignTables (&Bridge{ usersByMXID: make(map[id.UserID]*User),