diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 9e912fdfb..d3e9837c8 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -22,7 +22,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError -from synapse.types import Requester, RoomID, UserID, EventID +from synapse.types import Requester, UserID, get_domian_from_id from synapse.util.logutils import log_function from synapse.util.logcontext import preserve_context_over_fn from synapse.util.metrics import Measure @@ -91,8 +91,8 @@ class Auth(object): "Room %r does not exist" % (event.room_id,) ) - creating_domain = RoomID.from_string(event.room_id).domain - originating_domain = UserID.from_string(event.sender).domain + creating_domain = get_domian_from_id(event.room_id) + originating_domain = get_domian_from_id(event.sender) if creating_domain != originating_domain: if not self.can_federate(event, auth_events): raise AuthError( @@ -219,7 +219,7 @@ class Auth(object): for event in curr_state.values(): if event.type == EventTypes.Member: try: - if UserID.from_string(event.state_key).domain != host: + if get_domian_from_id(event.state_key) != host: continue except: logger.warn("state_key not user_id: %s", event.state_key) @@ -266,8 +266,8 @@ class Auth(object): target_user_id = event.state_key - creating_domain = RoomID.from_string(event.room_id).domain - target_domain = UserID.from_string(target_user_id).domain + creating_domain = get_domian_from_id(event.room_id) + target_domain = get_domian_from_id(target_user_id) if creating_domain != target_domain: if not self.can_federate(event, auth_events): raise AuthError( @@ -889,8 +889,8 @@ class Auth(object): if user_level >= redact_level: return False - redacter_domain = EventID.from_string(event.event_id).domain - redactee_domain = EventID.from_string(event.redacts).domain + redacter_domain = get_domian_from_id(event.event_id) + redactee_domain = get_domian_from_id(event.redacts) if redacter_domain == redactee_domain: return True diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index 81f7929b5..745c8901e 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -18,7 +18,7 @@ from twisted.internet import defer from synapse.api.errors import LimitExceededError, SynapseError, AuthError from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.api.constants import Membership, EventTypes -from synapse.types import UserID, RoomAlias, Requester +from synapse.types import UserID, RoomAlias, Requester, get_domian_from_id from synapse.push.action_generator import ActionGenerator from synapse.util.logcontext import PreserveLoggingContext, preserve_fn @@ -312,7 +312,7 @@ class BaseHandler(object): return True for (state_key, membership) in room_members: if ( - UserID.from_string(state_key).domain == self.hs.hostname + self.hs.is_mine_id(state_key) and membership == Membership.JOIN ): return True @@ -437,9 +437,7 @@ class BaseHandler(object): try: if k[0] == EventTypes.Member: if s.content["membership"] == Membership.JOIN: - destinations.add( - UserID.from_string(s.state_key).domain - ) + destinations.add(get_domian_from_id(s.state_key)) except SynapseError: logger.warn( "Failed to get destination from event %s", s.event_id diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index d95e0b23b..f38c6a871 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -33,7 +33,7 @@ from synapse.util.frozenutils import unfreeze from synapse.crypto.event_signing import ( compute_event_signature, add_hashes_and_signatures, ) -from synapse.types import UserID +from synapse.types import UserID, get_domian_from_id from synapse.events.utils import prune_event @@ -453,7 +453,7 @@ class FederationHandler(BaseHandler): joined_domains = {} for u, d in joined_users: try: - dom = UserID.from_string(u).domain + dom = get_domian_from_id(u) old_d = joined_domains.get(dom) if old_d: joined_domains[dom] = min(d, old_d) @@ -743,9 +743,7 @@ class FederationHandler(BaseHandler): try: if k[0] == EventTypes.Member: if s.content["membership"] == Membership.JOIN: - destinations.add( - UserID.from_string(s.state_key).domain - ) + destinations.add(get_domian_from_id(s.state_key)) except: logger.warn( "Failed to get destination from event %s", s.event_id @@ -970,9 +968,7 @@ class FederationHandler(BaseHandler): try: if k[0] == EventTypes.Member: if s.content["membership"] == Membership.LEAVE: - destinations.add( - UserID.from_string(s.state_key).domain - ) + destinations.add(get_domian_from_id(s.state_key)) except: logger.warn( "Failed to get destination from event %s", s.event_id diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 639567953..a8529cce4 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -33,7 +33,7 @@ from synapse.util.logcontext import preserve_fn from synapse.util.logutils import log_function from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer -from synapse.types import UserID +from synapse.types import UserID, get_domian_from_id import synapse.metrics from ._base import BaseHandler @@ -440,7 +440,7 @@ class PresenceHandler(BaseHandler): if not local_states: continue - host = UserID.from_string(user_id).domain + host = get_domian_from_id(user_id) hosts_to_states.setdefault(host, []).extend(local_states) # TODO: de-dup hosts_to_states, as a single host might have multiple diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 08a54cbdd..9d6bfd524 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -21,7 +21,7 @@ from ._base import SQLBaseStore from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.api.constants import Membership -from synapse.types import UserID +from synapse.types import get_domian_from_id import logging @@ -273,10 +273,7 @@ class RoomMemberStore(SQLBaseStore): room_id, membership=Membership.JOIN ) - joined_domains = set( - UserID.from_string(r["user_id"]).domain - for r in rows - ) + joined_domains = set(get_domian_from_id(r["user_id"]) for r in rows) return joined_domains diff --git a/synapse/types.py b/synapse/types.py index 5b166835b..42fd9c720 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -21,6 +21,10 @@ from collections import namedtuple Requester = namedtuple("Requester", ["user", "access_token_id", "is_guest"]) +def get_domian_from_id(string): + return string.split(":", 1)[1] + + class DomainSpecificString( namedtuple("DomainSpecificString", ("localpart", "domain")) ):