mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 20:53:52 +01:00
Merge pull request #990 from matrix-org/erikj/fed_vers
Add federation /version API
This commit is contained in:
commit
6bf6bc1d1d
6 changed files with 81 additions and 50 deletions
|
@ -165,7 +165,7 @@ def start(config_options):
|
|||
db_config=config.database_config,
|
||||
tls_server_context_factory=tls_server_context_factory,
|
||||
config=config,
|
||||
version_string=get_version_string("Synapse", synapse),
|
||||
version_string="Synapse/" + get_version_string(synapse),
|
||||
database_engine=database_engine,
|
||||
)
|
||||
|
||||
|
|
|
@ -285,7 +285,7 @@ def setup(config_options):
|
|||
# check any extra requirements we have now we have a config
|
||||
check_requirements(config)
|
||||
|
||||
version_string = get_version_string("Synapse", synapse)
|
||||
version_string = "Synapse/" + get_version_string(synapse)
|
||||
|
||||
logger.info("Server hostname: %s", config.server_name)
|
||||
logger.info("Server version: %s", version_string)
|
||||
|
|
|
@ -273,7 +273,7 @@ def start(config_options):
|
|||
config.server_name,
|
||||
db_config=config.database_config,
|
||||
config=config,
|
||||
version_string=get_version_string("Synapse", synapse),
|
||||
version_string="Synapse/" + get_version_string(synapse),
|
||||
database_engine=database_engine,
|
||||
)
|
||||
|
||||
|
|
|
@ -424,7 +424,7 @@ def start(config_options):
|
|||
config.server_name,
|
||||
db_config=config.database_config,
|
||||
config=config,
|
||||
version_string=get_version_string("Synapse", synapse),
|
||||
version_string="Synapse/" + get_version_string(synapse),
|
||||
database_engine=database_engine,
|
||||
application_service_handler=SynchrotronApplicationService(),
|
||||
)
|
||||
|
|
|
@ -18,13 +18,14 @@ from twisted.internet import defer
|
|||
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.http.server import JsonResource
|
||||
from synapse.http.servlet import parse_json_object_from_request, parse_string
|
||||
from synapse.http.servlet import parse_json_object_from_request
|
||||
from synapse.util.ratelimitutils import FederationRateLimiter
|
||||
from synapse.util.versionstring import get_version_string
|
||||
|
||||
import functools
|
||||
import logging
|
||||
import simplejson as json
|
||||
import re
|
||||
import synapse
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -60,6 +61,16 @@ class TransportLayerServer(JsonResource):
|
|||
)
|
||||
|
||||
|
||||
class AuthenticationError(SynapseError):
|
||||
"""There was a problem authenticating the request"""
|
||||
pass
|
||||
|
||||
|
||||
class NoAuthenticationError(AuthenticationError):
|
||||
"""The request had no authentication information"""
|
||||
pass
|
||||
|
||||
|
||||
class Authenticator(object):
|
||||
def __init__(self, hs):
|
||||
self.keyring = hs.get_keyring()
|
||||
|
@ -67,7 +78,7 @@ class Authenticator(object):
|
|||
|
||||
# A method just so we can pass 'self' as the authenticator to the Servlets
|
||||
@defer.inlineCallbacks
|
||||
def authenticate_request(self, request):
|
||||
def authenticate_request(self, request, content):
|
||||
json_request = {
|
||||
"method": request.method,
|
||||
"uri": request.uri,
|
||||
|
@ -75,17 +86,10 @@ class Authenticator(object):
|
|||
"signatures": {},
|
||||
}
|
||||
|
||||
content = None
|
||||
origin = None
|
||||
if content is not None:
|
||||
json_request["content"] = content
|
||||
|
||||
if request.method in ["PUT", "POST"]:
|
||||
# TODO: Handle other method types? other content types?
|
||||
try:
|
||||
content_bytes = request.content.read()
|
||||
content = json.loads(content_bytes)
|
||||
json_request["content"] = content
|
||||
except:
|
||||
raise SynapseError(400, "Unable to parse JSON", Codes.BAD_JSON)
|
||||
origin = None
|
||||
|
||||
def parse_auth_header(header_str):
|
||||
try:
|
||||
|
@ -103,14 +107,14 @@ class Authenticator(object):
|
|||
sig = strip_quotes(param_dict["sig"])
|
||||
return (origin, key, sig)
|
||||
except:
|
||||
raise SynapseError(
|
||||
raise AuthenticationError(
|
||||
400, "Malformed Authorization header", Codes.UNAUTHORIZED
|
||||
)
|
||||
|
||||
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
|
||||
|
||||
if not auth_headers:
|
||||
raise SynapseError(
|
||||
raise NoAuthenticationError(
|
||||
401, "Missing Authorization headers", Codes.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
|
@ -121,7 +125,7 @@ class Authenticator(object):
|
|||
json_request["signatures"].setdefault(origin, {})[key] = sig
|
||||
|
||||
if not json_request["signatures"]:
|
||||
raise SynapseError(
|
||||
raise NoAuthenticationError(
|
||||
401, "Missing Authorization headers", Codes.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
|
@ -130,10 +134,12 @@ class Authenticator(object):
|
|||
logger.info("Request from %s", origin)
|
||||
request.authenticated_entity = origin
|
||||
|
||||
defer.returnValue((origin, content))
|
||||
defer.returnValue(origin)
|
||||
|
||||
|
||||
class BaseFederationServlet(object):
|
||||
REQUIRE_AUTH = True
|
||||
|
||||
def __init__(self, handler, authenticator, ratelimiter, server_name,
|
||||
room_list_handler):
|
||||
self.handler = handler
|
||||
|
@ -141,29 +147,46 @@ class BaseFederationServlet(object):
|
|||
self.ratelimiter = ratelimiter
|
||||
self.room_list_handler = room_list_handler
|
||||
|
||||
def _wrap(self, code):
|
||||
def _wrap(self, func):
|
||||
authenticator = self.authenticator
|
||||
ratelimiter = self.ratelimiter
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@functools.wraps(code)
|
||||
def new_code(request, *args, **kwargs):
|
||||
@functools.wraps(func)
|
||||
def new_func(request, *args, **kwargs):
|
||||
content = None
|
||||
if request.method in ["PUT", "POST"]:
|
||||
# TODO: Handle other method types? other content types?
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
try:
|
||||
(origin, content) = yield authenticator.authenticate_request(request)
|
||||
with ratelimiter.ratelimit(origin) as d:
|
||||
yield d
|
||||
response = yield code(
|
||||
origin, content, request.args, *args, **kwargs
|
||||
)
|
||||
origin = yield authenticator.authenticate_request(request, content)
|
||||
except NoAuthenticationError:
|
||||
origin = None
|
||||
if self.REQUIRE_AUTH:
|
||||
logger.exception("authenticate_request failed")
|
||||
raise
|
||||
except:
|
||||
logger.exception("authenticate_request failed")
|
||||
raise
|
||||
|
||||
if origin:
|
||||
with ratelimiter.ratelimit(origin) as d:
|
||||
yield d
|
||||
response = yield func(
|
||||
origin, content, request.args, *args, **kwargs
|
||||
)
|
||||
else:
|
||||
response = yield func(
|
||||
origin, content, request.args, *args, **kwargs
|
||||
)
|
||||
|
||||
defer.returnValue(response)
|
||||
|
||||
# Extra logic that functools.wraps() doesn't finish
|
||||
new_code.__self__ = code.__self__
|
||||
new_func.__self__ = func.__self__
|
||||
|
||||
return new_code
|
||||
return new_func
|
||||
|
||||
def register(self, server):
|
||||
pattern = re.compile("^" + PREFIX + self.PATH + "$")
|
||||
|
@ -429,9 +452,10 @@ class FederationGetMissingEventsServlet(BaseFederationServlet):
|
|||
class On3pidBindServlet(BaseFederationServlet):
|
||||
PATH = "/3pid/onbind"
|
||||
|
||||
REQUIRE_AUTH = False
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
content = parse_json_object_from_request(request)
|
||||
def on_POST(self, origin, content, query):
|
||||
if "invites" in content:
|
||||
last_exception = None
|
||||
for invite in content["invites"]:
|
||||
|
@ -453,11 +477,6 @@ class On3pidBindServlet(BaseFederationServlet):
|
|||
raise last_exception
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
# Avoid doing remote HS authorization checks which are done by default by
|
||||
# BaseFederationServlet.
|
||||
def _wrap(self, code):
|
||||
return code
|
||||
|
||||
|
||||
class OpenIdUserInfo(BaseFederationServlet):
|
||||
"""
|
||||
|
@ -478,9 +497,11 @@ class OpenIdUserInfo(BaseFederationServlet):
|
|||
|
||||
PATH = "/openid/userinfo"
|
||||
|
||||
REQUIRE_AUTH = False
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
token = parse_string(request, "access_token")
|
||||
def on_GET(self, origin, content, query):
|
||||
token = query.get("access_token", [None])[0]
|
||||
if token is None:
|
||||
defer.returnValue((401, {
|
||||
"errcode": "M_MISSING_TOKEN", "error": "Access Token required"
|
||||
|
@ -497,11 +518,6 @@ class OpenIdUserInfo(BaseFederationServlet):
|
|||
|
||||
defer.returnValue((200, {"sub": user_id}))
|
||||
|
||||
# Avoid doing remote HS authorization checks which are done by default by
|
||||
# BaseFederationServlet.
|
||||
def _wrap(self, code):
|
||||
return code
|
||||
|
||||
|
||||
class PublicRoomList(BaseFederationServlet):
|
||||
"""
|
||||
|
@ -542,6 +558,20 @@ class PublicRoomList(BaseFederationServlet):
|
|||
defer.returnValue((200, data))
|
||||
|
||||
|
||||
class FederationVersionServlet(BaseFederationServlet):
|
||||
PATH = "/version"
|
||||
|
||||
REQUIRE_AUTH = False
|
||||
|
||||
def on_GET(self, origin, content, query):
|
||||
return defer.succeed((200, {
|
||||
"server": {
|
||||
"name": "Synapse",
|
||||
"version": get_version_string(synapse)
|
||||
},
|
||||
}))
|
||||
|
||||
|
||||
SERVLET_CLASSES = (
|
||||
FederationSendServlet,
|
||||
FederationPullServlet,
|
||||
|
@ -565,6 +595,7 @@ SERVLET_CLASSES = (
|
|||
On3pidBindServlet,
|
||||
OpenIdUserInfo,
|
||||
PublicRoomList,
|
||||
FederationVersionServlet,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ import logging
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_version_string(name, module):
|
||||
def get_version_string(module):
|
||||
try:
|
||||
null = open(os.devnull, 'w')
|
||||
cwd = os.path.dirname(os.path.abspath(module.__file__))
|
||||
|
@ -74,11 +74,11 @@ def get_version_string(name, module):
|
|||
)
|
||||
|
||||
return (
|
||||
"%s/%s (%s)" % (
|
||||
name, module.__version__, git_version,
|
||||
"%s (%s)" % (
|
||||
module.__version__, git_version,
|
||||
)
|
||||
).encode("ascii")
|
||||
except Exception as e:
|
||||
logger.info("Failed to check for git repository: %s", e)
|
||||
|
||||
return ("%s/%s" % (name, module.__version__,)).encode("ascii")
|
||||
return module.__version__.encode("ascii")
|
||||
|
|
Loading…
Reference in a new issue