0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-06-18 10:38:21 +02:00

A third batch of Pydantic validation for rest/client/account.py (#13736)

This commit is contained in:
David Robertson 2022-09-15 18:36:02 +01:00 committed by GitHub
parent 918c74bfb5
commit 742f9f9d78
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 78 additions and 45 deletions

View file

@ -0,0 +1 @@
Improve validation of request bodies for the following client-server API endpoints: [`/account/3pid/add`](https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3account3pidadd), [`/account/3pid/bind`](https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3account3pidbind), [`/account/3pid/delete`](https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3account3piddelete) and [`/account/3pid/unbind`](https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3account3pidunbind).

View file

@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
from urllib.parse import urlparse
from pydantic import StrictBool, StrictStr, constr
from typing_extensions import Literal
from twisted.web.server import Request
@ -43,6 +44,7 @@ from synapse.metrics import threepid_send_requests
from synapse.push.mailer import Mailer
from synapse.rest.client.models import (
AuthenticationData,
ClientSecretStr,
EmailRequestTokenBody,
MsisdnRequestTokenBody,
)
@ -627,6 +629,11 @@ class ThreepidAddRestServlet(RestServlet):
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
class PostBody(RequestBodyModel):
auth: Optional[AuthenticationData] = None
client_secret: ClientSecretStr
sid: StrictStr
@interactive_auth_handler
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if not self.hs.config.registration.enable_3pid_changes:
@ -636,22 +643,17 @@ class ThreepidAddRestServlet(RestServlet):
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ["client_secret", "sid"])
sid = body["sid"]
client_secret = body["client_secret"]
assert_valid_client_secret(client_secret)
body = parse_and_validate_json_object_from_request(request, self.PostBody)
await self.auth_handler.validate_user_via_ui_auth(
requester,
request,
body,
body.dict(exclude_unset=True),
"add a third-party identifier to your account",
)
validation_session = await self.identity_handler.validate_threepid_session(
client_secret, sid
body.client_secret, body.sid
)
if validation_session:
await self.auth_handler.add_threepid(
@ -676,23 +678,20 @@ class ThreepidBindRestServlet(RestServlet):
self.identity_handler = hs.get_identity_handler()
self.auth = hs.get_auth()
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request)
class PostBody(RequestBodyModel):
client_secret: ClientSecretStr
id_access_token: StrictStr
id_server: StrictStr
sid: StrictStr
assert_params_in_dict(
body, ["id_server", "sid", "id_access_token", "client_secret"]
)
id_server = body["id_server"]
sid = body["sid"]
id_access_token = body["id_access_token"]
client_secret = body["client_secret"]
assert_valid_client_secret(client_secret)
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_and_validate_json_object_from_request(request, self.PostBody)
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
await self.identity_handler.bind_threepid(
client_secret, sid, user_id, id_server, id_access_token
body.client_secret, body.sid, user_id, body.id_server, body.id_access_token
)
return 200, {}
@ -708,23 +707,27 @@ class ThreepidUnbindRestServlet(RestServlet):
self.auth = hs.get_auth()
self.datastore = self.hs.get_datastores().main
class PostBody(RequestBodyModel):
address: StrictStr
id_server: Optional[StrictStr] = None
medium: Literal["email", "msisdn"]
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
"""Unbind the given 3pid from a specific identity server, or identity servers that are
known to have this 3pid bound
"""
requester = await self.auth.get_user_by_req(request)
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ["medium", "address"])
medium = body.get("medium")
address = body.get("address")
id_server = body.get("id_server")
body = parse_and_validate_json_object_from_request(request, self.PostBody)
# Attempt to unbind the threepid from an identity server. If id_server is None, try to
# unbind from all identity servers this threepid has been added to in the past
result = await self.identity_handler.try_unbind_threepid(
requester.user.to_string(),
{"address": address, "medium": medium, "id_server": id_server},
{
"address": body.address,
"medium": body.medium,
"id_server": body.id_server,
},
)
return 200, {"id_server_unbind_result": "success" if result else "no-support"}
@ -738,21 +741,25 @@ class ThreepidDeleteRestServlet(RestServlet):
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
class PostBody(RequestBodyModel):
address: StrictStr
id_server: Optional[StrictStr] = None
medium: Literal["email", "msisdn"]
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if not self.hs.config.registration.enable_3pid_changes:
raise SynapseError(
400, "3PID changes are disabled on this server", Codes.FORBIDDEN
)
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ["medium", "address"])
body = parse_and_validate_json_object_from_request(request, self.PostBody)
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
try:
ret = await self.auth_handler.delete_threepid(
user_id, body["medium"], body["address"], body.get("id_server")
user_id, body.medium, body.address, body.id_server
)
except Exception:
# NB. This endpoint should succeed if there is nothing to

View file

@ -36,18 +36,20 @@ class AuthenticationData(RequestBodyModel):
type: Optional[StrictStr] = None
class ThreePidRequestTokenBody(RequestBodyModel):
if TYPE_CHECKING:
client_secret: StrictStr
else:
# See also assert_valid_client_secret()
client_secret: constr(
regex="[0-9a-zA-Z.=_-]", # noqa: F722
min_length=0,
max_length=255,
strict=True,
)
if TYPE_CHECKING:
ClientSecretStr = StrictStr
else:
# See also assert_valid_client_secret()
ClientSecretStr = constr(
regex="[0-9a-zA-Z.=_-]", # noqa: F722
min_length=1,
max_length=255,
strict=True,
)
class ThreepidRequestTokenBody(RequestBodyModel):
client_secret: ClientSecretStr
id_server: Optional[StrictStr]
id_access_token: Optional[StrictStr]
next_link: Optional[StrictStr]
@ -62,7 +64,7 @@ class ThreePidRequestTokenBody(RequestBodyModel):
return token
class EmailRequestTokenBody(ThreePidRequestTokenBody):
class EmailRequestTokenBody(ThreepidRequestTokenBody):
email: StrictStr
# Canonicalise the email address. The addresses are all stored canonicalised
@ -80,6 +82,6 @@ else:
ISO3116_1_Alpha_2 = constr(regex="[A-Z]{2}", strict=True)
class MsisdnRequestTokenBody(ThreePidRequestTokenBody):
class MsisdnRequestTokenBody(ThreepidRequestTokenBody):
country: ISO3116_1_Alpha_2
phone_number: StrictStr

View file

@ -11,14 +11,37 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import unittest as stdlib_unittest
from pydantic import ValidationError
from pydantic import BaseModel, ValidationError
from typing_extensions import Literal
from synapse.rest.client.models import EmailRequestTokenBody
class EmailRequestTokenBodyTestCase(unittest.TestCase):
class ThreepidMediumEnumTestCase(stdlib_unittest.TestCase):
class Model(BaseModel):
medium: Literal["email", "msisdn"]
def test_accepts_valid_medium_string(self) -> None:
"""Sanity check that Pydantic behaves sensibly with an enum-of-str
This is arguably more of a test of a class that inherits from str and Enum
simultaneously.
"""
model = self.Model.parse_obj({"medium": "email"})
self.assertEqual(model.medium, "email")
def test_rejects_invalid_medium_value(self) -> None:
with self.assertRaises(ValidationError):
self.Model.parse_obj({"medium": "interpretive_dance"})
def test_rejects_invalid_medium_type(self) -> None:
with self.assertRaises(ValidationError):
self.Model.parse_obj({"medium": 123})
class EmailRequestTokenBodyTestCase(stdlib_unittest.TestCase):
base_request = {
"client_secret": "hunter2",
"email": "alice@wonderland.com",