diff --git a/clientapi/clientapi_test.go b/clientapi/clientapi_test.go index 29ff58665..fffe4b6b8 100644 --- a/clientapi/clientapi_test.go +++ b/clientapi/clientapi_test.go @@ -958,7 +958,8 @@ func TestCapabilities(t *testing.T) { cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) // Needed to create accounts - rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, nil, caching.DisableMetrics) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) // We mostly need the rsAPI/userAPI for this test, so nil for other APIs etc. @@ -1005,7 +1006,8 @@ func TestTurnserver(t *testing.T) { cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) // Needed to create accounts - rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, nil, caching.DisableMetrics) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) //rsAPI.SetUserAPI(userAPI) @@ -1103,7 +1105,8 @@ func Test3PID(t *testing.T) { cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) // Needed to create accounts - rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, nil, caching.DisableMetrics) + caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, &natsInstance, caches, caching.DisableMetrics) rsAPI.SetFederationAPI(nil, nil) userAPI := userapi.NewInternalAPI(processCtx, cfg, cm, &natsInstance, rsAPI, nil, caching.DisableMetrics, testIsBlacklistedOrBackingOff) // We mostly need the rsAPI/userAPI for this test, so nil for other APIs etc. diff --git a/roomserver/acls/acls.go b/roomserver/acls/acls.go index e247c7553..660f4f3bb 100644 --- a/roomserver/acls/acls.go +++ b/roomserver/acls/acls.go @@ -23,7 +23,7 @@ import ( "strings" "sync" - "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" "github.com/sirupsen/logrus" @@ -34,10 +34,10 @@ const MRoomServerACL = "m.room.server_acl" type ServerACLDatabase interface { // GetKnownRooms returns a list of all rooms we know about. GetKnownRooms(ctx context.Context) ([]string, error) - // GetStateEvent returns the state event of a given type for a given room with a given state key - // If no event could be found, returns nil - // If there was an issue during the retrieval, returns an error - GetStateEvent(ctx context.Context, roomID, evType, stateKey string) (*types.HeaderedEvent, error) + + // GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match. + // If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned. + GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error) } type ServerACLs struct { @@ -58,15 +58,14 @@ func NewServerACLs(db ServerACLDatabase) *ServerACLs { // For each room, let's see if we have a server ACL state event. If we // do then we'll process it into memory so that we have the regexes to // hand. - for _, room := range rooms { - state, err := db.GetStateEvent(ctx, room, MRoomServerACL, "") - if err != nil { - logrus.WithError(err).Errorf("Failed to get server ACLs for room %q", room) - continue - } - if state != nil { - acls.OnServerACLUpdate(state.PDU) - } + + events, err := db.GetBulkStateContent(ctx, rooms, []gomatrixserverlib.StateKeyTuple{{EventType: MRoomServerACL, StateKey: ""}}, false) + if err != nil { + logrus.WithError(err).Errorf("Failed to get server ACLs for all rooms: %q", err) + } + + for _, event := range events { + acls.OnServerACLUpdate(event) } return acls } @@ -90,9 +89,9 @@ func compileACLRegex(orig string) (*regexp.Regexp, error) { return regexp.Compile(escaped) } -func (s *ServerACLs) OnServerACLUpdate(state gomatrixserverlib.PDU) { +func (s *ServerACLs) OnServerACLUpdate(strippedEvent tables.StrippedEvent) { acls := &serverACL{} - if err := json.Unmarshal(state.Content(), &acls.ServerACL); err != nil { + if err := json.Unmarshal([]byte(strippedEvent.ContentValue), &acls.ServerACL); err != nil { logrus.WithError(err).Errorf("Failed to unmarshal state content for server ACLs") return } @@ -118,10 +117,10 @@ func (s *ServerACLs) OnServerACLUpdate(state gomatrixserverlib.PDU) { "allow_ip_literals": acls.AllowIPLiterals, "num_allowed": len(acls.allowedRegexes), "num_denied": len(acls.deniedRegexes), - }).Debugf("Updating server ACLs for %q", state.RoomID()) + }).Debugf("Updating server ACLs for %q", strippedEvent.RoomID) s.aclsMutex.Lock() defer s.aclsMutex.Unlock() - s.acls[state.RoomID().String()] = acls + s.acls[strippedEvent.RoomID] = acls } func (s *ServerACLs) IsServerBannedFromRoom(serverName spec.ServerName, roomID string) bool { diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 1d9208434..657ca8719 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -24,6 +24,7 @@ import ( "fmt" "time" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/tidwall/gjson" "github.com/matrix-org/gomatrixserverlib" @@ -509,7 +510,13 @@ func (r *Inputer) processRoomEvent( logrus.WithError(err).Error("failed to get server ACLs") } if aclEvent != nil { - r.ACLs.OnServerACLUpdate(aclEvent) + strippedEvent := tables.StrippedEvent{ + RoomID: aclEvent.RoomID().String(), + EventType: aclEvent.Type(), + StateKey: *aclEvent.StateKey(), + ContentValue: string(aclEvent.Content()), + } + r.ACLs.OnServerACLUpdate(strippedEvent) } } } diff --git a/roomserver/producers/roomevent.go b/roomserver/producers/roomevent.go index af7e10580..894e6d81b 100644 --- a/roomserver/producers/roomevent.go +++ b/roomserver/producers/roomevent.go @@ -17,6 +17,7 @@ package producers import ( "encoding/json" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/nats-io/nats.go" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" @@ -75,7 +76,13 @@ func (r *RoomEventProducer) ProduceRoomEvents(roomID string, updates []api.Outpu if eventType == acls.MRoomServerACL && update.NewRoomEvent.Event.StateKeyEquals("") { ev := update.NewRoomEvent.Event.PDU - defer r.ACLs.OnServerACLUpdate(ev) + strippedEvent := tables.StrippedEvent{ + RoomID: ev.RoomID().String(), + EventType: ev.Type(), + StateKey: *ev.StateKey(), + ContentValue: string(ev.Content()), + } + defer r.ACLs.OnServerACLUpdate(strippedEvent) } } logger.Tracef("Producing to topic '%s'", r.Topic) diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 0ae064e6b..b3cb31880 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -235,6 +235,10 @@ func ExtractContentValue(ev *types.HeaderedEvent) string { key = "topic" case "m.room.guest_access": key = "guest_access" + case "m.room.server_acl": + // We need the entire content and not only one key, so we can use it + // on startup to generate the ACLs. This is merely a workaround. + return string(content) } result := gjson.GetBytes(content, key) if !result.Exists() {