0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-14 11:33:53 +01:00

Extract the id token of the token when authing users, include the token and device_id in the internal meta data for the event along with the transaction id when sending events

This commit is contained in:
Mark Haines 2015-01-28 16:58:23 +00:00
parent c59bcabf0b
commit 388581e087
18 changed files with 92 additions and 48 deletions

View file

@ -21,7 +21,7 @@ from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, StoreError, Codes, SynapseError
from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor
from synapse.types import UserID
from synapse.types import UserID, ClientID
import logging
@ -292,7 +292,7 @@ class Auth(object):
Returns:
Tuple of UserID and device string:
User ID object of the user making the request
Device ID string of the device the user is using
Client ID object of the client instance the user is using
Raises:
AuthError if no user by that token exists or the token is invalid.
"""
@ -302,6 +302,7 @@ class Auth(object):
user_info = yield self.get_user_by_token(access_token)
user = user_info["user"]
device_id = user_info["device_id"]
token_id = user_info["token_id"]
ip_addr = self.hs.get_ip_from_request(request)
user_agent = request.requestHeaders.getRawHeaders(
@ -317,7 +318,7 @@ class Auth(object):
user_agent=user_agent
)
defer.returnValue((user, device_id))
defer.returnValue((user, ClientID(device_id, token_id)))
except KeyError:
raise AuthError(403, "Missing access token.")
@ -342,6 +343,7 @@ class Auth(object):
"admin": bool(ret.get("admin", False)),
"device_id": ret.get("device_id"),
"user": UserID.from_string(ret.get("name")),
"token_id": ret.get("token_id", None),
}
defer.returnValue(user_info)

View file

@ -114,7 +114,8 @@ class MessageHandler(BaseHandler):
defer.returnValue(chunk)
@defer.inlineCallbacks
def create_and_send_event(self, event_dict, ratelimit=True):
def create_and_send_event(self, event_dict, ratelimit=True,
client=None, txn_id=None):
""" Given a dict from a client, create and handle a new event.
Creates an FrozenEvent object, filling out auth_events, prev_events,
@ -148,6 +149,15 @@ class MessageHandler(BaseHandler):
builder.content
)
if client is not None:
if client.token_id is not None:
builder.internal_metadata.token_id = client.token_id
if client.device_id is not None:
builder.internal_metadata.device_id = client.device_id
if txn_id is not None:
builder.internal_metadata.txn_id = txn_id
event, context = yield self._create_new_client_event(
builder=builder,
)

View file

@ -31,7 +31,7 @@ class WhoisRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, user_id):
target_user = UserID.from_string(user_id)
auth_user, device_id = yield self.auth.get_user_by_req(request)
auth_user, client = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(auth_user)
if not is_admin and target_user != auth_user:

View file

@ -45,7 +45,7 @@ class ClientDirectoryServer(ClientV1RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, room_alias):
user, device_id = yield self.auth.get_user_by_req(request)
user, client = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
if not "room_id" in content:
@ -85,7 +85,7 @@ class ClientDirectoryServer(ClientV1RestServlet):
@defer.inlineCallbacks
def on_DELETE(self, request, room_alias):
user, device_id = yield self.auth.get_user_by_req(request)
user, client = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(user)
if not is_admin:

View file

@ -34,7 +34,7 @@ class EventStreamRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
auth_user, device_id = yield self.auth.get_user_by_req(request)
auth_user, client = yield self.auth.get_user_by_req(request)
try:
handler = self.handlers.event_stream_handler
pagin_config = PaginationConfig.from_request(request)
@ -71,7 +71,7 @@ class EventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, event_id):
auth_user, device_id = yield self.auth.get_user_by_req(request)
auth_user, client = yield self.auth.get_user_by_req(request)
handler = self.handlers.event_handler
event = yield handler.get_event(auth_user, event_id)

View file

@ -25,7 +25,7 @@ class InitialSyncRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
user, device_id = yield self.auth.get_user_by_req(request)
user, client = yield self.auth.get_user_by_req(request)
with_feedback = "feedback" in request.args
as_client_event = "raw" not in request.args
pagination_config = PaginationConfig.from_request(request)

View file

@ -32,7 +32,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, user_id):
auth_user, device_id = yield self.auth.get_user_by_req(request)
auth_user, client = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
state = yield self.handlers.presence_handler.get_state(
@ -42,7 +42,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, user_id):
auth_user, device_id = yield self.auth.get_user_by_req(request)
auth_user, client = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
state = {}
@ -77,7 +77,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, user_id):
auth_user, device_id = yield self.auth.get_user_by_req(request)
auth_user, client = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
if not self.hs.is_mine(user):
@ -97,7 +97,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_POST(self, request, user_id):
auth_user, device_id = yield self.auth.get_user_by_req(request)
auth_user, client = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
if not self.hs.is_mine(user):

View file

@ -37,7 +37,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, user_id):
auth_user, device_id = yield self.auth.get_user_by_req(request)
auth_user, client = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
try:
@ -70,7 +70,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, user_id):
auth_user, device_id = yield self.auth.get_user_by_req(request)
auth_user, client = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
try:

View file

@ -62,7 +62,7 @@ class RoomCreateRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
auth_user, device_id = yield self.auth.get_user_by_req(request)
auth_user, client = yield self.auth.get_user_by_req(request)
room_config = self.get_room_config(request)
info = yield self.make_room(room_config, auth_user, None)
@ -125,7 +125,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, room_id, event_type, state_key):
user, device_id = yield self.auth.get_user_by_req(request)
user, client = yield self.auth.get_user_by_req(request)
msg_handler = self.handlers.message_handler
data = yield msg_handler.get_room_data(
@ -142,8 +142,8 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
defer.returnValue((200, data.get_dict()["content"]))
@defer.inlineCallbacks
def on_PUT(self, request, room_id, event_type, state_key):
user, device_id = yield self.auth.get_user_by_req(request)
def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
user, client = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
@ -158,7 +158,9 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
event_dict["state_key"] = state_key
msg_handler = self.handlers.message_handler
yield msg_handler.create_and_send_event(event_dict)
yield msg_handler.create_and_send_event(
event_dict, client=client, txn_id=txn_id,
)
defer.returnValue((200, {}))
@ -172,8 +174,8 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
register_txn_path(self, PATTERN, http_server, with_get=True)
@defer.inlineCallbacks
def on_POST(self, request, room_id, event_type):
user, device_id = yield self.auth.get_user_by_req(request)
def on_POST(self, request, room_id, event_type, txn_id=None):
user, client = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
msg_handler = self.handlers.message_handler
@ -183,7 +185,9 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
"content": content,
"room_id": room_id,
"sender": user.to_string(),
}
},
client=client,
txn_id=txn_id,
)
defer.returnValue((200, {"event_id": event.event_id}))
@ -200,7 +204,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
except KeyError:
pass
response = yield self.on_POST(request, room_id, event_type)
response = yield self.on_POST(request, room_id, event_type, txn_id)
self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response)
@ -215,8 +219,8 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
register_txn_path(self, PATTERN, http_server)
@defer.inlineCallbacks
def on_POST(self, request, room_identifier):
user, device_id = yield self.auth.get_user_by_req(request)
def on_POST(self, request, room_identifier, txn_id=None):
user, client = yield self.auth.get_user_by_req(request)
# the identifier could be a room alias or a room id. Try one then the
# other if it fails to parse, without swallowing other valid
@ -245,7 +249,9 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
"room_id": identifier.to_string(),
"sender": user.to_string(),
"state_key": user.to_string(),
}
},
client=client,
txn_id=txn_id,
)
defer.returnValue((200, {"room_id": identifier.to_string()}))
@ -259,7 +265,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
except KeyError:
pass
response = yield self.on_POST(request, room_identifier)
response = yield self.on_POST(request, room_identifier, txn_id)
self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response)
@ -283,7 +289,7 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, room_id):
# TODO support Pagination stream API (limit/tokens)
user, device_id = yield self.auth.get_user_by_req(request)
user, client = yield self.auth.get_user_by_req(request)
handler = self.handlers.room_member_handler
members = yield handler.get_room_members_as_pagination_chunk(
room_id=room_id,
@ -311,7 +317,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, room_id):
user, device_id = yield self.auth.get_user_by_req(request)
user, client = yield self.auth.get_user_by_req(request)
pagination_config = PaginationConfig.from_request(
request, default_limit=10,
)
@ -335,7 +341,7 @@ class RoomStateRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, room_id):
user, device_id = yield self.auth.get_user_by_req(request)
user, client = yield self.auth.get_user_by_req(request)
handler = self.handlers.message_handler
# Get all the current state for this room
events = yield handler.get_state_events(
@ -351,7 +357,7 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, room_id):
user, device_id = yield self.auth.get_user_by_req(request)
user, client = yield self.auth.get_user_by_req(request)
pagination_config = PaginationConfig.from_request(request)
content = yield self.handlers.message_handler.room_initial_sync(
room_id=room_id,
@ -395,8 +401,8 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
register_txn_path(self, PATTERN, http_server)
@defer.inlineCallbacks
def on_POST(self, request, room_id, membership_action):
user, device_id = yield self.auth.get_user_by_req(request)
def on_POST(self, request, room_id, membership_action, txn_id=None):
user, client = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
@ -418,7 +424,9 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
"room_id": room_id,
"sender": user.to_string(),
"state_key": state_key,
}
},
client=client,
txn_id=txn_id,
)
defer.returnValue((200, {}))
@ -432,7 +440,9 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
except KeyError:
pass
response = yield self.on_POST(request, room_id, membership_action)
response = yield self.on_POST(
request, room_id, membership_action, txn_id
)
self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response)
@ -444,8 +454,8 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
register_txn_path(self, PATTERN, http_server)
@defer.inlineCallbacks
def on_POST(self, request, room_id, event_id):
user, device_id = yield self.auth.get_user_by_req(request)
def on_POST(self, request, room_id, event_id, txn_id=None):
user, client = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
msg_handler = self.handlers.message_handler
@ -456,7 +466,9 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
"room_id": room_id,
"sender": user.to_string(),
"redacts": event_id,
}
},
client=client,
txn_id=txn_id,
)
defer.returnValue((200, {"event_id": event.event_id}))
@ -470,7 +482,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
except KeyError:
pass
response = yield self.on_POST(request, room_id, event_id)
response = yield self.on_POST(request, room_id, event_id, txn_id)
self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response)
@ -483,7 +495,7 @@ class RoomTypingRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, room_id, user_id):
auth_user, device_id = yield self.auth.get_user_by_req(request)
auth_user, client = yield self.auth.get_user_by_req(request)
room_id = urllib.unquote(room_id)
target_user = UserID.from_string(urllib.unquote(user_id))

View file

@ -28,7 +28,7 @@ class VoipRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
auth_user, device_id = yield self.auth.get_user_by_req(request)
auth_user, client = yield self.auth.get_user_by_req(request)
turnUris = self.hs.config.turn_uris
turnSecret = self.hs.config.turn_shared_secret

View file

@ -66,7 +66,7 @@ class ContentRepoResource(resource.Resource):
@defer.inlineCallbacks
def map_request_to_name(self, request):
# auth the user
auth_user, device_id = yield self.auth.get_user_by_req(request)
auth_user, client = yield self.auth.get_user_by_req(request)
# namespace all file uploads on the user
prefix = base64.urlsafe_b64encode(

View file

@ -42,7 +42,7 @@ class UploadResource(BaseMediaResource):
@defer.inlineCallbacks
def _async_render_POST(self, request):
try:
auth_user, device_id = yield self.auth.get_user_by_req(request)
auth_user, client = yield self.auth.get_user_by_req(request)
# TODO: The checks here are a bit late. The content will have
# already been uploaded to a tmp file at this point
content_length = request.getHeader("Content-Length")

View file

@ -122,7 +122,8 @@ class RegistrationStore(SQLBaseStore):
def _query_for_auth(self, txn, token):
sql = (
"SELECT users.name, users.admin, access_tokens.device_id"
"SELECT users.name, users.admin,"
" access_tokens.device_id, access_tokens.id as token_id"
" FROM users"
" INNER JOIN access_tokens on users.id = access_tokens.user_id"
" WHERE token = ?"

View file

@ -119,3 +119,6 @@ class StreamToken(
d = self._asdict()
d[key] = new_value
return StreamToken(**d)
ClientID = namedtuple("ClientID", ("device_id", "token_id"))

View file

@ -75,6 +75,7 @@ class PresenceStateTestCase(unittest.TestCase):
"user": UserID.from_string(myid),
"admin": False,
"device_id": None,
"token_id": 1,
}
hs.get_auth().get_user_by_token = _get_user_by_token
@ -165,6 +166,7 @@ class PresenceListTestCase(unittest.TestCase):
"user": UserID.from_string(myid),
"admin": False,
"device_id": None,
"token_id": 1,
}
hs.handlers.room_member_handler = Mock(

View file

@ -70,6 +70,7 @@ class RoomPermissionsTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id),
"admin": False,
"device_id": None,
"token_id": 1,
}
hs.get_auth().get_user_by_token = _get_user_by_token
@ -466,6 +467,7 @@ class RoomsMemberListTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id),
"admin": False,
"device_id": None,
"token_id": 1,
}
hs.get_auth().get_user_by_token = _get_user_by_token
@ -555,6 +557,7 @@ class RoomsCreateTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id),
"admin": False,
"device_id": None,
"token_id": 1,
}
hs.get_auth().get_user_by_token = _get_user_by_token
@ -657,6 +660,7 @@ class RoomTopicTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id),
"admin": False,
"device_id": None,
"token_id": 1,
}
hs.get_auth().get_user_by_token = _get_user_by_token
@ -773,6 +777,7 @@ class RoomMemberStateTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id),
"admin": False,
"device_id": None,
"token_id": 1,
}
hs.get_auth().get_user_by_token = _get_user_by_token
@ -909,6 +914,7 @@ class RoomMessagesTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id),
"admin": False,
"device_id": None,
"token_id": 1,
}
hs.get_auth().get_user_by_token = _get_user_by_token
@ -1013,6 +1019,7 @@ class RoomInitialSyncTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id),
"admin": False,
"device_id": None,
"token_id": 1,
}
hs.get_auth().get_user_by_token = _get_user_by_token

View file

@ -73,6 +73,7 @@ class RoomTypingTestCase(RestTestCase):
"user": UserID.from_string(self.auth_user_id),
"admin": False,
"device_id": None,
"token_id": 1,
}
hs.get_auth().get_user_by_token = _get_user_by_token

View file

@ -53,7 +53,10 @@ class RegistrationStoreTestCase(unittest.TestCase):
)
self.assertEquals(
{"admin": 0, "device_id": None, "name": self.user_id},
{"admin": 0,
"device_id": None,
"name": self.user_id,
"token_id": 1},
(yield self.store.get_user_by_token(self.tokens[0]))
)
@ -63,7 +66,10 @@ class RegistrationStoreTestCase(unittest.TestCase):
yield self.store.add_access_token_to_user(self.user_id, self.tokens[1])
self.assertEquals(
{"admin": 0, "device_id": None, "name": self.user_id},
{"admin": 0,
"device_id": None,
"name": self.user_id,
"token_id": 2},
(yield self.store.get_user_by_token(self.tokens[1]))
)