0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2025-01-22 07:20:15 +01:00

Merge branch 'erikj/events_move' of github.com:matrix-org/synapse into erikj/perf

This commit is contained in:
Erik Johnston 2015-05-18 10:23:37 +01:00
commit 131bdf9bb1
9 changed files with 317 additions and 280 deletions

View file

@ -32,9 +32,9 @@ from synapse.server import HomeServer
from twisted.internet import reactor from twisted.internet import reactor
from twisted.application import service from twisted.application import service
from twisted.enterprise import adbapi from twisted.enterprise import adbapi
from twisted.web.resource import Resource from twisted.web.resource import Resource, EncodingResourceWrapper
from twisted.web.static import File from twisted.web.static import File
from twisted.web.server import Site from twisted.web.server import Site, GzipEncoderFactory
from twisted.web.http import proxiedLogFormatter, combinedLogFormatter from twisted.web.http import proxiedLogFormatter, combinedLogFormatter
from synapse.http.server import JsonResource, RootRedirect from synapse.http.server import JsonResource, RootRedirect
from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.rest.media.v0.content_repository import ContentRepoResource
@ -69,16 +69,26 @@ import subprocess
logger = logging.getLogger("synapse.app.homeserver") logger = logging.getLogger("synapse.app.homeserver")
class GzipFile(File):
def getChild(self, path, request):
child = File.getChild(self, path, request)
return EncodingResourceWrapper(child, [GzipEncoderFactory()])
def gz_wrap(r):
return EncodingResourceWrapper(r, [GzipEncoderFactory()])
class SynapseHomeServer(HomeServer): class SynapseHomeServer(HomeServer):
def build_http_client(self): def build_http_client(self):
return MatrixFederationHttpClient(self) return MatrixFederationHttpClient(self)
def build_resource_for_client(self): def build_resource_for_client(self):
return ClientV1RestResource(self) return gz_wrap(ClientV1RestResource(self))
def build_resource_for_client_v2_alpha(self): def build_resource_for_client_v2_alpha(self):
return ClientV2AlphaRestResource(self) return gz_wrap(ClientV2AlphaRestResource(self))
def build_resource_for_federation(self): def build_resource_for_federation(self):
return JsonResource(self) return JsonResource(self)
@ -87,9 +97,10 @@ class SynapseHomeServer(HomeServer):
import syweb import syweb
syweb_path = os.path.dirname(syweb.__file__) syweb_path = os.path.dirname(syweb.__file__)
webclient_path = os.path.join(syweb_path, "webclient") webclient_path = os.path.join(syweb_path, "webclient")
return File(webclient_path) # TODO configurable? return GzipFile(webclient_path) # TODO configurable?
def build_resource_for_static_content(self): def build_resource_for_static_content(self):
# This is old and should go away: not going to bother adding gzip
return File("static") return File("static")
def build_resource_for_content_repo(self): def build_resource_for_content_repo(self):

View file

@ -22,6 +22,7 @@ from synapse.api.constants import EventTypes
from synapse.types import RoomAlias from synapse.types import RoomAlias
import logging import logging
import string
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -40,6 +41,10 @@ class DirectoryHandler(BaseHandler):
def _create_association(self, room_alias, room_id, servers=None): def _create_association(self, room_alias, room_id, servers=None):
# general association creation for both human users and app services # general association creation for both human users and app services
for wchar in string.whitespace:
if wchar in room_alias.localpart:
raise SynapseError(400, "Invalid characters in room alias")
if not self.hs.is_mine(room_alias): if not self.hs.is_mine(room_alias):
raise SynapseError(400, "Room alias must be local") raise SynapseError(400, "Room alias must be local")
# TODO(erikj): Change this. # TODO(erikj): Change this.

View file

@ -317,6 +317,14 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def user_joined_room(self, user, room_id): def user_joined_room(self, user, room_id):
"""Called via the distributor whenever a user joins a room.
Notifies the new member of the presence of the current members.
Notifies the current members of the room of the new member's presence.
Args:
user(UserID): The user who joined the room.
room_id(str): The room id the user joined.
"""
if self.hs.is_mine(user): if self.hs.is_mine(user):
statuscache = self._get_or_make_usercache(user) statuscache = self._get_or_make_usercache(user)
@ -344,6 +352,7 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def send_invite(self, observer_user, observed_user): def send_invite(self, observer_user, observed_user):
"""Request the presence of a local or remote user for a local user"""
if not self.hs.is_mine(observer_user): if not self.hs.is_mine(observer_user):
raise SynapseError(400, "User is not hosted on this Home Server") raise SynapseError(400, "User is not hosted on this Home Server")
@ -378,6 +387,15 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def invite_presence(self, observed_user, observer_user): def invite_presence(self, observed_user, observer_user):
"""Handles a m.presence_invite EDU. A remote or local user has
requested presence updates for a local user. If the invite is accepted
then allow the local or remote user to see the presence of the local
user.
Args:
observed_user(UserID): The local user whose presence is requested.
observer_user(UserID): The remote or local user requesting presence.
"""
accept = yield self._should_accept_invite(observed_user, observer_user) accept = yield self._should_accept_invite(observed_user, observer_user)
if accept: if accept:
@ -404,6 +422,14 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def accept_presence(self, observed_user, observer_user): def accept_presence(self, observed_user, observer_user):
"""Handles a m.presence_accept EDU. Mark a presence invite from a
local or remote user as accepted in a local user's presence list.
Starts polling for presence updates from the local or remote user.
Args:
observed_user(UserID): The user to update in the presence list.
observer_user(UserID): The owner of the presence list to update.
"""
yield self.store.set_presence_list_accepted( yield self.store.set_presence_list_accepted(
observer_user.localpart, observed_user.to_string() observer_user.localpart, observed_user.to_string()
) )
@ -414,6 +440,16 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def deny_presence(self, observed_user, observer_user): def deny_presence(self, observed_user, observer_user):
"""Handle a m.presence_deny EDU. Removes a local or remote user from a
local user's presence list.
Args:
observed_user(UserID): The local or remote user to remove from the
list.
observer_user(UserID): The local owner of the presence list.
Returns:
A Deferred.
"""
yield self.store.del_presence_list( yield self.store.del_presence_list(
observer_user.localpart, observed_user.to_string() observer_user.localpart, observed_user.to_string()
) )
@ -422,6 +458,16 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def drop(self, observed_user, observer_user): def drop(self, observed_user, observer_user):
"""Remove a local or remote user from a local user's presence list and
unsubscribe the local user from updates that user.
Args:
observed_user(UserId): The local or remote user to remove from the
list.
observer_user(UserId): The local owner of the presence list.
Returns:
A Deferred.
"""
if not self.hs.is_mine(observer_user): if not self.hs.is_mine(observer_user):
raise SynapseError(400, "User is not hosted on this Home Server") raise SynapseError(400, "User is not hosted on this Home Server")
@ -435,6 +481,16 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_presence_list(self, observer_user, accepted=None): def get_presence_list(self, observer_user, accepted=None):
"""Get the presence list for a local user. The retured list includes
the current presence state for each user listed.
Args:
observer_user(UserID): The local user whose presence list to fetch.
accepted(bool or None): If not none then only include users who
have or have not accepted the presence invite request.
Returns:
A Deferred list of presence state events.
"""
if not self.hs.is_mine(observer_user): if not self.hs.is_mine(observer_user):
raise SynapseError(400, "User is not hosted on this Home Server") raise SynapseError(400, "User is not hosted on this Home Server")
@ -456,6 +512,23 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def start_polling_presence(self, user, target_user=None, state=None): def start_polling_presence(self, user, target_user=None, state=None):
"""Subscribe a local user to presence updates from a local or remote
user. If no target_user is supplied then subscribe to all users stored
in the presence list for the local user.
Additonally this pushes the current presence state of this user to all
target_users. That state can be provided directly or will be read from
the stored state for the local user.
Also this attempts to notify the local user of the current state of
any local target users.
Args:
user(UserID): The local user that whishes for presence updates.
target_user(UserID): The local or remote user whose updates are
wanted.
state(dict): Optional presence state for the local user.
"""
logger.debug("Start polling for presence from %s", user) logger.debug("Start polling for presence from %s", user)
if target_user: if target_user:
@ -496,9 +569,7 @@ class PresenceHandler(BaseHandler):
# We want to tell the person that just came online # We want to tell the person that just came online
# presence state of people they are interested in? # presence state of people they are interested in?
self.push_update_to_clients( self.push_update_to_clients(
observed_user=target_user,
users_to_push=[user], users_to_push=[user],
statuscache=self._get_or_offline_usercache(target_user),
) )
deferreds = [] deferreds = []
@ -515,6 +586,12 @@ class PresenceHandler(BaseHandler):
yield defer.DeferredList(deferreds, consumeErrors=True) yield defer.DeferredList(deferreds, consumeErrors=True)
def _start_polling_local(self, user, target_user): def _start_polling_local(self, user, target_user):
"""Subscribe a local user to presence updates for a local user
Args:
user(UserId): The local user that wishes for updates.
target_user(UserId): The local users whose updates are wanted.
"""
target_localpart = target_user.localpart target_localpart = target_user.localpart
if target_localpart not in self._local_pushmap: if target_localpart not in self._local_pushmap:
@ -523,6 +600,17 @@ class PresenceHandler(BaseHandler):
self._local_pushmap[target_localpart].add(user) self._local_pushmap[target_localpart].add(user)
def _start_polling_remote(self, user, domain, remoteusers): def _start_polling_remote(self, user, domain, remoteusers):
"""Subscribe a local user to presence updates for remote users on a
given remote domain.
Args:
user(UserID): The local user that wishes for updates.
domain(str): The remote server the local user wants updates from.
remoteusers(UserID): The remote users that local user wants to be
told about.
Returns:
A Deferred.
"""
to_poll = set() to_poll = set()
for u in remoteusers: for u in remoteusers:
@ -543,6 +631,17 @@ class PresenceHandler(BaseHandler):
@log_function @log_function
def stop_polling_presence(self, user, target_user=None): def stop_polling_presence(self, user, target_user=None):
"""Unsubscribe a local user from presence updates from a local or
remote user. If no target user is supplied then unsubscribe the user
from all presence updates that the user had subscribed to.
Args:
user(UserID): The local user that no longer wishes for updates.
target_user(UserID or None): The user whose updates are no longer
wanted.
Returns:
A Deferred.
"""
logger.debug("Stop polling for presence from %s", user) logger.debug("Stop polling for presence from %s", user)
if not target_user or self.hs.is_mine(target_user): if not target_user or self.hs.is_mine(target_user):
@ -571,6 +670,13 @@ class PresenceHandler(BaseHandler):
return defer.DeferredList(deferreds, consumeErrors=True) return defer.DeferredList(deferreds, consumeErrors=True)
def _stop_polling_local(self, user, target_user): def _stop_polling_local(self, user, target_user):
"""Unsubscribe a local user from presence updates from a local user on
this server.
Args:
user(UserID): The local user that no longer wishes for updates.
target_user(UserID): The user whose updates are no longer wanted.
"""
for localpart in self._local_pushmap.keys(): for localpart in self._local_pushmap.keys():
if target_user and localpart != target_user.localpart: if target_user and localpart != target_user.localpart:
continue continue
@ -583,6 +689,17 @@ class PresenceHandler(BaseHandler):
@log_function @log_function
def _stop_polling_remote(self, user, domain, remoteusers): def _stop_polling_remote(self, user, domain, remoteusers):
"""Unsubscribe a local user from presence updates from remote users on
a given domain.
Args:
user(UserID): The local user that no longer wishes for updates.
domain(str): The remote server to unsubscribe from.
remoteusers([UserID]): The users on that remote server that the
local user no longer wishes to be updated about.
Returns:
A Deferred.
"""
to_unpoll = set() to_unpoll = set()
for u in remoteusers: for u in remoteusers:
@ -604,6 +721,19 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def push_presence(self, user, statuscache): def push_presence(self, user, statuscache):
"""
Notify local and remote users of a change in presence of a local user.
Pushes the update to local clients and remote domains that are directly
subscribed to the presence of the local user.
Also pushes that update to any local user or remote domain that shares
a room with the local user.
Args:
user(UserID): The local user whose presence was updated.
statuscache(UserPresenceCache): Cache of the user's presence state
Returns:
A Deferred.
"""
assert(self.hs.is_mine(user)) assert(self.hs.is_mine(user))
logger.debug("Pushing presence update from %s", user) logger.debug("Pushing presence update from %s", user)
@ -630,45 +760,24 @@ class PresenceHandler(BaseHandler):
) )
yield self.distributor.fire("user_presence_changed", user, statuscache) yield self.distributor.fire("user_presence_changed", user, statuscache)
@defer.inlineCallbacks
def _push_presence_remote(self, user, destination, state=None):
if state is None:
state = yield self.store.get_presence_state(user.localpart)
del state["mtime"]
state["presence"] = state.pop("state")
if user in self._user_cachemap:
state["last_active"] = (
self._user_cachemap[user].get_state()["last_active"]
)
yield self.distributor.fire(
"collect_presencelike_data", user, state
)
if "last_active" in state:
state = dict(state)
state["last_active_ago"] = int(
self.clock.time_msec() - state.pop("last_active")
)
user_state = {
"user_id": user.to_string(),
}
user_state.update(**state)
yield self.federation.send_edu(
destination=destination,
edu_type="m.presence",
content={
"push": [
user_state,
],
}
)
@defer.inlineCallbacks @defer.inlineCallbacks
def incoming_presence(self, origin, content): def incoming_presence(self, origin, content):
"""Handle an incoming m.presence EDU.
For each presence update in the "push" list update our local cache and
notify the appropriate local clients. Only clients that share a room
or are directly subscribed to the presence for a user should be
notified of the update.
For each subscription request in the "poll" list start pushing presence
updates to the remote server.
For unsubscribe request in the "unpoll" list stop pushing presence
updates to the remote server.
Args:
orgin(str): The source of this m.presence EDU.
content(dict): The content of this m.presence EDU.
Returns:
A Deferred.
"""
deferreds = [] deferreds = []
for push in content.get("push", []): for push in content.get("push", []):
@ -712,10 +821,7 @@ class PresenceHandler(BaseHandler):
continue continue
self.push_update_to_clients( self.push_update_to_clients(
observed_user=user, users_to_push=observers, room_ids=room_ids
users_to_push=observers,
room_ids=room_ids,
statuscache=statuscache,
) )
user_id = user.to_string() user_id = user.to_string()
@ -770,6 +876,23 @@ class PresenceHandler(BaseHandler):
def push_update_to_local_and_remote(self, observed_user, statuscache, def push_update_to_local_and_remote(self, observed_user, statuscache,
users_to_push=[], room_ids=[], users_to_push=[], room_ids=[],
remote_domains=[]): remote_domains=[]):
"""Notify local clients and remote servers of a change in the presence
of a user.
Args:
observed_user(UserID): The user to push the presence state for.
statuscache(UserPresenceCache): The cache for the presence state to
push.
users_to_push([UserID]): A list of local and remote users to
notify.
room_ids([str]): Notify the local and remote occupants of these
rooms.
remote_domains([str]): A list of remote servers to notify in
addition to those implied by the users_to_push and the
room_ids.
Returns:
A Deferred.
"""
localusers, remoteusers = partitionbool( localusers, remoteusers = partitionbool(
users_to_push, users_to_push,
@ -779,10 +902,7 @@ class PresenceHandler(BaseHandler):
localusers = set(localusers) localusers = set(localusers)
self.push_update_to_clients( self.push_update_to_clients(
observed_user=observed_user, users_to_push=localusers, room_ids=room_ids
users_to_push=localusers,
room_ids=room_ids,
statuscache=statuscache,
) )
remote_domains = set(remote_domains) remote_domains = set(remote_domains)
@ -807,14 +927,65 @@ class PresenceHandler(BaseHandler):
defer.returnValue((localusers, remote_domains)) defer.returnValue((localusers, remote_domains))
def push_update_to_clients(self, observed_user, users_to_push=[], def push_update_to_clients(self, users_to_push=[], room_ids=[]):
room_ids=[], statuscache=None): """Notify clients of a new presence event.
Args:
users_to_push([UserID]): List of users to notify.
room_ids([str]): List of room_ids to notify.
"""
with PreserveLoggingContext(): with PreserveLoggingContext():
self.notifier.on_new_user_event( self.notifier.on_new_user_event(
users_to_push, users_to_push,
room_ids, room_ids,
) )
@defer.inlineCallbacks
def _push_presence_remote(self, user, destination, state=None):
"""Push a user's presence to a remote server. If a presence state event
that event is sent. Otherwise a new state event is constructed from the
stored presence state.
The last_active is replaced with last_active_ago in case the wallclock
time on the remote server is different to the time on this server.
Sends an EDU to the remote server with the current presence state.
Args:
user(UserID): The user to push the presence state for.
destination(str): The remote server to send state to.
state(dict): The state to push, or None to use the current stored
state.
Returns:
A Deferred.
"""
if state is None:
state = yield self.store.get_presence_state(user.localpart)
del state["mtime"]
state["presence"] = state.pop("state")
if user in self._user_cachemap:
state["last_active"] = (
self._user_cachemap[user].get_state()["last_active"]
)
yield self.distributor.fire(
"collect_presencelike_data", user, state
)
if "last_active" in state:
state = dict(state)
state["last_active_ago"] = int(
self.clock.time_msec() - state.pop("last_active")
)
user_state = {"user_id": user.to_string(), }
user_state.update(state)
yield self.federation.send_edu(
destination=destination,
edu_type="m.presence",
content={"push": [user_state, ], }
)
class PresenceEventSource(object): class PresenceEventSource(object):
def __init__(self, hs): def __init__(self, hs):

View file

@ -88,6 +88,9 @@ class ProfileHandler(BaseHandler):
if target_user != auth_user: if target_user != auth_user:
raise AuthError(400, "Cannot set another user's displayname") raise AuthError(400, "Cannot set another user's displayname")
if new_displayname == '':
new_displayname = None
yield self.store.set_profile_displayname( yield self.store.set_profile_displayname(
target_user.localpart, new_displayname target_user.localpart, new_displayname
) )

View file

@ -26,6 +26,7 @@ from synapse.util.async import run_on_reactor
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
import logging import logging
import string
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -50,6 +51,10 @@ class RoomCreationHandler(BaseHandler):
self.ratelimit(user_id) self.ratelimit(user_id)
if "room_alias_name" in config: if "room_alias_name" in config:
for wchar in string.whitespace:
if wchar in config["room_alias_name"]:
raise SynapseError(400, "Invalid characters in room alias")
room_alias = RoomAlias.create( room_alias = RoomAlias.create(
config["room_alias_name"], config["room_alias_name"],
self.hs.hostname, self.hs.hostname,

View file

@ -16,6 +16,7 @@
from twisted.internet import defer from twisted.internet import defer
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor
from synapse.types import StreamToken from synapse.types import StreamToken
import synapse.metrics import synapse.metrics
@ -49,13 +50,9 @@ class _NotificationListener(object):
so that it can remove itself from the indexes in the Notifier class. so that it can remove itself from the indexes in the Notifier class.
""" """
def __init__(self, user, rooms, from_token, limit, timeout, deferred, def __init__(self, user, rooms, deferred, appservice=None):
appservice=None):
self.user = user self.user = user
self.appservice = appservice self.appservice = appservice
self.from_token = from_token
self.limit = limit
self.timeout = timeout
self.deferred = deferred self.deferred = deferred
self.rooms = rooms self.rooms = rooms
self.timer = None self.timer = None
@ -63,17 +60,14 @@ class _NotificationListener(object):
def notified(self): def notified(self):
return self.deferred.called return self.deferred.called
def notify(self, notifier, events, start_token, end_token): def notify(self, notifier):
""" Inform whoever is listening about the new events. This will """ Inform whoever is listening about the new events. This will
also remove this listener from all the indexes in the Notifier also remove this listener from all the indexes in the Notifier
it knows about. it knows about.
""" """
result = (events, (start_token, end_token))
try: try:
self.deferred.callback(result) self.deferred.callback(None)
notified_events_counter.inc_by(len(events))
except defer.AlreadyCalledError: except defer.AlreadyCalledError:
pass pass
@ -160,6 +154,7 @@ class Notifier(object):
listening to the room, and any listeners for the users in the listening to the room, and any listeners for the users in the
`extra_users` param. `extra_users` param.
""" """
yield run_on_reactor()
# poke any interested application service. # poke any interested application service.
self.hs.get_handlers().appservice_handler.notify_interested_services( self.hs.get_handlers().appservice_handler.notify_interested_services(
event event
@ -167,8 +162,6 @@ class Notifier(object):
room_id = event.room_id room_id = event.room_id
room_source = self.event_sources.sources["room"]
room_listeners = self.room_to_listeners.get(room_id, set()) room_listeners = self.room_to_listeners.get(room_id, set())
_discard_if_notified(room_listeners) _discard_if_notified(room_listeners)
@ -199,33 +192,11 @@ class Notifier(object):
logger.debug("on_new_room_event listeners %s", listeners) logger.debug("on_new_room_event listeners %s", listeners)
# TODO (erikj): Can we make this more efficient by hitting the for listener in listeners:
# db once? try:
listener.notify(self)
@defer.inlineCallbacks except:
def notify(listener): logger.exception("Failed to notify listener")
events, end_key = yield room_source.get_new_events_for_user(
listener.user,
listener.from_token.room_key,
listener.limit,
)
if events:
end_token = listener.from_token.copy_and_replace(
"room_key", end_key
)
listener.notify(
self, events, listener.from_token, end_token
)
def eb(failure):
logger.exception("Failed to notify listener", failure)
yield defer.DeferredList(
[notify(l).addErrback(eb) for l in listeners],
consumeErrors=True,
)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@ -235,11 +206,7 @@ class Notifier(object):
Will wake up all listeners for the given users and rooms. Will wake up all listeners for the given users and rooms.
""" """
# TODO(paul): This is horrible, having to manually list every event yield run_on_reactor()
# source here individually
presence_source = self.event_sources.sources["presence"]
typing_source = self.event_sources.sources["typing"]
listeners = set() listeners = set()
for user in users: for user in users:
@ -256,68 +223,29 @@ class Notifier(object):
listeners |= room_listeners listeners |= room_listeners
@defer.inlineCallbacks for listener in listeners:
def notify(listener): try:
presence_events, presence_end_key = ( listener.notify(self)
yield presence_source.get_new_events_for_user( except:
listener.user, logger.exception("Failed to notify listener")
listener.from_token.presence_key,
listener.limit,
)
)
typing_events, typing_end_key = (
yield typing_source.get_new_events_for_user(
listener.user,
listener.from_token.typing_key,
listener.limit,
)
)
if presence_events or typing_events:
end_token = listener.from_token.copy_and_replace(
"presence_key", presence_end_key
).copy_and_replace(
"typing_key", typing_end_key
)
listener.notify(
self,
presence_events + typing_events,
listener.from_token,
end_token
)
def eb(failure):
logger.error(
"Failed to notify listener",
exc_info=(
failure.type,
failure.value,
failure.getTracebackObject())
)
yield defer.DeferredList(
[notify(l).addErrback(eb) for l in listeners],
consumeErrors=True,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def wait_for_events(self, user, rooms, filter, timeout, callback): def wait_for_events(self, user, rooms, timeout, callback,
from_token=StreamToken("s0", "0", "0")):
"""Wait until the callback returns a non empty response or the """Wait until the callback returns a non empty response or the
timeout fires. timeout fires.
""" """
deferred = defer.Deferred() deferred = defer.Deferred()
appservice = yield self.hs.get_datastore().get_app_service_by_user_id(
from_token = StreamToken("s0", "0", "0") user.to_string()
)
listener = [_NotificationListener( listener = [_NotificationListener(
user=user, user=user,
rooms=rooms, rooms=rooms,
from_token=from_token,
limit=1,
timeout=timeout,
deferred=deferred, deferred=deferred,
appservice=appservice,
)] )]
if timeout: if timeout:
@ -332,7 +260,7 @@ class Notifier(object):
def _timeout_listener(): def _timeout_listener():
timed_out[0] = True timed_out[0] = True
timer[0] = None timer[0] = None
listener[0].notify(self, [], from_token, from_token) listener[0].notify(self)
# We create multiple notification listeners so we have to manage # We create multiple notification listeners so we have to manage
# canceling the timeout ourselves. # canceling the timeout ourselves.
@ -344,10 +272,8 @@ class Notifier(object):
listener[0] = _NotificationListener( listener[0] = _NotificationListener(
user=user, user=user,
rooms=rooms, rooms=rooms,
from_token=from_token,
limit=1,
timeout=timeout,
deferred=deferred, deferred=deferred,
appservice=appservice,
) )
self._register_with_keys(listener[0]) self._register_with_keys(listener[0])
result = yield callback() result = yield callback()
@ -360,65 +286,43 @@ class Notifier(object):
defer.returnValue(result) defer.returnValue(result)
@defer.inlineCallbacks
def get_events_for(self, user, rooms, pagination_config, timeout): def get_events_for(self, user, rooms, pagination_config, timeout):
""" For the given user and rooms, return any new events for them. If """ For the given user and rooms, return any new events for them. If
there are no new events wait for up to `timeout` milliseconds for any there are no new events wait for up to `timeout` milliseconds for any
new events to happen before returning. new events to happen before returning.
""" """
deferred = defer.Deferred() from_token = pagination_config.from_token
self._get_events(
deferred, user, rooms, pagination_config.from_token,
pagination_config.limit, timeout
).addErrback(deferred.errback)
return deferred
@defer.inlineCallbacks
def _get_events(self, deferred, user, rooms, from_token, limit, timeout):
if not from_token: if not from_token:
from_token = yield self.event_sources.get_current_token() from_token = yield self.event_sources.get_current_token()
appservice = yield self.hs.get_datastore().get_app_service_by_user_id( limit = pagination_config.limit
user.to_string()
@defer.inlineCallbacks
def check_for_updates():
events = []
end_token = from_token
for name, source in self.event_sources.sources.items():
keyname = "%s_key" % name
stuff, new_key = yield source.get_new_events_for_user(
user, getattr(from_token, keyname), limit,
) )
events.extend(stuff)
end_token = end_token.copy_and_replace(keyname, new_key)
listener = _NotificationListener( if events:
user, defer.returnValue((events, (from_token, end_token)))
rooms,
from_token,
limit,
timeout,
deferred,
appservice=appservice
)
def _timeout_listener():
# TODO (erikj): We should probably set to_token to the current
# max rather than reusing from_token.
# Remove the timer from the listener so we don't try to cancel it.
listener.timer = None
listener.notify(
self,
[],
listener.from_token,
listener.from_token,
)
if timeout:
self._register_with_keys(listener)
yield self._check_for_updates(listener)
if not timeout:
_timeout_listener()
else: else:
# Only add the timer if the listener hasn't been notified defer.returnValue(None)
if not listener.notified():
listener.timer = self.clock.call_later( result = yield self.wait_for_events(
timeout/1000.0, _timeout_listener user, rooms, timeout, check_for_updates, from_token=from_token
) )
return
if result is None:
result = ([], (from_token, from_token))
defer.returnValue(result)
@log_function @log_function
def _register_with_keys(self, listener): def _register_with_keys(self, listener):
@ -433,36 +337,6 @@ class Notifier(object):
listener.appservice, set() listener.appservice, set()
).add(listener) ).add(listener)
@defer.inlineCallbacks
@log_function
def _check_for_updates(self, listener):
# TODO (erikj): We need to think about limits across multiple sources
events = []
from_token = listener.from_token
limit = listener.limit
# TODO (erikj): DeferredList?
for name, source in self.event_sources.sources.items():
keyname = "%s_key" % name
stuff, new_key = yield source.get_new_events_for_user(
listener.user,
getattr(from_token, keyname),
limit,
)
events.extend(stuff)
from_token = from_token.copy_and_replace(keyname, new_key)
end_token = from_token
if events:
listener.notify(self, events, listener.from_token, end_token)
defer.returnValue(listener)
def _user_joined_room(self, user, room_id): def _user_joined_room(self, user, room_id):
new_listeners = self.user_to_listeners.get(user, set()) new_listeners = self.user_to_listeners.get(user, set())

View file

@ -82,8 +82,10 @@ class RegisterRestServlet(RestServlet):
[LoginType.EMAIL_IDENTITY] [LoginType.EMAIL_IDENTITY]
] ]
result = None
if service: if service:
is_application_server = True is_application_server = True
params = body
elif 'mac' in body: elif 'mac' in body:
# Check registration-specific shared secret auth # Check registration-specific shared secret auth
if 'username' not in body: if 'username' not in body:
@ -92,6 +94,7 @@ class RegisterRestServlet(RestServlet):
body['username'], body['mac'] body['username'], body['mac']
) )
is_using_shared_secret = True is_using_shared_secret = True
params = body
else: else:
authed, result, params = yield self.auth_handler.check_auth( authed, result, params = yield self.auth_handler.check_auth(
flows, body, self.hs.get_ip_from_request(request) flows, body, self.hs.get_ip_from_request(request)
@ -118,7 +121,7 @@ class RegisterRestServlet(RestServlet):
password=new_password password=new_password
) )
if LoginType.EMAIL_IDENTITY in result: if result and LoginType.EMAIL_IDENTITY in result:
threepid = result[LoginType.EMAIL_IDENTITY] threepid = result[LoginType.EMAIL_IDENTITY]
for reqd in ['medium', 'address', 'validated_at']: for reqd in ['medium', 'address', 'validated_at']:

View file

@ -1097,12 +1097,8 @@ class PresencePollingTestCase(MockedDatastorePresenceTestCase):
# apple should see both banana and clementine currently offline # apple should see both banana and clementine currently offline
self.mock_update_client.assert_has_calls([ self.mock_update_client.assert_has_calls([
call(users_to_push=[self.u_apple], call(users_to_push=[self.u_apple]),
observed_user=self.u_banana, call(users_to_push=[self.u_apple]),
statuscache=ANY),
call(users_to_push=[self.u_apple],
observed_user=self.u_clementine,
statuscache=ANY),
], any_order=True) ], any_order=True)
# Gut-wrenching tests # Gut-wrenching tests
@ -1121,13 +1117,8 @@ class PresencePollingTestCase(MockedDatastorePresenceTestCase):
# apple and banana should now both see each other online # apple and banana should now both see each other online
self.mock_update_client.assert_has_calls([ self.mock_update_client.assert_has_calls([
call(users_to_push=set([self.u_apple]), call(users_to_push=set([self.u_apple]), room_ids=[]),
observed_user=self.u_banana, call(users_to_push=[self.u_banana]),
room_ids=[],
statuscache=ANY),
call(users_to_push=[self.u_banana],
observed_user=self.u_apple,
statuscache=ANY),
], any_order=True) ], any_order=True)
self.assertTrue("apple" in self.handler._local_pushmap) self.assertTrue("apple" in self.handler._local_pushmap)
@ -1143,10 +1134,7 @@ class PresencePollingTestCase(MockedDatastorePresenceTestCase):
# banana should now be told apple is offline # banana should now be told apple is offline
self.mock_update_client.assert_has_calls([ self.mock_update_client.assert_has_calls([
call(users_to_push=set([self.u_banana, self.u_apple]), call(users_to_push=set([self.u_banana, self.u_apple]), room_ids=[]),
observed_user=self.u_apple,
room_ids=[],
statuscache=ANY),
], any_order=True) ], any_order=True)
self.assertFalse("banana" in self.handler._local_pushmap) self.assertFalse("banana" in self.handler._local_pushmap)

View file

@ -209,20 +209,12 @@ class PresenceProfilelikeDataTestCase(unittest.TestCase):
], presence) ], presence)
self.mock_update_client.assert_has_calls([ self.mock_update_client.assert_has_calls([
call(users_to_push=set([self.u_apple, self.u_banana, self.u_clementine]), call(
room_ids=[], users_to_push={self.u_apple, self.u_banana, self.u_clementine},
observed_user=self.u_apple, room_ids=[]
statuscache=ANY), # self-reflection ),
], any_order=True) ], any_order=True)
statuscache = self.mock_update_client.call_args[1]["statuscache"]
self.assertEquals({
"presence": ONLINE,
"last_active": 1000000, # MockClock
"displayname": "Frank",
"avatar_url": "http://foo",
}, statuscache.state)
self.mock_update_client.reset_mock() self.mock_update_client.reset_mock()
self.datastore.set_profile_displayname.return_value = defer.succeed( self.datastore.set_profile_displayname.return_value = defer.succeed(
@ -232,21 +224,12 @@ class PresenceProfilelikeDataTestCase(unittest.TestCase):
self.u_apple, "I am an Apple") self.u_apple, "I am an Apple")
self.mock_update_client.assert_has_calls([ self.mock_update_client.assert_has_calls([
call(users_to_push=set([self.u_apple, self.u_banana, self.u_clementine]), call(
users_to_push={self.u_apple, self.u_banana, self.u_clementine},
room_ids=[], room_ids=[],
observed_user=self.u_apple, ),
statuscache=ANY), # self-reflection
], any_order=True) ], any_order=True)
statuscache = self.mock_update_client.call_args[1]["statuscache"]
self.assertEquals({
"presence": ONLINE,
"last_active": 1000000, # MockClock
"displayname": "I am an Apple",
"avatar_url": "http://foo",
}, statuscache.state)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_push_remote(self): def test_push_remote(self):
self.presence_list = [ self.presence_list = [
@ -314,13 +297,7 @@ class PresenceProfilelikeDataTestCase(unittest.TestCase):
self.mock_update_client.assert_called_with( self.mock_update_client.assert_called_with(
users_to_push=set([self.u_apple]), users_to_push=set([self.u_apple]),
room_ids=[], room_ids=[],
observed_user=self.u_potato, )
statuscache=ANY)
statuscache = self.mock_update_client.call_args[1]["statuscache"]
self.assertEquals({"presence": ONLINE,
"displayname": "Frank",
"avatar_url": "http://foo"}, statuscache.state)
state = yield self.handlers.presence_handler.get_state(self.u_potato, state = yield self.handlers.presence_handler.get_state(self.u_potato,
self.u_apple) self.u_apple)