0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-15 04:33:53 +01:00

Merge branch 'develop' of github.com:matrix-org/synapse into erikj/state_ids

This commit is contained in:
Erik Johnston 2016-08-26 09:48:13 +01:00
commit 30961182f2
8 changed files with 77 additions and 15 deletions

View file

@ -85,3 +85,8 @@ class RoomCreationPreset(object):
PRIVATE_CHAT = "private_chat" PRIVATE_CHAT = "private_chat"
PUBLIC_CHAT = "public_chat" PUBLIC_CHAT = "public_chat"
TRUSTED_PRIVATE_CHAT = "trusted_private_chat" TRUSTED_PRIVATE_CHAT = "trusted_private_chat"
class ThirdPartyEntityKind(object):
USER = "user"
LOCATION = "location"

View file

@ -25,4 +25,3 @@ SERVER_KEY_PREFIX = "/_matrix/key/v1"
SERVER_KEY_V2_PREFIX = "/_matrix/key/v2" SERVER_KEY_V2_PREFIX = "/_matrix/key/v2"
MEDIA_PREFIX = "/_matrix/media/r0" MEDIA_PREFIX = "/_matrix/media/r0"
LEGACY_MEDIA_PREFIX = "/_matrix/media/v1" LEGACY_MEDIA_PREFIX = "/_matrix/media/v1"
APP_SERVICE_PREFIX = "/_matrix/appservice/v1"

View file

@ -88,6 +88,8 @@ class ApplicationService(object):
self.sender = sender self.sender = sender
self.namespaces = self._check_namespaces(namespaces) self.namespaces = self._check_namespaces(namespaces)
self.id = id self.id = id
# .protocols is a publicly visible field
if protocols: if protocols:
self.protocols = set(protocols) self.protocols = set(protocols)
else: else:

View file

@ -14,10 +14,11 @@
# limitations under the License. # limitations under the License.
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import ThirdPartyEntityKind
from synapse.api.errors import CodeMessageException from synapse.api.errors import CodeMessageException
from synapse.http.client import SimpleHttpClient from synapse.http.client import SimpleHttpClient
from synapse.events.utils import serialize_event from synapse.events.utils import serialize_event
from synapse.types import ThirdPartyEntityKind from synapse.util.caches.response_cache import ResponseCache
import logging import logging
import urllib import urllib
@ -25,6 +26,12 @@ import urllib
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
HOUR_IN_MS = 60 * 60 * 1000
APP_SERVICE_PREFIX = "/_matrix/app/unstable"
def _is_valid_3pe_result(r, field): def _is_valid_3pe_result(r, field):
if not isinstance(r, dict): if not isinstance(r, dict):
return False return False
@ -56,6 +63,8 @@ class ApplicationServiceApi(SimpleHttpClient):
super(ApplicationServiceApi, self).__init__(hs) super(ApplicationServiceApi, self).__init__(hs)
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.protocol_meta_cache = ResponseCache(hs, timeout_ms=HOUR_IN_MS)
@defer.inlineCallbacks @defer.inlineCallbacks
def query_user(self, service, user_id): def query_user(self, service, user_id):
uri = service.url + ("/users/%s" % urllib.quote(user_id)) uri = service.url + ("/users/%s" % urllib.quote(user_id))
@ -97,16 +106,20 @@ class ApplicationServiceApi(SimpleHttpClient):
@defer.inlineCallbacks @defer.inlineCallbacks
def query_3pe(self, service, kind, protocol, fields): def query_3pe(self, service, kind, protocol, fields):
if kind == ThirdPartyEntityKind.USER: if kind == ThirdPartyEntityKind.USER:
uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol))
required_field = "userid" required_field = "userid"
elif kind == ThirdPartyEntityKind.LOCATION: elif kind == ThirdPartyEntityKind.LOCATION:
uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol))
required_field = "alias" required_field = "alias"
else: else:
raise ValueError( raise ValueError(
"Unrecognised 'kind' argument %r to query_3pe()", kind "Unrecognised 'kind' argument %r to query_3pe()", kind
) )
uri = "%s%s/thirdparty/%s/%s" % (
service.url,
APP_SERVICE_PREFIX,
kind,
urllib.quote(protocol)
)
try: try:
response = yield self.get_json(uri, fields) response = yield self.get_json(uri, fields)
if not isinstance(response, list): if not isinstance(response, list):
@ -131,6 +144,26 @@ class ApplicationServiceApi(SimpleHttpClient):
logger.warning("query_3pe to %s threw exception %s", uri, ex) logger.warning("query_3pe to %s threw exception %s", uri, ex)
defer.returnValue([]) defer.returnValue([])
def get_3pe_protocol(self, service, protocol):
@defer.inlineCallbacks
def _get():
uri = "%s%s/thirdparty/protocol/%s" % (
service.url,
APP_SERVICE_PREFIX,
urllib.quote(protocol)
)
try:
defer.returnValue((yield self.get_json(uri, {})))
except Exception as ex:
logger.warning("query_3pe_protocol to %s threw exception %s",
uri, ex)
defer.returnValue({})
key = (service.id, protocol)
return self.protocol_meta_cache.get(key) or (
self.protocol_meta_cache.set(key, _get())
)
@defer.inlineCallbacks @defer.inlineCallbacks
def push_bulk(self, service, events, txn_id=None): def push_bulk(self, service, events, txn_id=None):
events = self._serialize(events) events = self._serialize(events)

View file

@ -175,6 +175,16 @@ class ApplicationServicesHandler(object):
defer.returnValue(ret) defer.returnValue(ret)
@defer.inlineCallbacks
def get_3pe_protocols(self):
services = yield self.store.get_app_services()
protocols = {}
for s in services:
for p in s.protocols:
protocols[p] = yield self.appservice_api.get_3pe_protocol(s, p)
defer.returnValue(protocols)
@defer.inlineCallbacks @defer.inlineCallbacks
def _get_services_for_event(self, event): def _get_services_for_event(self, event):
"""Retrieve a list of application services interested in this event. """Retrieve a list of application services interested in this event.

View file

@ -387,7 +387,9 @@ class FederationHandler(BaseHandler):
)).addErrback(unwrapFirstError) )).addErrback(unwrapFirstError)
auth_events.update({a.event_id: a for a in results if a}) auth_events.update({a.event_id: a for a in results if a})
required_auth.update( required_auth.update(
a_id for event in results for a_id, _ in event.auth_events if event a_id
for event in results if event
for a_id, _ in event.auth_events
) )
missing_auth = required_auth - set(auth_events) missing_auth = required_auth - set(auth_events)

View file

@ -18,15 +18,32 @@ import logging
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import ThirdPartyEntityKind
from synapse.http.servlet import RestServlet from synapse.http.servlet import RestServlet
from synapse.types import ThirdPartyEntityKind
from ._base import client_v2_patterns from ._base import client_v2_patterns
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ThirdPartyProtocolsServlet(RestServlet):
PATTERNS = client_v2_patterns("/thirdparty/protocols", releases=())
def __init__(self, hs):
super(ThirdPartyProtocolsServlet, self).__init__()
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
@defer.inlineCallbacks
def on_GET(self, request):
yield self.auth.get_user_by_req(request)
protocols = yield self.appservice_handler.get_3pe_protocols()
defer.returnValue((200, protocols))
class ThirdPartyUserServlet(RestServlet): class ThirdPartyUserServlet(RestServlet):
PATTERNS = client_v2_patterns("/3pu(/(?P<protocol>[^/]+))?$", PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$",
releases=()) releases=())
def __init__(self, hs): def __init__(self, hs):
@ -50,7 +67,7 @@ class ThirdPartyUserServlet(RestServlet):
class ThirdPartyLocationServlet(RestServlet): class ThirdPartyLocationServlet(RestServlet):
PATTERNS = client_v2_patterns("/3pl(/(?P<protocol>[^/]+))?$", PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$",
releases=()) releases=())
def __init__(self, hs): def __init__(self, hs):
@ -74,5 +91,6 @@ class ThirdPartyLocationServlet(RestServlet):
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
ThirdPartyProtocolsServlet(hs).register(http_server)
ThirdPartyUserServlet(hs).register(http_server) ThirdPartyUserServlet(hs).register(http_server)
ThirdPartyLocationServlet(hs).register(http_server) ThirdPartyLocationServlet(hs).register(http_server)

View file

@ -269,10 +269,3 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
return "t%d-%d" % (self.topological, self.stream) return "t%d-%d" % (self.topological, self.stream)
else: else:
return "s%d" % (self.stream,) return "s%d" % (self.stream,)
# Some arbitrary constants used for internal API enumerations. Don't rely on
# exact values; always pass or compare symbolically
class ThirdPartyEntityKind(object):
USER = 'user'
LOCATION = 'location'