mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-17 00:23:52 +01:00
Merge implementation of /join by alias or ID
This code is kind of rough (passing the remote servers down a long chain), but is a step towards improvement.
This commit is contained in:
parent
dbeed36dec
commit
e71095801f
5 changed files with 72 additions and 71 deletions
|
@ -188,9 +188,12 @@ class BaseHandler(object):
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def handle_new_client_event(self, event, context, extra_users=[]):
|
def handle_new_client_event(self, event, context, ratelimit=True, extra_users=[]):
|
||||||
# We now need to go and hit out to wherever we need to hit out to.
|
# We now need to go and hit out to wherever we need to hit out to.
|
||||||
|
|
||||||
|
if ratelimit:
|
||||||
|
self.ratelimit(event.sender)
|
||||||
|
|
||||||
self.auth.check(event, auth_events=context.current_state)
|
self.auth.check(event, auth_events=context.current_state)
|
||||||
|
|
||||||
yield self.maybe_kick_guest_users(event, context.current_state.values())
|
yield self.maybe_kick_guest_users(event, context.current_state.values())
|
||||||
|
|
|
@ -216,7 +216,7 @@ class MessageHandler(BaseHandler):
|
||||||
defer.returnValue((event, context))
|
defer.returnValue((event, context))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def send_event(self, event, context, ratelimit=True, is_guest=False):
|
def send_event(self, event, context, ratelimit=True, is_guest=False, room_hosts=None):
|
||||||
"""
|
"""
|
||||||
Persists and notifies local clients and federation of an event.
|
Persists and notifies local clients and federation of an event.
|
||||||
|
|
||||||
|
@ -230,9 +230,6 @@ class MessageHandler(BaseHandler):
|
||||||
|
|
||||||
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
|
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
|
||||||
|
|
||||||
if ratelimit:
|
|
||||||
self.ratelimit(event.sender)
|
|
||||||
|
|
||||||
if event.is_state():
|
if event.is_state():
|
||||||
prev_state = context.current_state.get((event.type, event.state_key))
|
prev_state = context.current_state.get((event.type, event.state_key))
|
||||||
if prev_state and event.user_id == prev_state.user_id:
|
if prev_state and event.user_id == prev_state.user_id:
|
||||||
|
@ -245,11 +242,18 @@ class MessageHandler(BaseHandler):
|
||||||
|
|
||||||
if event.type == EventTypes.Member:
|
if event.type == EventTypes.Member:
|
||||||
member_handler = self.hs.get_handlers().room_member_handler
|
member_handler = self.hs.get_handlers().room_member_handler
|
||||||
yield member_handler.send_membership_event(event, context, is_guest=is_guest)
|
yield member_handler.send_membership_event(
|
||||||
|
event,
|
||||||
|
context,
|
||||||
|
is_guest=is_guest,
|
||||||
|
ratelimit=ratelimit,
|
||||||
|
room_hosts=room_hosts
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
yield self.handle_new_client_event(
|
yield self.handle_new_client_event(
|
||||||
event=event,
|
event=event,
|
||||||
context=context,
|
context=context,
|
||||||
|
ratelimit=ratelimit,
|
||||||
)
|
)
|
||||||
|
|
||||||
if event.type == EventTypes.Message:
|
if event.type == EventTypes.Message:
|
||||||
|
@ -259,7 +263,8 @@ class MessageHandler(BaseHandler):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def create_and_send_event(self, event_dict, ratelimit=True,
|
def create_and_send_event(self, event_dict, ratelimit=True,
|
||||||
token_id=None, txn_id=None, is_guest=False):
|
token_id=None, txn_id=None, is_guest=False,
|
||||||
|
room_hosts=None):
|
||||||
"""
|
"""
|
||||||
Creates an event, then sends it.
|
Creates an event, then sends it.
|
||||||
|
|
||||||
|
@ -274,7 +279,8 @@ class MessageHandler(BaseHandler):
|
||||||
event,
|
event,
|
||||||
context,
|
context,
|
||||||
ratelimit=ratelimit,
|
ratelimit=ratelimit,
|
||||||
is_guest=is_guest
|
is_guest=is_guest,
|
||||||
|
room_hosts=room_hosts,
|
||||||
)
|
)
|
||||||
defer.returnValue(event)
|
defer.returnValue(event)
|
||||||
|
|
||||||
|
|
|
@ -455,7 +455,9 @@ class RoomMemberHandler(BaseHandler):
|
||||||
yield self.forget(requester.user, room_id)
|
yield self.forget(requester.user, room_id)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def send_membership_event(self, event, context, is_guest=False, room_hosts=None):
|
def send_membership_event(
|
||||||
|
self, event, context, is_guest=False, room_hosts=None, ratelimit=True
|
||||||
|
):
|
||||||
""" Change the membership status of a user in a room.
|
""" Change the membership status of a user in a room.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -527,8 +529,17 @@ class RoomMemberHandler(BaseHandler):
|
||||||
defer.returnValue({"room_id": room_id})
|
defer.returnValue({"room_id": room_id})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def join_room_alias(self, requester, room_alias, content={}):
|
def lookup_room_alias(self, room_alias):
|
||||||
joinee = requester.user
|
"""
|
||||||
|
Get the room ID associated with a room alias.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
room_alias (RoomAlias): The alias to look up.
|
||||||
|
Returns:
|
||||||
|
The room ID as a RoomID object.
|
||||||
|
Raises:
|
||||||
|
SynapseError if room alias could not be found.
|
||||||
|
"""
|
||||||
directory_handler = self.hs.get_handlers().directory_handler
|
directory_handler = self.hs.get_handlers().directory_handler
|
||||||
mapping = yield directory_handler.get_association(room_alias)
|
mapping = yield directory_handler.get_association(room_alias)
|
||||||
|
|
||||||
|
@ -540,28 +551,7 @@ class RoomMemberHandler(BaseHandler):
|
||||||
if not hosts:
|
if not hosts:
|
||||||
raise SynapseError(404, "No known servers")
|
raise SynapseError(404, "No known servers")
|
||||||
|
|
||||||
# If event doesn't include a display name, add one.
|
defer.returnValue((RoomID.from_string(room_id), hosts))
|
||||||
yield collect_presencelike_data(self.distributor, joinee, content)
|
|
||||||
|
|
||||||
content.update({"membership": Membership.JOIN})
|
|
||||||
builder = self.event_builder_factory.new({
|
|
||||||
"type": EventTypes.Member,
|
|
||||||
"state_key": joinee.to_string(),
|
|
||||||
"room_id": room_id,
|
|
||||||
"sender": joinee.to_string(),
|
|
||||||
"membership": Membership.JOIN,
|
|
||||||
"content": content,
|
|
||||||
})
|
|
||||||
event, context = yield self._create_new_client_event(builder)
|
|
||||||
|
|
||||||
yield self.send_membership_event(
|
|
||||||
event,
|
|
||||||
context,
|
|
||||||
is_guest=requester.is_guest,
|
|
||||||
room_hosts=hosts
|
|
||||||
)
|
|
||||||
|
|
||||||
defer.returnValue({"room_id": room_id})
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _do_join(self, event, context, room_hosts=None):
|
def _do_join(self, event, context, room_hosts=None):
|
||||||
|
|
|
@ -229,46 +229,40 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
|
||||||
allow_guest=True,
|
allow_guest=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# the identifier could be a room alias or a room id. Try one then the
|
if RoomID.is_valid(room_identifier):
|
||||||
# other if it fails to parse, without swallowing other valid
|
room_id = room_identifier
|
||||||
# SynapseErrors.
|
room_hosts = None
|
||||||
|
elif RoomAlias.is_valid(room_identifier):
|
||||||
identifier = None
|
|
||||||
is_room_alias = False
|
|
||||||
try:
|
|
||||||
identifier = RoomAlias.from_string(room_identifier)
|
|
||||||
is_room_alias = True
|
|
||||||
except SynapseError:
|
|
||||||
identifier = RoomID.from_string(room_identifier)
|
|
||||||
|
|
||||||
# TODO: Support for specifying the home server to join with?
|
|
||||||
|
|
||||||
if is_room_alias:
|
|
||||||
handler = self.handlers.room_member_handler
|
handler = self.handlers.room_member_handler
|
||||||
ret_dict = yield handler.join_room_alias(
|
room_alias = RoomAlias.from_string(room_identifier)
|
||||||
requester,
|
room_id, room_hosts = yield handler.lookup_room_alias(room_alias)
|
||||||
identifier,
|
room_id = room_id.to_string()
|
||||||
)
|
else:
|
||||||
defer.returnValue((200, ret_dict))
|
raise SynapseError(400, "%s was not legal room ID or room alias" % (
|
||||||
else: # room id
|
room_identifier,
|
||||||
msg_handler = self.handlers.message_handler
|
))
|
||||||
content = {"membership": Membership.JOIN}
|
|
||||||
if requester.is_guest:
|
|
||||||
content["kind"] = "guest"
|
|
||||||
yield msg_handler.create_and_send_event(
|
|
||||||
{
|
|
||||||
"type": EventTypes.Member,
|
|
||||||
"content": content,
|
|
||||||
"room_id": identifier.to_string(),
|
|
||||||
"sender": requester.user.to_string(),
|
|
||||||
"state_key": requester.user.to_string(),
|
|
||||||
},
|
|
||||||
token_id=requester.access_token_id,
|
|
||||||
txn_id=txn_id,
|
|
||||||
is_guest=requester.is_guest,
|
|
||||||
)
|
|
||||||
|
|
||||||
defer.returnValue((200, {"room_id": identifier.to_string()}))
|
msg_handler = self.handlers.message_handler
|
||||||
|
content = {"membership": Membership.JOIN}
|
||||||
|
if requester.is_guest:
|
||||||
|
content["kind"] = "guest"
|
||||||
|
yield msg_handler.create_and_send_event(
|
||||||
|
{
|
||||||
|
"type": EventTypes.Member,
|
||||||
|
"content": content,
|
||||||
|
"room_id": room_id,
|
||||||
|
"sender": requester.user.to_string(),
|
||||||
|
"state_key": requester.user.to_string(),
|
||||||
|
|
||||||
|
"membership": Membership.JOIN, # For backwards compatibility
|
||||||
|
},
|
||||||
|
token_id=requester.access_token_id,
|
||||||
|
txn_id=txn_id,
|
||||||
|
is_guest=requester.is_guest,
|
||||||
|
room_hosts=room_hosts,
|
||||||
|
)
|
||||||
|
|
||||||
|
defer.returnValue((200, {"room_id": room_id}))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_PUT(self, request, room_identifier, txn_id):
|
def on_PUT(self, request, room_identifier, txn_id):
|
||||||
|
|
|
@ -73,6 +73,14 @@ class DomainSpecificString(
|
||||||
"""Return a string encoding the fields of the structure object."""
|
"""Return a string encoding the fields of the structure object."""
|
||||||
return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain)
|
return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_valid(cls, s):
|
||||||
|
try:
|
||||||
|
cls.from_string(s)
|
||||||
|
return True
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
__str__ = to_string
|
__str__ = to_string
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
Loading…
Reference in a new issue