mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 13:13:50 +01:00
Merge branch 'develop' of github.com:matrix-org/synapse into develop
This commit is contained in:
commit
347146be29
55 changed files with 571 additions and 489 deletions
|
@ -49,6 +49,7 @@ class Auth(object):
|
||||||
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
|
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
|
||||||
self._KNOWN_CAVEAT_PREFIXES = set([
|
self._KNOWN_CAVEAT_PREFIXES = set([
|
||||||
"gen = ",
|
"gen = ",
|
||||||
|
"guest = ",
|
||||||
"type = ",
|
"type = ",
|
||||||
"time < ",
|
"time < ",
|
||||||
"user_id = ",
|
"user_id = ",
|
||||||
|
@ -183,15 +184,11 @@ class Auth(object):
|
||||||
defer.returnValue(member)
|
defer.returnValue(member)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def check_user_was_in_room(self, room_id, user_id, current_state=None):
|
def check_user_was_in_room(self, room_id, user_id):
|
||||||
"""Check if the user was in the room at some point.
|
"""Check if the user was in the room at some point.
|
||||||
Args:
|
Args:
|
||||||
room_id(str): The room to check.
|
room_id(str): The room to check.
|
||||||
user_id(str): The user to check.
|
user_id(str): The user to check.
|
||||||
current_state(dict): Optional map of the current state of the room.
|
|
||||||
If provided then that map is used to check whether they are a
|
|
||||||
member of the room. Otherwise the current membership is
|
|
||||||
loaded from the database.
|
|
||||||
Raises:
|
Raises:
|
||||||
AuthError if the user was never in the room.
|
AuthError if the user was never in the room.
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -199,12 +196,6 @@ class Auth(object):
|
||||||
room. This will be the join event if they are currently joined to
|
room. This will be the join event if they are currently joined to
|
||||||
the room. This will be the leave event if they have left the room.
|
the room. This will be the leave event if they have left the room.
|
||||||
"""
|
"""
|
||||||
if current_state:
|
|
||||||
member = current_state.get(
|
|
||||||
(EventTypes.Member, user_id),
|
|
||||||
None
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
member = yield self.state.get_current_state(
|
member = yield self.state.get_current_state(
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
event_type=EventTypes.Member,
|
event_type=EventTypes.Member,
|
||||||
|
@ -497,7 +488,7 @@ class Auth(object):
|
||||||
return default
|
return default
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_user_by_req(self, request):
|
def get_user_by_req(self, request, allow_guest=False):
|
||||||
""" Get a registered user's ID.
|
""" Get a registered user's ID.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -535,7 +526,7 @@ class Auth(object):
|
||||||
|
|
||||||
request.authenticated_entity = user_id
|
request.authenticated_entity = user_id
|
||||||
|
|
||||||
defer.returnValue((UserID.from_string(user_id), ""))
|
defer.returnValue((UserID.from_string(user_id), "", False))
|
||||||
return
|
return
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass # normal users won't have the user_id query parameter set.
|
pass # normal users won't have the user_id query parameter set.
|
||||||
|
@ -543,6 +534,7 @@ class Auth(object):
|
||||||
user_info = yield self._get_user_by_access_token(access_token)
|
user_info = yield self._get_user_by_access_token(access_token)
|
||||||
user = user_info["user"]
|
user = user_info["user"]
|
||||||
token_id = user_info["token_id"]
|
token_id = user_info["token_id"]
|
||||||
|
is_guest = user_info["is_guest"]
|
||||||
|
|
||||||
ip_addr = self.hs.get_ip_from_request(request)
|
ip_addr = self.hs.get_ip_from_request(request)
|
||||||
user_agent = request.requestHeaders.getRawHeaders(
|
user_agent = request.requestHeaders.getRawHeaders(
|
||||||
|
@ -557,9 +549,14 @@ class Auth(object):
|
||||||
user_agent=user_agent
|
user_agent=user_agent
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if is_guest and not allow_guest:
|
||||||
|
raise AuthError(
|
||||||
|
403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
|
||||||
|
)
|
||||||
|
|
||||||
request.authenticated_entity = user.to_string()
|
request.authenticated_entity = user.to_string()
|
||||||
|
|
||||||
defer.returnValue((user, token_id,))
|
defer.returnValue((user, token_id, is_guest,))
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise AuthError(
|
raise AuthError(
|
||||||
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
|
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
|
||||||
|
@ -592,9 +589,27 @@ class Auth(object):
|
||||||
self._validate_macaroon(macaroon)
|
self._validate_macaroon(macaroon)
|
||||||
|
|
||||||
user_prefix = "user_id = "
|
user_prefix = "user_id = "
|
||||||
|
user = None
|
||||||
|
guest = False
|
||||||
for caveat in macaroon.caveats:
|
for caveat in macaroon.caveats:
|
||||||
if caveat.caveat_id.startswith(user_prefix):
|
if caveat.caveat_id.startswith(user_prefix):
|
||||||
user = UserID.from_string(caveat.caveat_id[len(user_prefix):])
|
user = UserID.from_string(caveat.caveat_id[len(user_prefix):])
|
||||||
|
elif caveat.caveat_id == "guest = true":
|
||||||
|
guest = True
|
||||||
|
|
||||||
|
if user is None:
|
||||||
|
raise AuthError(
|
||||||
|
self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
|
||||||
|
errcode=Codes.UNKNOWN_TOKEN
|
||||||
|
)
|
||||||
|
|
||||||
|
if guest:
|
||||||
|
ret = {
|
||||||
|
"user": user,
|
||||||
|
"is_guest": True,
|
||||||
|
"token_id": None,
|
||||||
|
}
|
||||||
|
else:
|
||||||
# This codepath exists so that we can actually return a
|
# This codepath exists so that we can actually return a
|
||||||
# token ID, because we use token IDs in place of device
|
# token ID, because we use token IDs in place of device
|
||||||
# identifiers throughout the codebase.
|
# identifiers throughout the codebase.
|
||||||
|
@ -613,10 +628,6 @@ class Auth(object):
|
||||||
errcode=Codes.UNKNOWN_TOKEN
|
errcode=Codes.UNKNOWN_TOKEN
|
||||||
)
|
)
|
||||||
defer.returnValue(ret)
|
defer.returnValue(ret)
|
||||||
raise AuthError(
|
|
||||||
self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
|
|
||||||
errcode=Codes.UNKNOWN_TOKEN
|
|
||||||
)
|
|
||||||
except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
|
except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
|
||||||
raise AuthError(
|
raise AuthError(
|
||||||
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.",
|
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.",
|
||||||
|
@ -629,6 +640,7 @@ class Auth(object):
|
||||||
v.satisfy_exact("type = access")
|
v.satisfy_exact("type = access")
|
||||||
v.satisfy_general(lambda c: c.startswith("user_id = "))
|
v.satisfy_general(lambda c: c.startswith("user_id = "))
|
||||||
v.satisfy_general(self._verify_expiry)
|
v.satisfy_general(self._verify_expiry)
|
||||||
|
v.satisfy_exact("guest = true")
|
||||||
v.verify(macaroon, self.hs.config.macaroon_secret_key)
|
v.verify(macaroon, self.hs.config.macaroon_secret_key)
|
||||||
|
|
||||||
v = pymacaroons.Verifier()
|
v = pymacaroons.Verifier()
|
||||||
|
@ -666,6 +678,7 @@ class Auth(object):
|
||||||
user_info = {
|
user_info = {
|
||||||
"user": UserID.from_string(ret.get("name")),
|
"user": UserID.from_string(ret.get("name")),
|
||||||
"token_id": ret.get("token_id", None),
|
"token_id": ret.get("token_id", None),
|
||||||
|
"is_guest": False,
|
||||||
}
|
}
|
||||||
defer.returnValue(user_info)
|
defer.returnValue(user_info)
|
||||||
|
|
||||||
|
|
|
@ -33,6 +33,7 @@ class Codes(object):
|
||||||
NOT_FOUND = "M_NOT_FOUND"
|
NOT_FOUND = "M_NOT_FOUND"
|
||||||
MISSING_TOKEN = "M_MISSING_TOKEN"
|
MISSING_TOKEN = "M_MISSING_TOKEN"
|
||||||
UNKNOWN_TOKEN = "M_UNKNOWN_TOKEN"
|
UNKNOWN_TOKEN = "M_UNKNOWN_TOKEN"
|
||||||
|
GUEST_ACCESS_FORBIDDEN = "M_GUEST_ACCESS_FORBIDDEN"
|
||||||
LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED"
|
LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED"
|
||||||
CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED"
|
CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED"
|
||||||
CAPTCHA_INVALID = "M_CAPTCHA_INVALID"
|
CAPTCHA_INVALID = "M_CAPTCHA_INVALID"
|
||||||
|
|
|
@ -50,11 +50,11 @@ class Filtering(object):
|
||||||
# many definitions.
|
# many definitions.
|
||||||
|
|
||||||
top_level_definitions = [
|
top_level_definitions = [
|
||||||
"public_user_data", "private_user_data", "server_data"
|
"presence"
|
||||||
]
|
]
|
||||||
|
|
||||||
room_level_definitions = [
|
room_level_definitions = [
|
||||||
"state", "timeline", "ephemeral"
|
"state", "timeline", "ephemeral", "private_user_data"
|
||||||
]
|
]
|
||||||
|
|
||||||
for key in top_level_definitions:
|
for key in top_level_definitions:
|
||||||
|
@ -114,22 +114,6 @@ class Filtering(object):
|
||||||
if not isinstance(event_type, basestring):
|
if not isinstance(event_type, basestring):
|
||||||
raise SynapseError(400, "Event type should be a string")
|
raise SynapseError(400, "Event type should be a string")
|
||||||
|
|
||||||
if "format" in definition:
|
|
||||||
event_format = definition["format"]
|
|
||||||
if event_format not in ["federation", "events"]:
|
|
||||||
raise SynapseError(400, "Invalid format: %s" % (event_format,))
|
|
||||||
|
|
||||||
if "select" in definition:
|
|
||||||
event_select_list = definition["select"]
|
|
||||||
for select_key in event_select_list:
|
|
||||||
if select_key not in ["event_id", "origin_server_ts",
|
|
||||||
"thread_id", "content", "content.body"]:
|
|
||||||
raise SynapseError(400, "Bad select: %s" % (select_key,))
|
|
||||||
|
|
||||||
if ("bundle_updates" in definition and
|
|
||||||
type(definition["bundle_updates"]) != bool):
|
|
||||||
raise SynapseError(400, "Bad bundle_updates: expected bool.")
|
|
||||||
|
|
||||||
|
|
||||||
class FilterCollection(object):
|
class FilterCollection(object):
|
||||||
def __init__(self, filter_json):
|
def __init__(self, filter_json):
|
||||||
|
|
|
@ -34,6 +34,7 @@ class RegistrationConfig(Config):
|
||||||
self.registration_shared_secret = config.get("registration_shared_secret")
|
self.registration_shared_secret = config.get("registration_shared_secret")
|
||||||
self.macaroon_secret_key = config.get("macaroon_secret_key")
|
self.macaroon_secret_key = config.get("macaroon_secret_key")
|
||||||
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
|
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
|
||||||
|
self.allow_guest_access = config.get("allow_guest_access", False)
|
||||||
|
|
||||||
def default_config(self, **kwargs):
|
def default_config(self, **kwargs):
|
||||||
registration_shared_secret = random_string_with_symbols(50)
|
registration_shared_secret = random_string_with_symbols(50)
|
||||||
|
@ -54,6 +55,11 @@ class RegistrationConfig(Config):
|
||||||
# Larger numbers increase the work factor needed to generate the hash.
|
# Larger numbers increase the work factor needed to generate the hash.
|
||||||
# The default number of rounds is 12.
|
# The default number of rounds is 12.
|
||||||
bcrypt_rounds: 12
|
bcrypt_rounds: 12
|
||||||
|
|
||||||
|
# Allows users to register as guests without a password/email/etc, and
|
||||||
|
# participate in rooms hosted on this server which have been made
|
||||||
|
# accessible to anonymous users.
|
||||||
|
allow_guest_access: False
|
||||||
""" % locals()
|
""" % locals()
|
||||||
|
|
||||||
def add_arguments(self, parser):
|
def add_arguments(self, parser):
|
||||||
|
|
|
@ -47,37 +47,24 @@ class BaseHandler(object):
|
||||||
self.event_builder_factory = hs.get_event_builder_factory()
|
self.event_builder_factory = hs.get_event_builder_factory()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _filter_events_for_client(self, user_id, events):
|
def _filter_events_for_client(self, user_id, events, is_guest=False,
|
||||||
event_id_to_state = yield self.store.get_state_for_events(
|
require_all_visible_for_guests=True):
|
||||||
frozenset(e.event_id for e in events),
|
# Assumes that user has at some point joined the room if not is_guest.
|
||||||
types=(
|
|
||||||
(EventTypes.RoomHistoryVisibility, ""),
|
|
||||||
(EventTypes.Member, user_id),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def allowed(event, state):
|
def allowed(event, membership, visibility):
|
||||||
if event.type == EventTypes.RoomHistoryVisibility:
|
if visibility == "world_readable":
|
||||||
return True
|
return True
|
||||||
|
|
||||||
membership_ev = state.get((EventTypes.Member, user_id), None)
|
if is_guest:
|
||||||
if membership_ev:
|
return False
|
||||||
membership = membership_ev.membership
|
|
||||||
else:
|
|
||||||
membership = Membership.LEAVE
|
|
||||||
|
|
||||||
if membership == Membership.JOIN:
|
if membership == Membership.JOIN:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
|
if event.type == EventTypes.RoomHistoryVisibility:
|
||||||
if history:
|
return not is_guest
|
||||||
visibility = history.content.get("history_visibility", "shared")
|
|
||||||
else:
|
|
||||||
visibility = "shared"
|
|
||||||
|
|
||||||
if visibility == "public":
|
if visibility == "shared":
|
||||||
return True
|
|
||||||
elif visibility == "shared":
|
|
||||||
return True
|
return True
|
||||||
elif visibility == "joined":
|
elif visibility == "joined":
|
||||||
return membership == Membership.JOIN
|
return membership == Membership.JOIN
|
||||||
|
@ -86,11 +73,46 @@ class BaseHandler(object):
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
defer.returnValue([
|
event_id_to_state = yield self.store.get_state_for_events(
|
||||||
event
|
frozenset(e.event_id for e in events),
|
||||||
for event in events
|
types=(
|
||||||
if allowed(event, event_id_to_state[event.event_id])
|
(EventTypes.RoomHistoryVisibility, ""),
|
||||||
])
|
(EventTypes.Member, user_id),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
events_to_return = []
|
||||||
|
for event in events:
|
||||||
|
state = event_id_to_state[event.event_id]
|
||||||
|
|
||||||
|
membership_event = state.get((EventTypes.Member, user_id), None)
|
||||||
|
if membership_event:
|
||||||
|
membership = membership_event.membership
|
||||||
|
else:
|
||||||
|
membership = None
|
||||||
|
|
||||||
|
visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
|
||||||
|
if visibility_event:
|
||||||
|
visibility = visibility_event.content.get("history_visibility", "shared")
|
||||||
|
else:
|
||||||
|
visibility = "shared"
|
||||||
|
|
||||||
|
should_include = allowed(event, membership, visibility)
|
||||||
|
if should_include:
|
||||||
|
events_to_return.append(event)
|
||||||
|
|
||||||
|
if (require_all_visible_for_guests
|
||||||
|
and is_guest
|
||||||
|
and len(events_to_return) < len(events)):
|
||||||
|
# This indicates that some events in the requested range were not
|
||||||
|
# visible to guest users. To be safe, we reject the entire request,
|
||||||
|
# so that we don't have to worry about interpreting visibility
|
||||||
|
# boundaries.
|
||||||
|
raise AuthError(403, "User %s does not have permission" % (
|
||||||
|
user_id
|
||||||
|
))
|
||||||
|
|
||||||
|
defer.returnValue(events_to_return)
|
||||||
|
|
||||||
def ratelimit(self, user_id):
|
def ratelimit(self, user_id):
|
||||||
time_now = self.clock.time()
|
time_now = self.clock.time()
|
||||||
|
|
|
@ -372,12 +372,15 @@ class AuthHandler(BaseHandler):
|
||||||
yield self.store.add_refresh_token_to_user(user_id, refresh_token)
|
yield self.store.add_refresh_token_to_user(user_id, refresh_token)
|
||||||
defer.returnValue(refresh_token)
|
defer.returnValue(refresh_token)
|
||||||
|
|
||||||
def generate_access_token(self, user_id):
|
def generate_access_token(self, user_id, extra_caveats=None):
|
||||||
|
extra_caveats = extra_caveats or []
|
||||||
macaroon = self._generate_base_macaroon(user_id)
|
macaroon = self._generate_base_macaroon(user_id)
|
||||||
macaroon.add_first_party_caveat("type = access")
|
macaroon.add_first_party_caveat("type = access")
|
||||||
now = self.hs.get_clock().time_msec()
|
now = self.hs.get_clock().time_msec()
|
||||||
expiry = now + (60 * 60 * 1000)
|
expiry = now + (60 * 60 * 1000)
|
||||||
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
macaroon.add_first_party_caveat("time < %d" % (expiry,))
|
||||||
|
for caveat in extra_caveats:
|
||||||
|
macaroon.add_first_party_caveat(caveat)
|
||||||
return macaroon.serialize()
|
return macaroon.serialize()
|
||||||
|
|
||||||
def generate_refresh_token(self, user_id):
|
def generate_refresh_token(self, user_id):
|
||||||
|
|
|
@ -100,7 +100,7 @@ class EventStreamHandler(BaseHandler):
|
||||||
@log_function
|
@log_function
|
||||||
def get_stream(self, auth_user_id, pagin_config, timeout=0,
|
def get_stream(self, auth_user_id, pagin_config, timeout=0,
|
||||||
as_client_event=True, affect_presence=True,
|
as_client_event=True, affect_presence=True,
|
||||||
only_room_events=False):
|
only_room_events=False, room_id=None, is_guest=False):
|
||||||
"""Fetches the events stream for a given user.
|
"""Fetches the events stream for a given user.
|
||||||
|
|
||||||
If `only_room_events` is `True` only room events will be returned.
|
If `only_room_events` is `True` only room events will be returned.
|
||||||
|
@ -111,17 +111,6 @@ class EventStreamHandler(BaseHandler):
|
||||||
if affect_presence:
|
if affect_presence:
|
||||||
yield self.started_stream(auth_user)
|
yield self.started_stream(auth_user)
|
||||||
|
|
||||||
rm_handler = self.hs.get_handlers().room_member_handler
|
|
||||||
|
|
||||||
app_service = yield self.store.get_app_service_by_user_id(
|
|
||||||
auth_user.to_string()
|
|
||||||
)
|
|
||||||
if app_service:
|
|
||||||
rooms = yield self.store.get_app_service_rooms(app_service)
|
|
||||||
room_ids = set(r.room_id for r in rooms)
|
|
||||||
else:
|
|
||||||
room_ids = yield rm_handler.get_joined_rooms_for_user(auth_user)
|
|
||||||
|
|
||||||
if timeout:
|
if timeout:
|
||||||
# If they've set a timeout set a minimum limit.
|
# If they've set a timeout set a minimum limit.
|
||||||
timeout = max(timeout, 500)
|
timeout = max(timeout, 500)
|
||||||
|
@ -130,9 +119,15 @@ class EventStreamHandler(BaseHandler):
|
||||||
# thundering herds on restart.
|
# thundering herds on restart.
|
||||||
timeout = random.randint(int(timeout*0.9), int(timeout*1.1))
|
timeout = random.randint(int(timeout*0.9), int(timeout*1.1))
|
||||||
|
|
||||||
|
if is_guest:
|
||||||
|
yield self.distributor.fire(
|
||||||
|
"user_joined_room", user=auth_user, room_id=room_id
|
||||||
|
)
|
||||||
|
|
||||||
events, tokens = yield self.notifier.get_events_for(
|
events, tokens = yield self.notifier.get_events_for(
|
||||||
auth_user, room_ids, pagin_config, timeout,
|
auth_user, pagin_config, timeout,
|
||||||
only_room_events=only_room_events
|
only_room_events=only_room_events,
|
||||||
|
is_guest=is_guest, guest_room_id=room_id
|
||||||
)
|
)
|
||||||
|
|
||||||
time_now = self.clock.time_msec()
|
time_now = self.clock.time_msec()
|
||||||
|
|
|
@ -72,8 +72,6 @@ class FederationHandler(BaseHandler):
|
||||||
self.server_name = hs.hostname
|
self.server_name = hs.hostname
|
||||||
self.keyring = hs.get_keyring()
|
self.keyring = hs.get_keyring()
|
||||||
|
|
||||||
self.lock_manager = hs.get_room_lock_manager()
|
|
||||||
|
|
||||||
self.replication_layer.set_handler(self)
|
self.replication_layer.set_handler(self)
|
||||||
|
|
||||||
# When joining a room we need to queue any events for that room up
|
# When joining a room we need to queue any events for that room up
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes, Membership
|
from synapse.api.constants import EventTypes, Membership
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError, AuthError, Codes
|
||||||
from synapse.streams.config import PaginationConfig
|
from synapse.streams.config import PaginationConfig
|
||||||
from synapse.events.utils import serialize_event
|
from synapse.events.utils import serialize_event
|
||||||
from synapse.events.validator import EventValidator
|
from synapse.events.validator import EventValidator
|
||||||
|
@ -71,7 +71,7 @@ class MessageHandler(BaseHandler):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_messages(self, user_id=None, room_id=None, pagin_config=None,
|
def get_messages(self, user_id=None, room_id=None, pagin_config=None,
|
||||||
as_client_event=True):
|
as_client_event=True, is_guest=False):
|
||||||
"""Get messages in a room.
|
"""Get messages in a room.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -80,11 +80,11 @@ class MessageHandler(BaseHandler):
|
||||||
pagin_config (synapse.api.streams.PaginationConfig): The pagination
|
pagin_config (synapse.api.streams.PaginationConfig): The pagination
|
||||||
config rules to apply, if any.
|
config rules to apply, if any.
|
||||||
as_client_event (bool): True to get events in client-server format.
|
as_client_event (bool): True to get events in client-server format.
|
||||||
|
is_guest (bool): Whether the requesting user is a guest (as opposed
|
||||||
|
to a fully registered user).
|
||||||
Returns:
|
Returns:
|
||||||
dict: Pagination API results
|
dict: Pagination API results
|
||||||
"""
|
"""
|
||||||
member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
|
|
||||||
|
|
||||||
data_source = self.hs.get_event_sources().sources["room"]
|
data_source = self.hs.get_event_sources().sources["room"]
|
||||||
|
|
||||||
if pagin_config.from_token:
|
if pagin_config.from_token:
|
||||||
|
@ -107,9 +107,13 @@ class MessageHandler(BaseHandler):
|
||||||
|
|
||||||
source_config = pagin_config.get_source_config("room")
|
source_config = pagin_config.get_source_config("room")
|
||||||
|
|
||||||
|
if not is_guest:
|
||||||
|
member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
|
||||||
if member_event.membership == Membership.LEAVE:
|
if member_event.membership == Membership.LEAVE:
|
||||||
# If they have left the room then clamp the token to be before
|
# If they have left the room then clamp the token to be before
|
||||||
# they left the room
|
# they left the room.
|
||||||
|
# If they're a guest, we'll just 403 them if they're asking for
|
||||||
|
# events they can't see.
|
||||||
leave_token = yield self.store.get_topological_token_for_event(
|
leave_token = yield self.store.get_topological_token_for_event(
|
||||||
member_event.event_id
|
member_event.event_id
|
||||||
)
|
)
|
||||||
|
@ -146,7 +150,7 @@ class MessageHandler(BaseHandler):
|
||||||
"end": next_token.to_string(),
|
"end": next_token.to_string(),
|
||||||
})
|
})
|
||||||
|
|
||||||
events = yield self._filter_events_for_client(user_id, events)
|
events = yield self._filter_events_for_client(user_id, events, is_guest=is_guest)
|
||||||
|
|
||||||
time_now = self.clock.time_msec()
|
time_now = self.clock.time_msec()
|
||||||
|
|
||||||
|
@ -225,7 +229,7 @@ class MessageHandler(BaseHandler):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_room_data(self, user_id=None, room_id=None,
|
def get_room_data(self, user_id=None, room_id=None,
|
||||||
event_type=None, state_key=""):
|
event_type=None, state_key="", is_guest=False):
|
||||||
""" Get data from a room.
|
""" Get data from a room.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -235,23 +239,42 @@ class MessageHandler(BaseHandler):
|
||||||
Raises:
|
Raises:
|
||||||
SynapseError if something went wrong.
|
SynapseError if something went wrong.
|
||||||
"""
|
"""
|
||||||
member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
|
membership, membership_event_id = yield self._check_in_room_or_world_readable(
|
||||||
|
room_id, user_id, is_guest
|
||||||
|
)
|
||||||
|
|
||||||
if member_event.membership == Membership.JOIN:
|
if membership == Membership.JOIN:
|
||||||
data = yield self.state_handler.get_current_state(
|
data = yield self.state_handler.get_current_state(
|
||||||
room_id, event_type, state_key
|
room_id, event_type, state_key
|
||||||
)
|
)
|
||||||
elif member_event.membership == Membership.LEAVE:
|
elif membership == Membership.LEAVE:
|
||||||
key = (event_type, state_key)
|
key = (event_type, state_key)
|
||||||
room_state = yield self.store.get_state_for_events(
|
room_state = yield self.store.get_state_for_events(
|
||||||
[member_event.event_id], [key]
|
[membership_event_id], [key]
|
||||||
)
|
)
|
||||||
data = room_state[member_event.event_id].get(key)
|
data = room_state[membership_event_id].get(key)
|
||||||
|
|
||||||
defer.returnValue(data)
|
defer.returnValue(data)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_state_events(self, user_id, room_id):
|
def _check_in_room_or_world_readable(self, room_id, user_id, is_guest):
|
||||||
|
if is_guest:
|
||||||
|
visibility = yield self.state_handler.get_current_state(
|
||||||
|
room_id, EventTypes.RoomHistoryVisibility, ""
|
||||||
|
)
|
||||||
|
if visibility.content["history_visibility"] == "world_readable":
|
||||||
|
defer.returnValue((Membership.JOIN, None))
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
raise AuthError(
|
||||||
|
403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
|
||||||
|
defer.returnValue((member_event.membership, member_event.event_id))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_state_events(self, user_id, room_id, is_guest=False):
|
||||||
"""Retrieve all state events for a given room. If the user is
|
"""Retrieve all state events for a given room. If the user is
|
||||||
joined to the room then return the current state. If the user has
|
joined to the room then return the current state. If the user has
|
||||||
left the room return the state events from when they left.
|
left the room return the state events from when they left.
|
||||||
|
@ -262,15 +285,17 @@ class MessageHandler(BaseHandler):
|
||||||
Returns:
|
Returns:
|
||||||
A list of dicts representing state events. [{}, {}, {}]
|
A list of dicts representing state events. [{}, {}, {}]
|
||||||
"""
|
"""
|
||||||
member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
|
membership, membership_event_id = yield self._check_in_room_or_world_readable(
|
||||||
|
room_id, user_id, is_guest
|
||||||
if member_event.membership == Membership.JOIN:
|
|
||||||
room_state = yield self.state_handler.get_current_state(room_id)
|
|
||||||
elif member_event.membership == Membership.LEAVE:
|
|
||||||
room_state = yield self.store.get_state_for_events(
|
|
||||||
[member_event.event_id], None
|
|
||||||
)
|
)
|
||||||
room_state = room_state[member_event.event_id]
|
|
||||||
|
if membership == Membership.JOIN:
|
||||||
|
room_state = yield self.state_handler.get_current_state(room_id)
|
||||||
|
elif membership == Membership.LEAVE:
|
||||||
|
room_state = yield self.store.get_state_for_events(
|
||||||
|
[membership_event_id], None
|
||||||
|
)
|
||||||
|
room_state = room_state[membership_event_id]
|
||||||
|
|
||||||
now = self.clock.time_msec()
|
now = self.clock.time_msec()
|
||||||
defer.returnValue(
|
defer.returnValue(
|
||||||
|
|
|
@ -1142,8 +1142,9 @@ class PresenceEventSource(object):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@log_function
|
@log_function
|
||||||
def get_new_events_for_user(self, user, from_key, limit):
|
def get_new_events(self, user, from_key, room_ids=None, **kwargs):
|
||||||
from_key = int(from_key)
|
from_key = int(from_key)
|
||||||
|
room_ids = room_ids or []
|
||||||
|
|
||||||
presence = self.hs.get_handlers().presence_handler
|
presence = self.hs.get_handlers().presence_handler
|
||||||
cachemap = presence._user_cachemap
|
cachemap = presence._user_cachemap
|
||||||
|
@ -1161,7 +1162,6 @@ class PresenceEventSource(object):
|
||||||
user_ids_to_check |= set(
|
user_ids_to_check |= set(
|
||||||
UserID.from_string(p["observed_user_id"]) for p in presence_list
|
UserID.from_string(p["observed_user_id"]) for p in presence_list
|
||||||
)
|
)
|
||||||
room_ids = yield presence.get_joined_rooms_for_user(user)
|
|
||||||
for room_id in set(room_ids) & set(presence._room_serials):
|
for room_id in set(room_ids) & set(presence._room_serials):
|
||||||
if presence._room_serials[room_id] > from_key:
|
if presence._room_serials[room_id] > from_key:
|
||||||
joined = yield presence.get_joined_users_for_room_id(room_id)
|
joined = yield presence.get_joined_users_for_room_id(room_id)
|
||||||
|
|
|
@ -24,7 +24,7 @@ class PrivateUserDataEventSource(object):
|
||||||
return self.store.get_max_private_user_data_stream_id()
|
return self.store.get_max_private_user_data_stream_id()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_new_events_for_user(self, user, from_key, limit):
|
def get_new_events(self, user, from_key, **kwargs):
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
last_stream_id = from_key
|
last_stream_id = from_key
|
||||||
|
|
||||||
|
|
|
@ -164,17 +164,15 @@ class ReceiptEventSource(object):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_new_events_for_user(self, user, from_key, limit):
|
def get_new_events(self, from_key, room_ids, **kwargs):
|
||||||
from_key = int(from_key)
|
from_key = int(from_key)
|
||||||
to_key = yield self.get_current_key()
|
to_key = yield self.get_current_key()
|
||||||
|
|
||||||
if from_key == to_key:
|
if from_key == to_key:
|
||||||
defer.returnValue(([], to_key))
|
defer.returnValue(([], to_key))
|
||||||
|
|
||||||
rooms = yield self.store.get_rooms_for_user(user.to_string())
|
|
||||||
rooms = [room.room_id for room in rooms]
|
|
||||||
events = yield self.store.get_linearized_receipts_for_rooms(
|
events = yield self.store.get_linearized_receipts_for_rooms(
|
||||||
rooms,
|
room_ids,
|
||||||
from_key=from_key,
|
from_key=from_key,
|
||||||
to_key=to_key,
|
to_key=to_key,
|
||||||
)
|
)
|
||||||
|
|
|
@ -64,7 +64,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def register(self, localpart=None, password=None):
|
def register(self, localpart=None, password=None, generate_token=True):
|
||||||
"""Registers a new client on the server.
|
"""Registers a new client on the server.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -89,6 +89,8 @@ class RegistrationHandler(BaseHandler):
|
||||||
user = UserID(localpart, self.hs.hostname)
|
user = UserID(localpart, self.hs.hostname)
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
|
|
||||||
|
token = None
|
||||||
|
if generate_token:
|
||||||
token = self.auth_handler().generate_access_token(user_id)
|
token = self.auth_handler().generate_access_token(user_id)
|
||||||
yield self.store.register(
|
yield self.store.register(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
@ -102,13 +104,13 @@ class RegistrationHandler(BaseHandler):
|
||||||
attempts = 0
|
attempts = 0
|
||||||
user_id = None
|
user_id = None
|
||||||
token = None
|
token = None
|
||||||
while not user_id and not token:
|
while not user_id:
|
||||||
try:
|
try:
|
||||||
localpart = self._generate_user_id()
|
localpart = self._generate_user_id()
|
||||||
user = UserID(localpart, self.hs.hostname)
|
user = UserID(localpart, self.hs.hostname)
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
yield self.check_user_id_is_valid(user_id)
|
yield self.check_user_id_is_valid(user_id)
|
||||||
|
if generate_token:
|
||||||
token = self.auth_handler().generate_access_token(user_id)
|
token = self.auth_handler().generate_access_token(user_id)
|
||||||
yield self.store.register(
|
yield self.store.register(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
|
|
@ -807,7 +807,14 @@ class RoomEventSource(object):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_new_events_for_user(self, user, from_key, limit):
|
def get_new_events(
|
||||||
|
self,
|
||||||
|
user,
|
||||||
|
from_key,
|
||||||
|
limit,
|
||||||
|
room_ids,
|
||||||
|
is_guest,
|
||||||
|
):
|
||||||
# We just ignore the key for now.
|
# We just ignore the key for now.
|
||||||
|
|
||||||
to_key = yield self.get_current_key()
|
to_key = yield self.get_current_key()
|
||||||
|
@ -827,8 +834,9 @@ class RoomEventSource(object):
|
||||||
user_id=user.to_string(),
|
user_id=user.to_string(),
|
||||||
from_key=from_key,
|
from_key=from_key,
|
||||||
to_key=to_key,
|
to_key=to_key,
|
||||||
room_id=None,
|
|
||||||
limit=limit,
|
limit=limit,
|
||||||
|
room_ids=room_ids,
|
||||||
|
is_guest=is_guest,
|
||||||
)
|
)
|
||||||
|
|
||||||
defer.returnValue((events, end_key))
|
defer.returnValue((events, end_key))
|
||||||
|
|
|
@ -143,21 +143,8 @@ class SyncHandler(BaseHandler):
|
||||||
def current_sync_callback(before_token, after_token):
|
def current_sync_callback(before_token, after_token):
|
||||||
return self.current_sync_for_user(sync_config, since_token)
|
return self.current_sync_for_user(sync_config, since_token)
|
||||||
|
|
||||||
rm_handler = self.hs.get_handlers().room_member_handler
|
|
||||||
|
|
||||||
app_service = yield self.store.get_app_service_by_user_id(
|
|
||||||
sync_config.user.to_string()
|
|
||||||
)
|
|
||||||
if app_service:
|
|
||||||
rooms = yield self.store.get_app_service_rooms(app_service)
|
|
||||||
room_ids = set(r.room_id for r in rooms)
|
|
||||||
else:
|
|
||||||
room_ids = yield rm_handler.get_joined_rooms_for_user(
|
|
||||||
sync_config.user
|
|
||||||
)
|
|
||||||
|
|
||||||
result = yield self.notifier.wait_for_events(
|
result = yield self.notifier.wait_for_events(
|
||||||
sync_config.user, room_ids, timeout, current_sync_callback,
|
sync_config.user, timeout, current_sync_callback,
|
||||||
from_token=since_token
|
from_token=since_token
|
||||||
)
|
)
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
@ -308,11 +295,16 @@ class SyncHandler(BaseHandler):
|
||||||
|
|
||||||
typing_key = since_token.typing_key if since_token else "0"
|
typing_key = since_token.typing_key if since_token else "0"
|
||||||
|
|
||||||
|
rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string())
|
||||||
|
room_ids = [room.room_id for room in rooms]
|
||||||
|
|
||||||
typing_source = self.event_sources.sources["typing"]
|
typing_source = self.event_sources.sources["typing"]
|
||||||
typing, typing_key = yield typing_source.get_new_events_for_user(
|
typing, typing_key = yield typing_source.get_new_events(
|
||||||
user=sync_config.user,
|
user=sync_config.user,
|
||||||
from_key=typing_key,
|
from_key=typing_key,
|
||||||
limit=sync_config.filter.ephemeral_limit(),
|
limit=sync_config.filter.ephemeral_limit(),
|
||||||
|
room_ids=room_ids,
|
||||||
|
is_guest=False,
|
||||||
)
|
)
|
||||||
now_token = now_token.copy_and_replace("typing_key", typing_key)
|
now_token = now_token.copy_and_replace("typing_key", typing_key)
|
||||||
|
|
||||||
|
@ -325,10 +317,13 @@ class SyncHandler(BaseHandler):
|
||||||
receipt_key = since_token.receipt_key if since_token else "0"
|
receipt_key = since_token.receipt_key if since_token else "0"
|
||||||
|
|
||||||
receipt_source = self.event_sources.sources["receipt"]
|
receipt_source = self.event_sources.sources["receipt"]
|
||||||
receipts, receipt_key = yield receipt_source.get_new_events_for_user(
|
receipts, receipt_key = yield receipt_source.get_new_events(
|
||||||
user=sync_config.user,
|
user=sync_config.user,
|
||||||
from_key=receipt_key,
|
from_key=receipt_key,
|
||||||
limit=sync_config.filter.ephemeral_limit(),
|
limit=sync_config.filter.ephemeral_limit(),
|
||||||
|
room_ids=room_ids,
|
||||||
|
# /sync doesn't support guest access, they can't get to this point in code
|
||||||
|
is_guest=False,
|
||||||
)
|
)
|
||||||
now_token = now_token.copy_and_replace("receipt_key", receipt_key)
|
now_token = now_token.copy_and_replace("receipt_key", receipt_key)
|
||||||
|
|
||||||
|
@ -373,11 +368,17 @@ class SyncHandler(BaseHandler):
|
||||||
"""
|
"""
|
||||||
now_token = yield self.event_sources.get_current_token()
|
now_token = yield self.event_sources.get_current_token()
|
||||||
|
|
||||||
|
rooms = yield self.store.get_rooms_for_user(sync_config.user.to_string())
|
||||||
|
room_ids = [room.room_id for room in rooms]
|
||||||
|
|
||||||
presence_source = self.event_sources.sources["presence"]
|
presence_source = self.event_sources.sources["presence"]
|
||||||
presence, presence_key = yield presence_source.get_new_events_for_user(
|
presence, presence_key = yield presence_source.get_new_events(
|
||||||
user=sync_config.user,
|
user=sync_config.user,
|
||||||
from_key=since_token.presence_key,
|
from_key=since_token.presence_key,
|
||||||
limit=sync_config.filter.presence_limit(),
|
limit=sync_config.filter.presence_limit(),
|
||||||
|
room_ids=room_ids,
|
||||||
|
# /sync doesn't support guest access, they can't get to this point in code
|
||||||
|
is_guest=False,
|
||||||
)
|
)
|
||||||
now_token = now_token.copy_and_replace("presence_key", presence_key)
|
now_token = now_token.copy_and_replace("presence_key", presence_key)
|
||||||
|
|
||||||
|
@ -403,7 +404,6 @@ class SyncHandler(BaseHandler):
|
||||||
sync_config.user.to_string(),
|
sync_config.user.to_string(),
|
||||||
from_key=since_token.room_key,
|
from_key=since_token.room_key,
|
||||||
to_key=now_token.room_key,
|
to_key=now_token.room_key,
|
||||||
room_id=None,
|
|
||||||
limit=timeline_limit + 1,
|
limit=timeline_limit + 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -246,17 +246,12 @@ class TypingNotificationEventSource(object):
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
def get_new_events(self, from_key, room_ids, **kwargs):
|
||||||
def get_new_events_for_user(self, user, from_key, limit):
|
|
||||||
from_key = int(from_key)
|
from_key = int(from_key)
|
||||||
handler = self.handler()
|
handler = self.handler()
|
||||||
|
|
||||||
joined_room_ids = (
|
|
||||||
yield self.room_member_handler().get_joined_rooms_for_user(user)
|
|
||||||
)
|
|
||||||
|
|
||||||
events = []
|
events = []
|
||||||
for room_id in joined_room_ids:
|
for room_id in room_ids:
|
||||||
if room_id not in handler._room_serials:
|
if room_id not in handler._room_serials:
|
||||||
continue
|
continue
|
||||||
if handler._room_serials[room_id] <= from_key:
|
if handler._room_serials[room_id] <= from_key:
|
||||||
|
@ -264,7 +259,7 @@ class TypingNotificationEventSource(object):
|
||||||
|
|
||||||
events.append(self._make_event_for(room_id))
|
events.append(self._make_event_for(room_id))
|
||||||
|
|
||||||
defer.returnValue((events, handler._latest_room_serial))
|
return events, handler._latest_room_serial
|
||||||
|
|
||||||
def get_current_key(self):
|
def get_current_key(self):
|
||||||
return self.handler()._latest_room_serial
|
return self.handler()._latest_room_serial
|
||||||
|
|
|
@ -269,7 +269,7 @@ class Notifier(object):
|
||||||
logger.exception("Failed to notify listener")
|
logger.exception("Failed to notify listener")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def wait_for_events(self, user, rooms, timeout, callback,
|
def wait_for_events(self, user, timeout, callback, room_ids=None,
|
||||||
from_token=StreamToken("s0", "0", "0", "0", "0")):
|
from_token=StreamToken("s0", "0", "0", "0", "0")):
|
||||||
"""Wait until the callback returns a non empty response or the
|
"""Wait until the callback returns a non empty response or the
|
||||||
timeout fires.
|
timeout fires.
|
||||||
|
@ -279,11 +279,12 @@ class Notifier(object):
|
||||||
if user_stream is None:
|
if user_stream is None:
|
||||||
appservice = yield self.store.get_app_service_by_user_id(user)
|
appservice = yield self.store.get_app_service_by_user_id(user)
|
||||||
current_token = yield self.event_sources.get_current_token()
|
current_token = yield self.event_sources.get_current_token()
|
||||||
|
if room_ids is None:
|
||||||
rooms = yield self.store.get_rooms_for_user(user)
|
rooms = yield self.store.get_rooms_for_user(user)
|
||||||
rooms = [room.room_id for room in rooms]
|
room_ids = [room.room_id for room in rooms]
|
||||||
user_stream = _NotifierUserStream(
|
user_stream = _NotifierUserStream(
|
||||||
user=user,
|
user=user,
|
||||||
rooms=rooms,
|
rooms=room_ids,
|
||||||
appservice=appservice,
|
appservice=appservice,
|
||||||
current_token=current_token,
|
current_token=current_token,
|
||||||
time_now_ms=self.clock.time_msec(),
|
time_now_ms=self.clock.time_msec(),
|
||||||
|
@ -328,8 +329,9 @@ class Notifier(object):
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_events_for(self, user, rooms, pagination_config, timeout,
|
def get_events_for(self, user, pagination_config, timeout,
|
||||||
only_room_events=False):
|
only_room_events=False,
|
||||||
|
is_guest=False, guest_room_id=None):
|
||||||
""" For the given user and rooms, return any new events for them. If
|
""" For the given user and rooms, return any new events for them. If
|
||||||
there are no new events wait for up to `timeout` milliseconds for any
|
there are no new events wait for up to `timeout` milliseconds for any
|
||||||
new events to happen before returning.
|
new events to happen before returning.
|
||||||
|
@ -342,6 +344,16 @@ class Notifier(object):
|
||||||
|
|
||||||
limit = pagination_config.limit
|
limit = pagination_config.limit
|
||||||
|
|
||||||
|
room_ids = []
|
||||||
|
if is_guest:
|
||||||
|
# TODO(daniel): Deal with non-room events too
|
||||||
|
only_room_events = True
|
||||||
|
if guest_room_id:
|
||||||
|
room_ids = [guest_room_id]
|
||||||
|
else:
|
||||||
|
rooms = yield self.store.get_rooms_for_user(user.to_string())
|
||||||
|
room_ids = [room.room_id for room in rooms]
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def check_for_updates(before_token, after_token):
|
def check_for_updates(before_token, after_token):
|
||||||
if not after_token.is_after(before_token):
|
if not after_token.is_after(before_token):
|
||||||
|
@ -357,9 +369,23 @@ class Notifier(object):
|
||||||
continue
|
continue
|
||||||
if only_room_events and name != "room":
|
if only_room_events and name != "room":
|
||||||
continue
|
continue
|
||||||
new_events, new_key = yield source.get_new_events_for_user(
|
new_events, new_key = yield source.get_new_events(
|
||||||
user, getattr(from_token, keyname), limit,
|
user=user,
|
||||||
|
from_key=getattr(from_token, keyname),
|
||||||
|
limit=limit,
|
||||||
|
is_guest=is_guest,
|
||||||
|
room_ids=room_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if is_guest:
|
||||||
|
room_member_handler = self.hs.get_handlers().room_member_handler
|
||||||
|
new_events = yield room_member_handler._filter_events_for_client(
|
||||||
|
user.to_string(),
|
||||||
|
new_events,
|
||||||
|
is_guest=is_guest,
|
||||||
|
require_all_visible_for_guests=False
|
||||||
|
)
|
||||||
|
|
||||||
events.extend(new_events)
|
events.extend(new_events)
|
||||||
end_token = end_token.copy_and_replace(keyname, new_key)
|
end_token = end_token.copy_and_replace(keyname, new_key)
|
||||||
|
|
||||||
|
@ -369,7 +395,7 @@ class Notifier(object):
|
||||||
defer.returnValue(None)
|
defer.returnValue(None)
|
||||||
|
|
||||||
result = yield self.wait_for_events(
|
result = yield self.wait_for_events(
|
||||||
user, rooms, timeout, check_for_updates, from_token=from_token
|
user, timeout, check_for_updates, room_ids=room_ids, from_token=from_token
|
||||||
)
|
)
|
||||||
|
|
||||||
if result is None:
|
if result is None:
|
||||||
|
|
|
@ -31,7 +31,7 @@ class WhoisRestServlet(ClientV1RestServlet):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, user_id):
|
def on_GET(self, request, user_id):
|
||||||
target_user = UserID.from_string(user_id)
|
target_user = UserID.from_string(user_id)
|
||||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||||
is_admin = yield self.auth.is_server_admin(auth_user)
|
is_admin = yield self.auth.is_server_admin(auth_user)
|
||||||
|
|
||||||
if not is_admin and target_user != auth_user:
|
if not is_admin and target_user != auth_user:
|
||||||
|
|
|
@ -69,7 +69,7 @@ class ClientDirectoryServer(ClientV1RestServlet):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# try to auth as a user
|
# try to auth as a user
|
||||||
user, _ = yield self.auth.get_user_by_req(request)
|
user, _, _ = yield self.auth.get_user_by_req(request)
|
||||||
try:
|
try:
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
yield dir_handler.create_association(
|
yield dir_handler.create_association(
|
||||||
|
@ -116,7 +116,7 @@ class ClientDirectoryServer(ClientV1RestServlet):
|
||||||
# fallback to default user behaviour if they aren't an AS
|
# fallback to default user behaviour if they aren't an AS
|
||||||
pass
|
pass
|
||||||
|
|
||||||
user, _ = yield self.auth.get_user_by_req(request)
|
user, _, _ = yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
is_admin = yield self.auth.is_server_admin(user)
|
is_admin = yield self.auth.is_server_admin(user)
|
||||||
if not is_admin:
|
if not is_admin:
|
||||||
|
|
|
@ -34,7 +34,15 @@ class EventStreamRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
auth_user, _, is_guest = yield self.auth.get_user_by_req(
|
||||||
|
request,
|
||||||
|
allow_guest=True
|
||||||
|
)
|
||||||
|
room_id = None
|
||||||
|
if is_guest:
|
||||||
|
if "room_id" not in request.args:
|
||||||
|
raise SynapseError(400, "Guest users must specify room_id param")
|
||||||
|
room_id = request.args["room_id"][0]
|
||||||
try:
|
try:
|
||||||
handler = self.handlers.event_stream_handler
|
handler = self.handlers.event_stream_handler
|
||||||
pagin_config = PaginationConfig.from_request(request)
|
pagin_config = PaginationConfig.from_request(request)
|
||||||
|
@ -49,7 +57,8 @@ class EventStreamRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
chunk = yield handler.get_stream(
|
chunk = yield handler.get_stream(
|
||||||
auth_user.to_string(), pagin_config, timeout=timeout,
|
auth_user.to_string(), pagin_config, timeout=timeout,
|
||||||
as_client_event=as_client_event
|
as_client_event=as_client_event, affect_presence=(not is_guest),
|
||||||
|
room_id=room_id, is_guest=is_guest
|
||||||
)
|
)
|
||||||
except:
|
except:
|
||||||
logger.exception("Event stream failed")
|
logger.exception("Event stream failed")
|
||||||
|
@ -71,7 +80,7 @@ class EventRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, event_id):
|
def on_GET(self, request, event_id):
|
||||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||||
handler = self.handlers.event_handler
|
handler = self.handlers.event_handler
|
||||||
event = yield handler.get_event(auth_user, event_id)
|
event = yield handler.get_event(auth_user, event_id)
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,7 @@ class InitialSyncRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
user, _ = yield self.auth.get_user_by_req(request)
|
user, _, _ = yield self.auth.get_user_by_req(request)
|
||||||
as_client_event = "raw" not in request.args
|
as_client_event = "raw" not in request.args
|
||||||
pagination_config = PaginationConfig.from_request(request)
|
pagination_config = PaginationConfig.from_request(request)
|
||||||
handler = self.handlers.message_handler
|
handler = self.handlers.message_handler
|
||||||
|
|
|
@ -32,7 +32,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, user_id):
|
def on_GET(self, request, user_id):
|
||||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||||
user = UserID.from_string(user_id)
|
user = UserID.from_string(user_id)
|
||||||
|
|
||||||
state = yield self.handlers.presence_handler.get_state(
|
state = yield self.handlers.presence_handler.get_state(
|
||||||
|
@ -42,7 +42,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_PUT(self, request, user_id):
|
def on_PUT(self, request, user_id):
|
||||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||||
user = UserID.from_string(user_id)
|
user = UserID.from_string(user_id)
|
||||||
|
|
||||||
state = {}
|
state = {}
|
||||||
|
@ -77,7 +77,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, user_id):
|
def on_GET(self, request, user_id):
|
||||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||||
user = UserID.from_string(user_id)
|
user = UserID.from_string(user_id)
|
||||||
|
|
||||||
if not self.hs.is_mine(user):
|
if not self.hs.is_mine(user):
|
||||||
|
@ -97,7 +97,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request, user_id):
|
def on_POST(self, request, user_id):
|
||||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||||
user = UserID.from_string(user_id)
|
user = UserID.from_string(user_id)
|
||||||
|
|
||||||
if not self.hs.is_mine(user):
|
if not self.hs.is_mine(user):
|
||||||
|
|
|
@ -37,7 +37,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_PUT(self, request, user_id):
|
def on_PUT(self, request, user_id):
|
||||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||||
user = UserID.from_string(user_id)
|
user = UserID.from_string(user_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -70,7 +70,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_PUT(self, request, user_id):
|
def on_PUT(self, request, user_id):
|
||||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||||
user = UserID.from_string(user_id)
|
user = UserID.from_string(user_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -43,7 +43,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
||||||
except InvalidRuleException as e:
|
except InvalidRuleException as e:
|
||||||
raise SynapseError(400, e.message)
|
raise SynapseError(400, e.message)
|
||||||
|
|
||||||
user, _ = yield self.auth.get_user_by_req(request)
|
user, _, _ = yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
if '/' in spec['rule_id'] or '\\' in spec['rule_id']:
|
if '/' in spec['rule_id'] or '\\' in spec['rule_id']:
|
||||||
raise SynapseError(400, "rule_id may not contain slashes")
|
raise SynapseError(400, "rule_id may not contain slashes")
|
||||||
|
@ -92,7 +92,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
||||||
def on_DELETE(self, request):
|
def on_DELETE(self, request):
|
||||||
spec = _rule_spec_from_path(request.postpath)
|
spec = _rule_spec_from_path(request.postpath)
|
||||||
|
|
||||||
user, _ = yield self.auth.get_user_by_req(request)
|
user, _, _ = yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
|
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
|
||||||
|
|
||||||
|
@ -109,7 +109,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
user, _ = yield self.auth.get_user_by_req(request)
|
user, _, _ = yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
# we build up the full structure and then decide which bits of it
|
# we build up the full structure and then decide which bits of it
|
||||||
# to send which means doing unnecessary work sometimes but is
|
# to send which means doing unnecessary work sometimes but is
|
||||||
|
|
|
@ -27,7 +27,7 @@ class PusherRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
user, token_id = yield self.auth.get_user_by_req(request)
|
user, token_id, _ = yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
content = _parse_json(request)
|
content = _parse_json(request)
|
||||||
|
|
||||||
|
|
|
@ -62,7 +62,7 @@ class RoomCreateRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
room_config = self.get_room_config(request)
|
room_config = self.get_room_config(request)
|
||||||
info = yield self.make_room(room_config, auth_user, None)
|
info = yield self.make_room(room_config, auth_user, None)
|
||||||
|
@ -125,7 +125,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, room_id, event_type, state_key):
|
def on_GET(self, request, room_id, event_type, state_key):
|
||||||
user, _ = yield self.auth.get_user_by_req(request)
|
user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
|
|
||||||
msg_handler = self.handlers.message_handler
|
msg_handler = self.handlers.message_handler
|
||||||
data = yield msg_handler.get_room_data(
|
data = yield msg_handler.get_room_data(
|
||||||
|
@ -133,6 +133,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
event_type=event_type,
|
event_type=event_type,
|
||||||
state_key=state_key,
|
state_key=state_key,
|
||||||
|
is_guest=is_guest,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not data:
|
if not data:
|
||||||
|
@ -143,7 +144,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
|
def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
|
||||||
user, token_id = yield self.auth.get_user_by_req(request)
|
user, token_id, _ = yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
content = _parse_json(request)
|
content = _parse_json(request)
|
||||||
|
|
||||||
|
@ -175,7 +176,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request, room_id, event_type, txn_id=None):
|
def on_POST(self, request, room_id, event_type, txn_id=None):
|
||||||
user, token_id = yield self.auth.get_user_by_req(request)
|
user, token_id, _ = yield self.auth.get_user_by_req(request)
|
||||||
content = _parse_json(request)
|
content = _parse_json(request)
|
||||||
|
|
||||||
msg_handler = self.handlers.message_handler
|
msg_handler = self.handlers.message_handler
|
||||||
|
@ -220,7 +221,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request, room_identifier, txn_id=None):
|
def on_POST(self, request, room_identifier, txn_id=None):
|
||||||
user, token_id = yield self.auth.get_user_by_req(request)
|
user, token_id, _ = yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
# the identifier could be a room alias or a room id. Try one then the
|
# the identifier could be a room alias or a room id. Try one then the
|
||||||
# other if it fails to parse, without swallowing other valid
|
# other if it fails to parse, without swallowing other valid
|
||||||
|
@ -289,7 +290,7 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, room_id):
|
def on_GET(self, request, room_id):
|
||||||
# TODO support Pagination stream API (limit/tokens)
|
# TODO support Pagination stream API (limit/tokens)
|
||||||
user, _ = yield self.auth.get_user_by_req(request)
|
user, _, _ = yield self.auth.get_user_by_req(request)
|
||||||
handler = self.handlers.message_handler
|
handler = self.handlers.message_handler
|
||||||
events = yield handler.get_state_events(
|
events = yield handler.get_state_events(
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
|
@ -325,7 +326,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, room_id):
|
def on_GET(self, request, room_id):
|
||||||
user, _ = yield self.auth.get_user_by_req(request)
|
user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
pagination_config = PaginationConfig.from_request(
|
pagination_config = PaginationConfig.from_request(
|
||||||
request, default_limit=10,
|
request, default_limit=10,
|
||||||
)
|
)
|
||||||
|
@ -334,6 +335,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
|
||||||
msgs = yield handler.get_messages(
|
msgs = yield handler.get_messages(
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
user_id=user.to_string(),
|
user_id=user.to_string(),
|
||||||
|
is_guest=is_guest,
|
||||||
pagin_config=pagination_config,
|
pagin_config=pagination_config,
|
||||||
as_client_event=as_client_event
|
as_client_event=as_client_event
|
||||||
)
|
)
|
||||||
|
@ -347,12 +349,13 @@ class RoomStateRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, room_id):
|
def on_GET(self, request, room_id):
|
||||||
user, _ = yield self.auth.get_user_by_req(request)
|
user, _, is_guest = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
handler = self.handlers.message_handler
|
handler = self.handlers.message_handler
|
||||||
# Get all the current state for this room
|
# Get all the current state for this room
|
||||||
events = yield handler.get_state_events(
|
events = yield handler.get_state_events(
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
user_id=user.to_string(),
|
user_id=user.to_string(),
|
||||||
|
is_guest=is_guest,
|
||||||
)
|
)
|
||||||
defer.returnValue((200, events))
|
defer.returnValue((200, events))
|
||||||
|
|
||||||
|
@ -363,7 +366,7 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, room_id):
|
def on_GET(self, request, room_id):
|
||||||
user, _ = yield self.auth.get_user_by_req(request)
|
user, _, _ = yield self.auth.get_user_by_req(request)
|
||||||
pagination_config = PaginationConfig.from_request(request)
|
pagination_config = PaginationConfig.from_request(request)
|
||||||
content = yield self.handlers.message_handler.room_initial_sync(
|
content = yield self.handlers.message_handler.room_initial_sync(
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
|
@ -443,7 +446,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request, room_id, membership_action, txn_id=None):
|
def on_POST(self, request, room_id, membership_action, txn_id=None):
|
||||||
user, token_id = yield self.auth.get_user_by_req(request)
|
user, token_id, _ = yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
content = _parse_json(request)
|
content = _parse_json(request)
|
||||||
|
|
||||||
|
@ -524,7 +527,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request, room_id, event_id, txn_id=None):
|
def on_POST(self, request, room_id, event_id, txn_id=None):
|
||||||
user, token_id = yield self.auth.get_user_by_req(request)
|
user, token_id, _ = yield self.auth.get_user_by_req(request)
|
||||||
content = _parse_json(request)
|
content = _parse_json(request)
|
||||||
|
|
||||||
msg_handler = self.handlers.message_handler
|
msg_handler = self.handlers.message_handler
|
||||||
|
@ -564,7 +567,7 @@ class RoomTypingRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_PUT(self, request, room_id, user_id):
|
def on_PUT(self, request, room_id, user_id):
|
||||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
room_id = urllib.unquote(room_id)
|
room_id = urllib.unquote(room_id)
|
||||||
target_user = UserID.from_string(urllib.unquote(user_id))
|
target_user = UserID.from_string(urllib.unquote(user_id))
|
||||||
|
@ -597,7 +600,7 @@ class SearchRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
content = _parse_json(request)
|
content = _parse_json(request)
|
||||||
|
|
||||||
|
|
|
@ -28,7 +28,7 @@ class VoipRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
turnUris = self.hs.config.turn_uris
|
turnUris = self.hs.config.turn_uris
|
||||||
turnSecret = self.hs.config.turn_shared_secret
|
turnSecret = self.hs.config.turn_shared_secret
|
||||||
|
|
|
@ -55,7 +55,7 @@ class PasswordRestServlet(RestServlet):
|
||||||
|
|
||||||
if LoginType.PASSWORD in result:
|
if LoginType.PASSWORD in result:
|
||||||
# if using password, they should also be logged in
|
# if using password, they should also be logged in
|
||||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||||
if auth_user.to_string() != result[LoginType.PASSWORD]:
|
if auth_user.to_string() != result[LoginType.PASSWORD]:
|
||||||
raise LoginError(400, "", Codes.UNKNOWN)
|
raise LoginError(400, "", Codes.UNKNOWN)
|
||||||
user_id = auth_user.to_string()
|
user_id = auth_user.to_string()
|
||||||
|
@ -102,7 +102,7 @@ class ThreepidRestServlet(RestServlet):
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
yield run_on_reactor()
|
yield run_on_reactor()
|
||||||
|
|
||||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
threepids = yield self.hs.get_datastore().user_get_threepids(
|
threepids = yield self.hs.get_datastore().user_get_threepids(
|
||||||
auth_user.to_string()
|
auth_user.to_string()
|
||||||
|
@ -120,7 +120,7 @@ class ThreepidRestServlet(RestServlet):
|
||||||
raise SynapseError(400, "Missing param", Codes.MISSING_PARAM)
|
raise SynapseError(400, "Missing param", Codes.MISSING_PARAM)
|
||||||
threePidCreds = body['threePidCreds']
|
threePidCreds = body['threePidCreds']
|
||||||
|
|
||||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
threepid = yield self.identity_handler.threepid_from_creds(threePidCreds)
|
threepid = yield self.identity_handler.threepid_from_creds(threePidCreds)
|
||||||
|
|
||||||
|
|
|
@ -40,7 +40,7 @@ class GetFilterRestServlet(RestServlet):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, user_id, filter_id):
|
def on_GET(self, request, user_id, filter_id):
|
||||||
target_user = UserID.from_string(user_id)
|
target_user = UserID.from_string(user_id)
|
||||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
if target_user != auth_user:
|
if target_user != auth_user:
|
||||||
raise AuthError(403, "Cannot get filters for other users")
|
raise AuthError(403, "Cannot get filters for other users")
|
||||||
|
@ -76,7 +76,7 @@ class CreateFilterRestServlet(RestServlet):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request, user_id):
|
def on_POST(self, request, user_id):
|
||||||
target_user = UserID.from_string(user_id)
|
target_user = UserID.from_string(user_id)
|
||||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
if target_user != auth_user:
|
if target_user != auth_user:
|
||||||
raise AuthError(403, "Cannot create filters for other users")
|
raise AuthError(403, "Cannot create filters for other users")
|
||||||
|
|
|
@ -64,7 +64,7 @@ class KeyUploadServlet(RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request, device_id):
|
def on_POST(self, request, device_id):
|
||||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||||
user_id = auth_user.to_string()
|
user_id = auth_user.to_string()
|
||||||
# TODO: Check that the device_id matches that in the authentication
|
# TODO: Check that the device_id matches that in the authentication
|
||||||
# or derive the device_id from the authentication instead.
|
# or derive the device_id from the authentication instead.
|
||||||
|
@ -109,7 +109,7 @@ class KeyUploadServlet(RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, device_id):
|
def on_GET(self, request, device_id):
|
||||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||||
user_id = auth_user.to_string()
|
user_id = auth_user.to_string()
|
||||||
|
|
||||||
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
|
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
|
||||||
|
@ -181,7 +181,7 @@ class KeyQueryServlet(RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, user_id, device_id):
|
def on_GET(self, request, user_id, device_id):
|
||||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||||
auth_user_id = auth_user.to_string()
|
auth_user_id = auth_user.to_string()
|
||||||
user_id = user_id if user_id else auth_user_id
|
user_id = user_id if user_id else auth_user_id
|
||||||
device_ids = [device_id] if device_id else []
|
device_ids = [device_id] if device_id else []
|
||||||
|
|
|
@ -40,7 +40,7 @@ class ReceiptRestServlet(RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request, room_id, receipt_type, event_id):
|
def on_POST(self, request, room_id, receipt_type, event_id):
|
||||||
user, _ = yield self.auth.get_user_by_req(request)
|
user, _, _ = yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
if receipt_type != "m.read":
|
if receipt_type != "m.read":
|
||||||
raise SynapseError(400, "Receipt type must be 'm.read'")
|
raise SynapseError(400, "Receipt type must be 'm.read'")
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.constants import LoginType
|
from synapse.api.constants import LoginType
|
||||||
from synapse.api.errors import SynapseError, Codes
|
from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
|
||||||
from synapse.http.servlet import RestServlet
|
from synapse.http.servlet import RestServlet
|
||||||
|
|
||||||
from ._base import client_v2_pattern, parse_json_dict_from_request
|
from ._base import client_v2_pattern, parse_json_dict_from_request
|
||||||
|
@ -55,6 +55,19 @@ class RegisterRestServlet(RestServlet):
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
yield run_on_reactor()
|
yield run_on_reactor()
|
||||||
|
|
||||||
|
kind = "user"
|
||||||
|
if "kind" in request.args:
|
||||||
|
kind = request.args["kind"][0]
|
||||||
|
|
||||||
|
if kind == "guest":
|
||||||
|
ret = yield self._do_guest_registration()
|
||||||
|
defer.returnValue(ret)
|
||||||
|
return
|
||||||
|
elif kind != "user":
|
||||||
|
raise UnrecognizedRequestError(
|
||||||
|
"Do not understand membership kind: %s" % (kind,)
|
||||||
|
)
|
||||||
|
|
||||||
if '/register/email/requestToken' in request.path:
|
if '/register/email/requestToken' in request.path:
|
||||||
ret = yield self.onEmailTokenRequest(request)
|
ret = yield self.onEmailTokenRequest(request)
|
||||||
defer.returnValue(ret)
|
defer.returnValue(ret)
|
||||||
|
@ -236,6 +249,18 @@ class RegisterRestServlet(RestServlet):
|
||||||
ret = yield self.identity_handler.requestEmailToken(**body)
|
ret = yield self.identity_handler.requestEmailToken(**body)
|
||||||
defer.returnValue((200, ret))
|
defer.returnValue((200, ret))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _do_guest_registration(self):
|
||||||
|
if not self.hs.config.allow_guest_access:
|
||||||
|
defer.returnValue((403, "Guest access is disabled"))
|
||||||
|
user_id, _ = yield self.registration_handler.register(generate_token=False)
|
||||||
|
access_token = self.auth_handler.generate_access_token(user_id, ["guest = true"])
|
||||||
|
defer.returnValue((200, {
|
||||||
|
"user_id": user_id,
|
||||||
|
"access_token": access_token,
|
||||||
|
"home_server": self.hs.hostname,
|
||||||
|
}))
|
||||||
|
|
||||||
|
|
||||||
def register_servlets(hs, http_server):
|
def register_servlets(hs, http_server):
|
||||||
RegisterRestServlet(hs).register(http_server)
|
RegisterRestServlet(hs).register(http_server)
|
||||||
|
|
|
@ -81,7 +81,7 @@ class SyncRestServlet(RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
user, token_id = yield self.auth.get_user_by_req(request)
|
user, token_id, _ = yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
timeout = parse_integer(request, "timeout", default=0)
|
timeout = parse_integer(request, "timeout", default=0)
|
||||||
since = parse_string(request, "since")
|
since = parse_string(request, "since")
|
||||||
|
|
|
@ -42,7 +42,7 @@ class TagListServlet(RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, user_id, room_id):
|
def on_GET(self, request, user_id, room_id):
|
||||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||||
if user_id != auth_user.to_string():
|
if user_id != auth_user.to_string():
|
||||||
raise AuthError(403, "Cannot get tags for other users.")
|
raise AuthError(403, "Cannot get tags for other users.")
|
||||||
|
|
||||||
|
@ -68,7 +68,7 @@ class TagServlet(RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_PUT(self, request, user_id, room_id, tag):
|
def on_PUT(self, request, user_id, room_id, tag):
|
||||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||||
if user_id != auth_user.to_string():
|
if user_id != auth_user.to_string():
|
||||||
raise AuthError(403, "Cannot add tags for other users.")
|
raise AuthError(403, "Cannot add tags for other users.")
|
||||||
|
|
||||||
|
@ -88,7 +88,7 @@ class TagServlet(RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_DELETE(self, request, user_id, room_id, tag):
|
def on_DELETE(self, request, user_id, room_id, tag):
|
||||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||||
if user_id != auth_user.to_string():
|
if user_id != auth_user.to_string():
|
||||||
raise AuthError(403, "Cannot add tags for other users.")
|
raise AuthError(403, "Cannot add tags for other users.")
|
||||||
|
|
||||||
|
|
|
@ -66,7 +66,7 @@ class ContentRepoResource(resource.Resource):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def map_request_to_name(self, request):
|
def map_request_to_name(self, request):
|
||||||
# auth the user
|
# auth the user
|
||||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
# namespace all file uploads on the user
|
# namespace all file uploads on the user
|
||||||
prefix = base64.urlsafe_b64encode(
|
prefix = base64.urlsafe_b64encode(
|
||||||
|
|
|
@ -70,7 +70,7 @@ class UploadResource(BaseMediaResource):
|
||||||
@request_handler
|
@request_handler
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _async_render_POST(self, request):
|
def _async_render_POST(self, request):
|
||||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
auth_user, _, _ = yield self.auth.get_user_by_req(request)
|
||||||
# TODO: The checks here are a bit late. The content will have
|
# TODO: The checks here are a bit late. The content will have
|
||||||
# already been uploaded to a tmp file at this point
|
# already been uploaded to a tmp file at this point
|
||||||
content_length = request.getHeader("Content-Length")
|
content_length = request.getHeader("Content-Length")
|
||||||
|
|
|
@ -29,7 +29,6 @@ from synapse.state import StateHandler
|
||||||
from synapse.storage import DataStore
|
from synapse.storage import DataStore
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
from synapse.util.distributor import Distributor
|
from synapse.util.distributor import Distributor
|
||||||
from synapse.util.lockutils import LockManager
|
|
||||||
from synapse.streams.events import EventSources
|
from synapse.streams.events import EventSources
|
||||||
from synapse.api.ratelimiting import Ratelimiter
|
from synapse.api.ratelimiting import Ratelimiter
|
||||||
from synapse.crypto.keyring import Keyring
|
from synapse.crypto.keyring import Keyring
|
||||||
|
@ -70,7 +69,6 @@ class BaseHomeServer(object):
|
||||||
'auth',
|
'auth',
|
||||||
'rest_servlet_factory',
|
'rest_servlet_factory',
|
||||||
'state_handler',
|
'state_handler',
|
||||||
'room_lock_manager',
|
|
||||||
'notifier',
|
'notifier',
|
||||||
'distributor',
|
'distributor',
|
||||||
'resource_for_client',
|
'resource_for_client',
|
||||||
|
@ -201,9 +199,6 @@ class HomeServer(BaseHomeServer):
|
||||||
def build_state_handler(self):
|
def build_state_handler(self):
|
||||||
return StateHandler(self)
|
return StateHandler(self)
|
||||||
|
|
||||||
def build_room_lock_manager(self):
|
|
||||||
return LockManager()
|
|
||||||
|
|
||||||
def build_distributor(self):
|
def build_distributor(self):
|
||||||
return Distributor()
|
return Distributor()
|
||||||
|
|
||||||
|
|
|
@ -311,6 +311,8 @@ class EventsStore(SQLBaseStore):
|
||||||
self._store_room_message_txn(txn, event)
|
self._store_room_message_txn(txn, event)
|
||||||
elif event.type == EventTypes.Redaction:
|
elif event.type == EventTypes.Redaction:
|
||||||
self._store_redaction(txn, event)
|
self._store_redaction(txn, event)
|
||||||
|
elif event.type == EventTypes.RoomHistoryVisibility:
|
||||||
|
self._store_history_visibility_txn(txn, event)
|
||||||
|
|
||||||
self._store_room_members_txn(
|
self._store_room_members_txn(
|
||||||
txn,
|
txn,
|
||||||
|
|
|
@ -102,6 +102,7 @@ class RegistrationStore(SQLBaseStore):
|
||||||
400, "User ID already taken.", errcode=Codes.USER_IN_USE
|
400, "User ID already taken.", errcode=Codes.USER_IN_USE
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if token:
|
||||||
# it's possible for this to get a conflict, but only for a single user
|
# it's possible for this to get a conflict, but only for a single user
|
||||||
# since tokens are namespaced based on their user ID
|
# since tokens are namespaced based on their user ID
|
||||||
txn.execute(
|
txn.execute(
|
||||||
|
|
|
@ -202,6 +202,19 @@ class RoomStore(SQLBaseStore):
|
||||||
txn, event, "content.body", event.content["body"]
|
txn, event, "content.body", event.content["body"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _store_history_visibility_txn(self, txn, event):
|
||||||
|
if hasattr(event, "content") and "history_visibility" in event.content:
|
||||||
|
sql = (
|
||||||
|
"INSERT INTO history_visibility"
|
||||||
|
" (event_id, room_id, history_visibility)"
|
||||||
|
" VALUES (?, ?, ?)"
|
||||||
|
)
|
||||||
|
txn.execute(sql, (
|
||||||
|
event.event_id,
|
||||||
|
event.room_id,
|
||||||
|
event.content["history_visibility"]
|
||||||
|
))
|
||||||
|
|
||||||
def _store_event_search_txn(self, txn, event, key, value):
|
def _store_event_search_txn(self, txn, event, key, value):
|
||||||
if isinstance(self.database_engine, PostgresEngine):
|
if isinstance(self.database_engine, PostgresEngine):
|
||||||
sql = (
|
sql = (
|
||||||
|
|
26
synapse/storage/schema/delta/25/history_visibility.sql
Normal file
26
synapse/storage/schema/delta/25/history_visibility.sql
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
/* Copyright 2015 OpenMarket Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/*
|
||||||
|
* This is a manual index of history_visibility content of state events,
|
||||||
|
* so that we can join on them in SELECT statements.
|
||||||
|
*/
|
||||||
|
CREATE TABLE IF NOT EXISTS history_visibility(
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
event_id TEXT NOT NULL,
|
||||||
|
room_id TEXT NOT NULL,
|
||||||
|
history_visibility TEXT NOT NULL,
|
||||||
|
UNIQUE (event_id)
|
||||||
|
);
|
|
@ -158,14 +158,40 @@ class StreamStore(SQLBaseStore):
|
||||||
defer.returnValue(results)
|
defer.returnValue(results)
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def get_room_events_stream(self, user_id, from_key, to_key, room_id,
|
def get_room_events_stream(
|
||||||
limit=0):
|
self,
|
||||||
|
user_id,
|
||||||
|
from_key,
|
||||||
|
to_key,
|
||||||
|
limit=0,
|
||||||
|
is_guest=False,
|
||||||
|
room_ids=None
|
||||||
|
):
|
||||||
|
room_ids = room_ids or []
|
||||||
|
room_ids = [r for r in room_ids]
|
||||||
|
if is_guest:
|
||||||
|
current_room_membership_sql = (
|
||||||
|
"SELECT c.room_id FROM history_visibility AS h"
|
||||||
|
" INNER JOIN current_state_events AS c"
|
||||||
|
" ON h.event_id = c.event_id"
|
||||||
|
" WHERE c.room_id IN (%s) AND h.history_visibility = 'world_readable'" % (
|
||||||
|
",".join(map(lambda _: "?", room_ids))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
current_room_membership_args = room_ids
|
||||||
|
else:
|
||||||
current_room_membership_sql = (
|
current_room_membership_sql = (
|
||||||
"SELECT m.room_id FROM room_memberships as m "
|
"SELECT m.room_id FROM room_memberships as m "
|
||||||
" INNER JOIN current_state_events as c"
|
" INNER JOIN current_state_events as c"
|
||||||
" ON m.event_id = c.event_id AND c.state_key = m.user_id"
|
" ON m.event_id = c.event_id AND c.state_key = m.user_id"
|
||||||
" WHERE m.user_id = ? AND m.membership = 'join'"
|
" WHERE m.user_id = ? AND m.membership = 'join'"
|
||||||
)
|
)
|
||||||
|
current_room_membership_args = [user_id]
|
||||||
|
if room_ids:
|
||||||
|
current_room_membership_sql += " AND m.room_id in (%s)" % (
|
||||||
|
",".join(map(lambda _: "?", room_ids))
|
||||||
|
)
|
||||||
|
current_room_membership_args = [user_id] + room_ids
|
||||||
|
|
||||||
# We also want to get any membership events about that user, e.g.
|
# We also want to get any membership events about that user, e.g.
|
||||||
# invites or leave notifications.
|
# invites or leave notifications.
|
||||||
|
@ -174,6 +200,7 @@ class StreamStore(SQLBaseStore):
|
||||||
"INNER JOIN current_state_events as c ON m.event_id = c.event_id "
|
"INNER JOIN current_state_events as c ON m.event_id = c.event_id "
|
||||||
"WHERE m.user_id = ? "
|
"WHERE m.user_id = ? "
|
||||||
)
|
)
|
||||||
|
membership_args = [user_id]
|
||||||
|
|
||||||
if limit:
|
if limit:
|
||||||
limit = max(limit, MAX_STREAM_SIZE)
|
limit = max(limit, MAX_STREAM_SIZE)
|
||||||
|
@ -200,7 +227,9 @@ class StreamStore(SQLBaseStore):
|
||||||
}
|
}
|
||||||
|
|
||||||
def f(txn):
|
def f(txn):
|
||||||
txn.execute(sql, (False, user_id, user_id, from_id.stream, to_id.stream,))
|
args = ([False] + current_room_membership_args + membership_args +
|
||||||
|
[from_id.stream, to_id.stream])
|
||||||
|
txn.execute(sql, args)
|
||||||
|
|
||||||
rows = self.cursor_to_dict(txn)
|
rows = self.cursor_to_dict(txn)
|
||||||
|
|
||||||
|
|
|
@ -1,74 +0,0 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
# Copyright 2014, 2015 OpenMarket Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class Lock(object):
|
|
||||||
|
|
||||||
def __init__(self, deferred, key):
|
|
||||||
self._deferred = deferred
|
|
||||||
self.released = False
|
|
||||||
self.key = key
|
|
||||||
|
|
||||||
def release(self):
|
|
||||||
self.released = True
|
|
||||||
self._deferred.callback(None)
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
if not self.released:
|
|
||||||
logger.critical("Lock was destructed but never released!")
|
|
||||||
self.release()
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, type, value, traceback):
|
|
||||||
logger.debug("Releasing lock for key=%r", self.key)
|
|
||||||
self.release()
|
|
||||||
|
|
||||||
|
|
||||||
class LockManager(object):
|
|
||||||
""" Utility class that allows us to lock based on a `key` """
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self._lock_deferreds = {}
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def lock(self, key):
|
|
||||||
""" Allows us to block until it is our turn.
|
|
||||||
Args:
|
|
||||||
key (str)
|
|
||||||
Returns:
|
|
||||||
Lock
|
|
||||||
"""
|
|
||||||
new_deferred = defer.Deferred()
|
|
||||||
old_deferred = self._lock_deferreds.get(key)
|
|
||||||
self._lock_deferreds[key] = new_deferred
|
|
||||||
|
|
||||||
if old_deferred:
|
|
||||||
logger.debug("Queueing on lock for key=%r", key)
|
|
||||||
yield old_deferred
|
|
||||||
logger.debug("Obtained lock for key=%r", key)
|
|
||||||
else:
|
|
||||||
logger.debug("Entering uncontended lock for key=%r", key)
|
|
||||||
|
|
||||||
defer.returnValue(Lock(new_deferred, key))
|
|
|
@ -51,7 +51,7 @@ class AuthTestCase(unittest.TestCase):
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.args["access_token"] = [self.test_token]
|
request.args["access_token"] = [self.test_token]
|
||||||
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
|
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
|
||||||
(user, _) = yield self.auth.get_user_by_req(request)
|
(user, _, _) = yield self.auth.get_user_by_req(request)
|
||||||
self.assertEquals(user.to_string(), self.test_user)
|
self.assertEquals(user.to_string(), self.test_user)
|
||||||
|
|
||||||
def test_get_user_by_req_user_bad_token(self):
|
def test_get_user_by_req_user_bad_token(self):
|
||||||
|
@ -86,7 +86,7 @@ class AuthTestCase(unittest.TestCase):
|
||||||
request = Mock(args={})
|
request = Mock(args={})
|
||||||
request.args["access_token"] = [self.test_token]
|
request.args["access_token"] = [self.test_token]
|
||||||
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
|
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
|
||||||
(user, _) = yield self.auth.get_user_by_req(request)
|
(user, _, _) = yield self.auth.get_user_by_req(request)
|
||||||
self.assertEquals(user.to_string(), self.test_user)
|
self.assertEquals(user.to_string(), self.test_user)
|
||||||
|
|
||||||
def test_get_user_by_req_appservice_bad_token(self):
|
def test_get_user_by_req_appservice_bad_token(self):
|
||||||
|
@ -121,7 +121,7 @@ class AuthTestCase(unittest.TestCase):
|
||||||
request.args["access_token"] = [self.test_token]
|
request.args["access_token"] = [self.test_token]
|
||||||
request.args["user_id"] = [masquerading_user_id]
|
request.args["user_id"] = [masquerading_user_id]
|
||||||
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
|
request.requestHeaders.getRawHeaders = Mock(return_value=[""])
|
||||||
(user, _) = yield self.auth.get_user_by_req(request)
|
(user, _, _) = yield self.auth.get_user_by_req(request)
|
||||||
self.assertEquals(user.to_string(), masquerading_user_id)
|
self.assertEquals(user.to_string(), masquerading_user_id)
|
||||||
|
|
||||||
def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
|
def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
|
||||||
|
@ -158,6 +158,25 @@ class AuthTestCase(unittest.TestCase):
|
||||||
user = user_info["user"]
|
user = user_info["user"]
|
||||||
self.assertEqual(UserID.from_string(user_id), user)
|
self.assertEqual(UserID.from_string(user_id), user)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_get_guest_user_from_macaroon(self):
|
||||||
|
user_id = "@baldrick:matrix.org"
|
||||||
|
macaroon = pymacaroons.Macaroon(
|
||||||
|
location=self.hs.config.server_name,
|
||||||
|
identifier="key",
|
||||||
|
key=self.hs.config.macaroon_secret_key)
|
||||||
|
macaroon.add_first_party_caveat("gen = 1")
|
||||||
|
macaroon.add_first_party_caveat("type = access")
|
||||||
|
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
|
||||||
|
macaroon.add_first_party_caveat("guest = true")
|
||||||
|
serialized = macaroon.serialize()
|
||||||
|
|
||||||
|
user_info = yield self.auth._get_user_from_macaroon(serialized)
|
||||||
|
user = user_info["user"]
|
||||||
|
is_guest = user_info["is_guest"]
|
||||||
|
self.assertEqual(UserID.from_string(user_id), user)
|
||||||
|
self.assertTrue(is_guest)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_get_user_from_macaroon_user_db_mismatch(self):
|
def test_get_user_from_macaroon_user_db_mismatch(self):
|
||||||
self.store.get_user_by_access_token = Mock(
|
self.store.get_user_by_access_token = Mock(
|
||||||
|
|
|
@ -650,9 +650,30 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
|
||||||
{"presence": ONLINE}
|
{"presence": ONLINE}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Apple sees self-reflection even without room_id
|
||||||
|
(events, _) = yield self.event_source.get_new_events(
|
||||||
|
user=self.u_apple,
|
||||||
|
from_key=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||||
|
self.assertEquals(events,
|
||||||
|
[
|
||||||
|
{"type": "m.presence",
|
||||||
|
"content": {
|
||||||
|
"user_id": "@apple:test",
|
||||||
|
"presence": ONLINE,
|
||||||
|
"last_active_ago": 0,
|
||||||
|
}},
|
||||||
|
],
|
||||||
|
msg="Presence event should be visible to self-reflection"
|
||||||
|
)
|
||||||
|
|
||||||
# Apple sees self-reflection
|
# Apple sees self-reflection
|
||||||
(events, _) = yield self.event_source.get_new_events_for_user(
|
(events, _) = yield self.event_source.get_new_events(
|
||||||
self.u_apple, 0, None
|
user=self.u_apple,
|
||||||
|
from_key=0,
|
||||||
|
room_ids=[self.room_id],
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEquals(self.event_source.get_current_key(), 1)
|
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||||
|
@ -684,8 +705,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Banana sees it because of presence subscription
|
# Banana sees it because of presence subscription
|
||||||
(events, _) = yield self.event_source.get_new_events_for_user(
|
(events, _) = yield self.event_source.get_new_events(
|
||||||
self.u_banana, 0, None
|
user=self.u_banana,
|
||||||
|
from_key=0,
|
||||||
|
room_ids=[self.room_id],
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEquals(self.event_source.get_current_key(), 1)
|
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||||
|
@ -702,8 +725,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Elderberry sees it because of same room
|
# Elderberry sees it because of same room
|
||||||
(events, _) = yield self.event_source.get_new_events_for_user(
|
(events, _) = yield self.event_source.get_new_events(
|
||||||
self.u_elderberry, 0, None
|
user=self.u_elderberry,
|
||||||
|
from_key=0,
|
||||||
|
room_ids=[self.room_id],
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEquals(self.event_source.get_current_key(), 1)
|
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||||
|
@ -720,8 +745,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Durian is not in the room, should not see this event
|
# Durian is not in the room, should not see this event
|
||||||
(events, _) = yield self.event_source.get_new_events_for_user(
|
(events, _) = yield self.event_source.get_new_events(
|
||||||
self.u_durian, 0, None
|
user=self.u_durian,
|
||||||
|
from_key=0,
|
||||||
|
room_ids=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEquals(self.event_source.get_current_key(), 1)
|
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||||
|
@ -767,8 +794,9 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
|
||||||
"accepted": True},
|
"accepted": True},
|
||||||
], presence)
|
], presence)
|
||||||
|
|
||||||
(events, _) = yield self.event_source.get_new_events_for_user(
|
(events, _) = yield self.event_source.get_new_events(
|
||||||
self.u_apple, 1, None
|
user=self.u_apple,
|
||||||
|
from_key=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEquals(self.event_source.get_current_key(), 2)
|
self.assertEquals(self.event_source.get_current_key(), 2)
|
||||||
|
@ -858,8 +886,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
(events, _) = yield self.event_source.get_new_events_for_user(
|
(events, _) = yield self.event_source.get_new_events(
|
||||||
self.u_apple, 0, None
|
user=self.u_apple,
|
||||||
|
from_key=0,
|
||||||
|
room_ids=[self.room_id],
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEquals(self.event_source.get_current_key(), 1)
|
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||||
|
@ -905,8 +935,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
|
||||||
|
|
||||||
self.assertEquals(self.event_source.get_current_key(), 1)
|
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||||
|
|
||||||
(events, _) = yield self.event_source.get_new_events_for_user(
|
(events, _) = yield self.event_source.get_new_events(
|
||||||
self.u_apple, 0, None
|
user=self.u_apple,
|
||||||
|
from_key=0,
|
||||||
|
room_ids=[self.room_id,]
|
||||||
)
|
)
|
||||||
self.assertEquals(events,
|
self.assertEquals(events,
|
||||||
[
|
[
|
||||||
|
@ -932,8 +964,10 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
|
||||||
|
|
||||||
self.assertEquals(self.event_source.get_current_key(), 2)
|
self.assertEquals(self.event_source.get_current_key(), 2)
|
||||||
|
|
||||||
(events, _) = yield self.event_source.get_new_events_for_user(
|
(events, _) = yield self.event_source.get_new_events(
|
||||||
self.u_apple, 0, None
|
user=self.u_apple,
|
||||||
|
from_key=0,
|
||||||
|
room_ids=[self.room_id,]
|
||||||
)
|
)
|
||||||
self.assertEquals(events,
|
self.assertEquals(events,
|
||||||
[
|
[
|
||||||
|
@ -966,8 +1000,9 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
|
||||||
|
|
||||||
self.room_members.append(self.u_clementine)
|
self.room_members.append(self.u_clementine)
|
||||||
|
|
||||||
(events, _) = yield self.event_source.get_new_events_for_user(
|
(events, _) = yield self.event_source.get_new_events(
|
||||||
self.u_apple, 0, None
|
user=self.u_apple,
|
||||||
|
from_key=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEquals(self.event_source.get_current_key(), 1)
|
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||||
|
|
|
@ -187,7 +187,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
||||||
])
|
])
|
||||||
|
|
||||||
self.assertEquals(self.event_source.get_current_key(), 1)
|
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||||
events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None)
|
events = yield self.event_source.get_new_events(
|
||||||
|
room_ids=[self.room_id],
|
||||||
|
from_key=0,
|
||||||
|
)
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
events[0],
|
events[0],
|
||||||
[
|
[
|
||||||
|
@ -250,7 +253,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
||||||
])
|
])
|
||||||
|
|
||||||
self.assertEquals(self.event_source.get_current_key(), 1)
|
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||||
events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None)
|
events = yield self.event_source.get_new_events(
|
||||||
|
room_ids=[self.room_id],
|
||||||
|
from_key=0
|
||||||
|
)
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
events[0],
|
events[0],
|
||||||
[
|
[
|
||||||
|
@ -306,7 +312,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
||||||
yield put_json.await_calls()
|
yield put_json.await_calls()
|
||||||
|
|
||||||
self.assertEquals(self.event_source.get_current_key(), 1)
|
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||||
events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None)
|
events = yield self.event_source.get_new_events(
|
||||||
|
room_ids=[self.room_id],
|
||||||
|
from_key=0,
|
||||||
|
)
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
events[0],
|
events[0],
|
||||||
[
|
[
|
||||||
|
@ -337,7 +346,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
||||||
self.on_new_event.reset_mock()
|
self.on_new_event.reset_mock()
|
||||||
|
|
||||||
self.assertEquals(self.event_source.get_current_key(), 1)
|
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||||
events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None)
|
events = yield self.event_source.get_new_events(
|
||||||
|
room_ids=[self.room_id],
|
||||||
|
from_key=0,
|
||||||
|
)
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
events[0],
|
events[0],
|
||||||
[
|
[
|
||||||
|
@ -356,7 +368,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
||||||
])
|
])
|
||||||
|
|
||||||
self.assertEquals(self.event_source.get_current_key(), 2)
|
self.assertEquals(self.event_source.get_current_key(), 2)
|
||||||
events = yield self.event_source.get_new_events_for_user(self.u_apple, 1, None)
|
events = yield self.event_source.get_new_events(
|
||||||
|
room_ids=[self.room_id],
|
||||||
|
from_key=1,
|
||||||
|
)
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
events[0],
|
events[0],
|
||||||
[
|
[
|
||||||
|
@ -383,7 +398,10 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
||||||
self.on_new_event.reset_mock()
|
self.on_new_event.reset_mock()
|
||||||
|
|
||||||
self.assertEquals(self.event_source.get_current_key(), 3)
|
self.assertEquals(self.event_source.get_current_key(), 3)
|
||||||
events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None)
|
events = yield self.event_source.get_new_events(
|
||||||
|
room_ids=[self.room_id],
|
||||||
|
from_key=0,
|
||||||
|
)
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
events[0],
|
events[0],
|
||||||
[
|
[
|
||||||
|
|
|
@ -47,7 +47,14 @@ class NullSource(object):
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_new_events_for_user(self, user, from_key, limit):
|
def get_new_events(
|
||||||
|
self,
|
||||||
|
user,
|
||||||
|
from_key,
|
||||||
|
room_ids=None,
|
||||||
|
limit=None,
|
||||||
|
is_guest=None
|
||||||
|
):
|
||||||
return defer.succeed(([], from_key))
|
return defer.succeed(([], from_key))
|
||||||
|
|
||||||
def get_current_key(self, direction='f'):
|
def get_current_key(self, direction='f'):
|
||||||
|
@ -86,10 +93,11 @@ class PresenceStateTestCase(unittest.TestCase):
|
||||||
return defer.succeed([])
|
return defer.succeed([])
|
||||||
self.datastore.get_presence_list = get_presence_list
|
self.datastore.get_presence_list = get_presence_list
|
||||||
|
|
||||||
def _get_user_by_access_token(token=None):
|
def _get_user_by_access_token(token=None, allow_guest=False):
|
||||||
return {
|
return {
|
||||||
"user": UserID.from_string(myid),
|
"user": UserID.from_string(myid),
|
||||||
"token_id": 1,
|
"token_id": 1,
|
||||||
|
"is_guest": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
|
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
|
||||||
|
@ -173,10 +181,11 @@ class PresenceListTestCase(unittest.TestCase):
|
||||||
)
|
)
|
||||||
self.datastore.has_presence_state = has_presence_state
|
self.datastore.has_presence_state = has_presence_state
|
||||||
|
|
||||||
def _get_user_by_access_token(token=None):
|
def _get_user_by_access_token(token=None, allow_guest=False):
|
||||||
return {
|
return {
|
||||||
"user": UserID.from_string(myid),
|
"user": UserID.from_string(myid),
|
||||||
"token_id": 1,
|
"token_id": 1,
|
||||||
|
"is_guest": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
hs.handlers.room_member_handler = Mock(
|
hs.handlers.room_member_handler = Mock(
|
||||||
|
@ -291,8 +300,8 @@ class PresenceEventStreamTestCase(unittest.TestCase):
|
||||||
|
|
||||||
hs.get_clock().time_msec.return_value = 1000000
|
hs.get_clock().time_msec.return_value = 1000000
|
||||||
|
|
||||||
def _get_user_by_req(req=None):
|
def _get_user_by_req(req=None, allow_guest=False):
|
||||||
return (UserID.from_string(myid), "")
|
return (UserID.from_string(myid), "", False)
|
||||||
|
|
||||||
hs.get_v1auth().get_user_by_req = _get_user_by_req
|
hs.get_v1auth().get_user_by_req = _get_user_by_req
|
||||||
|
|
||||||
|
|
|
@ -52,8 +52,8 @@ class ProfileTestCase(unittest.TestCase):
|
||||||
replication_layer=Mock(),
|
replication_layer=Mock(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_user_by_req(request=None):
|
def _get_user_by_req(request=None, allow_guest=False):
|
||||||
return (UserID.from_string(myid), "")
|
return (UserID.from_string(myid), "", False)
|
||||||
|
|
||||||
hs.get_v1auth().get_user_by_req = _get_user_by_req
|
hs.get_v1auth().get_user_by_req = _get_user_by_req
|
||||||
|
|
||||||
|
|
|
@ -54,10 +54,11 @@ class RoomPermissionsTestCase(RestTestCase):
|
||||||
|
|
||||||
hs.get_handlers().federation_handler = Mock()
|
hs.get_handlers().federation_handler = Mock()
|
||||||
|
|
||||||
def _get_user_by_access_token(token=None):
|
def _get_user_by_access_token(token=None, allow_guest=False):
|
||||||
return {
|
return {
|
||||||
"user": UserID.from_string(self.auth_user_id),
|
"user": UserID.from_string(self.auth_user_id),
|
||||||
"token_id": 1,
|
"token_id": 1,
|
||||||
|
"is_guest": False,
|
||||||
}
|
}
|
||||||
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
|
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
|
||||||
|
|
||||||
|
@ -439,10 +440,11 @@ class RoomsMemberListTestCase(RestTestCase):
|
||||||
|
|
||||||
self.auth_user_id = self.user_id
|
self.auth_user_id = self.user_id
|
||||||
|
|
||||||
def _get_user_by_access_token(token=None):
|
def _get_user_by_access_token(token=None, allow_guest=False):
|
||||||
return {
|
return {
|
||||||
"user": UserID.from_string(self.auth_user_id),
|
"user": UserID.from_string(self.auth_user_id),
|
||||||
"token_id": 1,
|
"token_id": 1,
|
||||||
|
"is_guest": False,
|
||||||
}
|
}
|
||||||
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
|
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
|
||||||
|
|
||||||
|
@ -517,10 +519,11 @@ class RoomsCreateTestCase(RestTestCase):
|
||||||
|
|
||||||
hs.get_handlers().federation_handler = Mock()
|
hs.get_handlers().federation_handler = Mock()
|
||||||
|
|
||||||
def _get_user_by_access_token(token=None):
|
def _get_user_by_access_token(token=None, allow_guest=False):
|
||||||
return {
|
return {
|
||||||
"user": UserID.from_string(self.auth_user_id),
|
"user": UserID.from_string(self.auth_user_id),
|
||||||
"token_id": 1,
|
"token_id": 1,
|
||||||
|
"is_guest": False,
|
||||||
}
|
}
|
||||||
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
|
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
|
||||||
|
|
||||||
|
@ -608,10 +611,11 @@ class RoomTopicTestCase(RestTestCase):
|
||||||
|
|
||||||
hs.get_handlers().federation_handler = Mock()
|
hs.get_handlers().federation_handler = Mock()
|
||||||
|
|
||||||
def _get_user_by_access_token(token=None):
|
def _get_user_by_access_token(token=None, allow_guest=False):
|
||||||
return {
|
return {
|
||||||
"user": UserID.from_string(self.auth_user_id),
|
"user": UserID.from_string(self.auth_user_id),
|
||||||
"token_id": 1,
|
"token_id": 1,
|
||||||
|
"is_guest": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
|
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
|
||||||
|
@ -713,10 +717,11 @@ class RoomMemberStateTestCase(RestTestCase):
|
||||||
|
|
||||||
hs.get_handlers().federation_handler = Mock()
|
hs.get_handlers().federation_handler = Mock()
|
||||||
|
|
||||||
def _get_user_by_access_token(token=None):
|
def _get_user_by_access_token(token=None, allow_guest=False):
|
||||||
return {
|
return {
|
||||||
"user": UserID.from_string(self.auth_user_id),
|
"user": UserID.from_string(self.auth_user_id),
|
||||||
"token_id": 1,
|
"token_id": 1,
|
||||||
|
"is_guest": False,
|
||||||
}
|
}
|
||||||
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
|
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
|
||||||
|
|
||||||
|
@ -838,10 +843,11 @@ class RoomMessagesTestCase(RestTestCase):
|
||||||
|
|
||||||
hs.get_handlers().federation_handler = Mock()
|
hs.get_handlers().federation_handler = Mock()
|
||||||
|
|
||||||
def _get_user_by_access_token(token=None):
|
def _get_user_by_access_token(token=None, allow_guest=False):
|
||||||
return {
|
return {
|
||||||
"user": UserID.from_string(self.auth_user_id),
|
"user": UserID.from_string(self.auth_user_id),
|
||||||
"token_id": 1,
|
"token_id": 1,
|
||||||
|
"is_guest": False,
|
||||||
}
|
}
|
||||||
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
|
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
|
||||||
|
|
||||||
|
@ -933,10 +939,11 @@ class RoomInitialSyncTestCase(RestTestCase):
|
||||||
|
|
||||||
hs.get_handlers().federation_handler = Mock()
|
hs.get_handlers().federation_handler = Mock()
|
||||||
|
|
||||||
def _get_user_by_access_token(token=None):
|
def _get_user_by_access_token(token=None, allow_guest=False):
|
||||||
return {
|
return {
|
||||||
"user": UserID.from_string(self.auth_user_id),
|
"user": UserID.from_string(self.auth_user_id),
|
||||||
"token_id": 1,
|
"token_id": 1,
|
||||||
|
"is_guest": False,
|
||||||
}
|
}
|
||||||
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
|
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
|
||||||
|
|
||||||
|
|
|
@ -61,10 +61,11 @@ class RoomTypingTestCase(RestTestCase):
|
||||||
|
|
||||||
hs.get_handlers().federation_handler = Mock()
|
hs.get_handlers().federation_handler = Mock()
|
||||||
|
|
||||||
def _get_user_by_access_token(token=None):
|
def _get_user_by_access_token(token=None, allow_guest=False):
|
||||||
return {
|
return {
|
||||||
"user": UserID.from_string(self.auth_user_id),
|
"user": UserID.from_string(self.auth_user_id),
|
||||||
"token_id": 1,
|
"token_id": 1,
|
||||||
|
"is_guest": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
|
hs.get_v1auth()._get_user_by_access_token = _get_user_by_access_token
|
||||||
|
@ -115,7 +116,10 @@ class RoomTypingTestCase(RestTestCase):
|
||||||
self.assertEquals(200, code)
|
self.assertEquals(200, code)
|
||||||
|
|
||||||
self.assertEquals(self.event_source.get_current_key(), 1)
|
self.assertEquals(self.event_source.get_current_key(), 1)
|
||||||
events = yield self.event_source.get_new_events_for_user(self.user, 0, None)
|
events = yield self.event_source.get_new_events(
|
||||||
|
from_key=0,
|
||||||
|
room_ids=[self.room_id],
|
||||||
|
)
|
||||||
self.assertEquals(
|
self.assertEquals(
|
||||||
events[0],
|
events[0],
|
||||||
[
|
[
|
||||||
|
|
|
@ -43,10 +43,11 @@ class V2AlphaRestTestCase(unittest.TestCase):
|
||||||
resource_for_federation=self.mock_resource,
|
resource_for_federation=self.mock_resource,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_user_by_access_token(token=None):
|
def _get_user_by_access_token(token=None, allow_guest=False):
|
||||||
return {
|
return {
|
||||||
"user": UserID.from_string(self.USER_ID),
|
"user": UserID.from_string(self.USER_ID),
|
||||||
"token_id": 1,
|
"token_id": 1,
|
||||||
|
"is_guest": False,
|
||||||
}
|
}
|
||||||
hs.get_auth()._get_user_by_access_token = _get_user_by_access_token
|
hs.get_auth()._get_user_by_access_token = _get_user_by_access_token
|
||||||
|
|
||||||
|
|
|
@ -120,7 +120,6 @@ class RedactionTestCase(unittest.TestCase):
|
||||||
self.u_alice.to_string(),
|
self.u_alice.to_string(),
|
||||||
start,
|
start,
|
||||||
end,
|
end,
|
||||||
None, # Is currently ignored
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(1, len(results))
|
self.assertEqual(1, len(results))
|
||||||
|
@ -149,7 +148,6 @@ class RedactionTestCase(unittest.TestCase):
|
||||||
self.u_alice.to_string(),
|
self.u_alice.to_string(),
|
||||||
start,
|
start,
|
||||||
end,
|
end,
|
||||||
None, # Is currently ignored
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(1, len(results))
|
self.assertEqual(1, len(results))
|
||||||
|
@ -199,7 +197,6 @@ class RedactionTestCase(unittest.TestCase):
|
||||||
self.u_alice.to_string(),
|
self.u_alice.to_string(),
|
||||||
start,
|
start,
|
||||||
end,
|
end,
|
||||||
None, # Is currently ignored
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(1, len(results))
|
self.assertEqual(1, len(results))
|
||||||
|
@ -228,7 +225,6 @@ class RedactionTestCase(unittest.TestCase):
|
||||||
self.u_alice.to_string(),
|
self.u_alice.to_string(),
|
||||||
start,
|
start,
|
||||||
end,
|
end,
|
||||||
None, # Is currently ignored
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(1, len(results))
|
self.assertEqual(1, len(results))
|
||||||
|
|
|
@ -68,7 +68,6 @@ class StreamStoreTestCase(unittest.TestCase):
|
||||||
self.u_bob.to_string(),
|
self.u_bob.to_string(),
|
||||||
start,
|
start,
|
||||||
end,
|
end,
|
||||||
None, # Is currently ignored
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(1, len(results))
|
self.assertEqual(1, len(results))
|
||||||
|
@ -105,7 +104,6 @@ class StreamStoreTestCase(unittest.TestCase):
|
||||||
self.u_alice.to_string(),
|
self.u_alice.to_string(),
|
||||||
start,
|
start,
|
||||||
end,
|
end,
|
||||||
None, # Is currently ignored
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(1, len(results))
|
self.assertEqual(1, len(results))
|
||||||
|
@ -147,7 +145,6 @@ class StreamStoreTestCase(unittest.TestCase):
|
||||||
self.u_bob.to_string(),
|
self.u_bob.to_string(),
|
||||||
start,
|
start,
|
||||||
end,
|
end,
|
||||||
None, # Is currently ignored
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# We should not get the message, as it happened *after* bob left.
|
# We should not get the message, as it happened *after* bob left.
|
||||||
|
@ -175,7 +172,6 @@ class StreamStoreTestCase(unittest.TestCase):
|
||||||
self.u_bob.to_string(),
|
self.u_bob.to_string(),
|
||||||
start,
|
start,
|
||||||
end,
|
end,
|
||||||
None, # Is currently ignored
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# We should not get the message, as it happened *after* bob left.
|
# We should not get the message, as it happened *after* bob left.
|
||||||
|
|
|
@ -1,108 +0,0 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
# Copyright 2014 OpenMarket Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
from tests import unittest
|
|
||||||
|
|
||||||
from synapse.util.lockutils import LockManager
|
|
||||||
|
|
||||||
|
|
||||||
class LockManagerTestCase(unittest.TestCase):
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
self.lock_manager = LockManager()
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_one_lock(self):
|
|
||||||
key = "test"
|
|
||||||
deferred_lock1 = self.lock_manager.lock(key)
|
|
||||||
|
|
||||||
self.assertTrue(deferred_lock1.called)
|
|
||||||
|
|
||||||
lock1 = yield deferred_lock1
|
|
||||||
|
|
||||||
self.assertFalse(lock1.released)
|
|
||||||
|
|
||||||
lock1.release()
|
|
||||||
|
|
||||||
self.assertTrue(lock1.released)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_concurrent_locks(self):
|
|
||||||
key = "test"
|
|
||||||
deferred_lock1 = self.lock_manager.lock(key)
|
|
||||||
deferred_lock2 = self.lock_manager.lock(key)
|
|
||||||
|
|
||||||
self.assertTrue(deferred_lock1.called)
|
|
||||||
self.assertFalse(deferred_lock2.called)
|
|
||||||
|
|
||||||
lock1 = yield deferred_lock1
|
|
||||||
|
|
||||||
self.assertFalse(lock1.released)
|
|
||||||
self.assertFalse(deferred_lock2.called)
|
|
||||||
|
|
||||||
lock1.release()
|
|
||||||
|
|
||||||
self.assertTrue(lock1.released)
|
|
||||||
self.assertTrue(deferred_lock2.called)
|
|
||||||
|
|
||||||
lock2 = yield deferred_lock2
|
|
||||||
|
|
||||||
lock2.release()
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_sequential_locks(self):
|
|
||||||
key = "test"
|
|
||||||
deferred_lock1 = self.lock_manager.lock(key)
|
|
||||||
|
|
||||||
self.assertTrue(deferred_lock1.called)
|
|
||||||
|
|
||||||
lock1 = yield deferred_lock1
|
|
||||||
|
|
||||||
self.assertFalse(lock1.released)
|
|
||||||
|
|
||||||
lock1.release()
|
|
||||||
|
|
||||||
self.assertTrue(lock1.released)
|
|
||||||
|
|
||||||
deferred_lock2 = self.lock_manager.lock(key)
|
|
||||||
|
|
||||||
self.assertTrue(deferred_lock2.called)
|
|
||||||
|
|
||||||
lock2 = yield deferred_lock2
|
|
||||||
|
|
||||||
self.assertFalse(lock2.released)
|
|
||||||
|
|
||||||
lock2.release()
|
|
||||||
|
|
||||||
self.assertTrue(lock2.released)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_with_statement(self):
|
|
||||||
key = "test"
|
|
||||||
with (yield self.lock_manager.lock(key)) as lock:
|
|
||||||
self.assertFalse(lock.released)
|
|
||||||
|
|
||||||
self.assertTrue(lock.released)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_two_with_statement(self):
|
|
||||||
key = "test"
|
|
||||||
with (yield self.lock_manager.lock(key)):
|
|
||||||
pass
|
|
||||||
|
|
||||||
with (yield self.lock_manager.lock(key)):
|
|
||||||
pass
|
|
|
@ -335,7 +335,7 @@ class MemoryDataStore(object):
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_room_events_stream(self, user_id=None, from_key=None, to_key=None,
|
def get_room_events_stream(self, user_id=None, from_key=None, to_key=None,
|
||||||
room_id=None, limit=0, with_feedback=False):
|
limit=0, with_feedback=False):
|
||||||
return ([], from_key) # TODO
|
return ([], from_key) # TODO
|
||||||
|
|
||||||
def get_joined_hosts_for_room(self, room_id):
|
def get_joined_hosts_for_room(self, room_id):
|
||||||
|
|
Loading…
Reference in a new issue