0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-19 16:32:24 +01:00

Merge pull request #110 from matrix-org/fix_ban

Fix ban
This commit is contained in:
Erik Johnston 2015-03-16 15:36:52 +00:00
commit cd2539ab2a
6 changed files with 32 additions and 52 deletions

View file

@ -28,6 +28,12 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
AuthEventTypes = (
EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,
EventTypes.JoinRules,
)
class Auth(object): class Auth(object):
def __init__(self, hs): def __init__(self, hs):
@ -166,6 +172,7 @@ class Auth(object):
target = auth_events.get(key) target = auth_events.get(key)
target_in_room = target and target.membership == Membership.JOIN target_in_room = target and target.membership == Membership.JOIN
target_banned = target and target.membership == Membership.BAN
key = (EventTypes.JoinRules, "", ) key = (EventTypes.JoinRules, "", )
join_rule_event = auth_events.get(key) join_rule_event = auth_events.get(key)
@ -194,6 +201,7 @@ class Auth(object):
{ {
"caller_in_room": caller_in_room, "caller_in_room": caller_in_room,
"caller_invited": caller_invited, "caller_invited": caller_invited,
"target_banned": target_banned,
"target_in_room": target_in_room, "target_in_room": target_in_room,
"membership": membership, "membership": membership,
"join_rule": join_rule, "join_rule": join_rule,
@ -202,6 +210,11 @@ class Auth(object):
} }
) )
if ban_level:
ban_level = int(ban_level)
else:
ban_level = 50 # FIXME (erikj): What should we do here?
if Membership.INVITE == membership: if Membership.INVITE == membership:
# TODO (erikj): We should probably handle this more intelligently # TODO (erikj): We should probably handle this more intelligently
# PRIVATE join rules. # PRIVATE join rules.
@ -212,6 +225,10 @@ class Auth(object):
403, 403,
"%s not in room %s." % (event.user_id, event.room_id,) "%s not in room %s." % (event.user_id, event.room_id,)
) )
elif target_banned:
raise AuthError(
403, "%s is banned from the room" % (target_user_id,)
)
elif target_in_room: # the target is already in the room. elif target_in_room: # the target is already in the room.
raise AuthError(403, "%s is already in the room." % raise AuthError(403, "%s is already in the room." %
target_user_id) target_user_id)
@ -221,6 +238,8 @@ class Auth(object):
# joined: It's a NOOP # joined: It's a NOOP
if event.user_id != target_user_id: if event.user_id != target_user_id:
raise AuthError(403, "Cannot force another user to join.") raise AuthError(403, "Cannot force another user to join.")
elif target_banned:
raise AuthError(403, "You are banned from this room")
elif join_rule == JoinRules.PUBLIC: elif join_rule == JoinRules.PUBLIC:
pass pass
elif join_rule == JoinRules.INVITE: elif join_rule == JoinRules.INVITE:
@ -238,6 +257,10 @@ class Auth(object):
403, 403,
"%s not in room %s." % (target_user_id, event.room_id,) "%s not in room %s." % (target_user_id, event.room_id,)
) )
elif target_banned and user_level < ban_level:
raise AuthError(
403, "You cannot unban user &s." % (target_user_id,)
)
elif target_user_id != event.user_id: elif target_user_id != event.user_id:
if kick_level: if kick_level:
kick_level = int(kick_level) kick_level = int(kick_level)
@ -249,11 +272,6 @@ class Auth(object):
403, "You cannot kick user %s." % target_user_id 403, "You cannot kick user %s." % target_user_id
) )
elif Membership.BAN == membership: elif Membership.BAN == membership:
if ban_level:
ban_level = int(ban_level)
else:
ban_level = 50 # FIXME (erikj): What should we do here?
if user_level < ban_level: if user_level < ban_level:
raise AuthError(403, "You don't have permission to ban") raise AuthError(403, "You don't have permission to ban")
else: else:
@ -412,12 +430,6 @@ class Auth(object):
builder.auth_events = auth_events_entries builder.auth_events = auth_events_entries
context.auth_events = {
k: v
for k, v in context.current_state.items()
if v.event_id in auth_ids
}
def compute_auth_events(self, event, current_state): def compute_auth_events(self, event, current_state):
if event.type == EventTypes.Create: if event.type == EventTypes.Create:
return [] return []

View file

@ -16,8 +16,7 @@
class EventContext(object): class EventContext(object):
def __init__(self, current_state=None, auth_events=None): def __init__(self, current_state=None):
self.current_state = current_state self.current_state = current_state
self.auth_events = auth_events
self.state_group = None self.state_group = None
self.rejected = False self.rejected = False

View file

@ -90,8 +90,8 @@ class BaseHandler(object):
event = builder.build() event = builder.build()
logger.debug( logger.debug(
"Created event %s with auth_events: %s, current state: %s", "Created event %s with current state: %s",
event.event_id, context.auth_events, context.current_state, event.event_id, context.current_state,
) )
defer.returnValue( defer.returnValue(
@ -106,7 +106,7 @@ class BaseHandler(object):
# We now need to go and hit out to wherever we need to hit out to. # We now need to go and hit out to wherever we need to hit out to.
if not suppress_auth: if not suppress_auth:
self.auth.check(event, auth_events=context.auth_events) self.auth.check(event, auth_events=context.current_state)
yield self.store.persist_event(event, context=context) yield self.store.persist_event(event, context=context)

View file

@ -464,11 +464,9 @@ class FederationHandler(BaseHandler):
builder=builder, builder=builder,
) )
self.auth.check(event, auth_events=context.auth_events) self.auth.check(event, auth_events=context.current_state)
pdu = event defer.returnValue(event)
defer.returnValue(pdu)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
@ -705,7 +703,7 @@ class FederationHandler(BaseHandler):
) )
if not auth_events: if not auth_events:
auth_events = context.auth_events auth_events = context.current_state
logger.debug( logger.debug(
"_handle_new_event: %s, auth_events: %s", "_handle_new_event: %s, auth_events: %s",

View file

@ -21,6 +21,7 @@ from synapse.util.async import run_on_reactor
from synapse.util.expiringcache import ExpiringCache from synapse.util.expiringcache import ExpiringCache
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.api.auth import AuthEventTypes
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from collections import namedtuple from collections import namedtuple
@ -38,12 +39,6 @@ def _get_state_key_from_event(event):
KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key")) KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
AuthEventTypes = (
EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,
EventTypes.JoinRules,
)
SIZE_OF_CACHE = 1000 SIZE_OF_CACHE = 1000
EVICTION_TIMEOUT_SECONDS = 20 EVICTION_TIMEOUT_SECONDS = 20
@ -139,18 +134,6 @@ class StateHandler(object):
} }
context.state_group = None context.state_group = None
if hasattr(event, "auth_events") and event.auth_events:
auth_ids = self.hs.get_auth().compute_auth_events(
event, context.current_state
)
context.auth_events = {
k: v
for k, v in context.current_state.items()
if v.event_id in auth_ids
}
else:
context.auth_events = {}
if event.is_state(): if event.is_state():
key = (event.type, event.state_key) key = (event.type, event.state_key)
if key in context.current_state: if key in context.current_state:
@ -187,18 +170,6 @@ class StateHandler(object):
replaces = context.current_state[key] replaces = context.current_state[key]
event.unsigned["replaces_state"] = replaces.event_id event.unsigned["replaces_state"] = replaces.event_id
if hasattr(event, "auth_events") and event.auth_events:
auth_ids = self.hs.get_auth().compute_auth_events(
event, context.current_state
)
context.auth_events = {
k: v
for k, v in context.current_state.items()
if v.event_id in auth_ids
}
else:
context.auth_events = {}
context.prev_state_events = prev_state context.prev_state_events = prev_state
defer.returnValue(context) defer.returnValue(context)

View file

@ -82,7 +82,7 @@ class StateStore(SQLBaseStore):
if context.current_state is None: if context.current_state is None:
return return
state_events = context.current_state state_events = dict(context.current_state)
if event.is_state(): if event.is_state():
state_events[(event.type, event.state_key)] = event state_events[(event.type, event.state_key)] = event