0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-06-14 16:48:18 +02:00

Parse Integer negative value validation (#16920)

This commit is contained in:
Gordan Trevis 2024-04-16 21:12:36 +02:00 committed by GitHub
parent 3a196b3227
commit f0d6f14047
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 89 additions and 158 deletions

1
changelog.d/16920.bugfix Normal file
View file

@ -0,0 +1 @@
Adds validation to ensure that the `limit` parameter on `/publicRooms` is non-negative.

View file

@ -19,7 +19,8 @@
# #
# #
""" This module contains base REST classes for constructing REST servlets. """ """This module contains base REST classes for constructing REST servlets."""
import enum import enum
import logging import logging
from http import HTTPStatus from http import HTTPStatus
@ -65,17 +66,49 @@ def parse_integer(request: Request, name: str, default: int) -> int: ...
@overload @overload
def parse_integer(request: Request, name: str, *, required: Literal[True]) -> int: ... def parse_integer(
request: Request, name: str, *, default: int, negative: bool
) -> int: ...
@overload @overload
def parse_integer( def parse_integer(
request: Request, name: str, default: Optional[int] = None, required: bool = False request: Request, name: str, *, default: int, negative: bool = False
) -> int: ...
@overload
def parse_integer(
request: Request, name: str, *, required: Literal[True], negative: bool = False
) -> int: ...
@overload
def parse_integer(
request: Request, name: str, *, default: Literal[None], negative: bool = False
) -> None: ...
@overload
def parse_integer(request: Request, name: str, *, negative: bool) -> Optional[int]: ...
@overload
def parse_integer(
request: Request,
name: str,
default: Optional[int] = None,
required: bool = False,
negative: bool = False,
) -> Optional[int]: ... ) -> Optional[int]: ...
def parse_integer( def parse_integer(
request: Request, name: str, default: Optional[int] = None, required: bool = False request: Request,
name: str,
default: Optional[int] = None,
required: bool = False,
negative: bool = False,
) -> Optional[int]: ) -> Optional[int]:
"""Parse an integer parameter from the request string """Parse an integer parameter from the request string
@ -85,16 +118,17 @@ def parse_integer(
default: value to use if the parameter is absent, defaults to None. default: value to use if the parameter is absent, defaults to None.
required: whether to raise a 400 SynapseError if the parameter is absent, required: whether to raise a 400 SynapseError if the parameter is absent,
defaults to False. defaults to False.
negative: whether to allow negative integers, defaults to True.
Returns: Returns:
An int value or the default. An int value or the default.
Raises: Raises:
SynapseError: if the parameter is absent and required, or if the SynapseError: if the parameter is absent and required, if the
parameter is present and not an integer. parameter is present and not an integer, or if the
parameter is illegitimate negative.
""" """
args: Mapping[bytes, Sequence[bytes]] = request.args # type: ignore args: Mapping[bytes, Sequence[bytes]] = request.args # type: ignore
return parse_integer_from_args(args, name, default, required) return parse_integer_from_args(args, name, default, required, negative)
@overload @overload
@ -120,6 +154,7 @@ def parse_integer_from_args(
name: str, name: str,
default: Optional[int] = None, default: Optional[int] = None,
required: bool = False, required: bool = False,
negative: bool = False,
) -> Optional[int]: ... ) -> Optional[int]: ...
@ -128,6 +163,7 @@ def parse_integer_from_args(
name: str, name: str,
default: Optional[int] = None, default: Optional[int] = None,
required: bool = False, required: bool = False,
negative: bool = True,
) -> Optional[int]: ) -> Optional[int]:
"""Parse an integer parameter from the request string """Parse an integer parameter from the request string
@ -137,33 +173,37 @@ def parse_integer_from_args(
default: value to use if the parameter is absent, defaults to None. default: value to use if the parameter is absent, defaults to None.
required: whether to raise a 400 SynapseError if the parameter is absent, required: whether to raise a 400 SynapseError if the parameter is absent,
defaults to False. defaults to False.
negative: whether to allow negative integers, defaults to True.
Returns: Returns:
An int value or the default. An int value or the default.
Raises: Raises:
SynapseError: if the parameter is absent and required, or if the SynapseError: if the parameter is absent and required, if the
parameter is present and not an integer. parameter is present and not an integer, or if the
parameter is illegitimate negative.
""" """
name_bytes = name.encode("ascii") name_bytes = name.encode("ascii")
if name_bytes in args: if name_bytes not in args:
try: if not required:
return int(args[name_bytes][0])
except Exception:
message = "Query parameter %r must be an integer" % (name,)
raise SynapseError(
HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM
)
else:
if required:
message = "Missing integer query parameter %r" % (name,)
raise SynapseError(
HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM
)
else:
return default return default
message = f"Missing required integer query parameter {name}"
raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.MISSING_PARAM)
try:
integer = int(args[name_bytes][0])
except Exception:
message = f"Query parameter {name} must be an integer"
raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM)
if not negative and integer < 0:
message = f"Query parameter {name} must be a positive integer."
raise SynapseError(HTTPStatus.BAD_REQUEST, message, errcode=Codes.INVALID_PARAM)
return integer
@overload @overload
def parse_boolean(request: Request, name: str, default: bool) -> bool: ... def parse_boolean(request: Request, name: str, default: bool) -> bool: ...

View file

@ -23,7 +23,7 @@ from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING, Tuple
from synapse.api.constants import Direction from synapse.api.constants import Direction
from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.api.errors import NotFoundError, SynapseError
from synapse.federation.transport.server import Authenticator from synapse.federation.transport.server import Authenticator
from synapse.http.servlet import RestServlet, parse_enum, parse_integer, parse_string from synapse.http.servlet import RestServlet, parse_enum, parse_integer, parse_string
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
@ -61,22 +61,8 @@ class ListDestinationsRestServlet(RestServlet):
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self._auth, request) await assert_requester_is_admin(self._auth, request)
start = parse_integer(request, "from", default=0) start = parse_integer(request, "from", default=0, negative=False)
limit = parse_integer(request, "limit", default=100) limit = parse_integer(request, "limit", default=100, negative=False)
if start < 0:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Query parameter from must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
if limit < 0:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Query parameter limit must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
destination = parse_string(request, "destination") destination = parse_string(request, "destination")
@ -195,22 +181,8 @@ class DestinationMembershipRestServlet(RestServlet):
if not await self._store.is_destination_known(destination): if not await self._store.is_destination_known(destination):
raise NotFoundError("Unknown destination") raise NotFoundError("Unknown destination")
start = parse_integer(request, "from", default=0) start = parse_integer(request, "from", default=0, negative=False)
limit = parse_integer(request, "limit", default=100) limit = parse_integer(request, "limit", default=100, negative=False)
if start < 0:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Query parameter from must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
if limit < 0:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Query parameter limit must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS) direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)

View file

@ -311,29 +311,17 @@ class DeleteMediaByDateSize(RestServlet):
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
before_ts = parse_integer(request, "before_ts", required=True) before_ts = parse_integer(request, "before_ts", required=True, negative=False)
size_gt = parse_integer(request, "size_gt", default=0) size_gt = parse_integer(request, "size_gt", default=0, negative=False)
keep_profiles = parse_boolean(request, "keep_profiles", default=True) keep_profiles = parse_boolean(request, "keep_profiles", default=True)
if before_ts < 0: if before_ts < 30000000000: # Dec 1970 in milliseconds, Aug 2920 in seconds
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Query parameter before_ts must be a positive integer.",
errcode=Codes.INVALID_PARAM,
)
elif before_ts < 30000000000: # Dec 1970 in milliseconds, Aug 2920 in seconds
raise SynapseError( raise SynapseError(
HTTPStatus.BAD_REQUEST, HTTPStatus.BAD_REQUEST,
"Query parameter before_ts you provided is from the year 1970. " "Query parameter before_ts you provided is from the year 1970. "
+ "Double check that you are providing a timestamp in milliseconds.", + "Double check that you are providing a timestamp in milliseconds.",
errcode=Codes.INVALID_PARAM, errcode=Codes.INVALID_PARAM,
) )
if size_gt < 0:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Query parameter size_gt must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
# This check is useless, we keep it for the legacy endpoint only. # This check is useless, we keep it for the legacy endpoint only.
if server_name is not None and self.server_name != server_name: if server_name is not None and self.server_name != server_name:
@ -389,22 +377,8 @@ class UserMediaRestServlet(RestServlet):
if user is None: if user is None:
raise NotFoundError("Unknown user") raise NotFoundError("Unknown user")
start = parse_integer(request, "from", default=0) start = parse_integer(request, "from", default=0, negative=False)
limit = parse_integer(request, "limit", default=100) limit = parse_integer(request, "limit", default=100, negative=False)
if start < 0:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Query parameter from must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
if limit < 0:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Query parameter limit must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
# If neither `order_by` nor `dir` is set, set the default order # If neither `order_by` nor `dir` is set, set the default order
# to newest media is on top for backward compatibility. # to newest media is on top for backward compatibility.
@ -447,22 +421,8 @@ class UserMediaRestServlet(RestServlet):
if user is None: if user is None:
raise NotFoundError("Unknown user") raise NotFoundError("Unknown user")
start = parse_integer(request, "from", default=0) start = parse_integer(request, "from", default=0, negative=False)
limit = parse_integer(request, "limit", default=100) limit = parse_integer(request, "limit", default=100, negative=False)
if start < 0:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Query parameter from must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
if limit < 0:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Query parameter limit must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
# If neither `order_by` nor `dir` is set, set the default order # If neither `order_by` nor `dir` is set, set the default order
# to newest media is on top for backward compatibility. # to newest media is on top for backward compatibility.

View file

@ -63,38 +63,12 @@ class UserMediaStatisticsRestServlet(RestServlet):
), ),
) )
start = parse_integer(request, "from", default=0) start = parse_integer(request, "from", default=0, negative=False)
if start < 0: limit = parse_integer(request, "limit", default=100, negative=False)
raise SynapseError( from_ts = parse_integer(request, "from_ts", default=0, negative=False)
HTTPStatus.BAD_REQUEST, until_ts = parse_integer(request, "until_ts", negative=False)
"Query parameter from must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
limit = parse_integer(request, "limit", default=100)
if limit < 0:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Query parameter limit must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
from_ts = parse_integer(request, "from_ts", default=0)
if from_ts < 0:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Query parameter from_ts must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
until_ts = parse_integer(request, "until_ts")
if until_ts is not None: if until_ts is not None:
if until_ts < 0:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Query parameter until_ts must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
if until_ts <= from_ts: if until_ts <= from_ts:
raise SynapseError( raise SynapseError(
HTTPStatus.BAD_REQUEST, HTTPStatus.BAD_REQUEST,

View file

@ -90,22 +90,8 @@ class UsersRestServletV2(RestServlet):
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
start = parse_integer(request, "from", default=0) start = parse_integer(request, "from", default=0, negative=False)
limit = parse_integer(request, "limit", default=100) limit = parse_integer(request, "limit", default=100, negative=False)
if start < 0:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Query parameter from must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
if limit < 0:
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Query parameter limit must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
user_id = parse_string(request, "user_id") user_id = parse_string(request, "user_id")
name = parse_string(request, "name", encoding="utf-8") name = parse_string(request, "name", encoding="utf-8")

View file

@ -499,7 +499,7 @@ class PublicRoomListRestServlet(RestServlet):
if server: if server:
raise e raise e
limit: Optional[int] = parse_integer(request, "limit", 0) limit: Optional[int] = parse_integer(request, "limit", 0, negative=False)
since_token = parse_string(request, "since") since_token = parse_string(request, "since")
if limit == 0: if limit == 0:

View file

@ -72,9 +72,6 @@ class PreviewUrlResource(RestServlet):
# XXX: if get_user_by_req fails, what should we do in an async render? # XXX: if get_user_by_req fails, what should we do in an async render?
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
url = parse_string(request, "url", required=True) url = parse_string(request, "url", required=True)
ts = parse_integer(request, "ts") ts = parse_integer(request, "ts", default=self.clock.time_msec())
if ts is None:
ts = self.clock.time_msec()
og = await self.url_previewer.preview(url, requester.user, ts) og = await self.url_previewer.preview(url, requester.user, ts)
respond_with_json_bytes(request, 200, og, send_cors=True) respond_with_json_bytes(request, 200, og, send_cors=True)

View file

@ -277,7 +277,8 @@ class DeleteMediaByDateSizeTestCase(_AdminMediaTests):
self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
self.assertEqual( self.assertEqual(
"Missing integer query parameter 'before_ts'", channel.json_body["error"] "Missing required integer query parameter before_ts",
channel.json_body["error"],
) )
def test_invalid_parameter(self) -> None: def test_invalid_parameter(self) -> None:
@ -320,7 +321,7 @@ class DeleteMediaByDateSizeTestCase(_AdminMediaTests):
self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual( self.assertEqual(
"Query parameter size_gt must be a string representing a positive integer.", "Query parameter size_gt must be a positive integer.",
channel.json_body["error"], channel.json_body["error"],
) )