Do not convert async functions to Deferreds in the interactive_auth_handler (#7944)

This commit is contained in:
Patrick Cloke 2020-07-24 09:43:49 -04:00 committed by GitHub
parent 5ea29d7f85
commit 53f7b49f5b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 22 additions and 26 deletions

1
changelog.d/7944.misc Normal file
View file

@ -0,0 +1 @@
Convert the interactive_auth_handler wrapper to async/await.

View file

@ -17,8 +17,7 @@
"""
import logging
import re
from twisted.internet import defer
from typing import Iterable, Pattern
from synapse.api.errors import InteractiveAuthIncompleteError
from synapse.api.urls import CLIENT_API_PREFIX
@ -27,15 +26,23 @@ from synapse.types import JsonDict
logger = logging.getLogger(__name__)
def client_patterns(path_regex, releases=(0,), unstable=True, v1=False):
def client_patterns(
path_regex: str,
releases: Iterable[int] = (0,),
unstable: bool = True,
v1: bool = False,
) -> Iterable[Pattern]:
"""Creates a regex compiled client path with the correct client path
prefix.
Args:
path_regex (str): The regex string to match. This should NOT have a ^
path_regex: The regex string to match. This should NOT have a ^
as this will be prefixed.
releases: An iterable of releases to include this endpoint under.
unstable: If true, include this endpoint under the "unstable" prefix.
v1: If true, include this endpoint under the "api/v1" prefix.
Returns:
SRE_Pattern
An iterable of patterns.
"""
patterns = []
@ -73,34 +80,22 @@ def set_timeline_upper_limit(filter_json: JsonDict, filter_timeline_limit: int)
def interactive_auth_handler(orig):
"""Wraps an on_POST method to handle InteractiveAuthIncompleteErrors
Takes a on_POST method which returns a deferred (errcode, body) response
Takes a on_POST method which returns an Awaitable (errcode, body) response
and adds exception handling to turn a InteractiveAuthIncompleteError into
a 401 response.
Normal usage is:
@interactive_auth_handler
@defer.inlineCallbacks
def on_POST(self, request):
async def on_POST(self, request):
# ...
yield self.auth_handler.check_auth
"""
await self.auth_handler.check_auth
"""
def wrapped(*args, **kwargs):
res = defer.ensureDeferred(orig(*args, **kwargs))
res.addErrback(_catch_incomplete_interactive_auth)
return res
async def wrapped(*args, **kwargs):
try:
return await orig(*args, **kwargs)
except InteractiveAuthIncompleteError as e:
return 401, e.result
return wrapped
def _catch_incomplete_interactive_auth(f):
"""helper for interactive_auth_handler
Catches InteractiveAuthIncompleteErrors and turns them into 401 responses
Args:
f (failure.Failure):
"""
f.trap(InteractiveAuthIncompleteError)
return 401, f.value.result