package storage_test import ( "context" "reflect" "testing" "time" "github.com/matrix-org/dendrite/federationapi/storage" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" "github.com/matrix-org/util" "github.com/stretchr/testify/assert" ) func mustCreateFederationDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { caches := caching.NewRistrettoCache(8*1024*1024, time.Hour, false) connStr, dbClose := test.PrepareDBConnectionString(t, dbType) ctx := context.Background() cm := sqlutil.NewConnectionManager(nil, config.DatabaseOptions{}) db, err := storage.NewDatabase(ctx, cm, &config.DatabaseOptions{ ConnectionString: config.DataSource(connStr), }, caches, func(server spec.ServerName) bool { return server == "localhost" }) if err != nil { t.Fatalf("NewDatabase returned %s", err) } return db, func() { dbClose() } } func TestExpireEDUs(t *testing.T) { var expireEDUTypes = map[string]time.Duration{ spec.MReceipt: 0, } ctx := context.Background() destinations := map[spec.ServerName]struct{}{"localhost": {}} test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, close := mustCreateFederationDatabase(t, dbType) defer close() // insert some data for i := 0; i < 100; i++ { receipt, err := db.StoreJSON(ctx, "{}") assert.NoError(t, err) err = db.AssociateEDUWithDestinations(ctx, destinations, receipt, spec.MReceipt, expireEDUTypes) assert.NoError(t, err) } // add data without expiry receipt, err := db.StoreJSON(ctx, "{}") assert.NoError(t, err) // m.read_marker gets the default expiry of 24h, so won't be deleted further down in this test err = db.AssociateEDUWithDestinations(ctx, destinations, receipt, "m.read_marker", expireEDUTypes) assert.NoError(t, err) // Delete expired EDUs err = db.DeleteExpiredEDUs(ctx) assert.NoError(t, err) // verify the data is gone data, err := db.GetPendingEDUs(ctx, "localhost", 100) assert.NoError(t, err) assert.Equal(t, 1, len(data)) // check that m.direct_to_device is never expired receipt, err = db.StoreJSON(ctx, "{}") assert.NoError(t, err) err = db.AssociateEDUWithDestinations(ctx, destinations, receipt, spec.MDirectToDevice, expireEDUTypes) assert.NoError(t, err) err = db.DeleteExpiredEDUs(ctx) assert.NoError(t, err) // We should get two EDUs, the m.read_marker and the m.direct_to_device data, err = db.GetPendingEDUs(ctx, "localhost", 100) assert.NoError(t, err) assert.Equal(t, 2, len(data)) }) } func TestOutboundPeeking(t *testing.T) { alice := test.NewUser(t) room := test.NewRoom(t, alice) _, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID) ctx := context.Background() test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, closeDB := mustCreateFederationDatabase(t, dbType) defer closeDB() peekID := util.RandomString(8) var renewalInterval int64 = 1000 // Add outbound peek if err := db.AddOutboundPeek(ctx, serverName, room.ID, peekID, renewalInterval); err != nil { t.Fatal(err) } // select the newly inserted peek outboundPeek1, err := db.GetOutboundPeek(ctx, serverName, room.ID, peekID) if err != nil { t.Fatal(err) } // Assert fields are set as expected if outboundPeek1.PeekID != peekID { t.Fatalf("unexpected outbound peek ID: %s, want %s", outboundPeek1.PeekID, peekID) } if outboundPeek1.RoomID != room.ID { t.Fatalf("unexpected outbound peek room ID: %s, want %s", outboundPeek1.RoomID, peekID) } if outboundPeek1.ServerName != serverName { t.Fatalf("unexpected outbound peek servername: %s, want %s", outboundPeek1.ServerName, serverName) } if outboundPeek1.RenewalInterval != renewalInterval { t.Fatalf("unexpected outbound peek renewal interval: %d, want %d", outboundPeek1.RenewalInterval, renewalInterval) } // Renew the peek if err = db.RenewOutboundPeek(ctx, serverName, room.ID, peekID, 2000); err != nil { t.Fatal(err) } // verify the values changed outboundPeek2, err := db.GetOutboundPeek(ctx, serverName, room.ID, peekID) if err != nil { t.Fatal(err) } if reflect.DeepEqual(outboundPeek1, outboundPeek2) { t.Fatal("expected a change peek, but they are the same") } if outboundPeek1.ServerName != outboundPeek2.ServerName { t.Fatalf("unexpected servername change: %s -> %s", outboundPeek1.ServerName, outboundPeek2.ServerName) } if outboundPeek1.RoomID != outboundPeek2.RoomID { t.Fatalf("unexpected roomID change: %s -> %s", outboundPeek1.RoomID, outboundPeek2.RoomID) } // insert some peeks peekIDs := []string{peekID} for i := 0; i < 5; i++ { peekID = util.RandomString(8) if err = db.AddOutboundPeek(ctx, serverName, room.ID, peekID, 1000); err != nil { t.Fatal(err) } peekIDs = append(peekIDs, peekID) } // Now select them outboundPeeks, err := db.GetOutboundPeeks(ctx, room.ID) if err != nil { t.Fatal(err) } if len(outboundPeeks) != len(peekIDs) { t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(outboundPeeks)) } gotPeekIDs := make([]string, 0, len(outboundPeeks)) for _, p := range outboundPeeks { gotPeekIDs = append(gotPeekIDs, p.PeekID) } assert.ElementsMatch(t, gotPeekIDs, peekIDs) }) } func TestInboundPeeking(t *testing.T) { alice := test.NewUser(t) room := test.NewRoom(t, alice) _, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID) ctx := context.Background() test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, closeDB := mustCreateFederationDatabase(t, dbType) defer closeDB() peekID := util.RandomString(8) var renewalInterval int64 = 1000 // Add inbound peek if err := db.AddInboundPeek(ctx, serverName, room.ID, peekID, renewalInterval); err != nil { t.Fatal(err) } // select the newly inserted peek inboundPeek1, err := db.GetInboundPeek(ctx, serverName, room.ID, peekID) if err != nil { t.Fatal(err) } // Assert fields are set as expected if inboundPeek1.PeekID != peekID { t.Fatalf("unexpected inbound peek ID: %s, want %s", inboundPeek1.PeekID, peekID) } if inboundPeek1.RoomID != room.ID { t.Fatalf("unexpected inbound peek room ID: %s, want %s", inboundPeek1.RoomID, peekID) } if inboundPeek1.ServerName != serverName { t.Fatalf("unexpected inbound peek servername: %s, want %s", inboundPeek1.ServerName, serverName) } if inboundPeek1.RenewalInterval != renewalInterval { t.Fatalf("unexpected inbound peek renewal interval: %d, want %d", inboundPeek1.RenewalInterval, renewalInterval) } // Renew the peek if err = db.RenewInboundPeek(ctx, serverName, room.ID, peekID, 2000); err != nil { t.Fatal(err) } // verify the values changed inboundPeek2, err := db.GetInboundPeek(ctx, serverName, room.ID, peekID) if err != nil { t.Fatal(err) } if reflect.DeepEqual(inboundPeek1, inboundPeek2) { t.Fatal("expected a change peek, but they are the same") } if inboundPeek1.ServerName != inboundPeek2.ServerName { t.Fatalf("unexpected servername change: %s -> %s", inboundPeek1.ServerName, inboundPeek2.ServerName) } if inboundPeek1.RoomID != inboundPeek2.RoomID { t.Fatalf("unexpected roomID change: %s -> %s", inboundPeek1.RoomID, inboundPeek2.RoomID) } // insert some peeks peekIDs := []string{peekID} for i := 0; i < 5; i++ { peekID = util.RandomString(8) if err = db.AddInboundPeek(ctx, serverName, room.ID, peekID, 1000); err != nil { t.Fatal(err) } peekIDs = append(peekIDs, peekID) } // Now select them inboundPeeks, err := db.GetInboundPeeks(ctx, room.ID) if err != nil { t.Fatal(err) } if len(inboundPeeks) != len(peekIDs) { t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(inboundPeeks)) } gotPeekIDs := make([]string, 0, len(inboundPeeks)) for _, p := range inboundPeeks { gotPeekIDs = append(gotPeekIDs, p.PeekID) } assert.ElementsMatch(t, gotPeekIDs, peekIDs) }) } func TestServersAssumedOffline(t *testing.T) { server1 := spec.ServerName("server1") server2 := spec.ServerName("server2") test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, closeDB := mustCreateFederationDatabase(t, dbType) defer closeDB() // Set server1 & server2 as assumed offline. err := db.SetServerAssumedOffline(context.Background(), server1) assert.Nil(t, err) err = db.SetServerAssumedOffline(context.Background(), server2) assert.Nil(t, err) // Ensure both servers are assumed offline. isOffline, err := db.IsServerAssumedOffline(context.Background(), server1) assert.Nil(t, err) assert.True(t, isOffline) isOffline, err = db.IsServerAssumedOffline(context.Background(), server2) assert.Nil(t, err) assert.True(t, isOffline) // Set server1 as not assumed offline. err = db.RemoveServerAssumedOffline(context.Background(), server1) assert.Nil(t, err) // Ensure both servers have correct state. isOffline, err = db.IsServerAssumedOffline(context.Background(), server1) assert.Nil(t, err) assert.False(t, isOffline) isOffline, err = db.IsServerAssumedOffline(context.Background(), server2) assert.Nil(t, err) assert.True(t, isOffline) // Re-set server1 as assumed offline. err = db.SetServerAssumedOffline(context.Background(), server1) assert.Nil(t, err) // Ensure server1 is assumed offline. isOffline, err = db.IsServerAssumedOffline(context.Background(), server1) assert.Nil(t, err) assert.True(t, isOffline) err = db.RemoveAllServersAssumedOffline(context.Background()) assert.Nil(t, err) // Ensure both servers have correct state. isOffline, err = db.IsServerAssumedOffline(context.Background(), server1) assert.Nil(t, err) assert.False(t, isOffline) isOffline, err = db.IsServerAssumedOffline(context.Background(), server2) assert.Nil(t, err) assert.False(t, isOffline) }) } func TestRelayServersStored(t *testing.T) { server := spec.ServerName("server") relayServer1 := spec.ServerName("relayserver1") relayServer2 := spec.ServerName("relayserver2") test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { db, closeDB := mustCreateFederationDatabase(t, dbType) defer closeDB() err := db.P2PAddRelayServersForServer(context.Background(), server, []spec.ServerName{relayServer1}) assert.Nil(t, err) relayServers, err := db.P2PGetRelayServersForServer(context.Background(), server) assert.Nil(t, err) assert.Equal(t, relayServer1, relayServers[0]) err = db.P2PRemoveRelayServersForServer(context.Background(), server, []spec.ServerName{relayServer1}) assert.Nil(t, err) relayServers, err = db.P2PGetRelayServersForServer(context.Background(), server) assert.Nil(t, err) assert.Zero(t, len(relayServers)) err = db.P2PAddRelayServersForServer(context.Background(), server, []spec.ServerName{relayServer1, relayServer2}) assert.Nil(t, err) relayServers, err = db.P2PGetRelayServersForServer(context.Background(), server) assert.Nil(t, err) assert.Equal(t, relayServer1, relayServers[0]) assert.Equal(t, relayServer2, relayServers[1]) err = db.P2PRemoveAllRelayServersForServer(context.Background(), server) assert.Nil(t, err) relayServers, err = db.P2PGetRelayServersForServer(context.Background(), server) assert.Nil(t, err) assert.Zero(t, len(relayServers)) }) }