Update black, and run auto formatting over the codebase (#9381)

- Update black version to the latest
 - Run black auto formatting over the codebase
    - Run autoformatting according to [`docs/code_style.md
`](80d6dc9783/docs/code_style.md)
 - Update `code_style.md` docs around installing black to use the correct version
This commit is contained in:
Eric Eastwood 2021-02-16 16:32:34 -06:00 committed by GitHub
parent 5636e597c3
commit 0a00b7ff14
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
271 changed files with 2802 additions and 1713 deletions

1
changelog.d/9381.misc Normal file
View File

@ -0,0 +1 @@
Update the version of black used to 20.8b1.

View File

@ -92,7 +92,7 @@ class SynapseCmd(cmd.Cmd):
return self.config["user"].split(":")[1]
def do_config(self, line):
""" Show the config for this client: "config"
"""Show the config for this client: "config"
Edit a key value mapping: "config key value" e.g. "config token 1234"
Config variables:
user: The username to auth with.
@ -360,7 +360,7 @@ class SynapseCmd(cmd.Cmd):
print(e)
def do_topic(self, line):
""""topic [set|get] <roomid> [<newtopic>]"
""" "topic [set|get] <roomid> [<newtopic>]"
Set the topic for a room: topic set <roomid> <newtopic>
Get the topic for a room: topic get <roomid>
"""
@ -690,7 +690,7 @@ class SynapseCmd(cmd.Cmd):
self._do_presence_state(2, line)
def _parse(self, line, keys, force_keys=False):
""" Parses the given line.
"""Parses the given line.
Args:
line : The line to parse
@ -721,7 +721,7 @@ class SynapseCmd(cmd.Cmd):
query_params={"access_token": None},
alt_text=None,
):
""" Runs an HTTP request and pretty prints the output.
"""Runs an HTTP request and pretty prints the output.
Args:
method: HTTP method

View File

@ -23,11 +23,10 @@ from twisted.web.http_headers import Headers
class HttpClient:
""" Interface for talking json over http
"""
"""Interface for talking json over http"""
def put_json(self, url, data):
""" Sends the specifed json data using PUT
"""Sends the specifed json data using PUT
Args:
url (str): The URL to PUT data to.
@ -41,7 +40,7 @@ class HttpClient:
pass
def get_json(self, url, args=None):
""" Gets some json from the given host homeserver and path
"""Gets some json from the given host homeserver and path
Args:
url (str): The URL to GET data from.
@ -58,7 +57,7 @@ class HttpClient:
class TwistedHttpClient(HttpClient):
""" Wrapper around the twisted HTTP client api.
"""Wrapper around the twisted HTTP client api.
Attributes:
agent (twisted.web.client.Agent): The twisted Agent used to send the
@ -87,8 +86,7 @@ class TwistedHttpClient(HttpClient):
defer.returnValue(json.loads(body))
def _create_put_request(self, url, json_data, headers_dict={}):
""" Wrapper of _create_request to issue a PUT request
"""
"""Wrapper of _create_request to issue a PUT request"""
if "Content-Type" not in headers_dict:
raise defer.error(RuntimeError("Must include Content-Type header for PUTs"))
@ -98,8 +96,7 @@ class TwistedHttpClient(HttpClient):
)
def _create_get_request(self, url, headers_dict={}):
""" Wrapper of _create_request to issue a GET request
"""
"""Wrapper of _create_request to issue a GET request"""
return self._create_request("GET", url, headers_dict=headers_dict)
@defer.inlineCallbacks
@ -127,8 +124,7 @@ class TwistedHttpClient(HttpClient):
@defer.inlineCallbacks
def _create_request(self, method, url, producer=None, headers_dict={}):
""" Creates and sends a request to the given url
"""
"""Creates and sends a request to the given url"""
headers_dict["User-Agent"] = ["Synapse Cmd Client"]
retries_left = 5
@ -185,8 +181,7 @@ class _RawProducer:
class _JsonProducer:
""" Used by the twisted http client to create the HTTP body from json
"""
"""Used by the twisted http client to create the HTTP body from json"""
def __init__(self, jsn):
self.data = jsn

View File

@ -63,8 +63,7 @@ class CursesStdIO:
self.redraw()
def redraw(self):
""" method for redisplaying lines
based on internal list of lines """
"""method for redisplaying lines based on internal list of lines"""
self.stdscr.clear()
self.paintStatus(self.statusText)

View File

@ -56,7 +56,7 @@ def excpetion_errback(failure):
class InputOutput:
""" This is responsible for basic I/O so that a user can interact with
"""This is responsible for basic I/O so that a user can interact with
the example app.
"""
@ -68,8 +68,7 @@ class InputOutput:
self.server = server
def on_line(self, line):
""" This is where we process commands.
"""
"""This is where we process commands."""
try:
m = re.match(r"^join (\S+)$", line)
@ -133,7 +132,7 @@ class IOLoggerHandler(logging.Handler):
class Room:
""" Used to store (in memory) the current membership state of a room, and
"""Used to store (in memory) the current membership state of a room, and
which home servers we should send PDUs associated with the room to.
"""
@ -148,8 +147,7 @@ class Room:
self.have_got_metadata = False
def add_participant(self, participant):
""" Someone has joined the room
"""
"""Someone has joined the room"""
self.participants.add(participant)
self.invited.discard(participant)
@ -160,14 +158,13 @@ class Room:
self.oldest_server = server
def add_invited(self, invitee):
""" Someone has been invited to the room
"""
"""Someone has been invited to the room"""
self.invited.add(invitee)
self.servers.add(origin_from_ucid(invitee))
class HomeServer(ReplicationHandler):
""" A very basic home server implentation that allows people to join a
"""A very basic home server implentation that allows people to join a
room and then invite other people.
"""
@ -181,8 +178,7 @@ class HomeServer(ReplicationHandler):
self.output = output
def on_receive_pdu(self, pdu):
""" We just received a PDU
"""
"""We just received a PDU"""
pdu_type = pdu.pdu_type
if pdu_type == "sy.room.message":
@ -199,23 +195,20 @@ class HomeServer(ReplicationHandler):
)
def _on_message(self, pdu):
""" We received a message
"""
"""We received a message"""
self.output.print_line(
"#%s %s %s" % (pdu.context, pdu.content["sender"], pdu.content["body"])
)
def _on_join(self, context, joinee):
""" Someone has joined a room, either a remote user or a local user
"""
"""Someone has joined a room, either a remote user or a local user"""
room = self._get_or_create_room(context)
room.add_participant(joinee)
self.output.print_line("#%s %s %s" % (context, joinee, "*** JOINED"))
def _on_invite(self, origin, context, invitee):
""" Someone has been invited
"""
"""Someone has been invited"""
room = self._get_or_create_room(context)
room.add_invited(invitee)
@ -228,8 +221,7 @@ class HomeServer(ReplicationHandler):
@defer.inlineCallbacks
def send_message(self, room_name, sender, body):
""" Send a message to a room!
"""
"""Send a message to a room!"""
destinations = yield self.get_servers_for_context(room_name)
try:
@ -247,8 +239,7 @@ class HomeServer(ReplicationHandler):
@defer.inlineCallbacks
def join_room(self, room_name, sender, joinee):
""" Join a room!
"""
"""Join a room!"""
self._on_join(room_name, joinee)
destinations = yield self.get_servers_for_context(room_name)
@ -269,8 +260,7 @@ class HomeServer(ReplicationHandler):
@defer.inlineCallbacks
def invite_to_room(self, room_name, sender, invitee):
""" Invite someone to a room!
"""
"""Invite someone to a room!"""
self._on_invite(self.server_name, room_name, invitee)
destinations = yield self.get_servers_for_context(room_name)

View File

@ -193,15 +193,12 @@ class TrivialXmppClient:
time.sleep(7)
print("SSRC spammer started")
while self.running:
ssrcMsg = (
"<presence to='%(tojid)s' xmlns='jabber:client'><x xmlns='http://jabber.org/protocol/muc'/><c xmlns='http://jabber.org/protocol/caps' hash='sha-1' node='http://jitsi.org/jitsimeet' ver='0WkSdhFnAUxrz4ImQQLdB80GFlE='/><nick xmlns='http://jabber.org/protocol/nick'>%(nick)s</nick><stats xmlns='http://jitsi.org/jitmeet/stats'><stat name='bitrate_download' value='175'/><stat name='bitrate_upload' value='176'/><stat name='packetLoss_total' value='0'/><stat name='packetLoss_download' value='0'/><stat name='packetLoss_upload' value='0'/></stats><media xmlns='http://estos.de/ns/mjs'><source type='audio' ssrc='%(assrc)s' direction='sendre'/><source type='video' ssrc='%(vssrc)s' direction='sendre'/></media></presence>"
% {
"tojid": "%s@%s/%s" % (ROOMNAME, ROOMDOMAIN, self.shortJid),
"nick": self.userId,
"assrc": self.ssrcs["audio"],
"vssrc": self.ssrcs["video"],
}
)
ssrcMsg = "<presence to='%(tojid)s' xmlns='jabber:client'><x xmlns='http://jabber.org/protocol/muc'/><c xmlns='http://jabber.org/protocol/caps' hash='sha-1' node='http://jitsi.org/jitsimeet' ver='0WkSdhFnAUxrz4ImQQLdB80GFlE='/><nick xmlns='http://jabber.org/protocol/nick'>%(nick)s</nick><stats xmlns='http://jitsi.org/jitmeet/stats'><stat name='bitrate_download' value='175'/><stat name='bitrate_upload' value='176'/><stat name='packetLoss_total' value='0'/><stat name='packetLoss_download' value='0'/><stat name='packetLoss_upload' value='0'/></stats><media xmlns='http://estos.de/ns/mjs'><source type='audio' ssrc='%(assrc)s' direction='sendre'/><source type='video' ssrc='%(vssrc)s' direction='sendre'/></media></presence>" % {
"tojid": "%s@%s/%s" % (ROOMNAME, ROOMDOMAIN, self.shortJid),
"nick": self.userId,
"assrc": self.ssrcs["audio"],
"vssrc": self.ssrcs["video"],
}
res = self.sendIq(ssrcMsg)
print("reply from ssrc announce: ", res)
time.sleep(10)

View File

@ -8,16 +8,16 @@ errors in code.
The necessary tools are detailed below.
First install them with:
pip install -e ".[lint,mypy]"
- **black**
The Synapse codebase uses [black](https://pypi.org/project/black/)
as an opinionated code formatter, ensuring all comitted code is
properly formatted.
First install `black` with:
pip install --upgrade black
Have `black` auto-format your code (it shouldn't change any
functionality) with:
@ -28,10 +28,6 @@ The necessary tools are detailed below.
`flake8` is a code checking tool. We require code to pass `flake8`
before being merged into the codebase.
Install `flake8` with:
pip install --upgrade flake8 flake8-comprehensions
Check all application and test code with:
flake8 synapse tests
@ -41,10 +37,6 @@ The necessary tools are detailed below.
`isort` ensures imports are nicely formatted, and can suggest and
auto-fix issues such as double-importing.
Install `isort` with:
pip install --upgrade isort
Auto-fix imports with:
isort -rc synapse tests

View File

@ -87,7 +87,9 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
arg_kinds.append(ARG_NAMED_OPT) # Arg is an optional kwarg.
signature = signature.copy_modified(
arg_types=arg_types, arg_names=arg_names, arg_kinds=arg_kinds,
arg_types=arg_types,
arg_names=arg_names,
arg_kinds=arg_kinds,
)
return signature

View File

@ -97,7 +97,7 @@ CONDITIONAL_REQUIREMENTS["all"] = list(ALL_OPTIONAL_REQUIREMENTS)
# We pin black so that our tests don't start failing on new releases.
CONDITIONAL_REQUIREMENTS["lint"] = [
"isort==5.7.0",
"black==19.10b0",
"black==20.8b1",
"flake8-comprehensions",
"flake8",
]

View File

@ -89,12 +89,16 @@ class SortedDict(Dict[_KT, _VT]):
def __reduce__(
self,
) -> Tuple[
Type[SortedDict[_KT, _VT]], Tuple[Callable[[_KT], Any], List[Tuple[_KT, _VT]]],
Type[SortedDict[_KT, _VT]],
Tuple[Callable[[_KT], Any], List[Tuple[_KT, _VT]]],
]: ...
def __repr__(self) -> str: ...
def _check(self) -> None: ...
def islice(
self, start: Optional[int] = ..., stop: Optional[int] = ..., reverse=bool,
self,
start: Optional[int] = ...,
stop: Optional[int] = ...,
reverse=bool,
) -> Iterator[_KT]: ...
def bisect_left(self, value: _KT) -> int: ...
def bisect_right(self, value: _KT) -> int: ...

View File

@ -31,7 +31,9 @@ class SortedList(MutableSequence[_T]):
DEFAULT_LOAD_FACTOR: int = ...
def __init__(
self, iterable: Optional[Iterable[_T]] = ..., key: Optional[_Key[_T]] = ...,
self,
iterable: Optional[Iterable[_T]] = ...,
key: Optional[_Key[_T]] = ...,
): ...
# NB: currently mypy does not honour return type, see mypy #3307
@overload
@ -76,10 +78,18 @@ class SortedList(MutableSequence[_T]):
def __len__(self) -> int: ...
def reverse(self) -> None: ...
def islice(
self, start: Optional[int] = ..., stop: Optional[int] = ..., reverse=bool,
self,
start: Optional[int] = ...,
stop: Optional[int] = ...,
reverse=bool,
) -> Iterator[_T]: ...
def _islice(
self, min_pos: int, min_idx: int, max_pos: int, max_idx: int, reverse: bool,
self,
min_pos: int,
min_idx: int,
max_pos: int,
max_idx: int,
reverse: bool,
) -> Iterator[_T]: ...
def irange(
self,

View File

@ -168,7 +168,7 @@ class Auth:
rights: str = "access",
allow_expired: bool = False,
) -> synapse.types.Requester:
""" Get a registered user's ID.
"""Get a registered user's ID.
Args:
request: An HTTP request with an access_token query parameter.
@ -294,9 +294,12 @@ class Auth:
return user_id, app_service
async def get_user_by_access_token(
self, token: str, rights: str = "access", allow_expired: bool = False,
self,
token: str,
rights: str = "access",
allow_expired: bool = False,
) -> TokenLookupResult:
""" Validate access token and get user_id from it
"""Validate access token and get user_id from it
Args:
token: The access token to get the user by
@ -489,7 +492,7 @@ class Auth:
return service
async def is_server_admin(self, user: UserID) -> bool:
""" Check if the given user is a local server admin.
"""Check if the given user is a local server admin.
Args:
user: user to check
@ -500,7 +503,10 @@ class Auth:
return await self.store.is_server_admin(user)
def compute_auth_events(
self, event, current_state_ids: StateMap[str], for_verification: bool = False,
self,
event,
current_state_ids: StateMap[str],
for_verification: bool = False,
) -> List[str]:
"""Given an event and current state return the list of event IDs used
to auth an event.

View File

@ -128,8 +128,7 @@ class UserTypes:
class RelationTypes:
"""The types of relations known to this server.
"""
"""The types of relations known to this server."""
ANNOTATION = "m.annotation"
REPLACE = "m.replace"

View File

@ -390,8 +390,7 @@ class InvalidCaptchaError(SynapseError):
class LimitExceededError(SynapseError):
"""A client has sent too many requests and is being throttled.
"""
"""A client has sent too many requests and is being throttled."""
def __init__(
self,
@ -408,8 +407,7 @@ class LimitExceededError(SynapseError):
class RoomKeysVersionError(SynapseError):
"""A client has tried to upload to a non-current version of the room_keys store
"""
"""A client has tried to upload to a non-current version of the room_keys store"""
def __init__(self, current_version: str):
"""
@ -426,7 +424,9 @@ class UnsupportedRoomVersionError(SynapseError):
def __init__(self, msg: str = "Homeserver does not support this room version"):
super().__init__(
code=400, msg=msg, errcode=Codes.UNSUPPORTED_ROOM_VERSION,
code=400,
msg=msg,
errcode=Codes.UNSUPPORTED_ROOM_VERSION,
)
@ -461,8 +461,7 @@ class IncompatibleRoomVersionError(SynapseError):
class PasswordRefusedError(SynapseError):
"""A password has been refused, either during password reset/change or registration.
"""
"""A password has been refused, either during password reset/change or registration."""
def __init__(
self,
@ -470,7 +469,9 @@ class PasswordRefusedError(SynapseError):
errcode: str = Codes.WEAK_PASSWORD,
):
super().__init__(
code=400, msg=msg, errcode=errcode,
code=400,
msg=msg,
errcode=errcode,
)
@ -493,7 +494,7 @@ class RequestSendFailed(RuntimeError):
def cs_error(msg: str, code: str = Codes.UNKNOWN, **kwargs):
""" Utility method for constructing an error response for client-server
"""Utility method for constructing an error response for client-server
interactions.
Args:
@ -510,7 +511,7 @@ def cs_error(msg: str, code: str = Codes.UNKNOWN, **kwargs):
class FederationError(RuntimeError):
""" This class is used to inform remote homeservers about erroneous
"""This class is used to inform remote homeservers about erroneous
PDUs they sent us.
FATAL: The remote server could not interpret the source event.

View File

@ -56,8 +56,7 @@ class UserPresenceState(
@classmethod
def default(cls, user_id):
"""Returns a default presence state.
"""
"""Returns a default presence state."""
return cls(
user_id=user_id,
state=PresenceState.OFFLINE,

View File

@ -58,7 +58,7 @@ def register_sighup(func, *args, **kwargs):
def start_worker_reactor(appname, config, run_command=reactor.run):
""" Run the reactor in the main process
"""Run the reactor in the main process
Daemonizes if necessary, and then configures some resources, before starting
the reactor. Pulls configuration from the 'worker' settings in 'config'.
@ -93,7 +93,7 @@ def start_reactor(
logger,
run_command=reactor.run,
):
""" Run the reactor in the main process
"""Run the reactor in the main process
Daemonizes if necessary, and then configures some resources, before starting
the reactor
@ -313,9 +313,7 @@ async def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerCon
refresh_certificate(hs)
# Start the tracer
synapse.logging.opentracing.init_tracer( # type: ignore[attr-defined] # noqa
hs
)
synapse.logging.opentracing.init_tracer(hs) # type: ignore[attr-defined] # noqa
# It is now safe to start your Synapse.
hs.start_listening(listeners)
@ -370,8 +368,7 @@ def setup_sentry(hs):
def setup_sdnotify(hs):
"""Adds process state hooks to tell systemd what we are up to.
"""
"""Adds process state hooks to tell systemd what we are up to."""
# Tell systemd our state, if we're using it. This will silently fail if
# we're not using systemd.
@ -405,8 +402,7 @@ def install_dns_limiter(reactor, max_dns_requests_in_flight=100):
class _LimitedHostnameResolver:
"""Wraps a IHostnameResolver, limiting the number of in-flight DNS lookups.
"""
"""Wraps a IHostnameResolver, limiting the number of in-flight DNS lookups."""
def __init__(self, resolver, max_dns_requests_in_flight):
self._resolver = resolver

View File

@ -421,8 +421,7 @@ class GenericWorkerPresence(BasePresenceHandler):
]
async def set_state(self, target_user, state, ignore_status_msg=False):
"""Set the presence state of the user.
"""
"""Set the presence state of the user."""
presence = state["presence"]
valid_presence = (

View File

@ -166,7 +166,10 @@ class ApplicationService:
@cached(num_args=1, cache_context=True)
async def matches_user_in_member_list(
self, room_id: str, store: "DataStore", cache_context: _CacheContext,
self,
room_id: str,
store: "DataStore",
cache_context: _CacheContext,
) -> bool:
"""Check if this service is interested a room based upon it's membership

View File

@ -227,7 +227,9 @@ class ApplicationServiceApi(SimpleHttpClient):
try:
await self.put_json(
uri=uri, json_body=body, args={"access_token": service.hs_token},
uri=uri,
json_body=body,
args={"access_token": service.hs_token},
)
sent_transactions_counter.labels(service.id).inc()
sent_events_counter.labels(service.id).inc(len(events))

View File

@ -68,7 +68,7 @@ MAX_EPHEMERAL_EVENTS_PER_TRANSACTION = 100
class ApplicationServiceScheduler:
""" Public facing API for this module. Does the required DI to tie the
"""Public facing API for this module. Does the required DI to tie the
components together. This also serves as the "event_pool", which in this
case is a simple array.
"""

View File

@ -224,7 +224,9 @@ class Config:
return self.read_templates([filename])[0]
def read_templates(
self, filenames: List[str], custom_template_directory: Optional[str] = None,
self,
filenames: List[str],
custom_template_directory: Optional[str] = None,
) -> List[jinja2.Template]:
"""Load a list of template files from disk using the given variables.
@ -264,7 +266,10 @@ class Config:
# TODO: switch to synapse.util.templates.build_jinja_env
loader = jinja2.FileSystemLoader(search_directories)
env = jinja2.Environment(loader=loader, autoescape=jinja2.select_autoescape(),)
env = jinja2.Environment(
loader=loader,
autoescape=jinja2.select_autoescape(),
)
# Update the environment with our custom filters
env.filters.update(
@ -825,8 +830,7 @@ class ShardedWorkerHandlingConfig:
instances = attr.ib(type=List[str])
def should_handle(self, instance_name: str, key: str) -> bool:
"""Whether this instance is responsible for handling the given key.
"""
"""Whether this instance is responsible for handling the given key."""
# If multiple instances are not defined we always return true
if not self.instances or len(self.instances) == 1:
return True

View File

@ -18,8 +18,7 @@ from ._base import Config
class AuthConfig(Config):
"""Password and login configuration
"""
"""Password and login configuration"""
section = "auth"

View File

@ -207,8 +207,7 @@ class DatabaseConfig(Config):
)
def get_single_database(self) -> DatabaseConnectionConfig:
"""Returns the database if there is only one, useful for e.g. tests
"""
"""Returns the database if there is only one, useful for e.g. tests"""
if not self.databases:
raise Exception("More than one database exists")

View File

@ -289,7 +289,8 @@ class EmailConfig(Config):
self.email_notif_template_html,
self.email_notif_template_text,
) = self.read_templates(
[notif_template_html, notif_template_text], template_dir,
[notif_template_html, notif_template_text],
template_dir,
)
self.email_notif_for_new_users = email_config.get(
@ -311,7 +312,8 @@ class EmailConfig(Config):
self.account_validity_template_html,
self.account_validity_template_text,
) = self.read_templates(
[expiry_template_html, expiry_template_text], template_dir,
[expiry_template_html, expiry_template_text],
template_dir,
)
subjects_config = email_config.get("subjects", {})

View File

@ -162,7 +162,10 @@ class LoggingConfig(Config):
)
logging_group.add_argument(
"-f", "--log-file", dest="log_file", help=argparse.SUPPRESS,
"-f",
"--log-file",
dest="log_file",
help=argparse.SUPPRESS,
)
def generate_files(self, config, config_dir_path):

View File

@ -355,9 +355,10 @@ def _parse_oidc_config_dict(
ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
ump_config.setdefault("config", {})
(user_mapping_provider_class, user_mapping_provider_config,) = load_module(
ump_config, config_path + ("user_mapping_provider",)
)
(
user_mapping_provider_class,
user_mapping_provider_config,
) = load_module(ump_config, config_path + ("user_mapping_provider",))
# Ensure loaded user mapping module has defined all necessary methods
required_methods = [
@ -372,7 +373,11 @@ def _parse_oidc_config_dict(
if missing_methods:
raise ConfigError(
"Class %s is missing required "
"methods: %s" % (user_mapping_provider_class, ", ".join(missing_methods),),
"methods: %s"
% (
user_mapping_provider_class,
", ".join(missing_methods),
),
config_path + ("user_mapping_provider", "module"),
)

View File

@ -52,7 +52,7 @@ MediaStorageProviderConfig = namedtuple(
def parse_thumbnail_requirements(thumbnail_sizes):
""" Takes a list of dictionaries with "width", "height", and "method" keys
"""Takes a list of dictionaries with "width", "height", and "method" keys
and creates a map from image media types to the thumbnail size, thumbnailing
method, and thumbnail media type to precalculate

View File

@ -52,7 +52,12 @@ def _6to4(network: IPNetwork) -> IPNetwork:
hex_network = hex(network.first)[2:]
hex_network = ("0" * (8 - len(hex_network))) + hex_network
return IPNetwork(
"2002:%s:%s::/%d" % (hex_network[:4], hex_network[4:], 16 + network.prefixlen,)
"2002:%s:%s::/%d"
% (
hex_network[:4],
hex_network[4:],
16 + network.prefixlen,
)
)
@ -254,7 +259,8 @@ class ServerConfig(Config):
# Whether to require sharing a room with a user to retrieve their
# profile data
self.limit_profile_requests_to_users_who_share_rooms = config.get(
"limit_profile_requests_to_users_who_share_rooms", False,
"limit_profile_requests_to_users_who_share_rooms",
False,
)
if "restrict_public_rooms_to_local_users" in config and (
@ -614,7 +620,9 @@ class ServerConfig(Config):
if manhole:
self.listeners.append(
ListenerConfig(
port=manhole, bind_addresses=["127.0.0.1"], type="manhole",
port=manhole,
bind_addresses=["127.0.0.1"],
type="manhole",
)
)
@ -650,7 +658,8 @@ class ServerConfig(Config):
# and letting the client know which email address is bound to an account and
# which one isn't.
self.request_token_inhibit_3pid_errors = config.get(
"request_token_inhibit_3pid_errors", False,
"request_token_inhibit_3pid_errors",
False,
)
# List of users trialing the new experimental default push rules. This setting is

View File

@ -35,8 +35,7 @@ class SsoAttributeRequirement:
class SSOConfig(Config):
"""SSO Configuration
"""
"""SSO Configuration"""
section = "sso"

View File

@ -33,8 +33,7 @@ def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]:
@attr.s
class InstanceLocationConfig:
"""The host and port to talk to an instance via HTTP replication.
"""
"""The host and port to talk to an instance via HTTP replication."""
host = attr.ib(type=str)
port = attr.ib(type=int)
@ -54,13 +53,19 @@ class WriterLocations:
)
typing = attr.ib(default="master", type=str)
to_device = attr.ib(
default=["master"], type=List[str], converter=_instance_to_list_converter,
default=["master"],
type=List[str],
converter=_instance_to_list_converter,
)
account_data = attr.ib(
default=["master"], type=List[str], converter=_instance_to_list_converter,
default=["master"],
type=List[str],
converter=_instance_to_list_converter,
)
receipts = attr.ib(
default=["master"], type=List[str], converter=_instance_to_list_converter,
default=["master"],
type=List[str],
converter=_instance_to_list_converter,
)
@ -107,7 +112,9 @@ class WorkerConfig(Config):
if manhole:
self.worker_listeners.append(
ListenerConfig(
port=manhole, bind_addresses=["127.0.0.1"], type="manhole",
port=manhole,
bind_addresses=["127.0.0.1"],
type="manhole",
)
)

View File

@ -42,7 +42,7 @@ def check(
do_sig_check: bool = True,
do_size_check: bool = True,
) -> None:
""" Checks if this event is correctly authed.
"""Checks if this event is correctly authed.
Args:
room_version_obj: the version of the room
@ -423,7 +423,9 @@ def _can_send_event(event: EventBase, auth_events: StateMap[EventBase]) -> bool:
def check_redaction(
room_version_obj: RoomVersion, event: EventBase, auth_events: StateMap[EventBase],
room_version_obj: RoomVersion,
event: EventBase,
auth_events: StateMap[EventBase],
) -> bool:
"""Check whether the event sender is allowed to redact the target event.
@ -459,7 +461,9 @@ def check_redaction(
def _check_power_levels(
room_version_obj: RoomVersion, event: EventBase, auth_events: StateMap[EventBase],
room_version_obj: RoomVersion,
event: EventBase,
auth_events: StateMap[EventBase],
) -> None:
user_list = event.content.get("users", {})
# Validate users

View File

@ -98,7 +98,9 @@ class EventBuilder:
return self._state_key is not None
async def build(
self, prev_event_ids: List[str], auth_event_ids: Optional[List[str]],
self,
prev_event_ids: List[str],
auth_event_ids: Optional[List[str]],
) -> EventBase:
"""Transform into a fully signed and hashed event

View File

@ -341,8 +341,7 @@ def _encode_state_dict(state_dict):
def _decode_state_dict(input):
"""Decodes a state dict encoded using `_encode_state_dict` above
"""
"""Decodes a state dict encoded using `_encode_state_dict` above"""
if input is None:
return None

View File

@ -40,7 +40,8 @@ class ThirdPartyEventRules:
if module is not None:
self.third_party_rules = module(
config=config, module_api=hs.get_module_api(),
config=config,
module_api=hs.get_module_api(),
)
async def check_event_allowed(

View File

@ -34,7 +34,7 @@ SPLIT_FIELD_REGEX = re.compile(r"(?<!\\)\.")
def prune_event(event: EventBase) -> EventBase:
""" Returns a pruned version of the given event, which removes all keys we
"""Returns a pruned version of the given event, which removes all keys we
don't know about or think could potentially be dodgy.
This is used when we "redact" an event. We want to remove all fields that

View File

@ -750,7 +750,11 @@ class FederationClient(FederationBase):
return resp[1]
async def send_invite(
self, destination: str, room_id: str, event_id: str, pdu: EventBase,
self,
destination: str,
room_id: str,
event_id: str,
pdu: EventBase,
) -> EventBase:
room_version = await self.store.get_room_version(room_id)

View File

@ -85,7 +85,8 @@ received_queries_counter = Counter(
)
pdu_process_time = Histogram(
"synapse_federation_server_pdu_process_time", "Time taken to process an event",
"synapse_federation_server_pdu_process_time",
"Time taken to process an event",
)
@ -204,7 +205,7 @@ class FederationServer(FederationBase):
async def _handle_incoming_transaction(
self, origin: str, transaction: Transaction, request_time: int
) -> Tuple[int, Dict[str, Any]]:
""" Process an incoming transaction and return the HTTP response
"""Process an incoming transaction and return the HTTP response
Args:
origin: the server making the request
@ -373,8 +374,7 @@ class FederationServer(FederationBase):
return pdu_results
async def _handle_edus_in_txn(self, origin: str, transaction: Transaction):
"""Process the EDUs in a received transaction.
"""
"""Process the EDUs in a received transaction."""
async def _process_edu(edu_dict):
received_edus_counter.inc()
@ -437,7 +437,10 @@ class FederationServer(FederationBase):
raise AuthError(403, "Host not in room.")
resp = await self._state_ids_resp_cache.wrap(
(room_id, event_id), self._on_state_ids_request_compute, room_id, event_id,
(room_id, event_id),
self._on_state_ids_request_compute,
room_id,
event_id,
)
return 200, resp
@ -679,7 +682,7 @@ class FederationServer(FederationBase):
)
async def _handle_received_pdu(self, origin: str, pdu: EventBase) -> None:
""" Process a PDU received in a federation /send/ transaction.
"""Process a PDU received in a federation /send/ transaction.
If the event is invalid, then this method throws a FederationError.
(The error will then be logged and sent back to the sender (which
@ -906,13 +909,11 @@ class FederationHandlerRegistry:
self.query_handlers[query_type] = handler
def register_instance_for_edu(self, edu_type: str, instance_name: str):
"""Register that the EDU handler is on a different instance than master.
"""
"""Register that the EDU handler is on a different instance than master."""
self._edu_type_to_instance[edu_type] = [instance_name]
def register_instances_for_edu(self, edu_type: str, instance_names: List[str]):
"""Register that the EDU handler is on multiple instances.
"""
"""Register that the EDU handler is on multiple instances."""
self._edu_type_to_instance[edu_type] = instance_names
async def on_edu(self, edu_type: str, origin: str, content: dict):

View File

@ -30,8 +30,7 @@ logger = logging.getLogger(__name__)
class TransactionActions:
""" Defines persistence actions that relate to handling Transactions.
"""
"""Defines persistence actions that relate to handling Transactions."""
def __init__(self, datastore):
self.store = datastore
@ -57,8 +56,7 @@ class TransactionActions:
async def set_response(
self, origin: str, transaction: Transaction, code: int, response: JsonDict
) -> None:
"""Persist how we responded to a transaction.
"""
"""Persist how we responded to a transaction."""
transaction_id = transaction.transaction_id # type: ignore
if not transaction_id:
raise RuntimeError("Cannot persist a transaction with no transaction_id")

View File

@ -468,8 +468,7 @@ class KeyedEduRow(
class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu
"""Streams EDUs that don't have keys. See KeyedEduRow
"""
"""Streams EDUs that don't have keys. See KeyedEduRow"""
TypeId = "e"
@ -519,7 +518,10 @@ def process_rows_for_federation(transaction_queue, rows):
# them into the appropriate collection and then send them off.
buff = ParsedFederationStreamData(
presence=[], presence_destinations=[], keyed_edus={}, edus={},
presence=[],
presence_destinations=[],
keyed_edus={},
edus={},
)
# Parse the rows in the stream and add to the buffer

View File

@ -328,7 +328,9 @@ class FederationSender:
# to allow us to perform catch-up later on if the remote is unreachable
# for a while.
await self.store.store_destination_rooms_entries(
destinations, pdu.room_id, pdu.internal_metadata.stream_ordering,
destinations,
pdu.room_id,
pdu.internal_metadata.stream_ordering,
)
for destination in destinations:
@ -475,7 +477,7 @@ class FederationSender:
self, states: List[UserPresenceState], destinations: List[str]
) -> None:
"""Send the given presence states to the given destinations.
destinations (list[str])
destinations (list[str])
"""
if not states or not self.hs.config.use_presence:
@ -616,8 +618,8 @@ class FederationSender:
last_processed = None # type: Optional[str]
while True:
destinations_to_wake = await self.store.get_catch_up_outstanding_destinations(
last_processed
destinations_to_wake = (
await self.store.get_catch_up_outstanding_destinations(last_processed)
)
if not destinations_to_wake:

View File

@ -85,7 +85,8 @@ class PerDestinationQueue:
# processing. We have a guard in `attempt_new_transaction` that
# ensure we don't start sending stuff.
logger.error(
"Create a per destination queue for %s on wrong worker", destination,
"Create a per destination queue for %s on wrong worker",
destination,
)
self._should_send_on_this_instance = False
@ -440,8 +441,10 @@ class PerDestinationQueue:
if first_catch_up_check:
# first catchup so get last_successful_stream_ordering from database
self._last_successful_stream_ordering = await self._store.get_destination_last_successful_stream_ordering(
self._destination
self._last_successful_stream_ordering = (
await self._store.get_destination_last_successful_stream_ordering(
self._destination
)
)
if self._last_successful_stream_ordering is None:
@ -457,7 +460,8 @@ class PerDestinationQueue:
# get at most 50 catchup room/PDUs
while True:
event_ids = await self._store.get_catch_up_room_event_ids(
self._destination, self._last_successful_stream_ordering,
self._destination,
self._last_successful_stream_ordering,
)
if not event_ids:

View File

@ -65,7 +65,10 @@ class TransactionManager:
@measure_func("_send_new_transaction")
async def send_new_transaction(
self, destination: str, pdus: List[EventBase], edus: List[Edu],
self,
destination: str,
pdus: List[EventBase],
edus: List[Edu],
) -> bool:
"""
Args:

View File

@ -39,7 +39,7 @@ class TransportLayerClient:
@log_function
def get_room_state_ids(self, destination, room_id, event_id):
""" Requests all state for a given room from the given server at the
"""Requests all state for a given room from the given server at the
given event. Returns the state's event_id's
Args:
@ -63,7 +63,7 @@ class TransportLayerClient:
@log_function
def get_event(self, destination, event_id, timeout=None):
""" Requests the pdu with give id and origin from the given server.
"""Requests the pdu with give id and origin from the given server.
Args:
destination (str): The host name of the remote homeserver we want
@ -84,7 +84,7 @@ class TransportLayerClient:
@log_function
def backfill(self, destination, room_id, event_tuples, limit):
""" Requests `limit` previous PDUs in a given context before list of
"""Requests `limit` previous PDUs in a given context before list of
PDUs.
Args:
@ -118,7 +118,7 @@ class TransportLayerClient:
@log_function
async def send_transaction(self, transaction, json_data_callback=None):
""" Sends the given Transaction to its destination
"""Sends the given Transaction to its destination
Args:
transaction (Transaction)
@ -551,8 +551,7 @@ class TransportLayerClient:
@log_function
def get_group_profile(self, destination, group_id, requester_user_id):
"""Get a group profile
"""
"""Get a group profile"""
path = _create_v1_path("/groups/%s/profile", group_id)
return self.client.get_json(
@ -584,8 +583,7 @@ class TransportLayerClient:
@log_function
def get_group_summary(self, destination, group_id, requester_user_id):
"""Get a group summary
"""
"""Get a group summary"""
path = _create_v1_path("/groups/%s/summary", group_id)
return self.client.get_json(
@ -597,8 +595,7 @@ class TransportLayerClient:
@log_function
def get_rooms_in_group(self, destination, group_id, requester_user_id):
"""Get all rooms in a group
"""
"""Get all rooms in a group"""
path = _create_v1_path("/groups/%s/rooms", group_id)
return self.client.get_json(
@ -611,8 +608,7 @@ class TransportLayerClient:
def add_room_to_group(
self, destination, group_id, requester_user_id, room_id, content
):
"""Add a room to a group
"""
"""Add a room to a group"""
path = _create_v1_path("/groups/%s/room/%s", group_id, room_id)
return self.client.post_json(
@ -626,8 +622,7 @@ class TransportLayerClient:
def update_room_in_group(
self, destination, group_id, requester_user_id, room_id, config_key, content
):
"""Update room in group
"""
"""Update room in group"""
path = _create_v1_path(
"/groups/%s/room/%s/config/%s", group_id, room_id, config_key
)
@ -641,8 +636,7 @@ class TransportLayerClient:
)
def remove_room_from_group(self, destination, group_id, requester_user_id, room_id):
"""Remove a room from a group
"""
"""Remove a room from a group"""
path = _create_v1_path("/groups/%s/room/%s", group_id, room_id)
return self.client.delete_json(
@ -654,8 +648,7 @@ class TransportLayerClient:
@log_function
def get_users_in_group(self, destination, group_id, requester_user_id):
"""Get users in a group
"""
"""Get users in a group"""
path = _create_v1_path("/groups/%s/users", group_id)
return self.client.get_json(
@ -667,8 +660,7 @@ class TransportLayerClient:
@log_function
def get_invited_users_in_group(self, destination, group_id, requester_user_id):
"""Get users that have been invited to a group
"""
"""Get users that have been invited to a group"""
path = _create_v1_path("/groups/%s/invited_users", group_id)
return self.client.get_json(
@ -680,8 +672,7 @@ class TransportLayerClient:
@log_function
def accept_group_invite(self, destination, group_id, user_id, content):
"""Accept a group invite
"""
"""Accept a group invite"""
path = _create_v1_path("/groups/%s/users/%s/accept_invite", group_id, user_id)
return self.client.post_json(
@ -690,8 +681,7 @@ class TransportLayerClient:
@log_function
def join_group(self, destination, group_id, user_id, content):
"""Attempts to join a group
"""
"""Attempts to join a group"""
path = _create_v1_path("/groups/%s/users/%s/join", group_id, user_id)
return self.client.post_json(
@ -702,8 +692,7 @@ class TransportLayerClient:
def invite_to_group(
self, destination, group_id, user_id, requester_user_id, content
):
"""Invite a user to a group
"""
"""Invite a user to a group"""
path = _create_v1_path("/groups/%s/users/%s/invite", group_id, user_id)
return self.client.post_json(
@ -730,8 +719,7 @@ class TransportLayerClient:
def remove_user_from_group(
self, destination, group_id, requester_user_id, user_id, content
):
"""Remove a user from a group
"""
"""Remove a user from a group"""
path = _create_v1_path("/groups/%s/users/%s/remove", group_id, user_id)
return self.client.post_json(
@ -772,8 +760,7 @@ class TransportLayerClient:
def update_group_summary_room(
self, destination, group_id, user_id, room_id, category_id, content
):
"""Update a room entry in a group summary
"""
"""Update a room entry in a group summary"""
if category_id:
path = _create_v1_path(
"/groups/%s/summary/categories/%s/rooms/%s",
@ -796,8 +783,7 @@ class TransportLayerClient:
def delete_group_summary_room(
self, destination, group_id, user_id, room_id, category_id
):
"""Delete a room entry in a group summary
"""
"""Delete a room entry in a group summary"""
if category_id:
path = _create_v1_path(
"/groups/%s/summary/categories/%s/rooms/%s",
@ -817,8 +803,7 @@ class TransportLayerClient:
@log_function
def get_group_categories(self, destination, group_id, requester_user_id):
"""Get all categories in a group
"""
"""Get all categories in a group"""
path = _create_v1_path("/groups/%s/categories", group_id)
return self.client.get_json(
@ -830,8 +815,7 @@ class TransportLayerClient:
@log_function
def get_group_category(self, destination, group_id, requester_user_id, category_id):
"""Get category info in a group
"""
"""Get category info in a group"""
path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
return self.client.get_json(
@ -845,8 +829,7 @@ class TransportLayerClient:
def update_group_category(
self, destination, group_id, requester_user_id, category_id, content
):
"""Update a category in a group
"""
"""Update a category in a group"""
path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
return self.client.post_json(
@ -861,8 +844,7 @@ class TransportLayerClient:
def delete_group_category(
self, destination, group_id, requester_user_id, category_id
):
"""Delete a category in a group
"""
"""Delete a category in a group"""
path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
return self.client.delete_json(
@ -874,8 +856,7 @@ class TransportLayerClient:
@log_function
def get_group_roles(self, destination, group_id, requester_user_id):
"""Get all roles in a group
"""
"""Get all roles in a group"""
path = _create_v1_path("/groups/%s/roles", group_id)
return self.client.get_json(
@ -887,8 +868,7 @@ class TransportLayerClient:
@log_function
def get_group_role(self, destination, group_id, requester_user_id, role_id):
"""Get a roles info
"""
"""Get a roles info"""
path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
return self.client.get_json(
@ -902,8 +882,7 @@ class TransportLayerClient:
def update_group_role(
self, destination, group_id, requester_user_id, role_id, content
):
"""Update a role in a group
"""
"""Update a role in a group"""
path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
return self.client.post_json(
@ -916,8 +895,7 @@ class TransportLayerClient:
@log_function
def delete_group_role(self, destination, group_id, requester_user_id, role_id):
"""Delete a role in a group
"""
"""Delete a role in a group"""
path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
return self.client.delete_json(
@ -931,8 +909,7 @@ class TransportLayerClient:
def update_group_summary_user(
self, destination, group_id, requester_user_id, user_id, role_id, content
):
"""Update a users entry in a group
"""
"""Update a users entry in a group"""
if role_id:
path = _create_v1_path(
"/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id
@ -950,8 +927,7 @@ class TransportLayerClient:
@log_function
def set_group_join_policy(self, destination, group_id, requester_user_id, content):
"""Sets the join policy for a group
"""
"""Sets the join policy for a group"""
path = _create_v1_path("/groups/%s/settings/m.join_policy", group_id)
return self.client.put_json(
@ -966,8 +942,7 @@ class TransportLayerClient:
def delete_group_summary_user(
self, destination, group_id, requester_user_id, user_id, role_id
):
"""Delete a users entry in a group
"""
"""Delete a users entry in a group"""
if role_id:
path = _create_v1_path(
"/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id
@ -983,8 +958,7 @@ class TransportLayerClient:
)
def bulk_get_publicised_groups(self, destination, user_ids):
"""Get the groups a list of users are publicising
"""
"""Get the groups a list of users are publicising"""
path = _create_v1_path("/get_groups_publicised")

View File

@ -364,7 +364,10 @@ class BaseFederationServlet:
continue
server.register_paths(
method, (pattern,), self._wrap(code), self.__class__.__name__,
method,
(pattern,),
self._wrap(code),
self.__class__.__name__,
)
@ -381,7 +384,7 @@ class FederationSendServlet(BaseFederationServlet):
# This is when someone is trying to send us a bunch of data.
async def on_PUT(self, origin, content, query, transaction_id):
""" Called on PUT /send/<transaction_id>/
"""Called on PUT /send/<transaction_id>/
Args:
request (twisted.web.http.Request): The HTTP request.
@ -855,8 +858,7 @@ class FederationVersionServlet(BaseFederationServlet):
class FederationGroupsProfileServlet(BaseFederationServlet):
"""Get/set the basic profile of a group on behalf of a user
"""
"""Get/set the basic profile of a group on behalf of a user"""
PATH = "/groups/(?P<group_id>[^/]*)/profile"
@ -895,8 +897,7 @@ class FederationGroupsSummaryServlet(BaseFederationServlet):
class FederationGroupsRoomsServlet(BaseFederationServlet):
"""Get the rooms in a group on behalf of a user
"""
"""Get the rooms in a group on behalf of a user"""
PATH = "/groups/(?P<group_id>[^/]*)/rooms"
@ -911,8 +912,7 @@ class FederationGroupsRoomsServlet(BaseFederationServlet):
class FederationGroupsAddRoomsServlet(BaseFederationServlet):
"""Add/remove room from group
"""
"""Add/remove room from group"""
PATH = "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)"
@ -940,8 +940,7 @@ class FederationGroupsAddRoomsServlet(BaseFederationServlet):
class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet):
"""Update room config in group
"""
"""Update room config in group"""
PATH = (
"/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)"
@ -961,8 +960,7 @@ class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet):
class FederationGroupsUsersServlet(BaseFederationServlet):
"""Get the users in a group on behalf of a user
"""
"""Get the users in a group on behalf of a user"""
PATH = "/groups/(?P<group_id>[^/]*)/users"
@ -977,8 +975,7 @@ class FederationGroupsUsersServlet(BaseFederationServlet):
class FederationGroupsInvitedUsersServlet(BaseFederationServlet):
"""Get the users that have been invited to a group
"""
"""Get the users that have been invited to a group"""
PATH = "/groups/(?P<group_id>[^/]*)/invited_users"
@ -995,8 +992,7 @@ class FederationGroupsInvitedUsersServlet(BaseFederationServlet):
class FederationGroupsInviteServlet(BaseFederationServlet):
"""Ask a group server to invite someone to the group
"""
"""Ask a group server to invite someone to the group"""
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite"
@ -1013,8 +1009,7 @@ class FederationGroupsInviteServlet(BaseFederationServlet):
class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
"""Accept an invitation from the group server
"""
"""Accept an invitation from the group server"""
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/accept_invite"
@ -1028,8 +1023,7 @@ class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
class FederationGroupsJoinServlet(BaseFederationServlet):
"""Attempt to join a group
"""
"""Attempt to join a group"""
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/join"
@ -1043,8 +1037,7 @@ class FederationGroupsJoinServlet(BaseFederationServlet):
class FederationGroupsRemoveUserServlet(BaseFederationServlet):
"""Leave or kick a user from the group
"""
"""Leave or kick a user from the group"""
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove"
@ -1061,8 +1054,7 @@ class FederationGroupsRemoveUserServlet(BaseFederationServlet):
class FederationGroupsLocalInviteServlet(BaseFederationServlet):
"""A group server has invited a local user
"""
"""A group server has invited a local user"""
PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite"
@ -1076,8 +1068,7 @@ class FederationGroupsLocalInviteServlet(BaseFederationServlet):
class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet):
"""A group server has removed a local user
"""
"""A group server has removed a local user"""
PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove"
@ -1093,8 +1084,7 @@ class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet):
class FederationGroupsRenewAttestaionServlet(BaseFederationServlet):
"""A group or user's server renews their attestation
"""
"""A group or user's server renews their attestation"""
PATH = "/groups/(?P<group_id>[^/]*)/renew_attestation/(?P<user_id>[^/]*)"
@ -1156,8 +1146,7 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
class FederationGroupsCategoriesServlet(BaseFederationServlet):
"""Get all categories for a group
"""
"""Get all categories for a group"""
PATH = "/groups/(?P<group_id>[^/]*)/categories/?"
@ -1172,8 +1161,7 @@ class FederationGroupsCategoriesServlet(BaseFederationServlet):
class FederationGroupsCategoryServlet(BaseFederationServlet):
"""Add/remove/get a category in a group
"""
"""Add/remove/get a category in a group"""
PATH = "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)"
@ -1218,8 +1206,7 @@ class FederationGroupsCategoryServlet(BaseFederationServlet):
class FederationGroupsRolesServlet(BaseFederationServlet):
"""Get roles in a group
"""
"""Get roles in a group"""
PATH = "/groups/(?P<group_id>[^/]*)/roles/?"
@ -1234,8 +1221,7 @@ class FederationGroupsRolesServlet(BaseFederationServlet):
class FederationGroupsRoleServlet(BaseFederationServlet):
"""Add/remove/get a role in a group
"""
"""Add/remove/get a role in a group"""
PATH = "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)"
@ -1325,8 +1311,7 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
class FederationGroupsBulkPublicisedServlet(BaseFederationServlet):
"""Get roles in a group
"""
"""Get roles in a group"""
PATH = "/get_groups_publicised"
@ -1339,8 +1324,7 @@ class FederationGroupsBulkPublicisedServlet(BaseFederationServlet):
class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet):
"""Sets whether a group is joinable without an invite or knock
"""
"""Sets whether a group is joinable without an invite or knock"""
PATH = "/groups/(?P<group_id>[^/]*)/settings/m.join_policy"

View File

@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
@attr.s(slots=True)
class Edu(JsonEncodedObject):
""" An Edu represents a piece of data sent from one homeserver to another.
"""An Edu represents a piece of data sent from one homeserver to another.
In comparison to Pdus, Edus are not persisted for a long time on disk, are
not meaningful beyond a given pair of homeservers, and don't have an
@ -63,7 +63,7 @@ class Edu(JsonEncodedObject):
class Transaction(JsonEncodedObject):
""" A transaction is a list of Pdus and Edus to be sent to a remote home
"""A transaction is a list of Pdus and Edus to be sent to a remote home
server with some extra metadata.
Example transaction::
@ -99,7 +99,7 @@ class Transaction(JsonEncodedObject):
]
def __init__(self, transaction_id=None, pdus=[], **kwargs):
""" If we include a list of pdus then we decode then as PDU's
"""If we include a list of pdus then we decode then as PDU's
automatically.
"""
@ -111,7 +111,7 @@ class Transaction(JsonEncodedObject):
@staticmethod
def create_new(pdus, **kwargs):
""" Used to create a new transaction. Will auto fill out
"""Used to create a new transaction. Will auto fill out
transaction_id and origin_server_ts keys.
"""
if "origin_server_ts" not in kwargs:

View File

@ -61,8 +61,7 @@ UPDATE_ATTESTATION_TIME_MS = 1 * 24 * 60 * 60 * 1000
class GroupAttestationSigning:
"""Creates and verifies group attestations.
"""
"""Creates and verifies group attestations."""
def __init__(self, hs):
self.keyring = hs.get_keyring()
@ -125,8 +124,7 @@ class GroupAttestationSigning:
class GroupAttestionRenewer:
"""Responsible for sending and receiving attestation updates.
"""
"""Responsible for sending and receiving attestation updates."""
def __init__(self, hs):
self.clock = hs.get_clock()
@ -142,8 +140,7 @@ class GroupAttestionRenewer:
)
async def on_renew_attestation(self, group_id, user_id, content):
"""When a remote updates an attestation
"""
"""When a remote updates an attestation"""
attestation = content["attestation"]
if not self.is_mine_id(group_id) and not self.is_mine_id(user_id):
@ -161,8 +158,7 @@ class GroupAttestionRenewer:
return run_as_background_process("renew_attestations", self._renew_attestations)
async def _renew_attestations(self):
"""Called periodically to check if we need to update any of our attestations
"""
"""Called periodically to check if we need to update any of our attestations"""
now = self.clock.time_msec()

View File

@ -165,16 +165,14 @@ class GroupsServerWorkerHandler:
}
async def get_group_categories(self, group_id, requester_user_id):
"""Get all categories in a group (as seen by user)
"""
"""Get all categories in a group (as seen by user)"""
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
categories = await self.store.get_group_categories(group_id=group_id)
return {"categories": categories}
async def get_group_category(self, group_id, requester_user_id, category_id):
"""Get a specific category in a group (as seen by user)
"""
"""Get a specific category in a group (as seen by user)"""
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
res = await self.store.get_group_category(
@ -186,24 +184,21 @@ class GroupsServerWorkerHandler:
return res
async def get_group_roles(self, group_id, requester_user_id):
"""Get all roles in a group (as seen by user)
"""
"""Get all roles in a group (as seen by user)"""
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
roles = await self.store.get_group_roles(group_id=group_id)
return {"roles": roles}
async def get_group_role(self, group_id, requester_user_id, role_id):
"""Get a specific role in a group (as seen by user)
"""
"""Get a specific role in a group (as seen by user)"""
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
res = await self.store.get_group_role(group_id=group_id, role_id=role_id)
return res
async def get_group_profile(self, group_id, requester_user_id):
"""Get the group profile as seen by requester_user_id
"""
"""Get the group profile as seen by requester_user_id"""
await self.check_group_is_ours(group_id, requester_user_id)
@ -350,8 +345,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
async def update_group_summary_room(
self, group_id, requester_user_id, room_id, category_id, content
):
"""Add/update a room to the group summary
"""
"""Add/update a room to the group summary"""
await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
@ -375,8 +369,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
async def delete_group_summary_room(
self, group_id, requester_user_id, room_id, category_id
):
"""Remove a room from the summary
"""
"""Remove a room from the summary"""
await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
@ -409,8 +402,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
async def update_group_category(
self, group_id, requester_user_id, category_id, content
):
"""Add/Update a group category
"""
"""Add/Update a group category"""
await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
@ -428,8 +420,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {}
async def delete_group_category(self, group_id, requester_user_id, category_id):
"""Delete a group category
"""
"""Delete a group category"""
await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
@ -441,8 +432,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {}
async def update_group_role(self, group_id, requester_user_id, role_id, content):
"""Add/update a role in a group
"""
"""Add/update a role in a group"""
await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
@ -458,8 +448,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {}
async def delete_group_role(self, group_id, requester_user_id, role_id):
"""Remove role from group
"""
"""Remove role from group"""
await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
@ -471,8 +460,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
async def update_group_summary_user(
self, group_id, requester_user_id, user_id, role_id, content
):
"""Add/update a users entry in the group summary
"""
"""Add/update a users entry in the group summary"""
await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
@ -494,8 +482,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
async def delete_group_summary_user(
self, group_id, requester_user_id, user_id, role_id
):
"""Remove a user from the group summary
"""
"""Remove a user from the group summary"""
await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
@ -507,8 +494,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {}
async def update_group_profile(self, group_id, requester_user_id, content):
"""Update the group profile
"""
"""Update the group profile"""
await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
@ -539,8 +525,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
await self.store.update_group_profile(group_id, profile)
async def add_room_to_group(self, group_id, requester_user_id, room_id, content):
"""Add room to group
"""
"""Add room to group"""
RoomID.from_string(room_id) # Ensure valid room id
await self.check_group_is_ours(
@ -556,8 +541,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
async def update_room_in_group(
self, group_id, requester_user_id, room_id, config_key, content
):
"""Update room in group
"""
"""Update room in group"""
RoomID.from_string(room_id) # Ensure valid room id
await self.check_group_is_ours(
@ -576,8 +560,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {}
async def remove_room_from_group(self, group_id, requester_user_id, room_id):
"""Remove room from group
"""
"""Remove room from group"""
await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
@ -587,8 +570,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {}
async def invite_to_group(self, group_id, user_id, requester_user_id, content):
"""Invite user to group
"""
"""Invite user to group"""
group = await self.check_group_is_ours(
group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
@ -724,8 +706,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
return {"state": "join", "attestation": local_attestation}
async def knock(self, group_id, requester_user_id, content):
"""A user requests becoming a member of the group
"""
"""A user requests becoming a member of the group"""
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
raise NotImplementedError()
@ -918,8 +899,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
def _parse_join_policy_from_contents(content):
"""Given a content for a request, return the specified join policy or None
"""
"""Given a content for a request, return the specified join policy or None"""
join_policy_dict = content.get("m.join_policy")
if join_policy_dict:
@ -929,8 +909,7 @@ def _parse_join_policy_from_contents(content):
def _parse_join_policy_dict(join_policy_dict):
"""Given a dict for the "m.join_policy" config return the join policy specified
"""
"""Given a dict for the "m.join_policy" config return the join policy specified"""
join_policy_type = join_policy_dict.get("type")
if not join_policy_type:
return "invite"

View File

@ -203,13 +203,11 @@ class AdminHandler(BaseHandler):
class ExfiltrationWriter(metaclass=abc.ABCMeta):
"""Interface used to specify how to write exported data.
"""
"""Interface used to specify how to write exported data."""
@abc.abstractmethod
def write_events(self, room_id: str, events: List[EventBase]) -> None:
"""Write a batch of events for a room.
"""
"""Write a batch of events for a room."""
raise NotImplementedError()
@abc.abstractmethod

View File

@ -290,7 +290,9 @@ class ApplicationServicesHandler:
if not interested:
continue
presence_events, _ = await presence_source.get_new_events(
user=user, service=service, from_key=from_key,
user=user,
service=service,
from_key=from_key,
)
time_now = self.clock.time_msec()
events.extend(

View File

@ -120,7 +120,9 @@ def convert_client_dict_legacy_fields_to_identifier(
# Ensure the identifier has a type
if "type" not in identifier:
raise SynapseError(
400, "'identifier' dict has no key 'type'", errcode=Codes.MISSING_PARAM,
400,
"'identifier' dict has no key 'type'",
errcode=Codes.MISSING_PARAM,
)
return identifier
@ -351,7 +353,11 @@ class AuthHandler(BaseHandler):
try:
result, params, session_id = await self.check_ui_auth(
flows, request, request_body, description, get_new_session_data,
flows,
request,
request_body,
description,
get_new_session_data,
)
except LoginError:
# Update the ratelimiter to say we failed (`can_do_action` doesn't raise).
@ -379,8 +385,7 @@ class AuthHandler(BaseHandler):
return params, session_id
async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]:
"""Get a list of the authentication types this user can use
"""
"""Get a list of the authentication types this user can use"""
ui_auth_types = set()
@ -723,7 +728,9 @@ class AuthHandler(BaseHandler):
}
def _auth_dict_for_flows(
self, flows: List[List[str]], session_id: str,
self,
flows: List[List[str]],
session_id: str,
) -> Dict[str, Any]:
public_flows = []
for f in flows:
@ -880,7 +887,9 @@ class AuthHandler(BaseHandler):
return self._supported_login_types
async def validate_login(
self, login_submission: Dict[str, Any], ratelimit: bool = False,
self,
login_submission: Dict[str, Any],
ratelimit: bool = False,
) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
"""Authenticates the user for the /login API
@ -1023,7 +1032,9 @@ class AuthHandler(BaseHandler):
raise
async def _validate_userid_login(
self, username: str, login_submission: Dict[str, Any],
self,
username: str,
login_submission: Dict[str, Any],
) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
"""Helper for validate_login
@ -1446,7 +1457,8 @@ class AuthHandler(BaseHandler):
# is considered OK since the newest SSO attributes should be most valid.
if extra_attributes:
self._extra_attributes[registered_user_id] = SsoLoginExtraAttributes(
self._clock.time_msec(), extra_attributes,
self._clock.time_msec(),
extra_attributes,
)
# Create a login token
@ -1702,5 +1714,9 @@ class PasswordProvider:
# This might return an awaitable, if it does block the log out
# until it completes.
await maybe_awaitable(
g(user_id=user_id, device_id=device_id, access_token=access_token,)
g(
user_id=user_id,
device_id=device_id,
access_token=access_token,
)
)

View File

@ -33,8 +33,7 @@ logger = logging.getLogger(__name__)
class CasError(Exception):
"""Used to catch errors when validating the CAS ticket.
"""
"""Used to catch errors when validating the CAS ticket."""
def __init__(self, error, error_description=None):
self.error = error
@ -100,7 +99,10 @@ class CasHandler:
Returns:
The URL to use as a "service" parameter.
"""
return "%s?%s" % (self._cas_service_url, urllib.parse.urlencode(args),)
return "%s?%s" % (
self._cas_service_url,
urllib.parse.urlencode(args),
)
async def _validate_ticket(
self, ticket: str, service_args: Dict[str, str]
@ -296,7 +298,10 @@ class CasHandler:
# first check if we're doing a UIA
if session:
return await self._sso_handler.complete_sso_ui_auth_request(
self.idp_id, cas_response.username, session, request,
self.idp_id,
cas_response.username,
session,
request,
)
# otherwise, we're handling a login request.
@ -366,7 +371,8 @@ class CasHandler:
user_id = UserID(localpart, self._hostname).to_string()
logger.debug(
"Looking for existing account based on mapped %s", user_id,
"Looking for existing account based on mapped %s",
user_id,
)
users = await self._store.get_users_by_id_case_insensitive(user_id)

View File

@ -196,8 +196,7 @@ class DeactivateAccountHandler(BaseHandler):
run_as_background_process("user_parter_loop", self._user_parter_loop)
async def _user_parter_loop(self) -> None:
"""Loop that parts deactivated users from rooms
"""
"""Loop that parts deactivated users from rooms"""
self._user_parter_running = True
logger.info("Starting user parter")
try:
@ -214,8 +213,7 @@ class DeactivateAccountHandler(BaseHandler):
self._user_parter_running = False
async def _part_user(self, user_id: str) -> None:
"""Causes the given user_id to leave all the rooms they're joined to
"""
"""Causes the given user_id to leave all the rooms they're joined to"""
user = UserID.from_string(user_id)
rooms_for_user = await self.store.get_rooms_for_user(user_id)

View File

@ -86,7 +86,7 @@ class DeviceWorkerHandler(BaseHandler):
@trace
async def get_device(self, user_id: str, device_id: str) -> JsonDict:
""" Retrieve the given device
"""Retrieve the given device
Args:
user_id: The user to get the device from
@ -341,7 +341,7 @@ class DeviceHandler(DeviceWorkerHandler):
@trace
async def delete_device(self, user_id: str, device_id: str) -> None:
""" Delete the given device
"""Delete the given device
Args:
user_id: The user to delete the device from.
@ -386,7 +386,7 @@ class DeviceHandler(DeviceWorkerHandler):
await self.delete_devices(user_id, device_ids)
async def delete_devices(self, user_id: str, device_ids: List[str]) -> None:
""" Delete several devices
"""Delete several devices
Args:
user_id: The user to delete devices from.
@ -417,7 +417,7 @@ class DeviceHandler(DeviceWorkerHandler):
await self.notify_device_update(user_id, device_ids)
async def update_device(self, user_id: str, device_id: str, content: dict) -> None:
""" Update the given device
"""Update the given device
Args:
user_id: The user to update devices of.
@ -534,7 +534,9 @@ class DeviceHandler(DeviceWorkerHandler):
device id of the dehydrated device
"""
device_id = await self.check_device_registered(
user_id, None, initial_device_display_name,
user_id,
None,
initial_device_display_name,
)
old_device_id = await self.store.store_dehydrated_device(
user_id, device_id, device_data
@ -803,7 +805,8 @@ class DeviceListUpdater:
try:
# Try to resync the current user's devices list.
result = await self.user_device_resync(
user_id=user_id, mark_failed_as_stale=False,
user_id=user_id,
mark_failed_as_stale=False,
)
# user_device_resync only returns a result if it managed to
@ -813,14 +816,17 @@ class DeviceListUpdater:
# self.store.update_remote_device_list_cache).
if result:
logger.debug(
"Successfully resynced the device list for %s", user_id,
"Successfully resynced the device list for %s",
user_id,
)
except Exception as e:
# If there was an issue resyncing this user, e.g. if the remote
# server sent a malformed result, just log the error instead of
# aborting all the subsequent resyncs.
logger.debug(
"Could not resync the device list for %s: %s", user_id, e,
"Could not resync the device list for %s: %s",
user_id,
e,
)
finally:
# Allow future calls to retry resyncinc out of sync device lists.
@ -855,7 +861,9 @@ class DeviceListUpdater:
return None
except (RequestSendFailed, HttpResponseException) as e:
logger.warning(
"Failed to handle device list update for %s: %s", user_id, e,
"Failed to handle device list update for %s: %s",
user_id,
e,
)
if mark_failed_as_stale:
@ -931,7 +939,9 @@ class DeviceListUpdater:
# Handle cross-signing keys.
cross_signing_device_ids = await self.process_cross_signing_key_update(
user_id, master_key, self_signing_key,
user_id,
master_key,
self_signing_key,
)
device_ids = device_ids + cross_signing_device_ids

View File

@ -62,7 +62,8 @@ class DeviceMessageHandler:
)
else:
hs.get_federation_registry().register_instances_for_edu(
"m.direct_to_device", hs.config.worker.writers.to_device,
"m.direct_to_device",
hs.config.worker.writers.to_device,
)
# The handler to call when we think a user's device list might be out of
@ -73,8 +74,8 @@ class DeviceMessageHandler:
hs.get_device_handler().device_list_updater.user_device_resync
)
else:
self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client(
hs
self._user_device_resync = (
ReplicationUserDevicesResyncRestServlet.make_client(hs)
)
async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None:

View File

@ -61,8 +61,8 @@ class E2eKeysHandler:
self._is_master = hs.config.worker_app is None
if not self._is_master:
self._user_device_resync_client = ReplicationUserDevicesResyncRestServlet.make_client(
hs
self._user_device_resync_client = (
ReplicationUserDevicesResyncRestServlet.make_client(hs)
)
else:
# Only register this edu handler on master as it requires writing
@ -85,7 +85,7 @@ class E2eKeysHandler:
async def query_devices(
self, query_body: JsonDict, timeout: int, from_user_id: str
) -> JsonDict:
""" Handle a device key query from a client
"""Handle a device key query from a client
{
"device_keys": {
@ -391,8 +391,7 @@ class E2eKeysHandler:
async def on_federation_query_client_keys(
self, query_body: Dict[str, Dict[str, Optional[List[str]]]]
) -> JsonDict:
""" Handle a device key query from a federated server
"""
"""Handle a device key query from a federated server"""
device_keys_query = query_body.get(
"device_keys", {}
) # type: Dict[str, Optional[List[str]]]
@ -1065,7 +1064,9 @@ class E2eKeysHandler:
return key, key_id, verify_key
async def _retrieve_cross_signing_keys_for_remote_user(
self, user: UserID, desired_key_type: str,
self,
user: UserID,
desired_key_type: str,
) -> Tuple[Optional[dict], Optional[str], Optional[VerifyKey]]:
"""Queries cross-signing keys for a remote user and saves them to the database
@ -1269,8 +1270,7 @@ def _one_time_keys_match(old_key_json: str, new_key: JsonDict) -> bool:
@attr.s(slots=True)
class SignatureListItem:
"""An item in the signature list as used by upload_signatures_for_device_keys.
"""
"""An item in the signature list as used by upload_signatures_for_device_keys."""
signing_key_id = attr.ib(type=str)
target_user_id = attr.ib(type=str)
@ -1355,8 +1355,12 @@ class SigningKeyEduUpdater:
logger.info("pending updates: %r", pending_updates)
for master_key, self_signing_key in pending_updates:
new_device_ids = await device_list_updater.process_cross_signing_key_update(
user_id, master_key, self_signing_key,
new_device_ids = (
await device_list_updater.process_cross_signing_key_update(
user_id,
master_key,
self_signing_key,
)
)
device_ids = device_ids + new_device_ids

View File

@ -57,8 +57,7 @@ class EventStreamHandler(BaseHandler):
room_id: Optional[str] = None,
is_guest: bool = False,
) -> JsonDict:
"""Fetches the events stream for a given user.
"""
"""Fetches the events stream for a given user."""
if room_id:
blocked = await self.store.is_room_blocked(room_id)

View File

@ -111,13 +111,13 @@ class _NewEventInfo:
class FederationHandler(BaseHandler):
"""Handles events that originated from federation.
Responsible for:
a) handling received Pdus before handing them on as Events to the rest
of the homeserver (including auth and state conflict resolutions)
b) converting events that were produced by local clients that may need
to be sent to remote homeservers.
c) doing the necessary dances to invite remote users and join remote
rooms.
Responsible for:
a) handling received Pdus before handing them on as Events to the rest
of the homeserver (including auth and state conflict resolutions)
b) converting events that were produced by local clients that may need
to be sent to remote homeservers.
c) doing the necessary dances to invite remote users and join remote
rooms.
"""
def __init__(self, hs: "HomeServer"):
@ -150,11 +150,11 @@ class FederationHandler(BaseHandler):
)
if hs.config.worker_app:
self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client(
hs
self._user_device_resync = (
ReplicationUserDevicesResyncRestServlet.make_client(hs)
)
self._maybe_store_room_on_outlier_membership = ReplicationStoreRoomOnOutlierMembershipRestServlet.make_client(
hs
self._maybe_store_room_on_outlier_membership = (
ReplicationStoreRoomOnOutlierMembershipRestServlet.make_client(hs)
)
else:
self._device_list_updater = hs.get_device_handler().device_list_updater
@ -172,7 +172,7 @@ class FederationHandler(BaseHandler):
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
""" Process a PDU received via a federation /send/ transaction, or
"""Process a PDU received via a federation /send/ transaction, or
via backfill of missing prev_events
Args:
@ -368,7 +368,8 @@ class FederationHandler(BaseHandler):
# know about
for p in prevs - seen:
logger.info(
"Requesting state at missing prev_event %s", event_id,
"Requesting state at missing prev_event %s",
event_id,
)
with nested_logging_context(p):
@ -388,12 +389,14 @@ class FederationHandler(BaseHandler):
event_map[x.event_id] = x
room_version = await self.store.get_room_version_id(room_id)
state_map = await self._state_resolution_handler.resolve_events_with_store(
room_id,
room_version,
state_maps,
event_map,
state_res_store=StateResolutionStore(self.store),
state_map = (
await self._state_resolution_handler.resolve_events_with_store(
room_id,
room_version,
state_maps,
event_map,
state_res_store=StateResolutionStore(self.store),
)
)
# We need to give _process_received_pdu the actual state events
@ -687,9 +690,12 @@ class FederationHandler(BaseHandler):
return fetched_events
async def _process_received_pdu(
self, origin: str, event: EventBase, state: Optional[Iterable[EventBase]],
self,
origin: str,
event: EventBase,
state: Optional[Iterable[EventBase]],
):
""" Called when we have a new pdu. We need to do auth checks and put it
"""Called when we have a new pdu. We need to do auth checks and put it
through the StateHandler.
Args:
@ -801,7 +807,7 @@ class FederationHandler(BaseHandler):
@log_function
async def backfill(self, dest, room_id, limit, extremities):
""" Trigger a backfill request to `dest` for the given `room_id`
"""Trigger a backfill request to `dest` for the given `room_id`
This will attempt to get more events from the remote. If the other side
has no new events to offer, this will return an empty list.
@ -1204,11 +1210,16 @@ class FederationHandler(BaseHandler):
with nested_logging_context(event_id):
try:
event = await self.federation_client.get_pdu(
[destination], event_id, room_version, outlier=True,
[destination],
event_id,
room_version,
outlier=True,
)
if event is None:
logger.warning(
"Server %s didn't return event %s", destination, event_id,
"Server %s didn't return event %s",
destination,
event_id,
)
return
@ -1235,7 +1246,8 @@ class FederationHandler(BaseHandler):
if aid not in event_map
]
persisted_events = await self.store.get_events(
auth_events, allow_rejected=True,
auth_events,
allow_rejected=True,
)
event_infos = []
@ -1251,7 +1263,9 @@ class FederationHandler(BaseHandler):
event_infos.append(_NewEventInfo(event, None, auth))
await self._handle_new_events(
destination, room_id, event_infos,
destination,
room_id,
event_infos,
)
def _sanity_check_event(self, ev):
@ -1287,7 +1301,7 @@ class FederationHandler(BaseHandler):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events")
async def send_invite(self, target_host, event):
""" Sends the invite to the remote server for signing.
"""Sends the invite to the remote server for signing.
Invites must be signed by the invitee's server before distribution.
"""
@ -1310,7 +1324,7 @@ class FederationHandler(BaseHandler):
async def do_invite_join(
self, target_hosts: Iterable[str], room_id: str, joinee: str, content: JsonDict
) -> Tuple[str, int]:
""" Attempts to join the `joinee` to the room `room_id` via the
"""Attempts to join the `joinee` to the room `room_id` via the
servers contained in `target_hosts`.
This first triggers a /make_join/ request that returns a partial
@ -1388,7 +1402,8 @@ class FederationHandler(BaseHandler):
# so we can rely on it now.
#
await self.store.upsert_room_on_join(
room_id=room_id, room_version=room_version_obj,
room_id=room_id,
room_version=room_version_obj,
)
max_stream_id = await self._persist_auth_tree(
@ -1458,7 +1473,7 @@ class FederationHandler(BaseHandler):
async def on_make_join_request(
self, origin: str, room_id: str, user_id: str
) -> EventBase:
""" We've received a /make_join/ request, so we create a partial
"""We've received a /make_join/ request, so we create a partial
join event for the room and return that. We do *not* persist or
process it until the other server has signed it and sent it back.
@ -1483,7 +1498,8 @@ class FederationHandler(BaseHandler):
is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
if not is_in_room:
logger.info(
"Got /make_join request for room %s we are no longer in", room_id,
"Got /make_join request for room %s we are no longer in",
room_id,
)
raise NotFoundError("Not an active room on this server")
@ -1517,7 +1533,7 @@ class FederationHandler(BaseHandler):
return event
async def on_send_join_request(self, origin, pdu):
""" We have received a join event for a room. Fully process it and
"""We have received a join event for a room. Fully process it and
respond with the current state and auth chains.
"""
event = pdu
@ -1573,7 +1589,7 @@ class FederationHandler(BaseHandler):
async def on_invite_request(
self, origin: str, event: EventBase, room_version: RoomVersion
):
""" We've got an invite event. Process and persist it. Sign it.
"""We've got an invite event. Process and persist it. Sign it.
Respond with the now signed event.
"""
@ -1700,7 +1716,7 @@ class FederationHandler(BaseHandler):
async def on_make_leave_request(
self, origin: str, room_id: str, user_id: str
) -> EventBase:
""" We've received a /make_leave/ request, so we create a partial
"""We've received a /make_leave/ request, so we create a partial
leave event for the room and return that. We do *not* persist or
process it until the other server has signed it and sent it back.
@ -1776,8 +1792,7 @@ class FederationHandler(BaseHandler):
return None
async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]:
"""Returns the state at the event. i.e. not including said event.
"""
"""Returns the state at the event. i.e. not including said event."""
event = await self.store.get_event(event_id, check_room_id=room_id)
@ -1803,8 +1818,7 @@ class FederationHandler(BaseHandler):
return []
async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
"""Returns the state at the event. i.e. not including said event.
"""
"""Returns the state at the event. i.e. not including said event."""
event = await self.store.get_event(event_id, check_room_id=room_id)
state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
@ -2010,7 +2024,11 @@ class FederationHandler(BaseHandler):
for e_id in missing_auth_events:
m_ev = await self.federation_client.get_pdu(
[origin], e_id, room_version=room_version, outlier=True, timeout=10000,
[origin],
e_id,
room_version=room_version,
outlier=True,
timeout=10000,
)
if m_ev and m_ev.event_id == e_id:
event_map[e_id] = m_ev
@ -2160,7 +2178,9 @@ class FederationHandler(BaseHandler):
)
logger.debug(
"Doing soft-fail check for %s: state %s", event.event_id, current_state_ids,
"Doing soft-fail check for %s: state %s",
event.event_id,
current_state_ids,
)
# Now check if event pass auth against said current state
@ -2513,7 +2533,7 @@ class FederationHandler(BaseHandler):
async def construct_auth_difference(
self, local_auth: Iterable[EventBase], remote_auth: Iterable[EventBase]
) -> Dict:
""" Given a local and remote auth chain, find the differences. This
"""Given a local and remote auth chain, find the differences. This
assumes that we have already processed all events in remote_auth
Params:

View File

@ -146,8 +146,7 @@ class GroupsLocalWorkerHandler:
async def get_users_in_group(
self, group_id: str, requester_user_id: str
) -> JsonDict:
"""Get users in a group
"""
"""Get users in a group"""
if self.is_mine_id(group_id):
return await self.groups_server_handler.get_users_in_group(
group_id, requester_user_id
@ -283,8 +282,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
async def create_group(
self, group_id: str, user_id: str, content: JsonDict
) -> JsonDict:
"""Create a group
"""
"""Create a group"""
logger.info("Asking to create group with ID: %r", group_id)
@ -314,8 +312,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
async def join_group(
self, group_id: str, user_id: str, content: JsonDict
) -> JsonDict:
"""Request to join a group
"""
"""Request to join a group"""
if self.is_mine_id(group_id):
await self.groups_server_handler.join_group(group_id, user_id, content)
local_attestation = None
@ -361,8 +358,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
async def accept_invite(
self, group_id: str, user_id: str, content: JsonDict
) -> JsonDict:
"""Accept an invite to a group
"""
"""Accept an invite to a group"""
if self.is_mine_id(group_id):
await self.groups_server_handler.accept_invite(group_id, user_id, content)
local_attestation = None
@ -408,8 +404,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
async def invite(
self, group_id: str, user_id: str, requester_user_id: str, config: JsonDict
) -> JsonDict:
"""Invite a user to a group
"""
"""Invite a user to a group"""
content = {"requester_user_id": requester_user_id, "config": config}
if self.is_mine_id(group_id):
res = await self.groups_server_handler.invite_to_group(
@ -434,8 +429,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
async def on_invite(
self, group_id: str, user_id: str, content: JsonDict
) -> JsonDict:
"""One of our users were invited to a group
"""
"""One of our users were invited to a group"""
# TODO: Support auto join and rejection
if not self.is_mine_id(user_id):
@ -466,8 +460,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
async def remove_user_from_group(
self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict
) -> JsonDict:
"""Remove a user from a group
"""
"""Remove a user from a group"""
if user_id == requester_user_id:
token = await self.store.register_user_group_membership(
group_id, user_id, membership="leave"
@ -501,8 +494,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
async def user_removed_from_group(
self, group_id: str, user_id: str, content: JsonDict
) -> None:
"""One of our users was removed/kicked from a group
"""
"""One of our users was removed/kicked from a group"""
# TODO: Check if user in group
token = await self.store.register_user_group_membership(
group_id, user_id, membership="leave"

View File

@ -72,7 +72,10 @@ class IdentityHandler(BaseHandler):
)
def ratelimit_request_token_requests(
self, request: SynapseRequest, medium: str, address: str,
self,
request: SynapseRequest,
medium: str,
address: str,
):
"""Used to ratelimit requests to `/requestToken` by IP and address.

View File

@ -124,7 +124,8 @@ class InitialSyncHandler(BaseHandler):
joined_rooms = [r.room_id for r in room_list if r.membership == Membership.JOIN]
receipt = await self.store.get_linearized_receipts_for_rooms(
joined_rooms, to_key=int(now_token.receipt_key),
joined_rooms,
to_key=int(now_token.receipt_key),
)
tags_by_room = await self.store.get_tags_for_user(user_id)
@ -169,7 +170,10 @@ class InitialSyncHandler(BaseHandler):
self.state_handler.get_current_state, event.room_id
)
elif event.membership == Membership.LEAVE:
room_end_token = RoomStreamToken(None, event.stream_ordering,)
room_end_token = RoomStreamToken(
None,
event.stream_ordering,
)
deferred_room_state = run_in_background(
self.state_store.get_state_for_events, [event.event_id]
)
@ -284,7 +288,9 @@ class InitialSyncHandler(BaseHandler):
membership,
member_event_id,
) = await self.auth.check_user_in_room_or_world_readable(
room_id, user_id, allow_departed_users=True,
room_id,
user_id,
allow_departed_users=True,
)
is_peeking = member_event_id is None

View File

@ -65,8 +65,7 @@ logger = logging.getLogger(__name__)
class MessageHandler:
"""Contains some read only APIs to get state about a room
"""
"""Contains some read only APIs to get state about a room"""
def __init__(self, hs):
self.auth = hs.get_auth()
@ -88,9 +87,13 @@ class MessageHandler:
)
async def get_room_data(
self, user_id: str, room_id: str, event_type: str, state_key: str,
self,
user_id: str,
room_id: str,
event_type: str,
state_key: str,
) -> dict:
""" Get data from a room.
"""Get data from a room.
Args:
user_id
@ -174,7 +177,10 @@ class MessageHandler:
raise NotFoundError("Can't find event for token %s" % (at_token,))
visible_events = await filter_events_for_client(
self.storage, user_id, last_events, filter_send_to_client=False,
self.storage,
user_id,
last_events,
filter_send_to_client=False,
)
event = last_events[0]
@ -571,7 +577,7 @@ class EventCreationHandler:
async def _is_exempt_from_privacy_policy(
self, builder: EventBuilder, requester: Requester
) -> bool:
""""Determine if an event to be sent is exempt from having to consent
""" "Determine if an event to be sent is exempt from having to consent
to the privacy policy
Args:
@ -793,9 +799,10 @@ class EventCreationHandler:
"""
if prev_event_ids is not None:
assert len(prev_event_ids) <= 10, (
"Attempting to create an event with %i prev_events"
% (len(prev_event_ids),)
assert (
len(prev_event_ids) <= 10
), "Attempting to create an event with %i prev_events" % (
len(prev_event_ids),
)
else:
prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id)
@ -821,7 +828,8 @@ class EventCreationHandler:
)
if not third_party_result:
logger.info(
"Event %s forbidden by third-party rules", event,
"Event %s forbidden by third-party rules",
event,
)
raise SynapseError(
403, "This event is not allowed in this context", Codes.FORBIDDEN
@ -1316,7 +1324,11 @@ class EventCreationHandler:
# Since this is a dummy-event it is OK if it is sent by a
# shadow-banned user.
await self.handle_new_client_event(
requester, event, context, ratelimit=False, ignore_shadow_ban=True,
requester,
event,
context,
ratelimit=False,
ignore_shadow_ban=True,
)
return True
except AuthError:

View File

@ -73,8 +73,7 @@ JWKS = TypedDict("JWKS", {"keys": List[JWK]})
class OidcHandler:
"""Handles requests related to the OpenID Connect login flow.
"""
"""Handles requests related to the OpenID Connect login flow."""
def __init__(self, hs: "HomeServer"):
self._sso_handler = hs.get_sso_handler()
@ -216,8 +215,7 @@ class OidcHandler:
class OidcError(Exception):
"""Used to catch errors when calling the token_endpoint
"""
"""Used to catch errors when calling the token_endpoint"""
def __init__(self, error, error_description=None):
self.error = error
@ -252,7 +250,9 @@ class OidcProvider:
self._scopes = provider.scopes
self._user_profile_method = provider.user_profile_method
self._client_auth = ClientAuth(
provider.client_id, provider.client_secret, provider.client_auth_method,
provider.client_id,
provider.client_secret,
provider.client_auth_method,
) # type: ClientAuth
self._client_auth_method = provider.client_auth_method
@ -509,7 +509,10 @@ class OidcProvider:
# We're not using the SimpleHttpClient util methods as we don't want to
# check the HTTP status code and we do the body encoding ourself.
response = await self._http_client.request(
method="POST", uri=uri, data=body.encode("utf-8"), headers=headers,
method="POST",
uri=uri,
data=body.encode("utf-8"),
headers=headers,
)
# This is used in multiple error messages below
@ -966,7 +969,9 @@ class OidcSessionTokenGenerator:
A signed macaroon token with the session information.
"""
macaroon = pymacaroons.Macaroon(
location=self._server_name, identifier="key", key=self._macaroon_secret_key,
location=self._server_name,
identifier="key",
key=self._macaroon_secret_key,
)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = session")

View File

@ -197,7 +197,8 @@ class PaginationHandler:
stream_ordering = await self.store.find_first_stream_ordering_after_ts(ts)
r = await self.store.get_room_event_before_stream_ordering(
room_id, stream_ordering,
room_id,
stream_ordering,
)
if not r:
logger.warning(
@ -223,7 +224,12 @@ class PaginationHandler:
# the background so that it's not blocking any other operation apart from
# other purges in the same room.
run_as_background_process(
"_purge_history", self._purge_history, purge_id, room_id, token, True,
"_purge_history",
self._purge_history,
purge_id,
room_id,
token,
True,
)
def start_purge_history(
@ -389,7 +395,9 @@ class PaginationHandler:
)
await self.hs.get_federation_handler().maybe_backfill(
room_id, curr_topo, limit=pagin_config.limit,
room_id,
curr_topo,
limit=pagin_config.limit,
)
to_room_key = None

View File

@ -635,8 +635,7 @@ class PresenceHandler(BasePresenceHandler):
self.external_process_last_updated_ms.pop(process_id, None)
async def current_state_for_user(self, user_id):
"""Get the current presence state for a user.
"""
"""Get the current presence state for a user."""
res = await self.current_state_for_users([user_id])
return res[user_id]
@ -678,8 +677,7 @@ class PresenceHandler(BasePresenceHandler):
self.federation.send_presence(states)
async def incoming_presence(self, origin, content):
"""Called when we receive a `m.presence` EDU from a remote server.
"""
"""Called when we receive a `m.presence` EDU from a remote server."""
if not self._presence_enabled:
return
@ -729,8 +727,7 @@ class PresenceHandler(BasePresenceHandler):
await self._update_states(updates)
async def set_state(self, target_user, state, ignore_status_msg=False):
"""Set the presence state of the user.
"""
"""Set the presence state of the user."""
status_msg = state.get("status_msg", None)
presence = state["presence"]
@ -758,8 +755,7 @@ class PresenceHandler(BasePresenceHandler):
await self._update_states([prev_state.copy_and_replace(**new_fields)])
async def is_visible(self, observed_user, observer_user):
"""Returns whether a user can see another user's presence.
"""
"""Returns whether a user can see another user's presence."""
observer_room_ids = await self.store.get_rooms_for_user(
observer_user.to_string()
)
@ -953,8 +949,7 @@ class PresenceHandler(BasePresenceHandler):
def should_notify(old_state, new_state):
"""Decides if a presence state change should be sent to interested parties.
"""
"""Decides if a presence state change should be sent to interested parties."""
if old_state == new_state:
return False

View File

@ -207,7 +207,8 @@ class ProfileHandler(BaseHandler):
# This must be done by the target user himself.
if by_admin:
requester = create_requester(
target_user, authenticated_entity=requester.authenticated_entity,
target_user,
authenticated_entity=requester.authenticated_entity,
)
await self.store.set_profile_displayname(

View File

@ -49,15 +49,15 @@ class ReceiptsHandler(BaseHandler):
)
else:
hs.get_federation_registry().register_instances_for_edu(
"m.receipt", hs.config.worker.writers.receipts,
"m.receipt",
hs.config.worker.writers.receipts,
)
self.clock = self.hs.get_clock()
self.state = hs.get_state_handler()
async def _received_remote_receipt(self, origin: str, content: JsonDict) -> None:
"""Called when we receive an EDU of type m.receipt from a remote HS.
"""
"""Called when we receive an EDU of type m.receipt from a remote HS."""
receipts = []
for room_id, room_values in content.items():
for receipt_type, users in room_values.items():
@ -83,8 +83,7 @@ class ReceiptsHandler(BaseHandler):
await self._handle_new_receipts(receipts)
async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool:
"""Takes a list of receipts, stores them and informs the notifier.
"""
"""Takes a list of receipts, stores them and informs the notifier."""
min_batch_id = None # type: Optional[int]
max_batch_id = None # type: Optional[int]

View File

@ -62,8 +62,8 @@ class RegistrationHandler(BaseHandler):
self._register_device_client = RegisterDeviceReplicationServlet.make_client(
hs
)
self._post_registration_client = ReplicationPostRegisterActionsServlet.make_client(
hs
self._post_registration_client = (
ReplicationPostRegisterActionsServlet.make_client(hs)
)
else:
self.device_handler = hs.get_device_handler()
@ -189,12 +189,15 @@ class RegistrationHandler(BaseHandler):
self.check_registration_ratelimit(address)
result = await self.spam_checker.check_registration_for_spam(
threepid, localpart, user_agent_ips or [],
threepid,
localpart,
user_agent_ips or [],
)
if result == RegistrationBehaviour.DENY:
logger.info(
"Blocked registration of %r", localpart,
"Blocked registration of %r",
localpart,
)
# We return a 429 to make it not obvious that they've been
# denied.
@ -203,7 +206,8 @@ class RegistrationHandler(BaseHandler):
shadow_banned = result == RegistrationBehaviour.SHADOW_BAN
if shadow_banned:
logger.info(
"Shadow banning registration of %r", localpart,
"Shadow banning registration of %r",
localpart,
)
# do not check_auth_blocking if the call is coming through the Admin API
@ -369,7 +373,9 @@ class RegistrationHandler(BaseHandler):
config["room_alias_name"] = room_alias.localpart
info, _ = await room_creation_handler.create_room(
fake_requester, config=config, ratelimit=False,
fake_requester,
config=config,
ratelimit=False,
)
# If the room does not require an invite, but another user
@ -753,7 +759,10 @@ class RegistrationHandler(BaseHandler):
return
await self._auth_handler.add_threepid(
user_id, threepid["medium"], threepid["address"], threepid["validated_at"],
user_id,
threepid["medium"],
threepid["address"],
threepid["validated_at"],
)
# And we add an email pusher for them by default, but only
@ -805,5 +814,8 @@ class RegistrationHandler(BaseHandler):
raise
await self._auth_handler.add_threepid(
user_id, threepid["medium"], threepid["address"], threepid["validated_at"],
user_id,
threepid["medium"],
threepid["address"],
threepid["validated_at"],
)

View File

@ -198,7 +198,9 @@ class RoomCreationHandler(BaseHandler):
if r is None:
raise NotFoundError("Unknown room id %s" % (old_room_id,))
new_room_id = await self._generate_room_id(
creator_id=user_id, is_public=r["is_public"], room_version=new_version,
creator_id=user_id,
is_public=r["is_public"],
room_version=new_version,
)
logger.info("Creating new room %s to replace %s", new_room_id, old_room_id)
@ -236,7 +238,9 @@ class RoomCreationHandler(BaseHandler):
# now send the tombstone
await self.event_creation_handler.handle_new_client_event(
requester=requester, event=tombstone_event, context=tombstone_context,
requester=requester,
event=tombstone_event,
context=tombstone_context,
)
old_room_state = await tombstone_context.get_current_state_ids()
@ -257,7 +261,10 @@ class RoomCreationHandler(BaseHandler):
# finally, shut down the PLs in the old room, and update them in the new
# room.
await self._update_upgraded_room_pls(
requester, old_room_id, new_room_id, old_room_state,
requester,
old_room_id,
new_room_id,
old_room_state,
)
return new_room_id
@ -570,7 +577,7 @@ class RoomCreationHandler(BaseHandler):
ratelimit: bool = True,
creator_join_profile: Optional[JsonDict] = None,
) -> Tuple[dict, int]:
""" Creates a new room.
"""Creates a new room.
Args:
requester:
@ -691,7 +698,9 @@ class RoomCreationHandler(BaseHandler):
is_public = visibility == "public"
room_id = await self._generate_room_id(
creator_id=user_id, is_public=is_public, room_version=room_version,
creator_id=user_id,
is_public=is_public,
room_version=room_version,
)
# Check whether this visibility value is blocked by a third party module
@ -884,7 +893,10 @@ class RoomCreationHandler(BaseHandler):
_,
last_stream_id,
) = await self.event_creation_handler.create_and_send_nonmember_event(
creator, event, ratelimit=False, ignore_shadow_ban=True,
creator,
event,
ratelimit=False,
ignore_shadow_ban=True,
)
return last_stream_id
@ -984,7 +996,10 @@ class RoomCreationHandler(BaseHandler):
return last_sent_stream_id
async def _generate_room_id(
self, creator_id: str, is_public: bool, room_version: RoomVersion,
self,
creator_id: str,
is_public: bool,
room_version: RoomVersion,
):
# autogen room IDs and try to create it. We may clash, so just
# try a few times till one goes through, giving up eventually.

View File

@ -191,7 +191,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# do it up front for efficiency.)
if txn_id and requester.access_token_id:
existing_event_id = await self.store.get_event_id_from_transaction_id(
room_id, requester.user.to_string(), requester.access_token_id, txn_id,
room_id,
requester.user.to_string(),
requester.access_token_id,
txn_id,
)
if existing_event_id:
event_pos = await self.store.get_position_for_event(existing_event_id)
@ -238,7 +241,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
)
result_event = await self.event_creation_handler.handle_new_client_event(
requester, event, context, extra_users=[target], ratelimit=ratelimit,
requester,
event,
context,
extra_users=[target],
ratelimit=ratelimit,
)
if event.membership == Membership.LEAVE:
@ -583,7 +590,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# send the rejection to the inviter's HS (with fallback to
# local event)
return await self.remote_reject_invite(
invite.event_id, txn_id, requester, content,
invite.event_id,
txn_id,
requester,
content,
)
# the inviter was on our server, but has now left. Carry on
@ -1056,8 +1066,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
user: UserID,
content: dict,
) -> Tuple[str, int]:
"""Implements RoomMemberHandler._remote_join
"""
"""Implements RoomMemberHandler._remote_join"""
# filter ourselves out of remote_room_hosts: do_invite_join ignores it
# and if it is the only entry we'd like to return a 404 rather than a
# 500.
@ -1211,7 +1220,10 @@ class RoomMemberMasterHandler(RoomMemberHandler):
event.internal_metadata.out_of_band_membership = True
result_event = await self.event_creation_handler.handle_new_client_event(
requester, event, context, extra_users=[UserID.from_string(target_user)],
requester,
event,
context,
extra_users=[UserID.from_string(target_user)],
)
# we know it was persisted, so must have a stream ordering
assert result_event.internal_metadata.stream_ordering
@ -1219,8 +1231,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
return result_event.event_id, result_event.internal_metadata.stream_ordering
async def _user_left_room(self, target: UserID, room_id: str) -> None:
"""Implements RoomMemberHandler._user_left_room
"""
"""Implements RoomMemberHandler._user_left_room"""
user_left_room(self.distributor, target, room_id)
async def forget(self, user: UserID, room_id: str) -> None:

View File

@ -44,8 +44,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
user: UserID,
content: dict,
) -> Tuple[str, int]:
"""Implements RoomMemberHandler._remote_join
"""
"""Implements RoomMemberHandler._remote_join"""
if len(remote_room_hosts) == 0:
raise SynapseError(404, "No known servers")
@ -80,8 +79,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
return ret["event_id"], ret["stream_id"]
async def _user_left_room(self, target: UserID, room_id: str) -> None:
"""Implements RoomMemberHandler._user_left_room
"""
"""Implements RoomMemberHandler._user_left_room"""
await self._notify_change_client(
user_id=target.to_string(), room_id=room_id, change="left"
)

View File

@ -121,7 +121,8 @@ class SamlHandler(BaseHandler):
now = self.clock.time_msec()
self._outstanding_requests_dict[reqid] = Saml2SessionData(
creation_time=now, ui_auth_session_id=ui_auth_session_id,
creation_time=now,
ui_auth_session_id=ui_auth_session_id,
)
for key, value in info["headers"]:
@ -450,7 +451,8 @@ class DefaultSamlMappingProvider:
mxid_source = saml_response.ava[self._mxid_source_attribute][0]
except KeyError:
logger.warning(
"SAML2 response lacks a '%s' attestation", self._mxid_source_attribute,
"SAML2 response lacks a '%s' attestation",
self._mxid_source_attribute,
)
raise SynapseError(
400, "%s not in SAML2 response" % (self._mxid_source_attribute,)

View File

@ -327,7 +327,8 @@ class SsoHandler:
# Check if we already have a mapping for this user.
previously_registered_user_id = await self._store.get_user_by_external_id(
auth_provider_id, remote_user_id,
auth_provider_id,
remote_user_id,
)
# A match was found, return the user ID.
@ -416,7 +417,8 @@ class SsoHandler:
with await self._mapping_lock.queue(auth_provider_id):
# first of all, check if we already have a mapping for this user
user_id = await self.get_sso_user_by_remote_user_id(
auth_provider_id, remote_user_id,
auth_provider_id,
remote_user_id,
)
# Check for grandfathering of users.
@ -461,7 +463,8 @@ class SsoHandler:
)
async def _call_attribute_mapper(
self, sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
self,
sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
) -> UserAttributes:
"""Call the attribute mapper function in a loop, until we get a unique userid"""
for i in range(self._MAP_USERNAME_RETRIES):
@ -632,7 +635,8 @@ class SsoHandler:
"""
user_id = await self.get_sso_user_by_remote_user_id(
auth_provider_id, remote_user_id,
auth_provider_id,
remote_user_id,
)
user_id_to_verify = await self._auth_handler.get_session_data(
@ -671,7 +675,8 @@ class SsoHandler:
# render an error page.
html = self._bad_user_template.render(
server_name=self._server_name, user_id_to_verify=user_id_to_verify,
server_name=self._server_name,
user_id_to_verify=user_id_to_verify,
)
respond_with_html(request, 200, html)
@ -695,7 +700,9 @@ class SsoHandler:
raise SynapseError(400, "unknown session")
async def check_username_availability(
self, localpart: str, session_id: str,
self,
localpart: str,
session_id: str,
) -> bool:
"""Handle an "is username available" callback check
@ -833,7 +840,8 @@ class SsoHandler:
)
attributes = UserAttributes(
localpart=session.chosen_localpart, emails=session.emails_to_use,
localpart=session.chosen_localpart,
emails=session.emails_to_use,
)
if session.use_display_name:

View File

@ -63,8 +63,7 @@ class StatsHandler:
self.clock.call_later(0, self.notify_new_event)
def notify_new_event(self) -> None:
"""Called when there may be more deltas to process
"""
"""Called when there may be more deltas to process"""
if not self.stats_enabled or self._is_processing:
return

View File

@ -339,8 +339,7 @@ class SyncHandler:
since_token: Optional[StreamToken] = None,
full_state: bool = False,
) -> SyncResult:
"""Get the sync for client needed to match what the server has now.
"""
"""Get the sync for client needed to match what the server has now."""
return await self.generate_sync_result(sync_config, since_token, full_state)
async def push_rules_for_user(self, user: UserID) -> JsonDict:
@ -564,7 +563,7 @@ class SyncHandler:
stream_position: StreamToken,
state_filter: StateFilter = StateFilter.all(),
) -> StateMap[str]:
""" Get the room state at a particular stream position
"""Get the room state at a particular stream position
Args:
room_id: room for which to get state
@ -598,7 +597,7 @@ class SyncHandler:
state: MutableStateMap[EventBase],
now_token: StreamToken,
) -> Optional[JsonDict]:
""" Works out a room summary block for this room, summarising the number
"""Works out a room summary block for this room, summarising the number
of joined members in the room, and providing the 'hero' members if the
room has no name so clients can consistently name rooms. Also adds
state events to 'state' if needed to describe the heroes.
@ -743,7 +742,7 @@ class SyncHandler:
now_token: StreamToken,
full_state: bool,
) -> MutableStateMap[EventBase]:
""" Works out the difference in state between the start of the timeline
"""Works out the difference in state between the start of the timeline
and the previous sync.
Args:
@ -820,8 +819,10 @@ class SyncHandler:
)
elif batch.limited:
if batch:
state_at_timeline_start = await self.state_store.get_state_ids_for_event(
batch.events[0].event_id, state_filter=state_filter
state_at_timeline_start = (
await self.state_store.get_state_ids_for_event(
batch.events[0].event_id, state_filter=state_filter
)
)
else:
# We can get here if the user has ignored the senders of all
@ -955,8 +956,7 @@ class SyncHandler:
since_token: Optional[StreamToken] = None,
full_state: bool = False,
) -> SyncResult:
"""Generates a sync result.
"""
"""Generates a sync result."""
# NB: The now_token gets changed by some of the generate_sync_* methods,
# this is due to some of the underlying streams not supporting the ability
# to query up to a given point.
@ -1030,8 +1030,8 @@ class SyncHandler:
one_time_key_counts = await self.store.count_e2e_one_time_keys(
user_id, device_id
)
unused_fallback_key_types = await self.store.get_e2e_unused_fallback_key_types(
user_id, device_id
unused_fallback_key_types = (
await self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
)
logger.debug("Fetching group data")
@ -1176,8 +1176,10 @@ class SyncHandler:
# weren't in the previous sync *or* they left and rejoined.
users_that_have_changed.update(newly_joined_or_invited_users)
user_signatures_changed = await self.store.get_users_whose_signatures_changed(
user_id, since_token.device_list_key
user_signatures_changed = (
await self.store.get_users_whose_signatures_changed(
user_id, since_token.device_list_key
)
)
users_that_have_changed.update(user_signatures_changed)
@ -1393,8 +1395,10 @@ class SyncHandler:
logger.debug("no-oping sync")
return set(), set(), set(), set()
ignored_account_data = await self.store.get_global_account_data_by_type_for_user(
AccountDataTypes.IGNORED_USER_LIST, user_id=user_id
ignored_account_data = (
await self.store.get_global_account_data_by_type_for_user(
AccountDataTypes.IGNORED_USER_LIST, user_id=user_id
)
)
# If there is ignored users account data and it matches the proper type,
@ -1499,8 +1503,7 @@ class SyncHandler:
async def _get_rooms_changed(
self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str]
) -> _RoomChanges:
"""Gets the the changes that have happened since the last sync.
"""
"""Gets the the changes that have happened since the last sync."""
user_id = sync_result_builder.sync_config.user.to_string()
since_token = sync_result_builder.since_token
now_token = sync_result_builder.now_token

View File

@ -61,7 +61,8 @@ class FollowerTypingHandler:
if hs.config.worker.writers.typing != hs.get_instance_name():
hs.get_federation_registry().register_instance_for_edu(
"m.typing", hs.config.worker.writers.typing,
"m.typing",
hs.config.worker.writers.typing,
)
# map room IDs to serial numbers
@ -76,8 +77,7 @@ class FollowerTypingHandler:
self.clock.looping_call(self._handle_timeouts, 5000)
def _reset(self) -> None:
"""Reset the typing handler's data caches.
"""
"""Reset the typing handler's data caches."""
# map room IDs to serial numbers
self._room_serials = {}
# map room IDs to sets of users currently typing
@ -149,8 +149,7 @@ class FollowerTypingHandler:
def process_replication_rows(
self, token: int, rows: List[TypingStream.TypingStreamRow]
) -> None:
"""Should be called whenever we receive updates for typing stream.
"""
"""Should be called whenever we receive updates for typing stream."""
if self._latest_room_serial > token:
# The master has gone backwards. To prevent inconsistent data, just

View File

@ -97,8 +97,7 @@ class UserDirectoryHandler(StateDeltasHandler):
return results
def notify_new_event(self) -> None:
"""Called when there may be more deltas to process
"""
"""Called when there may be more deltas to process"""
if not self.update_user_directory:
return
@ -134,8 +133,7 @@ class UserDirectoryHandler(StateDeltasHandler):
)
async def handle_user_deactivated(self, user_id: str) -> None:
"""Called when a user ID is deactivated
"""
"""Called when a user ID is deactivated"""
# FIXME(#3714): We should probably do this in the same worker as all
# the other changes.
await self.store.remove_from_user_dir(user_id)
@ -172,8 +170,7 @@ class UserDirectoryHandler(StateDeltasHandler):
await self.store.update_user_directory_stream_pos(max_pos)
async def _handle_deltas(self, deltas: List[Dict[str, Any]]) -> None:
"""Called with the state deltas to process
"""
"""Called with the state deltas to process"""
for delta in deltas:
typ = delta["type"]
state_key = delta["state_key"]

View File

@ -54,8 +54,7 @@ class QuieterFileBodyProducer(FileBodyProducer):
def get_request_user_agent(request: IRequest, default: str = "") -> str:
"""Return the last User-Agent header, or the given default.
"""
"""Return the last User-Agent header, or the given default."""
# There could be raw utf-8 bytes in the User-Agent header.
# N.B. if you don't do this, the logger explodes cryptically

View File

@ -398,7 +398,8 @@ class SimpleHttpClient:
body_producer = None
if data is not None:
body_producer = QuieterFileBodyProducer(
BytesIO(data), cooperator=self._cooperator,
BytesIO(data),
cooperator=self._cooperator,
)
request_deferred = treq.request(
@ -413,7 +414,9 @@ class SimpleHttpClient:
# we use our own timeout mechanism rather than treq's as a workaround
# for https://twistedmatrix.com/trac/ticket/9534.
request_deferred = timeout_deferred(
request_deferred, 60, self.hs.get_reactor(),
request_deferred,
60,
self.hs.get_reactor(),
)
# turn timeouts into RequestTimedOutErrors

View File

@ -195,8 +195,7 @@ class MatrixFederationAgent:
@implementer(IAgentEndpointFactory)
class MatrixHostnameEndpointFactory:
"""Factory for MatrixHostnameEndpoint for parsing to an Agent.
"""
"""Factory for MatrixHostnameEndpoint for parsing to an Agent."""
def __init__(
self,
@ -261,8 +260,7 @@ class MatrixHostnameEndpoint:
self._srv_resolver = srv_resolver
def connect(self, protocol_factory: IProtocolFactory) -> defer.Deferred:
"""Implements IStreamClientEndpoint interface
"""
"""Implements IStreamClientEndpoint interface"""
return run_in_background(self._do_connect, protocol_factory)

View File

@ -81,8 +81,7 @@ class WellKnownLookupResult:
class WellKnownResolver:
"""Handles well-known lookups for matrix servers.
"""
"""Handles well-known lookups for matrix servers."""
def __init__(
self,

View File

@ -254,7 +254,8 @@ class MatrixFederationHttpClient:
# Use a BlacklistingAgentWrapper to prevent circumventing the IP
# blacklist via IP literals in server names
self.agent = BlacklistingAgentWrapper(
self.agent, ip_blacklist=hs.config.federation_ip_range_blacklist,
self.agent,
ip_blacklist=hs.config.federation_ip_range_blacklist,
)
self.clock = hs.get_clock()
@ -652,7 +653,7 @@ class MatrixFederationHttpClient:
backoff_on_404: bool = False,
try_trailing_slash_on_400: bool = False,
) -> Union[JsonDict, list]:
""" Sends the specified json data using PUT
"""Sends the specified json data using PUT
Args:
destination: The remote server to send the HTTP request to.
@ -740,7 +741,7 @@ class MatrixFederationHttpClient:
ignore_backoff: bool = False,
args: Optional[QueryArgs] = None,
) -> Union[JsonDict, list]:
""" Sends the specified json data using POST
"""Sends the specified json data using POST
Args:
destination: The remote server to send the HTTP request to.
@ -799,7 +800,11 @@ class MatrixFederationHttpClient:
_sec_timeout = self.default_timeout
body = await _handle_json_response(
self.reactor, _sec_timeout, request, response, start_ms,
self.reactor,
_sec_timeout,
request,
response,
start_ms,
)
return body
@ -813,7 +818,7 @@ class MatrixFederationHttpClient:
ignore_backoff: bool = False,
try_trailing_slash_on_400: bool = False,
) -> Union[JsonDict, list]:
""" GETs some json from the given host homeserver and path
"""GETs some json from the given host homeserver and path
Args:
destination: The remote server to send the HTTP request to.
@ -994,7 +999,10 @@ class MatrixFederationHttpClient:
except BodyExceededMaxSize:
msg = "Requested file is too large > %r bytes" % (max_size,)
logger.warning(
"{%s} [%s] %s", request.txn_id, request.destination, msg,
"{%s} [%s] %s",
request.txn_id,
request.destination,
msg,
)
raise SynapseError(502, msg, Codes.TOO_LARGE)
except Exception as e:

View File

@ -213,8 +213,7 @@ class RequestMetrics:
self.update_metrics()
def update_metrics(self):
"""Updates the in flight metrics with values from this request.
"""
"""Updates the in flight metrics with values from this request."""
new_stats = self.start_context.get_resource_usage()
diff = new_stats - self._request_stats

View File

@ -76,8 +76,7 @@ HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
"""Sends a JSON error response to clients.
"""
"""Sends a JSON error response to clients."""
if f.check(SynapseError):
error_code = f.value.code
@ -106,12 +105,17 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
pass
else:
respond_with_json(
request, error_code, error_dict, send_cors=True,
request,
error_code,
error_dict,
send_cors=True,
)
def return_html_error(
f: failure.Failure, request: Request, error_template: Union[str, jinja2.Template],
f: failure.Failure,
request: Request,
error_template: Union[str, jinja2.Template],
) -> None:
"""Sends an HTML error page corresponding to the given failure.
@ -189,8 +193,7 @@ ServletCallback = Callable[
class HttpServer(Protocol):
""" Interface for registering callbacks on a HTTP server
"""
"""Interface for registering callbacks on a HTTP server"""
def register_paths(
self,
@ -199,7 +202,7 @@ class HttpServer(Protocol):
callback: ServletCallback,
servlet_classname: str,
) -> None:
""" Register a callback that gets fired if we receive a http request
"""Register a callback that gets fired if we receive a http request
with the given method for a path that matches the given regex.
If the regex contains groups these gets passed to the callback via
@ -235,8 +238,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
self._extract_context = extract_context
def render(self, request):
""" This gets called by twisted every time someone sends us a request.
"""
"""This gets called by twisted every time someone sends us a request."""
defer.ensureDeferred(self._async_render_wrapper(request))
return NOT_DONE_YET
@ -287,13 +289,18 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
@abc.abstractmethod
def _send_response(
self, request: SynapseRequest, code: int, response_object: Any,
self,
request: SynapseRequest,
code: int,
response_object: Any,
) -> None:
raise NotImplementedError()
@abc.abstractmethod
def _send_error_response(
self, f: failure.Failure, request: SynapseRequest,
self,
f: failure.Failure,
request: SynapseRequest,
) -> None:
raise NotImplementedError()
@ -308,10 +315,12 @@ class DirectServeJsonResource(_AsyncResource):
self.canonical_json = canonical_json
def _send_response(
self, request: Request, code: int, response_object: Any,
self,
request: Request,
code: int,
response_object: Any,
):
"""Implements _AsyncResource._send_response
"""
"""Implements _AsyncResource._send_response"""
# TODO: Only enable CORS for the requests that need it.
respond_with_json(
request,
@ -322,15 +331,16 @@ class DirectServeJsonResource(_AsyncResource):
)
def _send_error_response(
self, f: failure.Failure, request: SynapseRequest,
self,
f: failure.Failure,
request: SynapseRequest,
) -> None:
"""Implements _AsyncResource._send_error_response
"""
"""Implements _AsyncResource._send_error_response"""
return_json_error(f, request)
class JsonResource(DirectServeJsonResource):
""" This implements the HttpServer interface and provides JSON support for
"""This implements the HttpServer interface and provides JSON support for
Resources.
Register callbacks via register_paths()
@ -443,10 +453,12 @@ class DirectServeHtmlResource(_AsyncResource):
ERROR_TEMPLATE = HTML_ERROR_TEMPLATE
def _send_response(
self, request: SynapseRequest, code: int, response_object: Any,
self,
request: SynapseRequest,
code: int,
response_object: Any,
):
"""Implements _AsyncResource._send_response
"""
"""Implements _AsyncResource._send_response"""
# We expect to get bytes for us to write
assert isinstance(response_object, bytes)
html_bytes = response_object
@ -454,10 +466,11 @@ class DirectServeHtmlResource(_AsyncResource):
respond_with_html_bytes(request, 200, html_bytes)
def _send_error_response(
self, f: failure.Failure, request: SynapseRequest,
self,
f: failure.Failure,
request: SynapseRequest,
) -> None:
"""Implements _AsyncResource._send_error_response
"""
"""Implements _AsyncResource._send_error_response"""
return_html_error(f, request, self.ERROR_TEMPLATE)
@ -534,7 +547,9 @@ class _ByteProducer:
min_chunk_size = 1024
def __init__(
self, request: Request, iterator: Iterator[bytes],
self,
request: Request,
iterator: Iterator[bytes],
):
self._request = request
self._iterator = iterator
@ -654,7 +669,10 @@ def respond_with_json(
def respond_with_json_bytes(
request: Request, code: int, json_bytes: bytes, send_cors: bool = False,
request: Request,
code: int,
json_bytes: bytes,
send_cors: bool = False,
):
"""Sends encoded JSON in response to the given request.
@ -769,7 +787,7 @@ def respond_with_redirect(request: Request, url: bytes) -> None:
def finish_request(request: Request):
""" Finish writing the response to the request.
"""Finish writing the response to the request.
Twisted throws a RuntimeException if the connection closed before the
response was written but doesn't provide a convenient or reliable way to

View File

@ -258,7 +258,7 @@ def assert_params_in_dict(body, required):
class RestServlet:
""" A Synapse REST Servlet.
"""A Synapse REST Servlet.
An implementing class can either provide its own custom 'register' method,
or use the automatic pattern handling provided by the base class.

View File

@ -249,8 +249,7 @@ class SynapseRequest(Request):
)
def _finished_processing(self):
"""Log the completion of this request and update the metrics
"""
"""Log the completion of this request and update the metrics"""
assert self.logcontext is not None
usage = self.logcontext.get_resource_usage()
@ -276,7 +275,8 @@ class SynapseRequest(Request):
# authenticated (e.g. and admin is puppetting a user) then we log both.
if self.requester.user.to_string() != authenticated_entity:
authenticated_entity = "{},{}".format(
authenticated_entity, self.requester.user.to_string(),
authenticated_entity,
self.requester.user.to_string(),
)
elif self.requester is not None:
# This shouldn't happen, but we log it so we don't lose information
@ -322,8 +322,7 @@ class SynapseRequest(Request):
logger.warning("Failed to stop metrics: %r", e)
def _should_log_request(self) -> bool:
"""Whether we should log at INFO that we processed the request.
"""
"""Whether we should log at INFO that we processed the request."""
if self.path == b"/health":
return False

View File

@ -174,7 +174,9 @@ class RemoteHandler(logging.Handler):
# Make a new producer and start it.
self._producer = LogProducer(
buffer=self._buffer, transport=result.transport, format=self.format,
buffer=self._buffer,
transport=result.transport,
format=self.format,
)
result.transport.registerProducer(self._producer, True)
self._producer.resumeProducing()

View File

@ -60,7 +60,10 @@ def parse_drain_configs(
)
# Either use the default formatter or the tersejson one.
if logging_type in (DrainType.CONSOLE_JSON, DrainType.FILE_JSON,):
if logging_type in (
DrainType.CONSOLE_JSON,
DrainType.FILE_JSON,
):
formatter = "json" # type: Optional[str]
elif logging_type in (
DrainType.CONSOLE_JSON_TERSE,
@ -131,7 +134,9 @@ def parse_drain_configs(
)
def setup_structured_logging(log_config: dict,) -> dict:
def setup_structured_logging(
log_config: dict,
) -> dict:
"""
Convert a legacy structured logging configuration (from Synapse < v1.23.0)
to one compatible with the new standard library handlers.

View File

@ -338,7 +338,10 @@ class LoggingContext:
if self.previous_context != old_context:
logcontext_error(
"Expected previous context %r, found %r"
% (self.previous_context, old_context,)
% (
self.previous_context,
old_context,
)
)
return self
@ -562,7 +565,7 @@ class LoggingContextFilter(logging.Filter):
class PreserveLoggingContext:
"""Context manager which replaces the logging context
The previous logging context is restored on exit."""
The previous logging context is restored on exit."""
__slots__ = ["_old_context", "_new_context"]
@ -585,7 +588,10 @@ class PreserveLoggingContext:
else:
logcontext_error(
"Expected logging context %s but found %s"
% (self._new_context, context,)
% (
self._new_context,
context,
)
)

View File

@ -238,8 +238,7 @@ try:
@attr.s(slots=True, frozen=True)
class _WrappedRustReporter:
"""Wrap the reporter to ensure `report_span` never throws.
"""
"""Wrap the reporter to ensure `report_span` never throws."""
_reporter = attr.ib(type=Reporter, default=attr.Factory(Reporter))
@ -326,8 +325,7 @@ def noop_context_manager(*args, **kwargs):
def init_tracer(hs: "HomeServer"):
"""Set the whitelists and initialise the JaegerClient tracer
"""
"""Set the whitelists and initialise the JaegerClient tracer"""
global opentracing
if not hs.config.opentracer_enabled:
# We don't have a tracer
@ -384,7 +382,7 @@ def whitelisted_homeserver(destination):
Args:
destination (str)
"""
"""
if _homeserver_whitelist:
return _homeserver_whitelist.match(destination)

View File

@ -43,8 +43,7 @@ def _log_debug_as_f(f, msg, msg_args):
def log_function(f):
""" Function decorator that logs every call to that function.
"""
"""Function decorator that logs every call to that function."""
func_name = f.__name__
@wraps(f)

View File

@ -155,8 +155,7 @@ class InFlightGauge:
self._registrations.setdefault(key, set()).add(callback)
def unregister(self, key, callback):
"""Registers that we've exited a block with labels `key`.
"""
"""Registers that we've exited a block with labels `key`."""
with self._lock:
self._registrations.setdefault(key, set()).discard(callback)
@ -402,7 +401,9 @@ class PyPyGCStats:
# Total time spent in GC: 0.073 # s.total_gc_time
pypy_gc_time = CounterMetricFamily(
"pypy_gc_time_seconds_total", "Total time spent in PyPy GC", labels=[],
"pypy_gc_time_seconds_total",
"Total time spent in PyPy GC",
labels=[],
)
pypy_gc_time.add_metric([], s.total_gc_time / 1000)
yield pypy_gc_time

View File

@ -216,7 +216,7 @@ class MetricsHandler(BaseHTTPRequestHandler):
@classmethod
def factory(cls, registry):
"""Returns a dynamic MetricsHandler class tied
to the passed registry.
to the passed registry.
"""
# This implementation relies on MetricsHandler.registry
# (defined above and defaulted to REGISTRY).

View File

@ -208,7 +208,8 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
return await maybe_awaitable(func(*args, **kwargs))
except Exception:
logger.exception(
"Background process '%s' threw an exception", desc,
"Background process '%s' threw an exception",
desc,
)
finally:
_background_process_in_flight_count.labels(desc).dec()
@ -249,8 +250,7 @@ class BackgroundProcessLoggingContext(LoggingContext):
self._proc = _BackgroundProcess(name, self)
def start(self, rusage: "Optional[resource._RUsage]"):
"""Log context has started running (again).
"""
"""Log context has started running (again)."""
super().start(rusage)
@ -261,8 +261,7 @@ class BackgroundProcessLoggingContext(LoggingContext):
_background_processes_active_since_last_scrape.add(self._proc)
def __exit__(self, type, value, traceback) -> None:
"""Log context has finished.
"""
"""Log context has finished."""
super().__exit__(type, value, traceback)

View File

@ -275,7 +275,9 @@ class ModuleApi:
redirect them directly if whitelisted).
"""
self._auth_handler._complete_sso_login(
registered_user_id, request, client_redirect_url,
registered_user_id,
request,
client_redirect_url,
)
async def complete_sso_login_async(
@ -352,7 +354,10 @@ class ModuleApi:
event,
_,
) = await self._hs.get_event_creation_handler().create_and_send_nonmember_event(
requester, event_dict, ratelimit=False, ignore_shadow_ban=True,
requester,
event_dict,
ratelimit=False,
ignore_shadow_ban=True,
)
return event

View File

@ -75,7 +75,7 @@ def count(func: Callable[[T], bool], it: Iterable[T]) -> int:
class _NotificationListener:
""" This represents a single client connection to the events stream.
"""This represents a single client connection to the events stream.
The events stream handler will have yielded to the deferred, so to
notify the handler it is sufficient to resolve the deferred.
"""
@ -119,7 +119,10 @@ class _NotifierUserStream:
self.notify_deferred = ObservableDeferred(defer.Deferred())
def notify(
self, stream_key: str, stream_id: Union[int, RoomStreamToken], time_now_ms: int,
self,
stream_key: str,
stream_id: Union[int, RoomStreamToken],
time_now_ms: int,
):
"""Notify any listeners for this user of a new event from an
event source.
@ -140,7 +143,7 @@ class _NotifierUserStream:
noify_deferred.callback(self.current_token)
def remove(self, notifier: "Notifier"):
""" Remove this listener from all the indexes in the Notifier
"""Remove this listener from all the indexes in the Notifier
it knows about.
"""
@ -186,7 +189,7 @@ class _PendingRoomEventEntry:
class Notifier:
""" This class is responsible for notifying any listeners when there are
"""This class is responsible for notifying any listeners when there are
new events available for it.
Primarily used from the /events stream.
@ -265,8 +268,7 @@ class Notifier:
max_room_stream_token: RoomStreamToken,
extra_users: Collection[UserID] = [],
):
"""Unwraps event and calls `on_new_room_event_args`.
"""
"""Unwraps event and calls `on_new_room_event_args`."""
self.on_new_room_event_args(
event_pos=event_pos,
room_id=event.room_id,
@ -341,7 +343,10 @@ class Notifier:
if users or rooms:
self.on_new_event(
"room_key", max_room_stream_token, users=users, rooms=rooms,
"room_key",
max_room_stream_token,
users=users,
rooms=rooms,
)
self._on_updated_room_token(max_room_stream_token)
@ -392,7 +397,7 @@ class Notifier:
users: Collection[Union[str, UserID]] = [],
rooms: Collection[str] = [],
):
""" Used to inform listeners that something has happened event wise.
"""Used to inform listeners that something has happened event wise.
Will wake up all listeners for the given users and rooms.
"""
@ -418,7 +423,9 @@ class Notifier:
# Notify appservices
self._notify_app_services_ephemeral(
stream_key, new_token, users,
stream_key,
new_token,
users,
)
def on_new_replication_data(self) -> None:
@ -502,7 +509,7 @@ class Notifier:
is_guest: bool = False,
explicit_room_id: str = None,
) -> EventStreamResult:
""" For the given user and rooms, return any new events for them. If
"""For the given user and rooms, return any new events for them. If
there are no new events wait for up to `timeout` milliseconds for any
new events to happen before returning.
@ -651,8 +658,7 @@ class Notifier:
cb()
def notify_remote_server_up(self, server: str):
"""Notify any replication that a remote server has come back up
"""
"""Notify any replication that a remote server has come back up"""
# We call federation_sender directly rather than registering as a
# callback as a) we already have a reference to it and b) it introduces
# circular dependencies.

View File

@ -144,8 +144,7 @@ class BulkPushRuleEvaluator:
@lru_cache()
def _get_rules_for_room(self, room_id: str) -> "RulesForRoom":
"""Get the current RulesForRoom object for the given room id
"""
"""Get the current RulesForRoom object for the given room id"""
# It's important that RulesForRoom gets added to self._get_rules_for_room.cache
# before any lookup methods get called on it as otherwise there may be
# a race if invalidate_all gets called (which assumes its in the cache)
@ -252,7 +251,9 @@ class BulkPushRuleEvaluator:
# notified for this event. (This will then get handled when we persist
# the event)
await self.store.add_push_actions_to_staging(
event.event_id, actions_by_user, count_as_unread,
event.event_id,
actions_by_user,
count_as_unread,
)

View File

@ -116,8 +116,7 @@ class EmailPusher(Pusher):
self._is_processing = True
def _resume_processing(self) -> None:
"""Used by tests to resume processing of events after pausing.
"""
"""Used by tests to resume processing of events after pausing."""
assert self._is_processing
self._is_processing = False
self._start_processing()
@ -157,8 +156,10 @@ class EmailPusher(Pusher):
being run.
"""
start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering
unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_email(
self.user_id, start, self.max_stream_ordering
unprocessed = (
await self.store.get_unread_push_actions_for_user_in_range_for_email(
self.user_id, start, self.max_stream_ordering
)
)
soonest_due_at = None # type: Optional[int]
@ -222,12 +223,14 @@ class EmailPusher(Pusher):
self, last_stream_ordering: int
) -> None:
self.last_stream_ordering = last_stream_ordering
pusher_still_exists = await self.store.update_pusher_last_stream_ordering_and_success(
self.app_id,
self.email,
self.user_id,
last_stream_ordering,
self.clock.time_msec(),
pusher_still_exists = (
await self.store.update_pusher_last_stream_ordering_and_success(
self.app_id,
self.email,
self.user_id,
last_stream_ordering,
self.clock.time_msec(),
)
)
if not pusher_still_exists:
# The pusher has been deleted while we were processing, so
@ -298,7 +301,8 @@ class EmailPusher(Pusher):
current_throttle_ms * THROTTLE_MULTIPLIER, THROTTLE_MAX_MS
)
self.throttle_params[room_id] = ThrottleParams(
self.clock.time_msec(), new_throttle_ms,
self.clock.time_msec(),
new_throttle_ms,
)
assert self.pusher_id is not None
await self.store.set_throttle_params(

View File

@ -176,8 +176,10 @@ class HttpPusher(Pusher):
Never call this directly: use _process which will only allow this to
run once per pusher.
"""
unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_http(
self.user_id, self.last_stream_ordering, self.max_stream_ordering
unprocessed = (
await self.store.get_unread_push_actions_for_user_in_range_for_http(
self.user_id, self.last_stream_ordering, self.max_stream_ordering
)
)
logger.info(
@ -204,12 +206,14 @@ class HttpPusher(Pusher):
http_push_processed_counter.inc()
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.last_stream_ordering = push_action["stream_ordering"]
pusher_still_exists = await self.store.update_pusher_last_stream_ordering_and_success(
self.app_id,
self.pushkey,
self.user_id,
self.last_stream_ordering,
self.clock.time_msec(),
pusher_still_exists = (
await self.store.update_pusher_last_stream_ordering_and_success(
self.app_id,
self.pushkey,
self.user_id,
self.last_stream_ordering,
self.clock.time_msec(),
)
)
if not pusher_still_exists:
# The pusher has been deleted while we were processing, so
@ -290,7 +294,8 @@ class HttpPusher(Pusher):
# for sanity, we only remove the pushkey if it
# was the one we actually sent...
logger.warning(
("Ignoring rejected pushkey %s because we didn't send it"), pk,
("Ignoring rejected pushkey %s because we didn't send it"),
pk,
)
else:
logger.info("Pushkey %s was rejected: removing", pk)

View File

@ -78,8 +78,7 @@ class PusherPool:
self.pushers = {} # type: Dict[str, Dict[str, Pusher]]
def start(self) -> None:
"""Starts the pushers off in a background process.
"""
"""Starts the pushers off in a background process."""
if not self._should_start_pushers:
logger.info("Not starting pushers because they are disabled in the config")
return
@ -297,8 +296,7 @@ class PusherPool:
return pusher
async def _start_pushers(self) -> None:
"""Start all the pushers
"""
"""Start all the pushers"""
pushers = await self.store.get_all_pushers()
# Stagger starting up the pushers so we don't completely drown the
@ -335,7 +333,8 @@ class PusherPool:
return None
except Exception:
logger.exception(
"Couldn't start pusher id %i: caught Exception", pusher_config.id,
"Couldn't start pusher id %i: caught Exception",
pusher_config.id,
)
return None

View File

@ -273,7 +273,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args))
http_server.register_paths(
method, [pattern], self._check_auth_and_handle, self.__class__.__name__,
method,
[pattern],
self._check_auth_and_handle,
self.__class__.__name__,
)
def _check_auth_and_handle(self, request, **kwargs):

Some files were not shown because too many files have changed in this diff Show More