0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-22 03:14:02 +01:00

Misc typing fixes for tests, part 1 of N (#11323)

* Annotate HomeserverTestCase.servlets
* Correct annotation of federation_auth_origin
* Use AnyStr custom_headers instead of a Union

This allows (str, str) and (bytes, bytes).
This disallows (str, bytes) and (bytes, str)

* DomainSpecificString.SIGIL is a ClassVar
This commit is contained in:
David Robertson 2021-11-12 15:50:54 +00:00 committed by GitHub
parent 95547e5300
commit 4c96ce396e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 53 additions and 29 deletions

1
changelog.d/11323.misc Normal file
View file

@ -0,0 +1 @@
Improve type annotations in Synapse's test suite.

View file

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Callable
from synapse.http.server import HttpServer, JsonResource
from synapse.rest import admin
@ -62,6 +62,8 @@ from synapse.rest.client import (
if TYPE_CHECKING:
from synapse.server import HomeServer
RegisterServletsFunc = Callable[["HomeServer", HttpServer], None]
class ClientRestResource(JsonResource):
"""Matrix Client API REST resource.

View file

@ -19,6 +19,7 @@ from collections import namedtuple
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Dict,
Mapping,
MutableMapping,
@ -219,7 +220,7 @@ class DomainSpecificString(metaclass=abc.ABCMeta):
'domain' : The domain part of the name
"""
SIGIL: str = abc.abstractproperty() # type: ignore
SIGIL: ClassVar[str] = abc.abstractproperty() # type: ignore
localpart = attr.ib(type=str)
domain = attr.ib(type=str)

View file

@ -12,13 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
from twisted.internet.protocol import Protocol
from twisted.web.resource import Resource
from synapse.app.generic_worker import GenericWorkerServer
from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest, SynapseSite
from synapse.replication.http import ReplicationRestResource
from synapse.replication.tcp.client import ReplicationDataHandler
@ -220,8 +219,6 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
unlike `BaseStreamTestCase`.
"""
servlets: List[Callable[[HomeServer, JsonResource], None]] = []
def setUp(self):
super().setUp()

View file

@ -19,7 +19,17 @@ import json
import re
import time
import urllib.parse
from typing import Any, Dict, Iterable, Mapping, MutableMapping, Optional, Tuple, Union
from typing import (
Any,
AnyStr,
Dict,
Iterable,
Mapping,
MutableMapping,
Optional,
Tuple,
Union,
)
from unittest.mock import patch
import attr
@ -53,9 +63,7 @@ class RestHelper:
tok: Optional[str] = None,
expect_code: int = 200,
extra_content: Optional[Dict] = None,
custom_headers: Optional[
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
] = None,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
) -> str:
"""
Create a room.
@ -227,9 +235,7 @@ class RestHelper:
txn_id=None,
tok=None,
expect_code=200,
custom_headers: Optional[
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
] = None,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
):
if body is None:
body = "body_text_here"
@ -418,7 +424,7 @@ class RestHelper:
path,
content=image_data,
access_token=tok,
custom_headers=[(b"Content-Length", str(image_length))],
custom_headers=[("Content-Length", str(image_length))],
)
assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (

View file

@ -16,7 +16,16 @@ import json
import logging
from collections import deque
from io import SEEK_END, BytesIO
from typing import Callable, Dict, Iterable, MutableMapping, Optional, Tuple, Union
from typing import (
AnyStr,
Callable,
Dict,
Iterable,
MutableMapping,
Optional,
Tuple,
Union,
)
import attr
from typing_extensions import Deque
@ -222,9 +231,7 @@ def make_request(
federation_auth_origin: Optional[bytes] = None,
content_is_form: bool = False,
await_result: bool = True,
custom_headers: Optional[
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
] = None,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
client_ip: str = "127.0.0.1",
) -> FakeChannel:
"""

View file

@ -20,7 +20,20 @@ import inspect
import logging
import secrets
import time
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, TypeVar, Union
from typing import (
Any,
AnyStr,
Callable,
ClassVar,
Dict,
Iterable,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
from unittest.mock import Mock, patch
from canonicaljson import json
@ -45,6 +58,7 @@ from synapse.logging.context import (
current_context,
set_current_context,
)
from synapse.rest import RegisterServletsFunc
from synapse.server import HomeServer
from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock
@ -204,15 +218,15 @@ class HomeserverTestCase(TestCase):
config dict.
Attributes:
servlets (list[function]): List of servlet registration function.
servlets: List of servlet registration function.
user_id (str): The user ID to assume if auth is hijacked.
hijack_auth (bool): Whether to hijack auth to return the user specified
in user_id.
"""
servlets = []
hijack_auth = True
needs_threadpool = False
servlets: ClassVar[List[RegisterServletsFunc]] = []
def __init__(self, methodName, *args, **kwargs):
super().__init__(methodName, *args, **kwargs)
@ -405,12 +419,10 @@ class HomeserverTestCase(TestCase):
access_token: Optional[str] = None,
request: Type[T] = SynapseRequest,
shorthand: bool = True,
federation_auth_origin: str = None,
federation_auth_origin: Optional[bytes] = None,
content_is_form: bool = False,
await_result: bool = True,
custom_headers: Optional[
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
] = None,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
client_ip: str = "127.0.0.1",
) -> FakeChannel:
"""
@ -425,7 +437,7 @@ class HomeserverTestCase(TestCase):
a dict.
shorthand: Whether to try and be helpful and prefix the given URL
with the usual REST API path, if it doesn't contain it.
federation_auth_origin (bytes|None): if set to not-None, we will add a fake
federation_auth_origin: if set to not-None, we will add a fake
Authorization header pretenting to be the given server name.
content_is_form: Whether the content is URL encoded form data. Adds the
'Content-Type': 'application/x-www-form-urlencoded' header.
@ -639,9 +651,7 @@ class HomeserverTestCase(TestCase):
username,
password,
device_id=None,
custom_headers: Optional[
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
] = None,
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
):
"""
Log in a user, and get an access token. Requires the Login API be