forked from MirrorHub/synapse
Allow guests to register and call /events?room_id=
This follows the same flows-based flow as regular registration, but as the only implemented flow has no requirements, it auto-succeeds. In the future, other flows (e.g. captcha) may be required, so clients should treat this like the regular registration flow choices.
This commit is contained in:
parent
f74f48e9e6
commit
f522f50a08
33 changed files with 271 additions and 166 deletions
|
@ -49,6 +49,7 @@ class Auth(object):
|
|||
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
|
||||
self._KNOWN_CAVEAT_PREFIXES = set([
|
||||
"gen = ",
|
||||
"guest = ",
|
||||
"type = ",
|
||||
"time < ",
|
||||
"user_id = ",
|
||||
|
@ -183,15 +184,11 @@ class Auth(object):
|
|||
defer.returnValue(member)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_user_was_in_room(self, room_id, user_id, current_state=None):
|
||||
def check_user_was_in_room(self, room_id, user_id):
|
||||
"""Check if the user was in the room at some point.
|
||||
Args:
|
||||
room_id(str): The room to check.
|
||||
user_id(str): The user to check.
|
||||
current_state(dict): Optional map of the current state of the room.
|
||||
If provided then that map is used to check whether they are a
|
||||
member of the room. Otherwise the current membership is
|
||||
loaded from the database.
|
||||
Raises:
|
||||
AuthError if the user was never in the room.
|
||||
Returns:
|
||||
|
@ -199,17 +196,11 @@ class Auth(object):
|
|||
room. This will be the join event if they are currently joined to
|
||||
the room. This will be the leave event if they have left the room.
|
||||
"""
|
||||
if current_state:
|
||||
member = current_state.get(
|
||||
(EventTypes.Member, user_id),
|
||||
None
|
||||
)
|
||||
else:
|
||||
member = yield self.state.get_current_state(
|
||||
room_id=room_id,
|
||||
event_type=EventTypes.Member,
|
||||
state_key=user_id
|
||||
)
|
||||
member = yield self.state.get_current_state(
|
||||
room_id=room_id,
|
||||
event_type=EventTypes.Member,
|
||||
state_key=user_id
|
||||
)
|
||||
membership = member.membership if member else None
|
||||
|
||||
if membership not in (Membership.JOIN, Membership.LEAVE):
|
||||
|
@ -497,7 +488,7 @@ class Auth(object):
|
|||
return default
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_user_by_req(self, request):
|
||||
def get_user_by_req(self, request, allow_guest=False):
|
||||
""" Get a registered user's ID.
|
||||
|
||||
Args:
|
||||
|
@ -535,7 +526,7 @@ class Auth(object):
|
|||
|
||||
request.authenticated_entity = user_id
|
||||
|
||||
defer.returnValue((UserID.from_string(user_id), ""))
|
||||
defer.returnValue((UserID.from_string(user_id), "", False))
|
||||
return
|
||||
except KeyError:
|
||||
pass # normal users won't have the user_id query parameter set.
|
||||
|
@ -543,6 +534,7 @@ class Auth(object):
|
|||
user_info = yield self._get_user_by_access_token(access_token)
|
||||
user = user_info["user"]
|
||||
token_id = user_info["token_id"]
|
||||
is_guest = user_info["is_guest"]
|
||||
|
||||
ip_addr = self.hs.get_ip_from_request(request)
|
||||
user_agent = request.requestHeaders.getRawHeaders(
|
||||
|
@ -557,9 +549,14 @@ class Auth(object):
|
|||
user_agent=user_agent
|
||||
)
|
||||
|
||||
if is_guest and not allow_guest:
|
||||
raise AuthError(
|
||||
403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
|
||||
)
|
||||
|
||||
request.authenticated_entity = user.to_string()
|
||||
|
||||
defer.returnValue((user, token_id,))
|
||||
defer.returnValue((user, token_id, is_guest,))
|
||||
except KeyError:
|
||||
raise AuthError(
|
||||
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
|
||||
|
@ -592,31 +589,45 @@ class Auth(object):
|
|||
self._validate_macaroon(macaroon)
|
||||
|
||||
user_prefix = "user_id = "
|
||||
user = None
|
||||
guest = False
|
||||
for caveat in macaroon.caveats:
|
||||
if caveat.caveat_id.startswith(user_prefix):
|
||||
user = UserID.from_string(caveat.caveat_id[len(user_prefix):])
|
||||
# This codepath exists so that we can actually return a
|
||||
# token ID, because we use token IDs in place of device
|
||||
# identifiers throughout the codebase.
|
||||
# TODO(daniel): Remove this fallback when device IDs are
|
||||
# properly implemented.
|
||||
ret = yield self._look_up_user_by_access_token(macaroon_str)
|
||||
if ret["user"] != user:
|
||||
logger.error(
|
||||
"Macaroon user (%s) != DB user (%s)",
|
||||
user,
|
||||
ret["user"]
|
||||
)
|
||||
raise AuthError(
|
||||
self.TOKEN_NOT_FOUND_HTTP_STATUS,
|
||||
"User mismatch in macaroon",
|
||||
errcode=Codes.UNKNOWN_TOKEN
|
||||
)
|
||||
defer.returnValue(ret)
|
||||
raise AuthError(
|
||||
self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
|
||||
errcode=Codes.UNKNOWN_TOKEN
|
||||
)
|
||||
elif caveat.caveat_id == "guest = true":
|
||||
guest = True
|
||||
|
||||
if user is None:
|
||||
raise AuthError(
|
||||
self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
|
||||
errcode=Codes.UNKNOWN_TOKEN
|
||||
)
|
||||
|
||||
if guest:
|
||||
ret = {
|
||||
"user": user,
|
||||
"is_guest": True,
|
||||
"token_id": None,
|
||||
}
|
||||
else:
|
||||
# This codepath exists so that we can actually return a
|
||||
# token ID, because we use token IDs in place of device
|
||||
# identifiers throughout the codebase.
|
||||
# TODO(daniel): Remove this fallback when device IDs are
|
||||
# properly implemented.
|
||||
ret = yield self._look_up_user_by_access_token(macaroon_str)
|
||||
if ret["user"] != user:
|
||||
logger.error(
|
||||
"Macaroon user (%s) != DB user (%s)",
|
||||
user,
|
||||
ret["user"]
|
||||
)
|
||||
raise AuthError(
|
||||
self.TOKEN_NOT_FOUND_HTTP_STATUS,
|
||||
"User mismatch in macaroon",
|
||||
errcode=Codes.UNKNOWN_TOKEN
|
||||
)
|
||||
defer.returnValue(ret)
|
||||
except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
|
||||
raise AuthError(
|
||||
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.",
|
||||
|
@ -629,6 +640,7 @@ class Auth(object):
|
|||
v.satisfy_exact("type = access")
|
||||
v.satisfy_general(lambda c: c.startswith("user_id = "))
|
||||
v.satisfy_general(self._verify_expiry)
|
||||
v.satisfy_exact("guest = true")
|
||||
v.verify(macaroon, self.hs.config.macaroon_secret_key)
|
||||
|
||||
v = pymacaroons.Verifier()
|
||||
|
@ -666,6 +678,7 @@ class Auth(object):
|
|||
user_info = {
|
||||
"user": UserID.from_string(ret.get("name")),
|
||||
"token_id": ret.get("token_id", None),
|
||||
"is_guest": False,
|
||||
}
|
||||
defer.returnValue(user_info)
|
||||
|
||||
|
|
|
@ -33,6 +33,7 @@ class Codes(object):
|
|||
NOT_FOUND = "M_NOT_FOUND"
|
||||
MISSING_TOKEN = "M_MISSING_TOKEN"
|
||||
UNKNOWN_TOKEN = "M_UNKNOWN_TOKEN"
|
||||
GUEST_ACCESS_FORBIDDEN = "M_GUEST_ACCESS_FORBIDDEN"
|
||||
LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED"
|
||||
CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED"
|
||||
CAPTCHA_INVALID = "M_CAPTCHA_INVALID"
|
||||
|
|
|
@ -34,6 +34,7 @@ class RegistrationConfig(Config):
|
|||
self.registration_shared_secret = config.get("registration_shared_secret")
|
||||
self.macaroon_secret_key = config.get("macaroon_secret_key")
|
||||
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
|
||||
self.allow_guest_access = config.get("allow_guest_access", False)
|
||||
|
||||
def default_config(self, **kwargs):
|
||||
registration_shared_secret = random_string_with_symbols(50)
|
||||
|
@ -54,6 +55,11 @@ class RegistrationConfig(Config):
|
|||
# Larger numbers increase the work factor needed to generate the hash.
|
||||
# The default number of rounds is 12.
|
||||
bcrypt_rounds: 12
|
||||
|
||||
# Allows users to register as guests without a password/email/etc, and
|
||||
# participate in rooms hosted on this server which have been made
|
||||
# accessible to anonymous users.
|
||||
allow_guest_access: False
|
||||
""" % locals()
|
||||
|
||||
def add_arguments(self, parser):
|
||||
|
|
|
@ -47,37 +47,23 @@ class BaseHandler(object):
|
|||
self.event_builder_factory = hs.get_event_builder_factory()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _filter_events_for_client(self, user_id, events):
|
||||
event_id_to_state = yield self.store.get_state_for_events(
|
||||
frozenset(e.event_id for e in events),
|
||||
types=(
|
||||
(EventTypes.RoomHistoryVisibility, ""),
|
||||
(EventTypes.Member, user_id),
|
||||
)
|
||||
)
|
||||
def _filter_events_for_client(self, user_id, events, is_guest=False):
|
||||
# Assumes that user has at some point joined the room if not is_guest.
|
||||
|
||||
def allowed(event, state):
|
||||
if event.type == EventTypes.RoomHistoryVisibility:
|
||||
def allowed(event, membership, visibility):
|
||||
if visibility == "world_readable":
|
||||
return True
|
||||
|
||||
membership_ev = state.get((EventTypes.Member, user_id), None)
|
||||
if membership_ev:
|
||||
membership = membership_ev.membership
|
||||
else:
|
||||
membership = Membership.LEAVE
|
||||
if is_guest:
|
||||
return False
|
||||
|
||||
if membership == Membership.JOIN:
|
||||
return True
|
||||
|
||||
history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
|
||||
if history:
|
||||
visibility = history.content.get("history_visibility", "shared")
|
||||
else:
|
||||
visibility = "shared"
|
||||
if event.type == EventTypes.RoomHistoryVisibility:
|
||||
return not is_guest
|
||||
|
||||
if visibility == "public":
|
||||
return True
|
||||
elif visibility == "shared":
|
||||
if visibility == "shared":
|
||||
return True
|
||||
elif visibility == "joined":
|
||||
return membership == Membership.JOIN
|
||||
|
@ -86,11 +72,44 @@ class BaseHandler(object):
|
|||
|
||||
return True
|
||||
|
||||
defer.returnValue([
|
||||
event
|
||||
for event in events
|
||||
if allowed(event, event_id_to_state[event.event_id])
|
||||
])
|
||||
event_id_to_state = yield self.store.get_state_for_events(
|
||||
frozenset(e.event_id for e in events),
|
||||
types=(
|
||||
(EventTypes.RoomHistoryVisibility, ""),
|
||||
(EventTypes.Member, user_id),
|
||||
)
|
||||
)
|
||||
|
||||
events_to_return = []
|
||||
for event in events:
|
||||
state = event_id_to_state[event.event_id]
|
||||
|
||||
membership_event = state.get((EventTypes.Member, user_id), None)
|
||||
if membership_event:
|
||||
membership = membership_event.membership
|
||||
else:
|
||||
membership = None
|
||||
|
||||
visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
|
||||
if visibility_event:
|
||||
visibility = visibility_event.content.get("history_visibility", "shared")
|
||||
else:
|
||||
visibility = "shared"
|
||||
|
||||
should_include = allowed(event, membership, visibility)
|
||||
if should_include:
|
||||
events_to_return.append(event)
|
||||
|
||||
if is_guest and len(events_to_return) < len(events):
|
||||
# This indicates that some events in the requested range were not
|
||||
# visible to guest users. To be safe, we reject the entire request,
|
||||
# so that we don't have to worry about interpreting visibility
|
||||
# boundaries.
|
||||
raise AuthError(403, "User %s does not have permission" % (
|
||||
user_id
|
||||
))
|
||||
|
||||
defer.returnValue(events_to_return)
|
||||
|
||||
def ratelimit(self, user_id):
|
||||
time_now = self.clock.time()
|
||||
|
|
|
@ -372,12 +372,15 @@ class AuthHandler(BaseHandler):
|
|||
yield self.store.add_refresh_token_to_user(user_id, refresh_token)
|
||||
defer.returnValue(refresh_token)
|
||||
|
||||
def generate_access_token(self, user_id):
|
||||
def generate_access_token(self, user_id, extra_caveats=None):
|
||||
extra_caveats = extra_caveats or []
|
||||
macaroon = self._generate_base_macaroon(user_id)
|
||||
macaroon.add_first_party_caveat("type = access")
|
||||
now = self.hs.get_clock().time_msec()
|
||||
expiry = now + (60 * 60 * 1000)
|
||||
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
||||
for caveat in extra_caveats:
|
||||
macaroon.add_first_party_caveat(caveat)
|
||||
return macaroon.serialize()
|
||||
|
||||
def generate_refresh_token(self, user_id):
|
||||
|
|
|
@ -71,20 +71,20 @@ class MessageHandler(BaseHandler):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def get_messages(self, user_id=None, room_id=None, pagin_config=None,
|
||||
as_client_event=True):
|
||||
as_client_event=True, is_guest=False):
|
||||
"""Get messages in a room.
|
||||
|
||||
Args:
|
||||
user_id (str): The user requesting messages.
|
||||
room_id (str): The room they want messages from.
|
||||
pagin_config (synapse.api.streams.PaginationConfig): The pagination
|
||||
config rules to apply, if any.
|
||||
config rules to apply, if any.
|
||||
as_client_event (bool): True to get events in client-server format.
|
||||
is_guest (bool): Whether the requesting user is a guest (as opposed
|
||||
to a fully registered user).
|
||||
Returns:
|
||||
dict: Pagination API results
|
||||
"""
|
||||
member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
|
||||
|
||||
data_source = self.hs.get_event_sources().sources["room"]
|
||||
|
||||
if pagin_config.from_token:
|
||||
|
@ -107,23 +107,27 @@ class MessageHandler(BaseHandler):
|
|||
|
||||
source_config = pagin_config.get_source_config("room")
|
||||
|
||||
if member_event.membership == Membership.LEAVE:
|
||||
# If they have left the room then clamp the token to be before
|
||||
# they left the room
|
||||
leave_token = yield self.store.get_topological_token_for_event(
|
||||
member_event.event_id
|
||||
)
|
||||
leave_token = RoomStreamToken.parse(leave_token)
|
||||
if leave_token.topological < room_token.topological:
|
||||
source_config.from_key = str(leave_token)
|
||||
if not is_guest:
|
||||
member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
|
||||
if member_event.membership == Membership.LEAVE:
|
||||
# If they have left the room then clamp the token to be before
|
||||
# they left the room.
|
||||
# If they're a guest, we'll just 403 them if they're asking for
|
||||
# events they can't see.
|
||||
leave_token = yield self.store.get_topological_token_for_event(
|
||||
member_event.event_id
|
||||
)
|
||||
leave_token = RoomStreamToken.parse(leave_token)
|
||||
if leave_token.topological < room_token.topological:
|
||||
source_config.from_key = str(leave_token)
|
||||
|
||||
if source_config.direction == "f":
|
||||
if source_config.to_key is None:
|
||||
source_config.to_key = str(leave_token)
|
||||
else:
|
||||
to_token = RoomStreamToken.parse(source_config.to_key)
|
||||
if leave_token.topological < to_token.topological:
|
||||
if source_config.direction == "f":
|
||||
if source_config.to_key is None:
|
||||
source_config.to_key = str(leave_token)
|
||||
else:
|
||||
to_token = RoomStreamToken.parse(source_config.to_key)
|
||||
if leave_token.topological < to_token.topological:
|
||||
source_config.to_key = str(leave_token)
|
||||
|
||||
yield self.hs.get_handlers().federation_handler.maybe_backfill(
|
||||
room_id, room_token.topological
|
||||
|
@ -146,7 +150,7 @@ class MessageHandler(BaseHandler):
|
|||
"end": next_token.to_string(),
|
||||
})
|
||||
|
||||
events = yield self._filter_events_for_client(user_id, events)
|
||||
events = yield self._filter_events_for_client(user_id, events, is_guest=is_guest)
|
||||
|
||||
time_now = self.clock.time_msec()
|
||||
|
||||
|
|
|
@ -64,7 +64,7 @@ class RegistrationHandler(BaseHandler):
|
|||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def register(self, localpart=None, password=None):
|
||||
def register(self, localpart=None, password=None, generate_token=True):
|
||||
"""Registers a new client on the server.
|
||||
|
||||
Args:
|
||||
|
@ -89,7 +89,9 @@ class RegistrationHandler(BaseHandler):
|
|||
user = UserID(localpart, self.hs.hostname)
|
||||
user_id = user.to_string()
|
||||
|
||||
token = self.auth_handler().generate_access_token(user_id)
|
||||
token = None
|
||||
if generate_token:
|
||||
token = self.auth_handler().generate_access_token(user_id)
|
||||
yield self.store.register(
|
||||
user_id=user_id,
|
||||
token=token,
|
||||
|
@ -102,14 +104,14 @@ class RegistrationHandler(BaseHandler):
|
|||
attempts = 0
|
||||
user_id = None
|
||||
token = None
|
||||
while not user_id and not token:
|
||||
while not user_id:
|
||||
try:
|
||||
localpart = self._generate_user_id()
|
||||
user = UserID(localpart, self.hs.hostname)
|
||||
user_id = user.to_string()
|
||||
yield self.check_user_id_is_valid(user_id)
|
||||
|
||||
token = self.auth_handler().generate_access_token(user_id)
|
||||
if generate_token:
|
||||
token = self.auth_handler().generate_access_token(user_id)
|
||||
yield self.store.register(
|
||||
user_id=user_id,
|
||||
token=token,
|
||||
|
|
|
@ -31,7 +31,7 @@ class WhoisRestServlet(ClientV1RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id):
|
||||
target_user = UserID.from_string(user_id)
|
||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
is_admin = yield self.auth.is_server_admin(auth_user)
|
||||
|
||||
if not is_admin and target_user != auth_user:
|
||||
|
|
|
@ -69,7 +69,7 @@ class ClientDirectoryServer(ClientV1RestServlet):
|
|||
|
||||
try:
|
||||
# try to auth as a user
|
||||
user, _ = yield self.auth.get_user_by_req(request)
|
||||
user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
try:
|
||||
user_id = user.to_string()
|
||||
yield dir_handler.create_association(
|
||||
|
@ -116,7 +116,7 @@ class ClientDirectoryServer(ClientV1RestServlet):
|
|||
# fallback to default user behaviour if they aren't an AS
|
||||
pass
|
||||
|
||||
user, _ = yield self.auth.get_user_by_req(request)
|
||||
user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
|
||||
is_admin = yield self.auth.is_server_admin(user)
|
||||
if not is_admin:
|
||||
|
|
|
@ -34,7 +34,7 @@ class EventStreamRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
try:
|
||||
handler = self.handlers.event_stream_handler
|
||||
pagin_config = PaginationConfig.from_request(request)
|
||||
|
@ -71,7 +71,7 @@ class EventRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, event_id):
|
||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
handler = self.handlers.event_handler
|
||||
event = yield handler.get_event(auth_user, event_id)
|
||||
|
||||
|
|
|
@ -25,7 +25,7 @@ class InitialSyncRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
user, _ = yield self.auth.get_user_by_req(request)
|
||||
user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
as_client_event = "raw" not in request.args
|
||||
pagination_config = PaginationConfig.from_request(request)
|
||||
handler = self.handlers.message_handler
|
||||
|
|
|
@ -32,7 +32,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id):
|
||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
user = UserID.from_string(user_id)
|
||||
|
||||
state = yield self.handlers.presence_handler.get_state(
|
||||
|
@ -42,7 +42,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, user_id):
|
||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
user = UserID.from_string(user_id)
|
||||
|
||||
state = {}
|
||||
|
@ -77,7 +77,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id):
|
||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
user = UserID.from_string(user_id)
|
||||
|
||||
if not self.hs.is_mine(user):
|
||||
|
@ -97,7 +97,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, user_id):
|
||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
user = UserID.from_string(user_id)
|
||||
|
||||
if not self.hs.is_mine(user):
|
||||
|
|
|
@ -37,7 +37,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, user_id):
|
||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
user = UserID.from_string(user_id)
|
||||
|
||||
try:
|
||||
|
@ -70,7 +70,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, user_id):
|
||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
user = UserID.from_string(user_id)
|
||||
|
||||
try:
|
||||
|
|
|
@ -43,7 +43,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
|||
except InvalidRuleException as e:
|
||||
raise SynapseError(400, e.message)
|
||||
|
||||
user, _ = yield self.auth.get_user_by_req(request)
|
||||
user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
|
||||
if '/' in spec['rule_id'] or '\\' in spec['rule_id']:
|
||||
raise SynapseError(400, "rule_id may not contain slashes")
|
||||
|
@ -92,7 +92,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
|||
def on_DELETE(self, request):
|
||||
spec = _rule_spec_from_path(request.postpath)
|
||||
|
||||
user, _ = yield self.auth.get_user_by_req(request)
|
||||
user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
|
||||
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
|
||||
|
||||
|
@ -109,7 +109,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
user, _ = yield self.auth.get_user_by_req(request)
|
||||
user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
|
||||
# we build up the full structure and then decide which bits of it
|
||||
# to send which means doing unnecessary work sometimes but is
|
||||
|
|
|
@ -27,7 +27,7 @@ class PusherRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
user, token_id = yield self.auth.get_user_by_req(request)
|
||||
user, token_id, _ = yield self.auth.get_user_by_req(request)
|
||||
|
||||
content = _parse_json(request)
|
||||
|
||||
|
|
|
@ -62,7 +62,7 @@ class RoomCreateRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
|
||||
room_config = self.get_room_config(request)
|
||||
info = yield self.make_room(room_config, auth_user, None)
|
||||
|
@ -125,7 +125,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id, event_type, state_key):
|
||||
user, _ = yield self.auth.get_user_by_req(request)
|
||||
user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
|
||||
msg_handler = self.handlers.message_handler
|
||||
data = yield msg_handler.get_room_data(
|
||||
|
@ -143,7 +143,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
|
||||
user, token_id = yield self.auth.get_user_by_req(request)
|
||||
user, token_id, _ = yield self.auth.get_user_by_req(request)
|
||||
|
||||
content = _parse_json(request)
|
||||
|
||||
|
@ -175,7 +175,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, room_id, event_type, txn_id=None):
|
||||
user, token_id = yield self.auth.get_user_by_req(request)
|
||||
user, token_id, _ = yield self.auth.get_user_by_req(request)
|
||||
content = _parse_json(request)
|
||||
|
||||
msg_handler = self.handlers.message_handler
|
||||
|
@ -220,7 +220,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, room_identifier, txn_id=None):
|
||||
user, token_id = yield self.auth.get_user_by_req(request)
|
||||
user, token_id, _ = yield self.auth.get_user_by_req(request)
|
||||
|
||||
# the identifier could be a room alias or a room id. Try one then the
|
||||
# other if it fails to parse, without swallowing other valid
|
||||
|
@ -289,7 +289,7 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id):
|
||||
# TODO support Pagination stream API (limit/tokens)
|
||||
user, _ = yield self.auth.get_user_by_req(request)
|
||||
user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
handler = self.handlers.message_handler
|
||||
events = yield handler.get_state_events(
|
||||
room_id=room_id,
|
||||
|
@ -325,7 +325,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id):
|
||||
user, _ = yield self.auth.get_user_by_req(request)
|
||||
user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
pagination_config = PaginationConfig.from_request(
|
||||
request, default_limit=10,
|
||||
)
|
||||
|
@ -334,6 +334,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
|
|||
msgs = yield handler.get_messages(
|
||||
room_id=room_id,
|
||||
user_id=user.to_string(),
|
||||
is_guest=is_guest,
|
||||
pagin_config=pagination_config,
|
||||
as_client_event=as_client_event
|
||||
)
|
||||
|
@ -347,7 +348,7 @@ class RoomStateRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id):
|
||||
user, _ = yield self.auth.get_user_by_req(request)
|
||||
user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
handler = self.handlers.message_handler
|
||||
# Get all the current state for this room
|
||||
events = yield handler.get_state_events(
|
||||
|
@ -363,7 +364,7 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id):
|
||||
user, _ = yield self.auth.get_user_by_req(request)
|
||||
user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
pagination_config = PaginationConfig.from_request(request)
|
||||
content = yield self.handlers.message_handler.room_initial_sync(
|
||||
room_id=room_id,
|
||||
|
@ -443,7 +444,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, room_id, membership_action, txn_id=None):
|
||||
user, token_id = yield self.auth.get_user_by_req(request)
|
||||
user, token_id, _ = yield self.auth.get_user_by_req(request)
|
||||
|
||||
content = _parse_json(request)
|
||||
|
||||
|
@ -524,7 +525,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, room_id, event_id, txn_id=None):
|
||||
user, token_id = yield self.auth.get_user_by_req(request)
|
||||
user, token_id, _ = yield self.auth.get_user_by_req(request)
|
||||
content = _parse_json(request)
|
||||
|
||||
msg_handler = self.handlers.message_handler
|
||||
|
@ -564,7 +565,7 @@ class RoomTypingRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, room_id, user_id):
|
||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
|
||||
room_id = urllib.unquote(room_id)
|
||||
target_user = UserID.from_string(urllib.unquote(user_id))
|
||||
|
@ -597,7 +598,7 @@ class SearchRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
|
||||
content = _parse_json(request)
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ class VoipRestServlet(ClientV1RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
|
||||
turnUris = self.hs.config.turn_uris
|
||||
turnSecret = self.hs.config.turn_shared_secret
|
||||
|
|
|
@ -55,7 +55,7 @@ class PasswordRestServlet(RestServlet):
|
|||
|
||||
if LoginType.PASSWORD in result:
|
||||
# if using password, they should also be logged in
|
||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
if auth_user.to_string() != result[LoginType.PASSWORD]:
|
||||
raise LoginError(400, "", Codes.UNKNOWN)
|
||||
user_id = auth_user.to_string()
|
||||
|
@ -102,7 +102,7 @@ class ThreepidRestServlet(RestServlet):
|
|||
def on_GET(self, request):
|
||||
yield run_on_reactor()
|
||||
|
||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
|
||||
threepids = yield self.hs.get_datastore().user_get_threepids(
|
||||
auth_user.to_string()
|
||||
|
@ -120,7 +120,7 @@ class ThreepidRestServlet(RestServlet):
|
|||
raise SynapseError(400, "Missing param", Codes.MISSING_PARAM)
|
||||
threePidCreds = body['threePidCreds']
|
||||
|
||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
|
||||
threepid = yield self.identity_handler.threepid_from_creds(threePidCreds)
|
||||
|
||||
|
|
|
@ -40,7 +40,7 @@ class GetFilterRestServlet(RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id, filter_id):
|
||||
target_user = UserID.from_string(user_id)
|
||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
|
||||
if target_user != auth_user:
|
||||
raise AuthError(403, "Cannot get filters for other users")
|
||||
|
@ -76,7 +76,7 @@ class CreateFilterRestServlet(RestServlet):
|
|||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, user_id):
|
||||
target_user = UserID.from_string(user_id)
|
||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
|
||||
if target_user != auth_user:
|
||||
raise AuthError(403, "Cannot create filters for other users")
|
||||
|
|
|
@ -64,7 +64,7 @@ class KeyUploadServlet(RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, device_id):
|
||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
user_id = auth_user.to_string()
|
||||
# TODO: Check that the device_id matches that in the authentication
|
||||
# or derive the device_id from the authentication instead.
|
||||
|
@ -109,7 +109,7 @@ class KeyUploadServlet(RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, device_id):
|
||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
user_id = auth_user.to_string()
|
||||
|
||||
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
|
||||
|
@ -181,7 +181,7 @@ class KeyQueryServlet(RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id, device_id):
|
||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
auth_user_id = auth_user.to_string()
|
||||
user_id = user_id if user_id else auth_user_id
|
||||
device_ids = [device_id] if device_id else []
|
||||
|
|
|
@ -40,7 +40,7 @@ class ReceiptRestServlet(RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, room_id, receipt_type, event_id):
|
||||
user, _ = yield self.auth.get_user_by_req(request)
|
||||
user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
|
||||
if receipt_type != "m.read":
|
||||
raise SynapseError(400, "Receipt type must be 'm.read'")
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.api.errors import SynapseError, Codes
|
||||
from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
|
||||
from synapse.http.servlet import RestServlet
|
||||
|
||||
from ._base import client_v2_pattern, parse_json_dict_from_request
|
||||
|
@ -55,6 +55,19 @@ class RegisterRestServlet(RestServlet):
|
|||
def on_POST(self, request):
|
||||
yield run_on_reactor()
|
||||
|
||||
kind = "user"
|
||||
if "kind" in request.args:
|
||||
kind = request.args["kind"][0]
|
||||
|
||||
if kind == "guest":
|
||||
ret = yield self._do_guest_registration()
|
||||
defer.returnValue(ret)
|
||||
return
|
||||
elif kind != "user":
|
||||
raise UnrecognizedRequestError(
|
||||
"Do not understand membership kind: %s" % (kind,)
|
||||
)
|
||||
|
||||
if '/register/email/requestToken' in request.path:
|
||||
ret = yield self.onEmailTokenRequest(request)
|
||||
defer.returnValue(ret)
|
||||
|
@ -236,6 +249,18 @@ class RegisterRestServlet(RestServlet):
|
|||
ret = yield self.identity_handler.requestEmailToken(**body)
|
||||
defer.returnValue((200, ret))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _do_guest_registration(self):
|
||||
if not self.hs.config.allow_guest_access:
|
||||
defer.returnValue((403, "Guest access is disabled"))
|
||||
user_id, _ = yield self.registration_handler.register(generate_token=False)
|
||||
access_token = self.auth_handler.generate_access_token(user_id, ["guest = true"])
|
||||
defer.returnValue((200, {
|
||||
"user_id": user_id,
|
||||
"access_token": access_token,
|
||||
"home_server": self.hs.hostname,
|
||||
}))
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
RegisterRestServlet(hs).register(http_server)
|
||||
|
|
|
@ -81,7 +81,7 @@ class SyncRestServlet(RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
user, token_id = yield self.auth.get_user_by_req(request)
|
||||
user, token_id, _ = yield self.auth.get_user_by_req(request)
|
||||
|
||||
timeout = parse_integer(request, "timeout", default=0)
|
||||
since = parse_string(request, "since")
|
||||
|
|
|
@ -42,7 +42,7 @@ class TagListServlet(RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id, room_id):
|
||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
if user_id != auth_user.to_string():
|
||||
raise AuthError(403, "Cannot get tags for other users.")
|
||||
|
||||
|
@ -68,7 +68,7 @@ class TagServlet(RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, user_id, room_id, tag):
|
||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
if user_id != auth_user.to_string():
|
||||
raise AuthError(403, "Cannot add tags for other users.")
|
||||
|
||||
|
@ -88,7 +88,7 @@ class TagServlet(RestServlet):
|
|||
|
||||
@defer.inlineCallbacks
|
||||
def on_DELETE(self, request, user_id, room_id, tag):
|
||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
if user_id != auth_user.to_string():
|
||||
raise AuthError(403, "Cannot add tags for other users.")
|
||||
|
||||
|
|
|
@ -66,7 +66,7 @@ class ContentRepoResource(resource.Resource):
|
|||
@defer.inlineCallbacks
|
||||
def map_request_to_name(self, request):
|
||||
# auth the user
|
||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
|
||||
# namespace all file uploads on the user
|
||||
prefix = base64.urlsafe_b64encode(
|
||||
|
|
|
@ -70,7 +70,7 @@ class UploadResource(BaseMediaResource):
|
|||
@request_handler
|
||||
@defer.inlineCallbacks
|
||||
def _async_render_POST(self, request):
|
||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
||||
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||
# TODO: The checks here are a bit late. The content will have
|
||||
# already been uploaded to a tmp file at this point
|
||||
content_length = request.getHeader("Content-Length")
|
||||
|
|
|
@ -102,13 +102,14 @@ class RegistrationStore(SQLBaseStore):
|
|||
400, "User ID already taken.", errcode=Codes.USER_IN_USE
|
||||
)
|
||||
|
||||
# it's possible for this to get a conflict, but only for a single user
|
||||
# since tokens are namespaced based on their user ID
|
||||
txn.execute(
|
||||
"INSERT INTO access_tokens(id, user_id, token)"
|
||||
" VALUES (?,?,?)",
|
||||
(next_id, user_id, token,)
|
||||
)
|
||||
if token:
|
||||
# it's possible for this to get a conflict, but only for a single user
|
||||
# since tokens are namespaced based on their user ID
|
||||
txn.execute(
|
||||
"INSERT INTO access_tokens(id, user_id, token)"
|
||||
" VALUES (?,?,?)",
|
||||
(next_id, user_id, token,)
|
||||
)
|
||||
|
||||
def get_user_by_id(self, user_id):
|
||||
return self._simple_select_one(
|
||||
|
|
|
@ -51,7 +51,7 @@ class AuthTestCase(unittest.TestCase):
|
|||
request = Mock(args={})
|
||||
request.args["access_token"] = [self.test_token]
|
||||
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
|
||||
(user, _) = yield self.auth.get_user_by_req(request)
|
||||
(user, _, _) = yield self.auth.get_user_by_req(request)
|
||||
self.assertEquals(user.to_string(), self.test_user)
|
||||
|
||||
def test_get_user_by_req_user_bad_token(self):
|
||||
|
@ -86,7 +86,7 @@ class AuthTestCase(unittest.TestCase):
|
|||
request = Mock(args={})
|
||||
request.args["access_token"] = [self.test_token]
|
||||
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
|
||||
(user, _) = yield self.auth.get_user_by_req(request)
|
||||
(user, _, _) = yield self.auth.get_user_by_req(request)
|
||||
self.assertEquals(user.to_string(), self.test_user)
|
||||
|
||||
def test_get_user_by_req_appservice_bad_token(self):
|
||||
|
@ -121,7 +121,7 @@ class AuthTestCase(unittest.TestCase):
|
|||
request.args["access_token"] = [self.test_token]
|
||||
request.args["user_id"] = [masquerading_user_id]
|
||||
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
|
||||
(user, _) = yield self.auth.get_user_by_req(request)
|
||||
(user, _, _) = yield self.auth.get_user_by_req(request)
|
||||
self.assertEquals(user.to_string(), masquerading_user_id)
|
||||
|
||||
def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
|
||||
|
@ -158,6 +158,25 @@ class AuthTestCase(unittest.TestCase):
|
|||
user = user_info["user"]
|
||||
self.assertEqual(UserID.from_string(user_id), user)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_guest_user_from_macaroon(self):
|
||||
user_id = "@baldrick:matrix.org"
|
||||
macaroon = pymacaroons.Macaroon(
|
||||
location=self.hs.config.server_name,
|
||||
identifier="key",
|
||||
key=self.hs.config.macaroon_secret_key)
|
||||
macaroon.add_first_party_caveat("gen = 1")
|
||||
macaroon.add_first_party_caveat("type = access")
|
||||
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
|
||||
macaroon.add_first_party_caveat("guest = true")
|
||||
serialized = macaroon.serialize()
|
||||
|
||||
user_info = yield self.auth._get_user_from_macaroon(serialized)
|
||||
user = user_info["user"]
|
||||
is_guest = user_info["is_guest"]
|
||||
self.assertEqual(UserID.from_string(user_id), user)
|
||||
self.assertTrue(is_guest)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_user_from_macaroon_user_db_mismatch(self):
|
||||
self.store.get_user_by_access_token = Mock(
|
||||
|
|
|
@ -86,10 +86,11 @@ class PresenceStateTestCase(unittest.TestCase):
|
|||
return defer.succeed([])
|
||||
self.datastore.get_presence_list = get_presence_list
|
||||
|
||||
def _get_user_by_access_token(token=None):
|
||||
def _get_user_by_access_token(token=None, allow_guest=False):
|
||||
return {
|
||||
"user": UserID.from_string(myid),
|
||||
"token_id": 1,
|
||||
"is_guest": False,
|
||||
}
|
||||
|
||||
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
|
||||
|
@ -173,10 +174,11 @@ class PresenceListTestCase(unittest.TestCase):
|
|||
)
|
||||
self.datastore.has_presence_state = has_presence_state
|
||||
|
||||
def _get_user_by_access_token(token=None):
|
||||
def _get_user_by_access_token(token=None, allow_guest=False):
|
||||
return {
|
||||
"user": UserID.from_string(myid),
|
||||
"token_id": 1,
|
||||
"is_guest": False,
|
||||
}
|
||||
|
||||
hs.handlers.room_member_handler = Mock(
|
||||
|
@ -291,8 +293,8 @@ class PresenceEventStreamTestCase(unittest.TestCase):
|
|||
|
||||
hs.get_clock().time_msec.return_value = 1000000
|
||||
|
||||
def _get_user_by_req(req=None):
|
||||
return (UserID.from_string(myid), "")
|
||||
def _get_user_by_req(req=None, allow_guest=False):
|
||||
return (UserID.from_string(myid), "", False)
|
||||
|
||||
hs.get_v1auth().get_user_by_req = _get_user_by_req
|
||||
|
||||
|
|
|
@ -52,8 +52,8 @@ class ProfileTestCase(unittest.TestCase):
|
|||
replication_layer=Mock(),
|
||||
)
|
||||
|
||||
def _get_user_by_req(request=None):
|
||||
return (UserID.from_string(myid), "")
|
||||
def _get_user_by_req(request=None, allow_guest=False):
|
||||
return (UserID.from_string(myid), "", False)
|
||||
|
||||
hs.get_v1auth().get_user_by_req = _get_user_by_req
|
||||
|
||||
|
|
|
@ -54,10 +54,11 @@ class RoomPermissionsTestCase(RestTestCase):
|
|||
|
||||
hs.get_handlers().federation_handler = Mock()
|
||||
|
||||
def _get_user_by_access_token(token=None):
|
||||
def _get_user_by_access_token(token=None, allow_guest=False):
|
||||
return {
|
||||
"user": UserID.from_string(self.auth_user_id),
|
||||
"token_id": 1,
|
||||
"is_guest": False,
|
||||
}
|
||||
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
|
||||
|
||||
|
@ -439,10 +440,11 @@ class RoomsMemberListTestCase(RestTestCase):
|
|||
|
||||
self.auth_user_id = self.user_id
|
||||
|
||||
def _get_user_by_access_token(token=None):
|
||||
def _get_user_by_access_token(token=None, allow_guest=False):
|
||||
return {
|
||||
"user": UserID.from_string(self.auth_user_id),
|
||||
"token_id": 1,
|
||||
"is_guest": False,
|
||||
}
|
||||
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
|
||||
|
||||
|
@ -517,10 +519,11 @@ class RoomsCreateTestCase(RestTestCase):
|
|||
|
||||
hs.get_handlers().federation_handler = Mock()
|
||||
|
||||
def _get_user_by_access_token(token=None):
|
||||
def _get_user_by_access_token(token=None, allow_guest=False):
|
||||
return {
|
||||
"user": UserID.from_string(self.auth_user_id),
|
||||
"token_id": 1,
|
||||
"is_guest": False,
|
||||
}
|
||||
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
|
||||
|
||||
|
@ -608,10 +611,11 @@ class RoomTopicTestCase(RestTestCase):
|
|||
|
||||
hs.get_handlers().federation_handler = Mock()
|
||||
|
||||
def _get_user_by_access_token(token=None):
|
||||
def _get_user_by_access_token(token=None, allow_guest=False):
|
||||
return {
|
||||
"user": UserID.from_string(self.auth_user_id),
|
||||
"token_id": 1,
|
||||
"is_guest": False,
|
||||
}
|
||||
|
||||
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
|
||||
|
@ -713,10 +717,11 @@ class RoomMemberStateTestCase(RestTestCase):
|
|||
|
||||
hs.get_handlers().federation_handler = Mock()
|
||||
|
||||
def _get_user_by_access_token(token=None):
|
||||
def _get_user_by_access_token(token=None, allow_guest=False):
|
||||
return {
|
||||
"user": UserID.from_string(self.auth_user_id),
|
||||
"token_id": 1,
|
||||
"is_guest": False,
|
||||
}
|
||||
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
|
||||
|
||||
|
@ -838,10 +843,11 @@ class RoomMessagesTestCase(RestTestCase):
|
|||
|
||||
hs.get_handlers().federation_handler = Mock()
|
||||
|
||||
def _get_user_by_access_token(token=None):
|
||||
def _get_user_by_access_token(token=None, allow_guest=False):
|
||||
return {
|
||||
"user": UserID.from_string(self.auth_user_id),
|
||||
"token_id": 1,
|
||||
"is_guest": False,
|
||||
}
|
||||
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
|
||||
|
||||
|
@ -933,10 +939,11 @@ class RoomInitialSyncTestCase(RestTestCase):
|
|||
|
||||
hs.get_handlers().federation_handler = Mock()
|
||||
|
||||
def _get_user_by_access_token(token=None):
|
||||
def _get_user_by_access_token(token=None, allow_guest=False):
|
||||
return {
|
||||
"user": UserID.from_string(self.auth_user_id),
|
||||
"token_id": 1,
|
||||
"is_guest": False,
|
||||
}
|
||||
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
|
||||
|
||||
|
|
|
@ -61,10 +61,11 @@ class RoomTypingTestCase(RestTestCase):
|
|||
|
||||
hs.get_handlers().federation_handler = Mock()
|
||||
|
||||
def _get_user_by_access_token(token=None):
|
||||
def _get_user_by_access_token(token=None, allow_guest=False):
|
||||
return {
|
||||
"user": UserID.from_string(self.auth_user_id),
|
||||
"token_id": 1,
|
||||
"is_guest": False,
|
||||
}
|
||||
|
||||
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
|
||||
|
|
|
@ -43,10 +43,11 @@ class V2AlphaRestTestCase(unittest.TestCase):
|
|||
resource_for_federation=self.mock_resource,
|
||||
)
|
||||
|
||||
def _get_user_by_access_token(token=None):
|
||||
def _get_user_by_access_token(token=None, allow_guest=False):
|
||||
return {
|
||||
"user": UserID.from_string(self.USER_ID),
|
||||
"token_id": 1,
|
||||
"is_guest": False,
|
||||
}
|
||||
hs.get_auth()._get_user_by_access_token = _get_user_by_access_token
|
||||
|
||||
|
|
Loading…
Reference in a new issue