0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2024-12-17 21:33:50 +01:00

Merge branch 'develop' of github.com:matrix-org/synapse into release-v0.19.3

This commit is contained in:
Erik Johnston 2017-03-13 09:59:54 +00:00
commit 672dcf59d3
48 changed files with 1491 additions and 245 deletions

View file

@ -332,9 +332,8 @@ https://obs.infoserver.lv/project/monitor/matrix-synapse
ArchLinux ArchLinux
--------- ---------
The quickest way to get up and running with ArchLinux is probably with Ivan The quickest way to get up and running with ArchLinux is probably with the community package
Shapovalov's AUR package from https://www.archlinux.org/packages/community/any/matrix-synapse/, which should pull in all
https://aur.archlinux.org/packages/matrix-synapse/, which should pull in all
the necessary dependencies. the necessary dependencies.
Alternatively, to install using pip a few changes may be needed as ArchLinux Alternatively, to install using pip a few changes may be needed as ArchLinux

View file

@ -32,7 +32,7 @@ import urlparse
import nacl.signing import nacl.signing
import nacl.encoding import nacl.encoding
from syutil.crypto.jsonsign import verify_signed_json, SignatureVerifyException from signedjson.sign import verify_signed_json, SignatureVerifyException
CONFIG_JSON = "cmdclient_config.json" CONFIG_JSON = "cmdclient_config.json"

View file

@ -0,0 +1,48 @@
# Example log_config file for synapse. To enable, point `log_config` to it in
# `homeserver.yaml`, and restart synapse.
#
# This configuration will produce similar results to the defaults within
# synapse, but can be edited to give more flexibility.
version: 1
formatters:
fmt:
format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s- %(message)s'
filters:
context:
(): synapse.util.logcontext.LoggingContextFilter
request: ""
handlers:
# example output to console
console:
class: logging.StreamHandler
filters: [context]
# example output to file - to enable, edit 'root' config below.
file:
class: logging.handlers.RotatingFileHandler
formatter: fmt
filename: /var/log/synapse/homeserver.log
maxBytes: 100000000
backupCount: 3
filters: [context]
root:
level: INFO
handlers: [console] # to use file handler instead, switch to [file]
loggers:
synapse:
level: INFO
synapse.storage:
level: INFO
# example of enabling debugging for a component:
#
# synapse.federation.transport.server:
# level: DEBUG

View file

@ -1,5 +1,5 @@
# This assumes that Synapse has been installed as a system package # This assumes that Synapse has been installed as a system package
# (e.g. https://aur.archlinux.org/packages/matrix-synapse/ for ArchLinux) # (e.g. https://www.archlinux.org/packages/community/any/matrix-synapse/ for ArchLinux)
# rather than in a user home directory or similar under virtualenv. # rather than in a user home directory or similar under virtualenv.
[Unit] [Unit]

View file

@ -25,6 +25,5 @@ Configuring IP used for auth
The ReCaptcha API requires that the IP address of the user who solved the The ReCaptcha API requires that the IP address of the user who solved the
captcha is sent. If the client is connecting through a proxy or load balancer, captcha is sent. If the client is connecting through a proxy or load balancer,
it may be required to use the X-Forwarded-For (XFF) header instead of the origin it may be required to use the X-Forwarded-For (XFF) header instead of the origin
IP address. This can be configured as an option on the home server like so:: IP address. This can be configured using the x_forwarded directive in the
listeners section of the homeserver.yaml configuration file.
captcha_ip_origin_is_x_forwarded: true

View file

@ -0,0 +1,22 @@
#!/bin/bash
set -eux
: ${WORKSPACE:="$(pwd)"}
export WORKSPACE
export PYTHONDONTWRITEBYTECODE=yep
export SYNAPSE_CACHE_FACTOR=1
export HAPROXY_BIN=/home/haproxy/haproxy-1.6.11/haproxy
./jenkins/prepare_synapse.sh
./jenkins/clone.sh sytest https://github.com/matrix-org/sytest.git
./jenkins/clone.sh dendron https://github.com/matrix-org/dendron.git
./dendron/jenkins/build_dendron.sh
./sytest/jenkins/prep_sytest_for_postgres.sh
./sytest/jenkins/install_and_run.sh \
--synapse-directory $WORKSPACE \
--dendron $WORKSPACE/dendron/bin/dendron \
--haproxy \

View file

@ -17,9 +17,3 @@ export SYNAPSE_CACHE_FACTOR=1
./sytest/jenkins/install_and_run.sh \ ./sytest/jenkins/install_and_run.sh \
--synapse-directory $WORKSPACE \ --synapse-directory $WORKSPACE \
--dendron $WORKSPACE/dendron/bin/dendron \ --dendron $WORKSPACE/dendron/bin/dendron \
--pusher \
--synchrotron \
--federation-reader \
--client-reader \
--appservice \
--federation-sender \

View file

@ -40,6 +40,7 @@ BOOLEAN_COLUMNS = {
"presence_list": ["accepted"], "presence_list": ["accepted"],
"presence_stream": ["currently_active"], "presence_stream": ["currently_active"],
"public_room_list_stream": ["visibility"], "public_room_list_stream": ["visibility"],
"device_lists_outbound_pokes": ["sent"],
} }

View file

@ -43,9 +43,6 @@ class JoinRules(object):
class LoginType(object): class LoginType(object):
PASSWORD = u"m.login.password" PASSWORD = u"m.login.password"
OAUTH = u"m.login.oauth2"
EMAIL_CODE = u"m.login.email.code"
EMAIL_URL = u"m.login.email.url"
EMAIL_IDENTITY = u"m.login.email.identity" EMAIL_IDENTITY = u"m.login.email.identity"
RECAPTCHA = u"m.login.recaptcha" RECAPTCHA = u"m.login.recaptcha"
DUMMY = u"m.login.dummy" DUMMY = u"m.login.dummy"

View file

@ -87,6 +87,10 @@ class SynchrotronSlavedStore(
RoomMemberStore.__dict__["who_forgot_in_room"] RoomMemberStore.__dict__["who_forgot_in_room"]
) )
did_forget = (
RoomMemberStore.__dict__["did_forget"]
)
# XXX: This is a bit broken because we don't persist the accepted list in a # XXX: This is a bit broken because we don't persist the accepted list in a
# way that can be replicated. This means that we don't have a way to # way that can be replicated. This means that we don't have a way to
# invalidate the cache correctly. # invalidate the cache correctly.

View file

@ -95,7 +95,7 @@ class TlsConfig(Config):
# make HTTPS requests to this server will check that the TLS # make HTTPS requests to this server will check that the TLS
# certificates returned by this server match one of the fingerprints. # certificates returned by this server match one of the fingerprints.
# #
# Synapse automatically adds its the fingerprint of its own certificate # Synapse automatically adds the fingerprint of its own certificate
# to the list. So if federation traffic is handle directly by synapse # to the list. So if federation traffic is handle directly by synapse
# then no modification to the list is required. # then no modification to the list is required.
# #

View file

@ -206,8 +206,7 @@ class FederationClient(FederationBase):
Args: Args:
destinations (list): Which home servers to query destinations (list): Which home servers to query
pdu_origin (str): The home server that originally sent the pdu. event_id (str): event to fetch
event_id (str)
outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
it's from an arbitary point in the context as opposed to part it's from an arbitary point in the context as opposed to part
of the current block of PDUs. Defaults to `False` of the current block of PDUs. Defaults to `False`

View file

@ -303,18 +303,10 @@ class TransactionQueue(object):
try: try:
self.pending_transactions[destination] = 1 self.pending_transactions[destination] = 1
# XXX: what's this for?
yield run_on_reactor() yield run_on_reactor()
while True: while True:
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
pending_edus = self.pending_edus_by_dest.pop(destination, [])
pending_presence = self.pending_presence_by_dest.pop(destination, {})
pending_failures = self.pending_failures_by_dest.pop(destination, [])
pending_edus.extend(
self.pending_edus_keyed_by_dest.pop(destination, {}).values()
)
limiter = yield get_retry_limiter( limiter = yield get_retry_limiter(
destination, destination,
self.clock, self.clock,
@ -326,6 +318,24 @@ class TransactionQueue(object):
yield self._get_new_device_messages(destination) yield self._get_new_device_messages(destination)
) )
# BEGIN CRITICAL SECTION
#
# In order to avoid a race condition, we need to make sure that
# the following code (from popping the queues up to the point
# where we decide if we actually have any pending messages) is
# atomic - otherwise new PDUs or EDUs might arrive in the
# meantime, but not get sent because we hold the
# pending_transactions flag.
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
pending_edus = self.pending_edus_by_dest.pop(destination, [])
pending_presence = self.pending_presence_by_dest.pop(destination, {})
pending_failures = self.pending_failures_by_dest.pop(destination, [])
pending_edus.extend(
self.pending_edus_keyed_by_dest.pop(destination, {}).values()
)
pending_edus.extend(device_message_edus) pending_edus.extend(device_message_edus)
if pending_presence: if pending_presence:
pending_edus.append( pending_edus.append(
@ -355,6 +365,8 @@ class TransactionQueue(object):
) )
return return
# END CRITICAL SECTION
success = yield self._send_new_transaction( success = yield self._send_new_transaction(
destination, pending_pdus, pending_edus, pending_failures, destination, pending_pdus, pending_edus, pending_failures,
limiter=limiter, limiter=limiter,

View file

@ -19,7 +19,6 @@ from ._base import BaseHandler
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -54,3 +53,46 @@ class AdminHandler(BaseHandler):
} }
defer.returnValue(ret) defer.returnValue(ret)
@defer.inlineCallbacks
def get_users(self):
"""Function to reterive a list of users in users table.
Args:
Returns:
defer.Deferred: resolves to list[dict[str, Any]]
"""
ret = yield self.store.get_users()
defer.returnValue(ret)
@defer.inlineCallbacks
def get_users_paginate(self, order, start, limit):
"""Function to reterive a paginated list of users from
users list. This will return a json object, which contains
list of users and the total number of users in users table.
Args:
order (str): column name to order the select by this column
start (int): start number to begin the query from
limit (int): number of rows to reterive
Returns:
defer.Deferred: resolves to json object {list[dict[str, Any]], count}
"""
ret = yield self.store.get_users_paginate(order, start, limit)
defer.returnValue(ret)
@defer.inlineCallbacks
def search_users(self, term):
"""Function to search users list for one or more users with
the matched term.
Args:
term (str): search term
Returns:
defer.Deferred: resolves to list[dict[str, Any]]
"""
ret = yield self.store.search_users(term)
defer.returnValue(ret)

View file

@ -12,11 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.api import errors from synapse.api import errors
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.util import stringutils from synapse.util import stringutils
from synapse.util.async import Linearizer from synapse.util.async import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import measure_func from synapse.util.metrics import measure_func
from synapse.types import get_domain_from_id, RoomStreamToken from synapse.types import get_domain_from_id, RoomStreamToken
from twisted.internet import defer from twisted.internet import defer
@ -35,10 +35,11 @@ class DeviceHandler(BaseHandler):
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.federation_sender = hs.get_federation_sender() self.federation_sender = hs.get_federation_sender()
self.federation = hs.get_replication_layer() self.federation = hs.get_replication_layer()
self._remote_edue_linearizer = Linearizer(name="remote_device_list")
self._edu_updater = DeviceListEduUpdater(hs, self)
self.federation.register_edu_handler( self.federation.register_edu_handler(
"m.device_list_update", self._incoming_device_list_update, "m.device_list_update", self._edu_updater.incoming_device_list_update,
) )
self.federation.register_query_handler( self.federation.register_query_handler(
"user_devices", self.on_federation_query_user_devices, "user_devices", self.on_federation_query_user_devices,
@ -246,30 +247,51 @@ class DeviceHandler(BaseHandler):
# Then work out if any users have since joined # Then work out if any users have since joined
rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key) rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key)
stream_ordering = RoomStreamToken.parse_stream_token(
from_token.room_key).stream
possibly_changed = set(changed) possibly_changed = set(changed)
for room_id in rooms_changed: for room_id in rooms_changed:
# Fetch the current state at the time. # Fetch the current state at the time.
stream_ordering = RoomStreamToken.parse_stream_token(from_token.room_key)
try: try:
event_ids = yield self.store.get_forward_extremeties_for_room( event_ids = yield self.store.get_forward_extremeties_for_room(
room_id, stream_ordering=stream_ordering room_id, stream_ordering=stream_ordering
) )
prev_state_ids = yield self.store.get_state_ids_for_events(event_ids) except errors.StoreError:
except: # we have purged the stream_ordering index since the stream
prev_state_ids = {} # ordering: treat it the same as a new room
event_ids = []
current_state_ids = yield self.state.get_current_state_ids(room_id) current_state_ids = yield self.state.get_current_state_ids(room_id)
# special-case for an empty prev state: include all members
# in the changed list
if not event_ids:
for key, event_id in current_state_ids.iteritems():
etype, state_key = key
if etype != EventTypes.Member:
continue
possibly_changed.add(state_key)
continue
# mapping from event_id -> state_dict
prev_state_ids = yield self.store.get_state_ids_for_events(event_ids)
# If there has been any change in membership, include them in the # If there has been any change in membership, include them in the
# possibly changed list. We'll check if they are joined below, # possibly changed list. We'll check if they are joined below,
# and we're not toooo worried about spuriously adding users. # and we're not toooo worried about spuriously adding users.
for key, event_id in current_state_ids.iteritems(): for key, event_id in current_state_ids.iteritems():
etype, state_key = key etype, state_key = key
if etype == EventTypes.Member: if etype != EventTypes.Member:
prev_event_id = prev_state_ids.get(key, None) continue
# check if this member has changed since any of the extremities
# at the stream_ordering, and add them to the list if so.
for state_dict in prev_state_ids.values():
prev_event_id = state_dict.get(key, None)
if not prev_event_id or prev_event_id != event_id: if not prev_event_id or prev_event_id != event_id:
possibly_changed.add(state_key) possibly_changed.add(state_key)
break
users_who_share_room = yield self.store.get_users_who_share_room_with_user( users_who_share_room = yield self.store.get_users_who_share_room_with_user(
user_id user_id
@ -279,58 +301,6 @@ class DeviceHandler(BaseHandler):
# and those that actually still share a room with the user # and those that actually still share a room with the user
defer.returnValue(users_who_share_room & possibly_changed) defer.returnValue(users_who_share_room & possibly_changed)
@measure_func("_incoming_device_list_update")
@defer.inlineCallbacks
def _incoming_device_list_update(self, origin, edu_content):
user_id = edu_content["user_id"]
device_id = edu_content["device_id"]
stream_id = edu_content["stream_id"]
prev_ids = edu_content.get("prev_id", [])
if get_domain_from_id(user_id) != origin:
# TODO: Raise?
logger.warning("Got device list update edu for %r from %r", user_id, origin)
return
rooms = yield self.store.get_rooms_for_user(user_id)
if not rooms:
# We don't share any rooms with this user. Ignore update, as we
# probably won't get any further updates.
return
with (yield self._remote_edue_linearizer.queue(user_id)):
# If the prev id matches whats in our cache table, then we don't need
# to resync the users device list, otherwise we do.
resync = True
if len(prev_ids) == 1:
extremity = yield self.store.get_device_list_last_stream_id_for_remote(
user_id
)
logger.info("Extrem: %r, prev_ids: %r", extremity, prev_ids)
if str(extremity) == str(prev_ids[0]):
resync = False
if resync:
# Fetch all devices for the user.
result = yield self.federation.query_user_devices(origin, user_id)
stream_id = result["stream_id"]
devices = result["devices"]
yield self.store.update_remote_device_list_cache(
user_id, devices, stream_id,
)
device_ids = [device["device_id"] for device in devices]
yield self.notify_device_update(user_id, device_ids)
else:
# Simply update the single device, since we know that is the only
# change (becuase of the single prev_id matching the current cache)
content = dict(edu_content)
for key in ("user_id", "device_id", "stream_id", "prev_ids"):
content.pop(key, None)
yield self.store.update_remote_device_list_cache_entry(
user_id, device_id, content, stream_id,
)
yield self.notify_device_update(user_id, [device_id])
@defer.inlineCallbacks @defer.inlineCallbacks
def on_federation_query_user_devices(self, user_id): def on_federation_query_user_devices(self, user_id):
stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id) stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
@ -356,3 +326,129 @@ def _update_device_from_client_ips(device, client_ips):
"last_seen_ts": ip.get("last_seen"), "last_seen_ts": ip.get("last_seen"),
"last_seen_ip": ip.get("ip"), "last_seen_ip": ip.get("ip"),
}) })
class DeviceListEduUpdater(object):
"Handles incoming device list updates from federation and updates the DB"
def __init__(self, hs, device_handler):
self.store = hs.get_datastore()
self.federation = hs.get_replication_layer()
self.clock = hs.get_clock()
self.device_handler = device_handler
self._remote_edu_linearizer = Linearizer(name="remote_device_list")
# user_id -> list of updates waiting to be handled.
self._pending_updates = {}
# Recently seen stream ids. We don't bother keeping these in the DB,
# but they're useful to have them about to reduce the number of spurious
# resyncs.
self._seen_updates = ExpiringCache(
cache_name="device_update_edu",
clock=self.clock,
max_len=10000,
expiry_ms=30 * 60 * 1000,
iterable=True,
)
@defer.inlineCallbacks
def incoming_device_list_update(self, origin, edu_content):
"""Called on incoming device list update from federation. Responsible
for parsing the EDU and adding to pending updates list.
"""
user_id = edu_content.pop("user_id")
device_id = edu_content.pop("device_id")
stream_id = str(edu_content.pop("stream_id")) # They may come as ints
prev_ids = edu_content.pop("prev_id", [])
prev_ids = [str(p) for p in prev_ids] # They may come as ints
if get_domain_from_id(user_id) != origin:
# TODO: Raise?
logger.warning("Got device list update edu for %r from %r", user_id, origin)
return
rooms = yield self.store.get_rooms_for_user(user_id)
if not rooms:
# We don't share any rooms with this user. Ignore update, as we
# probably won't get any further updates.
return
self._pending_updates.setdefault(user_id, []).append(
(device_id, stream_id, prev_ids, edu_content)
)
yield self._handle_device_updates(user_id)
@measure_func("_incoming_device_list_update")
@defer.inlineCallbacks
def _handle_device_updates(self, user_id):
"Actually handle pending updates."
with (yield self._remote_edu_linearizer.queue(user_id)):
pending_updates = self._pending_updates.pop(user_id, [])
if not pending_updates:
# This can happen since we batch updates
return
resync = yield self._need_to_do_resync(user_id, pending_updates)
if resync:
# Fetch all devices for the user.
origin = get_domain_from_id(user_id)
result = yield self.federation.query_user_devices(origin, user_id)
stream_id = result["stream_id"]
devices = result["devices"]
yield self.store.update_remote_device_list_cache(
user_id, devices, stream_id,
)
device_ids = [device["device_id"] for device in devices]
yield self.device_handler.notify_device_update(user_id, device_ids)
else:
# Simply update the single device, since we know that is the only
# change (becuase of the single prev_id matching the current cache)
for device_id, stream_id, prev_ids, content in pending_updates:
yield self.store.update_remote_device_list_cache_entry(
user_id, device_id, content, stream_id,
)
yield self.device_handler.notify_device_update(
user_id, [device_id for device_id, _, _, _ in pending_updates]
)
self._seen_updates.setdefault(user_id, set()).update(
stream_id for _, stream_id, _, _ in pending_updates
)
@defer.inlineCallbacks
def _need_to_do_resync(self, user_id, updates):
"""Given a list of updates for a user figure out if we need to do a full
resync, or whether we have enough data that we can just apply the delta.
"""
seen_updates = self._seen_updates.get(user_id, set())
extremity = yield self.store.get_device_list_last_stream_id_for_remote(
user_id
)
stream_id_in_updates = set() # stream_ids in updates list
for _, stream_id, prev_ids, _ in updates:
if not prev_ids:
# We always do a resync if there are no previous IDs
defer.returnValue(True)
for prev_id in prev_ids:
if prev_id == extremity:
continue
elif prev_id in seen_updates:
continue
elif prev_id in stream_id_in_updates:
continue
else:
defer.returnValue(True)
stream_id_in_updates.add(stream_id)
defer.returnValue(False)

View file

@ -1096,7 +1096,7 @@ class FederationHandler(BaseHandler):
if prev_id != event.event_id: if prev_id != event.event_id:
results[(event.type, event.state_key)] = prev_id results[(event.type, event.state_key)] = prev_id
else: else:
del results[(event.type, event.state_key)] results.pop((event.type, event.state_key), None)
defer.returnValue(results.values()) defer.returnValue(results.values())
else: else:

View file

@ -531,7 +531,7 @@ class PresenceHandler(object):
# There are things not in our in memory cache. Lets pull them out of # There are things not in our in memory cache. Lets pull them out of
# the database. # the database.
res = yield self.store.get_presence_for_users(missing) res = yield self.store.get_presence_for_users(missing)
states.update({state.user_id: state for state in res}) states.update(res)
missing = [user_id for user_id, state in states.items() if not state] missing = [user_id for user_id, state in states.items() if not state]
if missing: if missing:

View file

@ -719,7 +719,9 @@ class RoomMemberHandler(BaseHandler):
) )
membership = member.membership if member else None membership = member.membership if member else None
if membership is not None and membership != Membership.LEAVE: if membership is not None and membership not in [
Membership.LEAVE, Membership.BAN
]:
raise SynapseError(400, "User %s in room %s" % ( raise SynapseError(400, "User %s in room %s" % (
user_id, room_id user_id, room_id
)) ))

View file

@ -609,14 +609,14 @@ class SyncHandler(object):
deleted = yield self.store.delete_messages_for_device( deleted = yield self.store.delete_messages_for_device(
user_id, device_id, since_stream_id user_id, device_id, since_stream_id
) )
logger.info("Deleted %d to-device messages up to %d", logger.debug("Deleted %d to-device messages up to %d",
deleted, since_stream_id) deleted, since_stream_id)
messages, stream_id = yield self.store.get_new_messages_for_device( messages, stream_id = yield self.store.get_new_messages_for_device(
user_id, device_id, since_stream_id, now_token.to_device_key user_id, device_id, since_stream_id, now_token.to_device_key
) )
logger.info( logger.debug(
"Returning %d to-device messages between %d and %d (current token: %d)", "Returning %d to-device messages between %d and %d (current token: %d)",
len(messages), since_stream_id, stream_id, now_token.to_device_key len(messages), since_stream_id, stream_id, now_token.to_device_key
) )

View file

@ -218,7 +218,8 @@ class EmailPusher(object):
) )
def seconds_until(self, ts_msec): def seconds_until(self, ts_msec):
return (ts_msec - self.clock.time_msec()) / 1000 secs = (ts_msec - self.clock.time_msec()) / 1000
return max(secs, 0)
def get_room_throttle_ms(self, room_id): def get_room_throttle_ms(self, room_id):
if room_id in self.throttle_params: if room_id in self.throttle_params:

View file

@ -54,7 +54,9 @@ class BaseSlavedStore(SQLBaseStore):
try: try:
getattr(self, cache_func).invalidate(tuple(keys)) getattr(self, cache_func).invalidate(tuple(keys))
except AttributeError: except AttributeError:
logger.info("Got unexpected cache_func: %r", cache_func) # We probably haven't pulled in the cache in this worker,
# which is fine.
pass
self._cache_id_gen.advance(int(stream["position"])) self._cache_id_gen.advance(int(stream["position"]))
return defer.succeed(None) return defer.succeed(None)

View file

@ -46,6 +46,12 @@ class SlavedAccountDataStore(BaseSlavedStore):
) )
get_tags_for_user = TagsStore.__dict__["get_tags_for_user"] get_tags_for_user = TagsStore.__dict__["get_tags_for_user"]
get_tags_for_room = (
DataStore.get_tags_for_room.__func__
)
get_account_data_for_room = (
DataStore.get_account_data_for_room.__func__
)
get_updated_tags = DataStore.get_updated_tags.__func__ get_updated_tags = DataStore.get_updated_tags.__func__
get_updated_account_data_for_user = ( get_updated_account_data_for_user = (

View file

@ -17,6 +17,7 @@ from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore from synapse.storage import DataStore
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.caches.expiringcache import ExpiringCache
class SlavedDeviceInboxStore(BaseSlavedStore): class SlavedDeviceInboxStore(BaseSlavedStore):
@ -34,6 +35,13 @@ class SlavedDeviceInboxStore(BaseSlavedStore):
self._device_inbox_id_gen.get_current_token() self._device_inbox_id_gen.get_current_token()
) )
self._last_device_delete_cache = ExpiringCache(
cache_name="last_device_delete_cache",
clock=self._clock,
max_len=10000,
expiry_ms=30 * 60 * 1000,
)
get_to_device_stream_token = DataStore.get_to_device_stream_token.__func__ get_to_device_stream_token = DataStore.get_to_device_stream_token.__func__
get_new_messages_for_device = DataStore.get_new_messages_for_device.__func__ get_new_messages_for_device = DataStore.get_new_messages_for_device.__func__
get_new_device_msgs_for_remote = DataStore.get_new_device_msgs_for_remote.__func__ get_new_device_msgs_for_remote = DataStore.get_new_device_msgs_for_remote.__func__

View file

@ -85,6 +85,12 @@ class SlavedEventStore(BaseSlavedStore):
get_unread_event_push_actions_by_room_for_user = ( get_unread_event_push_actions_by_room_for_user = (
EventPushActionsStore.__dict__["get_unread_event_push_actions_by_room_for_user"] EventPushActionsStore.__dict__["get_unread_event_push_actions_by_room_for_user"]
) )
_get_unread_counts_by_receipt_txn = (
DataStore._get_unread_counts_by_receipt_txn.__func__
)
_get_unread_counts_by_pos_txn = (
DataStore._get_unread_counts_by_pos_txn.__func__
)
_get_state_group_for_events = ( _get_state_group_for_events = (
StateStore.__dict__["_get_state_group_for_events"] StateStore.__dict__["_get_state_group_for_events"]
) )

View file

@ -18,6 +18,7 @@ from ._slaved_id_tracker import SlavedIdTracker
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.storage import DataStore from synapse.storage import DataStore
from synapse.storage.presence import PresenceStore
class SlavedPresenceStore(BaseSlavedStore): class SlavedPresenceStore(BaseSlavedStore):
@ -35,7 +36,8 @@ class SlavedPresenceStore(BaseSlavedStore):
_get_active_presence = DataStore._get_active_presence.__func__ _get_active_presence = DataStore._get_active_presence.__func__
take_presence_startup_info = DataStore.take_presence_startup_info.__func__ take_presence_startup_info = DataStore.take_presence_startup_info.__func__
get_presence_for_users = DataStore.get_presence_for_users.__func__ _get_presence_for_user = PresenceStore.__dict__["_get_presence_for_user"]
get_presence_for_users = PresenceStore.__dict__["get_presence_for_users"]
def get_current_presence_token(self): def get_current_presence_token(self):
return self._presence_id_gen.get_current_token() return self._presence_id_gen.get_current_token()

View file

@ -87,9 +87,17 @@ class HttpTransactionCache(object):
deferred = fn(*args, **kwargs) deferred = fn(*args, **kwargs)
# We don't add an errback to the raw deferred, so we ask ObservableDeferred # if the request fails with a Twisted failure, remove it
# to swallow the error. This is fine as the error will still be reported # from the transaction map. This is done to ensure that we don't
# to the observers. # cache transient errors like rate-limiting errors, etc.
def remove_from_map(err):
self.transactions.pop(txn_key, None)
return err
deferred.addErrback(remove_from_map)
# We don't add any other errbacks to the raw deferred, so we ask
# ObservableDeferred to swallow the error. This is fine as the error will
# still be reported to the observers.
observable = ObservableDeferred(deferred, consumeErrors=True) observable = ObservableDeferred(deferred, consumeErrors=True)
self.transactions[txn_key] = (observable, self.clock.time_msec()) self.transactions[txn_key] = (observable, self.clock.time_msec())
return observable.observe() return observable.observe()

View file

@ -17,6 +17,7 @@ from twisted.internet import defer
from synapse.api.errors import AuthError, SynapseError from synapse.api.errors import AuthError, SynapseError
from synapse.types import UserID from synapse.types import UserID
from synapse.http.servlet import parse_json_object_from_request
from .base import ClientV1RestServlet, client_path_patterns from .base import ClientV1RestServlet, client_path_patterns
@ -25,6 +26,34 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class UsersRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/admin/users/(?P<user_id>[^/]*)")
def __init__(self, hs):
super(UsersRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request, user_id):
target_user = UserID.from_string(user_id)
requester = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(requester.user)
if not is_admin:
raise AuthError(403, "You are not a server admin")
# To allow all users to get the users list
# if not is_admin and target_user != auth_user:
# raise AuthError(403, "You are not a server admin")
if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only users a local user")
ret = yield self.handlers.admin_handler.get_users()
defer.returnValue((200, ret))
class WhoisRestServlet(ClientV1RestServlet): class WhoisRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/admin/whois/(?P<user_id>[^/]*)") PATTERNS = client_path_patterns("/admin/whois/(?P<user_id>[^/]*)")
@ -128,8 +157,199 @@ class DeactivateAccountRestServlet(ClientV1RestServlet):
defer.returnValue((200, {})) defer.returnValue((200, {}))
class ResetPasswordRestServlet(ClientV1RestServlet):
"""Post request to allow an administrator reset password for a user.
This need a user have a administrator access in Synapse.
Example:
http://localhost:8008/_matrix/client/api/v1/admin/reset_password/
@user:to_reset_password?access_token=admin_access_token
JsonBodyToSend:
{
"new_password": "secret"
}
Returns:
200 OK with empty object if success otherwise an error.
"""
PATTERNS = client_path_patterns("/admin/reset_password/(?P<target_user_id>[^/]*)")
def __init__(self, hs):
self.store = hs.get_datastore()
super(ResetPasswordRestServlet, self).__init__(hs)
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
@defer.inlineCallbacks
def on_POST(self, request, target_user_id):
"""Post request to allow an administrator reset password for a user.
This need a user have a administrator access in Synapse.
"""
UserID.from_string(target_user_id)
requester = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(requester.user)
if not is_admin:
raise AuthError(403, "You are not a server admin")
params = parse_json_object_from_request(request)
new_password = params['new_password']
if not new_password:
raise SynapseError(400, "Missing 'new_password' arg")
logger.info("new_password: %r", new_password)
yield self.auth_handler.set_password(
target_user_id, new_password, requester
)
defer.returnValue((200, {}))
class GetUsersPaginatedRestServlet(ClientV1RestServlet):
"""Get request to get specific number of users from Synapse.
This need a user have a administrator access in Synapse.
Example:
http://localhost:8008/_matrix/client/api/v1/admin/users_paginate/
@admin:user?access_token=admin_access_token&start=0&limit=10
Returns:
200 OK with json object {list[dict[str, Any]], count} or empty object.
"""
PATTERNS = client_path_patterns("/admin/users_paginate/(?P<target_user_id>[^/]*)")
def __init__(self, hs):
self.store = hs.get_datastore()
super(GetUsersPaginatedRestServlet, self).__init__(hs)
self.hs = hs
self.auth = hs.get_auth()
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request, target_user_id):
"""Get request to get specific number of users from Synapse.
This need a user have a administrator access in Synapse.
"""
target_user = UserID.from_string(target_user_id)
requester = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(requester.user)
if not is_admin:
raise AuthError(403, "You are not a server admin")
# To allow all users to get the users list
# if not is_admin and target_user != auth_user:
# raise AuthError(403, "You are not a server admin")
if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only users a local user")
order = "name" # order by name in user table
start = request.args.get("start")[0]
limit = request.args.get("limit")[0]
if not limit:
raise SynapseError(400, "Missing 'limit' arg")
if not start:
raise SynapseError(400, "Missing 'start' arg")
logger.info("limit: %s, start: %s", limit, start)
ret = yield self.handlers.admin_handler.get_users_paginate(
order, start, limit
)
defer.returnValue((200, ret))
@defer.inlineCallbacks
def on_POST(self, request, target_user_id):
"""Post request to get specific number of users from Synapse..
This need a user have a administrator access in Synapse.
Example:
http://localhost:8008/_matrix/client/api/v1/admin/users_paginate/
@admin:user?access_token=admin_access_token
JsonBodyToSend:
{
"start": "0",
"limit": "10
}
Returns:
200 OK with json object {list[dict[str, Any]], count} or empty object.
"""
UserID.from_string(target_user_id)
requester = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(requester.user)
if not is_admin:
raise AuthError(403, "You are not a server admin")
order = "name" # order by name in user table
params = parse_json_object_from_request(request)
limit = params['limit']
start = params['start']
if not limit:
raise SynapseError(400, "Missing 'limit' arg")
if not start:
raise SynapseError(400, "Missing 'start' arg")
logger.info("limit: %s, start: %s", limit, start)
ret = yield self.handlers.admin_handler.get_users_paginate(
order, start, limit
)
defer.returnValue((200, ret))
class SearchUsersRestServlet(ClientV1RestServlet):
"""Get request to search user table for specific users according to
search term.
This need a user have a administrator access in Synapse.
Example:
http://localhost:8008/_matrix/client/api/v1/admin/search_users/
@admin:user?access_token=admin_access_token&term=alice
Returns:
200 OK with json object {list[dict[str, Any]], count} or empty object.
"""
PATTERNS = client_path_patterns("/admin/search_users/(?P<target_user_id>[^/]*)")
def __init__(self, hs):
self.store = hs.get_datastore()
super(SearchUsersRestServlet, self).__init__(hs)
self.hs = hs
self.auth = hs.get_auth()
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request, target_user_id):
"""Get request to search user table for specific users according to
search term.
This need a user have a administrator access in Synapse.
"""
target_user = UserID.from_string(target_user_id)
requester = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(requester.user)
if not is_admin:
raise AuthError(403, "You are not a server admin")
# To allow all users to get the users list
# if not is_admin and target_user != auth_user:
# raise AuthError(403, "You are not a server admin")
if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only users a local user")
term = request.args.get("term")[0]
if not term:
raise SynapseError(400, "Missing 'term' arg")
logger.info("term: %s ", term)
ret = yield self.handlers.admin_handler.search_users(
term
)
defer.returnValue((200, ret))
def register_servlets(hs, http_server): def register_servlets(hs, http_server):
WhoisRestServlet(hs).register(http_server) WhoisRestServlet(hs).register(http_server)
PurgeMediaCacheRestServlet(hs).register(http_server) PurgeMediaCacheRestServlet(hs).register(http_server)
DeactivateAccountRestServlet(hs).register(http_server) DeactivateAccountRestServlet(hs).register(http_server)
PurgeHistoryRestServlet(hs).register(http_server) PurgeHistoryRestServlet(hs).register(http_server)
UsersRestServlet(hs).register(http_server)
ResetPasswordRestServlet(hs).register(http_server)
GetUsersPaginatedRestServlet(hs).register(http_server)
SearchUsersRestServlet(hs).register(http_server)

View file

@ -46,6 +46,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
def on_PUT(self, request, user_id): def on_PUT(self, request, user_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True) requester = yield self.auth.get_user_by_req(request, allow_guest=True)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
is_admin = yield self.auth.is_server_admin(requester.user)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
@ -55,7 +56,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
defer.returnValue((400, "Unable to parse name")) defer.returnValue((400, "Unable to parse name"))
yield self.handlers.profile_handler.set_displayname( yield self.handlers.profile_handler.set_displayname(
user, requester, new_name) user, requester, new_name, is_admin)
defer.returnValue((200, {})) defer.returnValue((200, {}))
@ -88,6 +89,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
def on_PUT(self, request, user_id): def on_PUT(self, request, user_id):
requester = yield self.auth.get_user_by_req(request) requester = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
is_admin = yield self.auth.is_server_admin(requester.user)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
try: try:
@ -96,7 +98,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
defer.returnValue((400, "Unable to parse name")) defer.returnValue((400, "Unable to parse name"))
yield self.handlers.profile_handler.set_avatar_url( yield self.handlers.profile_handler.set_avatar_url(
user, requester, new_name) user, requester, new_name, is_admin)
defer.returnValue((200, {})) defer.returnValue((200, {}))

View file

@ -608,6 +608,10 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
raise SynapseError(400, "Missing user_id key.") raise SynapseError(400, "Missing user_id key.")
target = UserID.from_string(content["user_id"]) target = UserID.from_string(content["user_id"])
event_content = None
if 'reason' in content and membership_action in ['kick', 'ban']:
event_content = {'reason': content['reason']}
yield self.handlers.room_member_handler.update_membership( yield self.handlers.room_member_handler.update_membership(
requester=requester, requester=requester,
target=target, target=target,
@ -615,6 +619,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
action=membership_action, action=membership_action,
txn_id=txn_id, txn_id=txn_id,
third_party_signed=content.get("third_party_signed", None), third_party_signed=content.get("third_party_signed", None),
content=event_content,
) )
defer.returnValue((200, {})) defer.returnValue((200, {}))

View file

@ -240,6 +240,9 @@ class MediaRepository(object):
if t_method == "crop": if t_method == "crop":
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type) t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
elif t_method == "scale": elif t_method == "scale":
t_width, t_height = thumbnailer.aspect(t_width, t_height)
t_width = min(m_width, t_width)
t_height = min(m_height, t_height)
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type) t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
else: else:
t_len = None t_len = None

View file

@ -297,6 +297,82 @@ class DataStore(RoomMemberStore, RoomStore,
desc="get_user_ip_and_agents", desc="get_user_ip_and_agents",
) )
def get_users(self):
"""Function to reterive a list of users in users table.
Args:
Returns:
defer.Deferred: resolves to list[dict[str, Any]]
"""
return self._simple_select_list(
table="users",
keyvalues={},
retcols=[
"name",
"password_hash",
"is_guest",
"admin"
],
desc="get_users",
)
def get_users_paginate(self, order, start, limit):
"""Function to reterive a paginated list of users from
users list. This will return a json object, which contains
list of users and the total number of users in users table.
Args:
order (str): column name to order the select by this column
start (int): start number to begin the query from
limit (int): number of rows to reterive
Returns:
defer.Deferred: resolves to json object {list[dict[str, Any]], count}
"""
is_guest = 0
i_start = (int)(start)
i_limit = (int)(limit)
return self.get_user_list_paginate(
table="users",
keyvalues={
"is_guest": is_guest
},
pagevalues=[
order,
i_limit,
i_start
],
retcols=[
"name",
"password_hash",
"is_guest",
"admin"
],
desc="get_users_paginate",
)
def search_users(self, term):
"""Function to search users list for one or more users with
the matched term.
Args:
term (str): search term
col (str): column to query term should be matched to
Returns:
defer.Deferred: resolves to list[dict[str, Any]]
"""
return self._simple_search_list(
table="users",
term=term,
col="name",
retcols=[
"name",
"password_hash",
"is_guest",
"admin"
],
desc="search_users",
)
def are_all_users_on_domain(txn, database_engine, domain): def are_all_users_on_domain(txn, database_engine, domain):
sql = database_engine.convert_param_style( sql = database_engine.convert_param_style(

View file

@ -18,7 +18,6 @@ from synapse.api.errors import StoreError
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.caches.dictionary_cache import DictionaryCache from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.caches.descriptors import Cache from synapse.util.caches.descriptors import Cache
from synapse.util.caches import intern_dict
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
import synapse.metrics import synapse.metrics
@ -80,7 +79,13 @@ class LoggingTransaction(object):
def executemany(self, sql, *args): def executemany(self, sql, *args):
self._do_execute(self.txn.executemany, sql, *args) self._do_execute(self.txn.executemany, sql, *args)
def _make_sql_one_line(self, sql):
"Strip newlines out of SQL so that the loggers in the DB are on one line"
return " ".join(l.strip() for l in sql.splitlines() if l.strip())
def _do_execute(self, func, sql, *args): def _do_execute(self, func, sql, *args):
sql = self._make_sql_one_line(sql)
# TODO(paul): Maybe use 'info' and 'debug' for values? # TODO(paul): Maybe use 'info' and 'debug' for values?
sql_logger.debug("[SQL] {%s} %s", self.name, sql) sql_logger.debug("[SQL] {%s} %s", self.name, sql)
@ -350,9 +355,9 @@ class SQLBaseStore(object):
Returns: Returns:
A list of dicts where the key is the column header. A list of dicts where the key is the column header.
""" """
col_headers = list(column[0] for column in cursor.description) col_headers = list(intern(column[0]) for column in cursor.description)
results = list( results = list(
intern_dict(dict(zip(col_headers, row))) for row in cursor.fetchall() dict(zip(col_headers, row)) for row in cursor.fetchall()
) )
return results return results
@ -483,10 +488,6 @@ class SQLBaseStore(object):
" AND ".join("%s = ?" % (k,) for k in keyvalues) " AND ".join("%s = ?" % (k,) for k in keyvalues)
) )
sqlargs = values.values() + keyvalues.values() sqlargs = values.values() + keyvalues.values()
logger.debug(
"[SQL] %s Args=%s",
sql, sqlargs,
)
txn.execute(sql, sqlargs) txn.execute(sql, sqlargs)
if txn.rowcount == 0: if txn.rowcount == 0:
@ -501,10 +502,6 @@ class SQLBaseStore(object):
", ".join(k for k in allvalues), ", ".join(k for k in allvalues),
", ".join("?" for _ in allvalues) ", ".join("?" for _ in allvalues)
) )
logger.debug(
"[SQL] %s Args=%s",
sql, keyvalues.values(),
)
txn.execute(sql, allvalues.values()) txn.execute(sql, allvalues.values())
return True return True
@ -934,6 +931,165 @@ class SQLBaseStore(object):
else: else:
return 0 return 0
def _simple_select_list_paginate(self, table, keyvalues, pagevalues, retcols,
desc="_simple_select_list_paginate"):
"""Executes a SELECT query on the named table with start and limit,
of row numbers, which may return zero or number of rows from start to limit,
returning the result as a list of dicts.
Args:
table (str): the table name
keyvalues (dict[str, Any] | None):
column names and values to select the rows with, or None to not
apply a WHERE clause.
retcols (iterable[str]): the names of the columns to return
order (str): order the select by this column
start (int): start number to begin the query from
limit (int): number of rows to reterive
Returns:
defer.Deferred: resolves to list[dict[str, Any]]
"""
return self.runInteraction(
desc,
self._simple_select_list_paginate_txn,
table, keyvalues, pagevalues, retcols
)
@classmethod
def _simple_select_list_paginate_txn(cls, txn, table, keyvalues, pagevalues, retcols):
"""Executes a SELECT query on the named table with start and limit,
of row numbers, which may return zero or number of rows from start to limit,
returning the result as a list of dicts.
Args:
txn : Transaction object
table (str): the table name
keyvalues (dict[str, T] | None):
column names and values to select the rows with, or None to not
apply a WHERE clause.
pagevalues ([]):
order (str): order the select by this column
start (int): start number to begin the query from
limit (int): number of rows to reterive
retcols (iterable[str]): the names of the columns to return
Returns:
defer.Deferred: resolves to list[dict[str, Any]]
"""
if keyvalues:
sql = "SELECT %s FROM %s WHERE %s ORDER BY %s" % (
", ".join(retcols),
table,
" AND ".join("%s = ?" % (k,) for k in keyvalues),
" ? ASC LIMIT ? OFFSET ?"
)
txn.execute(sql, keyvalues.values() + pagevalues)
else:
sql = "SELECT %s FROM %s ORDER BY %s" % (
", ".join(retcols),
table,
" ? ASC LIMIT ? OFFSET ?"
)
txn.execute(sql, pagevalues)
return cls.cursor_to_dict(txn)
@defer.inlineCallbacks
def get_user_list_paginate(self, table, keyvalues, pagevalues, retcols,
desc="get_user_list_paginate"):
"""Get a list of users from start row to a limit number of rows. This will
return a json object with users and total number of users in users list.
Args:
table (str): the table name
keyvalues (dict[str, Any] | None):
column names and values to select the rows with, or None to not
apply a WHERE clause.
pagevalues ([]):
order (str): order the select by this column
start (int): start number to begin the query from
limit (int): number of rows to reterive
retcols (iterable[str]): the names of the columns to return
Returns:
defer.Deferred: resolves to json object {list[dict[str, Any]], count}
"""
users = yield self.runInteraction(
desc,
self._simple_select_list_paginate_txn,
table, keyvalues, pagevalues, retcols
)
count = yield self.runInteraction(
desc,
self.get_user_count_txn
)
retval = {
"users": users,
"total": count
}
defer.returnValue(retval)
def get_user_count_txn(self, txn):
"""Get a total number of registerd users in the users list.
Args:
txn : Transaction object
Returns:
defer.Deferred: resolves to int
"""
sql_count = "SELECT COUNT(*) FROM users WHERE is_guest = 0;"
txn.execute(sql_count)
count = txn.fetchone()[0]
defer.returnValue(count)
def _simple_search_list(self, table, term, col, retcols,
desc="_simple_search_list"):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
Args:
table (str): the table name
term (str | None):
term for searching the table matched to a column.
col (str): column to query term should be matched to
retcols (iterable[str]): the names of the columns to return
Returns:
defer.Deferred: resolves to list[dict[str, Any]] or None
"""
return self.runInteraction(
desc,
self._simple_search_list_txn,
table, term, col, retcols
)
@classmethod
def _simple_search_list_txn(cls, txn, table, term, col, retcols):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
Args:
txn : Transaction object
table (str): the table name
term (str | None):
term for searching the table matched to a column.
col (str): column to query term should be matched to
retcols (iterable[str]): the names of the columns to return
Returns:
defer.Deferred: resolves to list[dict[str, Any]] or None
"""
if term:
sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (
", ".join(retcols),
table,
col
)
termvalues = ["%%" + term + "%%"]
txn.execute(sql, termvalues)
else:
return 0
return cls.cursor_to_dict(txn)
class _RollbackButIsFineException(Exception): class _RollbackButIsFineException(Exception):
""" This exception is used to rollback a transaction without implying """ This exception is used to rollback a transaction without implying

View file

@ -20,6 +20,8 @@ from twisted.internet import defer
from .background_updates import BackgroundUpdateStore from .background_updates import BackgroundUpdateStore
from synapse.util.caches.expiringcache import ExpiringCache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -42,6 +44,15 @@ class DeviceInboxStore(BackgroundUpdateStore):
self._background_drop_index_device_inbox, self._background_drop_index_device_inbox,
) )
# Map of (user_id, device_id) to the last stream_id that has been
# deleted up to. This is so that we can no op deletions.
self._last_device_delete_cache = ExpiringCache(
cache_name="last_device_delete_cache",
clock=self._clock,
max_len=10000,
expiry_ms=30 * 60 * 1000,
)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_messages_to_device_inbox(self, local_messages_by_user_then_device, def add_messages_to_device_inbox(self, local_messages_by_user_then_device,
remote_messages_by_destination): remote_messages_by_destination):
@ -251,6 +262,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
"get_new_messages_for_device", get_new_messages_for_device_txn, "get_new_messages_for_device", get_new_messages_for_device_txn,
) )
@defer.inlineCallbacks
def delete_messages_for_device(self, user_id, device_id, up_to_stream_id): def delete_messages_for_device(self, user_id, device_id, up_to_stream_id):
""" """
Args: Args:
@ -260,6 +272,18 @@ class DeviceInboxStore(BackgroundUpdateStore):
Returns: Returns:
A deferred that resolves to the number of messages deleted. A deferred that resolves to the number of messages deleted.
""" """
# If we have cached the last stream id we've deleted up to, we can
# check if there is likely to be anything that needs deleting
last_deleted_stream_id = self._last_device_delete_cache.get(
(user_id, device_id), None
)
if last_deleted_stream_id:
has_changed = self._device_inbox_stream_cache.has_entity_changed(
user_id, last_deleted_stream_id
)
if not has_changed:
defer.returnValue(0)
def delete_messages_for_device_txn(txn): def delete_messages_for_device_txn(txn):
sql = ( sql = (
"DELETE FROM device_inbox" "DELETE FROM device_inbox"
@ -269,10 +293,20 @@ class DeviceInboxStore(BackgroundUpdateStore):
txn.execute(sql, (user_id, device_id, up_to_stream_id)) txn.execute(sql, (user_id, device_id, up_to_stream_id))
return txn.rowcount return txn.rowcount
return self.runInteraction( count = yield self.runInteraction(
"delete_messages_for_device", delete_messages_for_device_txn "delete_messages_for_device", delete_messages_for_device_txn
) )
# Update the cache, ensuring that we only ever increase the value
last_deleted_stream_id = self._last_device_delete_cache.get(
(user_id, device_id), 0
)
self._last_device_delete_cache[(user_id, device_id)] = max(
last_deleted_stream_id, up_to_stream_id
)
defer.returnValue(count)
def get_all_new_device_messages(self, last_pos, current_pos, limit): def get_all_new_device_messages(self, last_pos, current_pos, limit):
""" """
Args: Args:

View file

@ -19,6 +19,8 @@ from twisted.internet import defer
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -31,6 +33,13 @@ class DeviceStore(SQLBaseStore):
self._prune_old_outbound_device_pokes, 60 * 60 * 1000 self._prune_old_outbound_device_pokes, 60 * 60 * 1000
) )
self.register_background_index_update(
"device_lists_stream_idx",
index_name="device_lists_stream_user_id",
table="device_lists_stream",
columns=["user_id", "device_id"],
)
@defer.inlineCallbacks @defer.inlineCallbacks
def store_device(self, user_id, device_id, def store_device(self, user_id, device_id,
initial_device_display_name): initial_device_display_name):
@ -144,6 +153,7 @@ class DeviceStore(SQLBaseStore):
defer.returnValue({d["device_id"]: d for d in devices}) defer.returnValue({d["device_id"]: d for d in devices})
@cached(max_entries=10000)
def get_device_list_last_stream_id_for_remote(self, user_id): def get_device_list_last_stream_id_for_remote(self, user_id):
"""Get the last stream_id we got for a user. May be None if we haven't """Get the last stream_id we got for a user. May be None if we haven't
got any information for them. got any information for them.
@ -156,16 +166,36 @@ class DeviceStore(SQLBaseStore):
allow_none=True, allow_none=True,
) )
@cachedList(cached_method_name="get_device_list_last_stream_id_for_remote",
list_name="user_ids", inlineCallbacks=True)
def get_device_list_last_stream_id_for_remotes(self, user_ids):
rows = yield self._simple_select_many_batch(
table="device_lists_remote_extremeties",
column="user_id",
iterable=user_ids,
retcols=("user_id", "stream_id",),
desc="get_user_devices_from_cache",
)
results = {user_id: None for user_id in user_ids}
results.update({
row["user_id"]: row["stream_id"] for row in rows
})
defer.returnValue(results)
@defer.inlineCallbacks
def mark_remote_user_device_list_as_unsubscribed(self, user_id): def mark_remote_user_device_list_as_unsubscribed(self, user_id):
"""Mark that we no longer track device lists for remote user. """Mark that we no longer track device lists for remote user.
""" """
return self._simple_delete( yield self._simple_delete(
table="device_lists_remote_extremeties", table="device_lists_remote_extremeties",
keyvalues={ keyvalues={
"user_id": user_id, "user_id": user_id,
}, },
desc="mark_remote_user_device_list_as_unsubscribed", desc="mark_remote_user_device_list_as_unsubscribed",
) )
self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))
def update_remote_device_list_cache_entry(self, user_id, device_id, content, def update_remote_device_list_cache_entry(self, user_id, device_id, content,
stream_id): stream_id):
@ -191,6 +221,12 @@ class DeviceStore(SQLBaseStore):
} }
) )
txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id,))
txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,))
txn.call_after(
self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
)
self._simple_upsert_txn( self._simple_upsert_txn(
txn, txn,
table="device_lists_remote_extremeties", table="device_lists_remote_extremeties",
@ -234,6 +270,12 @@ class DeviceStore(SQLBaseStore):
] ]
) )
txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,))
txn.call_after(self._get_cached_user_device.invalidate_many, (user_id,))
txn.call_after(
self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
)
self._simple_upsert_txn( self._simple_upsert_txn(
txn, txn,
table="device_lists_remote_extremeties", table="device_lists_remote_extremeties",
@ -320,6 +362,7 @@ class DeviceStore(SQLBaseStore):
return (now_stream_id, results) return (now_stream_id, results)
@defer.inlineCallbacks
def get_user_devices_from_cache(self, query_list): def get_user_devices_from_cache(self, query_list):
"""Get the devices (and keys if any) for remote users from the cache. """Get the devices (and keys if any) for remote users from the cache.
@ -332,27 +375,11 @@ class DeviceStore(SQLBaseStore):
a set of user_ids and results_map is a mapping of a set of user_ids and results_map is a mapping of
user_id -> device_id -> device_info user_id -> device_id -> device_info
""" """
return self.runInteraction( user_ids = set(user_id for user_id, _ in query_list)
"get_user_devices_from_cache", self._get_user_devices_from_cache_txn, user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids))
query_list, user_ids_in_cache = set(
user_id for user_id, stream_id in user_map.items() if stream_id
) )
def _get_user_devices_from_cache_txn(self, txn, query_list):
user_ids = {user_id for user_id, _ in query_list}
user_ids_in_cache = set()
for user_id in user_ids:
stream_ids = self._simple_select_onecol_txn(
txn,
table="device_lists_remote_extremeties",
keyvalues={
"user_id": user_id,
},
retcol="stream_id",
)
if stream_ids:
user_ids_in_cache.add(user_id)
user_ids_not_in_cache = user_ids - user_ids_in_cache user_ids_not_in_cache = user_ids - user_ids_in_cache
results = {} results = {}
@ -361,32 +388,40 @@ class DeviceStore(SQLBaseStore):
continue continue
if device_id: if device_id:
content = self._simple_select_one_onecol_txn( device = yield self._get_cached_user_device(user_id, device_id)
txn, results.setdefault(user_id, {})[device_id] = device
else:
results[user_id] = yield self._get_cached_devices_for_user(user_id)
defer.returnValue((user_ids_not_in_cache, results))
@cachedInlineCallbacks(num_args=2, tree=True)
def _get_cached_user_device(self, user_id, device_id):
content = yield self._simple_select_one_onecol(
table="device_lists_remote_cache", table="device_lists_remote_cache",
keyvalues={ keyvalues={
"user_id": user_id, "user_id": user_id,
"device_id": device_id, "device_id": device_id,
}, },
retcol="content", retcol="content",
desc="_get_cached_user_device",
) )
results.setdefault(user_id, {})[device_id] = json.loads(content) defer.returnValue(json.loads(content))
else:
devices = self._simple_select_list_txn( @cachedInlineCallbacks()
txn, def _get_cached_devices_for_user(self, user_id):
devices = yield self._simple_select_list(
table="device_lists_remote_cache", table="device_lists_remote_cache",
keyvalues={ keyvalues={
"user_id": user_id, "user_id": user_id,
}, },
retcols=("device_id", "content"), retcols=("device_id", "content"),
desc="_get_cached_devices_for_user",
) )
results[user_id] = { defer.returnValue({
device["device_id"]: json.loads(device["content"]) device["device_id"]: json.loads(device["content"])
for device in devices for device in devices
} })
user_ids_in_cache.discard(user_id)
return user_ids_not_in_cache, results
def get_devices_with_keys_by_user(self, user_id): def get_devices_with_keys_by_user(self, user_id):
"""Get all devices (with any device keys) for a user """Get all devices (with any device keys) for a user
@ -473,7 +508,7 @@ class DeviceStore(SQLBaseStore):
defer.returnValue(set(changed)) defer.returnValue(set(changed))
sql = """ sql = """
SELECT user_id FROM device_lists_stream WHERE stream_id > ? SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ?
""" """
rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key) rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key)
defer.returnValue(set(row[0] for row in rows)) defer.returnValue(set(row[0] for row in rows))
@ -489,7 +524,7 @@ class DeviceStore(SQLBaseStore):
WHERE stream_id > ? WHERE stream_id > ?
""" """
return self._execute( return self._execute(
"get_users_and_hosts_device_list", None, "get_all_device_list_changes_for_remotes", None,
sql, from_key, sql, from_key,
) )
@ -518,6 +553,16 @@ class DeviceStore(SQLBaseStore):
host, stream_id, host, stream_id,
) )
# Delete older entries in the table, as we really only care about
# when the latest change happened.
txn.executemany(
"""
DELETE FROM device_lists_stream
WHERE user_id = ? AND device_id = ? AND stream_id < ?
""",
[(user_id, device_id, stream_id) for device_id in device_ids]
)
self._simple_insert_many_txn( self._simple_insert_many_txn(
txn, txn,
table="device_lists_stream", table="device_lists_stream",

View file

@ -93,7 +93,7 @@ class EndToEndKeyStore(SQLBaseStore):
query_clause = "user_id = ?" query_clause = "user_id = ?"
query_params.append(user_id) query_params.append(user_id)
if device_id: if device_id is not None:
query_clause += " AND device_id = ?" query_clause += " AND device_id = ?"
query_params.append(device_id) query_params.append(device_id)

View file

@ -281,15 +281,30 @@ class EventFederationStore(SQLBaseStore):
) )
def get_forward_extremeties_for_room(self, room_id, stream_ordering): def get_forward_extremeties_for_room(self, room_id, stream_ordering):
"""For a given room_id and stream_ordering, return the forward
extremeties of the room at that point in "time".
Throws a StoreError if we have since purged the index for
stream_orderings from that point.
Args:
room_id (str):
stream_ordering (int):
Returns:
deferred, which resolves to a list of event_ids
"""
# We want to make the cache more effective, so we clamp to the last # We want to make the cache more effective, so we clamp to the last
# change before the given ordering. # change before the given ordering.
last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id) last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id)
# We don't always have a full stream_to_exterm_id table, e.g. after # We don't always have a full stream_to_exterm_id table, e.g. after
# the upgrade that introduced it, so we make sure we never ask for a # the upgrade that introduced it, so we make sure we never ask for a
# try and pin to a stream_ordering from before a restart # stream_ordering from before a restart
last_change = max(self._stream_order_on_start, last_change) last_change = max(self._stream_order_on_start, last_change)
# provided the last_change is recent enough, we now clamp the requested
# stream_ordering to it.
if last_change > self.stream_ordering_month_ago: if last_change > self.stream_ordering_month_ago:
stream_ordering = min(last_change, stream_ordering) stream_ordering = min(last_change, stream_ordering)

View file

@ -15,6 +15,7 @@
from ._base import SQLBaseStore from ._base import SQLBaseStore
from twisted.internet import defer from twisted.internet import defer
from synapse.util.async import sleep
from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.types import RoomStreamToken from synapse.types import RoomStreamToken
from .stream import lower_bound from .stream import lower_bound
@ -25,11 +26,46 @@ import ujson as json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DEFAULT_NOTIF_ACTION = ["notify", {"set_tweak": "highlight", "value": False}]
DEFAULT_HIGHLIGHT_ACTION = [
"notify", {"set_tweak": "sound", "value": "default"}, {"set_tweak": "highlight"}
]
def _serialize_action(actions, is_highlight):
"""Custom serializer for actions. This allows us to "compress" common actions.
We use the fact that most users have the same actions for notifs (and for
highlights).
We store these default actions as the empty string rather than the full JSON.
Since the empty string isn't valid JSON there is no risk of this clashing with
any real JSON actions
"""
if is_highlight:
if actions == DEFAULT_HIGHLIGHT_ACTION:
return "" # We use empty string as the column is non-NULL
else:
if actions == DEFAULT_NOTIF_ACTION:
return ""
return json.dumps(actions)
def _deserialize_action(actions, is_highlight):
"""Custom deserializer for actions. This allows us to "compress" common actions
"""
if actions:
return json.loads(actions)
if is_highlight:
return DEFAULT_HIGHLIGHT_ACTION
else:
return DEFAULT_NOTIF_ACTION
class EventPushActionsStore(SQLBaseStore): class EventPushActionsStore(SQLBaseStore):
EPA_HIGHLIGHT_INDEX = "epa_highlight_index" EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
def __init__(self, hs): def __init__(self, hs):
self.stream_ordering_month_ago = None
super(EventPushActionsStore, self).__init__(hs) super(EventPushActionsStore, self).__init__(hs)
self.register_background_index_update( self.register_background_index_update(
@ -47,6 +83,11 @@ class EventPushActionsStore(SQLBaseStore):
where_clause="highlight=1" where_clause="highlight=1"
) )
self._doing_notif_rotation = False
self._rotate_notif_loop = self._clock.looping_call(
self._rotate_notifs, 30 * 60 * 1000
)
def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples): def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples):
""" """
Args: Args:
@ -55,15 +96,17 @@ class EventPushActionsStore(SQLBaseStore):
""" """
values = [] values = []
for uid, actions in tuples: for uid, actions in tuples:
is_highlight = 1 if _action_has_highlight(actions) else 0
values.append({ values.append({
'room_id': event.room_id, 'room_id': event.room_id,
'event_id': event.event_id, 'event_id': event.event_id,
'user_id': uid, 'user_id': uid,
'actions': json.dumps(actions), 'actions': _serialize_action(actions, is_highlight),
'stream_ordering': event.internal_metadata.stream_ordering, 'stream_ordering': event.internal_metadata.stream_ordering,
'topological_ordering': event.depth, 'topological_ordering': event.depth,
'notif': 1, 'notif': 1,
'highlight': 1 if _action_has_highlight(actions) else 0, 'highlight': is_highlight,
}) })
for uid, __ in tuples: for uid, __ in tuples:
@ -77,7 +120,15 @@ class EventPushActionsStore(SQLBaseStore):
def get_unread_event_push_actions_by_room_for_user( def get_unread_event_push_actions_by_room_for_user(
self, room_id, user_id, last_read_event_id self, room_id, user_id, last_read_event_id
): ):
def _get_unread_event_push_actions_by_room(txn): ret = yield self.runInteraction(
"get_unread_event_push_actions_by_room",
self._get_unread_counts_by_receipt_txn,
room_id, user_id, last_read_event_id
)
defer.returnValue(ret)
def _get_unread_counts_by_receipt_txn(self, txn, room_id, user_id,
last_read_event_id):
sql = ( sql = (
"SELECT stream_ordering, topological_ordering" "SELECT stream_ordering, topological_ordering"
" FROM events" " FROM events"
@ -92,6 +143,13 @@ class EventPushActionsStore(SQLBaseStore):
stream_ordering = results[0][0] stream_ordering = results[0][0]
topological_ordering = results[0][1] topological_ordering = results[0][1]
return self._get_unread_counts_by_pos_txn(
txn, room_id, user_id, topological_ordering, stream_ordering
)
def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, topological_ordering,
stream_ordering):
token = RoomStreamToken( token = RoomStreamToken(
topological_ordering, stream_ordering topological_ordering, stream_ordering
) )
@ -112,6 +170,14 @@ class EventPushActionsStore(SQLBaseStore):
row = txn.fetchone() row = txn.fetchone()
notify_count = row[0] if row else 0 notify_count = row[0] if row else 0
txn.execute("""
SELECT notif_count FROM event_push_summary
WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
""", (room_id, user_id, stream_ordering,))
rows = txn.fetchall()
if rows:
notify_count += rows[0][0]
# Now get the number of highlights # Now get the number of highlights
sql = ( sql = (
"SELECT count(*)" "SELECT count(*)"
@ -132,12 +198,6 @@ class EventPushActionsStore(SQLBaseStore):
"highlight_count": highlight_count, "highlight_count": highlight_count,
} }
ret = yield self.runInteraction(
"get_unread_event_push_actions_by_room",
_get_unread_event_push_actions_by_room
)
defer.returnValue(ret)
@defer.inlineCallbacks @defer.inlineCallbacks
def get_push_action_users_in_range(self, min_stream_ordering, max_stream_ordering): def get_push_action_users_in_range(self, min_stream_ordering, max_stream_ordering):
def f(txn): def f(txn):
@ -176,7 +236,8 @@ class EventPushActionsStore(SQLBaseStore):
# find rooms that have a read receipt in them and return the next # find rooms that have a read receipt in them and return the next
# push actions # push actions
sql = ( sql = (
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions" "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
" ep.highlight "
" FROM (" " FROM ("
" SELECT room_id," " SELECT room_id,"
" MAX(topological_ordering) as topological_ordering," " MAX(topological_ordering) as topological_ordering,"
@ -217,7 +278,7 @@ class EventPushActionsStore(SQLBaseStore):
def get_no_receipt(txn): def get_no_receipt(txn):
sql = ( sql = (
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
" e.received_ts" " ep.highlight "
" FROM event_push_actions AS ep" " FROM event_push_actions AS ep"
" INNER JOIN events AS e USING (room_id, event_id)" " INNER JOIN events AS e USING (room_id, event_id)"
" WHERE" " WHERE"
@ -246,7 +307,7 @@ class EventPushActionsStore(SQLBaseStore):
"event_id": row[0], "event_id": row[0],
"room_id": row[1], "room_id": row[1],
"stream_ordering": row[2], "stream_ordering": row[2],
"actions": json.loads(row[3]), "actions": _deserialize_action(row[3], row[4]),
} for row in after_read_receipt + no_read_receipt } for row in after_read_receipt + no_read_receipt
] ]
@ -285,7 +346,7 @@ class EventPushActionsStore(SQLBaseStore):
def get_after_receipt(txn): def get_after_receipt(txn):
sql = ( sql = (
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
" e.received_ts" " ep.highlight, e.received_ts"
" FROM (" " FROM ("
" SELECT room_id," " SELECT room_id,"
" MAX(topological_ordering) as topological_ordering," " MAX(topological_ordering) as topological_ordering,"
@ -327,7 +388,7 @@ class EventPushActionsStore(SQLBaseStore):
def get_no_receipt(txn): def get_no_receipt(txn):
sql = ( sql = (
"SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,"
" e.received_ts" " ep.highlight, e.received_ts"
" FROM event_push_actions AS ep" " FROM event_push_actions AS ep"
" INNER JOIN events AS e USING (room_id, event_id)" " INNER JOIN events AS e USING (room_id, event_id)"
" WHERE" " WHERE"
@ -357,8 +418,8 @@ class EventPushActionsStore(SQLBaseStore):
"event_id": row[0], "event_id": row[0],
"room_id": row[1], "room_id": row[1],
"stream_ordering": row[2], "stream_ordering": row[2],
"actions": json.loads(row[3]), "actions": _deserialize_action(row[3], row[4]),
"received_ts": row[4], "received_ts": row[5],
} for row in after_read_receipt + no_read_receipt } for row in after_read_receipt + no_read_receipt
] ]
@ -392,7 +453,7 @@ class EventPushActionsStore(SQLBaseStore):
sql = ( sql = (
"SELECT epa.event_id, epa.room_id," "SELECT epa.event_id, epa.room_id,"
" epa.stream_ordering, epa.topological_ordering," " epa.stream_ordering, epa.topological_ordering,"
" epa.actions, epa.profile_tag, e.received_ts" " epa.actions, epa.highlight, epa.profile_tag, e.received_ts"
" FROM event_push_actions epa, events e" " FROM event_push_actions epa, events e"
" WHERE epa.event_id = e.event_id" " WHERE epa.event_id = e.event_id"
" AND epa.user_id = ? %s" " AND epa.user_id = ? %s"
@ -407,7 +468,7 @@ class EventPushActionsStore(SQLBaseStore):
"get_push_actions_for_user", f "get_push_actions_for_user", f
) )
for pa in push_actions: for pa in push_actions:
pa["actions"] = json.loads(pa["actions"]) pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
defer.returnValue(push_actions) defer.returnValue(push_actions)
@defer.inlineCallbacks @defer.inlineCallbacks
@ -448,10 +509,14 @@ class EventPushActionsStore(SQLBaseStore):
) )
def _remove_old_push_actions_before_txn(self, txn, room_id, user_id, def _remove_old_push_actions_before_txn(self, txn, room_id, user_id,
topological_ordering): topological_ordering, stream_ordering):
""" """
Purges old, stale push actions for a user and room before a given Purges old push actions for a user and room before a given
topological_ordering topological_ordering.
We however keep a months worth of highlighted notifications, so that
users can still get a list of recent highlights.
Args: Args:
txn: The transcation txn: The transcation
room_id: Room ID to delete from room_id: Room ID to delete from
@ -475,10 +540,16 @@ class EventPushActionsStore(SQLBaseStore):
txn.execute( txn.execute(
"DELETE FROM event_push_actions " "DELETE FROM event_push_actions "
" WHERE user_id = ? AND room_id = ? AND " " WHERE user_id = ? AND room_id = ? AND "
" topological_ordering < ? AND stream_ordering < ?", " topological_ordering <= ?"
" AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
(user_id, room_id, topological_ordering, self.stream_ordering_month_ago) (user_id, room_id, topological_ordering, self.stream_ordering_month_ago)
) )
txn.execute("""
DELETE FROM event_push_summary
WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
""", (room_id, user_id, stream_ordering))
@defer.inlineCallbacks @defer.inlineCallbacks
def _find_stream_orderings_for_times(self): def _find_stream_orderings_for_times(self):
yield self.runInteraction( yield self.runInteraction(
@ -495,6 +566,14 @@ class EventPushActionsStore(SQLBaseStore):
"Found stream ordering 1 month ago: it's %d", "Found stream ordering 1 month ago: it's %d",
self.stream_ordering_month_ago self.stream_ordering_month_ago
) )
logger.info("Searching for stream ordering 1 day ago")
self.stream_ordering_day_ago = self._find_first_stream_ordering_after_ts_txn(
txn, self._clock.time_msec() - 24 * 60 * 60 * 1000
)
logger.info(
"Found stream ordering 1 day ago: it's %d",
self.stream_ordering_day_ago
)
def _find_first_stream_ordering_after_ts_txn(self, txn, ts): def _find_first_stream_ordering_after_ts_txn(self, txn, ts):
""" """
@ -534,6 +613,131 @@ class EventPushActionsStore(SQLBaseStore):
return range_end return range_end
@defer.inlineCallbacks
def _rotate_notifs(self):
if self._doing_notif_rotation or self.stream_ordering_day_ago is None:
return
self._doing_notif_rotation = True
try:
while True:
logger.info("Rotating notifications")
caught_up = yield self.runInteraction(
"_rotate_notifs",
self._rotate_notifs_txn
)
if caught_up:
break
yield sleep(5)
finally:
self._doing_notif_rotation = False
def _rotate_notifs_txn(self, txn):
"""Archives older notifications into event_push_summary. Returns whether
the archiving process has caught up or not.
"""
old_rotate_stream_ordering = self._simple_select_one_onecol_txn(
txn,
table="event_push_summary_stream_ordering",
keyvalues={},
retcol="stream_ordering",
)
# We don't to try and rotate millions of rows at once, so we cap the
# maximum stream ordering we'll rotate before.
txn.execute("""
SELECT stream_ordering FROM event_push_actions
WHERE stream_ordering > ?
ORDER BY stream_ordering ASC LIMIT 1 OFFSET 50000
""", (old_rotate_stream_ordering,))
stream_row = txn.fetchone()
if stream_row:
offset_stream_ordering, = stream_row
rotate_to_stream_ordering = min(
self.stream_ordering_day_ago, offset_stream_ordering
)
caught_up = offset_stream_ordering >= self.stream_ordering_day_ago
else:
rotate_to_stream_ordering = self.stream_ordering_day_ago
caught_up = True
logger.info("Rotating notifications up to: %s", rotate_to_stream_ordering)
self._rotate_notifs_before_txn(txn, rotate_to_stream_ordering)
# We have caught up iff we were limited by `stream_ordering_day_ago`
return caught_up
def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering):
old_rotate_stream_ordering = self._simple_select_one_onecol_txn(
txn,
table="event_push_summary_stream_ordering",
keyvalues={},
retcol="stream_ordering",
)
# Calculate the new counts that should be upserted into event_push_summary
sql = """
SELECT user_id, room_id,
coalesce(old.notif_count, 0) + upd.notif_count,
upd.stream_ordering,
old.user_id
FROM (
SELECT user_id, room_id, count(*) as notif_count,
max(stream_ordering) as stream_ordering
FROM event_push_actions
WHERE ? <= stream_ordering AND stream_ordering < ?
AND highlight = 0
GROUP BY user_id, room_id
) AS upd
LEFT JOIN event_push_summary AS old USING (user_id, room_id)
"""
txn.execute(sql, (old_rotate_stream_ordering, rotate_to_stream_ordering,))
rows = txn.fetchall()
logger.info("Rotating notifications, handling %d rows", len(rows))
# If the `old.user_id` above is NULL then we know there isn't already an
# entry in the table, so we simply insert it. Otherwise we update the
# existing table.
self._simple_insert_many_txn(
txn,
table="event_push_summary",
values=[
{
"user_id": row[0],
"room_id": row[1],
"notif_count": row[2],
"stream_ordering": row[3],
}
for row in rows if row[4] is None
]
)
txn.executemany(
"""
UPDATE event_push_summary SET notif_count = ?, stream_ordering = ?
WHERE user_id = ? AND room_id = ?
""",
((row[2], row[3], row[0], row[1],) for row in rows if row[4] is not None)
)
txn.execute(
"DELETE FROM event_push_actions"
" WHERE ? <= stream_ordering AND stream_ordering < ? AND highlight = 0",
(old_rotate_stream_ordering, rotate_to_stream_ordering,)
)
logger.info("Rotating notifications, deleted %s push actions", txn.rowcount)
txn.execute(
"UPDATE event_push_summary_stream_ordering SET stream_ordering = ?",
(rotate_to_stream_ordering,)
)
def _action_has_highlight(actions): def _action_has_highlight(actions):
for action in actions: for action in actions:

View file

@ -311,6 +311,21 @@ class EventsStore(SQLBaseStore):
new_forward_extremeties[room_id] = new_latest_event_ids new_forward_extremeties[room_id] = new_latest_event_ids
len_1 = (
len(latest_event_ids) == 1
and len(new_latest_event_ids) == 1
)
if len_1:
all_single_prev_not_state = all(
len(event.prev_events) == 1
and not event.is_state()
for event, ctx in ev_ctx_rm
)
# Don't bother calculating state if they're just
# a long chain of single ancestor non-state events.
if all_single_prev_not_state:
continue
state = yield self._calculate_state_delta( state = yield self._calculate_state_delta(
room_id, ev_ctx_rm, new_latest_event_ids room_id, ev_ctx_rm, new_latest_event_ids
) )

View file

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 40 SCHEMA_VERSION = 41
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))

View file

@ -15,7 +15,7 @@
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.api.constants import PresenceState from synapse.api.constants import PresenceState
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
from collections import namedtuple from collections import namedtuple
from twisted.internet import defer from twisted.internet import defer
@ -85,6 +85,9 @@ class PresenceStore(SQLBaseStore):
self.presence_stream_cache.entity_has_changed, self.presence_stream_cache.entity_has_changed,
state.user_id, stream_id, state.user_id, stream_id,
) )
self._invalidate_cache_and_stream(
txn, self._get_presence_for_user, (state.user_id,)
)
# Actually insert new rows # Actually insert new rows
self._simple_insert_many_txn( self._simple_insert_many_txn(
@ -143,7 +146,12 @@ class PresenceStore(SQLBaseStore):
"get_all_presence_updates", get_all_presence_updates_txn "get_all_presence_updates", get_all_presence_updates_txn
) )
@defer.inlineCallbacks @cached()
def _get_presence_for_user(self, user_id):
raise NotImplementedError()
@cachedList(cached_method_name="_get_presence_for_user", list_name="user_ids",
num_args=1, inlineCallbacks=True)
def get_presence_for_users(self, user_ids): def get_presence_for_users(self, user_ids):
rows = yield self._simple_select_many_batch( rows = yield self._simple_select_many_batch(
table="presence_stream", table="presence_stream",
@ -165,7 +173,7 @@ class PresenceStore(SQLBaseStore):
for row in rows: for row in rows:
row["currently_active"] = bool(row["currently_active"]) row["currently_active"] = bool(row["currently_active"])
defer.returnValue([UserPresenceState(**row) for row in rows]) defer.returnValue({row["user_id"]: UserPresenceState(**row) for row in rows})
def get_current_presence_token(self): def get_current_presence_token(self):
return self._presence_id_gen.get_current_token() return self._presence_id_gen.get_current_token()

View file

@ -351,6 +351,7 @@ class ReceiptsStore(SQLBaseStore):
room_id=room_id, room_id=room_id,
user_id=user_id, user_id=user_id,
topological_ordering=topological_ordering, topological_ordering=topological_ordering,
stream_ordering=stream_ordering,
) )
return True return True

View file

@ -0,0 +1,37 @@
/* Copyright 2017 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Aggregate of old notification counts that have been deleted out of the
-- main event_push_actions table. This count does not include those that were
-- highlights, as they remain in the event_push_actions table.
CREATE TABLE event_push_summary (
user_id TEXT NOT NULL,
room_id TEXT NOT NULL,
notif_count BIGINT NOT NULL,
stream_ordering BIGINT NOT NULL
);
CREATE INDEX event_push_summary_user_rm ON event_push_summary(user_id, room_id);
-- The stream ordering up to which we have aggregated the event_push_actions
-- table into event_push_summary
CREATE TABLE event_push_summary_stream_ordering (
Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row.
stream_ordering BIGINT NOT NULL,
CHECK (Lock='X')
);
INSERT INTO event_push_summary_stream_ordering (stream_ordering) VALUES (0);

View file

@ -0,0 +1,39 @@
/* Copyright 2017 Vector Creations Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE IF NOT EXISTS pushers2 (
id BIGINT PRIMARY KEY,
user_name TEXT NOT NULL,
access_token BIGINT DEFAULT NULL,
profile_tag TEXT NOT NULL,
kind TEXT NOT NULL,
app_id TEXT NOT NULL,
app_display_name TEXT NOT NULL,
device_display_name TEXT NOT NULL,
pushkey TEXT NOT NULL,
ts BIGINT NOT NULL,
lang TEXT,
data TEXT,
last_stream_ordering INTEGER,
last_success BIGINT,
failing_since BIGINT,
UNIQUE (app_id, pushkey, user_name)
);
INSERT INTO pushers2 SELECT * FROM PUSHERS;
DROP TABLE PUSHERS;
ALTER TABLE pushers2 RENAME TO pushers;

View file

@ -0,0 +1,17 @@
/* Copyright 2017 Vector Creations Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
INSERT into background_updates (update_name, progress_json)
VALUES ('device_lists_stream_idx', '{}');

View file

@ -0,0 +1,16 @@
/* Copyright 2017 Vector Creations Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE INDEX device_lists_outbound_pokes_stream ON device_lists_outbound_pokes(stream_id);

View file

@ -413,7 +413,19 @@ class StateStore(SQLBaseStore):
defer.returnValue({event: event_to_state[event] for event in event_ids}) defer.returnValue({event: event_to_state[event] for event in event_ids})
@defer.inlineCallbacks @defer.inlineCallbacks
def get_state_ids_for_events(self, event_ids, types): def get_state_ids_for_events(self, event_ids, types=None):
"""
Get the state dicts corresponding to a list of events
Args:
event_ids(list(str)): events whose state should be returned
types(list[(str, str)]|None): List of (type, state_key) tuples
which are used to filter the state fetched. May be None, which
matches any key
Returns:
A deferred dict from event_id -> (type, state_key) -> state_event
"""
event_to_groups = yield self._get_state_group_for_events( event_to_groups = yield self._get_state_group_for_events(
event_ids, event_ids,
) )

View file

@ -100,6 +100,13 @@ class ExpiringCache(object):
except KeyError: except KeyError:
return default return default
def setdefault(self, key, value):
try:
return self[key]
except KeyError:
self[key] = value
return value
def _prune_cache(self): def _prune_cache(self):
if not self._expiry_ms: if not self._expiry_ms:
# zero expiry time means don't expire. This should never get called # zero expiry time means don't expire. This should never get called

View file

@ -17,9 +17,15 @@ from twisted.internet import defer
import tests.unittest import tests.unittest
import tests.utils import tests.utils
from mock import Mock
USER_ID = "@user:example.com" USER_ID = "@user:example.com"
PlAIN_NOTIF = ["notify", {"set_tweak": "highlight", "value": False}]
HIGHLIGHT = [
"notify", {"set_tweak": "sound", "value": "default"}, {"set_tweak": "highlight"}
]
class EventPushActionsStoreTestCase(tests.unittest.TestCase): class EventPushActionsStoreTestCase(tests.unittest.TestCase):
@ -39,3 +45,83 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
yield self.store.get_unread_push_actions_for_user_in_range_for_email( yield self.store.get_unread_push_actions_for_user_in_range_for_email(
USER_ID, 0, 1000, 20 USER_ID, 0, 1000, 20
) )
@defer.inlineCallbacks
def test_count_aggregation(self):
room_id = "!foo:example.com"
user_id = "@user1235:example.com"
@defer.inlineCallbacks
def _assert_counts(noitf_count, highlight_count):
counts = yield self.store.runInteraction(
"", self.store._get_unread_counts_by_pos_txn,
room_id, user_id, 0, 0
)
self.assertEquals(
counts,
{"notify_count": noitf_count, "highlight_count": highlight_count}
)
def _inject_actions(stream, action):
event = Mock()
event.room_id = room_id
event.event_id = "$test:example.com"
event.internal_metadata.stream_ordering = stream
event.depth = stream
tuples = [(user_id, action)]
return self.store.runInteraction(
"", self.store._set_push_actions_for_event_and_users_txn,
event, tuples
)
def _rotate(stream):
return self.store.runInteraction(
"", self.store._rotate_notifs_before_txn, stream
)
def _mark_read(stream, depth):
return self.store.runInteraction(
"", self.store._remove_old_push_actions_before_txn,
room_id, user_id, depth, stream
)
yield _assert_counts(0, 0)
yield _inject_actions(1, PlAIN_NOTIF)
yield _assert_counts(1, 0)
yield _rotate(2)
yield _assert_counts(1, 0)
yield _inject_actions(3, PlAIN_NOTIF)
yield _assert_counts(2, 0)
yield _rotate(4)
yield _assert_counts(2, 0)
yield _inject_actions(5, PlAIN_NOTIF)
yield _mark_read(3, 3)
yield _assert_counts(1, 0)
yield _mark_read(5, 5)
yield _assert_counts(0, 0)
yield _inject_actions(6, PlAIN_NOTIF)
yield _rotate(7)
yield self.store._simple_delete(
table="event_push_actions",
keyvalues={"1": 1},
desc="",
)
yield _assert_counts(1, 0)
yield _mark_read(7, 7)
yield _assert_counts(0, 0)
yield _inject_actions(8, HIGHLIGHT)
yield _assert_counts(1, 1)
yield _rotate(9)
yield _assert_counts(1, 1)
yield _rotate(10)
yield _assert_counts(1, 1)