0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-14 13:33:50 +01:00

Make DomainSpecificString an attrs class (#9875)

This commit is contained in:
Erik Johnston 2021-04-23 15:46:29 +01:00 committed by GitHub
parent ceaa76970f
commit a15c003e5b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 24 additions and 8 deletions

1
changelog.d/9875.misc Normal file
View file

@ -0,0 +1 @@
Make `DomainSpecificString` an `attrs` class.

View file

@ -957,6 +957,11 @@ class OidcProvider:
# and attempt to match it. # and attempt to match it.
attributes = await oidc_response_to_user_attributes(failures=0) attributes = await oidc_response_to_user_attributes(failures=0)
if attributes.localpart is None:
# If no localpart is returned then we will generate one, so
# there is no need to search for existing users.
return None
user_id = UserID(attributes.localpart, self._server_name).to_string() user_id = UserID(attributes.localpart, self._server_name).to_string()
users = await self._store.get_users_by_id_case_insensitive(user_id) users = await self._store.get_users_by_id_case_insensitive(user_id)
if users: if users:

View file

@ -61,6 +61,15 @@ class NewUserConsentResource(DirectServeHtmlResource):
self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code) self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code)
return return
# It should be impossible to get here without having first been through
# the pick-a-username step, which ensures chosen_localpart gets set.
if not session.chosen_localpart:
logger.warning("Session has no user name selected")
self._sso_handler.render_error(
request, "no_user", "No user name has been selected.", code=400
)
return
user_id = UserID(session.chosen_localpart, self._server_name) user_id = UserID(session.chosen_localpart, self._server_name)
user_profile = { user_profile = {
"display_name": session.display_name, "display_name": session.display_name,

View file

@ -199,9 +199,8 @@ def get_localpart_from_id(string):
DS = TypeVar("DS", bound="DomainSpecificString") DS = TypeVar("DS", bound="DomainSpecificString")
class DomainSpecificString( @attr.s(slots=True, frozen=True, repr=False)
namedtuple("DomainSpecificString", ("localpart", "domain")), metaclass=abc.ABCMeta class DomainSpecificString(metaclass=abc.ABCMeta):
):
"""Common base class among ID/name strings that have a local part and a """Common base class among ID/name strings that have a local part and a
domain name, prefixed with a sigil. domain name, prefixed with a sigil.
@ -213,11 +212,8 @@ class DomainSpecificString(
SIGIL = abc.abstractproperty() # type: str # type: ignore SIGIL = abc.abstractproperty() # type: str # type: ignore
# Deny iteration because it will bite you if you try to create a singleton localpart = attr.ib(type=str)
# set by: domain = attr.ib(type=str)
# users = set(user)
def __iter__(self):
raise ValueError("Attempted to iterate a %s" % (type(self).__name__,))
# Because this class is a namedtuple of strings and booleans, it is deeply # Because this class is a namedtuple of strings and booleans, it is deeply
# immutable. # immutable.
@ -272,30 +268,35 @@ class DomainSpecificString(
__repr__ = to_string __repr__ = to_string
@attr.s(slots=True, frozen=True, repr=False)
class UserID(DomainSpecificString): class UserID(DomainSpecificString):
"""Structure representing a user ID.""" """Structure representing a user ID."""
SIGIL = "@" SIGIL = "@"
@attr.s(slots=True, frozen=True, repr=False)
class RoomAlias(DomainSpecificString): class RoomAlias(DomainSpecificString):
"""Structure representing a room name.""" """Structure representing a room name."""
SIGIL = "#" SIGIL = "#"
@attr.s(slots=True, frozen=True, repr=False)
class RoomID(DomainSpecificString): class RoomID(DomainSpecificString):
"""Structure representing a room id. """ """Structure representing a room id. """
SIGIL = "!" SIGIL = "!"
@attr.s(slots=True, frozen=True, repr=False)
class EventID(DomainSpecificString): class EventID(DomainSpecificString):
"""Structure representing an event id. """ """Structure representing an event id. """
SIGIL = "$" SIGIL = "$"
@attr.s(slots=True, frozen=True, repr=False)
class GroupID(DomainSpecificString): class GroupID(DomainSpecificString):
"""Structure representing a group ID.""" """Structure representing a group ID."""