0
0
Fork 1
mirror of https://mau.dev/maunium/synapse.git synced 2025-01-19 00:41:55 +01:00

Merge branch 'develop' of github.com:matrix-org/synapse into erikj/backfill_fixes

This commit is contained in:
Erik Johnston 2015-05-22 16:10:42 +01:00
commit 74b7de83ec
26 changed files with 259 additions and 144 deletions

View file

@ -1,3 +1,9 @@
Changes in synapse v0.9.0-r5 (2015-05-21)
=========================================
* Add more database caches to reduce amount of work done for each pusher. This
radically reduces CPU usage when multiple pushers are set up in the same room.
Changes in synapse v0.9.0 (2015-05-07)
======================================

View file

@ -33,9 +33,10 @@ def request_registration(user, password, server_location, shared_secret):
).hexdigest()
data = {
"username": user,
"user": user,
"password": password,
"mac": mac,
"type": "org.matrix.login.shared_secret",
}
server_location = server_location.rstrip("/")
@ -43,7 +44,7 @@ def request_registration(user, password, server_location, shared_secret):
print "Sending registration request..."
req = urllib2.Request(
"%s/_matrix/client/v2_alpha/register" % (server_location,),
"%s/_matrix/client/api/v1/register" % (server_location,),
data=json.dumps(data),
headers={'Content-Type': 'application/json'}
)

View file

@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server.
"""
__version__ = "0.9.0-r4"
__version__ = "0.9.0-r5"

View file

@ -277,9 +277,12 @@ class SynapseHomeServer(HomeServer):
config,
metrics_resource,
),
interface="127.0.0.1",
interface=config.metrics_bind_host,
)
logger.info(
"Metrics now running on %s port %d",
config.metrics_bind_host, config.metrics_port,
)
logger.info("Metrics now running on 127.0.0.1 port %d", config.metrics_port)
def run_startup_checks(self, db_conn, database_engine):
all_users_native = are_all_users_on_domain(

View file

@ -148,8 +148,8 @@ class ApplicationService(object):
and self.is_interested_in_user(event.state_key)):
return True
# check joined member events
for member in member_list:
if self.is_interested_in_user(member.state_key):
for user_id in member_list:
if self.is_interested_in_user(user_id):
return True
return False
@ -173,7 +173,7 @@ class ApplicationService(object):
restrict_to(str): The namespace to restrict regex tests to.
aliases_for_event(list): A list of all the known room aliases for
this event.
member_list(list): A list of all joined room members in this room.
member_list(list): A list of all joined user_ids in this room.
Returns:
bool: True if this service would like to know about this event.
"""

View file

@ -20,6 +20,7 @@ class MetricsConfig(Config):
def read_config(self, config):
self.enable_metrics = config["enable_metrics"]
self.metrics_port = config.get("metrics_port")
self.metrics_bind_host = config.get("metrics_bind_host", "127.0.0.1")
def default_config(self, config_dir_path, server_name):
return """\
@ -28,6 +29,9 @@ class MetricsConfig(Config):
# Enable collection and rendering of performance metrics
enable_metrics: False
# Separate port to accept metrics requests on (on localhost)
# Separate port to accept metrics requests on
# metrics_port: 8081
# Which host to bind the metric listener to
# metrics_bind_host: 127.0.0.1
"""

View file

@ -194,6 +194,8 @@ class FederationClient(FederationBase):
outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
it's from an arbitary point in the context as opposed to part
of the current block of PDUs. Defaults to `False`
timeout (int): How long to try (in ms) each destination for before
moving to the next destination. None indicates no timeout.
Returns:
Deferred: Results in the requested PDU.

View file

@ -207,13 +207,13 @@ class TransactionQueue(object):
# request at which point pending_pdus_by_dest just keeps growing.
# we need application-layer timeouts of some flavour of these
# requests
logger.info(
logger.debug(
"TX [%s] Transaction already in progress",
destination
)
return
logger.info("TX [%s] _attempt_new_transaction", destination)
logger.debug("TX [%s] _attempt_new_transaction", destination)
# list of (pending_pdu, deferred, order)
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
@ -221,11 +221,11 @@ class TransactionQueue(object):
pending_failures = self.pending_failures_by_dest.pop(destination, [])
if pending_pdus:
logger.info("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
destination, len(pending_pdus))
logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
destination, len(pending_pdus))
if not pending_pdus and not pending_edus and not pending_failures:
logger.info("TX [%s] Nothing to send", destination)
logger.debug("TX [%s] Nothing to send", destination)
return
# Sort based on the order field
@ -242,6 +242,8 @@ class TransactionQueue(object):
try:
self.pending_transactions[destination] = 1
txn_id = str(self._next_txn_id)
limiter = yield get_retry_limiter(
destination,
self._clock,
@ -249,9 +251,9 @@ class TransactionQueue(object):
)
logger.debug(
"TX [%s] Attempting new transaction"
"TX [%s] {%s} Attempting new transaction"
" (pdus: %d, edus: %d, failures: %d)",
destination,
destination, txn_id,
len(pending_pdus),
len(pending_edus),
len(pending_failures)
@ -261,7 +263,7 @@ class TransactionQueue(object):
transaction = Transaction.create_new(
origin_server_ts=int(self._clock.time_msec()),
transaction_id=str(self._next_txn_id),
transaction_id=txn_id,
origin=self.server_name,
destination=destination,
pdus=pdus,
@ -275,9 +277,13 @@ class TransactionQueue(object):
logger.debug("TX [%s] Persisted transaction", destination)
logger.info(
"TX [%s] Sending transaction [%s]",
destination,
"TX [%s] {%s} Sending transaction [%s],"
" (PDUs: %d, EDUs: %d, failures: %d)",
destination, txn_id,
transaction.transaction_id,
len(pending_pdus),
len(pending_edus),
len(pending_failures),
)
with limiter:
@ -313,7 +319,10 @@ class TransactionQueue(object):
code = e.code
response = e.response
logger.info("TX [%s] got %d response", destination, code)
logger.info(
"TX [%s] {%s} got %d response",
destination, txn_id, code
)
logger.debug("TX [%s] Sent transaction", destination)
logger.debug("TX [%s] Marking as delivered...", destination)

View file

@ -57,6 +57,8 @@ class TransportLayerClient(object):
destination (str): The host name of the remote home server we want
to get the state from.
event_id (str): The id of the event being requested.
timeout (int): How long to try (in ms) the destination for before
giving up. None indicates no timeout.
Returns:
Deferred: Results in a dict received from the remote homeserver.

View file

@ -196,6 +196,14 @@ class FederationSendServlet(BaseFederationServlet):
transaction_id, str(transaction_data)
)
logger.info(
"Received txn %s from %s. (PDUs: %d, EDUs: %d, failures: %d)",
transaction_id, origin,
len(transaction_data.get("pdus", [])),
len(transaction_data.get("edus", [])),
len(transaction_data.get("failures", [])),
)
# We should ideally be getting this from the security layer.
# origin = body["origin"]

View file

@ -15,7 +15,7 @@
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.api.constants import EventTypes
from synapse.appservice import ApplicationService
from synapse.types import UserID
@ -147,10 +147,7 @@ class ApplicationServicesHandler(object):
)
# We need to know the members associated with this event.room_id,
# if any.
member_list = yield self.store.get_room_members(
room_id=event.room_id,
membership=Membership.JOIN
)
member_list = yield self.store.get_users_in_room(event.room_id)
services = yield self.store.get_app_services()
interested_list = [

View file

@ -146,6 +146,10 @@ class PresenceHandler(BaseHandler):
self._user_cachemap = {}
self._user_cachemap_latest_serial = 0
# map room_ids to the latest presence serial for a member of that
# room
self._room_serials = {}
metrics.register_callback(
"userCachemap:size",
lambda: len(self._user_cachemap),
@ -297,13 +301,34 @@ class PresenceHandler(BaseHandler):
self.changed_presencelike_data(user, {"last_active": now})
def get_joined_rooms_for_user(self, user):
"""Get the list of rooms a user is joined to.
Args:
user(UserID): The user.
Returns:
A Deferred of a list of room id strings.
"""
rm_handler = self.homeserver.get_handlers().room_member_handler
return rm_handler.get_joined_rooms_for_user(user)
def get_joined_users_for_room_id(self, room_id):
rm_handler = self.homeserver.get_handlers().room_member_handler
return rm_handler.get_room_members(room_id)
@defer.inlineCallbacks
def changed_presencelike_data(self, user, state):
statuscache = self._get_or_make_usercache(user)
"""Updates the presence state of a local user.
Args:
user(UserID): The user being updated.
state(dict): The new presence state for the user.
Returns:
A Deferred
"""
self._user_cachemap_latest_serial += 1
statuscache.update(state, serial=self._user_cachemap_latest_serial)
return self.push_presence(user, statuscache=statuscache)
statuscache = yield self.update_presence_cache(user, state)
yield self.push_presence(user, statuscache=statuscache)
@log_function
def started_user_eventstream(self, user):
@ -326,13 +351,12 @@ class PresenceHandler(BaseHandler):
room_id(str): The room id the user joined.
"""
if self.hs.is_mine(user):
statuscache = self._get_or_make_usercache(user)
# No actual update but we need to bump the serial anyway for the
# event source
self._user_cachemap_latest_serial += 1
statuscache.update({}, serial=self._user_cachemap_latest_serial)
statuscache = yield self.update_presence_cache(
user, room_ids=[room_id]
)
self.push_update_to_local_and_remote(
observed_user=user,
room_ids=[room_id],
@ -340,16 +364,17 @@ class PresenceHandler(BaseHandler):
)
# We also want to tell them about current presence of people.
rm_handler = self.homeserver.get_handlers().room_member_handler
curr_users = yield rm_handler.get_room_members(room_id)
curr_users = yield self.get_joined_users_for_room_id(room_id)
for local_user in [c for c in curr_users if self.hs.is_mine(c)]:
statuscache = self._get_or_offline_usercache(local_user)
statuscache.update({}, serial=self._user_cachemap_latest_serial)
statuscache = yield self.update_presence_cache(
local_user, room_ids=[room_id], add_to_cache=False
)
self.push_update_to_local_and_remote(
observed_user=local_user,
users_to_push=[user],
statuscache=self._get_or_offline_usercache(local_user),
statuscache=statuscache,
)
@defer.inlineCallbacks
@ -546,8 +571,7 @@ class PresenceHandler(BaseHandler):
# Also include people in all my rooms
rm_handler = self.homeserver.get_handlers().room_member_handler
room_ids = yield rm_handler.get_joined_rooms_for_user(user)
room_ids = yield self.get_joined_rooms_for_user(user)
if state is None:
state = yield self.store.get_presence_state(user.localpart)
@ -747,8 +771,7 @@ class PresenceHandler(BaseHandler):
# and also user is informed of server-forced pushes
localusers.add(user)
rm_handler = self.homeserver.get_handlers().room_member_handler
room_ids = yield rm_handler.get_joined_rooms_for_user(user)
room_ids = yield self.get_joined_rooms_for_user(user)
if not localusers and not room_ids:
defer.returnValue(None)
@ -793,8 +816,7 @@ class PresenceHandler(BaseHandler):
" | %d interested local observers %r", len(observers), observers
)
rm_handler = self.homeserver.get_handlers().room_member_handler
room_ids = yield rm_handler.get_joined_rooms_for_user(user)
room_ids = yield self.get_joined_rooms_for_user(user)
if room_ids:
logger.debug(" | %d interested room IDs %r", len(room_ids), room_ids)
@ -813,10 +835,8 @@ class PresenceHandler(BaseHandler):
self.clock.time_msec() - state.pop("last_active_ago")
)
statuscache = self._get_or_make_usercache(user)
self._user_cachemap_latest_serial += 1
statuscache.update(state, serial=self._user_cachemap_latest_serial)
yield self.update_presence_cache(user, state, room_ids=room_ids)
if not observers and not room_ids:
logger.debug(" | no interested observers or room IDs")
@ -874,6 +894,35 @@ class PresenceHandler(BaseHandler):
yield defer.DeferredList(deferreds, consumeErrors=True)
@defer.inlineCallbacks
def update_presence_cache(self, user, state={}, room_ids=None,
add_to_cache=True):
"""Update the presence cache for a user with a new state and bump the
serial to the latest value.
Args:
user(UserID): The user being updated
state(dict): The presence state being updated
room_ids(None or list of str): A list of room_ids to update. If
room_ids is None then fetch the list of room_ids the user is
joined to.
add_to_cache: Whether to add an entry to the presence cache if the
user isn't already in the cache.
Returns:
A Deferred UserPresenceCache for the user being updated.
"""
if room_ids is None:
room_ids = yield self.get_joined_rooms_for_user(user)
for room_id in room_ids:
self._room_serials[room_id] = self._user_cachemap_latest_serial
if add_to_cache:
statuscache = self._get_or_make_usercache(user)
else:
statuscache = self._get_or_offline_usercache(user)
statuscache.update(state, serial=self._user_cachemap_latest_serial)
defer.returnValue(statuscache)
@defer.inlineCallbacks
def push_update_to_local_and_remote(self, observed_user, statuscache,
users_to_push=[], room_ids=[],
@ -996,39 +1045,11 @@ class PresenceEventSource(object):
self.hs = hs
self.clock = hs.get_clock()
@defer.inlineCallbacks
def is_visible(self, observer_user, observed_user):
if observer_user == observed_user:
defer.returnValue(True)
presence = self.hs.get_handlers().presence_handler
if (yield presence.store.user_rooms_intersect(
[u.to_string() for u in observer_user, observed_user])):
defer.returnValue(True)
if self.hs.is_mine(observed_user):
pushmap = presence._local_pushmap
defer.returnValue(
observed_user.localpart in pushmap and
observer_user in pushmap[observed_user.localpart]
)
else:
recvmap = presence._remote_recvmap
defer.returnValue(
observed_user in recvmap and
observer_user in recvmap[observed_user]
)
@defer.inlineCallbacks
@log_function
def get_new_events_for_user(self, user, from_key, limit):
from_key = int(from_key)
observer_user = user
presence = self.hs.get_handlers().presence_handler
cachemap = presence._user_cachemap
@ -1037,17 +1058,27 @@ class PresenceEventSource(object):
clock = self.clock
latest_serial = 0
user_ids_to_check = {user}
presence_list = yield presence.store.get_presence_list(
user.localpart, accepted=True
)
if presence_list is not None:
user_ids_to_check |= set(
UserID.from_string(p["observed_user_id"]) for p in presence_list
)
room_ids = yield presence.get_joined_rooms_for_user(user)
for room_id in set(room_ids) & set(presence._room_serials):
if presence._room_serials[room_id] > from_key:
joined = yield presence.get_joined_users_for_room_id(room_id)
user_ids_to_check |= set(joined)
updates = []
# TODO(paul): use a DeferredList ? How to limit concurrency.
for observed_user in cachemap.keys():
for observed_user in user_ids_to_check & set(cachemap):
cached = cachemap[observed_user]
if cached.serial <= from_key or cached.serial > max_serial:
continue
if not (yield self.is_visible(observer_user, observed_user)):
continue
latest_serial = max(cached.serial, latest_serial)
updates.append(cached.make_event(user=observed_user, clock=clock))
@ -1084,8 +1115,6 @@ class PresenceEventSource(object):
def get_pagination_rows(self, user, pagination_config, key):
# TODO (erikj): Does this make sense? Ordering?
observer_user = user
from_key = int(pagination_config.from_key)
if pagination_config.to_key:
@ -1096,14 +1125,26 @@ class PresenceEventSource(object):
presence = self.hs.get_handlers().presence_handler
cachemap = presence._user_cachemap
user_ids_to_check = {user}
presence_list = yield presence.store.get_presence_list(
user.localpart, accepted=True
)
if presence_list is not None:
user_ids_to_check |= set(
UserID.from_string(p["observed_user_id"]) for p in presence_list
)
room_ids = yield presence.get_joined_rooms_for_user(user)
for room_id in set(room_ids) & set(presence._room_serials):
if presence._room_serials[room_id] >= from_key:
joined = yield presence.get_joined_users_for_room_id(room_id)
user_ids_to_check |= set(joined)
updates = []
# TODO(paul): use a DeferredList ? How to limit concurrency.
for observed_user in cachemap.keys():
for observed_user in user_ids_to_check & set(cachemap):
if not (to_key < cachemap[observed_user].serial <= from_key):
continue
if (yield self.is_visible(observer_user, observed_user)):
updates.append((observed_user, cachemap[observed_user]))
updates.append((observed_user, cachemap[observed_user]))
# TODO(paul): limit

View file

@ -536,9 +536,7 @@ class RoomListHandler(BaseHandler):
chunk = yield self.store.get_rooms(is_public=True)
results = yield defer.gatherResults(
[
self.store.get_users_in_room(
room_id=room["room_id"],
)
self.store.get_users_in_room(room["room_id"])
for room in chunk
],
consumeErrors=True,

View file

@ -345,6 +345,9 @@ class MatrixFederationHttpClient(object):
path (str): The HTTP path.
args (dict): A dictionary used to create query strings, defaults to
None.
timeout (int): How long to try (in ms) the destination for before
giving up. None indicates no timeout and that the request will
be retried.
Returns:
Deferred: Succeeds when we get *any* HTTP response.

View file

@ -84,25 +84,20 @@ class Pusher(object):
rules = baserules.list_with_base_rules(rawrules, user)
room_id = ev['room_id']
# get *our* member event for display name matching
member_events_for_room = yield self.store.get_current_state(
room_id=ev['room_id'],
event_type='m.room.member',
state_key=None
)
my_display_name = None
room_member_count = 0
for mev in member_events_for_room:
if mev.content['membership'] != 'join':
continue
our_member_event = yield self.store.get_current_state(
room_id=room_id,
event_type='m.room.member',
state_key=self.user_name,
)
if our_member_event:
my_display_name = our_member_event[0].content.get("displayname")
# This loop does two things:
# 1) Find our current display name
if mev.state_key == self.user_name and 'displayname' in mev.content:
my_display_name = mev.content['displayname']
# and 2) Get the number of people in that room
room_member_count += 1
room_members = yield self.store.get_users_in_room(room_id)
room_member_count = len(room_members)
for r in rules:
if r['rule_id'] in enabled_map:
@ -287,9 +282,11 @@ class Pusher(object):
if len(actions) == 0:
logger.warn("Empty actions! Using default action.")
actions = Pusher.DEFAULT_ACTIONS
if 'notify' not in actions and 'dont_notify' not in actions:
logger.warn("Neither notify nor dont_notify in actions: adding default")
actions.extend(Pusher.DEFAULT_ACTIONS)
if 'dont_notify' in actions:
logger.debug(
"%s for %s: dont_notify",

View file

@ -121,6 +121,11 @@ class Cache(object):
self.sequence += 1
self.cache.pop(keyargs, None)
def invalidate_all(self):
self.check_thread()
self.sequence += 1
self.cache.clear()
def cached(max_entries=1000, num_args=1, lru=False):
""" A method decorator that applies a memoizing cache around the function.
@ -172,6 +177,7 @@ def cached(max_entries=1000, num_args=1, lru=False):
defer.returnValue(ret)
wrapped.invalidate = cache.invalidate
wrapped.invalidate_all = cache.invalidate_all
wrapped.prefill = cache.prefill
return wrapped

View file

@ -489,3 +489,4 @@ class EventFederationStore(SQLBaseStore):
query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
txn.execute(query, (room_id,))
txn.call_after(self.get_latest_event_ids_in_room.invalidate, room_id)

View file

@ -125,6 +125,12 @@ class EventsStore(SQLBaseStore):
# We purposefully do this first since if we include a `current_state`
# key, we *want* to update the `current_state_events` table
if current_state:
txn.call_after(self.get_current_state_for_key.invalidate_all)
txn.call_after(self.get_rooms_for_user.invalidate_all)
txn.call_after(self.get_users_in_room.invalidate, event.room_id)
txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id)
txn.call_after(self.get_room_name_and_aliases, event.room_id)
self._simple_delete_txn(
txn,
table="current_state_events",
@ -132,13 +138,6 @@ class EventsStore(SQLBaseStore):
)
for s in current_state:
if s.type == EventTypes.Member:
txn.call_after(
self.get_rooms_for_user.invalidate, s.state_key
)
txn.call_after(
self.get_joined_hosts_for_room.invalidate, s.room_id
)
self._simple_insert_txn(
txn,
"current_state_events",
@ -356,6 +355,18 @@ class EventsStore(SQLBaseStore):
)
if is_new_state and not context.rejected:
txn.call_after(
self.get_current_state_for_key.invalidate,
event.room_id, event.type, event.state_key
)
if (event.type == EventTypes.Name
or event.type == EventTypes.Aliases):
txn.call_after(
self.get_room_name_and_aliases.invalidate,
event.room_id
)
self._simple_upsert_txn(
txn,
"current_state_events",

View file

@ -13,9 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
from ._base import SQLBaseStore, Table
from ._base import SQLBaseStore, cached
from twisted.internet import defer
import logging
@ -41,6 +39,7 @@ class PushRuleStore(SQLBaseStore):
defer.returnValue(rows)
@cached()
@defer.inlineCallbacks
def get_push_rules_enabled_for_user(self, user_name):
results = yield self._simple_select_list(
@ -151,6 +150,10 @@ class PushRuleStore(SQLBaseStore):
txn.execute(sql, (user_name, priority_class, new_rule_priority))
txn.call_after(
self.get_push_rules_enabled_for_user.invalidate, user_name
)
self._simple_insert_txn(
txn,
table=PushRuleTable.table_name,
@ -179,6 +182,10 @@ class PushRuleStore(SQLBaseStore):
new_rule['priority_class'] = priority_class
new_rule['priority'] = new_prio
txn.call_after(
self.get_push_rules_enabled_for_user.invalidate, user_name
)
self._simple_insert_txn(
txn,
table=PushRuleTable.table_name,
@ -201,6 +208,7 @@ class PushRuleStore(SQLBaseStore):
{'user_name': user_name, 'rule_id': rule_id},
desc="delete_push_rule",
)
self.get_push_rules_enabled_for_user.invalidate(user_name)
@defer.inlineCallbacks
def set_push_rule_enabled(self, user_name, rule_id, enabled):
@ -220,6 +228,7 @@ class PushRuleStore(SQLBaseStore):
{'enabled': 1 if enabled else 0},
{'id': new_id},
)
self.get_push_rules_enabled_for_user.invalidate(user_name)
class RuleNotFoundException(Exception):
@ -230,7 +239,7 @@ class InconsistentRuleException(Exception):
pass
class PushRuleTable(Table):
class PushRuleTable(object):
table_name = "push_rules"
fields = [
@ -243,10 +252,8 @@ class PushRuleTable(Table):
"actions",
]
EntryType = collections.namedtuple("PushRuleEntry", fields)
class PushRuleEnableTable(Table):
class PushRuleEnableTable(object):
table_name = "push_rules_enable"
fields = [

View file

@ -17,7 +17,7 @@ from twisted.internet import defer
from synapse.api.errors import StoreError
from ._base import SQLBaseStore
from ._base import SQLBaseStore, cached
import collections
import logging
@ -186,6 +186,7 @@ class RoomStore(SQLBaseStore):
}
)
@cached()
@defer.inlineCallbacks
def get_room_name_and_aliases(self, room_id):
def f(txn):

View file

@ -66,6 +66,7 @@ class RoomMemberStore(SQLBaseStore):
txn.call_after(self.get_rooms_for_user.invalidate, target_user_id)
txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id)
txn.call_after(self.get_users_in_room.invalidate, event.room_id)
def get_room_member(self, user_id, room_id):
"""Retrieve the current state of a room member.
@ -87,6 +88,7 @@ class RoomMemberStore(SQLBaseStore):
lambda events: events[0] if events else None
)
@cached()
def get_users_in_room(self, room_id):
def f(txn):

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore
from ._base import SQLBaseStore, cached
from twisted.internet import defer
@ -143,6 +143,12 @@ class StateStore(SQLBaseStore):
@defer.inlineCallbacks
def get_current_state(self, room_id, event_type=None, state_key=""):
if event_type and state_key is not None:
result = yield self.get_current_state_for_key(
room_id, event_type, state_key
)
defer.returnValue(result)
def f(txn):
sql = (
"SELECT event_id FROM current_state_events"
@ -167,6 +173,23 @@ class StateStore(SQLBaseStore):
events = yield self._get_events(event_ids, get_prev_content=False)
defer.returnValue(events)
@cached(num_args=3)
@defer.inlineCallbacks
def get_current_state_for_key(self, room_id, event_type, state_key):
def f(txn):
sql = (
"SELECT event_id FROM current_state_events"
" WHERE room_id = ? AND type = ? AND state_key = ?"
)
args = (room_id, event_type, state_key)
txn.execute(sql, args)
results = txn.fetchall()
return [r[0] for r in results]
event_ids = yield self.runInteraction("get_current_state_for_key", f)
events = yield self._get_events(event_ids, get_prev_content=False)
defer.returnValue(events)
def _make_group_id(clock):
return str(int(clock.time_msec())) + random_string(5)

View file

@ -20,7 +20,6 @@ import threading
class LruCache(object):
"""Least-recently-used cache."""
# TODO(mjark) Add mutex for linked list for thread safety.
def __init__(self, max_size):
cache = {}
list_root = []
@ -105,6 +104,12 @@ class LruCache(object):
else:
return default
@synchronized
def cache_clear():
list_root[NEXT] = list_root
list_root[PREV] = list_root
cache.clear()
@synchronized
def cache_len():
return len(cache)
@ -120,6 +125,7 @@ class LruCache(object):
self.pop = cache_pop
self.len = cache_len
self.contains = cache_contains
self.clear = cache_clear
def __getitem__(self, key):
result = self.get(key, self.sentinel)

View file

@ -217,18 +217,9 @@ class ApplicationServiceTestCase(unittest.TestCase):
_regex("@irc_.*")
)
join_list = [
Mock(
type="m.room.member", room_id="!foo:bar", sender="@alice:here",
state_key="@alice:here"
),
Mock(
type="m.room.member", room_id="!foo:bar", sender="@irc_fo:here",
state_key="@irc_fo:here" # AS user
),
Mock(
type="m.room.member", room_id="!foo:bar", sender="@bob:here",
state_key="@bob:here"
)
"@alice:here",
"@irc_fo:here", # AS user
"@bob:here",
]
self.event.sender = "@xmpp_foobar:matrix.org"

View file

@ -624,6 +624,7 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
"""
PRESENCE_LIST = {
'apple': [ "@banana:test", "@clementine:test" ],
'banana': [ "@apple:test" ],
}
@defer.inlineCallbacks
@ -836,12 +837,7 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
@defer.inlineCallbacks
def test_recv_remote(self):
# TODO(paul): Gut-wrenching
potato_set = self.handler._remote_recvmap.setdefault(self.u_potato,
set())
potato_set.add(self.u_apple)
self.room_members = [self.u_banana, self.u_potato]
self.room_members = [self.u_apple, self.u_banana, self.u_potato]
self.assertEquals(self.event_source.get_current_key(), 0)
@ -886,11 +882,8 @@ class PresencePushTestCase(MockedDatastorePresenceTestCase):
@defer.inlineCallbacks
def test_recv_remote_offline(self):
""" Various tests relating to SYN-261 """
potato_set = self.handler._remote_recvmap.setdefault(self.u_potato,
set())
potato_set.add(self.u_apple)
self.room_members = [self.u_banana, self.u_potato]
self.room_members = [self.u_apple, self.u_banana, self.u_potato]
self.assertEquals(self.event_source.get_current_key(), 0)

View file

@ -297,6 +297,9 @@ class PresenceEventStreamTestCase(unittest.TestCase):
else:
return []
hs.handlers.room_member_handler.get_joined_rooms_for_user = get_rooms_for_user
hs.handlers.room_member_handler.get_room_members = (
lambda r: self.room_members if r == "a-room" else []
)
self.mock_datastore = hs.get_datastore()
self.mock_datastore.get_app_service_by_token = Mock(return_value=None)