mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 19:13:51 +01:00
A second batch of Pydantic models for rest/client/account.py (#13687)
This commit is contained in:
parent
d3d9ca156e
commit
b58386e37e
4 changed files with 64 additions and 34 deletions
1
changelog.d/13687.feature
Normal file
1
changelog.d/13687.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Improve validation of request bodies for the following client-server API endpoints: [`/account/3pid/msisdn/requestToken`](https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3account3pidmsisdnrequesttoken) and [`/org.matrix.msc3720/account_status`](https://github.com/matrix-org/matrix-spec-proposals/blob/babolivier/user_status/proposals/3720-account-status.md#post-_matrixclientv1account_status).
|
|
@ -28,7 +28,8 @@ from typing import (
|
||||||
overload,
|
overload,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel, MissingError, PydanticValueError, ValidationError
|
||||||
|
from pydantic.error_wrappers import ErrorWrapper
|
||||||
from typing_extensions import Literal
|
from typing_extensions import Literal
|
||||||
|
|
||||||
from twisted.web.server import Request
|
from twisted.web.server import Request
|
||||||
|
@ -714,7 +715,21 @@ def parse_and_validate_json_object_from_request(
|
||||||
try:
|
try:
|
||||||
instance = model_type.parse_obj(content)
|
instance = model_type.parse_obj(content)
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
raise SynapseError(HTTPStatus.BAD_REQUEST, str(e), errcode=Codes.BAD_JSON)
|
# Choose a matrix error code. The catch-all is BAD_JSON, but we try to find a
|
||||||
|
# more specific error if possible (which occasionally helps us to be spec-
|
||||||
|
# compliant) This is a bit awkward because the spec's error codes aren't very
|
||||||
|
# clear-cut: BAD_JSON arguably overlaps with MISSING_PARAM and INVALID_PARAM.
|
||||||
|
errcode = Codes.BAD_JSON
|
||||||
|
|
||||||
|
raw_errors = e.raw_errors
|
||||||
|
if len(raw_errors) == 1 and isinstance(raw_errors[0], ErrorWrapper):
|
||||||
|
raw_error = raw_errors[0].exc
|
||||||
|
if isinstance(raw_error, MissingError):
|
||||||
|
errcode = Codes.MISSING_PARAM
|
||||||
|
elif isinstance(raw_error, PydanticValueError):
|
||||||
|
errcode = Codes.INVALID_PARAM
|
||||||
|
|
||||||
|
raise SynapseError(HTTPStatus.BAD_REQUEST, str(e), errcode=errcode)
|
||||||
|
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
from typing import TYPE_CHECKING, Optional, Tuple
|
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from pydantic import StrictBool, StrictStr, constr
|
from pydantic import StrictBool, StrictStr, constr
|
||||||
|
@ -41,7 +41,11 @@ from synapse.http.servlet import (
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
from synapse.metrics import threepid_send_requests
|
from synapse.metrics import threepid_send_requests
|
||||||
from synapse.push.mailer import Mailer
|
from synapse.push.mailer import Mailer
|
||||||
from synapse.rest.client.models import AuthenticationData, EmailRequestTokenBody
|
from synapse.rest.client.models import (
|
||||||
|
AuthenticationData,
|
||||||
|
EmailRequestTokenBody,
|
||||||
|
MsisdnRequestTokenBody,
|
||||||
|
)
|
||||||
from synapse.rest.models import RequestBodyModel
|
from synapse.rest.models import RequestBodyModel
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
from synapse.util.msisdn import phone_number_to_msisdn
|
from synapse.util.msisdn import phone_number_to_msisdn
|
||||||
|
@ -400,23 +404,16 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
|
||||||
self.identity_handler = hs.get_identity_handler()
|
self.identity_handler = hs.get_identity_handler()
|
||||||
|
|
||||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||||
body = parse_json_object_from_request(request)
|
body = parse_and_validate_json_object_from_request(
|
||||||
assert_params_in_dict(
|
request, MsisdnRequestTokenBody
|
||||||
body, ["client_secret", "country", "phone_number", "send_attempt"]
|
|
||||||
)
|
)
|
||||||
client_secret = body["client_secret"]
|
msisdn = phone_number_to_msisdn(body.country, body.phone_number)
|
||||||
assert_valid_client_secret(client_secret)
|
|
||||||
|
|
||||||
country = body["country"]
|
|
||||||
phone_number = body["phone_number"]
|
|
||||||
send_attempt = body["send_attempt"]
|
|
||||||
next_link = body.get("next_link") # Optional param
|
|
||||||
|
|
||||||
msisdn = phone_number_to_msisdn(country, phone_number)
|
|
||||||
|
|
||||||
if not await check_3pid_allowed(self.hs, "msisdn", msisdn):
|
if not await check_3pid_allowed(self.hs, "msisdn", msisdn):
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
403,
|
403,
|
||||||
|
# TODO: is this error message accurate? Looks like we've only rejected
|
||||||
|
# this phone number, not necessarily all phone numbers
|
||||||
"Account phone numbers are not authorized on this server",
|
"Account phone numbers are not authorized on this server",
|
||||||
Codes.THREEPID_DENIED,
|
Codes.THREEPID_DENIED,
|
||||||
)
|
)
|
||||||
|
@ -425,9 +422,9 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
|
||||||
request, "msisdn", msisdn
|
request, "msisdn", msisdn
|
||||||
)
|
)
|
||||||
|
|
||||||
if next_link:
|
if body.next_link:
|
||||||
# Raise if the provided next_link value isn't valid
|
# Raise if the provided next_link value isn't valid
|
||||||
assert_valid_next_link(self.hs, next_link)
|
assert_valid_next_link(self.hs, body.next_link)
|
||||||
|
|
||||||
existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn)
|
existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn)
|
||||||
|
|
||||||
|
@ -454,15 +451,15 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
|
||||||
|
|
||||||
ret = await self.identity_handler.requestMsisdnToken(
|
ret = await self.identity_handler.requestMsisdnToken(
|
||||||
self.hs.config.registration.account_threepid_delegate_msisdn,
|
self.hs.config.registration.account_threepid_delegate_msisdn,
|
||||||
country,
|
body.country,
|
||||||
phone_number,
|
body.phone_number,
|
||||||
client_secret,
|
body.client_secret,
|
||||||
send_attempt,
|
body.send_attempt,
|
||||||
next_link,
|
body.next_link,
|
||||||
)
|
)
|
||||||
|
|
||||||
threepid_send_requests.labels(type="msisdn", reason="add_threepid").observe(
|
threepid_send_requests.labels(type="msisdn", reason="add_threepid").observe(
|
||||||
send_attempt
|
body.send_attempt
|
||||||
)
|
)
|
||||||
|
|
||||||
return 200, ret
|
return 200, ret
|
||||||
|
@ -845,17 +842,18 @@ class AccountStatusRestServlet(RestServlet):
|
||||||
self._auth = hs.get_auth()
|
self._auth = hs.get_auth()
|
||||||
self._account_handler = hs.get_account_handler()
|
self._account_handler = hs.get_account_handler()
|
||||||
|
|
||||||
|
class PostBody(RequestBodyModel):
|
||||||
|
# TODO: we could validate that each user id is an mxid here, and/or parse it
|
||||||
|
# as a UserID
|
||||||
|
user_ids: List[StrictStr]
|
||||||
|
|
||||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||||
await self._auth.get_user_by_req(request)
|
await self._auth.get_user_by_req(request)
|
||||||
|
|
||||||
body = parse_json_object_from_request(request)
|
body = parse_and_validate_json_object_from_request(request, self.PostBody)
|
||||||
if "user_ids" not in body:
|
|
||||||
raise SynapseError(
|
|
||||||
400, "Required parameter 'user_ids' is missing", Codes.MISSING_PARAM
|
|
||||||
)
|
|
||||||
|
|
||||||
statuses, failures = await self._account_handler.get_account_statuses(
|
statuses, failures = await self._account_handler.get_account_statuses(
|
||||||
body["user_ids"],
|
body.user_ids,
|
||||||
allow_remote=True,
|
allow_remote=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -25,8 +25,8 @@ class AuthenticationData(RequestBodyModel):
|
||||||
|
|
||||||
(The name "Authentication Data" is taken directly from the spec.)
|
(The name "Authentication Data" is taken directly from the spec.)
|
||||||
|
|
||||||
Additional keys will be present, depending on the `type` field. Use `.dict()` to
|
Additional keys will be present, depending on the `type` field. Use
|
||||||
access them.
|
`.dict(exclude_unset=True)` to access them.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
|
@ -36,7 +36,7 @@ class AuthenticationData(RequestBodyModel):
|
||||||
type: Optional[StrictStr] = None
|
type: Optional[StrictStr] = None
|
||||||
|
|
||||||
|
|
||||||
class EmailRequestTokenBody(RequestBodyModel):
|
class ThreePidRequestTokenBody(RequestBodyModel):
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
client_secret: StrictStr
|
client_secret: StrictStr
|
||||||
else:
|
else:
|
||||||
|
@ -47,7 +47,7 @@ class EmailRequestTokenBody(RequestBodyModel):
|
||||||
max_length=255,
|
max_length=255,
|
||||||
strict=True,
|
strict=True,
|
||||||
)
|
)
|
||||||
email: StrictStr
|
|
||||||
id_server: Optional[StrictStr]
|
id_server: Optional[StrictStr]
|
||||||
id_access_token: Optional[StrictStr]
|
id_access_token: Optional[StrictStr]
|
||||||
next_link: Optional[StrictStr]
|
next_link: Optional[StrictStr]
|
||||||
|
@ -61,9 +61,25 @@ class EmailRequestTokenBody(RequestBodyModel):
|
||||||
raise ValueError("id_access_token is required if an id_server is supplied.")
|
raise ValueError("id_access_token is required if an id_server is supplied.")
|
||||||
return token
|
return token
|
||||||
|
|
||||||
|
|
||||||
|
class EmailRequestTokenBody(ThreePidRequestTokenBody):
|
||||||
|
email: StrictStr
|
||||||
|
|
||||||
# Canonicalise the email address. The addresses are all stored canonicalised
|
# Canonicalise the email address. The addresses are all stored canonicalised
|
||||||
# in the database. This allows the user to reset his password without having to
|
# in the database. This allows the user to reset his password without having to
|
||||||
# know the exact spelling (eg. upper and lower case) of address in the database.
|
# know the exact spelling (eg. upper and lower case) of address in the database.
|
||||||
# Without this, an email stored in the database as "foo@bar.com" would cause
|
# Without this, an email stored in the database as "foo@bar.com" would cause
|
||||||
# user requests for "FOO@bar.com" to raise a Not Found error.
|
# user requests for "FOO@bar.com" to raise a Not Found error.
|
||||||
_email_validator = validator("email", allow_reuse=True)(validate_email)
|
_email_validator = validator("email", allow_reuse=True)(validate_email)
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
ISO3116_1_Alpha_2 = StrictStr
|
||||||
|
else:
|
||||||
|
# Per spec: two-letter uppercase ISO-3166-1-alpha-2
|
||||||
|
ISO3116_1_Alpha_2 = constr(regex="[A-Z]{2}", strict=True)
|
||||||
|
|
||||||
|
|
||||||
|
class MsisdnRequestTokenBody(ThreePidRequestTokenBody):
|
||||||
|
country: ISO3116_1_Alpha_2
|
||||||
|
phone_number: StrictStr
|
||||||
|
|
Loading…
Reference in a new issue