mirror of
https://mau.dev/maunium/synapse.git
synced 2024-11-11 12:31:58 +01:00
Open up /events to anonymous users for room events only
Squash-merge of PR #345 from daniel/anonymousevents
This commit is contained in:
parent
7a369e8a55
commit
ca2f90742d
20 changed files with 299 additions and 82 deletions
|
@ -47,7 +47,8 @@ class BaseHandler(object):
|
|||
self.event_builder_factory = hs.get_event_builder_factory()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _filter_events_for_client(self, user_id, events, is_guest=False):
|
||||
def _filter_events_for_client(self, user_id, events, is_guest=False,
|
||||
require_all_visible_for_guests=True):
|
||||
# Assumes that user has at some point joined the room if not is_guest.
|
||||
|
||||
def allowed(event, membership, visibility):
|
||||
|
@ -100,7 +101,9 @@ class BaseHandler(object):
|
|||
if should_include:
|
||||
events_to_return.append(event)
|
||||
|
||||
if is_guest and len(events_to_return) < len(events):
|
||||
if (require_all_visible_for_guests
|
||||
and is_guest
|
||||
and len(events_to_return) < len(events)):
|
||||
# This indicates that some events in the requested range were not
|
||||
# visible to guest users. To be safe, we reject the entire request,
|
||||
# so that we don't have to worry about interpreting visibility
|
||||
|
|
|
@ -100,7 +100,7 @@ class EventStreamHandler(BaseHandler):
|
|||
@log_function
|
||||
def get_stream(self, auth_user_id, pagin_config, timeout=0,
|
||||
as_client_event=True, affect_presence=True,
|
||||
only_room_events=False):
|
||||
only_room_events=False, room_id=None, is_guest=False):
|
||||
"""Fetches the events stream for a given user.
|
||||
|
||||
If `only_room_events` is `True` only room events will be returned.
|
||||
|
@ -119,9 +119,15 @@ class EventStreamHandler(BaseHandler):
|
|||
# thundering herds on restart.
|
||||
timeout = random.randint(int(timeout*0.9), int(timeout*1.1))
|
||||
|
||||
if is_guest:
|
||||
yield self.distributor.fire(
|
||||
"user_joined_room", user=auth_user, room_id=room_id
|
||||
)
|
||||
|
||||
events, tokens = yield self.notifier.get_events_for(
|
||||
auth_user, pagin_config, timeout,
|
||||
only_room_events=only_room_events
|
||||
only_room_events=only_room_events,
|
||||
is_guest=is_guest, guest_room_id=room_id
|
||||
)
|
||||
|
||||
time_now = self.clock.time_msec()
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.api.errors import SynapseError, AuthError, Codes
|
||||
from synapse.streams.config import PaginationConfig
|
||||
from synapse.events.utils import serialize_event
|
||||
from synapse.events.validator import EventValidator
|
||||
|
@ -229,7 +229,7 @@ class MessageHandler(BaseHandler):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def get_room_data(self, user_id=None, room_id=None,
|
||||
event_type=None, state_key=""):
|
||||
event_type=None, state_key="", is_guest=False):
|
||||
""" Get data from a room.
|
||||
|
||||
Args:
|
||||
|
@ -239,23 +239,42 @@ class MessageHandler(BaseHandler):
|
|||
Raises:
|
||||
SynapseError if something went wrong.
|
||||
"""
|
||||
member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
|
||||
membership, membership_event_id = yield self._check_in_room_or_world_readable(
|
||||
room_id, user_id, is_guest
|
||||
)
|
||||
|
||||
if member_event.membership == Membership.JOIN:
|
||||
if membership == Membership.JOIN:
|
||||
data = yield self.state_handler.get_current_state(
|
||||
room_id, event_type, state_key
|
||||
)
|
||||
elif member_event.membership == Membership.LEAVE:
|
||||
elif membership == Membership.LEAVE:
|
||||
key = (event_type, state_key)
|
||||
room_state = yield self.store.get_state_for_events(
|
||||
[member_event.event_id], [key]
|
||||
[membership_event_id], [key]
|
||||
)
|
||||
data = room_state[member_event.event_id].get(key)
|
||||
data = room_state[membership_event_id].get(key)
|
||||
|
||||
defer.returnValue(data)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_events(self, user_id, room_id):
|
||||
def _check_in_room_or_world_readable(self, room_id, user_id, is_guest):
|
||||
if is_guest:
|
||||
visibility = yield self.state_handler.get_current_state(
|
||||
room_id, EventTypes.RoomHistoryVisibility, ""
|
||||
)
|
||||
if visibility.content["history_visibility"] == "world_readable":
|
||||
defer.returnValue((Membership.JOIN, None))
|
||||
return
|
||||
else:
|
||||
raise AuthError(
|
||||
403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
|
||||
)
|
||||
else:
|
||||
member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
|
||||
defer.returnValue((member_event.membership, member_event.event_id))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_state_events(self, user_id, room_id, is_guest=False):
|
||||
"""Retrieve all state events for a given room. If the user is
|
||||
joined to the room then return the current state. If the user has
|
||||
left the room return the state events from when they left.
|
||||
|
@ -266,15 +285,17 @@ class MessageHandler(BaseHandler):
|
|||
Returns:
|
||||
A list of dicts representing state events. [{}, {}, {}]
|
||||
"""
|
||||
member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
|
||||
|
||||
if member_event.membership == Membership.JOIN:
|
||||
room_state = yield self.state_handler.get_current_state(room_id)
|
||||
elif member_event.membership == Membership.LEAVE:
|
||||
room_state = yield self.store.get_state_for_events(
|
||||
[member_event.event_id], None
|
||||
membership, membership_event_id = yield self._check_in_room_or_world_readable(
|
||||
room_id, user_id, is_guest
|
||||
)
|
||||
room_state = room_state[member_event.event_id]
|
||||
|
||||
if membership == Membership.JOIN:
|
||||
room_state = yield self.state_handler.get_current_state(room_id)
|
||||
elif membership == Membership.LEAVE:
|
||||
room_state = yield self.store.get_state_for_events(
|
||||
[membership_event_id], None
|
||||
)
|
||||
room_state = room_state[membership_event_id]
|
||||
|
||||
now = self.clock.time_msec()
|
||||
defer.returnValue(
|
||||
|
|
|
@ -1142,8 +1142,9 @@ class PresenceEventSource(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def get_new_events_for_user(self, user, from_key, limit):
|
||||
def get_new_events(self, user, from_key, room_ids=None, **kwargs):
|
||||
from_key = int(from_key)
|
||||
room_ids = room_ids or []
|
||||
|
||||
presence = self.hs.get_handlers().presence_handler
|
||||
cachemap = presence._user_cachemap
|
||||
|
@ -1161,7 +1162,6 @@ class PresenceEventSource(object):
|
|||
user_ids_to_check |= set(
|
||||
UserID.from_string(p["observed_user_id"]) for p in presence_list
|
||||
)
|
||||
room_ids = yield presence.get_joined_rooms_for_user(user)
|
||||
for room_id in set(room_ids) & set(presence._room_serials):
|
||||
if presence._room_serials[room_id] > from_key:
|
||||
joined = yield presence.get_joined_users_for_room_id(room_id)
|
||||
|
|
|
@ -24,7 +24,7 @@ class PrivateUserDataEventSource(object):
|
|||
return self.store.get_max_private_user_data_stream_id()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_new_events_for_user(self, user, from_key, limit):
|
||||
def get_new_events(self, user, from_key, **kwargs):
|
||||
user_id = user.to_string()
|
||||
last_stream_id = from_key
|
||||
|
||||
|
|
|
@ -164,17 +164,15 @@ class ReceiptEventSource(object):
|
|||
self.store = hs.get_datastore()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_new_events_for_user(self, user, from_key, limit):
|
||||
def get_new_events(self, from_key, room_ids, **kwargs):
|
||||
from_key = int(from_key)
|
||||
to_key = yield self.get_current_key()
|
||||
|
||||
if from_key == to_key:
|
||||
defer.returnValue(([], to_key))
|
||||
|
||||
rooms = yield self.store.get_rooms_for_user(user.to_string())
|
||||
rooms = [room.room_id for room in rooms]
|
||||
events = yield self.store.get_linearized_receipts_for_rooms(
|
||||
rooms,
|
||||
room_ids,
|
||||
from_key=from_key,
|
||||
to_key=to_key,
|
||||
)
|
||||
|
|
|
@ -807,7 +807,14 @@ class RoomEventSource(object):
|
|||
self.store = hs.get_datastore()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_new_events_for_user(self, user, from_key, limit):
|
||||
def get_new_events(
|
||||
self,
|
||||
user,
|
||||
from_key,
|
||||
limit,
|
||||
room_ids,
|
||||
is_guest,
|
||||
):
|
||||
# We just ignore the key for now.
|
||||
|
||||
to_key = yield self.get_current_key()
|
||||
|
@ -828,6 +835,8 @@ class RoomEventSource(object):
|
|||
from_key=from_key,
|
||||
to_key=to_key,
|
||||
limit=limit,
|
||||
room_ids=room_ids,
|
||||
is_guest=is_guest,
|
||||
)
|
||||
|
||||
defer.returnValue((events, end_key))
|
||||
|
|
|
@ -295,11 +295,16 @@ class SyncHandler(BaseHandler):
|
|||
|
||||
typing_key = since_token.typing_key if since_token else "0"
|
||||
|
||||
rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string())
|
||||
room_ids = [room.room_id for room in rooms]
|
||||
|
||||
typing_source = self.event_sources.sources["typing"]
|
||||
typing, typing_key = yield typing_source.get_new_events_for_user(
|
||||
typing, typing_key = yield typing_source.get_new_events(
|
||||
user=sync_config.user,
|
||||
from_key=typing_key,
|
||||
limit=sync_config.filter.ephemeral_limit(),
|
||||
room_ids=room_ids,
|
||||
is_guest=False,
|
||||
)
|
||||
now_token = now_token.copy_and_replace("typing_key", typing_key)
|
||||
|
||||
|
@ -312,10 +317,13 @@ class SyncHandler(BaseHandler):
|
|||
receipt_key = since_token.receipt_key if since_token else "0"
|
||||
|
||||
receipt_source = self.event_sources.sources["receipt"]
|
||||
receipts, receipt_key = yield receipt_source.get_new_events_for_user(
|
||||
receipts, receipt_key = yield receipt_source.get_new_events(
|
||||
user=sync_config.user,
|
||||
from_key=receipt_key,
|
||||
limit=sync_config.filter.ephemeral_limit(),
|
||||
room_ids=room_ids,
|
||||
# /sync doesn't support guest access, they can't get to this point in code
|
||||
is_guest=False,
|
||||
)
|
||||
now_token = now_token.copy_and_replace("receipt_key", receipt_key)
|
||||
|
||||
|
@ -360,11 +368,17 @@ class SyncHandler(BaseHandler):
|
|||
"""
|
||||
now_token = yield self.event_sources.get_current_token()
|
||||
|
||||
rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string())
|
||||
room_ids = [room.room_id for room in rooms]
|
||||
|
||||
presence_source = self.event_sources.sources["presence"]
|
||||
presence, presence_key = yield presence_source.get_new_events_for_user(
|
||||
presence, presence_key = yield presence_source.get_new_events(
|
||||
user=sync_config.user,
|
||||
from_key=since_token.presence_key,
|
||||
limit=sync_config.filter.presence_limit(),
|
||||
room_ids=room_ids,
|
||||
# /sync doesn't support guest access, they can't get to this point in code
|
||||
is_guest=False,
|
||||
)
|
||||
now_token = now_token.copy_and_replace("presence_key", presence_key)
|
||||
|
||||
|
|
|
@ -246,17 +246,12 @@ class TypingNotificationEventSource(object):
|
|||
},
|
||||
}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_new_events_for_user(self, user, from_key, limit):
|
||||
def get_new_events(self, from_key, room_ids, **kwargs):
|
||||
from_key = int(from_key)
|
||||
handler = self.handler()
|
||||
|
||||
joined_room_ids = (
|
||||
yield self.room_member_handler().get_joined_rooms_for_user(user)
|
||||
)
|
||||
|
||||
events = []
|
||||
for room_id in joined_room_ids:
|
||||
for room_id in room_ids:
|
||||
if room_id not in handler._room_serials:
|
||||
continue
|
||||
if handler._room_serials[room_id] <= from_key:
|
||||
|
@ -264,7 +259,7 @@ class TypingNotificationEventSource(object):
|
|||
|
||||
events.append(self._make_event_for(room_id))
|
||||
|
||||
defer.returnValue((events, handler._latest_room_serial))
|
||||
return events, handler._latest_room_serial
|
||||
|
||||
def get_current_key(self):
|
||||
return self.handler()._latest_room_serial
|
||||
|
|
|
@ -269,7 +269,7 @@ class Notifier(object):
|
|||
logger.exception("Failed to notify listener")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def wait_for_events(self, user, timeout, callback,
|
||||
def wait_for_events(self, user, timeout, callback, room_ids=None,
|
||||
from_token=StreamToken("s0", "0", "0", "0", "0")):
|
||||
"""Wait until the callback returns a non empty response or the
|
||||
timeout fires.
|
||||
|
@ -279,11 +279,12 @@ class Notifier(object):
|
|||
if user_stream is None:
|
||||
appservice = yield self.store.get_app_service_by_user_id(user)
|
||||
current_token = yield self.event_sources.get_current_token()
|
||||
if room_ids is None:
|
||||
rooms = yield self.store.get_rooms_for_user(user)
|
||||
rooms = [room.room_id for room in rooms]
|
||||
room_ids = [room.room_id for room in rooms]
|
||||
user_stream = _NotifierUserStream(
|
||||
user=user,
|
||||
rooms=rooms,
|
||||
rooms=room_ids,
|
||||
appservice=appservice,
|
||||
current_token=current_token,
|
||||
time_now_ms=self.clock.time_msec(),
|
||||
|
@ -329,7 +330,8 @@ class Notifier(object):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def get_events_for(self, user, pagination_config, timeout,
|
||||
only_room_events=False):
|
||||
only_room_events=False,
|
||||
is_guest=False, guest_room_id=None):
|
||||
""" For the given user and rooms, return any new events for them. If
|
||||
there are no new events wait for up to `timeout` milliseconds for any
|
||||
new events to happen before returning.
|
||||
|
@ -342,6 +344,16 @@ class Notifier(object):
|
|||
|
||||
limit = pagination_config.limit
|
||||
|
||||
room_ids = []
|
||||
if is_guest:
|
||||
# TODO(daniel): Deal with non-room events too
|
||||
only_room_events = True
|
||||
if guest_room_id:
|
||||
room_ids = [guest_room_id]
|
||||
else:
|
||||
rooms = yield self.store.get_rooms_for_user(user.to_string())
|
||||
room_ids = [room.room_id for room in rooms]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_for_updates(before_token, after_token):
|
||||
if not after_token.is_after(before_token):
|
||||
|
@ -357,9 +369,23 @@ class Notifier(object):
|
|||
continue
|
||||
if only_room_events and name != "room":
|
||||
continue
|
||||
new_events, new_key = yield source.get_new_events_for_user(
|
||||
user, getattr(from_token, keyname), limit,
|
||||
new_events, new_key = yield source.get_new_events(
|
||||
user=user,
|
||||
from_key=getattr(from_token, keyname),
|
||||
limit=limit,
|
||||
is_guest=is_guest,
|
||||
room_ids=room_ids,
|
||||
)
|
||||
|
||||
if is_guest:
|
||||
room_member_handler = self.hs.get_handlers().room_member_handler
|
||||
new_events = yield room_member_handler._filter_events_for_client(
|
||||
user.to_string(),
|
||||
new_events,
|
||||
is_guest=is_guest,
|
||||
require_all_visible_for_guests=False
|
||||
)
|
||||
|
||||
events.extend(new_events)
|
||||
end_token = end_token.copy_and_replace(keyname, new_key)
|
||||
|
||||
|
@ -369,7 +395,7 @@ class Notifier(object):
|
|||
defer.returnValue(None)
|
||||
|
||||
result = yield self.wait_for_events(
|
||||
user, timeout, check_for_updates, from_token=from_token
|
||||
user, timeout, check_for_updates, room_ids=room_ids, from_token=from_token
|
||||
)
|
||||
|
||||
if result is None:
|
||||
|
|
|
@ -34,7 +34,15 @@ class EventStreamRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
auth_user, _, is_guest = yield self.auth.get_user_by_req(
|
||||
request,
|
||||
allow_guest=True
|
||||
)
|
||||
room_id = None
|
||||
if is_guest:
|
||||
if "room_id" not in request.args:
|
||||
raise SynapseError(400, "Guest users must specify room_id param")
|
||||
room_id = request.args["room_id"][0]
|
||||
try:
|
||||
handler = self.handlers.event_stream_handler
|
||||
pagin_config = PaginationConfig.from_request(request)
|
||||
|
@ -49,7 +57,8 @@ class EventStreamRestServlet(ClientV1RestServlet):
|
|||
|
||||
chunk = yield handler.get_stream(
|
||||
auth_user.to_string(), pagin_config, timeout=timeout,
|
||||
as_client_event=as_client_event
|
||||
as_client_event=as_client_event, affect_presence=(not is_guest),
|
||||
room_id=room_id, is_guest=is_guest
|
||||
)
|
||||
except:
|
||||
logger.exception("Event stream failed")
|
||||
|
|
|
@ -125,7 +125,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id, event_type, state_key):
|
||||
user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
|
||||
msg_handler = self.handlers.message_handler
|
||||
data = yield msg_handler.get_room_data(
|
||||
|
@ -133,6 +133,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
|
|||
room_id=room_id,
|
||||
event_type=event_type,
|
||||
state_key=state_key,
|
||||
is_guest=is_guest,
|
||||
)
|
||||
|
||||
if not data:
|
||||
|
@ -348,12 +349,13 @@ class RoomStateRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id):
|
||||
user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
handler = self.handlers.message_handler
|
||||
# Get all the current state for this room
|
||||
events = yield handler.get_state_events(
|
||||
room_id=room_id,
|
||||
user_id=user.to_string(),
|
||||
is_guest=is_guest,
|
||||
)
|
||||
defer.returnValue((200, events))
|
||||
|
||||
|
|
|
@ -311,6 +311,8 @@ class EventsStore(SQLBaseStore):
|
|||
self._store_room_message_txn(txn, event)
|
||||
elif event.type == EventTypes.Redaction:
|
||||
self._store_redaction(txn, event)
|
||||
elif event.type == EventTypes.RoomHistoryVisibility:
|
||||
self._store_history_visibility_txn(txn, event)
|
||||
|
||||
self._store_room_members_txn(
|
||||
txn,
|
||||
|
|
|
@ -202,6 +202,19 @@ class RoomStore(SQLBaseStore):
|
|||
txn, event, "content.body", event.content["body"]
|
||||
)
|
||||
|
||||
def _store_history_visibility_txn(self, txn, event):
|
||||
if hasattr(event, "content") and "history_visibility" in event.content:
|
||||
sql = (
|
||||
"INSERT INTO history_visibility"
|
||||
" (event_id, room_id, history_visibility)"
|
||||
" VALUES (?, ?, ?)"
|
||||
)
|
||||
txn.execute(sql, (
|
||||
event.event_id,
|
||||
event.room_id,
|
||||
event.content["history_visibility"]
|
||||
))
|
||||
|
||||
def _store_event_search_txn(self, txn, event, key, value):
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
sql = (
|
||||
|
|
26
synapse/storage/schema/delta/25/history_visibility.sql
Normal file
26
synapse/storage/schema/delta/25/history_visibility.sql
Normal file
|
@ -0,0 +1,26 @@
|
|||
/* Copyright 2015 OpenMarket Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
/*
|
||||
* This is a manual index of history_visibility content of state events,
|
||||
* so that we can join on them in SELECT statements.
|
||||
*/
|
||||
CREATE TABLE IF NOT EXISTS history_visibility(
|
||||
id INTEGER PRIMARY KEY,
|
||||
event_id TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
history_visibility TEXT NOT NULL,
|
||||
UNIQUE (event_id)
|
||||
);
|
|
@ -158,13 +158,40 @@ class StreamStore(SQLBaseStore):
|
|||
defer.returnValue(results)
|
||||
|
||||
@log_function
|
||||
def get_room_events_stream(self, user_id, from_key, to_key, limit=0):
|
||||
def get_room_events_stream(
|
||||
self,
|
||||
user_id,
|
||||
from_key,
|
||||
to_key,
|
||||
limit=0,
|
||||
is_guest=False,
|
||||
room_ids=None
|
||||
):
|
||||
room_ids = room_ids or []
|
||||
room_ids = [r for r in room_ids]
|
||||
if is_guest:
|
||||
current_room_membership_sql = (
|
||||
"SELECT c.room_id FROM history_visibility AS h"
|
||||
" INNER JOIN current_state_events AS c"
|
||||
" ON h.event_id = c.event_id"
|
||||
" WHERE c.room_id IN (%s) AND h.history_visibility = 'world_readable'" % (
|
||||
",".join(map(lambda _: "?", room_ids))
|
||||
)
|
||||
)
|
||||
current_room_membership_args = room_ids
|
||||
else:
|
||||
current_room_membership_sql = (
|
||||
"SELECT m.room_id FROM room_memberships as m "
|
||||
" INNER JOIN current_state_events as c"
|
||||
" ON m.event_id = c.event_id AND c.state_key = m.user_id"
|
||||
" WHERE m.user_id = ? AND m.membership = 'join'"
|
||||
)
|
||||
current_room_membership_args = [user_id]
|
||||
if room_ids:
|
||||
current_room_membership_sql += " AND m.room_id in (%s)" % (
|
||||
",".join(map(lambda _: "?", room_ids))
|
||||
)
|
||||
current_room_membership_args = [user_id] + room_ids
|
||||
|
||||
# We also want to get any membership events about that user, e.g.
|
||||
# invites or leave notifications.
|
||||
|
@ -173,6 +200,7 @@ class StreamStore(SQLBaseStore):
|
|||
"INNER JOIN current_state_events as c ON m.event_id = c.event_id "
|
||||
"WHERE m.user_id = ? "
|
||||
)
|
||||
membership_args = [user_id]
|
||||
|
||||
if limit:
|
||||
limit = max(limit, MAX_STREAM_SIZE)
|
||||
|
@ -199,7 +227,9 @@ class StreamStore(SQLBaseStore):
|
|||
}
|
||||
|
||||
def f(txn):
|
||||
txn.execute(sql, (False, user_id, user_id, from_id.stream, to_id.stream,))
|
||||
args = ([False] + current_room_membership_args + membership_args +
|
||||
[from_id.stream, to_id.stream])
|
||||
txn.execute(sql, args)
|
||||
|
||||
rows = self.cursor_to_dict(txn)
|
||||
|
||||
|
|
|
@ -650,9 +650,30 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
|
|||
{"presence": ONLINE}
|
||||
)
|
||||
|
||||
# Apple sees self-reflection even without room_id
|
||||
(events, _) = yield self.event_source.get_new_events(
|
||||
user=self.u_apple,
|
||||
from_key=0,
|
||||
)
|
||||
|
||||
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||
self.assertEquals(events,
|
||||
[
|
||||
{"type": "m.presence",
|
||||
"content": {
|
||||
"user_id": "@apple:test",
|
||||
"presence": ONLINE,
|
||||
"last_active_ago": 0,
|
||||
}},
|
||||
],
|
||||
msg="Presence event should be visible to self-reflection"
|
||||
)
|
||||
|
||||
# Apple sees self-reflection
|
||||
(events, _) = yield self.event_source.get_new_events_for_user(
|
||||
self.u_apple, 0, None
|
||||
(events, _) = yield self.event_source.get_new_events(
|
||||
user=self.u_apple,
|
||||
from_key=0,
|
||||
room_ids=[self.room_id],
|
||||
)
|
||||
|
||||
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||
|
@ -684,8 +705,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
|
|||
)
|
||||
|
||||
# Banana sees it because of presence subscription
|
||||
(events, _) = yield self.event_source.get_new_events_for_user(
|
||||
self.u_banana, 0, None
|
||||
(events, _) = yield self.event_source.get_new_events(
|
||||
user=self.u_banana,
|
||||
from_key=0,
|
||||
room_ids=[self.room_id],
|
||||
)
|
||||
|
||||
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||
|
@ -702,8 +725,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
|
|||
)
|
||||
|
||||
# Elderberry sees it because of same room
|
||||
(events, _) = yield self.event_source.get_new_events_for_user(
|
||||
self.u_elderberry, 0, None
|
||||
(events, _) = yield self.event_source.get_new_events(
|
||||
user=self.u_elderberry,
|
||||
from_key=0,
|
||||
room_ids=[self.room_id],
|
||||
)
|
||||
|
||||
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||
|
@ -720,8 +745,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
|
|||
)
|
||||
|
||||
# Durian is not in the room, should not see this event
|
||||
(events, _) = yield self.event_source.get_new_events_for_user(
|
||||
self.u_durian, 0, None
|
||||
(events, _) = yield self.event_source.get_new_events(
|
||||
user=self.u_durian,
|
||||
from_key=0,
|
||||
room_ids=[],
|
||||
)
|
||||
|
||||
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||
|
@ -767,8 +794,9 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
|
|||
"accepted": True},
|
||||
], presence)
|
||||
|
||||
(events, _) = yield self.event_source.get_new_events_for_user(
|
||||
self.u_apple, 1, None
|
||||
(events, _) = yield self.event_source.get_new_events(
|
||||
user=self.u_apple,
|
||||
from_key=1,
|
||||
)
|
||||
|
||||
self.assertEquals(self.event_source.get_current_key(), 2)
|
||||
|
@ -858,8 +886,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
|
|||
)
|
||||
)
|
||||
|
||||
(events, _) = yield self.event_source.get_new_events_for_user(
|
||||
self.u_apple, 0, None
|
||||
(events, _) = yield self.event_source.get_new_events(
|
||||
user=self.u_apple,
|
||||
from_key=0,
|
||||
room_ids=[self.room_id],
|
||||
)
|
||||
|
||||
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||
|
@ -905,8 +935,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
|
|||
|
||||
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||
|
||||
(events, _) = yield self.event_source.get_new_events_for_user(
|
||||
self.u_apple, 0, None
|
||||
(events, _) = yield self.event_source.get_new_events(
|
||||
user=self.u_apple,
|
||||
from_key=0,
|
||||
room_ids=[self.room_id,]
|
||||
)
|
||||
self.assertEquals(events,
|
||||
[
|
||||
|
@ -932,8 +964,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
|
|||
|
||||
self.assertEquals(self.event_source.get_current_key(), 2)
|
||||
|
||||
(events, _) = yield self.event_source.get_new_events_for_user(
|
||||
self.u_apple, 0, None
|
||||
(events, _) = yield self.event_source.get_new_events(
|
||||
user=self.u_apple,
|
||||
from_key=0,
|
||||
room_ids=[self.room_id,]
|
||||
)
|
||||
self.assertEquals(events,
|
||||
[
|
||||
|
@ -966,8 +1000,9 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
|
|||
|
||||
self.room_members.append(self.u_clementine)
|
||||
|
||||
(events, _) = yield self.event_source.get_new_events_for_user(
|
||||
self.u_apple, 0, None
|
||||
(events, _) = yield self.event_source.get_new_events(
|
||||
user=self.u_apple,
|
||||
from_key=0,
|
||||
)
|
||||
|
||||
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||
|
|
|
@ -187,7 +187,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
|||
])
|
||||
|
||||
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||
events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None)
|
||||
events = yield self.event_source.get_new_events(
|
||||
room_ids=[self.room_id],
|
||||
from_key=0,
|
||||
)
|
||||
self.assertEquals(
|
||||
events[0],
|
||||
[
|
||||
|
@ -250,7 +253,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
|||
])
|
||||
|
||||
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||
events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None)
|
||||
events = yield self.event_source.get_new_events(
|
||||
room_ids=[self.room_id],
|
||||
from_key=0
|
||||
)
|
||||
self.assertEquals(
|
||||
events[0],
|
||||
[
|
||||
|
@ -306,7 +312,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
|||
yield put_json.await_calls()
|
||||
|
||||
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||
events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None)
|
||||
events = yield self.event_source.get_new_events(
|
||||
room_ids=[self.room_id],
|
||||
from_key=0,
|
||||
)
|
||||
self.assertEquals(
|
||||
events[0],
|
||||
[
|
||||
|
@ -337,7 +346,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
|||
self.on_new_event.reset_mock()
|
||||
|
||||
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||
events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None)
|
||||
events = yield self.event_source.get_new_events(
|
||||
room_ids=[self.room_id],
|
||||
from_key=0,
|
||||
)
|
||||
self.assertEquals(
|
||||
events[0],
|
||||
[
|
||||
|
@ -356,7 +368,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
|||
])
|
||||
|
||||
self.assertEquals(self.event_source.get_current_key(), 2)
|
||||
events = yield self.event_source.get_new_events_for_user(self.u_apple, 1, None)
|
||||
events = yield self.event_source.get_new_events(
|
||||
room_ids=[self.room_id],
|
||||
from_key=1,
|
||||
)
|
||||
self.assertEquals(
|
||||
events[0],
|
||||
[
|
||||
|
@ -383,7 +398,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
|||
self.on_new_event.reset_mock()
|
||||
|
||||
self.assertEquals(self.event_source.get_current_key(), 3)
|
||||
events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None)
|
||||
events = yield self.event_source.get_new_events(
|
||||
room_ids=[self.room_id],
|
||||
from_key=0,
|
||||
)
|
||||
self.assertEquals(
|
||||
events[0],
|
||||
[
|
||||
|
|
|
@ -47,7 +47,14 @@ class NullSource(object):
|
|||
def __init__(self, hs):
|
||||
pass
|
||||
|
||||
def get_new_events_for_user(self, user, from_key, limit):
|
||||
def get_new_events(
|
||||
self,
|
||||
user,
|
||||
from_key,
|
||||
room_ids=None,
|
||||
limit=None,
|
||||
is_guest=None
|
||||
):
|
||||
return defer.succeed(([], from_key))
|
||||
|
||||
def get_current_key(self, direction='f'):
|
||||
|
|
|
@ -116,7 +116,10 @@ class RoomTypingTestCase(RestTestCase):
|
|||
self.assertEquals(200, code)
|
||||
|
||||
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||
events = yield self.event_source.get_new_events_for_user(self.user, 0, None)
|
||||
events = yield self.event_source.get_new_events(
|
||||
from_key=0,
|
||||
room_ids=[self.room_id],
|
||||
)
|
||||
self.assertEquals(
|
||||
events[0],
|
||||
[
|
||||
|
|
Loading…
Reference in a new issue