0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-13 23:43:45 +01:00

Open up /events to anonymous users for room events only

Squash-merge of PR #345 from daniel/anonymousevents
This commit is contained in:
Daniel Wagner-Hall 2015-11-05 14:32:26 +00:00
parent 7a369e8a55
commit ca2f90742d
20 changed files with 299 additions and 82 deletions

View file

@ -47,7 +47,8 @@ class BaseHandler(object):
self.event_builder_factory = hs.get_event_builder_factory() self.event_builder_factory = hs.get_event_builder_factory()
@defer.inlineCallbacks @defer.inlineCallbacks
def _filter_events_for_client(self, user_id, events, is_guest=False): def _filter_events_for_client(self, user_id, events, is_guest=False,
require_all_visible_for_guests=True):
# Assumes that user has at some point joined the room if not is_guest. # Assumes that user has at some point joined the room if not is_guest.
def allowed(event, membership, visibility): def allowed(event, membership, visibility):
@ -100,7 +101,9 @@ class BaseHandler(object):
if should_include: if should_include:
events_to_return.append(event) events_to_return.append(event)
if is_guest and len(events_to_return) < len(events): if (require_all_visible_for_guests
and is_guest
and len(events_to_return) < len(events)):
# This indicates that some events in the requested range were not # This indicates that some events in the requested range were not
# visible to guest users. To be safe, we reject the entire request, # visible to guest users. To be safe, we reject the entire request,
# so that we don't have to worry about interpreting visibility # so that we don't have to worry about interpreting visibility

View file

@ -100,7 +100,7 @@ class EventStreamHandler(BaseHandler):
@log_function @log_function
def get_stream(self, auth_user_id, pagin_config, timeout=0, def get_stream(self, auth_user_id, pagin_config, timeout=0,
as_client_event=True, affect_presence=True, as_client_event=True, affect_presence=True,
only_room_events=False): only_room_events=False, room_id=None, is_guest=False):
"""Fetches the events stream for a given user. """Fetches the events stream for a given user.
If `only_room_events` is `True` only room events will be returned. If `only_room_events` is `True` only room events will be returned.
@ -119,9 +119,15 @@ class EventStreamHandler(BaseHandler):
# thundering herds on restart. # thundering herds on restart.
timeout = random.randint(int(timeout*0.9), int(timeout*1.1)) timeout = random.randint(int(timeout*0.9), int(timeout*1.1))
if is_guest:
yield self.distributor.fire(
"user_joined_room", user=auth_user, room_id=room_id
)
events, tokens = yield self.notifier.get_events_for( events, tokens = yield self.notifier.get_events_for(
auth_user, pagin_config, timeout, auth_user, pagin_config, timeout,
only_room_events=only_room_events only_room_events=only_room_events,
is_guest=is_guest, guest_room_id=room_id
) )
time_now = self.clock.time_msec() time_now = self.clock.time_msec()

View file

@ -16,7 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError, AuthError, Codes
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator from synapse.events.validator import EventValidator
@ -229,7 +229,7 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_room_data(self, user_id=None, room_id=None, def get_room_data(self, user_id=None, room_id=None,
event_type=None, state_key=""): event_type=None, state_key="", is_guest=False):
""" Get data from a room. """ Get data from a room.
Args: Args:
@ -239,23 +239,42 @@ class MessageHandler(BaseHandler):
Raises: Raises:
SynapseError if something went wrong. SynapseError if something went wrong.
""" """
member_event = yield self.auth.check_user_was_in_room(room_id, user_id) membership, membership_event_id = yield self._check_in_room_or_world_readable(
room_id, user_id, is_guest
)
if member_event.membership == Membership.JOIN: if membership == Membership.JOIN:
data = yield self.state_handler.get_current_state( data = yield self.state_handler.get_current_state(
room_id, event_type, state_key room_id, event_type, state_key
) )
elif member_event.membership == Membership.LEAVE: elif membership == Membership.LEAVE:
key = (event_type, state_key) key = (event_type, state_key)
room_state = yield self.store.get_state_for_events( room_state = yield self.store.get_state_for_events(
[member_event.event_id], [key] [membership_event_id], [key]
) )
data = room_state[member_event.event_id].get(key) data = room_state[membership_event_id].get(key)
defer.returnValue(data) defer.returnValue(data)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_events(self, user_id, room_id): def _check_in_room_or_world_readable(self, room_id, user_id, is_guest):
if is_guest:
visibility = yield self.state_handler.get_current_state(
room_id, EventTypes.RoomHistoryVisibility, ""
)
if visibility.content["history_visibility"] == "world_readable":
defer.returnValue((Membership.JOIN, None))
return
else:
raise AuthError(
403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
)
else:
member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
defer.returnValue((member_event.membership, member_event.event_id))
@defer.inlineCallbacks
def get_state_events(self, user_id, room_id, is_guest=False):
"""Retrieve all state events for a given room. If the user is """Retrieve all state events for a given room. If the user is
joined to the room then return the current state. If the user has joined to the room then return the current state. If the user has
left the room return the state events from when they left. left the room return the state events from when they left.
@ -266,15 +285,17 @@ class MessageHandler(BaseHandler):
Returns: Returns:
A list of dicts representing state events. [{}, {}, {}] A list of dicts representing state events. [{}, {}, {}]
""" """
member_event = yield self.auth.check_user_was_in_room(room_id, user_id) membership, membership_event_id = yield self._check_in_room_or_world_readable(
room_id, user_id, is_guest
)
if member_event.membership == Membership.JOIN: if membership == Membership.JOIN:
room_state = yield self.state_handler.get_current_state(room_id) room_state = yield self.state_handler.get_current_state(room_id)
elif member_event.membership == Membership.LEAVE: elif membership == Membership.LEAVE:
room_state = yield self.store.get_state_for_events( room_state = yield self.store.get_state_for_events(
[member_event.event_id], None [membership_event_id], None
) )
room_state = room_state[member_event.event_id] room_state = room_state[membership_event_id]
now = self.clock.time_msec() now = self.clock.time_msec()
defer.returnValue( defer.returnValue(

View file

@ -1142,8 +1142,9 @@ class PresenceEventSource(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def get_new_events_for_user(self, user, from_key, limit): def get_new_events(self, user, from_key, room_ids=None, **kwargs):
from_key = int(from_key) from_key = int(from_key)
room_ids = room_ids or []
presence = self.hs.get_handlers().presence_handler presence = self.hs.get_handlers().presence_handler
cachemap = presence._user_cachemap cachemap = presence._user_cachemap
@ -1161,7 +1162,6 @@ class PresenceEventSource(object):
user_ids_to_check |= set( user_ids_to_check |= set(
UserID.from_string(p["observed_user_id"]) for p in presence_list UserID.from_string(p["observed_user_id"]) for p in presence_list
) )
room_ids = yield presence.get_joined_rooms_for_user(user)
for room_id in set(room_ids) & set(presence._room_serials): for room_id in set(room_ids) & set(presence._room_serials):
if presence._room_serials[room_id] > from_key: if presence._room_serials[room_id] > from_key:
joined = yield presence.get_joined_users_for_room_id(room_id) joined = yield presence.get_joined_users_for_room_id(room_id)

View file

@ -24,7 +24,7 @@ class PrivateUserDataEventSource(object):
return self.store.get_max_private_user_data_stream_id() return self.store.get_max_private_user_data_stream_id()
@defer.inlineCallbacks @defer.inlineCallbacks
def get_new_events_for_user(self, user, from_key, limit): def get_new_events(self, user, from_key, **kwargs):
user_id = user.to_string() user_id = user.to_string()
last_stream_id = from_key last_stream_id = from_key

View file

@ -164,17 +164,15 @@ class ReceiptEventSource(object):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@defer.inlineCallbacks @defer.inlineCallbacks
def get_new_events_for_user(self, user, from_key, limit): def get_new_events(self, from_key, room_ids, **kwargs):
from_key = int(from_key) from_key = int(from_key)
to_key = yield self.get_current_key() to_key = yield self.get_current_key()
if from_key == to_key: if from_key == to_key:
defer.returnValue(([], to_key)) defer.returnValue(([], to_key))
rooms = yield self.store.get_rooms_for_user(user.to_string())
rooms = [room.room_id for room in rooms]
events = yield self.store.get_linearized_receipts_for_rooms( events = yield self.store.get_linearized_receipts_for_rooms(
rooms, room_ids,
from_key=from_key, from_key=from_key,
to_key=to_key, to_key=to_key,
) )

View file

@ -807,7 +807,14 @@ class RoomEventSource(object):
self.store = hs.get_datastore() self.store = hs.get_datastore()
@defer.inlineCallbacks @defer.inlineCallbacks
def get_new_events_for_user(self, user, from_key, limit): def get_new_events(
self,
user,
from_key,
limit,
room_ids,
is_guest,
):
# We just ignore the key for now. # We just ignore the key for now.
to_key = yield self.get_current_key() to_key = yield self.get_current_key()
@ -828,6 +835,8 @@ class RoomEventSource(object):
from_key=from_key, from_key=from_key,
to_key=to_key, to_key=to_key,
limit=limit, limit=limit,
room_ids=room_ids,
is_guest=is_guest,
) )
defer.returnValue((events, end_key)) defer.returnValue((events, end_key))

View file

@ -295,11 +295,16 @@ class SyncHandler(BaseHandler):
typing_key = since_token.typing_key if since_token else "0" typing_key = since_token.typing_key if since_token else "0"
rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string())
room_ids = [room.room_id for room in rooms]
typing_source = self.event_sources.sources["typing"] typing_source = self.event_sources.sources["typing"]
typing, typing_key = yield typing_source.get_new_events_for_user( typing, typing_key = yield typing_source.get_new_events(
user=sync_config.user, user=sync_config.user,
from_key=typing_key, from_key=typing_key,
limit=sync_config.filter.ephemeral_limit(), limit=sync_config.filter.ephemeral_limit(),
room_ids=room_ids,
is_guest=False,
) )
now_token = now_token.copy_and_replace("typing_key", typing_key) now_token = now_token.copy_and_replace("typing_key", typing_key)
@ -312,10 +317,13 @@ class SyncHandler(BaseHandler):
receipt_key = since_token.receipt_key if since_token else "0" receipt_key = since_token.receipt_key if since_token else "0"
receipt_source = self.event_sources.sources["receipt"] receipt_source = self.event_sources.sources["receipt"]
receipts, receipt_key = yield receipt_source.get_new_events_for_user( receipts, receipt_key = yield receipt_source.get_new_events(
user=sync_config.user, user=sync_config.user,
from_key=receipt_key, from_key=receipt_key,
limit=sync_config.filter.ephemeral_limit(), limit=sync_config.filter.ephemeral_limit(),
room_ids=room_ids,
# /sync doesn't support guest access, they can't get to this point in code
is_guest=False,
) )
now_token = now_token.copy_and_replace("receipt_key", receipt_key) now_token = now_token.copy_and_replace("receipt_key", receipt_key)
@ -360,11 +368,17 @@ class SyncHandler(BaseHandler):
""" """
now_token = yield self.event_sources.get_current_token() now_token = yield self.event_sources.get_current_token()
rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string())
room_ids = [room.room_id for room in rooms]
presence_source = self.event_sources.sources["presence"] presence_source = self.event_sources.sources["presence"]
presence, presence_key = yield presence_source.get_new_events_for_user( presence, presence_key = yield presence_source.get_new_events(
user=sync_config.user, user=sync_config.user,
from_key=since_token.presence_key, from_key=since_token.presence_key,
limit=sync_config.filter.presence_limit(), limit=sync_config.filter.presence_limit(),
room_ids=room_ids,
# /sync doesn't support guest access, they can't get to this point in code
is_guest=False,
) )
now_token = now_token.copy_and_replace("presence_key", presence_key) now_token = now_token.copy_and_replace("presence_key", presence_key)

View file

@ -246,17 +246,12 @@ class TypingNotificationEventSource(object):
}, },
} }
@defer.inlineCallbacks def get_new_events(self, from_key, room_ids, **kwargs):
def get_new_events_for_user(self, user, from_key, limit):
from_key = int(from_key) from_key = int(from_key)
handler = self.handler() handler = self.handler()
joined_room_ids = (
yield self.room_member_handler().get_joined_rooms_for_user(user)
)
events = [] events = []
for room_id in joined_room_ids: for room_id in room_ids:
if room_id not in handler._room_serials: if room_id not in handler._room_serials:
continue continue
if handler._room_serials[room_id] <= from_key: if handler._room_serials[room_id] <= from_key:
@ -264,7 +259,7 @@ class TypingNotificationEventSource(object):
events.append(self._make_event_for(room_id)) events.append(self._make_event_for(room_id))
defer.returnValue((events, handler._latest_room_serial)) return events, handler._latest_room_serial
def get_current_key(self): def get_current_key(self):
return self.handler()._latest_room_serial return self.handler()._latest_room_serial

View file

@ -269,7 +269,7 @@ class Notifier(object):
logger.exception("Failed to notify listener") logger.exception("Failed to notify listener")
@defer.inlineCallbacks @defer.inlineCallbacks
def wait_for_events(self, user, timeout, callback, def wait_for_events(self, user, timeout, callback, room_ids=None,
from_token=StreamToken("s0", "0", "0", "0", "0")): from_token=StreamToken("s0", "0", "0", "0", "0")):
"""Wait until the callback returns a non empty response or the """Wait until the callback returns a non empty response or the
timeout fires. timeout fires.
@ -279,11 +279,12 @@ class Notifier(object):
if user_stream is None: if user_stream is None:
appservice = yield self.store.get_app_service_by_user_id(user) appservice = yield self.store.get_app_service_by_user_id(user)
current_token = yield self.event_sources.get_current_token() current_token = yield self.event_sources.get_current_token()
rooms = yield self.store.get_rooms_for_user(user) if room_ids is None:
rooms = [room.room_id for room in rooms] rooms = yield self.store.get_rooms_for_user(user)
room_ids = [room.room_id for room in rooms]
user_stream = _NotifierUserStream( user_stream = _NotifierUserStream(
user=user, user=user,
rooms=rooms, rooms=room_ids,
appservice=appservice, appservice=appservice,
current_token=current_token, current_token=current_token,
time_now_ms=self.clock.time_msec(), time_now_ms=self.clock.time_msec(),
@ -329,7 +330,8 @@ class Notifier(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_events_for(self, user, pagination_config, timeout, def get_events_for(self, user, pagination_config, timeout,
only_room_events=False): only_room_events=False,
is_guest=False, guest_room_id=None):
""" For the given user and rooms, return any new events for them. If """ For the given user and rooms, return any new events for them. If
there are no new events wait for up to `timeout` milliseconds for any there are no new events wait for up to `timeout` milliseconds for any
new events to happen before returning. new events to happen before returning.
@ -342,6 +344,16 @@ class Notifier(object):
limit = pagination_config.limit limit = pagination_config.limit
room_ids = []
if is_guest:
# TODO(daniel): Deal with non-room events too
only_room_events = True
if guest_room_id:
room_ids = [guest_room_id]
else:
rooms = yield self.store.get_rooms_for_user(user.to_string())
room_ids = [room.room_id for room in rooms]
@defer.inlineCallbacks @defer.inlineCallbacks
def check_for_updates(before_token, after_token): def check_for_updates(before_token, after_token):
if not after_token.is_after(before_token): if not after_token.is_after(before_token):
@ -357,9 +369,23 @@ class Notifier(object):
continue continue
if only_room_events and name != "room": if only_room_events and name != "room":
continue continue
new_events, new_key = yield source.get_new_events_for_user( new_events, new_key = yield source.get_new_events(
user, getattr(from_token, keyname), limit, user=user,
from_key=getattr(from_token, keyname),
limit=limit,
is_guest=is_guest,
room_ids=room_ids,
) )
if is_guest:
room_member_handler = self.hs.get_handlers().room_member_handler
new_events = yield room_member_handler._filter_events_for_client(
user.to_string(),
new_events,
is_guest=is_guest,
require_all_visible_for_guests=False
)
events.extend(new_events) events.extend(new_events)
end_token = end_token.copy_and_replace(keyname, new_key) end_token = end_token.copy_and_replace(keyname, new_key)
@ -369,7 +395,7 @@ class Notifier(object):
defer.returnValue(None) defer.returnValue(None)
result = yield self.wait_for_events( result = yield self.wait_for_events(
user, timeout, check_for_updates, from_token=from_token user, timeout, check_for_updates, room_ids=room_ids, from_token=from_token
) )
if result is None: if result is None:

View file

@ -34,7 +34,15 @@ class EventStreamRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request): def on_GET(self, request):
auth_user, _, _ = yield self.auth.get_user_by_req(request) auth_user, _, is_guest = yield self.auth.get_user_by_req(
request,
allow_guest=True
)
room_id = None
if is_guest:
if "room_id" not in request.args:
raise SynapseError(400, "Guest users must specify room_id param")
room_id = request.args["room_id"][0]
try: try:
handler = self.handlers.event_stream_handler handler = self.handlers.event_stream_handler
pagin_config = PaginationConfig.from_request(request) pagin_config = PaginationConfig.from_request(request)
@ -49,7 +57,8 @@ class EventStreamRestServlet(ClientV1RestServlet):
chunk = yield handler.get_stream( chunk = yield handler.get_stream(
auth_user.to_string(), pagin_config, timeout=timeout, auth_user.to_string(), pagin_config, timeout=timeout,
as_client_event=as_client_event as_client_event=as_client_event, affect_presence=(not is_guest),
room_id=room_id, is_guest=is_guest
) )
except: except:
logger.exception("Event stream failed") logger.exception("Event stream failed")

View file

@ -125,7 +125,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id, event_type, state_key): def on_GET(self, request, room_id, event_type, state_key):
user, _, _ = yield self.auth.get_user_by_req(request) user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
msg_handler = self.handlers.message_handler msg_handler = self.handlers.message_handler
data = yield msg_handler.get_room_data( data = yield msg_handler.get_room_data(
@ -133,6 +133,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
room_id=room_id, room_id=room_id,
event_type=event_type, event_type=event_type,
state_key=state_key, state_key=state_key,
is_guest=is_guest,
) )
if not data: if not data:
@ -348,12 +349,13 @@ class RoomStateRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, room_id): def on_GET(self, request, room_id):
user, _, _ = yield self.auth.get_user_by_req(request) user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
handler = self.handlers.message_handler handler = self.handlers.message_handler
# Get all the current state for this room # Get all the current state for this room
events = yield handler.get_state_events( events = yield handler.get_state_events(
room_id=room_id, room_id=room_id,
user_id=user.to_string(), user_id=user.to_string(),
is_guest=is_guest,
) )
defer.returnValue((200, events)) defer.returnValue((200, events))

View file

@ -311,6 +311,8 @@ class EventsStore(SQLBaseStore):
self._store_room_message_txn(txn, event) self._store_room_message_txn(txn, event)
elif event.type == EventTypes.Redaction: elif event.type == EventTypes.Redaction:
self._store_redaction(txn, event) self._store_redaction(txn, event)
elif event.type == EventTypes.RoomHistoryVisibility:
self._store_history_visibility_txn(txn, event)
self._store_room_members_txn( self._store_room_members_txn(
txn, txn,

View file

@ -202,6 +202,19 @@ class RoomStore(SQLBaseStore):
txn, event, "content.body", event.content["body"] txn, event, "content.body", event.content["body"]
) )
def _store_history_visibility_txn(self, txn, event):
if hasattr(event, "content") and "history_visibility" in event.content:
sql = (
"INSERT INTO history_visibility"
" (event_id, room_id, history_visibility)"
" VALUES (?, ?, ?)"
)
txn.execute(sql, (
event.event_id,
event.room_id,
event.content["history_visibility"]
))
def _store_event_search_txn(self, txn, event, key, value): def _store_event_search_txn(self, txn, event, key, value):
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
sql = ( sql = (

View file

@ -0,0 +1,26 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* This is a manual index of history_visibility content of state events,
* so that we can join on them in SELECT statements.
*/
CREATE TABLE IF NOT EXISTS history_visibility(
id INTEGER PRIMARY KEY,
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
history_visibility TEXT NOT NULL,
UNIQUE (event_id)
);

View file

@ -158,13 +158,40 @@ class StreamStore(SQLBaseStore):
defer.returnValue(results) defer.returnValue(results)
@log_function @log_function
def get_room_events_stream(self, user_id, from_key, to_key, limit=0): def get_room_events_stream(
current_room_membership_sql = ( self,
"SELECT m.room_id FROM room_memberships as m " user_id,
" INNER JOIN current_state_events as c" from_key,
" ON m.event_id = c.event_id AND c.state_key = m.user_id" to_key,
" WHERE m.user_id = ? AND m.membership = 'join'" limit=0,
) is_guest=False,
room_ids=None
):
room_ids = room_ids or []
room_ids = [r for r in room_ids]
if is_guest:
current_room_membership_sql = (
"SELECT c.room_id FROM history_visibility AS h"
" INNER JOIN current_state_events AS c"
" ON h.event_id = c.event_id"
" WHERE c.room_id IN (%s) AND h.history_visibility = 'world_readable'" % (
",".join(map(lambda _: "?", room_ids))
)
)
current_room_membership_args = room_ids
else:
current_room_membership_sql = (
"SELECT m.room_id FROM room_memberships as m "
" INNER JOIN current_state_events as c"
" ON m.event_id = c.event_id AND c.state_key = m.user_id"
" WHERE m.user_id = ? AND m.membership = 'join'"
)
current_room_membership_args = [user_id]
if room_ids:
current_room_membership_sql += " AND m.room_id in (%s)" % (
",".join(map(lambda _: "?", room_ids))
)
current_room_membership_args = [user_id] + room_ids
# We also want to get any membership events about that user, e.g. # We also want to get any membership events about that user, e.g.
# invites or leave notifications. # invites or leave notifications.
@ -173,6 +200,7 @@ class StreamStore(SQLBaseStore):
"INNER JOIN current_state_events as c ON m.event_id = c.event_id " "INNER JOIN current_state_events as c ON m.event_id = c.event_id "
"WHERE m.user_id = ? " "WHERE m.user_id = ? "
) )
membership_args = [user_id]
if limit: if limit:
limit = max(limit, MAX_STREAM_SIZE) limit = max(limit, MAX_STREAM_SIZE)
@ -199,7 +227,9 @@ class StreamStore(SQLBaseStore):
} }
def f(txn): def f(txn):
txn.execute(sql, (False, user_id, user_id, from_id.stream, to_id.stream,)) args = ([False] + current_room_membership_args + membership_args +
[from_id.stream, to_id.stream])
txn.execute(sql, args)
rows = self.cursor_to_dict(txn) rows = self.cursor_to_dict(txn)

View file

@ -650,9 +650,30 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
{"presence": ONLINE} {"presence": ONLINE}
) )
# Apple sees self-reflection even without room_id
(events, _) = yield self.event_source.get_new_events(
user=self.u_apple,
from_key=0,
)
self.assertEquals(self.event_source.get_current_key(), 1)
self.assertEquals(events,
[
{"type": "m.presence",
"content": {
"user_id": "@apple:test",
"presence": ONLINE,
"last_active_ago": 0,
}},
],
msg="Presence event should be visible to self-reflection"
)
# Apple sees self-reflection # Apple sees self-reflection
(events, _) = yield self.event_source.get_new_events_for_user( (events, _) = yield self.event_source.get_new_events(
self.u_apple, 0, None user=self.u_apple,
from_key=0,
room_ids=[self.room_id],
) )
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
@ -684,8 +705,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
) )
# Banana sees it because of presence subscription # Banana sees it because of presence subscription
(events, _) = yield self.event_source.get_new_events_for_user( (events, _) = yield self.event_source.get_new_events(
self.u_banana, 0, None user=self.u_banana,
from_key=0,
room_ids=[self.room_id],
) )
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
@ -702,8 +725,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
) )
# Elderberry sees it because of same room # Elderberry sees it because of same room
(events, _) = yield self.event_source.get_new_events_for_user( (events, _) = yield self.event_source.get_new_events(
self.u_elderberry, 0, None user=self.u_elderberry,
from_key=0,
room_ids=[self.room_id],
) )
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
@ -720,8 +745,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
) )
# Durian is not in the room, should not see this event # Durian is not in the room, should not see this event
(events, _) = yield self.event_source.get_new_events_for_user( (events, _) = yield self.event_source.get_new_events(
self.u_durian, 0, None user=self.u_durian,
from_key=0,
room_ids=[],
) )
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
@ -767,8 +794,9 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
"accepted": True}, "accepted": True},
], presence) ], presence)
(events, _) = yield self.event_source.get_new_events_for_user( (events, _) = yield self.event_source.get_new_events(
self.u_apple, 1, None user=self.u_apple,
from_key=1,
) )
self.assertEquals(self.event_source.get_current_key(), 2) self.assertEquals(self.event_source.get_current_key(), 2)
@ -858,8 +886,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
) )
) )
(events, _) = yield self.event_source.get_new_events_for_user( (events, _) = yield self.event_source.get_new_events(
self.u_apple, 0, None user=self.u_apple,
from_key=0,
room_ids=[self.room_id],
) )
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
@ -905,8 +935,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
(events, _) = yield self.event_source.get_new_events_for_user( (events, _) = yield self.event_source.get_new_events(
self.u_apple, 0, None user=self.u_apple,
from_key=0,
room_ids=[self.room_id,]
) )
self.assertEquals(events, self.assertEquals(events,
[ [
@ -932,8 +964,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
self.assertEquals(self.event_source.get_current_key(), 2) self.assertEquals(self.event_source.get_current_key(), 2)
(events, _) = yield self.event_source.get_new_events_for_user( (events, _) = yield self.event_source.get_new_events(
self.u_apple, 0, None user=self.u_apple,
from_key=0,
room_ids=[self.room_id,]
) )
self.assertEquals(events, self.assertEquals(events,
[ [
@ -966,8 +1000,9 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
self.room_members.append(self.u_clementine) self.room_members.append(self.u_clementine)
(events, _) = yield self.event_source.get_new_events_for_user( (events, _) = yield self.event_source.get_new_events(
self.u_apple, 0, None user=self.u_apple,
from_key=0,
) )
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)

View file

@ -187,7 +187,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
]) ])
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None) events = yield self.event_source.get_new_events(
room_ids=[self.room_id],
from_key=0,
)
self.assertEquals( self.assertEquals(
events[0], events[0],
[ [
@ -250,7 +253,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
]) ])
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None) events = yield self.event_source.get_new_events(
room_ids=[self.room_id],
from_key=0
)
self.assertEquals( self.assertEquals(
events[0], events[0],
[ [
@ -306,7 +312,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
yield put_json.await_calls() yield put_json.await_calls()
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None) events = yield self.event_source.get_new_events(
room_ids=[self.room_id],
from_key=0,
)
self.assertEquals( self.assertEquals(
events[0], events[0],
[ [
@ -337,7 +346,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
self.on_new_event.reset_mock() self.on_new_event.reset_mock()
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None) events = yield self.event_source.get_new_events(
room_ids=[self.room_id],
from_key=0,
)
self.assertEquals( self.assertEquals(
events[0], events[0],
[ [
@ -356,7 +368,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
]) ])
self.assertEquals(self.event_source.get_current_key(), 2) self.assertEquals(self.event_source.get_current_key(), 2)
events = yield self.event_source.get_new_events_for_user(self.u_apple, 1, None) events = yield self.event_source.get_new_events(
room_ids=[self.room_id],
from_key=1,
)
self.assertEquals( self.assertEquals(
events[0], events[0],
[ [
@ -383,7 +398,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
self.on_new_event.reset_mock() self.on_new_event.reset_mock()
self.assertEquals(self.event_source.get_current_key(), 3) self.assertEquals(self.event_source.get_current_key(), 3)
events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None) events = yield self.event_source.get_new_events(
room_ids=[self.room_id],
from_key=0,
)
self.assertEquals( self.assertEquals(
events[0], events[0],
[ [

View file

@ -47,7 +47,14 @@ class NullSource(object):
def __init__(self, hs): def __init__(self, hs):
pass pass
def get_new_events_for_user(self, user, from_key, limit): def get_new_events(
self,
user,
from_key,
room_ids=None,
limit=None,
is_guest=None
):
return defer.succeed(([], from_key)) return defer.succeed(([], from_key))
def get_current_key(self, direction='f'): def get_current_key(self, direction='f'):

View file

@ -116,7 +116,10 @@ class RoomTypingTestCase(RestTestCase):
self.assertEquals(200, code) self.assertEquals(200, code)
self.assertEquals(self.event_source.get_current_key(), 1) self.assertEquals(self.event_source.get_current_key(), 1)
events = yield self.event_source.get_new_events_for_user(self.user, 0, None) events = yield self.event_source.get_new_events(
from_key=0,
room_ids=[self.room_id],
)
self.assertEquals( self.assertEquals(
events[0], events[0],
[ [