0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-11-19 16:32:24 +01:00

Merge pull request #4666 from matrix-org/erikj/register_login_split

Split out registration to worker
This commit is contained in:
Erik Johnston 2019-02-18 17:18:06 +00:00 committed by GitHub
commit fc2c245a1f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 450 additions and 255 deletions

1
changelog.d/4666.feature Normal file
View file

@ -0,0 +1 @@
Allow registration to be handled by a worker instance.

View file

@ -223,6 +223,12 @@ following regular expressions::
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/members$ ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/members$
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/state$ ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/state$
Additionally, the following REST endpoints can be handled, but all requests must
be routed to the same instance::
^/_matrix/client/(api/v1|r0|unstable)/register$
``synapse.app.user_dir`` ``synapse.app.user_dir``
~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~

View file

@ -47,6 +47,7 @@ from synapse.rest.client.v1.room import (
RoomMemberListRestServlet, RoomMemberListRestServlet,
RoomStateRestServlet, RoomStateRestServlet,
) )
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree from synapse.util.httpresourcetree import create_resource_tree
@ -92,6 +93,7 @@ class ClientReaderServer(HomeServer):
JoinedRoomMemberListRestServlet(self).register(resource) JoinedRoomMemberListRestServlet(self).register(resource)
RoomStateRestServlet(self).register(resource) RoomStateRestServlet(self).register(resource)
RoomEventContextServlet(self).register(resource) RoomEventContextServlet(self).register(resource)
RegisterRestServlet(self).register(resource)
resources.update({ resources.update({
"/_matrix/client/r0": resource, "/_matrix/client/r0": resource,

View file

@ -27,6 +27,8 @@ from synapse.api.errors import (
SynapseError, SynapseError,
) )
from synapse.http.client import CaptchaServerHttpClient from synapse.http.client import CaptchaServerHttpClient
from synapse.replication.http.login import RegisterDeviceReplicationServlet
from synapse.replication.http.register import ReplicationRegisterServlet
from synapse.types import RoomAlias, RoomID, UserID, create_requester from synapse.types import RoomAlias, RoomID, UserID, create_requester
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.threepids import check_3pid_allowed from synapse.util.threepids import check_3pid_allowed
@ -61,6 +63,14 @@ class RegistrationHandler(BaseHandler):
) )
self._server_notices_mxid = hs.config.server_notices_mxid self._server_notices_mxid = hs.config.server_notices_mxid
if hs.config.worker_app:
self._register_client = ReplicationRegisterServlet.make_client(hs)
self._register_device_client = (
RegisterDeviceReplicationServlet.make_client(hs)
)
else:
self.device_handler = hs.get_device_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def check_username(self, localpart, guest_access_token=None, def check_username(self, localpart, guest_access_token=None,
assigned_user_id=None): assigned_user_id=None):
@ -155,7 +165,7 @@ class RegistrationHandler(BaseHandler):
yield self.auth.check_auth_blocking(threepid=threepid) yield self.auth.check_auth_blocking(threepid=threepid)
password_hash = None password_hash = None
if password: if password:
password_hash = yield self.auth_handler().hash(password) password_hash = yield self._auth_handler.hash(password)
if localpart: if localpart:
yield self.check_username(localpart, guest_access_token=guest_access_token) yield self.check_username(localpart, guest_access_token=guest_access_token)
@ -185,7 +195,7 @@ class RegistrationHandler(BaseHandler):
token = None token = None
if generate_token: if generate_token:
token = self.macaroon_gen.generate_access_token(user_id) token = self.macaroon_gen.generate_access_token(user_id)
yield self.store.register( yield self._register_with_store(
user_id=user_id, user_id=user_id,
token=token, token=token,
password_hash=password_hash, password_hash=password_hash,
@ -217,7 +227,7 @@ class RegistrationHandler(BaseHandler):
if default_display_name is None: if default_display_name is None:
default_display_name = localpart default_display_name = localpart
try: try:
yield self.store.register( yield self._register_with_store(
user_id=user_id, user_id=user_id,
token=token, token=token,
password_hash=password_hash, password_hash=password_hash,
@ -316,7 +326,7 @@ class RegistrationHandler(BaseHandler):
user_id, allowed_appservice=service user_id, allowed_appservice=service
) )
yield self.store.register( yield self._register_with_store(
user_id=user_id, user_id=user_id,
password_hash="", password_hash="",
appservice_id=service_id, appservice_id=service_id,
@ -494,7 +504,7 @@ class RegistrationHandler(BaseHandler):
token = self.macaroon_gen.generate_access_token(user_id) token = self.macaroon_gen.generate_access_token(user_id)
if need_register: if need_register:
yield self.store.register( yield self._register_with_store(
user_id=user_id, user_id=user_id,
token=token, token=token,
password_hash=password_hash, password_hash=password_hash,
@ -512,9 +522,6 @@ class RegistrationHandler(BaseHandler):
defer.returnValue((user_id, token)) defer.returnValue((user_id, token))
def auth_handler(self):
return self.hs.get_auth_handler()
@defer.inlineCallbacks @defer.inlineCallbacks
def get_or_register_3pid_guest(self, medium, address, inviter_user_id): def get_or_register_3pid_guest(self, medium, address, inviter_user_id):
"""Get a guest access token for a 3PID, creating a guest account if """Get a guest access token for a 3PID, creating a guest account if
@ -573,3 +580,94 @@ class RegistrationHandler(BaseHandler):
action="join", action="join",
ratelimit=False, ratelimit=False,
) )
def _register_with_store(self, user_id, token=None, password_hash=None,
was_guest=False, make_guest=False, appservice_id=None,
create_profile_with_displayname=None, admin=False,
user_type=None):
"""Register user in the datastore.
Args:
user_id (str): The desired user ID to register.
token (str): The desired access token to use for this user. If this
is not None, the given access token is associated with the user
id.
password_hash (str|None): Optional. The password hash for this user.
was_guest (bool): Optional. Whether this is a guest account being
upgraded to a non-guest account.
make_guest (boolean): True if the the new user should be guest,
false to add a regular user account.
appservice_id (str|None): The ID of the appservice registering the user.
create_profile_with_displayname (unicode|None): Optionally create a
profile for the user, setting their displayname to the given value
admin (boolean): is an admin user?
user_type (str|None): type of user. One of the values from
api.constants.UserTypes, or None for a normal user.
Returns:
Deferred
"""
if self.hs.config.worker_app:
return self._register_client(
user_id=user_id,
token=token,
password_hash=password_hash,
was_guest=was_guest,
make_guest=make_guest,
appservice_id=appservice_id,
create_profile_with_displayname=create_profile_with_displayname,
admin=admin,
user_type=user_type,
)
else:
return self.store.register(
user_id=user_id,
token=token,
password_hash=password_hash,
was_guest=was_guest,
make_guest=make_guest,
appservice_id=appservice_id,
create_profile_with_displayname=create_profile_with_displayname,
admin=admin,
user_type=user_type,
)
@defer.inlineCallbacks
def register_device(self, user_id, device_id, initial_display_name,
is_guest=False):
"""Register a device for a user and generate an access token.
Args:
user_id (str): full canonical @user:id
device_id (str|None): The device ID to check, or None to generate
a new one.
initial_display_name (str|None): An optional display name for the
device.
is_guest (bool): Whether this is a guest account
Returns:
defer.Deferred[tuple[str, str]]: Tuple of device ID and access token
"""
if self.hs.config.worker_app:
r = yield self._register_device_client(
user_id=user_id,
device_id=device_id,
initial_display_name=initial_display_name,
is_guest=is_guest,
)
defer.returnValue((r["device_id"], r["access_token"]))
else:
device_id = yield self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
if is_guest:
access_token = self.macaroon_gen.generate_access_token(
user_id, ["guest = true"]
)
else:
access_token = yield self._auth_handler.get_access_token_for_user_id(
user_id, device_id=device_id,
)
defer.returnValue((device_id, access_token))

View file

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
from synapse.replication.http import federation, membership, send_event from synapse.replication.http import federation, login, membership, register, send_event
REPLICATION_PREFIX = "/_synapse/replication" REPLICATION_PREFIX = "/_synapse/replication"
@ -28,3 +28,5 @@ class ReplicationRestResource(JsonResource):
send_event.register_servlets(hs, self) send_event.register_servlets(hs, self)
membership.register_servlets(hs, self) membership.register_servlets(hs, self)
federation.register_servlets(hs, self) federation.register_servlets(hs, self)
login.register_servlets(hs, self)
register.register_servlets(hs, self)

View file

@ -0,0 +1,74 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
logger = logging.getLogger(__name__)
class RegisterDeviceReplicationServlet(ReplicationEndpoint):
"""Ensure a device is registered, generating a new access token for the
device.
Used during registration and login.
"""
NAME = "device_check_registered"
PATH_ARGS = ("user_id",)
def __init__(self, hs):
super(RegisterDeviceReplicationServlet, self).__init__(hs)
self.registration_handler = hs.get_handlers().registration_handler
@staticmethod
def _serialize_payload(user_id, device_id, initial_display_name, is_guest):
"""
Args:
device_id (str|None): Device ID to use, if None a new one is
generated.
initial_display_name (str|None)
is_guest (bool)
"""
return {
"device_id": device_id,
"initial_display_name": initial_display_name,
"is_guest": is_guest,
}
@defer.inlineCallbacks
def _handle_request(self, request, user_id):
content = parse_json_object_from_request(request)
device_id = content["device_id"]
initial_display_name = content["initial_display_name"]
is_guest = content["is_guest"]
device_id, access_token = yield self.registration_handler.register_device(
user_id, device_id, initial_display_name, is_guest,
)
defer.returnValue((200, {
"device_id": device_id,
"access_token": access_token,
}))
def register_servlets(hs, http_server):
RegisterDeviceReplicationServlet(hs).register(http_server)

View file

@ -0,0 +1,91 @@
# -*- coding: utf-8 -*-
# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
logger = logging.getLogger(__name__)
class ReplicationRegisterServlet(ReplicationEndpoint):
"""Register a new user
"""
NAME = "register_user"
PATH_ARGS = ("user_id",)
def __init__(self, hs):
super(ReplicationRegisterServlet, self).__init__(hs)
self.store = hs.get_datastore()
@staticmethod
def _serialize_payload(
user_id, token, password_hash, was_guest, make_guest, appservice_id,
create_profile_with_displayname, admin, user_type,
):
"""
Args:
user_id (str): The desired user ID to register.
token (str): The desired access token to use for this user. If this
is not None, the given access token is associated with the user
id.
password_hash (str|None): Optional. The password hash for this user.
was_guest (bool): Optional. Whether this is a guest account being
upgraded to a non-guest account.
make_guest (boolean): True if the the new user should be guest,
false to add a regular user account.
appservice_id (str|None): The ID of the appservice registering the user.
create_profile_with_displayname (unicode|None): Optionally create a
profile for the user, setting their displayname to the given value
admin (boolean): is an admin user?
user_type (str|None): type of user. One of the values from
api.constants.UserTypes, or None for a normal user.
"""
return {
"token": token,
"password_hash": password_hash,
"was_guest": was_guest,
"make_guest": make_guest,
"appservice_id": appservice_id,
"create_profile_with_displayname": create_profile_with_displayname,
"admin": admin,
"user_type": user_type,
}
@defer.inlineCallbacks
def _handle_request(self, request, user_id):
content = parse_json_object_from_request(request)
yield self.store.register(
user_id=user_id,
token=content["token"],
password_hash=content["password_hash"],
was_guest=content["was_guest"],
make_guest=content["make_guest"],
appservice_id=content["appservice_id"],
create_profile_with_displayname=content["create_profile_with_displayname"],
admin=content["admin"],
user_type=content["user_type"],
)
defer.returnValue((200, {}))
def register_servlets(hs, http_server):
ReplicationRegisterServlet(hs).register(http_server)

View file

@ -94,7 +94,7 @@ class LoginRestServlet(ClientV1RestServlet):
self.jwt_algorithm = hs.config.jwt_algorithm self.jwt_algorithm = hs.config.jwt_algorithm
self.cas_enabled = hs.config.cas_enabled self.cas_enabled = hs.config.cas_enabled
self.auth_handler = self.hs.get_auth_handler() self.auth_handler = self.hs.get_auth_handler()
self.device_handler = self.hs.get_device_handler() self.registration_handler = hs.get_handlers().registration_handler
self.handlers = hs.get_handlers() self.handlers = hs.get_handlers()
self._well_known_builder = WellKnownBuilder(hs) self._well_known_builder = WellKnownBuilder(hs)
@ -220,11 +220,10 @@ class LoginRestServlet(ClientV1RestServlet):
login_submission, login_submission,
) )
device_id = yield self._register_device( device_id = login_submission.get("device_id")
canonical_user_id, login_submission, initial_display_name = login_submission.get("initial_device_display_name")
) device_id, access_token = yield self.registration_handler.register_device(
access_token = yield auth_handler.get_access_token_for_user_id( canonical_user_id, device_id, initial_display_name,
canonical_user_id, device_id,
) )
result = { result = {
@ -246,10 +245,13 @@ class LoginRestServlet(ClientV1RestServlet):
user_id = ( user_id = (
yield auth_handler.validate_short_term_login_token_and_get_user_id(token) yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
) )
device_id = yield self._register_device(user_id, login_submission)
access_token = yield auth_handler.get_access_token_for_user_id( device_id = login_submission.get("device_id")
user_id, device_id, initial_display_name = login_submission.get("initial_device_display_name")
device_id, access_token = yield self.registration_handler.register_device(
user_id, device_id, initial_display_name,
) )
result = { result = {
"user_id": user_id, # may have changed "user_id": user_id, # may have changed
"access_token": access_token, "access_token": access_token,
@ -286,11 +288,10 @@ class LoginRestServlet(ClientV1RestServlet):
auth_handler = self.auth_handler auth_handler = self.auth_handler
registered_user_id = yield auth_handler.check_user_exists(user_id) registered_user_id = yield auth_handler.check_user_exists(user_id)
if registered_user_id: if registered_user_id:
device_id = yield self._register_device( device_id = login_submission.get("device_id")
registered_user_id, login_submission initial_display_name = login_submission.get("initial_device_display_name")
) device_id, access_token = yield self.registration_handler.register_device(
access_token = yield auth_handler.get_access_token_for_user_id( registered_user_id, device_id, initial_display_name,
registered_user_id, device_id,
) )
result = { result = {
@ -299,12 +300,16 @@ class LoginRestServlet(ClientV1RestServlet):
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
} }
else: else:
# TODO: we should probably check that the register isn't going
# to fonx/change our user_id before registering the device
device_id = yield self._register_device(user_id, login_submission)
user_id, access_token = ( user_id, access_token = (
yield self.handlers.registration_handler.register(localpart=user) yield self.handlers.registration_handler.register(localpart=user)
) )
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get("initial_device_display_name")
device_id, access_token = yield self.registration_handler.register_device(
registered_user_id, device_id, initial_display_name,
)
result = { result = {
"user_id": user_id, # may have changed "user_id": user_id, # may have changed
"access_token": access_token, "access_token": access_token,
@ -313,26 +318,6 @@ class LoginRestServlet(ClientV1RestServlet):
defer.returnValue(result) defer.returnValue(result)
def _register_device(self, user_id, login_submission):
"""Register a device for a user.
This is called after the user's credentials have been validated, but
before the access token has been issued.
Args:
(str) user_id: full canonical @user:id
(object) login_submission: dictionary supplied to /login call, from
which we pull device_id and initial_device_name
Returns:
defer.Deferred: (str) device_id
"""
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get(
"initial_device_display_name")
return self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
class CasRedirectServlet(RestServlet): class CasRedirectServlet(RestServlet):
PATTERNS = client_path_patterns("/login/(cas|sso)/redirect") PATTERNS = client_path_patterns("/login/(cas|sso)/redirect")

View file

@ -190,7 +190,6 @@ class RegisterRestServlet(RestServlet):
self.registration_handler = hs.get_handlers().registration_handler self.registration_handler = hs.get_handlers().registration_handler
self.identity_handler = hs.get_handlers().identity_handler self.identity_handler = hs.get_handlers().identity_handler
self.room_member_handler = hs.get_room_member_handler() self.room_member_handler = hs.get_room_member_handler()
self.device_handler = hs.get_device_handler()
self.macaroon_gen = hs.get_macaroon_generator() self.macaroon_gen = hs.get_macaroon_generator()
@interactive_auth_handler @interactive_auth_handler
@ -633,12 +632,10 @@ class RegisterRestServlet(RestServlet):
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
} }
if not params.get("inhibit_login", False): if not params.get("inhibit_login", False):
device_id = yield self._register_device(user_id, params) device_id = params.get("device_id")
initial_display_name = params.get("initial_device_display_name")
access_token = ( device_id, access_token = yield self.registration_handler.register_device(
yield self.auth_handler.get_access_token_for_user_id( user_id, device_id, initial_display_name, is_guest=False,
user_id, device_id=device_id,
)
) )
result.update({ result.update({
@ -647,26 +644,6 @@ class RegisterRestServlet(RestServlet):
}) })
defer.returnValue(result) defer.returnValue(result)
def _register_device(self, user_id, params):
"""Register a device for a user.
This is called after the user's credentials have been validated, but
before the access token has been issued.
Args:
(str) user_id: full canonical @user:id
(object) params: registration parameters, from which we pull
device_id and initial_device_name
Returns:
defer.Deferred: (str) device_id
"""
# register the user's device
device_id = params.get("device_id")
initial_display_name = params.get("initial_device_display_name")
return self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
@defer.inlineCallbacks @defer.inlineCallbacks
def _do_guest_registration(self, params): def _do_guest_registration(self, params):
if not self.hs.config.allow_guest_access: if not self.hs.config.allow_guest_access:
@ -680,13 +657,10 @@ class RegisterRestServlet(RestServlet):
# we have nowhere to store it. # we have nowhere to store it.
device_id = synapse.api.auth.GUEST_DEVICE_ID device_id = synapse.api.auth.GUEST_DEVICE_ID
initial_display_name = params.get("initial_device_display_name") initial_display_name = params.get("initial_device_display_name")
yield self.device_handler.check_device_registered( device_id, access_token = yield self.registration_handler.register_device(
user_id, device_id, initial_display_name user_id, device_id, initial_display_name, is_guest=True,
) )
access_token = self.macaroon_gen.generate_access_token(
user_id, ["guest = true"]
)
defer.returnValue((200, { defer.returnValue((200, {
"user_id": user_id, "user_id": user_id,
"device_id": device_id, "device_id": device_id,

View file

@ -139,6 +139,121 @@ class RegistrationWorkerStore(SQLBaseStore):
) )
return True if res == UserTypes.SUPPORT else False return True if res == UserTypes.SUPPORT else False
def get_users_by_id_case_insensitive(self, user_id):
"""Gets users that match user_id case insensitively.
Returns a mapping of user_id -> password_hash.
"""
def f(txn):
sql = (
"SELECT name, password_hash FROM users"
" WHERE lower(name) = lower(?)"
)
txn.execute(sql, (user_id,))
return dict(txn)
return self.runInteraction("get_users_by_id_case_insensitive", f)
@defer.inlineCallbacks
def count_all_users(self):
"""Counts all users registered on the homeserver."""
def _count_users(txn):
txn.execute("SELECT COUNT(*) AS users FROM users")
rows = self.cursor_to_dict(txn)
if rows:
return rows[0]["users"]
return 0
ret = yield self.runInteraction("count_users", _count_users)
defer.returnValue(ret)
def count_daily_user_type(self):
"""
Counts 1) native non guest users
2) native guests users
3) bridged users
who registered on the homeserver in the past 24 hours
"""
def _count_daily_user_type(txn):
yesterday = int(self._clock.time()) - (60 * 60 * 24)
sql = """
SELECT user_type, COALESCE(count(*), 0) AS count FROM (
SELECT
CASE
WHEN is_guest=0 AND appservice_id IS NULL THEN 'native'
WHEN is_guest=1 AND appservice_id IS NULL THEN 'guest'
WHEN is_guest=0 AND appservice_id IS NOT NULL THEN 'bridged'
END AS user_type
FROM users
WHERE creation_ts > ?
) AS t GROUP BY user_type
"""
results = {'native': 0, 'guest': 0, 'bridged': 0}
txn.execute(sql, (yesterday,))
for row in txn:
results[row[0]] = row[1]
return results
return self.runInteraction("count_daily_user_type", _count_daily_user_type)
@defer.inlineCallbacks
def count_nonbridged_users(self):
def _count_users(txn):
txn.execute("""
SELECT COALESCE(COUNT(*), 0) FROM users
WHERE appservice_id IS NULL
""")
count, = txn.fetchone()
return count
ret = yield self.runInteraction("count_users", _count_users)
defer.returnValue(ret)
@defer.inlineCallbacks
def find_next_generated_user_id_localpart(self):
"""
Gets the localpart of the next generated user ID.
Generated user IDs are integers, and we aim for them to be as small as
we can. Unfortunately, it's possible some of them are already taken by
existing users, and there may be gaps in the already taken range. This
function returns the start of the first allocatable gap. This is to
avoid the case of ID 10000000 being pre-allocated, so us wasting the
first (and shortest) many generated user IDs.
"""
def _find_next_generated_user_id(txn):
txn.execute("SELECT name FROM users")
regex = re.compile(r"^@(\d+):")
found = set()
for user_id, in txn:
match = regex.search(user_id)
if match:
found.add(int(match.group(1)))
for i in range(len(found) + 1):
if i not in found:
return i
defer.returnValue((yield self.runInteraction(
"find_next_generated_user_id",
_find_next_generated_user_id
)))
@defer.inlineCallbacks
def get_3pid_guest_access_token(self, medium, address):
ret = yield self._simple_select_one(
"threepid_guest_access_tokens",
{
"medium": medium,
"address": address
},
["guest_access_token"], True, 'get_3pid_guest_access_token'
)
if ret:
defer.returnValue(ret["guest_access_token"])
defer.returnValue(None)
class RegistrationStore(RegistrationWorkerStore, class RegistrationStore(RegistrationWorkerStore,
background_updates.BackgroundUpdateStore): background_updates.BackgroundUpdateStore):
@ -326,20 +441,6 @@ class RegistrationStore(RegistrationWorkerStore,
) )
txn.call_after(self.is_guest.invalidate, (user_id,)) txn.call_after(self.is_guest.invalidate, (user_id,))
def get_users_by_id_case_insensitive(self, user_id):
"""Gets users that match user_id case insensitively.
Returns a mapping of user_id -> password_hash.
"""
def f(txn):
sql = (
"SELECT name, password_hash FROM users"
" WHERE lower(name) = lower(?)"
)
txn.execute(sql, (user_id,))
return dict(txn)
return self.runInteraction("get_users_by_id_case_insensitive", f)
def user_set_password_hash(self, user_id, password_hash): def user_set_password_hash(self, user_id, password_hash):
""" """
NB. This does *not* evict any cache because the one use for this NB. This does *not* evict any cache because the one use for this
@ -564,107 +665,6 @@ class RegistrationStore(RegistrationWorkerStore,
desc="user_delete_threepids", desc="user_delete_threepids",
) )
@defer.inlineCallbacks
def count_all_users(self):
"""Counts all users registered on the homeserver."""
def _count_users(txn):
txn.execute("SELECT COUNT(*) AS users FROM users")
rows = self.cursor_to_dict(txn)
if rows:
return rows[0]["users"]
return 0
ret = yield self.runInteraction("count_users", _count_users)
defer.returnValue(ret)
def count_daily_user_type(self):
"""
Counts 1) native non guest users
2) native guests users
3) bridged users
who registered on the homeserver in the past 24 hours
"""
def _count_daily_user_type(txn):
yesterday = int(self._clock.time()) - (60 * 60 * 24)
sql = """
SELECT user_type, COALESCE(count(*), 0) AS count FROM (
SELECT
CASE
WHEN is_guest=0 AND appservice_id IS NULL THEN 'native'
WHEN is_guest=1 AND appservice_id IS NULL THEN 'guest'
WHEN is_guest=0 AND appservice_id IS NOT NULL THEN 'bridged'
END AS user_type
FROM users
WHERE creation_ts > ?
) AS t GROUP BY user_type
"""
results = {'native': 0, 'guest': 0, 'bridged': 0}
txn.execute(sql, (yesterday,))
for row in txn:
results[row[0]] = row[1]
return results
return self.runInteraction("count_daily_user_type", _count_daily_user_type)
@defer.inlineCallbacks
def count_nonbridged_users(self):
def _count_users(txn):
txn.execute("""
SELECT COALESCE(COUNT(*), 0) FROM users
WHERE appservice_id IS NULL
""")
count, = txn.fetchone()
return count
ret = yield self.runInteraction("count_users", _count_users)
defer.returnValue(ret)
@defer.inlineCallbacks
def find_next_generated_user_id_localpart(self):
"""
Gets the localpart of the next generated user ID.
Generated user IDs are integers, and we aim for them to be as small as
we can. Unfortunately, it's possible some of them are already taken by
existing users, and there may be gaps in the already taken range. This
function returns the start of the first allocatable gap. This is to
avoid the case of ID 10000000 being pre-allocated, so us wasting the
first (and shortest) many generated user IDs.
"""
def _find_next_generated_user_id(txn):
txn.execute("SELECT name FROM users")
regex = re.compile(r"^@(\d+):")
found = set()
for user_id, in txn:
match = regex.search(user_id)
if match:
found.add(int(match.group(1)))
for i in range(len(found) + 1):
if i not in found:
return i
defer.returnValue((yield self.runInteraction(
"find_next_generated_user_id",
_find_next_generated_user_id
)))
@defer.inlineCallbacks
def get_3pid_guest_access_token(self, medium, address):
ret = yield self._simple_select_one(
"threepid_guest_access_tokens",
{
"medium": medium,
"address": address
},
["guest_access_token"], True, 'get_3pid_guest_access_token'
)
if ret:
defer.returnValue(ret["guest_access_token"])
defer.returnValue(None)
@defer.inlineCallbacks @defer.inlineCallbacks
def save_or_get_3pid_guest_access_token( def save_or_get_3pid_guest_access_token(
self, medium, address, access_token, inviter_user_id self, medium, address, access_token, inviter_user_id

View file

@ -1,10 +1,7 @@
import json import json
from mock import Mock from synapse.api.constants import LoginType
from synapse.appservice import ApplicationService
from twisted.python import failure
from synapse.api.errors import InteractiveAuthIncompleteError
from synapse.rest.client.v2_alpha.register import register_servlets from synapse.rest.client.v2_alpha.register import register_servlets
from tests import unittest from tests import unittest
@ -18,50 +15,28 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.url = b"/_matrix/client/r0/register" self.url = b"/_matrix/client/r0/register"
self.appservice = None
self.auth = Mock(
get_appservice_by_req=Mock(side_effect=lambda x: self.appservice)
)
self.auth_result = failure.Failure(InteractiveAuthIncompleteError(None))
self.auth_handler = Mock(
check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
get_session_data=Mock(return_value=None),
)
self.registration_handler = Mock()
self.identity_handler = Mock()
self.login_handler = Mock()
self.device_handler = Mock()
self.device_handler.check_device_registered = Mock(return_value="FAKE")
self.datastore = Mock(return_value=Mock())
self.datastore.get_current_state_deltas = Mock(return_value=[])
# do the dance to hook it up to the hs global
self.handlers = Mock(
registration_handler=self.registration_handler,
identity_handler=self.identity_handler,
login_handler=self.login_handler,
)
self.hs = self.setup_test_homeserver() self.hs = self.setup_test_homeserver()
self.hs.get_auth = Mock(return_value=self.auth)
self.hs.get_handlers = Mock(return_value=self.handlers)
self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
self.hs.get_device_handler = Mock(return_value=self.device_handler)
self.hs.get_datastore = Mock(return_value=self.datastore)
self.hs.config.enable_registration = True self.hs.config.enable_registration = True
self.hs.config.registrations_require_3pid = [] self.hs.config.registrations_require_3pid = []
self.hs.config.auto_join_rooms = [] self.hs.config.auto_join_rooms = []
self.hs.config.enable_registration_captcha = False
return self.hs return self.hs
def test_POST_appservice_registration_valid(self): def test_POST_appservice_registration_valid(self):
user_id = "@kermit:muppet" user_id = "@as_user_kermit:test"
token = "kermits_access_token" as_token = "i_am_an_app_service"
self.appservice = {"id": "1234"}
self.registration_handler.appservice_register = Mock(return_value=user_id) appservice = ApplicationService(
self.auth_handler.get_access_token_for_user_id = Mock(return_value=token) as_token, self.hs.config.hostname,
request_data = json.dumps({"username": "kermit"}) id="1234",
namespaces={
"users": [{"regex": r"@as_user.*", "exclusive": True}],
},
)
self.hs.get_datastore().services_cache.append(appservice)
request_data = json.dumps({"username": "as_user_kermit"})
request, channel = self.make_request( request, channel = self.make_request(
b"POST", self.url + b"?access_token=i_am_an_app_service", request_data b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
@ -71,7 +46,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result) self.assertEquals(channel.result["code"], b"200", channel.result)
det_data = { det_data = {
"user_id": user_id, "user_id": user_id,
"access_token": token,
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
} }
self.assertDictContainsSubset(det_data, channel.json_body) self.assertDictContainsSubset(det_data, channel.json_body)
@ -103,39 +77,30 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.json_body["error"], "Invalid username") self.assertEquals(channel.json_body["error"], "Invalid username")
def test_POST_user_valid(self): def test_POST_user_valid(self):
user_id = "@kermit:muppet" user_id = "@kermit:test"
token = "kermits_access_token"
device_id = "frogfone" device_id = "frogfone"
request_data = json.dumps( params = {
{"username": "kermit", "password": "monkey", "device_id": device_id} "username": "kermit",
) "password": "monkey",
self.registration_handler.check_username = Mock(return_value=True) "device_id": device_id,
self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None) "auth": {"type": LoginType.DUMMY},
self.registration_handler.register = Mock(return_value=(user_id, None)) }
self.auth_handler.get_access_token_for_user_id = Mock(return_value=token) request_data = json.dumps(params)
self.device_handler.check_device_registered = Mock(return_value=device_id)
request, channel = self.make_request(b"POST", self.url, request_data) request, channel = self.make_request(b"POST", self.url, request_data)
self.render(request) self.render(request)
det_data = { det_data = {
"user_id": user_id, "user_id": user_id,
"access_token": token,
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
"device_id": device_id, "device_id": device_id,
} }
self.assertEquals(channel.result["code"], b"200", channel.result) self.assertEquals(channel.result["code"], b"200", channel.result)
self.assertDictContainsSubset(det_data, channel.json_body) self.assertDictContainsSubset(det_data, channel.json_body)
self.auth_handler.get_login_tuple_for_user_id(
user_id, device_id=device_id, initial_device_display_name=None
)
def test_POST_disabled_registration(self): def test_POST_disabled_registration(self):
self.hs.config.enable_registration = False self.hs.config.enable_registration = False
request_data = json.dumps({"username": "kermit", "password": "monkey"}) request_data = json.dumps({"username": "kermit", "password": "monkey"})
self.registration_handler.check_username = Mock(return_value=True)
self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None) self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
self.registration_handler.register = Mock(return_value=("@user:id", "t"))
request, channel = self.make_request(b"POST", self.url, request_data) request, channel = self.make_request(b"POST", self.url, request_data)
self.render(request) self.render(request)
@ -144,16 +109,13 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.json_body["error"], "Registration has been disabled") self.assertEquals(channel.json_body["error"], "Registration has been disabled")
def test_POST_guest_registration(self): def test_POST_guest_registration(self):
user_id = "a@b"
self.hs.config.macaroon_secret_key = "test" self.hs.config.macaroon_secret_key = "test"
self.hs.config.allow_guest_access = True self.hs.config.allow_guest_access = True
self.registration_handler.register = Mock(return_value=(user_id, None))
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
self.render(request) self.render(request)
det_data = { det_data = {
"user_id": user_id,
"home_server": self.hs.hostname, "home_server": self.hs.hostname,
"device_id": "guest_device", "device_id": "guest_device",
} }