0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-14 16:33:53 +01:00

Merge pull request #2847 from matrix-org/erikj/separate_event_creation

Split event creation into a separate handler
This commit is contained in:
Erik Johnston 2018-02-06 17:01:17 +00:00 committed by GitHub
commit 617199d73d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 247 additions and 226 deletions

View file

@ -34,6 +34,7 @@ class DirectoryHandler(BaseHandler):
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.appservice_handler = hs.get_application_service_handler() self.appservice_handler = hs.get_application_service_handler()
self.event_creation_handler = hs.get_event_creation_handler()
self.federation = hs.get_replication_layer() self.federation = hs.get_replication_layer()
self.federation.register_query_handler( self.federation.register_query_handler(
@ -249,8 +250,7 @@ class DirectoryHandler(BaseHandler):
def send_room_alias_update_event(self, requester, user_id, room_id): def send_room_alias_update_event(self, requester, user_id, room_id):
aliases = yield self.store.get_aliases_for_room(room_id) aliases = yield self.store.get_aliases_for_room(room_id)
msg_handler = self.hs.get_handlers().message_handler yield self.event_creation_handler.create_and_send_nonmember_event(
yield msg_handler.create_and_send_nonmember_event(
requester, requester,
{ {
"type": EventTypes.Aliases, "type": EventTypes.Aliases,
@ -272,8 +272,7 @@ class DirectoryHandler(BaseHandler):
if not alias_event or alias_event.content.get("alias", "") != alias_str: if not alias_event or alias_event.content.get("alias", "") != alias_str:
return return
msg_handler = self.hs.get_handlers().message_handler yield self.event_creation_handler.create_and_send_nonmember_event(
yield msg_handler.create_and_send_nonmember_event(
requester, requester,
{ {
"type": EventTypes.CanonicalAlias, "type": EventTypes.CanonicalAlias,

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -75,6 +76,7 @@ class FederationHandler(BaseHandler):
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
self.pusher_pool = hs.get_pusherpool() self.pusher_pool = hs.get_pusherpool()
self.spam_checker = hs.get_spam_checker() self.spam_checker = hs.get_spam_checker()
self.event_creation_handler = hs.get_event_creation_handler()
self.replication_layer.set_handler(self) self.replication_layer.set_handler(self)
@ -1007,8 +1009,7 @@ class FederationHandler(BaseHandler):
}) })
try: try:
message_handler = self.hs.get_handlers().message_handler event, context = yield self.event_creation_handler.create_new_client_event(
event, context = yield message_handler._create_new_client_event(
builder=builder, builder=builder,
) )
except AuthError as e: except AuthError as e:
@ -1248,8 +1249,7 @@ class FederationHandler(BaseHandler):
"state_key": user_id, "state_key": user_id,
}) })
message_handler = self.hs.get_handlers().message_handler event, context = yield self.event_creation_handler.create_new_client_event(
event, context = yield message_handler._create_new_client_event(
builder=builder, builder=builder,
) )
@ -2130,8 +2130,7 @@ class FederationHandler(BaseHandler):
if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)): if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)):
builder = self.event_builder_factory.new(event_dict) builder = self.event_builder_factory.new(event_dict)
EventValidator().validate_new(builder) EventValidator().validate_new(builder)
message_handler = self.hs.get_handlers().message_handler event, context = yield self.event_creation_handler.create_new_client_event(
event, context = yield message_handler._create_new_client_event(
builder=builder builder=builder
) )
@ -2169,8 +2168,7 @@ class FederationHandler(BaseHandler):
""" """
builder = self.event_builder_factory.new(event_dict) builder = self.event_builder_factory.new(event_dict)
message_handler = self.hs.get_handlers().message_handler event, context = yield self.event_creation_handler.create_new_client_event(
event, context = yield message_handler._create_new_client_event(
builder=builder, builder=builder,
) )
@ -2220,8 +2218,9 @@ class FederationHandler(BaseHandler):
builder = self.event_builder_factory.new(event_dict) builder = self.event_builder_factory.new(event_dict)
EventValidator().validate_new(builder) EventValidator().validate_new(builder)
message_handler = self.hs.get_handlers().message_handler event, context = yield self.event_creation_handler.create_new_client_event(
event, context = yield message_handler._create_new_client_event(builder=builder) builder=builder,
)
defer.returnValue((event, context)) defer.returnValue((event, context))
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd # Copyright 2014 - 2016 OpenMarket Ltd
# Copyright 2017 New Vector Ltd # Copyright 2017 - 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -47,21 +47,9 @@ class MessageHandler(BaseHandler):
self.hs = hs self.hs = hs
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.validator = EventValidator()
self.profile_handler = hs.get_profile_handler()
self.pagination_lock = ReadWriteLock() self.pagination_lock = ReadWriteLock()
self.pusher_pool = hs.get_pusherpool()
# We arbitrarily limit concurrent event creation for a room to 5.
# This is to stop us from diverging history *too* much.
self.limiter = Limiter(max_count=5)
self.action_generator = hs.get_action_generator()
self.spam_checker = hs.get_spam_checker()
@defer.inlineCallbacks @defer.inlineCallbacks
def purge_history(self, room_id, event_id): def purge_history(self, room_id, event_id):
event = yield self.store.get_event(event_id) event = yield self.store.get_event(event_id)
@ -182,166 +170,6 @@ class MessageHandler(BaseHandler):
defer.returnValue(chunk) defer.returnValue(chunk)
@defer.inlineCallbacks
def create_event(self, requester, event_dict, token_id=None, txn_id=None,
prev_event_ids=None):
"""
Given a dict from a client, create a new event.
Creates an FrozenEvent object, filling out auth_events, prev_events,
etc.
Adds display names to Join membership events.
Args:
requester
event_dict (dict): An entire event
token_id (str)
txn_id (str)
prev_event_ids (list): The prev event ids to use when creating the event
Returns:
Tuple of created event (FrozenEvent), Context
"""
builder = self.event_builder_factory.new(event_dict)
with (yield self.limiter.queue(builder.room_id)):
self.validator.validate_new(builder)
if builder.type == EventTypes.Member:
membership = builder.content.get("membership", None)
target = UserID.from_string(builder.state_key)
if membership in {Membership.JOIN, Membership.INVITE}:
# If event doesn't include a display name, add one.
profile = self.profile_handler
content = builder.content
try:
if "displayname" not in content:
content["displayname"] = yield profile.get_displayname(target)
if "avatar_url" not in content:
content["avatar_url"] = yield profile.get_avatar_url(target)
except Exception as e:
logger.info(
"Failed to get profile information for %r: %s",
target, e
)
if token_id is not None:
builder.internal_metadata.token_id = token_id
if txn_id is not None:
builder.internal_metadata.txn_id = txn_id
event, context = yield self._create_new_client_event(
builder=builder,
requester=requester,
prev_event_ids=prev_event_ids,
)
defer.returnValue((event, context))
@defer.inlineCallbacks
def send_nonmember_event(self, requester, event, context, ratelimit=True):
"""
Persists and notifies local clients and federation of an event.
Args:
event (FrozenEvent) the event to send.
context (Context) the context of the event.
ratelimit (bool): Whether to rate limit this send.
is_guest (bool): Whether the sender is a guest.
"""
if event.type == EventTypes.Member:
raise SynapseError(
500,
"Tried to send member event through non-member codepath"
)
# We check here if we are currently being rate limited, so that we
# don't do unnecessary work. We check again just before we actually
# send the event.
yield self.ratelimit(requester, update=False)
user = UserID.from_string(event.sender)
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
if event.is_state():
prev_state = yield self.deduplicate_state_event(event, context)
if prev_state is not None:
defer.returnValue(prev_state)
yield self.handle_new_client_event(
requester=requester,
event=event,
context=context,
ratelimit=ratelimit,
)
if event.type == EventTypes.Message:
presence = self.hs.get_presence_handler()
# We don't want to block sending messages on any presence code. This
# matters as sometimes presence code can take a while.
preserve_fn(presence.bump_presence_active_time)(user)
@defer.inlineCallbacks
def deduplicate_state_event(self, event, context):
"""
Checks whether event is in the latest resolved state in context.
If so, returns the version of the event in context.
Otherwise, returns None.
"""
prev_event_id = context.prev_state_ids.get((event.type, event.state_key))
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
if not prev_event:
return
if prev_event and event.user_id == prev_event.user_id:
prev_content = encode_canonical_json(prev_event.content)
next_content = encode_canonical_json(event.content)
if prev_content == next_content:
defer.returnValue(prev_event)
return
@defer.inlineCallbacks
def create_and_send_nonmember_event(
self,
requester,
event_dict,
ratelimit=True,
txn_id=None
):
"""
Creates an event, then sends it.
See self.create_event and self.send_nonmember_event.
"""
event, context = yield self.create_event(
requester,
event_dict,
token_id=requester.access_token_id,
txn_id=txn_id
)
spam_error = self.spam_checker.check_event_for_spam(event)
if spam_error:
if not isinstance(spam_error, basestring):
spam_error = "Spam is not permitted here"
raise SynapseError(
403, spam_error, Codes.FORBIDDEN
)
yield self.send_nonmember_event(
requester,
event,
context,
ratelimit=ratelimit,
)
defer.returnValue(event)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_room_data(self, user_id=None, room_id=None, def get_room_data(self, user_id=None, room_id=None,
event_type=None, state_key="", is_guest=False): event_type=None, state_key="", is_guest=False):
@ -470,9 +298,197 @@ class MessageHandler(BaseHandler):
for user_id, profile in users_with_profile.iteritems() for user_id, profile in users_with_profile.iteritems()
}) })
@measure_func("_create_new_client_event")
class EventCreationHandler(object):
def __init__(self, hs):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
self.clock = hs.get_clock()
self.validator = EventValidator()
self.profile_handler = hs.get_profile_handler()
self.event_builder_factory = hs.get_event_builder_factory()
self.server_name = hs.hostname
self.ratelimiter = hs.get_ratelimiter()
self.notifier = hs.get_notifier()
# This is only used to get at ratelimit function, and maybe_kick_guest_users
self.base_handler = BaseHandler(hs)
self.pusher_pool = hs.get_pusherpool()
# We arbitrarily limit concurrent event creation for a room to 5.
# This is to stop us from diverging history *too* much.
self.limiter = Limiter(max_count=5)
self.action_generator = hs.get_action_generator()
self.spam_checker = hs.get_spam_checker()
@defer.inlineCallbacks @defer.inlineCallbacks
def _create_new_client_event(self, builder, requester=None, prev_event_ids=None): def create_event(self, requester, event_dict, token_id=None, txn_id=None,
prev_event_ids=None):
"""
Given a dict from a client, create a new event.
Creates an FrozenEvent object, filling out auth_events, prev_events,
etc.
Adds display names to Join membership events.
Args:
requester
event_dict (dict): An entire event
token_id (str)
txn_id (str)
prev_event_ids (list): The prev event ids to use when creating the event
Returns:
Tuple of created event (FrozenEvent), Context
"""
builder = self.event_builder_factory.new(event_dict)
with (yield self.limiter.queue(builder.room_id)):
self.validator.validate_new(builder)
if builder.type == EventTypes.Member:
membership = builder.content.get("membership", None)
target = UserID.from_string(builder.state_key)
if membership in {Membership.JOIN, Membership.INVITE}:
# If event doesn't include a display name, add one.
profile = self.profile_handler
content = builder.content
try:
if "displayname" not in content:
content["displayname"] = yield profile.get_displayname(target)
if "avatar_url" not in content:
content["avatar_url"] = yield profile.get_avatar_url(target)
except Exception as e:
logger.info(
"Failed to get profile information for %r: %s",
target, e
)
if token_id is not None:
builder.internal_metadata.token_id = token_id
if txn_id is not None:
builder.internal_metadata.txn_id = txn_id
event, context = yield self.create_new_client_event(
builder=builder,
requester=requester,
prev_event_ids=prev_event_ids,
)
defer.returnValue((event, context))
@defer.inlineCallbacks
def send_nonmember_event(self, requester, event, context, ratelimit=True):
"""
Persists and notifies local clients and federation of an event.
Args:
event (FrozenEvent) the event to send.
context (Context) the context of the event.
ratelimit (bool): Whether to rate limit this send.
is_guest (bool): Whether the sender is a guest.
"""
if event.type == EventTypes.Member:
raise SynapseError(
500,
"Tried to send member event through non-member codepath"
)
# We check here if we are currently being rate limited, so that we
# don't do unnecessary work. We check again just before we actually
# send the event.
yield self.base_handler.ratelimit(requester, update=False)
user = UserID.from_string(event.sender)
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
if event.is_state():
prev_state = yield self.deduplicate_state_event(event, context)
if prev_state is not None:
defer.returnValue(prev_state)
yield self.handle_new_client_event(
requester=requester,
event=event,
context=context,
ratelimit=ratelimit,
)
if event.type == EventTypes.Message:
presence = self.hs.get_presence_handler()
# We don't want to block sending messages on any presence code. This
# matters as sometimes presence code can take a while.
preserve_fn(presence.bump_presence_active_time)(user)
@defer.inlineCallbacks
def deduplicate_state_event(self, event, context):
"""
Checks whether event is in the latest resolved state in context.
If so, returns the version of the event in context.
Otherwise, returns None.
"""
prev_event_id = context.prev_state_ids.get((event.type, event.state_key))
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
if not prev_event:
return
if prev_event and event.user_id == prev_event.user_id:
prev_content = encode_canonical_json(prev_event.content)
next_content = encode_canonical_json(event.content)
if prev_content == next_content:
defer.returnValue(prev_event)
return
@defer.inlineCallbacks
def create_and_send_nonmember_event(
self,
requester,
event_dict,
ratelimit=True,
txn_id=None
):
"""
Creates an event, then sends it.
See self.create_event and self.send_nonmember_event.
"""
event, context = yield self.create_event(
requester,
event_dict,
token_id=requester.access_token_id,
txn_id=txn_id
)
spam_error = self.spam_checker.check_event_for_spam(event)
if spam_error:
if not isinstance(spam_error, basestring):
spam_error = "Spam is not permitted here"
raise SynapseError(
403, spam_error, Codes.FORBIDDEN
)
yield self.send_nonmember_event(
requester,
event,
context,
ratelimit=ratelimit,
)
defer.returnValue(event)
@measure_func("create_new_client_event")
@defer.inlineCallbacks
def create_new_client_event(self, builder, requester=None, prev_event_ids=None):
if prev_event_ids: if prev_event_ids:
prev_events = yield self.store.add_event_hashes(prev_event_ids) prev_events = yield self.store.add_event_hashes(prev_event_ids)
prev_max_depth = yield self.store.get_max_depth_of_events(prev_event_ids) prev_max_depth = yield self.store.get_max_depth_of_events(prev_event_ids)
@ -509,9 +525,7 @@ class MessageHandler(BaseHandler):
builder.prev_events = prev_events builder.prev_events = prev_events
builder.depth = depth builder.depth = depth
state_handler = self.state_handler context = yield self.state.compute_event_context(builder)
context = yield state_handler.compute_event_context(builder)
if requester: if requester:
context.app_service = requester.app_service context.app_service = requester.app_service
@ -551,7 +565,7 @@ class MessageHandler(BaseHandler):
# We now need to go and hit out to wherever we need to hit out to. # We now need to go and hit out to wherever we need to hit out to.
if ratelimit: if ratelimit:
yield self.ratelimit(requester) yield self.base_handler.ratelimit(requester)
try: try:
yield self.auth.check_from_context(event, context) yield self.auth.check_from_context(event, context)
@ -567,7 +581,7 @@ class MessageHandler(BaseHandler):
logger.exception("Failed to encode content: %r", event.content) logger.exception("Failed to encode content: %r", event.content)
raise raise
yield self.maybe_kick_guest_users(event, context) yield self.base_handler.maybe_kick_guest_users(event, context)
if event.type == EventTypes.CanonicalAlias: if event.type == EventTypes.CanonicalAlias:
# Check the alias is acually valid (at this time at least) # Check the alias is acually valid (at this time at least)

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd # Copyright 2014 - 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -64,6 +65,7 @@ class RoomCreationHandler(BaseHandler):
super(RoomCreationHandler, self).__init__(hs) super(RoomCreationHandler, self).__init__(hs)
self.spam_checker = hs.get_spam_checker() self.spam_checker = hs.get_spam_checker()
self.event_creation_handler = hs.get_event_creation_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def create_room(self, requester, config, ratelimit=True): def create_room(self, requester, config, ratelimit=True):
@ -163,13 +165,11 @@ class RoomCreationHandler(BaseHandler):
creation_content = config.get("creation_content", {}) creation_content = config.get("creation_content", {})
msg_handler = self.hs.get_handlers().message_handler
room_member_handler = self.hs.get_handlers().room_member_handler room_member_handler = self.hs.get_handlers().room_member_handler
yield self._send_events_for_new_room( yield self._send_events_for_new_room(
requester, requester,
room_id, room_id,
msg_handler,
room_member_handler, room_member_handler,
preset_config=preset_config, preset_config=preset_config,
invite_list=invite_list, invite_list=invite_list,
@ -181,7 +181,7 @@ class RoomCreationHandler(BaseHandler):
if "name" in config: if "name" in config:
name = config["name"] name = config["name"]
yield msg_handler.create_and_send_nonmember_event( yield self.event_creation_handler.create_and_send_nonmember_event(
requester, requester,
{ {
"type": EventTypes.Name, "type": EventTypes.Name,
@ -194,7 +194,7 @@ class RoomCreationHandler(BaseHandler):
if "topic" in config: if "topic" in config:
topic = config["topic"] topic = config["topic"]
yield msg_handler.create_and_send_nonmember_event( yield self.event_creation_handler.create_and_send_nonmember_event(
requester, requester,
{ {
"type": EventTypes.Topic, "type": EventTypes.Topic,
@ -249,7 +249,6 @@ class RoomCreationHandler(BaseHandler):
self, self,
creator, # A Requester object. creator, # A Requester object.
room_id, room_id,
msg_handler,
room_member_handler, room_member_handler,
preset_config, preset_config,
invite_list, invite_list,
@ -272,7 +271,7 @@ class RoomCreationHandler(BaseHandler):
@defer.inlineCallbacks @defer.inlineCallbacks
def send(etype, content, **kwargs): def send(etype, content, **kwargs):
event = create(etype, content, **kwargs) event = create(etype, content, **kwargs)
yield msg_handler.create_and_send_nonmember_event( yield self.event_creation_handler.create_and_send_nonmember_event(
creator, creator,
event, event,
ratelimit=False ratelimit=False

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd # Copyright 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -46,6 +47,7 @@ class RoomMemberHandler(BaseHandler):
super(RoomMemberHandler, self).__init__(hs) super(RoomMemberHandler, self).__init__(hs)
self.profile_handler = hs.get_profile_handler() self.profile_handler = hs.get_profile_handler()
self.event_creation_hander = hs.get_event_creation_handler()
self.member_linearizer = Linearizer(name="member") self.member_linearizer = Linearizer(name="member")
@ -66,13 +68,12 @@ class RoomMemberHandler(BaseHandler):
): ):
if content is None: if content is None:
content = {} content = {}
msg_handler = self.hs.get_handlers().message_handler
content["membership"] = membership content["membership"] = membership
if requester.is_guest: if requester.is_guest:
content["kind"] = "guest" content["kind"] = "guest"
event, context = yield msg_handler.create_event( event, context = yield self.event_creation_hander.create_event(
requester, requester,
{ {
"type": EventTypes.Member, "type": EventTypes.Member,
@ -90,12 +91,14 @@ class RoomMemberHandler(BaseHandler):
) )
# Check if this event matches the previous membership event for the user. # Check if this event matches the previous membership event for the user.
duplicate = yield msg_handler.deduplicate_state_event(event, context) duplicate = yield self.event_creation_hander.deduplicate_state_event(
event, context,
)
if duplicate is not None: if duplicate is not None:
# Discard the new event since this membership change is a no-op. # Discard the new event since this membership change is a no-op.
defer.returnValue(duplicate) defer.returnValue(duplicate)
yield msg_handler.handle_new_client_event( yield self.event_creation_hander.handle_new_client_event(
requester, requester,
event, event,
context, context,
@ -394,8 +397,9 @@ class RoomMemberHandler(BaseHandler):
else: else:
requester = synapse.types.create_requester(target_user) requester = synapse.types.create_requester(target_user)
message_handler = self.hs.get_handlers().message_handler prev_event = yield self.event_creation_hander.deduplicate_state_event(
prev_event = yield message_handler.deduplicate_state_event(event, context) event, context,
)
if prev_event is not None: if prev_event is not None:
return return
@ -412,7 +416,7 @@ class RoomMemberHandler(BaseHandler):
if is_blocked: if is_blocked:
raise SynapseError(403, "This room has been blocked on this server") raise SynapseError(403, "This room has been blocked on this server")
yield message_handler.handle_new_client_event( yield self.event_creation_hander.handle_new_client_event(
requester, requester,
event, event,
context, context,
@ -644,8 +648,7 @@ class RoomMemberHandler(BaseHandler):
) )
) )
msg_handler = self.hs.get_handlers().message_handler yield self.event_creation_hander.create_and_send_nonmember_event(
yield msg_handler.create_and_send_nonmember_event(
requester, requester,
{ {
"type": EventTypes.ThirdPartyInvite, "type": EventTypes.ThirdPartyInvite,

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -171,6 +172,7 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.event_creation_handler = hs.get_event_creation_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def on_POST(self, request, room_id): def on_POST(self, request, room_id):
@ -203,8 +205,7 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
) )
new_room_id = info["room_id"] new_room_id = info["room_id"]
msg_handler = self.handlers.message_handler yield self.event_creation_handler.create_and_send_nonmember_event(
yield msg_handler.create_and_send_nonmember_event(
room_creator_requester, room_creator_requester,
{ {
"type": "m.room.message", "type": "m.room.message",

View file

@ -1,5 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd # Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -82,6 +83,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(RoomStateEventRestServlet, self).__init__(hs) super(RoomStateEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.event_creation_hander = hs.get_event_creation_handler()
def register(self, http_server): def register(self, http_server):
# /room/$roomid/state/$eventtype # /room/$roomid/state/$eventtype
@ -162,15 +164,16 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
content=content, content=content,
) )
else: else:
msg_handler = self.handlers.message_handler event, context = yield self.event_creation_hander.create_event(
event, context = yield msg_handler.create_event(
requester, requester,
event_dict, event_dict,
token_id=requester.access_token_id, token_id=requester.access_token_id,
txn_id=txn_id, txn_id=txn_id,
) )
yield msg_handler.send_nonmember_event(requester, event, context) yield self.event_creation_hander.send_nonmember_event(
requester, event, context,
)
ret = {} ret = {}
if event: if event:
@ -184,6 +187,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(RoomSendEventRestServlet, self).__init__(hs) super(RoomSendEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.event_creation_hander = hs.get_event_creation_handler()
def register(self, http_server): def register(self, http_server):
# /rooms/$roomid/send/$event_type[/$txn_id] # /rooms/$roomid/send/$event_type[/$txn_id]
@ -205,8 +209,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
if 'ts' in request.args and requester.app_service: if 'ts' in request.args and requester.app_service:
event_dict['origin_server_ts'] = parse_integer(request, "ts", 0) event_dict['origin_server_ts'] = parse_integer(request, "ts", 0)
msg_handler = self.handlers.message_handler event = yield self.event_creation_hander.create_and_send_nonmember_event(
event = yield msg_handler.create_and_send_nonmember_event(
requester, requester,
event_dict, event_dict,
txn_id=txn_id, txn_id=txn_id,
@ -670,6 +673,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
def __init__(self, hs): def __init__(self, hs):
super(RoomRedactEventRestServlet, self).__init__(hs) super(RoomRedactEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self.event_creation_handler = hs.get_event_creation_handler()
def register(self, http_server): def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)") PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)")
@ -680,8 +684,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
msg_handler = self.handlers.message_handler event = yield self.event_creation_handler.create_and_send_nonmember_event(
event = yield msg_handler.create_and_send_nonmember_event(
requester, requester,
{ {
"type": EventTypes.Redaction, "type": EventTypes.Redaction,

View file

@ -55,6 +55,7 @@ from synapse.handlers.read_marker import ReadMarkerHandler
from synapse.handlers.user_directory import UserDirectoryHandler from synapse.handlers.user_directory import UserDirectoryHandler
from synapse.handlers.groups_local import GroupsLocalHandler from synapse.handlers.groups_local import GroupsLocalHandler
from synapse.handlers.profile import ProfileHandler from synapse.handlers.profile import ProfileHandler
from synapse.handlers.message import EventCreationHandler
from synapse.groups.groups_server import GroupsServerHandler from synapse.groups.groups_server import GroupsServerHandler
from synapse.groups.attestations import GroupAttestionRenewer, GroupAttestationSigning from synapse.groups.attestations import GroupAttestionRenewer, GroupAttestationSigning
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
@ -118,6 +119,7 @@ class HomeServer(object):
'application_service_handler', 'application_service_handler',
'device_message_handler', 'device_message_handler',
'profile_handler', 'profile_handler',
'event_creation_handler',
'deactivate_account_handler', 'deactivate_account_handler',
'set_password_handler', 'set_password_handler',
'notifier', 'notifier',
@ -276,6 +278,9 @@ class HomeServer(object):
def build_profile_handler(self): def build_profile_handler(self):
return ProfileHandler(self) return ProfileHandler(self)
def build_event_creation_handler(self):
return EventCreationHandler(self)
def build_deactivate_account_handler(self): def build_deactivate_account_handler(self):
return DeactivateAccountHandler(self) return DeactivateAccountHandler(self)

View file

@ -36,8 +36,7 @@ class RedactionTestCase(unittest.TestCase):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.event_builder_factory = hs.get_event_builder_factory() self.event_builder_factory = hs.get_event_builder_factory()
self.handlers = hs.get_handlers() self.event_creation_handler = hs.get_event_creation_handler()
self.message_handler = self.handlers.message_handler
self.u_alice = UserID.from_string("@alice:test") self.u_alice = UserID.from_string("@alice:test")
self.u_bob = UserID.from_string("@bob:test") self.u_bob = UserID.from_string("@bob:test")
@ -59,7 +58,7 @@ class RedactionTestCase(unittest.TestCase):
"content": content, "content": content,
}) })
event, context = yield self.message_handler._create_new_client_event( event, context = yield self.event_creation_handler.create_new_client_event(
builder builder
) )
@ -79,7 +78,7 @@ class RedactionTestCase(unittest.TestCase):
"content": {"body": body, "msgtype": u"message"}, "content": {"body": body, "msgtype": u"message"},
}) })
event, context = yield self.message_handler._create_new_client_event( event, context = yield self.event_creation_handler.create_new_client_event(
builder builder
) )
@ -98,7 +97,7 @@ class RedactionTestCase(unittest.TestCase):
"redacts": event_id, "redacts": event_id,
}) })
event, context = yield self.message_handler._create_new_client_event( event, context = yield self.event_creation_handler.create_new_client_event(
builder builder
) )

View file

@ -37,8 +37,7 @@ class RoomMemberStoreTestCase(unittest.TestCase):
# storage logic # storage logic
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.event_builder_factory = hs.get_event_builder_factory() self.event_builder_factory = hs.get_event_builder_factory()
self.handlers = hs.get_handlers() self.event_creation_handler = hs.get_event_creation_handler()
self.message_handler = self.handlers.message_handler
self.u_alice = UserID.from_string("@alice:test") self.u_alice = UserID.from_string("@alice:test")
self.u_bob = UserID.from_string("@bob:test") self.u_bob = UserID.from_string("@bob:test")
@ -58,7 +57,7 @@ class RoomMemberStoreTestCase(unittest.TestCase):
"content": {"membership": membership}, "content": {"membership": membership},
}) })
event, context = yield self.message_handler._create_new_client_event( event, context = yield self.event_creation_handler.create_new_client_event(
builder builder
) )