Merge remote-tracking branch 'origin/develop' into dbkr/email_notifs

This commit is contained in:
David Baker 2016-05-10 14:40:19 +02:00
commit 997db04648
23 changed files with 620 additions and 89 deletions

10
docs/log_contexts.rst Normal file
View file

@ -0,0 +1,10 @@
What do I do about "Unexpected logging context" debug log-lines everywhere?
<Mjark> The logging context lives in thread local storage
<Mjark> Sometimes it gets out of sync with what it should actually be, usually because something scheduled something to run on the reactor without preserving the logging context.
<Matthew> what is the impact of it getting out of sync? and how and when should we preserve log context?
<Mjark> The impact is that some of the CPU and database metrics will be under-reported, and some log lines will be mis-attributed.
<Mjark> It should happen auto-magically in all the APIs that do IO or otherwise defer to the reactor.
<Erik> Mjark: the other place is if we branch, e.g. using defer.gatherResults
Unanswered: how and when should we preserve log context?

View file

@ -22,7 +22,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
from synapse.types import Requester, RoomID, UserID, EventID from synapse.types import Requester, UserID, get_domian_from_id
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn from synapse.util.logcontext import preserve_context_over_fn
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@ -91,8 +91,8 @@ class Auth(object):
"Room %r does not exist" % (event.room_id,) "Room %r does not exist" % (event.room_id,)
) )
creating_domain = RoomID.from_string(event.room_id).domain creating_domain = get_domian_from_id(event.room_id)
originating_domain = UserID.from_string(event.sender).domain originating_domain = get_domian_from_id(event.sender)
if creating_domain != originating_domain: if creating_domain != originating_domain:
if not self.can_federate(event, auth_events): if not self.can_federate(event, auth_events):
raise AuthError( raise AuthError(
@ -219,7 +219,7 @@ class Auth(object):
for event in curr_state.values(): for event in curr_state.values():
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
try: try:
if UserID.from_string(event.state_key).domain != host: if get_domian_from_id(event.state_key) != host:
continue continue
except: except:
logger.warn("state_key not user_id: %s", event.state_key) logger.warn("state_key not user_id: %s", event.state_key)
@ -266,8 +266,8 @@ class Auth(object):
target_user_id = event.state_key target_user_id = event.state_key
creating_domain = RoomID.from_string(event.room_id).domain creating_domain = get_domian_from_id(event.room_id)
target_domain = UserID.from_string(target_user_id).domain target_domain = get_domian_from_id(target_user_id)
if creating_domain != target_domain: if creating_domain != target_domain:
if not self.can_federate(event, auth_events): if not self.can_federate(event, auth_events):
raise AuthError( raise AuthError(
@ -889,8 +889,8 @@ class Auth(object):
if user_level >= redact_level: if user_level >= redact_level:
return False return False
redacter_domain = EventID.from_string(event.event_id).domain redacter_domain = get_domian_from_id(event.event_id)
redactee_domain = EventID.from_string(event.redacts).domain redactee_domain = get_domian_from_id(event.redacts)
if redacter_domain == redactee_domain: if redacter_domain == redactee_domain:
return True return True

View file

@ -387,6 +387,11 @@ class FederationServer(FederationBase):
"events": [ev.get_pdu_json(time_now) for ev in missing_events], "events": [ev.get_pdu_json(time_now) for ev in missing_events],
}) })
@log_function
def on_openid_userinfo(self, token):
ts_now_ms = self._clock.time_msec()
return self.store.get_user_id_for_open_id_token(token, ts_now_ms)
@log_function @log_function
def _get_persisted_pdu(self, origin, event_id, do_auth=True): def _get_persisted_pdu(self, origin, event_id, do_auth=True):
""" Get a PDU from the database with given origin and id. """ Get a PDU from the database with given origin and id.

View file

@ -20,6 +20,7 @@ from .persistence import TransactionActions
from .units import Transaction from .units import Transaction
from synapse.api.errors import HttpResponseException from synapse.api.errors import HttpResponseException
from synapse.util.async import run_on_reactor
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.retryutils import ( from synapse.util.retryutils import (
@ -199,6 +200,8 @@ class TransactionQueue(object):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def _attempt_new_transaction(self, destination): def _attempt_new_transaction(self, destination):
yield run_on_reactor()
# list of (pending_pdu, deferred, order) # list of (pending_pdu, deferred, order)
if destination in self.pending_transactions: if destination in self.pending_transactions:
# XXX: pending_transactions can get stuck on by a never-ending # XXX: pending_transactions can get stuck on by a never-ending

View file

@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.api.urls import FEDERATION_PREFIX as PREFIX from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request, parse_string
from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.ratelimitutils import FederationRateLimiter
import functools import functools
@ -323,7 +323,7 @@ class FederationSendLeaveServlet(BaseFederationServlet):
class FederationEventAuthServlet(BaseFederationServlet): class FederationEventAuthServlet(BaseFederationServlet):
PATH = "/event_auth(?P<context>[^/]*)/(?P<event_id>[^/]*)" PATH = "/event_auth/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
def on_GET(self, origin, content, query, context, event_id): def on_GET(self, origin, content, query, context, event_id):
return self.handler.on_event_auth(origin, context, event_id) return self.handler.on_event_auth(origin, context, event_id)
@ -448,6 +448,50 @@ class On3pidBindServlet(BaseFederationServlet):
return code return code
class OpenIdUserInfo(BaseFederationServlet):
"""
Exchange a bearer token for information about a user.
The response format should be compatible with:
http://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse
GET /openid/userinfo?access_token=ABDEFGH HTTP/1.1
HTTP/1.1 200 OK
Content-Type: application/json
{
"sub": "@userpart:example.org",
}
"""
PATH = "/openid/userinfo"
@defer.inlineCallbacks
def on_GET(self, request):
token = parse_string(request, "access_token")
if token is None:
defer.returnValue((401, {
"errcode": "M_MISSING_TOKEN", "error": "Access Token required"
}))
return
user_id = yield self.handler.on_openid_userinfo(token)
if user_id is None:
defer.returnValue((401, {
"errcode": "M_UNKNOWN_TOKEN",
"error": "Access Token unknown or expired"
}))
defer.returnValue((200, {"sub": user_id}))
# Avoid doing remote HS authorization checks which are done by default by
# BaseFederationServlet.
def _wrap(self, code):
return code
SERVLET_CLASSES = ( SERVLET_CLASSES = (
FederationSendServlet, FederationSendServlet,
FederationPullServlet, FederationPullServlet,
@ -468,6 +512,7 @@ SERVLET_CLASSES = (
FederationClientKeysClaimServlet, FederationClientKeysClaimServlet,
FederationThirdPartyInviteExchangeServlet, FederationThirdPartyInviteExchangeServlet,
On3pidBindServlet, On3pidBindServlet,
OpenIdUserInfo,
) )

View file

@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import LimitExceededError, SynapseError, AuthError from synapse.api.errors import LimitExceededError, SynapseError, AuthError
from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.api.constants import Membership, EventTypes from synapse.api.constants import Membership, EventTypes
from synapse.types import UserID, RoomAlias, Requester from synapse.types import UserID, RoomAlias, Requester, get_domian_from_id
from synapse.push.action_generator import ActionGenerator from synapse.push.action_generator import ActionGenerator
from synapse.util.logcontext import PreserveLoggingContext, preserve_fn from synapse.util.logcontext import PreserveLoggingContext, preserve_fn
@ -296,7 +296,7 @@ class BaseHandler(object):
return True return True
for (state_key, membership) in room_members: for (state_key, membership) in room_members:
if ( if (
UserID.from_string(state_key).domain == self.hs.hostname self.hs.is_mine_id(state_key)
and membership == Membership.JOIN and membership == Membership.JOIN
): ):
return True return True
@ -421,9 +421,7 @@ class BaseHandler(object):
try: try:
if k[0] == EventTypes.Member: if k[0] == EventTypes.Member:
if s.content["membership"] == Membership.JOIN: if s.content["membership"] == Membership.JOIN:
destinations.add( destinations.add(get_domian_from_id(s.state_key))
UserID.from_string(s.state_key).domain
)
except SynapseError: except SynapseError:
logger.warn( logger.warn(
"Failed to get destination from event %s", s.event_id "Failed to get destination from event %s", s.event_id

View file

@ -33,7 +33,7 @@ from synapse.util.frozenutils import unfreeze
from synapse.crypto.event_signing import ( from synapse.crypto.event_signing import (
compute_event_signature, add_hashes_and_signatures, compute_event_signature, add_hashes_and_signatures,
) )
from synapse.types import UserID from synapse.types import UserID, get_domian_from_id
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
@ -453,7 +453,7 @@ class FederationHandler(BaseHandler):
joined_domains = {} joined_domains = {}
for u, d in joined_users: for u, d in joined_users:
try: try:
dom = UserID.from_string(u).domain dom = get_domian_from_id(u)
old_d = joined_domains.get(dom) old_d = joined_domains.get(dom)
if old_d: if old_d:
joined_domains[dom] = min(d, old_d) joined_domains[dom] = min(d, old_d)
@ -743,9 +743,7 @@ class FederationHandler(BaseHandler):
try: try:
if k[0] == EventTypes.Member: if k[0] == EventTypes.Member:
if s.content["membership"] == Membership.JOIN: if s.content["membership"] == Membership.JOIN:
destinations.add( destinations.add(get_domian_from_id(s.state_key))
UserID.from_string(s.state_key).domain
)
except: except:
logger.warn( logger.warn(
"Failed to get destination from event %s", s.event_id "Failed to get destination from event %s", s.event_id
@ -970,9 +968,7 @@ class FederationHandler(BaseHandler):
try: try:
if k[0] == EventTypes.Member: if k[0] == EventTypes.Member:
if s.content["membership"] == Membership.LEAVE: if s.content["membership"] == Membership.LEAVE:
destinations.add( destinations.add(get_domian_from_id(s.state_key))
UserID.from_string(s.state_key).domain
)
except: except:
logger.warn( logger.warn(
"Failed to get destination from event %s", s.event_id "Failed to get destination from event %s", s.event_id

View file

@ -33,7 +33,7 @@ from synapse.util.logcontext import preserve_fn
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer from synapse.util.wheel_timer import WheelTimer
from synapse.types import UserID from synapse.types import UserID, get_domian_from_id
import synapse.metrics import synapse.metrics
from ._base import BaseHandler from ._base import BaseHandler
@ -168,7 +168,7 @@ class PresenceHandler(BaseHandler):
# The initial delay is to allow disconnected clients a chance to # The initial delay is to allow disconnected clients a chance to
# reconnect before we treat them as offline. # reconnect before we treat them as offline.
self.clock.call_later( self.clock.call_later(
0 * 1000, 30 * 1000,
self.clock.looping_call, self.clock.looping_call,
self._handle_timeouts, self._handle_timeouts,
5000, 5000,
@ -440,7 +440,7 @@ class PresenceHandler(BaseHandler):
if not local_states: if not local_states:
continue continue
host = UserID.from_string(user_id).domain host = get_domian_from_id(user_id)
hosts_to_states.setdefault(host, []).extend(local_states) hosts_to_states.setdefault(host, []).extend(local_states)
# TODO: de-dup hosts_to_states, as a single host might have multiple # TODO: de-dup hosts_to_states, as a single host might have multiple

View file

@ -159,6 +159,15 @@ class ReplicationResource(Resource):
result = yield self.notifier.wait_for_replication(replicate, timeout) result = yield self.notifier.wait_for_replication(replicate, timeout)
for stream_name, stream_content in result.items():
logger.info(
"Replicating %d rows of %s from %s -> %s",
len(stream_content["rows"]),
stream_name,
stream_content["position"],
request_streams.get(stream_name),
)
request.write(json.dumps(result, ensure_ascii=False)) request.write(json.dumps(result, ensure_ascii=False))
finish_request(request) finish_request(request)

View file

@ -44,6 +44,8 @@ from synapse.rest.client.v2_alpha import (
tokenrefresh, tokenrefresh,
tags, tags,
account_data, account_data,
report_event,
openid,
) )
from synapse.http.server import JsonResource from synapse.http.server import JsonResource
@ -86,3 +88,5 @@ class ClientRestResource(JsonResource):
tokenrefresh.register_servlets(hs, client_resource) tokenrefresh.register_servlets(hs, client_resource)
tags.register_servlets(hs, client_resource) tags.register_servlets(hs, client_resource)
account_data.register_servlets(hs, client_resource) account_data.register_servlets(hs, client_resource)
report_event.register_servlets(hs, client_resource)
openid.register_servlets(hs, client_resource)

View file

@ -0,0 +1,96 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import client_v2_patterns
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.api.errors import AuthError
from synapse.util.stringutils import random_string
from twisted.internet import defer
import logging
logger = logging.getLogger(__name__)
class IdTokenServlet(RestServlet):
"""
Get a bearer token that may be passed to a third party to confirm ownership
of a matrix user id.
The format of the response could be made compatible with the format given
in http://openid.net/specs/openid-connect-core-1_0.html#TokenResponse
But instead of returning a signed "id_token" the response contains the
name of the issuing matrix homeserver. This means that for now the third
party will need to check the validity of the "id_token" against the
federation /openid/userinfo endpoint of the homeserver.
Request:
POST /user/{user_id}/openid/request_token?access_token=... HTTP/1.1
{}
Response:
HTTP/1.1 200 OK
{
"access_token": "ABDEFGH",
"token_type": "Bearer",
"matrix_server_name": "example.com",
"expires_in": 3600,
}
"""
PATTERNS = client_v2_patterns(
"/user/(?P<user_id>[^/]*)/openid/request_token"
)
EXPIRES_MS = 3600 * 1000
def __init__(self, hs):
super(IdTokenServlet, self).__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.server_name = hs.config.server_name
@defer.inlineCallbacks
def on_POST(self, request, user_id):
requester = yield self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot request tokens for other users.")
# Parse the request body to make sure it's JSON, but ignore the contents
# for now.
parse_json_object_from_request(request)
token = random_string(24)
ts_valid_until_ms = self.clock.time_msec() + self.EXPIRES_MS
yield self.store.insert_open_id_token(token, ts_valid_until_ms, user_id)
defer.returnValue((200, {
"access_token": token,
"token_type": "Bearer",
"matrix_server_name": self.server_name,
"expires_in": self.EXPIRES_MS / 1000,
}))
def register_servlets(hs, http_server):
IdTokenServlet(hs).register(http_server)

View file

@ -0,0 +1,59 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from ._base import client_v2_patterns
import logging
logger = logging.getLogger(__name__)
class ReportEventRestServlet(RestServlet):
PATTERNS = client_v2_patterns(
"/rooms/(?P<room_id>[^/]*)/report/(?P<event_id>[^/]*)$"
)
def __init__(self, hs):
super(ReportEventRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.store = hs.get_datastore()
@defer.inlineCallbacks
def on_POST(self, request, room_id, event_id):
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
yield self.store.add_event_report(
room_id=room_id,
event_id=event_id,
user_id=user_id,
reason=body.get("reason"),
content=body,
received_ts=self.clock.time_msec(),
)
defer.returnValue((200, {}))
def register_servlets(hs, http_server):
ReportEventRestServlet(hs).register(http_server)

View file

@ -44,6 +44,7 @@ from .receipts import ReceiptsStore
from .search import SearchStore from .search import SearchStore
from .tags import TagsStore from .tags import TagsStore
from .account_data import AccountDataStore from .account_data import AccountDataStore
from .openid import OpenIdStore
from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator
@ -81,7 +82,8 @@ class DataStore(RoomMemberStore, RoomStore,
SearchStore, SearchStore,
TagsStore, TagsStore,
AccountDataStore, AccountDataStore,
EventPushActionsStore EventPushActionsStore,
OpenIdStore,
): ):
def __init__(self, db_conn, hs): def __init__(self, db_conn, hs):
@ -114,6 +116,7 @@ class DataStore(RoomMemberStore, RoomStore,
self._state_groups_id_gen = StreamIdGenerator(db_conn, "state_groups", "id") self._state_groups_id_gen = StreamIdGenerator(db_conn, "state_groups", "id")
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id") self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
self._push_rules_stream_id_gen = ChainedIdGenerator( self._push_rules_stream_id_gen = ChainedIdGenerator(

View file

@ -19,12 +19,14 @@ from twisted.internet import defer, reactor
from synapse.events import FrozenEvent, USE_FROZEN_DICTS from synapse.events import FrozenEvent, USE_FROZEN_DICTS
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import preserve_fn, PreserveLoggingContext from synapse.util.logcontext import preserve_fn, PreserveLoggingContext
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from collections import namedtuple from collections import deque, namedtuple
import logging import logging
import math import math
@ -50,6 +52,93 @@ EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events
EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
class _EventPeristenceQueue(object):
"""Queues up events so that they can be persisted in bulk with only one
concurrent transaction per room.
"""
_EventPersistQueueItem = namedtuple("_EventPersistQueueItem", (
"events_and_contexts", "current_state", "backfilled", "deferred",
))
def __init__(self):
self._event_persist_queues = {}
self._currently_persisting_rooms = set()
def add_to_queue(self, room_id, events_and_contexts, backfilled, current_state):
"""Add events to the queue, with the given persist_event options.
"""
queue = self._event_persist_queues.setdefault(room_id, deque())
if queue:
end_item = queue[-1]
if end_item.current_state or current_state:
# We perist events with current_state set to True one at a time
pass
if end_item.backfilled == backfilled:
end_item.events_and_contexts.extend(events_and_contexts)
return end_item.deferred.observe()
deferred = ObservableDeferred(defer.Deferred())
queue.append(self._EventPersistQueueItem(
events_and_contexts=events_and_contexts,
backfilled=backfilled,
current_state=current_state,
deferred=deferred,
))
return deferred.observe()
def handle_queue(self, room_id, per_item_callback):
"""Attempts to handle the queue for a room if not already being handled.
The given callback will be invoked with for each item in the queue,1
of type _EventPersistQueueItem. The per_item_callback will continuously
be called with new items, unless the queue becomnes empty. The return
value of the function will be given to the deferreds waiting on the item,
exceptions will be passed to the deferres as well.
This function should therefore be called whenever anything is added
to the queue.
If another callback is currently handling the queue then it will not be
invoked.
"""
if room_id in self._currently_persisting_rooms:
return
self._currently_persisting_rooms.add(room_id)
@defer.inlineCallbacks
def handle_queue_loop():
try:
queue = self._get_drainining_queue(room_id)
for item in queue:
try:
ret = yield per_item_callback(item)
item.deferred.callback(ret)
except Exception as e:
item.deferred.errback(e)
finally:
queue = self._event_persist_queues.pop(room_id, None)
if queue:
self._event_persist_queues[room_id] = queue
self._currently_persisting_rooms.discard(room_id)
preserve_fn(handle_queue_loop)()
def _get_drainining_queue(self, room_id):
queue = self._event_persist_queues.setdefault(room_id, deque())
try:
while True:
yield queue.popleft()
except IndexError:
# Queue has been drained.
pass
class EventsStore(SQLBaseStore): class EventsStore(SQLBaseStore):
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
@ -60,19 +149,72 @@ class EventsStore(SQLBaseStore):
self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
) )
@defer.inlineCallbacks self._event_persist_queue = _EventPeristenceQueue()
def persist_events(self, events_and_contexts, backfilled=False): def persist_events(self, events_and_contexts, backfilled=False):
""" """
Write events to the database Write events to the database
Args: Args:
events_and_contexts: list of tuples of (event, context) events_and_contexts: list of tuples of (event, context)
backfilled: ? backfilled: ?
Returns: Tuple of stream_orderings where the first is the minimum and
last is the maximum stream ordering assigned to the events when
persisting.
""" """
partitioned = {}
for event, ctx in events_and_contexts:
partitioned.setdefault(event.room_id, []).append((event, ctx))
deferreds = []
for room_id, evs_ctxs in partitioned.items():
d = self._event_persist_queue.add_to_queue(
room_id, evs_ctxs,
backfilled=backfilled,
current_state=None,
)
deferreds.append(d)
for room_id in partitioned.keys():
self._maybe_start_persisting(room_id)
return defer.gatherResults(deferreds, consumeErrors=True)
@defer.inlineCallbacks
@log_function
def persist_event(self, event, context, current_state=None, backfilled=False):
deferred = self._event_persist_queue.add_to_queue(
event.room_id, [(event, context)],
backfilled=backfilled,
current_state=current_state,
)
self._maybe_start_persisting(event.room_id)
yield deferred
max_persisted_id = yield self._stream_id_gen.get_current_token()
defer.returnValue((event.internal_metadata.stream_ordering, max_persisted_id))
def _maybe_start_persisting(self, room_id):
@defer.inlineCallbacks
def persisting_queue(item):
if item.current_state:
for event, context in item.events_and_contexts:
# There should only ever be one item in
# events_and_contexts when current_state is
# not None
yield self._persist_event(
event, context,
current_state=item.current_state,
backfilled=item.backfilled,
)
else:
yield self._persist_events(
item.events_and_contexts,
backfilled=item.backfilled,
)
self._event_persist_queue.handle_queue(room_id, persisting_queue)
@defer.inlineCallbacks
def _persist_events(self, events_and_contexts, backfilled=False):
if not events_and_contexts: if not events_and_contexts:
return return
@ -119,8 +261,7 @@ class EventsStore(SQLBaseStore):
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def persist_event(self, event, context, current_state=None, backfilled=False): def _persist_event(self, event, context, current_state=None, backfilled=False):
try: try:
with self._stream_id_gen.get_next() as stream_ordering: with self._stream_id_gen.get_next() as stream_ordering:
with self._state_groups_id_gen.get_next() as state_group_id: with self._state_groups_id_gen.get_next() as state_group_id:
@ -137,9 +278,6 @@ class EventsStore(SQLBaseStore):
except _RollbackButIsFineException: except _RollbackButIsFineException:
pass pass
max_persisted_id = yield self._stream_id_gen.get_current_token()
defer.returnValue((stream_ordering, max_persisted_id))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_event(self, event_id, check_redacted=True, def get_event(self, event_id, check_redacted=True,
get_prev_content=False, allow_rejected=False, get_prev_content=False, allow_rejected=False,

32
synapse/storage/openid.py Normal file
View file

@ -0,0 +1,32 @@
from ._base import SQLBaseStore
class OpenIdStore(SQLBaseStore):
def insert_open_id_token(self, token, ts_valid_until_ms, user_id):
return self._simple_insert(
table="open_id_tokens",
values={
"token": token,
"ts_valid_until_ms": ts_valid_until_ms,
"user_id": user_id,
},
desc="insert_open_id_token"
)
def get_user_id_for_open_id_token(self, token, ts_now_ms):
def get_user_id_for_token_txn(txn):
sql = (
"SELECT user_id FROM open_id_tokens"
" WHERE token = ? AND ? <= ts_valid_until_ms"
)
txn.execute(sql, (token, ts_now_ms))
rows = txn.fetchall()
if not rows:
return None
else:
return rows[0][0]
return self.runInteraction(
"get_user_id_for_token", get_user_id_for_token_txn
)

View file

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 31 SCHEMA_VERSION = 32
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))

View file

@ -101,6 +101,7 @@ class RegistrationStore(SQLBaseStore):
make_guest, make_guest,
appservice_id appservice_id
) )
self.get_user_by_id.invalidate((user_id,))
self.is_guest.invalidate((user_id,)) self.is_guest.invalidate((user_id,))
def _register( def _register(
@ -156,6 +157,7 @@ class RegistrationStore(SQLBaseStore):
(next_id, user_id, token,) (next_id, user_id, token,)
) )
@cached()
def get_user_by_id(self, user_id): def get_user_by_id(self, user_id):
return self._simple_select_one( return self._simple_select_one(
table="users", table="users",
@ -193,6 +195,7 @@ class RegistrationStore(SQLBaseStore):
}, { }, {
'password_hash': password_hash 'password_hash': password_hash
}) })
self.get_user_by_id.invalidate((user_id,))
@defer.inlineCallbacks @defer.inlineCallbacks
def user_delete_access_tokens(self, user_id, except_token_ids=[]): def user_delete_access_tokens(self, user_id, except_token_ids=[]):

View file

@ -23,6 +23,7 @@ from .engines import PostgresEngine, Sqlite3Engine
import collections import collections
import logging import logging
import ujson as json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -221,3 +222,20 @@ class RoomStore(SQLBaseStore):
aliases.extend(e.content['aliases']) aliases.extend(e.content['aliases'])
defer.returnValue((name, aliases)) defer.returnValue((name, aliases))
def add_event_report(self, room_id, event_id, user_id, reason, content,
received_ts):
next_id = self._event_reports_id_gen.get_next()
return self._simple_insert(
table="event_reports",
values={
"id": next_id,
"received_ts": received_ts,
"room_id": room_id,
"event_id": event_id,
"user_id": user_id,
"reason": reason,
"content": json.dumps(content),
},
desc="add_event_report"
)

View file

@ -21,7 +21,7 @@ from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.api.constants import Membership from synapse.api.constants import Membership
from synapse.types import UserID from synapse.types import get_domian_from_id
import logging import logging
@ -273,10 +273,7 @@ class RoomMemberStore(SQLBaseStore):
room_id, membership=Membership.JOIN room_id, membership=Membership.JOIN
) )
joined_domains = set( joined_domains = set(get_domian_from_id(r["user_id"]) for r in rows)
UserID.from_string(r["user_id"]).domain
for r in rows
)
return joined_domains return joined_domains

View file

@ -0,0 +1,9 @@
CREATE TABLE open_id_tokens (
token TEXT NOT NULL PRIMARY KEY,
ts_valid_until_ms bigint NOT NULL,
user_id TEXT NOT NULL,
UNIQUE (token)
);
CREATE index open_id_tokens_ts_valid_until_ms ON open_id_tokens(ts_valid_until_ms);

View file

@ -0,0 +1,25 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE event_reports(
id BIGINT NOT NULL PRIMARY KEY,
received_ts BIGINT NOT NULL,
room_id TEXT NOT NULL,
event_id TEXT NOT NULL,
user_id TEXT NOT NULL,
reason TEXT,
content TEXT
);

View file

@ -16,16 +16,56 @@
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from twisted.internet import defer, reactor
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from collections import namedtuple
import itertools
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_TransactionRow = namedtuple(
"_TransactionRow", (
"id", "transaction_id", "destination", "ts", "response_code",
"response_json",
)
)
_UpdateTransactionRow = namedtuple(
"_TransactionRow", (
"response_code", "response_json",
)
)
class TransactionStore(SQLBaseStore): class TransactionStore(SQLBaseStore):
"""A collection of queries for handling PDUs. """A collection of queries for handling PDUs.
""" """
def __init__(self, hs):
super(TransactionStore, self).__init__(hs)
# New transactions that are currently in flights
self.inflight_transactions = {}
# Newly delievered transactions that *weren't* persisted while in flight
self.new_delivered_transactions = {}
# Newly delivered transactions that *were* persisted while in flight
self.update_delivered_transactions = {}
self.last_transaction = {}
reactor.addSystemEventTrigger("before", "shutdown", self._persist_in_mem_txns)
hs.get_clock().looping_call(
self._persist_in_mem_txns,
1000,
)
def get_received_txn_response(self, transaction_id, origin): def get_received_txn_response(self, transaction_id, origin):
"""For an incoming transaction from a given origin, check if we have """For an incoming transaction from a given origin, check if we have
already responded to it. If so, return the response code and response already responded to it. If so, return the response code and response
@ -108,17 +148,30 @@ class TransactionStore(SQLBaseStore):
list: A list of previous transaction ids. list: A list of previous transaction ids.
""" """
return self.runInteraction( auto_id = self._transaction_id_gen.get_next()
"prep_send_transaction",
self._prep_send_transaction, txn_row = _TransactionRow(
transaction_id, destination, origin_server_ts id=auto_id,
transaction_id=transaction_id,
destination=destination,
ts=origin_server_ts,
response_code=0,
response_json=None,
) )
def _prep_send_transaction(self, txn, transaction_id, destination, self.inflight_transactions.setdefault(destination, {})[transaction_id] = txn_row
origin_server_ts):
next_id = self._transaction_id_gen.get_next() prev_txn = self.last_transaction.get(destination)
if prev_txn:
return defer.succeed(prev_txn)
else:
return self.runInteraction(
"_get_prevs_txn",
self._get_prevs_txn,
destination,
)
def _get_prevs_txn(self, txn, destination):
# First we find out what the prev_txns should be. # First we find out what the prev_txns should be.
# Since we know that we are only sending one transaction at a time, # Since we know that we are only sending one transaction at a time,
# we can simply take the last one. # we can simply take the last one.
@ -133,23 +186,6 @@ class TransactionStore(SQLBaseStore):
prev_txns = [r["transaction_id"] for r in results] prev_txns = [r["transaction_id"] for r in results]
# Actually add the new transaction to the sent_transactions table.
self._simple_insert_txn(
txn,
table="sent_transactions",
values={
"id": next_id,
"transaction_id": transaction_id,
"destination": destination,
"ts": origin_server_ts,
"response_code": 0,
"response_json": None,
}
)
# TODO Update the tx id -> pdu id mapping
return prev_txns return prev_txns
def delivered_txn(self, transaction_id, destination, code, response_dict): def delivered_txn(self, transaction_id, destination, code, response_dict):
@ -161,27 +197,23 @@ class TransactionStore(SQLBaseStore):
code (int) code (int)
response_json (str) response_json (str)
""" """
return self.runInteraction(
"delivered_txn",
self._delivered_txn,
transaction_id, destination, code,
buffer(encode_canonical_json(response_dict)),
)
def _delivered_txn(self, txn, transaction_id, destination, txn_row = self.inflight_transactions.get(
code, response_json): destination, {}
self._simple_update_one_txn( ).pop(transaction_id, None)
txn,
table="sent_transactions", self.last_transaction[destination] = transaction_id
keyvalues={
"transaction_id": transaction_id, if txn_row:
"destination": destination, d = self.new_delivered_transactions.setdefault(destination, {})
}, d[transaction_id] = txn_row._replace(
updatevalues={ response_code=code,
"response_code": code, response_json=None, # For now, don't persist response
"response_json": None, # For now, don't persist response_json
}
) )
else:
d = self.update_delivered_transactions.setdefault(destination, {})
# For now, don't persist response
d[transaction_id] = _UpdateTransactionRow(code, None)
def get_transactions_after(self, transaction_id, destination): def get_transactions_after(self, transaction_id, destination):
"""Get all transactions after a given local transaction_id. """Get all transactions after a given local transaction_id.
@ -305,3 +337,48 @@ class TransactionStore(SQLBaseStore):
txn.execute(query, (self._clock.time_msec(),)) txn.execute(query, (self._clock.time_msec(),))
return self.cursor_to_dict(txn) return self.cursor_to_dict(txn)
@defer.inlineCallbacks
def _persist_in_mem_txns(self):
try:
inflight = self.inflight_transactions
new_delivered = self.new_delivered_transactions
update_delivered = self.update_delivered_transactions
self.inflight_transactions = {}
self.new_delivered_transactions = {}
self.update_delivered_transactions = {}
full_rows = [
row._asdict()
for txn_map in itertools.chain(inflight.values(), new_delivered.values())
for row in txn_map.values()
]
def f(txn):
if full_rows:
self._simple_insert_many_txn(
txn=txn,
table="sent_transactions",
values=full_rows
)
for dest, txn_map in update_delivered.items():
for txn_id, update_row in txn_map.items():
self._simple_update_one_txn(
txn,
table="sent_transactions",
keyvalues={
"transaction_id": txn_id,
"destination": dest,
},
updatevalues={
"response_code": update_row.response_code,
"response_json": None, # For now, don't persist response
}
)
if full_rows or update_delivered:
yield self.runInteraction("_persist_in_mem_txns", f)
except:
logger.exception("Failed to persist transactions!")

View file

@ -21,6 +21,10 @@ from collections import namedtuple
Requester = namedtuple("Requester", ["user", "access_token_id", "is_guest"]) Requester = namedtuple("Requester", ["user", "access_token_id", "is_guest"])
def get_domian_from_id(string):
return string.split(":", 1)[1]
class DomainSpecificString( class DomainSpecificString(
namedtuple("DomainSpecificString", ("localpart", "domain")) namedtuple("DomainSpecificString", ("localpart", "domain"))
): ):