mirror of
https://mau.dev/maunium/synapse.git
synced 2025-01-01 22:04:02 +01:00
23740eaa3d
During the migration the automated script to update the copyright headers accidentally got rid of some of the existing copyright lines. Reinstate them.
249 lines
8.2 KiB
Python
249 lines
8.2 KiB
Python
#
|
|
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
|
#
|
|
# Copyright 2015, 2016 OpenMarket Ltd
|
|
# Copyright (C) 2023 New Vector, Ltd
|
|
#
|
|
# This program is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU Affero General Public License as
|
|
# published by the Free Software Foundation, either version 3 of the
|
|
# License, or (at your option) any later version.
|
|
#
|
|
# See the GNU Affero General Public License for more details:
|
|
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
|
#
|
|
# Originally licensed under the Apache License, Version 2.0:
|
|
# <http://www.apache.org/licenses/LICENSE-2.0>.
|
|
#
|
|
# [This file includes modifications made by New Vector Limited]
|
|
#
|
|
#
|
|
from typing import Optional
|
|
from unittest.mock import AsyncMock
|
|
|
|
import pymacaroons
|
|
|
|
from twisted.test.proto_helpers import MemoryReactor
|
|
|
|
from synapse.api.errors import AuthError, ResourceLimitError
|
|
from synapse.rest import admin
|
|
from synapse.rest.client import login
|
|
from synapse.server import HomeServer
|
|
from synapse.util import Clock
|
|
|
|
from tests import unittest
|
|
|
|
|
|
class AuthTestCase(unittest.HomeserverTestCase):
|
|
servlets = [
|
|
admin.register_servlets,
|
|
login.register_servlets,
|
|
]
|
|
|
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
|
self.auth_handler = hs.get_auth_handler()
|
|
self.macaroon_generator = hs.get_macaroon_generator()
|
|
|
|
# MAU tests
|
|
# AuthBlocking reads from the hs' config on initialization. We need to
|
|
# modify its config instead of the hs'
|
|
self.auth_blocking = hs.get_auth_blocking()
|
|
self.auth_blocking._max_mau_value = 50
|
|
|
|
self.small_number_of_users = 1
|
|
self.large_number_of_users = 100
|
|
|
|
self.user1 = self.register_user("a_user", "pass")
|
|
|
|
def token_login(self, token: str) -> Optional[str]:
|
|
body = {
|
|
"type": "m.login.token",
|
|
"token": token,
|
|
}
|
|
|
|
channel = self.make_request(
|
|
"POST",
|
|
"/_matrix/client/v3/login",
|
|
body,
|
|
)
|
|
|
|
if channel.code == 200:
|
|
return channel.json_body["user_id"]
|
|
|
|
return None
|
|
|
|
def test_macaroon_caveats(self) -> None:
|
|
token = self.macaroon_generator.generate_guest_access_token("a_user")
|
|
macaroon = pymacaroons.Macaroon.deserialize(token)
|
|
|
|
def verify_gen(caveat: str) -> bool:
|
|
return caveat == "gen = 1"
|
|
|
|
def verify_user(caveat: str) -> bool:
|
|
return caveat == "user_id = a_user"
|
|
|
|
def verify_type(caveat: str) -> bool:
|
|
return caveat == "type = access"
|
|
|
|
def verify_nonce(caveat: str) -> bool:
|
|
return caveat.startswith("nonce =")
|
|
|
|
def verify_guest(caveat: str) -> bool:
|
|
return caveat == "guest = true"
|
|
|
|
v = pymacaroons.Verifier()
|
|
v.satisfy_general(verify_gen)
|
|
v.satisfy_general(verify_user)
|
|
v.satisfy_general(verify_type)
|
|
v.satisfy_general(verify_nonce)
|
|
v.satisfy_general(verify_guest)
|
|
v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
|
|
|
|
def test_login_token_gives_user_id(self) -> None:
|
|
token = self.get_success(
|
|
self.auth_handler.create_login_token_for_user_id(
|
|
self.user1,
|
|
duration_ms=(5 * 1000),
|
|
)
|
|
)
|
|
|
|
res = self.get_success(self.auth_handler.consume_login_token(token))
|
|
self.assertEqual(self.user1, res.user_id)
|
|
self.assertEqual(None, res.auth_provider_id)
|
|
|
|
def test_login_token_reuse_fails(self) -> None:
|
|
token = self.get_success(
|
|
self.auth_handler.create_login_token_for_user_id(
|
|
self.user1,
|
|
duration_ms=(5 * 1000),
|
|
)
|
|
)
|
|
|
|
self.get_success(self.auth_handler.consume_login_token(token))
|
|
|
|
self.get_failure(
|
|
self.auth_handler.consume_login_token(token),
|
|
AuthError,
|
|
)
|
|
|
|
def test_login_token_expires(self) -> None:
|
|
token = self.get_success(
|
|
self.auth_handler.create_login_token_for_user_id(
|
|
self.user1,
|
|
duration_ms=(5 * 1000),
|
|
)
|
|
)
|
|
|
|
# when we advance the clock, the token should be rejected
|
|
self.reactor.advance(6)
|
|
self.get_failure(
|
|
self.auth_handler.consume_login_token(token),
|
|
AuthError,
|
|
)
|
|
|
|
def test_login_token_gives_auth_provider(self) -> None:
|
|
token = self.get_success(
|
|
self.auth_handler.create_login_token_for_user_id(
|
|
self.user1,
|
|
auth_provider_id="my_idp",
|
|
auth_provider_session_id="11-22-33-44",
|
|
duration_ms=(5 * 1000),
|
|
)
|
|
)
|
|
res = self.get_success(self.auth_handler.consume_login_token(token))
|
|
self.assertEqual(self.user1, res.user_id)
|
|
self.assertEqual("my_idp", res.auth_provider_id)
|
|
self.assertEqual("11-22-33-44", res.auth_provider_session_id)
|
|
|
|
def test_mau_limits_disabled(self) -> None:
|
|
self.auth_blocking._limit_usage_by_mau = False
|
|
# Ensure does not throw exception
|
|
self.get_success(
|
|
self.auth_handler.create_access_token_for_user_id(
|
|
self.user1, device_id=None, valid_until_ms=None
|
|
)
|
|
)
|
|
|
|
token = self.get_success(
|
|
self.auth_handler.create_login_token_for_user_id(self.user1)
|
|
)
|
|
|
|
self.assertIsNotNone(self.token_login(token))
|
|
|
|
def test_mau_limits_exceeded_large(self) -> None:
|
|
self.auth_blocking._limit_usage_by_mau = True
|
|
self.hs.get_datastores().main.get_monthly_active_count = AsyncMock(
|
|
return_value=self.large_number_of_users
|
|
)
|
|
|
|
self.get_failure(
|
|
self.auth_handler.create_access_token_for_user_id(
|
|
self.user1, device_id=None, valid_until_ms=None
|
|
),
|
|
ResourceLimitError,
|
|
)
|
|
|
|
self.hs.get_datastores().main.get_monthly_active_count = AsyncMock(
|
|
return_value=self.large_number_of_users
|
|
)
|
|
token = self.get_success(
|
|
self.auth_handler.create_login_token_for_user_id(self.user1)
|
|
)
|
|
self.assertIsNone(self.token_login(token))
|
|
|
|
def test_mau_limits_parity(self) -> None:
|
|
# Ensure we're not at the unix epoch.
|
|
self.reactor.advance(1)
|
|
self.auth_blocking._limit_usage_by_mau = True
|
|
|
|
# Set the server to be at the edge of too many users.
|
|
self.hs.get_datastores().main.get_monthly_active_count = AsyncMock(
|
|
return_value=self.auth_blocking._max_mau_value
|
|
)
|
|
|
|
# If not in monthly active cohort
|
|
self.get_failure(
|
|
self.auth_handler.create_access_token_for_user_id(
|
|
self.user1, device_id=None, valid_until_ms=None
|
|
),
|
|
ResourceLimitError,
|
|
)
|
|
token = self.get_success(
|
|
self.auth_handler.create_login_token_for_user_id(self.user1)
|
|
)
|
|
self.assertIsNone(self.token_login(token))
|
|
|
|
# If in monthly active cohort
|
|
self.hs.get_datastores().main.user_last_seen_monthly_active = AsyncMock(
|
|
return_value=self.clock.time_msec()
|
|
)
|
|
self.get_success(
|
|
self.auth_handler.create_access_token_for_user_id(
|
|
self.user1, device_id=None, valid_until_ms=None
|
|
)
|
|
)
|
|
token = self.get_success(
|
|
self.auth_handler.create_login_token_for_user_id(self.user1)
|
|
)
|
|
self.assertIsNotNone(self.token_login(token))
|
|
|
|
def test_mau_limits_not_exceeded(self) -> None:
|
|
self.auth_blocking._limit_usage_by_mau = True
|
|
|
|
self.hs.get_datastores().main.get_monthly_active_count = AsyncMock(
|
|
return_value=self.small_number_of_users
|
|
)
|
|
# Ensure does not raise exception
|
|
self.get_success(
|
|
self.auth_handler.create_access_token_for_user_id(
|
|
self.user1, device_id=None, valid_until_ms=None
|
|
)
|
|
)
|
|
|
|
self.hs.get_datastores().main.get_monthly_active_count = AsyncMock(
|
|
return_value=self.small_number_of_users
|
|
)
|
|
token = self.get_success(
|
|
self.auth_handler.create_login_token_for_user_id(self.user1)
|
|
)
|
|
self.assertIsNotNone(self.token_login(token))
|