Merge branch 'develop' of github.com:matrix-org/synapse into matrix-org-hotfixes

This commit is contained in:
Erik Johnston 2019-04-04 14:43:57 +01:00
commit 8467756dc1
82 changed files with 2826 additions and 3088 deletions

1
changelog.d/4555.bugfix Normal file
View file

@ -0,0 +1 @@
Avoid redundant URL encoding of redirect URL for SSO login in the fallback login page. Fixes a regression introduced in [#4220](https://github.com/matrix-org/synapse/pull/4220). Contributed by Marcel Fabian Krüger ("[zaugin](https://github.com/zauguin)").

1
changelog.d/4982.misc Normal file
View file

@ -0,0 +1 @@
Track which identity server is used when binding a threepid and use that for unbinding, as per MSC1915.

1
changelog.d/4985.misc Normal file
View file

@ -0,0 +1 @@
Rewrite KeyringTestCase as a HomeserverTestCase.

1
changelog.d/4987.misc Normal file
View file

@ -0,0 +1 @@
README updates: Corrected the default POSTGRES_USER. Added port forwarding hint in TLS section.

1
changelog.d/4989.feature Normal file
View file

@ -0,0 +1 @@
Remove presence list support as per MSC 1819.

1
changelog.d/4996.misc Normal file
View file

@ -0,0 +1 @@
Run `black` on the remainder of `synapse/storage/`.

1
changelog.d/4998.misc Normal file
View file

@ -0,0 +1 @@
Fix grammar in get_current_users_in_room and give it a docstring.

1
changelog.d/4999.bugfix Normal file
View file

@ -0,0 +1 @@
Prevent the ability to kick users from a room they aren't in.

1
changelog.d/5002.feature Normal file
View file

@ -0,0 +1 @@
Add a delete group admin API.

1
changelog.d/5003.bugfix Normal file
View file

@ -0,0 +1 @@
Fix issue #4596 so synapse_port_db script works with --curses option on Python 3. Contributed by Anders Jensen-Waud <anders@jensenwaud.com>.

1
changelog.d/5007.misc Normal file
View file

@ -0,0 +1 @@
Refactor synapse.storage._base._simple_select_list_paginate.

View file

@ -60,7 +60,8 @@ Synapse requires a valid TLS certificate. You can do one of the following:
* Provide your own certificate and key (as
`${DATA_PATH}/${SYNAPSE_SERVER_NAME}.tls.crt` and
`${DATA_PATH}/${SYNAPSE_SERVER_NAME}.tls.key`, or elsewhere by providing an
entire config as `${SYNAPSE_CONFIG_PATH}`).
entire config as `${SYNAPSE_CONFIG_PATH}`). In this case, you should forward
traffic to port 8448 in the container, for example with `-p 443:8448`.
* Use a reverse proxy to terminate incoming TLS, and forward the plain http
traffic to port 8008 in the container. In this case you should set `-e
@ -138,7 +139,7 @@ Database specific values (will use SQLite if not set):
**NOTE**: You are highly encouraged to use postgresql! Please use the compose
file to make it easier to deploy.
* `POSTGRES_USER` - The user for the synapse postgres database. [default:
`matrix`]
`synapse`]
Mail server specific values (will not send emails if not set):

View file

@ -0,0 +1,14 @@
# Delete a local group
This API lets a server admin delete a local group. Doing so will kick all
users out of the group so that their clients will correctly handle the group
being deleted.
The API is:
```
POST /_matrix/client/r0/admin/delete_group/<group_id>
```
including an `access_token` of a server admin.

View file

@ -811,7 +811,7 @@ class CursesProgress(Progress):
middle_space = 1
items = self.tables.items()
items.sort(key=lambda i: (i[1]["perc"], i[0]))
items = sorted(items, key=lambda i: (i[1]["perc"], i[0]))
for i, (table, data) in enumerate(items):
if i + 2 >= rows:

View file

@ -22,6 +22,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.types import GroupID, RoomID, UserID, get_domain_from_id
from synapse.util.async_helpers import concurrently_execute
logger = logging.getLogger(__name__)
@ -896,6 +897,78 @@ class GroupsServerHandler(object):
"group_id": group_id,
})
@defer.inlineCallbacks
def delete_group(self, group_id, requester_user_id):
"""Deletes a group, kicking out all current members.
Only group admins or server admins can call this request
Args:
group_id (str)
request_user_id (str)
Returns:
Deferred
"""
yield self.check_group_is_ours(
group_id, requester_user_id,
and_exists=True,
)
# Only server admins or group admins can delete groups.
is_admin = yield self.store.is_user_admin_in_group(
group_id, requester_user_id
)
if not is_admin:
is_admin = yield self.auth.is_server_admin(
UserID.from_string(requester_user_id),
)
if not is_admin:
raise SynapseError(403, "User is not an admin")
# Before deleting the group lets kick everyone out of it
users = yield self.store.get_users_in_group(
group_id, include_private=True,
)
@defer.inlineCallbacks
def _kick_user_from_group(user_id):
if self.hs.is_mine_id(user_id):
groups_local = self.hs.get_groups_local_handler()
yield groups_local.user_removed_from_group(group_id, user_id, {})
else:
yield self.transport_client.remove_user_from_group_notification(
get_domain_from_id(user_id), group_id, user_id, {}
)
yield self.store.maybe_delete_remote_profile_cache(user_id)
# We kick users out in the order of:
# 1. Non-admins
# 2. Other admins
# 3. The requester
#
# This is so that if the deletion fails for some reason other admins or
# the requester still has auth to retry.
non_admins = []
admins = []
for u in users:
if u["user_id"] == requester_user_id:
continue
if u["is_admin"]:
admins.append(u["user_id"])
else:
non_admins.append(u["user_id"])
yield concurrently_execute(_kick_user_from_group, non_admins, 10)
yield concurrently_execute(_kick_user_from_group, admins, 10)
yield _kick_user_from_group(requester_user_id)
yield self.store.delete_group(group_id)
def _parse_join_policy_from_contents(content):
"""Given a content for a request, return the specified join policy or None

View file

@ -912,7 +912,7 @@ class AuthHandler(BaseHandler):
)
@defer.inlineCallbacks
def delete_threepid(self, user_id, medium, address):
def delete_threepid(self, user_id, medium, address, id_server=None):
"""Attempts to unbind the 3pid on the identity servers and deletes it
from the local database.
@ -920,6 +920,10 @@ class AuthHandler(BaseHandler):
user_id (str)
medium (str)
address (str)
id_server (str|None): Use the given identity server when unbinding
any threepids. If None then will attempt to unbind using the
identity server specified when binding (if known).
Returns:
Deferred[bool]: Returns True if successfully unbound the 3pid on
@ -937,6 +941,7 @@ class AuthHandler(BaseHandler):
{
'medium': medium,
'address': address,
'id_server': id_server,
},
)

View file

@ -43,12 +43,15 @@ class DeactivateAccountHandler(BaseHandler):
hs.get_reactor().callWhenRunning(self._start_user_parting)
@defer.inlineCallbacks
def deactivate_account(self, user_id, erase_data):
def deactivate_account(self, user_id, erase_data, id_server=None):
"""Deactivate a user's account
Args:
user_id (str): ID of user to be deactivated
erase_data (bool): whether to GDPR-erase the user's data
id_server (str|None): Use the given identity server when unbinding
any threepids. If None then will attempt to unbind using the
identity server specified when binding (if known).
Returns:
Deferred[bool]: True if identity server supports removing
@ -74,6 +77,7 @@ class DeactivateAccountHandler(BaseHandler):
{
'medium': threepid['medium'],
'address': threepid['address'],
'id_server': id_server,
},
)
identity_server_supports_unbinding &= result

View file

@ -68,7 +68,7 @@ class DirectoryHandler(BaseHandler):
# TODO(erikj): Add transactions.
# TODO(erikj): Check if there is a current association.
if not servers:
users = yield self.state.get_current_user_in_room(room_id)
users = yield self.state.get_current_users_in_room(room_id)
servers = set(get_domain_from_id(u) for u in users)
if not servers:
@ -268,7 +268,7 @@ class DirectoryHandler(BaseHandler):
Codes.NOT_FOUND
)
users = yield self.state.get_current_user_in_room(room_id)
users = yield self.state.get_current_users_in_room(room_id)
extra_servers = set(get_domain_from_id(u) for u in users)
servers = set(extra_servers) | set(servers)

View file

@ -102,7 +102,7 @@ class EventStreamHandler(BaseHandler):
# Send down presence.
if event.state_key == auth_user_id:
# Send down presence for everyone in the room.
users = yield self.state.get_current_user_in_room(event.room_id)
users = yield self.state.get_current_users_in_room(event.room_id)
states = yield presence_handler.get_states(
users,
as_event=True,

View file

@ -132,6 +132,14 @@ class IdentityHandler(BaseHandler):
}
)
logger.debug("bound threepid %r to %s", creds, mxid)
# Remember where we bound the threepid
yield self.store.add_user_bound_threepid(
user_id=mxid,
medium=data["medium"],
address=data["address"],
id_server=id_server,
)
except CodeMessageException as e:
data = json.loads(e.msg) # XXX WAT?
defer.returnValue(data)
@ -140,9 +148,48 @@ class IdentityHandler(BaseHandler):
def try_unbind_threepid(self, mxid, threepid):
"""Removes a binding from an identity server
Args:
mxid (str): Matrix user ID of binding to be removed
threepid (dict): Dict with medium & address of binding to be
removed, and an optional id_server.
Raises:
SynapseError: If we failed to contact the identity server
Returns:
Deferred[bool]: True on success, otherwise False if the identity
server doesn't support unbinding (or no identity server found to
contact).
"""
if threepid.get("id_server"):
id_servers = [threepid["id_server"]]
else:
id_servers = yield self.store.get_id_servers_user_bound(
user_id=mxid,
medium=threepid["medium"],
address=threepid["address"],
)
# We don't know where to unbind, so we don't have a choice but to return
if not id_servers:
defer.returnValue(False)
changed = True
for id_server in id_servers:
changed &= yield self.try_unbind_threepid_with_id_server(
mxid, threepid, id_server,
)
defer.returnValue(changed)
@defer.inlineCallbacks
def try_unbind_threepid_with_id_server(self, mxid, threepid, id_server):
"""Removes a binding from an identity server
Args:
mxid (str): Matrix user ID of binding to be removed
threepid (dict): Dict with medium & address of binding to be removed
id_server (str): Identity server to unbind from
Raises:
SynapseError: If we failed to contact the identity server
@ -151,21 +198,13 @@ class IdentityHandler(BaseHandler):
Deferred[bool]: True on success, otherwise False if the identity
server doesn't support unbinding
"""
logger.debug("unbinding threepid %r from %s", threepid, mxid)
if not self.trusted_id_servers:
logger.warn("Can't unbind threepid: no trusted ID servers set in config")
defer.returnValue(False)
# We don't track what ID server we added 3pids on (perhaps we ought to)
# but we assume that any of the servers in the trusted list are in the
# same ID server federation, so we can pick any one of them to send the
# deletion request to.
id_server = next(iter(self.trusted_id_servers))
url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,)
content = {
"mxid": mxid,
"threepid": threepid,
"threepid": {
"medium": threepid["medium"],
"address": threepid["address"],
},
}
# we abuse the federation http client to sign the request, but we have to send it
@ -188,16 +227,24 @@ class IdentityHandler(BaseHandler):
content,
headers,
)
changed = True
except HttpResponseException as e:
changed = False
if e.code in (400, 404, 501,):
# The remote server probably doesn't support unbinding (yet)
logger.warn("Received %d response while unbinding threepid", e.code)
defer.returnValue(False)
else:
logger.error("Failed to unbind threepid on identity server: %s", e)
raise SynapseError(502, "Failed to contact identity server")
defer.returnValue(True)
yield self.store.remove_user_bound_threepid(
user_id=mxid,
medium=threepid["medium"],
address=threepid["address"],
id_server=id_server,
)
defer.returnValue(changed)
@defer.inlineCallbacks
def requestEmailToken(self, id_server, email, client_secret, send_attempt, **kwargs):

View file

@ -192,7 +192,7 @@ class MessageHandler(object):
"Getting joined members after leaving is not implemented"
)
users_with_profile = yield self.state.get_current_user_in_room(room_id)
users_with_profile = yield self.state.get_current_users_in_room(room_id)
# If this is an AS, double check that they are allowed to see the members.
# This can either be because the AS user is in the room or because there

View file

@ -113,27 +113,6 @@ class PresenceHandler(object):
federation_registry.register_edu_handler(
"m.presence", self.incoming_presence
)
federation_registry.register_edu_handler(
"m.presence_invite",
lambda origin, content: self.invite_presence(
observed_user=UserID.from_string(content["observed_user"]),
observer_user=UserID.from_string(content["observer_user"]),
)
)
federation_registry.register_edu_handler(
"m.presence_accept",
lambda origin, content: self.accept_presence(
observed_user=UserID.from_string(content["observed_user"]),
observer_user=UserID.from_string(content["observer_user"]),
)
)
federation_registry.register_edu_handler(
"m.presence_deny",
lambda origin, content: self.deny_presence(
observed_user=UserID.from_string(content["observed_user"]),
observer_user=UserID.from_string(content["observer_user"]),
)
)
active_presence = self.store.take_presence_startup_info()
@ -759,137 +738,6 @@ class PresenceHandler(object):
yield self._update_states([prev_state.copy_and_replace(**new_fields)])
@defer.inlineCallbacks
def get_presence_list(self, observer_user, accepted=None):
"""Returns the presence for all users in their presence list.
"""
if not self.is_mine(observer_user):
raise SynapseError(400, "User is not hosted on this Home Server")
presence_list = yield self.store.get_presence_list(
observer_user.localpart, accepted=accepted
)
results = yield self.get_states(
target_user_ids=[row["observed_user_id"] for row in presence_list],
as_event=False,
)
now = self.clock.time_msec()
results[:] = [format_user_presence_state(r, now) for r in results]
is_accepted = {
row["observed_user_id"]: row["accepted"] for row in presence_list
}
for result in results:
result.update({
"accepted": is_accepted,
})
defer.returnValue(results)
@defer.inlineCallbacks
def send_presence_invite(self, observer_user, observed_user):
"""Sends a presence invite.
"""
yield self.store.add_presence_list_pending(
observer_user.localpart, observed_user.to_string()
)
if self.is_mine(observed_user):
yield self.invite_presence(observed_user, observer_user)
else:
yield self.federation.build_and_send_edu(
destination=observed_user.domain,
edu_type="m.presence_invite",
content={
"observed_user": observed_user.to_string(),
"observer_user": observer_user.to_string(),
}
)
@defer.inlineCallbacks
def invite_presence(self, observed_user, observer_user):
"""Handles new presence invites.
"""
if not self.is_mine(observed_user):
raise SynapseError(400, "User is not hosted on this Home Server")
# TODO: Don't auto accept
if self.is_mine(observer_user):
yield self.accept_presence(observed_user, observer_user)
else:
self.federation.build_and_send_edu(
destination=observer_user.domain,
edu_type="m.presence_accept",
content={
"observed_user": observed_user.to_string(),
"observer_user": observer_user.to_string(),
}
)
state_dict = yield self.get_state(observed_user, as_event=False)
state_dict = format_user_presence_state(state_dict, self.clock.time_msec())
self.federation.build_and_send_edu(
destination=observer_user.domain,
edu_type="m.presence",
content={
"push": [state_dict]
}
)
@defer.inlineCallbacks
def accept_presence(self, observed_user, observer_user):
"""Handles a m.presence_accept EDU. Mark a presence invite from a
local or remote user as accepted in a local user's presence list.
Starts polling for presence updates from the local or remote user.
Args:
observed_user(UserID): The user to update in the presence list.
observer_user(UserID): The owner of the presence list to update.
"""
yield self.store.set_presence_list_accepted(
observer_user.localpart, observed_user.to_string()
)
@defer.inlineCallbacks
def deny_presence(self, observed_user, observer_user):
"""Handle a m.presence_deny EDU. Removes a local or remote user from a
local user's presence list.
Args:
observed_user(UserID): The local or remote user to remove from the
list.
observer_user(UserID): The local owner of the presence list.
Returns:
A Deferred.
"""
yield self.store.del_presence_list(
observer_user.localpart, observed_user.to_string()
)
# TODO(paul): Inform the user somehow?
@defer.inlineCallbacks
def drop(self, observed_user, observer_user):
"""Remove a local or remote user from a local user's presence list and
unsubscribe the local user from updates that user.
Args:
observed_user(UserId): The local or remote user to remove from the
list.
observer_user(UserId): The local owner of the presence list.
Returns:
A Deferred.
"""
if not self.is_mine(observer_user):
raise SynapseError(400, "User is not hosted on this Home Server")
yield self.store.del_presence_list(
observer_user.localpart, observed_user.to_string()
)
# TODO: Inform the remote that we've dropped the presence list.
@defer.inlineCallbacks
def is_visible(self, observed_user, observer_user):
"""Returns whether a user can see another user's presence.
@ -904,11 +752,7 @@ class PresenceHandler(object):
if observer_room_ids & observed_room_ids:
defer.returnValue(True)
accepted_observers = yield self.store.get_presence_list_observers_accepted(
observed_user.to_string()
)
defer.returnValue(observer_user.to_string() in accepted_observers)
defer.returnValue(False)
@defer.inlineCallbacks
def get_all_presence_updates(self, last_id, current_id):
@ -1039,7 +883,7 @@ class PresenceHandler(object):
# TODO: Check that this is actually a new server joining the
# room.
user_ids = yield self.state.get_current_user_in_room(room_id)
user_ids = yield self.state.get_current_users_in_room(room_id)
user_ids = list(filter(self.is_mine_id, user_ids))
states = yield self.current_state_for_users(user_ids)
@ -1204,10 +1048,7 @@ class PresenceEventSource(object):
updates for
"""
user_id = user.to_string()
plist = yield self.store.get_presence_list_accepted(
user.localpart, on_invalidate=cache_context.invalidate,
)
users_interested_in = set(row["observed_user_id"] for row in plist)
users_interested_in = set()
users_interested_in.add(user_id) # So that we receive our own presence
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
@ -1412,10 +1253,6 @@ def get_interested_parties(store, states):
for room_id in room_ids:
room_ids_to_states.setdefault(room_id, []).append(state)
plist = yield store.get_presence_list_observers_accepted(state.user_id)
for u in plist:
users_to_states.setdefault(u, []).append(state)
# Always notify self
users_to_states.setdefault(state.user_id, []).append(state)

View file

@ -171,7 +171,7 @@ class RoomListHandler(BaseHandler):
if not latest_event_ids:
return
joined_users = yield self.state_handler.get_current_user_in_room(
joined_users = yield self.state_handler.get_current_users_in_room(
room_id, latest_event_ids,
)

View file

@ -441,6 +441,9 @@ class RoomMemberHandler(object):
room_id, latest_event_ids=latest_event_ids,
)
# TODO: Refactor into dictionary of explicitly allowed transitions
# between old and new state, with specific error messages for some
# transitions and generic otherwise
old_state_id = current_state_ids.get((EventTypes.Member, target.to_string()))
if old_state_id:
old_state = yield self.store.get_event(old_state_id, allow_none=True)
@ -466,6 +469,9 @@ class RoomMemberHandler(object):
if same_sender and same_membership and same_content:
defer.returnValue(old_state)
if old_membership in ["ban", "leave"] and action == "kick":
raise AuthError(403, "The target user is not in the room")
# we don't allow people to reject invites to the server notice
# room, but they can leave it once they are joined.
if (
@ -479,6 +485,9 @@ class RoomMemberHandler(object):
"You cannot reject this invite",
errcode=Codes.CANNOT_LEAVE_SERVER_NOTICE_ROOM,
)
else:
if action == "kick":
raise AuthError(403, "The target user is not in the room")
is_host_in_room = yield self._is_host_in_room(current_state_ids)

View file

@ -1052,11 +1052,11 @@ class SyncHandler(object):
# TODO: Be more clever than this, i.e. remove users who we already
# share a room with?
for room_id in newly_joined_rooms:
joined_users = yield self.state.get_current_user_in_room(room_id)
joined_users = yield self.state.get_current_users_in_room(room_id)
newly_joined_users.update(joined_users)
for room_id in newly_left_rooms:
left_users = yield self.state.get_current_user_in_room(room_id)
left_users = yield self.state.get_current_users_in_room(room_id)
newly_left_users.update(left_users)
# TODO: Check that these users are actually new, i.e. either they
@ -1216,7 +1216,7 @@ class SyncHandler(object):
extra_users_ids = set(newly_joined_users)
for room_id in newly_joined_rooms:
users = yield self.state.get_current_user_in_room(room_id)
users = yield self.state.get_current_users_in_room(room_id)
extra_users_ids.update(users)
extra_users_ids.discard(user.to_string())
@ -1858,7 +1858,7 @@ class SyncHandler(object):
extrems = yield self.store.get_forward_extremeties_for_room(
room_id, stream_ordering,
)
users_in_room = yield self.state.get_current_user_in_room(
users_in_room = yield self.state.get_current_users_in_room(
room_id, extrems,
)
if user_id in users_in_room:

View file

@ -218,7 +218,7 @@ class TypingHandler(object):
@defer.inlineCallbacks
def _push_remote(self, member, typing):
try:
users = yield self.state.get_current_user_in_room(member.room_id)
users = yield self.state.get_current_users_in_room(member.room_id)
self._member_last_federation_poke[member] = self.clock.time_msec()
now = self.clock.time_msec()
@ -261,7 +261,7 @@ class TypingHandler(object):
)
return
users = yield self.state.get_current_user_in_room(room_id)
users = yield self.state.get_current_users_in_room(room_id)
domains = set(get_domain_from_id(u) for u in users)
if self.server_name in domains:

View file

@ -276,7 +276,7 @@ class UserDirectoryHandler(StateDeltasHandler):
# ignore the change
return
users_with_profile = yield self.state.get_current_user_in_room(room_id)
users_with_profile = yield self.state.get_current_users_in_room(room_id)
# Remove every user from the sharing tables for that room.
for user_id in iterkeys(users_with_profile):
@ -325,7 +325,7 @@ class UserDirectoryHandler(StateDeltasHandler):
room_id
)
# Now we update users who share rooms with users.
users_with_profile = yield self.state.get_current_user_in_room(room_id)
users_with_profile = yield self.state.get_current_users_in_room(room_id)
if is_public:
yield self.store.add_users_in_public_rooms(room_id, (user_id,))

View file

@ -39,16 +39,6 @@ class SlavedPresenceStore(BaseSlavedStore):
_get_presence_for_user = PresenceStore.__dict__["_get_presence_for_user"]
get_presence_for_users = PresenceStore.__dict__["get_presence_for_users"]
# XXX: This is a bit broken because we don't persist the accepted list in a
# way that can be replicated. This means that we don't have a way to
# invalidate the cache correctly.
get_presence_list_accepted = PresenceStore.__dict__[
"get_presence_list_accepted"
]
get_presence_list_observers_accepted = PresenceStore.__dict__[
"get_presence_list_observers_accepted"
]
def get_current_presence_token(self):
return self._presence_id_gen.get_current_token()

View file

@ -499,7 +499,7 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
# desirable in case the first attempt at blocking the room failed below.
yield self.store.block_room(room_id, requester_user_id)
users = yield self.state.get_current_user_in_room(room_id)
users = yield self.state.get_current_users_in_room(room_id)
kicked_users = []
failed_to_kick_users = []
for user_id in users:
@ -784,6 +784,31 @@ class SearchUsersRestServlet(ClientV1RestServlet):
defer.returnValue((200, ret))
class DeleteGroupAdminRestServlet(ClientV1RestServlet):
"""Allows deleting of local groups
"""
PATTERNS = client_path_patterns("/admin/delete_group/(?P<group_id>[^/]*)")
def __init__(self, hs):
super(DeleteGroupAdminRestServlet, self).__init__(hs)
self.group_server = hs.get_groups_server_handler()
self.is_mine_id = hs.is_mine_id
@defer.inlineCallbacks
def on_POST(self, request, group_id):
requester = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(requester.user)
if not is_admin:
raise AuthError(403, "You are not a server admin")
if not self.is_mine_id(group_id):
raise SynapseError(400, "Can only delete local groups")
yield self.group_server.delete_group(group_id, requester.user.to_string())
defer.returnValue((200, {}))
def register_servlets(hs, http_server):
WhoisRestServlet(hs).register(http_server)
PurgeMediaCacheRestServlet(hs).register(http_server)
@ -799,3 +824,4 @@ def register_servlets(hs, http_server):
ListMediaInRoom(hs).register(http_server)
UserRegisterServlet(hs).register(http_server)
VersionServlet(hs).register(http_server)
DeleteGroupAdminRestServlet(hs).register(http_server)

View file

@ -93,72 +93,5 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
return (200, {})
class PresenceListRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/presence/list/(?P<user_id>[^/]*)")
def __init__(self, hs):
super(PresenceListRestServlet, self).__init__(hs)
self.presence_handler = hs.get_presence_handler()
@defer.inlineCallbacks
def on_GET(self, request, user_id):
requester = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
if not self.hs.is_mine(user):
raise SynapseError(400, "User not hosted on this Home Server")
if requester.user != user:
raise SynapseError(400, "Cannot get another user's presence list")
presence = yield self.presence_handler.get_presence_list(
observer_user=user, accepted=True
)
defer.returnValue((200, presence))
@defer.inlineCallbacks
def on_POST(self, request, user_id):
requester = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
if not self.hs.is_mine(user):
raise SynapseError(400, "User not hosted on this Home Server")
if requester.user != user:
raise SynapseError(
400, "Cannot modify another user's presence list")
content = parse_json_object_from_request(request)
if "invite" in content:
for u in content["invite"]:
if not isinstance(u, string_types):
raise SynapseError(400, "Bad invite value.")
if len(u) == 0:
continue
invited_user = UserID.from_string(u)
yield self.presence_handler.send_presence_invite(
observer_user=user, observed_user=invited_user
)
if "drop" in content:
for u in content["drop"]:
if not isinstance(u, string_types):
raise SynapseError(400, "Bad drop value.")
if len(u) == 0:
continue
dropped_user = UserID.from_string(u)
yield self.presence_handler.drop(
observer_user=user, observed_user=dropped_user
)
defer.returnValue((200, {}))
def on_OPTIONS(self, request):
return (200, {})
def register_servlets(hs, http_server):
PresenceStatusRestServlet(hs).register(http_server)
PresenceListRestServlet(hs).register(http_server)

View file

@ -215,6 +215,7 @@ class DeactivateAccountRestServlet(RestServlet):
)
result = yield self._deactivate_account_handler.deactivate_account(
requester.user.to_string(), erase,
id_server=body.get("id_server"),
)
if result:
id_server_unbind_result = "success"
@ -363,7 +364,7 @@ class ThreepidRestServlet(RestServlet):
class ThreepidDeleteRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/3pid/delete$", releases=())
PATTERNS = client_v2_patterns("/account/3pid/delete$")
def __init__(self, hs):
super(ThreepidDeleteRestServlet, self).__init__()
@ -380,7 +381,7 @@ class ThreepidDeleteRestServlet(RestServlet):
try:
ret = yield self.auth_handler.delete_threepid(
user_id, body['medium'], body['address']
user_id, body['medium'], body['address'], body.get("id_server"),
)
except Exception:
# NB. This endpoint should succeed if there is nothing to

View file

@ -161,10 +161,21 @@ class StateHandler(object):
defer.returnValue(state)
@defer.inlineCallbacks
def get_current_user_in_room(self, room_id, latest_event_ids=None):
def get_current_users_in_room(self, room_id, latest_event_ids=None):
"""
Get the users who are currently in a room.
Args:
room_id (str): The ID of the room.
latest_event_ids (List[str]|None): Precomputed list of latest
event IDs. Will be computed if None.
Returns:
Deferred[Dict[str,ProfileInfo]]: Dictionary of user IDs to their
profileinfo.
"""
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_user_in_room")
logger.debug("calling resolve_state_groups from get_current_users_in_room")
entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
joined_users = yield self.store.get_joined_users_from_state(room_id, entry)
defer.returnValue(joined_users)

View file

@ -49,7 +49,7 @@ var show_login = function() {
$("#loading").hide();
var this_page = window.location.origin + window.location.pathname;
$("#sso_redirect_url").val(encodeURIComponent(this_page));
$("#sso_redirect_url").val(this_page);
if (matrixLogin.serverAcceptsPassword) {
$("#password_flow").show();

View file

@ -18,6 +18,8 @@ import calendar
import logging
import time
from twisted.internet import defer
from synapse.api.constants import PresenceState
from synapse.storage.devices import DeviceStore
from synapse.storage.user_erasure_store import UserErasureStore
@ -61,48 +63,60 @@ from .util.id_generators import ChainedIdGenerator, IdGenerator, StreamIdGenerat
logger = logging.getLogger(__name__)
class DataStore(RoomMemberStore, RoomStore,
RegistrationStore, StreamStore, ProfileStore,
PresenceStore, TransactionStore,
DirectoryStore, KeyStore, StateStore, SignatureStore,
ApplicationServiceStore,
EventsStore,
EventFederationStore,
MediaRepositoryStore,
RejectionsStore,
FilteringStore,
PusherStore,
PushRuleStore,
ApplicationServiceTransactionStore,
ReceiptsStore,
EndToEndKeyStore,
EndToEndRoomKeyStore,
SearchStore,
TagsStore,
AccountDataStore,
EventPushActionsStore,
OpenIdStore,
ClientIpStore,
DeviceStore,
DeviceInboxStore,
UserDirectoryStore,
GroupServerStore,
UserErasureStore,
MonthlyActiveUsersStore,
):
class DataStore(
RoomMemberStore,
RoomStore,
RegistrationStore,
StreamStore,
ProfileStore,
PresenceStore,
TransactionStore,
DirectoryStore,
KeyStore,
StateStore,
SignatureStore,
ApplicationServiceStore,
EventsStore,
EventFederationStore,
MediaRepositoryStore,
RejectionsStore,
FilteringStore,
PusherStore,
PushRuleStore,
ApplicationServiceTransactionStore,
ReceiptsStore,
EndToEndKeyStore,
EndToEndRoomKeyStore,
SearchStore,
TagsStore,
AccountDataStore,
EventPushActionsStore,
OpenIdStore,
ClientIpStore,
DeviceStore,
DeviceInboxStore,
UserDirectoryStore,
GroupServerStore,
UserErasureStore,
MonthlyActiveUsersStore,
):
def __init__(self, db_conn, hs):
self.hs = hs
self._clock = hs.get_clock()
self.database_engine = hs.database_engine
self._stream_id_gen = StreamIdGenerator(
db_conn, "events", "stream_ordering",
extra_tables=[("local_invites", "stream_id")]
db_conn,
"events",
"stream_ordering",
extra_tables=[("local_invites", "stream_id")],
)
self._backfill_id_gen = StreamIdGenerator(
db_conn, "events", "stream_ordering", step=-1,
extra_tables=[("ex_outlier_stream", "event_stream_ordering")]
db_conn,
"events",
"stream_ordering",
step=-1,
extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
)
self._presence_id_gen = StreamIdGenerator(
db_conn, "presence_stream", "stream_id"
@ -114,7 +128,7 @@ class DataStore(RoomMemberStore, RoomStore,
db_conn, "public_room_list_stream", "stream_id"
)
self._device_list_id_gen = StreamIdGenerator(
db_conn, "device_lists_stream", "stream_id",
db_conn, "device_lists_stream", "stream_id"
)
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
@ -125,16 +139,15 @@ class DataStore(RoomMemberStore, RoomStore,
self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
)
self._pushers_id_gen = StreamIdGenerator(
db_conn, "pushers", "id",
extra_tables=[("deleted_pushers", "stream_id")],
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
)
self._group_updates_id_gen = StreamIdGenerator(
db_conn, "local_group_updates", "stream_id",
db_conn, "local_group_updates", "stream_id"
)
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = StreamIdGenerator(
db_conn, "cache_invalidation_stream", "stream_id",
db_conn, "cache_invalidation_stream", "stream_id"
)
else:
self._cache_id_gen = None
@ -142,72 +155,82 @@ class DataStore(RoomMemberStore, RoomStore,
self._presence_on_startup = self._get_active_presence(db_conn)
presence_cache_prefill, min_presence_val = self._get_cache_dict(
db_conn, "presence_stream",
db_conn,
"presence_stream",
entity_column="user_id",
stream_column="stream_id",
max_value=self._presence_id_gen.get_current_token(),
)
self.presence_stream_cache = StreamChangeCache(
"PresenceStreamChangeCache", min_presence_val,
prefilled_cache=presence_cache_prefill
"PresenceStreamChangeCache",
min_presence_val,
prefilled_cache=presence_cache_prefill,
)
max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
device_inbox_prefill, min_device_inbox_id = self._get_cache_dict(
db_conn, "device_inbox",
db_conn,
"device_inbox",
entity_column="user_id",
stream_column="stream_id",
max_value=max_device_inbox_id,
limit=1000,
)
self._device_inbox_stream_cache = StreamChangeCache(
"DeviceInboxStreamChangeCache", min_device_inbox_id,
"DeviceInboxStreamChangeCache",
min_device_inbox_id,
prefilled_cache=device_inbox_prefill,
)
# The federation outbox and the local device inbox uses the same
# stream_id generator.
device_outbox_prefill, min_device_outbox_id = self._get_cache_dict(
db_conn, "device_federation_outbox",
db_conn,
"device_federation_outbox",
entity_column="destination",
stream_column="stream_id",
max_value=max_device_inbox_id,
limit=1000,
)
self._device_federation_outbox_stream_cache = StreamChangeCache(
"DeviceFederationOutboxStreamChangeCache", min_device_outbox_id,
"DeviceFederationOutboxStreamChangeCache",
min_device_outbox_id,
prefilled_cache=device_outbox_prefill,
)
device_list_max = self._device_list_id_gen.get_current_token()
self._device_list_stream_cache = StreamChangeCache(
"DeviceListStreamChangeCache", device_list_max,
"DeviceListStreamChangeCache", device_list_max
)
self._device_list_federation_stream_cache = StreamChangeCache(
"DeviceListFederationStreamChangeCache", device_list_max,
"DeviceListFederationStreamChangeCache", device_list_max
)
events_max = self._stream_id_gen.get_current_token()
curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict(
db_conn, "current_state_delta_stream",
db_conn,
"current_state_delta_stream",
entity_column="room_id",
stream_column="stream_id",
max_value=events_max, # As we share the stream id with events token
limit=1000,
)
self._curr_state_delta_stream_cache = StreamChangeCache(
"_curr_state_delta_stream_cache", min_curr_state_delta_id,
"_curr_state_delta_stream_cache",
min_curr_state_delta_id,
prefilled_cache=curr_state_delta_prefill,
)
_group_updates_prefill, min_group_updates_id = self._get_cache_dict(
db_conn, "local_group_updates",
db_conn,
"local_group_updates",
entity_column="user_id",
stream_column="stream_id",
max_value=self._group_updates_id_gen.get_current_token(),
limit=1000,
)
self._group_updates_stream_cache = StreamChangeCache(
"_group_updates_stream_cache", min_group_updates_id,
"_group_updates_stream_cache",
min_group_updates_id,
prefilled_cache=_group_updates_prefill,
)
@ -250,6 +273,7 @@ class DataStore(RoomMemberStore, RoomStore,
"""
Counts the number of users who used this homeserver in the last 24 hours.
"""
def _count_users(txn):
yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
@ -277,6 +301,7 @@ class DataStore(RoomMemberStore, RoomStore,
Returns counts globaly for a given user as well as breaking
by platform
"""
def _count_r30_users(txn):
thirty_days_in_secs = 86400 * 30
now = int(self._clock.time())
@ -313,8 +338,7 @@ class DataStore(RoomMemberStore, RoomStore,
"""
results = {}
txn.execute(sql, (thirty_days_ago_in_secs,
thirty_days_ago_in_secs))
txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
for row in txn:
if row[0] == 'unknown':
@ -341,8 +365,7 @@ class DataStore(RoomMemberStore, RoomStore,
) u
"""
txn.execute(sql, (thirty_days_ago_in_secs,
thirty_days_ago_in_secs))
txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
count, = txn.fetchone()
results['all'] = count
@ -356,15 +379,14 @@ class DataStore(RoomMemberStore, RoomStore,
Returns millisecond unixtime for start of UTC day.
"""
now = time.gmtime()
today_start = calendar.timegm((
now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0,
))
today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0))
return today_start * 1000
def generate_user_daily_visits(self):
"""
Generates daily visit data for use in cohort/ retention analysis
"""
def _generate_user_daily_visits(txn):
logger.info("Calling _generate_user_daily_visits")
today_start = self._get_start_of_day()
@ -395,25 +417,29 @@ class DataStore(RoomMemberStore, RoomStore,
# often to minimise this case.
if today_start > self._last_user_visit_update:
yesterday_start = today_start - a_day_in_milliseconds
txn.execute(sql, (
yesterday_start, yesterday_start,
self._last_user_visit_update, today_start
))
txn.execute(
sql,
(
yesterday_start,
yesterday_start,
self._last_user_visit_update,
today_start,
),
)
self._last_user_visit_update = today_start
txn.execute(sql, (
today_start, today_start,
self._last_user_visit_update,
now
))
txn.execute(
sql, (today_start, today_start, self._last_user_visit_update, now)
)
# Update _last_user_visit_update to now. The reason to do this
# rather just clamping to the beginning of the day is to limit
# the size of the join - meaning that the query can be run more
# frequently
self._last_user_visit_update = now
return self.runInteraction("generate_user_daily_visits",
_generate_user_daily_visits)
return self.runInteraction(
"generate_user_daily_visits", _generate_user_daily_visits
)
def get_users(self):
"""Function to reterive a list of users in users table.
@ -425,15 +451,11 @@ class DataStore(RoomMemberStore, RoomStore,
return self._simple_select_list(
table="users",
keyvalues={},
retcols=[
"name",
"password_hash",
"is_guest",
"admin"
],
retcols=["name", "password_hash", "is_guest", "admin"],
desc="get_users",
)
@defer.inlineCallbacks
def get_users_paginate(self, order, start, limit):
"""Function to reterive a paginated list of users from
users list. This will return a json object, which contains
@ -446,27 +468,19 @@ class DataStore(RoomMemberStore, RoomStore,
Returns:
defer.Deferred: resolves to json object {list[dict[str, Any]], count}
"""
is_guest = 0
i_start = (int)(start)
i_limit = (int)(limit)
return self.get_user_list_paginate(
users = yield self.runInteraction(
"get_users_paginate",
self._simple_select_list_paginate_txn,
table="users",
keyvalues={
"is_guest": is_guest
},
pagevalues=[
order,
i_limit,
i_start
],
retcols=[
"name",
"password_hash",
"is_guest",
"admin"
],
desc="get_users_paginate",
keyvalues={"is_guest": False},
orderby=order,
start=start,
limit=limit,
retcols=["name", "password_hash", "is_guest", "admin"],
)
count = yield self.runInteraction("get_users_paginate", self.get_user_count_txn)
retval = {"users": users, "total": count}
defer.returnValue(retval)
def search_users(self, term):
"""Function to search users list for one or more users with
@ -482,12 +496,7 @@ class DataStore(RoomMemberStore, RoomStore,
table="users",
term=term,
col="name",
retcols=[
"name",
"password_hash",
"is_guest",
"admin"
],
retcols=["name", "password_hash", "is_guest", "admin"],
desc="search_users",
)

View file

@ -41,7 +41,7 @@ try:
MAX_TXN_ID = sys.maxint - 1
except AttributeError:
# python 3 does not have a maximum int value
MAX_TXN_ID = 2**63 - 1
MAX_TXN_ID = 2 ** 63 - 1
sql_logger = logging.getLogger("synapse.storage.SQL")
transaction_logger = logging.getLogger("synapse.storage.txn")
@ -76,12 +76,18 @@ class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute()
method."""
__slots__ = [
"txn", "name", "database_engine", "after_callbacks", "exception_callbacks",
"txn",
"name",
"database_engine",
"after_callbacks",
"exception_callbacks",
]
def __init__(self, txn, name, database_engine, after_callbacks,
exception_callbacks):
def __init__(
self, txn, name, database_engine, after_callbacks, exception_callbacks
):
object.__setattr__(self, "txn", txn)
object.__setattr__(self, "name", name)
object.__setattr__(self, "database_engine", database_engine)
@ -110,6 +116,7 @@ class LoggingTransaction(object):
def execute_batch(self, sql, args):
if isinstance(self.database_engine, PostgresEngine):
from psycopg2.extras import execute_batch
self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
else:
for val in args:
@ -134,10 +141,7 @@ class LoggingTransaction(object):
sql = self.database_engine.convert_param_style(sql)
if args:
try:
sql_logger.debug(
"[SQL values] {%s} %r",
self.name, args[0]
)
sql_logger.debug("[SQL values] {%s} %r", self.name, args[0])
except Exception:
# Don't let logging failures stop SQL from working
pass
@ -145,9 +149,7 @@ class LoggingTransaction(object):
start = time.time()
try:
return func(
sql, *args
)
return func(sql, *args)
except Exception as e:
logger.debug("[SQL FAIL] {%s} %s", self.name, e)
raise
@ -176,11 +178,9 @@ class PerformanceCounters(object):
counters = []
for name, (count, cum_time) in iteritems(self.current_counters):
prev_count, prev_time = self.previous_counters.get(name, (0, 0))
counters.append((
(cum_time - prev_time) / interval_duration,
count - prev_count,
name
))
counters.append(
((cum_time - prev_time) / interval_duration, count - prev_count, name)
)
self.previous_counters = dict(self.current_counters)
@ -212,8 +212,9 @@ class SQLBaseStore(object):
self._txn_perf_counters = PerformanceCounters()
self._get_event_counters = PerformanceCounters()
self._get_event_cache = Cache("*getEvent*", keylen=3,
max_entries=hs.config.event_cache_size)
self._get_event_cache = Cache(
"*getEvent*", keylen=3, max_entries=hs.config.event_cache_size
)
self._event_fetch_lock = threading.Condition()
self._event_fetch_list = []
@ -239,7 +240,7 @@ class SQLBaseStore(object):
0.0,
run_as_background_process,
"upsert_safety_check",
self._check_safe_to_upsert
self._check_safe_to_upsert,
)
@defer.inlineCallbacks
@ -271,7 +272,7 @@ class SQLBaseStore(object):
15.0,
run_as_background_process,
"upsert_safety_check",
self._check_safe_to_upsert
self._check_safe_to_upsert,
)
def start_profiling(self):
@ -298,13 +299,16 @@ class SQLBaseStore(object):
perf_logger.info(
"Total database time: %.3f%% {%s} {%s}",
ratio * 100, top_three_counters, top_3_event_counters
ratio * 100,
top_three_counters,
top_3_event_counters,
)
self._clock.looping_call(loop, 10000)
def _new_transaction(self, conn, desc, after_callbacks, exception_callbacks,
func, *args, **kwargs):
def _new_transaction(
self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs
):
start = time.time()
txn_id = self._TXN_ID
@ -312,7 +316,7 @@ class SQLBaseStore(object):
# growing really large.
self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID)
name = "%s-%x" % (desc, txn_id, )
name = "%s-%x" % (desc, txn_id)
transaction_logger.debug("[TXN START] {%s}", name)
@ -323,7 +327,10 @@ class SQLBaseStore(object):
try:
txn = conn.cursor()
txn = LoggingTransaction(
txn, name, self.database_engine, after_callbacks,
txn,
name,
self.database_engine,
after_callbacks,
exception_callbacks,
)
r = func(txn, *args, **kwargs)
@ -334,7 +341,10 @@ class SQLBaseStore(object):
# transaction.
logger.warning(
"[TXN OPERROR] {%s} %s %d/%d",
name, exception_to_unicode(e), i, N
name,
exception_to_unicode(e),
i,
N,
)
if i < N:
i += 1
@ -342,8 +352,7 @@ class SQLBaseStore(object):
conn.rollback()
except self.database_engine.module.Error as e1:
logger.warning(
"[TXN EROLL] {%s} %s",
name, exception_to_unicode(e1),
"[TXN EROLL] {%s} %s", name, exception_to_unicode(e1)
)
continue
raise
@ -357,7 +366,8 @@ class SQLBaseStore(object):
except self.database_engine.module.Error as e1:
logger.warning(
"[TXN EROLL] {%s} %s",
name, exception_to_unicode(e1),
name,
exception_to_unicode(e1),
)
continue
raise
@ -396,16 +406,17 @@ class SQLBaseStore(object):
exception_callbacks = []
if LoggingContext.current_context() == LoggingContext.sentinel:
logger.warn(
"Starting db txn '%s' from sentinel context",
desc,
)
logger.warn("Starting db txn '%s' from sentinel context", desc)
try:
result = yield self.runWithConnection(
self._new_transaction,
desc, after_callbacks, exception_callbacks, func,
*args, **kwargs
desc,
after_callbacks,
exception_callbacks,
func,
*args,
**kwargs
)
for after_callback, after_args, after_kwargs in after_callbacks:
@ -434,7 +445,7 @@ class SQLBaseStore(object):
parent_context = LoggingContext.current_context()
if parent_context == LoggingContext.sentinel:
logger.warn(
"Starting db connection from sentinel context: metrics will be lost",
"Starting db connection from sentinel context: metrics will be lost"
)
parent_context = None
@ -453,9 +464,7 @@ class SQLBaseStore(object):
return func(conn, *args, **kwargs)
with PreserveLoggingContext():
result = yield self._db_pool.runWithConnection(
inner_func, *args, **kwargs
)
result = yield self._db_pool.runWithConnection(inner_func, *args, **kwargs)
defer.returnValue(result)
@ -469,9 +478,7 @@ class SQLBaseStore(object):
A list of dicts where the key is the column header.
"""
col_headers = list(intern(str(column[0])) for column in cursor.description)
results = list(
dict(zip(col_headers, row)) for row in cursor
)
results = list(dict(zip(col_headers, row)) for row in cursor)
return results
def _execute(self, desc, decoder, query, *args):
@ -485,6 +492,7 @@ class SQLBaseStore(object):
Returns:
The result of decoder(results)
"""
def interaction(txn):
txn.execute(query, args)
if decoder:
@ -498,8 +506,7 @@ class SQLBaseStore(object):
# no complex WHERE clauses, just a dict of values for columns.
@defer.inlineCallbacks
def _simple_insert(self, table, values, or_ignore=False,
desc="_simple_insert"):
def _simple_insert(self, table, values, or_ignore=False, desc="_simple_insert"):
"""Executes an INSERT query on the named table.
Args:
@ -511,10 +518,7 @@ class SQLBaseStore(object):
`or_ignore` is True
"""
try:
yield self.runInteraction(
desc,
self._simple_insert_txn, table, values,
)
yield self.runInteraction(desc, self._simple_insert_txn, table, values)
except self.database_engine.module.IntegrityError:
# We have to do or_ignore flag at this layer, since we can't reuse
# a cursor after we receive an error from the db.
@ -530,15 +534,13 @@ class SQLBaseStore(object):
sql = "INSERT INTO %s (%s) VALUES(%s)" % (
table,
", ".join(k for k in keys),
", ".join("?" for _ in keys)
", ".join("?" for _ in keys),
)
txn.execute(sql, vals)
def _simple_insert_many(self, table, values, desc):
return self.runInteraction(
desc, self._simple_insert_many_txn, table, values
)
return self.runInteraction(desc, self._simple_insert_many_txn, table, values)
@staticmethod
def _simple_insert_many_txn(txn, table, values):
@ -553,24 +555,18 @@ class SQLBaseStore(object):
#
# The sort is to ensure that we don't rely on dictionary iteration
# order.
keys, vals = zip(*[
zip(
*(sorted(i.items(), key=lambda kv: kv[0]))
)
for i in values
if i
])
keys, vals = zip(
*[zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i]
)
for k in keys:
if k != keys[0]:
raise RuntimeError(
"All items must have the same keys"
)
raise RuntimeError("All items must have the same keys")
sql = "INSERT INTO %s (%s) VALUES(%s)" % (
table,
", ".join(k for k in keys[0]),
", ".join("?" for _ in keys[0])
", ".join("?" for _ in keys[0]),
)
txn.executemany(sql, vals)
@ -583,7 +579,7 @@ class SQLBaseStore(object):
values,
insertion_values={},
desc="_simple_upsert",
lock=True
lock=True,
):
"""
@ -599,7 +595,7 @@ class SQLBaseStore(object):
Args:
table (str): The table to upsert into
keyvalues (dict): The unique key tables and their new values
keyvalues (dict): The unique key columns and their new values
values (dict): The nonunique columns and their new values
insertion_values (dict): additional key/values to use only when
inserting
@ -631,17 +627,11 @@ class SQLBaseStore(object):
# presumably we raced with another transaction: let's retry.
logger.warn(
"%s when upserting into %s; retrying: %s", e.__name__, table, e
"IntegrityError when upserting into %s; retrying: %s", table, e
)
def _simple_upsert_txn(
self,
txn,
table,
keyvalues,
values,
insertion_values={},
lock=True,
self, txn, table, keyvalues, values, insertion_values={}, lock=True
):
"""
Pick the UPSERT method which works best on the platform. Either the
@ -665,11 +655,7 @@ class SQLBaseStore(object):
and table not in self._unsafe_to_upsert_tables
):
return self._simple_upsert_txn_native_upsert(
txn,
table,
keyvalues,
values,
insertion_values=insertion_values,
txn, table, keyvalues, values, insertion_values=insertion_values
)
else:
return self._simple_upsert_txn_emulated(
@ -714,7 +700,7 @@ class SQLBaseStore(object):
# SELECT instead to see if it exists.
sql = "SELECT 1 FROM %s WHERE %s" % (
table,
" AND ".join(_getwhere(k) for k in keyvalues)
" AND ".join(_getwhere(k) for k in keyvalues),
)
sqlargs = list(keyvalues.values())
txn.execute(sql, sqlargs)
@ -726,7 +712,7 @@ class SQLBaseStore(object):
sql = "UPDATE %s SET %s WHERE %s" % (
table,
", ".join("%s = ?" % (k,) for k in values),
" AND ".join(_getwhere(k) for k in keyvalues)
" AND ".join(_getwhere(k) for k in keyvalues),
)
sqlargs = list(values.values()) + list(keyvalues.values())
@ -773,19 +759,14 @@ class SQLBaseStore(object):
latter = "NOTHING"
else:
allvalues.update(values)
latter = (
"UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
)
latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
sql = (
"INSERT INTO %s (%s) VALUES (%s) "
"ON CONFLICT (%s) DO %s"
) % (
sql = ("INSERT INTO %s (%s) VALUES (%s) " "ON CONFLICT (%s) DO %s") % (
table,
", ".join(k for k in allvalues),
", ".join("?" for _ in allvalues),
", ".join(k for k in keyvalues),
latter
latter,
)
txn.execute(sql, list(allvalues.values()))
@ -870,8 +851,8 @@ class SQLBaseStore(object):
latter = "NOTHING"
value_values = [() for x in range(len(key_values))]
else:
latter = (
"UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in value_names)
latter = "UPDATE SET " + ", ".join(
k + "=EXCLUDED." + k for k in value_names
)
sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % (
@ -889,8 +870,9 @@ class SQLBaseStore(object):
return txn.execute_batch(sql, args)
def _simple_select_one(self, table, keyvalues, retcols,
allow_none=False, desc="_simple_select_one"):
def _simple_select_one(
self, table, keyvalues, retcols, allow_none=False, desc="_simple_select_one"
):
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning multiple columns from it.
@ -903,14 +885,17 @@ class SQLBaseStore(object):
statement returns no rows
"""
return self.runInteraction(
desc,
self._simple_select_one_txn,
table, keyvalues, retcols, allow_none,
desc, self._simple_select_one_txn, table, keyvalues, retcols, allow_none
)
def _simple_select_one_onecol(self, table, keyvalues, retcol,
allow_none=False,
desc="_simple_select_one_onecol"):
def _simple_select_one_onecol(
self,
table,
keyvalues,
retcol,
allow_none=False,
desc="_simple_select_one_onecol",
):
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning a single column from it.
@ -922,17 +907,18 @@ class SQLBaseStore(object):
return self.runInteraction(
desc,
self._simple_select_one_onecol_txn,
table, keyvalues, retcol, allow_none=allow_none,
table,
keyvalues,
retcol,
allow_none=allow_none,
)
@classmethod
def _simple_select_one_onecol_txn(cls, txn, table, keyvalues, retcol,
allow_none=False):
def _simple_select_one_onecol_txn(
cls, txn, table, keyvalues, retcol, allow_none=False
):
ret = cls._simple_select_onecol_txn(
txn,
table=table,
keyvalues=keyvalues,
retcol=retcol,
txn, table=table, keyvalues=keyvalues, retcol=retcol
)
if ret:
@ -945,12 +931,7 @@ class SQLBaseStore(object):
@staticmethod
def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
sql = (
"SELECT %(retcol)s FROM %(table)s"
) % {
"retcol": retcol,
"table": table,
}
sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
if keyvalues:
sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
@ -960,8 +941,9 @@ class SQLBaseStore(object):
return [r[0] for r in txn]
def _simple_select_onecol(self, table, keyvalues, retcol,
desc="_simple_select_onecol"):
def _simple_select_onecol(
self, table, keyvalues, retcol, desc="_simple_select_onecol"
):
"""Executes a SELECT query on the named table, which returns a list
comprising of the values of the named column from the selected rows.
@ -974,13 +956,12 @@ class SQLBaseStore(object):
Deferred: Results in a list
"""
return self.runInteraction(
desc,
self._simple_select_onecol_txn,
table, keyvalues, retcol
desc, self._simple_select_onecol_txn, table, keyvalues, retcol
)
def _simple_select_list(self, table, keyvalues, retcols,
desc="_simple_select_list"):
def _simple_select_list(
self, table, keyvalues, retcols, desc="_simple_select_list"
):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
@ -994,9 +975,7 @@ class SQLBaseStore(object):
defer.Deferred: resolves to list[dict[str, Any]]
"""
return self.runInteraction(
desc,
self._simple_select_list_txn,
table, keyvalues, retcols
desc, self._simple_select_list_txn, table, keyvalues, retcols
)
@classmethod
@ -1016,22 +995,26 @@ class SQLBaseStore(object):
sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols),
table,
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
" AND ".join("%s = ?" % (k,) for k in keyvalues),
)
txn.execute(sql, list(keyvalues.values()))
else:
sql = "SELECT %s FROM %s" % (
", ".join(retcols),
table
)
sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
txn.execute(sql)
return cls.cursor_to_dict(txn)
@defer.inlineCallbacks
def _simple_select_many_batch(self, table, column, iterable, retcols,
keyvalues={}, desc="_simple_select_many_batch",
batch_size=100):
def _simple_select_many_batch(
self,
table,
column,
iterable,
retcols,
keyvalues={},
desc="_simple_select_many_batch",
batch_size=100,
):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
@ -1053,14 +1036,17 @@ class SQLBaseStore(object):
it_list = list(iterable)
chunks = [
it_list[i:i + batch_size]
for i in range(0, len(it_list), batch_size)
it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
]
for chunk in chunks:
rows = yield self.runInteraction(
desc,
self._simple_select_many_txn,
table, column, chunk, keyvalues, retcols
table,
column,
chunk,
keyvalues,
retcols,
)
results.extend(rows)
@ -1089,9 +1075,7 @@ class SQLBaseStore(object):
clauses = []
values = []
clauses.append(
"%s IN (%s)" % (column, ",".join("?" for _ in iterable))
)
clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable)))
values.extend(iterable)
for key, value in iteritems(keyvalues):
@ -1099,19 +1083,14 @@ class SQLBaseStore(object):
values.append(value)
if clauses:
sql = "%s WHERE %s" % (
sql,
" AND ".join(clauses),
)
sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
txn.execute(sql, values)
return cls.cursor_to_dict(txn)
def _simple_update(self, table, keyvalues, updatevalues, desc):
return self.runInteraction(
desc,
self._simple_update_txn,
table, keyvalues, updatevalues,
desc, self._simple_update_txn, table, keyvalues, updatevalues
)
@staticmethod
@ -1127,15 +1106,13 @@ class SQLBaseStore(object):
where,
)
txn.execute(
update_sql,
list(updatevalues.values()) + list(keyvalues.values())
)
txn.execute(update_sql, list(updatevalues.values()) + list(keyvalues.values()))
return txn.rowcount
def _simple_update_one(self, table, keyvalues, updatevalues,
desc="_simple_update_one"):
def _simple_update_one(
self, table, keyvalues, updatevalues, desc="_simple_update_one"
):
"""Executes an UPDATE query on the named table, setting new values for
columns in a row matching the key values.
@ -1154,9 +1131,7 @@ class SQLBaseStore(object):
the update column in the 'keyvalues' dict as well.
"""
return self.runInteraction(
desc,
self._simple_update_one_txn,
table, keyvalues, updatevalues,
desc, self._simple_update_one_txn, table, keyvalues, updatevalues
)
@classmethod
@ -1169,12 +1144,11 @@ class SQLBaseStore(object):
raise StoreError(500, "More than one row matched (%s)" % (table,))
@staticmethod
def _simple_select_one_txn(txn, table, keyvalues, retcols,
allow_none=False):
def _simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False):
select_sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols),
table,
" AND ".join("%s = ?" % (k,) for k in keyvalues)
" AND ".join("%s = ?" % (k,) for k in keyvalues),
)
txn.execute(select_sql, list(keyvalues.values()))
@ -1197,9 +1171,7 @@ class SQLBaseStore(object):
table : string giving the table name
keyvalues : dict of column names and values to select the row with
"""
return self.runInteraction(
desc, self._simple_delete_one_txn, table, keyvalues
)
return self.runInteraction(desc, self._simple_delete_one_txn, table, keyvalues)
@staticmethod
def _simple_delete_one_txn(txn, table, keyvalues):
@ -1212,7 +1184,7 @@ class SQLBaseStore(object):
"""
sql = "DELETE FROM %s WHERE %s" % (
table,
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
" AND ".join("%s = ?" % (k,) for k in keyvalues),
)
txn.execute(sql, list(keyvalues.values()))
@ -1222,15 +1194,13 @@ class SQLBaseStore(object):
raise StoreError(500, "More than one row matched (%s)" % (table,))
def _simple_delete(self, table, keyvalues, desc):
return self.runInteraction(
desc, self._simple_delete_txn, table, keyvalues
)
return self.runInteraction(desc, self._simple_delete_txn, table, keyvalues)
@staticmethod
def _simple_delete_txn(txn, table, keyvalues):
sql = "DELETE FROM %s WHERE %s" % (
table,
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
" AND ".join("%s = ?" % (k,) for k in keyvalues),
)
return txn.execute(sql, list(keyvalues.values()))
@ -1260,9 +1230,7 @@ class SQLBaseStore(object):
clauses = []
values = []
clauses.append(
"%s IN (%s)" % (column, ",".join("?" for _ in iterable))
)
clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable)))
values.extend(iterable)
for key, value in iteritems(keyvalues):
@ -1270,14 +1238,12 @@ class SQLBaseStore(object):
values.append(value)
if clauses:
sql = "%s WHERE %s" % (
sql,
" AND ".join(clauses),
)
sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
return txn.execute(sql, values)
def _get_cache_dict(self, db_conn, table, entity_column, stream_column,
max_value, limit=100000):
def _get_cache_dict(
self, db_conn, table, entity_column, stream_column, max_value, limit=100000
):
# Fetch a mapping of room_id -> max stream position for "recent" rooms.
# It doesn't really matter how many we get, the StreamChangeCache will
# do the right thing to ensure it respects the max size of cache.
@ -1297,10 +1263,7 @@ class SQLBaseStore(object):
txn = db_conn.cursor()
txn.execute(sql, (int(max_value),))
cache = {
row[0]: int(row[1])
for row in txn
}
cache = {row[0]: int(row[1]) for row in txn}
txn.close()
@ -1342,9 +1305,7 @@ class SQLBaseStore(object):
# be safe.
for chunk in batch_iter(members_changed, 50):
keys = itertools.chain([room_id], chunk)
self._send_invalidation_to_replication(
txn, _CURRENT_STATE_CACHE_NAME, keys,
)
self._send_invalidation_to_replication(txn, _CURRENT_STATE_CACHE_NAME, keys)
def _invalidate_state_caches(self, room_id, members_changed):
"""Invalidates caches that are based on the current state, but does
@ -1356,22 +1317,12 @@ class SQLBaseStore(object):
changed
"""
for host in set(get_domain_from_id(u) for u in members_changed):
self._attempt_to_invalidate_cache(
"is_host_joined", (room_id, host,),
)
self._attempt_to_invalidate_cache(
"was_host_joined", (room_id, host,),
)
self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
self._attempt_to_invalidate_cache("was_host_joined", (room_id, host))
self._attempt_to_invalidate_cache(
"get_users_in_room", (room_id,),
)
self._attempt_to_invalidate_cache(
"get_room_summary", (room_id,),
)
self._attempt_to_invalidate_cache(
"get_current_state_ids", (room_id,),
)
self._attempt_to_invalidate_cache("get_users_in_room", (room_id,))
self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
self._attempt_to_invalidate_cache("get_current_state_ids", (room_id,))
def _attempt_to_invalidate_cache(self, cache_name, key):
"""Attempts to invalidate the cache of the given name, ignoring if the
@ -1419,7 +1370,7 @@ class SQLBaseStore(object):
"cache_func": cache_name,
"keys": list(keys),
"invalidation_ts": self.clock.time_msec(),
}
},
)
def get_all_updated_caches(self, last_id, current_id, limit):
@ -1435,11 +1386,10 @@ class SQLBaseStore(object):
" FROM cache_invalidation_stream"
" WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
)
txn.execute(sql, (last_id, limit,))
txn.execute(sql, (last_id, limit))
return txn.fetchall()
return self.runInteraction(
"get_all_updated_caches", get_all_updated_caches_txn
)
return self.runInteraction("get_all_updated_caches", get_all_updated_caches_txn)
def get_cache_stream_token(self):
if self._cache_id_gen:
@ -1447,33 +1397,61 @@ class SQLBaseStore(object):
else:
return 0
def _simple_select_list_paginate(self, table, keyvalues, pagevalues, retcols,
desc="_simple_select_list_paginate"):
"""Executes a SELECT query on the named table with start and limit,
def _simple_select_list_paginate(
self,
table,
keyvalues,
orderby,
start,
limit,
retcols,
order_direction="ASC",
desc="_simple_select_list_paginate",
):
"""
Executes a SELECT query on the named table with start and limit,
of row numbers, which may return zero or number of rows from start to limit,
returning the result as a list of dicts.
Args:
table (str): the table name
keyvalues (dict[str, Any] | None):
keyvalues (dict[str, T] | None):
column names and values to select the rows with, or None to not
apply a WHERE clause.
orderby (str): Column to order the results by.
start (int): Index to begin the query at.
limit (int): Number of results to return.
retcols (iterable[str]): the names of the columns to return
order (str): order the select by this column
start (int): start number to begin the query from
limit (int): number of rows to reterive
order_direction (str): Whether the results should be ordered "ASC" or "DESC".
Returns:
defer.Deferred: resolves to list[dict[str, Any]]
"""
return self.runInteraction(
desc,
self._simple_select_list_paginate_txn,
table, keyvalues, pagevalues, retcols
table,
keyvalues,
orderby,
start,
limit,
retcols,
order_direction=order_direction,
)
@classmethod
def _simple_select_list_paginate_txn(cls, txn, table, keyvalues, pagevalues, retcols):
"""Executes a SELECT query on the named table with start and limit,
def _simple_select_list_paginate_txn(
cls,
txn,
table,
keyvalues,
orderby,
start,
limit,
retcols,
order_direction="ASC",
):
"""
Executes a SELECT query on the named table with start and limit,
of row numbers, which may return zero or number of rows from start to limit,
returning the result as a list of dicts.
@ -1483,67 +1461,33 @@ class SQLBaseStore(object):
keyvalues (dict[str, T] | None):
column names and values to select the rows with, or None to not
apply a WHERE clause.
pagevalues ([]):
order (str): order the select by this column
start (int): start number to begin the query from
limit (int): number of rows to reterive
orderby (str): Column to order the results by.
start (int): Index to begin the query at.
limit (int): Number of results to return.
retcols (iterable[str]): the names of the columns to return
order_direction (str): Whether the results should be ordered "ASC" or "DESC".
Returns:
defer.Deferred: resolves to list[dict[str, Any]]
"""
if order_direction not in ["ASC", "DESC"]:
raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
if keyvalues:
sql = "SELECT %s FROM %s WHERE %s ORDER BY %s" % (
", ".join(retcols),
table,
" AND ".join("%s = ?" % (k,) for k in keyvalues),
" ? ASC LIMIT ? OFFSET ?"
)
txn.execute(sql, list(keyvalues.values()) + list(pagevalues))
where_clause = "WHERE " + " AND ".join("%s = ?" % (k,) for k in keyvalues)
else:
sql = "SELECT %s FROM %s ORDER BY %s" % (
", ".join(retcols),
table,
" ? ASC LIMIT ? OFFSET ?"
)
txn.execute(sql, pagevalues)
where_clause = ""
sql = "SELECT %s FROM %s %s ORDER BY %s %s LIMIT ? OFFSET ?" % (
", ".join(retcols),
table,
where_clause,
orderby,
order_direction,
)
txn.execute(sql, list(keyvalues.values()) + [limit, start])
return cls.cursor_to_dict(txn)
@defer.inlineCallbacks
def get_user_list_paginate(self, table, keyvalues, pagevalues, retcols,
desc="get_user_list_paginate"):
"""Get a list of users from start row to a limit number of rows. This will
return a json object with users and total number of users in users list.
Args:
table (str): the table name
keyvalues (dict[str, Any] | None):
column names and values to select the rows with, or None to not
apply a WHERE clause.
pagevalues ([]):
order (str): order the select by this column
start (int): start number to begin the query from
limit (int): number of rows to reterive
retcols (iterable[str]): the names of the columns to return
Returns:
defer.Deferred: resolves to json object {list[dict[str, Any]], count}
"""
users = yield self.runInteraction(
desc,
self._simple_select_list_paginate_txn,
table, keyvalues, pagevalues, retcols
)
count = yield self.runInteraction(
desc,
self.get_user_count_txn
)
retval = {
"users": users,
"total": count
}
defer.returnValue(retval)
def get_user_count_txn(self, txn):
"""Get a total number of registered users in the users list.
@ -1556,8 +1500,9 @@ class SQLBaseStore(object):
txn.execute(sql_count)
return txn.fetchone()[0]
def _simple_search_list(self, table, term, col, retcols,
desc="_simple_search_list"):
def _simple_search_list(
self, table, term, col, retcols, desc="_simple_search_list"
):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
@ -1572,9 +1517,7 @@ class SQLBaseStore(object):
"""
return self.runInteraction(
desc,
self._simple_search_list_txn,
table, term, col, retcols
desc, self._simple_search_list_txn, table, term, col, retcols
)
@classmethod
@ -1593,11 +1536,7 @@ class SQLBaseStore(object):
defer.Deferred: resolves to list[dict[str, Any]] or None
"""
if term:
sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (
", ".join(retcols),
table,
col
)
sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col)
termvalues = ["%%" + term + "%%"]
txn.execute(sql, termvalues)
else:
@ -1618,6 +1557,7 @@ class _RollbackButIsFineException(Exception):
""" This exception is used to rollback a transaction without implying
something went wrong.
"""
pass

View file

@ -41,7 +41,7 @@ class AccountDataWorkerStore(SQLBaseStore):
def __init__(self, db_conn, hs):
account_max = self.get_max_account_data_stream_id()
self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache", account_max,
"AccountDataAndTagsChangeCache", account_max
)
super(AccountDataWorkerStore, self).__init__(db_conn, hs)
@ -68,8 +68,10 @@ class AccountDataWorkerStore(SQLBaseStore):
def get_account_data_for_user_txn(txn):
rows = self._simple_select_list_txn(
txn, "account_data", {"user_id": user_id},
["account_data_type", "content"]
txn,
"account_data",
{"user_id": user_id},
["account_data_type", "content"],
)
global_account_data = {
@ -77,8 +79,10 @@ class AccountDataWorkerStore(SQLBaseStore):
}
rows = self._simple_select_list_txn(
txn, "room_account_data", {"user_id": user_id},
["room_id", "account_data_type", "content"]
txn,
"room_account_data",
{"user_id": user_id},
["room_id", "account_data_type", "content"],
)
by_room = {}
@ -100,10 +104,7 @@ class AccountDataWorkerStore(SQLBaseStore):
"""
result = yield self._simple_select_one_onecol(
table="account_data",
keyvalues={
"user_id": user_id,
"account_data_type": data_type,
},
keyvalues={"user_id": user_id, "account_data_type": data_type},
retcol="content",
desc="get_global_account_data_by_type_for_user",
allow_none=True,
@ -124,10 +125,13 @@ class AccountDataWorkerStore(SQLBaseStore):
Returns:
A deferred dict of the room account_data
"""
def get_account_data_for_room_txn(txn):
rows = self._simple_select_list_txn(
txn, "room_account_data", {"user_id": user_id, "room_id": room_id},
["account_data_type", "content"]
txn,
"room_account_data",
{"user_id": user_id, "room_id": room_id},
["account_data_type", "content"],
)
return {
@ -150,6 +154,7 @@ class AccountDataWorkerStore(SQLBaseStore):
A deferred of the room account_data for that type, or None if
there isn't any set.
"""
def get_account_data_for_room_and_type_txn(txn):
content_json = self._simple_select_one_onecol_txn(
txn,
@ -160,18 +165,18 @@ class AccountDataWorkerStore(SQLBaseStore):
"account_data_type": account_data_type,
},
retcol="content",
allow_none=True
allow_none=True,
)
return json.loads(content_json) if content_json else None
return self.runInteraction(
"get_account_data_for_room_and_type",
get_account_data_for_room_and_type_txn,
"get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
)
def get_all_updated_account_data(self, last_global_id, last_room_id,
current_id, limit):
def get_all_updated_account_data(
self, last_global_id, last_room_id, current_id, limit
):
"""Get all the client account_data that has changed on the server
Args:
last_global_id(int): The position to fetch from for top level data
@ -201,6 +206,7 @@ class AccountDataWorkerStore(SQLBaseStore):
txn.execute(sql, (last_room_id, current_id, limit))
room_results = txn.fetchall()
return (global_results, room_results)
return self.runInteraction(
"get_all_updated_account_data_txn", get_updated_account_data_txn
)
@ -224,9 +230,7 @@ class AccountDataWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id, stream_id))
global_account_data = {
row[0]: json.loads(row[1]) for row in txn
}
global_account_data = {row[0]: json.loads(row[1]) for row in txn}
sql = (
"SELECT room_id, account_data_type, content FROM room_account_data"
@ -255,7 +259,8 @@ class AccountDataWorkerStore(SQLBaseStore):
@cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000)
def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context):
ignored_account_data = yield self.get_global_account_data_by_type_for_user(
"m.ignored_user_list", ignorer_user_id,
"m.ignored_user_list",
ignorer_user_id,
on_invalidate=cache_context.invalidate,
)
if not ignored_account_data:
@ -307,10 +312,7 @@ class AccountDataStore(AccountDataWorkerStore):
"room_id": room_id,
"account_data_type": account_data_type,
},
values={
"stream_id": next_id,
"content": content_json,
},
values={"stream_id": next_id, "content": content_json},
lock=False,
)
@ -324,9 +326,9 @@ class AccountDataStore(AccountDataWorkerStore):
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
self.get_account_data_for_user.invalidate((user_id,))
self.get_account_data_for_room.invalidate((user_id, room_id,))
self.get_account_data_for_room.invalidate((user_id, room_id))
self.get_account_data_for_room_and_type.prefill(
(user_id, room_id, account_data_type,), content,
(user_id, room_id, account_data_type), content
)
result = self._account_data_id_gen.get_current_token()
@ -351,14 +353,8 @@ class AccountDataStore(AccountDataWorkerStore):
yield self._simple_upsert(
desc="add_user_account_data",
table="account_data",
keyvalues={
"user_id": user_id,
"account_data_type": account_data_type,
},
values={
"stream_id": next_id,
"content": content_json,
},
keyvalues={"user_id": user_id, "account_data_type": account_data_type},
values={"stream_id": next_id, "content": content_json},
lock=False,
)
@ -370,12 +366,10 @@ class AccountDataStore(AccountDataWorkerStore):
# transaction.
yield self._update_max_stream_id(next_id)
self._account_data_stream_cache.entity_has_changed(
user_id, next_id,
)
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
self.get_account_data_for_user.invalidate((user_id,))
self.get_global_account_data_by_type_for_user.invalidate(
(account_data_type, user_id,)
(account_data_type, user_id)
)
result = self._account_data_id_gen.get_current_token()
@ -387,6 +381,7 @@ class AccountDataStore(AccountDataWorkerStore):
Args:
next_id(int): The the revision to advance to.
"""
def _update(txn):
update_max_id_sql = (
"UPDATE account_data_max_stream_id"
@ -394,7 +389,5 @@ class AccountDataStore(AccountDataWorkerStore):
" WHERE stream_id < ?"
)
txn.execute(update_max_id_sql, (next_id, next_id))
return self.runInteraction(
"update_account_data_max_stream_id",
_update,
)
return self.runInteraction("update_account_data_max_stream_id", _update)

View file

@ -51,8 +51,7 @@ def _make_exclusive_regex(services_cache):
class ApplicationServiceWorkerStore(SQLBaseStore):
def __init__(self, db_conn, hs):
self.services_cache = load_appservices(
hs.hostname,
hs.config.app_service_config_files
hs.hostname, hs.config.app_service_config_files
)
self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
@ -122,8 +121,9 @@ class ApplicationServiceStore(ApplicationServiceWorkerStore):
pass
class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
EventsWorkerStore):
class ApplicationServiceTransactionWorkerStore(
ApplicationServiceWorkerStore, EventsWorkerStore
):
@defer.inlineCallbacks
def get_appservices_by_state(self, state):
"""Get a list of application services based on their state.
@ -135,9 +135,7 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
may be empty.
"""
results = yield self._simple_select_list(
"application_services_state",
dict(state=state),
["as_id"]
"application_services_state", dict(state=state), ["as_id"]
)
# NB: This assumes this class is linked with ApplicationServiceStore
as_list = self.get_app_services()
@ -180,9 +178,7 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
A Deferred which resolves when the state was set successfully.
"""
return self._simple_upsert(
"application_services_state",
dict(as_id=service.id),
dict(state=state)
"application_services_state", dict(as_id=service.id), dict(state=state)
)
def create_appservice_txn(self, service, events):
@ -195,6 +191,7 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
Returns:
AppServiceTransaction: A new transaction.
"""
def _create_appservice_txn(txn):
# work out new txn id (highest txn id for this service += 1)
# The highest id may be the last one sent (in which case it is last_txn)
@ -204,7 +201,7 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
txn.execute(
"SELECT MAX(txn_id) FROM application_services_txns WHERE as_id=?",
(service.id,)
(service.id,),
)
highest_txn_id = txn.fetchone()[0]
if highest_txn_id is None:
@ -217,16 +214,11 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
txn.execute(
"INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
"VALUES(?,?,?)",
(service.id, new_txn_id, event_ids)
)
return AppServiceTransaction(
service=service, id=new_txn_id, events=events
(service.id, new_txn_id, event_ids),
)
return AppServiceTransaction(service=service, id=new_txn_id, events=events)
return self.runInteraction(
"create_appservice_txn",
_create_appservice_txn,
)
return self.runInteraction("create_appservice_txn", _create_appservice_txn)
def complete_appservice_txn(self, txn_id, service):
"""Completes an application service transaction.
@ -252,26 +244,26 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
"appservice: Completing a transaction which has an ID > 1 from "
"the last ID sent to this AS. We've either dropped events or "
"sent it to the AS out of order. FIX ME. last_txn=%s "
"completing_txn=%s service_id=%s", last_txn_id, txn_id,
service.id
"completing_txn=%s service_id=%s",
last_txn_id,
txn_id,
service.id,
)
# Set current txn_id for AS to 'txn_id'
self._simple_upsert_txn(
txn, "application_services_state", dict(as_id=service.id),
dict(last_txn=txn_id)
txn,
"application_services_state",
dict(as_id=service.id),
dict(last_txn=txn_id),
)
# Delete txn
self._simple_delete_txn(
txn, "application_services_txns",
dict(txn_id=txn_id, as_id=service.id)
txn, "application_services_txns", dict(txn_id=txn_id, as_id=service.id)
)
return self.runInteraction(
"complete_appservice_txn",
_complete_appservice_txn,
)
return self.runInteraction("complete_appservice_txn", _complete_appservice_txn)
@defer.inlineCallbacks
def get_oldest_unsent_txn(self, service):
@ -284,13 +276,14 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
A Deferred which resolves to an AppServiceTransaction or
None.
"""
def _get_oldest_unsent_txn(txn):
# Monotonically increasing txn ids, so just select the smallest
# one in the txns table (we delete them when they are sent)
txn.execute(
"SELECT * FROM application_services_txns WHERE as_id=?"
" ORDER BY txn_id ASC LIMIT 1",
(service.id,)
(service.id,),
)
rows = self.cursor_to_dict(txn)
if not rows:
@ -301,8 +294,7 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
return entry
entry = yield self.runInteraction(
"get_oldest_unsent_appservice_txn",
_get_oldest_unsent_txn,
"get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn
)
if not entry:
@ -312,14 +304,14 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
events = yield self._get_events(event_ids)
defer.returnValue(AppServiceTransaction(
service=service, id=entry["txn_id"], events=events
))
defer.returnValue(
AppServiceTransaction(service=service, id=entry["txn_id"], events=events)
)
def _get_last_txn(self, txn, service_id):
txn.execute(
"SELECT last_txn FROM application_services_state WHERE as_id=?",
(service_id,)
(service_id,),
)
last_txn_id = txn.fetchone()
if last_txn_id is None or last_txn_id[0] is None: # no row exists
@ -332,6 +324,7 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
txn.execute(
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
)
return self.runInteraction(
"set_appservice_last_pos", set_appservice_last_pos_txn
)
@ -362,7 +355,7 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
return upper_bound, [row[1] for row in rows]
upper_bound, event_ids = yield self.runInteraction(
"get_new_events_for_appservice", get_new_events_for_appservice_txn,
"get_new_events_for_appservice", get_new_events_for_appservice_txn
)
events = yield self._get_events(event_ids)

View file

@ -94,16 +94,13 @@ class BackgroundUpdateStore(SQLBaseStore):
self._all_done = False
def start_doing_background_updates(self):
run_as_background_process(
"background_updates", self._run_background_updates,
)
run_as_background_process("background_updates", self._run_background_updates)
@defer.inlineCallbacks
def _run_background_updates(self):
logger.info("Starting background schema updates")
while True:
yield self.hs.get_clock().sleep(
self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.)
yield self.hs.get_clock().sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0)
try:
result = yield self.do_next_background_update(
@ -187,8 +184,7 @@ class BackgroundUpdateStore(SQLBaseStore):
@defer.inlineCallbacks
def _do_background_update(self, update_name, desired_duration_ms):
logger.info("Starting update batch on background update '%s'",
update_name)
logger.info("Starting update batch on background update '%s'", update_name)
update_handler = self._background_update_handlers[update_name]
@ -210,7 +206,7 @@ class BackgroundUpdateStore(SQLBaseStore):
progress_json = yield self._simple_select_one_onecol(
"background_updates",
keyvalues={"update_name": update_name},
retcol="progress_json"
retcol="progress_json",
)
progress = json.loads(progress_json)
@ -224,7 +220,9 @@ class BackgroundUpdateStore(SQLBaseStore):
logger.info(
"Updating %r. Updated %r items in %rms."
" (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r, batch_size=%r)",
update_name, items_updated, duration_ms,
update_name,
items_updated,
duration_ms,
performance.total_items_per_ms(),
performance.average_items_per_ms(),
performance.total_item_count,
@ -264,6 +262,7 @@ class BackgroundUpdateStore(SQLBaseStore):
Args:
update_name (str): Name of update
"""
@defer.inlineCallbacks
def noop_update(progress, batch_size):
yield self._end_background_update(update_name)
@ -271,10 +270,16 @@ class BackgroundUpdateStore(SQLBaseStore):
self.register_background_update_handler(update_name, noop_update)
def register_background_index_update(self, update_name, index_name,
table, columns, where_clause=None,
unique=False,
psql_only=False):
def register_background_index_update(
self,
update_name,
index_name,
table,
columns,
where_clause=None,
unique=False,
psql_only=False,
):
"""Helper for store classes to do a background index addition
To use:
@ -320,7 +325,7 @@ class BackgroundUpdateStore(SQLBaseStore):
"name": index_name,
"table": table,
"columns": ", ".join(columns),
"where_clause": "WHERE " + where_clause if where_clause else ""
"where_clause": "WHERE " + where_clause if where_clause else "",
}
logger.debug("[SQL] %s", sql)
c.execute(sql)
@ -387,7 +392,7 @@ class BackgroundUpdateStore(SQLBaseStore):
return self._simple_insert(
"background_updates",
{"update_name": update_name, "progress_json": progress_json}
{"update_name": update_name, "progress_json": progress_json},
)
def _end_background_update(self, update_name):

View file

@ -37,9 +37,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
def __init__(self, db_conn, hs):
self.client_ip_last_seen = Cache(
name="client_ip_last_seen",
keylen=4,
max_entries=50000 * CACHE_SIZE_FACTOR,
name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
)
super(ClientIpStore, self).__init__(db_conn, hs)
@ -66,13 +64,11 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
)
self.register_background_update_handler(
"user_ips_analyze",
self._analyze_user_ip,
"user_ips_analyze", self._analyze_user_ip
)
self.register_background_update_handler(
"user_ips_remove_dupes",
self._remove_user_ip_dupes,
"user_ips_remove_dupes", self._remove_user_ip_dupes
)
# Register a unique index
@ -86,8 +82,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
# Drop the old non-unique index
self.register_background_update_handler(
"user_ips_drop_nonunique_index",
self._remove_user_ip_nonunique,
"user_ips_drop_nonunique_index", self._remove_user_ip_nonunique
)
# (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
@ -104,9 +99,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
def _remove_user_ip_nonunique(self, progress, batch_size):
def f(conn):
txn = conn.cursor()
txn.execute(
"DROP INDEX IF EXISTS user_ips_user_ip"
)
txn.execute("DROP INDEX IF EXISTS user_ips_user_ip")
txn.close()
yield self.runWithConnection(f)
@ -124,9 +117,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
def user_ips_analyze(txn):
txn.execute("ANALYZE user_ips")
yield self.runInteraction(
"user_ips_analyze", user_ips_analyze
)
yield self.runInteraction("user_ips_analyze", user_ips_analyze)
yield self._end_background_update("user_ips_analyze")
@ -151,7 +142,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
LIMIT 1
OFFSET ?
""",
(begin_last_seen, batch_size)
(begin_last_seen, batch_size),
)
row = txn.fetchone()
if row:
@ -169,7 +160,8 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
logger.info(
"Scanning for duplicate 'user_ips' rows in range: %s <= last_seen < %s",
begin_last_seen, end_last_seen,
begin_last_seen,
end_last_seen,
)
def remove(txn):
@ -207,8 +199,10 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
INNER JOIN user_ips USING (user_id, access_token, ip)
GROUP BY user_id, access_token, ip
HAVING count(*) > 1
""".format(clause),
args
""".format(
clause
),
args,
)
res = txn.fetchall()
@ -254,7 +248,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
DELETE FROM user_ips
WHERE user_id = ? AND access_token = ? AND ip = ? AND last_seen < ?
""",
(user_id, access_token, ip, last_seen)
(user_id, access_token, ip, last_seen),
)
if txn.rowcount == count - 1:
# We deleted all but one of the duplicate rows, i.e. there
@ -263,7 +257,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
continue
elif txn.rowcount >= count:
raise Exception(
"We deleted more duplicate rows from 'user_ips' than expected",
"We deleted more duplicate rows from 'user_ips' than expected"
)
# The previous step didn't delete enough rows, so we fallback to
@ -275,7 +269,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
DELETE FROM user_ips
WHERE user_id = ? AND access_token = ? AND ip = ?
""",
(user_id, access_token, ip)
(user_id, access_token, ip),
)
# Add in one to be the last_seen
@ -285,7 +279,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
(user_id, access_token, ip, device_id, user_agent, last_seen)
VALUES (?, ?, ?, ?, ?, ?)
""",
(user_id, access_token, ip, device_id, user_agent, last_seen)
(user_id, access_token, ip, device_id, user_agent, last_seen),
)
self._background_update_progress_txn(
@ -300,8 +294,9 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
defer.returnValue(batch_size)
@defer.inlineCallbacks
def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id,
now=None):
def insert_client_ip(
self, user_id, access_token, ip, user_agent, device_id, now=None
):
if not now:
now = int(self._clock.time_msec())
key = (user_id, access_token, ip)
@ -329,13 +324,10 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
to_update = self._batch_row_update
self._batch_row_update = {}
return self.runInteraction(
"_update_client_ips_batch", self._update_client_ips_batch_txn,
to_update,
"_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
)
return run_as_background_process(
"update_client_ips", update,
)
return run_as_background_process("update_client_ips", update)
def _update_client_ips_batch_txn(self, txn, to_update):
if "user_ips" in self._unsafe_to_upsert_tables or (
@ -383,7 +375,8 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
res = yield self.runInteraction(
"get_last_client_ip_by_device",
self._get_last_client_ip_by_device_txn,
user_id, device_id,
user_id,
device_id,
retcols=(
"user_id",
"access_token",
@ -416,7 +409,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
bindings = []
if device_id is None:
where_clauses.append("user_id = ?")
bindings.extend((user_id, ))
bindings.extend((user_id,))
else:
where_clauses.append("(user_id = ? AND device_id = ?)")
bindings.extend((user_id, device_id))
@ -428,9 +421,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
"SELECT MAX(last_seen) mls, user_id, device_id FROM user_ips "
"WHERE %(where)s "
"GROUP BY user_id, device_id"
) % {
"where": " OR ".join(where_clauses),
}
) % {"where": " OR ".join(where_clauses)}
sql = (
"SELECT %(retcols)s FROM user_ips "
@ -462,9 +453,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
rows = yield self._simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=[
"access_token", "ip", "user_agent", "last_seen"
],
retcols=["access_token", "ip", "user_agent", "last_seen"],
desc="get_user_ip_and_agents",
)
@ -472,12 +461,14 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
((row["access_token"], row["ip"]), (row["user_agent"], row["last_seen"]))
for row in rows
)
defer.returnValue(list(
{
"access_token": access_token,
"ip": ip,
"user_agent": user_agent,
"last_seen": last_seen,
}
for (access_token, ip), (user_agent, last_seen) in iteritems(results)
))
defer.returnValue(
list(
{
"access_token": access_token,
"ip": ip,
"user_agent": user_agent,
"last_seen": last_seen,
}
for (access_token, ip), (user_agent, last_seen) in iteritems(results)
)
)

View file

@ -57,9 +57,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
" ORDER BY stream_id ASC"
" LIMIT ?"
)
txn.execute(sql, (
user_id, device_id, last_stream_id, current_stream_id, limit
))
txn.execute(
sql, (user_id, device_id, last_stream_id, current_stream_id, limit)
)
messages = []
for row in txn:
stream_pos = row[0]
@ -69,7 +69,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
return (messages, stream_pos)
return self.runInteraction(
"get_new_messages_for_device", get_new_messages_for_device_txn,
"get_new_messages_for_device", get_new_messages_for_device_txn
)
@defer.inlineCallbacks
@ -146,9 +146,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
" ORDER BY stream_id ASC"
" LIMIT ?"
)
txn.execute(sql, (
destination, last_stream_id, current_stream_id, limit
))
txn.execute(sql, (destination, last_stream_id, current_stream_id, limit))
messages = []
for row in txn:
stream_pos = row[0]
@ -172,6 +170,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
Returns:
A deferred that resolves when the messages have been deleted.
"""
def delete_messages_for_remote_destination_txn(txn):
sql = (
"DELETE FROM device_federation_outbox"
@ -181,8 +180,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
txn.execute(sql, (destination, up_to_stream_id))
return self.runInteraction(
"delete_device_msgs_for_remote",
delete_messages_for_remote_destination_txn
"delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn
)
@ -200,8 +198,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
)
self.register_background_update_handler(
self.DEVICE_INBOX_STREAM_ID,
self._background_drop_index_device_inbox,
self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
)
# Map of (user_id, device_id) to the last stream_id that has been
@ -214,8 +211,9 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
)
@defer.inlineCallbacks
def add_messages_to_device_inbox(self, local_messages_by_user_then_device,
remote_messages_by_destination):
def add_messages_to_device_inbox(
self, local_messages_by_user_then_device, remote_messages_by_destination
):
"""Used to send messages from this server.
Args:
@ -252,15 +250,10 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
yield self.runInteraction(
"add_messages_to_device_inbox",
add_messages_txn,
now_ms,
stream_id,
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
)
for user_id in local_messages_by_user_then_device.keys():
self._device_inbox_stream_cache.entity_has_changed(
user_id, stream_id
)
self._device_inbox_stream_cache.entity_has_changed(user_id, stream_id)
for destination in remote_messages_by_destination.keys():
self._device_federation_outbox_stream_cache.entity_has_changed(
destination, stream_id
@ -277,7 +270,8 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
# origin. This can happen if the origin doesn't receive our
# acknowledgement from the first time we received the message.
already_inserted = self._simple_select_one_txn(
txn, table="device_federation_inbox",
txn,
table="device_federation_inbox",
keyvalues={"origin": origin, "message_id": message_id},
retcols=("message_id",),
allow_none=True,
@ -288,7 +282,8 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
# Add an entry for this message_id so that we know we've processed
# it.
self._simple_insert_txn(
txn, table="device_federation_inbox",
txn,
table="device_federation_inbox",
values={
"origin": origin,
"message_id": message_id,
@ -311,19 +306,14 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
stream_id,
)
for user_id in local_messages_by_user_then_device.keys():
self._device_inbox_stream_cache.entity_has_changed(
user_id, stream_id
)
self._device_inbox_stream_cache.entity_has_changed(user_id, stream_id)
defer.returnValue(stream_id)
def _add_messages_to_local_device_inbox_txn(self, txn, stream_id,
messages_by_user_then_device):
sql = (
"UPDATE device_max_stream_id"
" SET stream_id = ?"
" WHERE stream_id < ?"
)
def _add_messages_to_local_device_inbox_txn(
self, txn, stream_id, messages_by_user_then_device
):
sql = "UPDATE device_max_stream_id" " SET stream_id = ?" " WHERE stream_id < ?"
txn.execute(sql, (stream_id, stream_id))
local_by_user_then_device = {}
@ -332,10 +322,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
devices = list(messages_by_device.keys())
if len(devices) == 1 and devices[0] == "*":
# Handle wildcard device_ids.
sql = (
"SELECT device_id FROM devices"
" WHERE user_id = ?"
)
sql = "SELECT device_id FROM devices" " WHERE user_id = ?"
txn.execute(sql, (user_id,))
message_json = json.dumps(messages_by_device["*"])
for row in txn:
@ -428,9 +415,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
def _background_drop_index_device_inbox(self, progress, batch_size):
def reindex_txn(conn):
txn = conn.cursor()
txn.execute(
"DROP INDEX IF EXISTS device_inbox_stream_id"
)
txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
txn.close()
yield self.runWithConnection(reindex_txn)

View file

@ -67,7 +67,7 @@ class DeviceWorkerStore(SQLBaseStore):
table="devices",
keyvalues={"user_id": user_id},
retcols=("user_id", "device_id", "display_name"),
desc="get_devices_by_user"
desc="get_devices_by_user",
)
defer.returnValue({d["device_id"]: d for d in devices})
@ -87,21 +87,23 @@ class DeviceWorkerStore(SQLBaseStore):
return (now_stream_id, [])
return self.runInteraction(
"get_devices_by_remote", self._get_devices_by_remote_txn,
destination, from_stream_id, now_stream_id,
"get_devices_by_remote",
self._get_devices_by_remote_txn,
destination,
from_stream_id,
now_stream_id,
)
def _get_devices_by_remote_txn(self, txn, destination, from_stream_id,
now_stream_id):
def _get_devices_by_remote_txn(
self, txn, destination, from_stream_id, now_stream_id
):
sql = """
SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
GROUP BY user_id, device_id
LIMIT 20
"""
txn.execute(
sql, (destination, from_stream_id, now_stream_id, False)
)
txn.execute(sql, (destination, from_stream_id, now_stream_id, False))
# maps (user_id, device_id) -> stream_id
query_map = {(r[0], r[1]): r[2] for r in txn}
@ -112,7 +114,10 @@ class DeviceWorkerStore(SQLBaseStore):
now_stream_id = max(stream_id for stream_id in itervalues(query_map))
devices = self._get_e2e_device_keys_txn(
txn, query_map.keys(), include_all_devices=True, include_deleted_devices=True
txn,
query_map.keys(),
include_all_devices=True,
include_deleted_devices=True,
)
prev_sent_id_sql = """
@ -157,8 +162,10 @@ class DeviceWorkerStore(SQLBaseStore):
"""Mark that updates have successfully been sent to the destination.
"""
return self.runInteraction(
"mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn,
destination, stream_id,
"mark_as_sent_devices_by_remote",
self._mark_as_sent_devices_by_remote_txn,
destination,
stream_id,
)
def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
@ -173,7 +180,7 @@ class DeviceWorkerStore(SQLBaseStore):
WHERE destination = ? AND o.stream_id <= ?
GROUP BY user_id
"""
txn.execute(sql, (destination, stream_id,))
txn.execute(sql, (destination, stream_id))
rows = txn.fetchall()
sql = """
@ -181,16 +188,14 @@ class DeviceWorkerStore(SQLBaseStore):
SET stream_id = ?
WHERE destination = ? AND user_id = ?
"""
txn.executemany(
sql, ((row[1], destination, row[0],) for row in rows if row[2])
)
txn.executemany(sql, ((row[1], destination, row[0]) for row in rows if row[2]))
sql = """
INSERT INTO device_lists_outbound_last_success
(destination, user_id, stream_id) VALUES (?, ?, ?)
"""
txn.executemany(
sql, ((destination, row[0], row[1],) for row in rows if not row[2])
sql, ((destination, row[0], row[1]) for row in rows if not row[2])
)
# Delete all sent outbound pokes
@ -198,7 +203,7 @@ class DeviceWorkerStore(SQLBaseStore):
DELETE FROM device_lists_outbound_pokes
WHERE destination = ? AND stream_id <= ?
"""
txn.execute(sql, (destination, stream_id,))
txn.execute(sql, (destination, stream_id))
def get_device_stream_token(self):
return self._device_list_id_gen.get_current_token()
@ -240,10 +245,7 @@ class DeviceWorkerStore(SQLBaseStore):
def _get_cached_user_device(self, user_id, device_id):
content = yield self._simple_select_one_onecol(
table="device_lists_remote_cache",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
keyvalues={"user_id": user_id, "device_id": device_id},
retcol="content",
desc="_get_cached_user_device",
)
@ -253,16 +255,13 @@ class DeviceWorkerStore(SQLBaseStore):
def _get_cached_devices_for_user(self, user_id):
devices = yield self._simple_select_list(
table="device_lists_remote_cache",
keyvalues={
"user_id": user_id,
},
keyvalues={"user_id": user_id},
retcols=("device_id", "content"),
desc="_get_cached_devices_for_user",
)
defer.returnValue({
device["device_id"]: db_to_json(device["content"])
for device in devices
})
defer.returnValue(
{device["device_id"]: db_to_json(device["content"]) for device in devices}
)
def get_devices_with_keys_by_user(self, user_id):
"""Get all devices (with any device keys) for a user
@ -272,7 +271,8 @@ class DeviceWorkerStore(SQLBaseStore):
"""
return self.runInteraction(
"get_devices_with_keys_by_user",
self._get_devices_with_keys_by_user_txn, user_id,
self._get_devices_with_keys_by_user_txn,
user_id,
)
def _get_devices_with_keys_by_user_txn(self, txn, user_id):
@ -286,9 +286,7 @@ class DeviceWorkerStore(SQLBaseStore):
user_devices = devices[user_id]
results = []
for device_id, device in iteritems(user_devices):
result = {
"device_id": device_id,
}
result = {"device_id": device_id}
key_json = device.get("key_json", None)
if key_json:
@ -315,7 +313,9 @@ class DeviceWorkerStore(SQLBaseStore):
sql = """
SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ?
"""
rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key)
rows = yield self._execute(
"get_user_whose_devices_changed", None, sql, from_key
)
defer.returnValue(set(row[0] for row in rows))
def get_all_device_list_changes_for_remotes(self, from_key, to_key):
@ -333,8 +333,7 @@ class DeviceWorkerStore(SQLBaseStore):
GROUP BY user_id, destination
"""
return self._execute(
"get_all_device_list_changes_for_remotes", None,
sql, from_key, to_key
"get_all_device_list_changes_for_remotes", None, sql, from_key, to_key
)
@cached(max_entries=10000)
@ -350,21 +349,22 @@ class DeviceWorkerStore(SQLBaseStore):
allow_none=True,
)
@cachedList(cached_method_name="get_device_list_last_stream_id_for_remote",
list_name="user_ids", inlineCallbacks=True)
@cachedList(
cached_method_name="get_device_list_last_stream_id_for_remote",
list_name="user_ids",
inlineCallbacks=True,
)
def get_device_list_last_stream_id_for_remotes(self, user_ids):
rows = yield self._simple_select_many_batch(
table="device_lists_remote_extremeties",
column="user_id",
iterable=user_ids,
retcols=("user_id", "stream_id",),
retcols=("user_id", "stream_id"),
desc="get_device_list_last_stream_id_for_remotes",
)
results = {user_id: None for user_id in user_ids}
results.update({
row["user_id"]: row["stream_id"] for row in rows
})
results.update({row["user_id"]: row["stream_id"] for row in rows})
defer.returnValue(results)
@ -376,14 +376,10 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
# Map of (user_id, device_id) -> bool. If there is an entry that implies
# the device exists.
self.device_id_exists_cache = Cache(
name="device_id_exists",
keylen=2,
max_entries=10000,
name="device_id_exists", keylen=2, max_entries=10000
)
self._clock.looping_call(
self._prune_old_outbound_device_pokes, 60 * 60 * 1000
)
self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
self.register_background_index_update(
"device_lists_stream_idx",
@ -417,8 +413,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
)
@defer.inlineCallbacks
def store_device(self, user_id, device_id,
initial_device_display_name):
def store_device(self, user_id, device_id, initial_device_display_name):
"""Ensure the given device is known; add it to the store if not
Args:
@ -440,7 +435,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
values={
"user_id": user_id,
"device_id": device_id,
"display_name": initial_device_display_name
"display_name": initial_device_display_name,
},
desc="store_device",
or_ignore=True,
@ -448,12 +443,17 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
self.device_id_exists_cache.prefill(key, True)
defer.returnValue(inserted)
except Exception as e:
logger.error("store_device with device_id=%s(%r) user_id=%s(%r)"
" display_name=%s(%r) failed: %s",
type(device_id).__name__, device_id,
type(user_id).__name__, user_id,
type(initial_device_display_name).__name__,
initial_device_display_name, e)
logger.error(
"store_device with device_id=%s(%r) user_id=%s(%r)"
" display_name=%s(%r) failed: %s",
type(device_id).__name__,
device_id,
type(user_id).__name__,
user_id,
type(initial_device_display_name).__name__,
initial_device_display_name,
e,
)
raise StoreError(500, "Problem storing device.")
@defer.inlineCallbacks
@ -525,15 +525,14 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
"""
yield self._simple_delete(
table="device_lists_remote_extremeties",
keyvalues={
"user_id": user_id,
},
keyvalues={"user_id": user_id},
desc="mark_remote_user_device_list_as_unsubscribed",
)
self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))
def update_remote_device_list_cache_entry(self, user_id, device_id, content,
stream_id):
def update_remote_device_list_cache_entry(
self, user_id, device_id, content, stream_id
):
"""Updates a single device in the cache of a remote user's devicelist.
Note: assumes that we are the only thread that can be updating this user's
@ -551,42 +550,35 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
return self.runInteraction(
"update_remote_device_list_cache_entry",
self._update_remote_device_list_cache_entry_txn,
user_id, device_id, content, stream_id,
user_id,
device_id,
content,
stream_id,
)
def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id,
content, stream_id):
def _update_remote_device_list_cache_entry_txn(
self, txn, user_id, device_id, content, stream_id
):
if content.get("deleted"):
self._simple_delete_txn(
txn,
table="device_lists_remote_cache",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
keyvalues={"user_id": user_id, "device_id": device_id},
)
txn.call_after(
self.device_id_exists_cache.invalidate, (user_id, device_id,)
)
txn.call_after(self.device_id_exists_cache.invalidate, (user_id, device_id))
else:
self._simple_upsert_txn(
txn,
table="device_lists_remote_cache",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
values={
"content": json.dumps(content),
},
keyvalues={"user_id": user_id, "device_id": device_id},
values={"content": json.dumps(content)},
# we don't need to lock, because we assume we are the only thread
# updating this user's devices.
lock=False,
)
txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id,))
txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id))
txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,))
txn.call_after(
self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
@ -595,13 +587,8 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
self._simple_upsert_txn(
txn,
table="device_lists_remote_extremeties",
keyvalues={
"user_id": user_id,
},
values={
"stream_id": stream_id,
},
keyvalues={"user_id": user_id},
values={"stream_id": stream_id},
# again, we can assume we are the only thread updating this user's
# extremity.
lock=False,
@ -624,17 +611,14 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
return self.runInteraction(
"update_remote_device_list_cache",
self._update_remote_device_list_cache_txn,
user_id, devices, stream_id,
user_id,
devices,
stream_id,
)
def _update_remote_device_list_cache_txn(self, txn, user_id, devices,
stream_id):
def _update_remote_device_list_cache_txn(self, txn, user_id, devices, stream_id):
self._simple_delete_txn(
txn,
table="device_lists_remote_cache",
keyvalues={
"user_id": user_id,
},
txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
)
self._simple_insert_many_txn(
@ -647,7 +631,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
"content": json.dumps(content),
}
for content in devices
]
],
)
txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,))
@ -659,13 +643,8 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
self._simple_upsert_txn(
txn,
table="device_lists_remote_extremeties",
keyvalues={
"user_id": user_id,
},
values={
"stream_id": stream_id,
},
keyvalues={"user_id": user_id},
values={"stream_id": stream_id},
# we don't need to lock, because we can assume we are the only thread
# updating this user's extremity.
lock=False,
@ -678,8 +657,12 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
"""
with self._device_list_id_gen.get_next() as stream_id:
yield self.runInteraction(
"add_device_change_to_streams", self._add_device_change_txn,
user_id, device_ids, hosts, stream_id,
"add_device_change_to_streams",
self._add_device_change_txn,
user_id,
device_ids,
hosts,
stream_id,
)
defer.returnValue(stream_id)
@ -687,13 +670,13 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
now = self._clock.time_msec()
txn.call_after(
self._device_list_stream_cache.entity_has_changed,
user_id, stream_id,
self._device_list_stream_cache.entity_has_changed, user_id, stream_id
)
for host in hosts:
txn.call_after(
self._device_list_federation_stream_cache.entity_has_changed,
host, stream_id,
host,
stream_id,
)
# Delete older entries in the table, as we really only care about
@ -703,20 +686,16 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
DELETE FROM device_lists_stream
WHERE user_id = ? AND device_id = ? AND stream_id < ?
""",
[(user_id, device_id, stream_id) for device_id in device_ids]
[(user_id, device_id, stream_id) for device_id in device_ids],
)
self._simple_insert_many_txn(
txn,
table="device_lists_stream",
values=[
{
"stream_id": stream_id,
"user_id": user_id,
"device_id": device_id,
}
{"stream_id": stream_id, "user_id": user_id, "device_id": device_id}
for device_id in device_ids
]
],
)
self._simple_insert_many_txn(
@ -733,7 +712,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
}
for destination in hosts
for device_id in device_ids
]
],
)
def _prune_old_outbound_device_pokes(self):
@ -764,11 +743,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
"""
txn.executemany(
delete_sql,
(
(yesterday, row[0], row[1], row[2])
for row in rows
)
delete_sql, ((yesterday, row[0], row[1], row[2]) for row in rows)
)
# Since we've deleted unsent deltas, we need to remove the entry
@ -792,12 +767,8 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
def f(conn):
txn = conn.cursor()
txn.execute(
"DROP INDEX IF EXISTS device_lists_remote_cache_id"
)
txn.execute(
"DROP INDEX IF EXISTS device_lists_remote_extremeties_id"
)
txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
txn.close()
yield self.runWithConnection(f)

View file

@ -22,10 +22,7 @@ from synapse.util.caches.descriptors import cached
from ._base import SQLBaseStore
RoomAliasMapping = namedtuple(
"RoomAliasMapping",
("room_id", "room_alias", "servers",)
)
RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers"))
class DirectoryWorkerStore(SQLBaseStore):
@ -63,16 +60,12 @@ class DirectoryWorkerStore(SQLBaseStore):
defer.returnValue(None)
return
defer.returnValue(
RoomAliasMapping(room_id, room_alias.to_string(), servers)
)
defer.returnValue(RoomAliasMapping(room_id, room_alias.to_string(), servers))
def get_room_alias_creator(self, room_alias):
return self._simple_select_one_onecol(
table="room_aliases",
keyvalues={
"room_alias": room_alias,
},
keyvalues={"room_alias": room_alias},
retcol="creator",
desc="get_room_alias_creator",
)
@ -101,6 +94,7 @@ class DirectoryStore(DirectoryWorkerStore):
Returns:
Deferred
"""
def alias_txn(txn):
self._simple_insert_txn(
txn,
@ -115,10 +109,10 @@ class DirectoryStore(DirectoryWorkerStore):
self._simple_insert_many_txn(
txn,
table="room_alias_servers",
values=[{
"room_alias": room_alias.to_string(),
"server": server,
} for server in servers],
values=[
{"room_alias": room_alias.to_string(), "server": server}
for server in servers
],
)
self._invalidate_cache_and_stream(
@ -126,9 +120,7 @@ class DirectoryStore(DirectoryWorkerStore):
)
try:
ret = yield self.runInteraction(
"create_room_alias_association", alias_txn
)
ret = yield self.runInteraction("create_room_alias_association", alias_txn)
except self.database_engine.module.IntegrityError:
raise SynapseError(
409, "Room alias %s already exists" % room_alias.to_string()
@ -138,9 +130,7 @@ class DirectoryStore(DirectoryWorkerStore):
@defer.inlineCallbacks
def delete_room_alias(self, room_alias):
room_id = yield self.runInteraction(
"delete_room_alias",
self._delete_room_alias_txn,
room_alias,
"delete_room_alias", self._delete_room_alias_txn, room_alias
)
defer.returnValue(room_id)
@ -148,7 +138,7 @@ class DirectoryStore(DirectoryWorkerStore):
def _delete_room_alias_txn(self, txn, room_alias):
txn.execute(
"SELECT room_id FROM room_aliases WHERE room_alias = ?",
(room_alias.to_string(),)
(room_alias.to_string(),),
)
res = txn.fetchone()
@ -158,31 +148,29 @@ class DirectoryStore(DirectoryWorkerStore):
return None
txn.execute(
"DELETE FROM room_aliases WHERE room_alias = ?",
(room_alias.to_string(),)
"DELETE FROM room_aliases WHERE room_alias = ?", (room_alias.to_string(),)
)
txn.execute(
"DELETE FROM room_alias_servers WHERE room_alias = ?",
(room_alias.to_string(),)
(room_alias.to_string(),),
)
self._invalidate_cache_and_stream(
txn, self.get_aliases_for_room, (room_id,)
)
self._invalidate_cache_and_stream(txn, self.get_aliases_for_room, (room_id,))
return room_id
def update_aliases_for_room(self, old_room_id, new_room_id, creator):
def _update_aliases_for_room_txn(txn):
sql = "UPDATE room_aliases SET room_id = ?, creator = ? WHERE room_id = ?"
txn.execute(sql, (new_room_id, creator, old_room_id,))
txn.execute(sql, (new_room_id, creator, old_room_id))
self._invalidate_cache_and_stream(
txn, self.get_aliases_for_room, (old_room_id,)
)
self._invalidate_cache_and_stream(
txn, self.get_aliases_for_room, (new_room_id,)
)
return self.runInteraction(
"_update_aliases_for_room_txn", _update_aliases_for_room_txn
)

View file

@ -23,7 +23,6 @@ from ._base import SQLBaseStore
class EndToEndRoomKeyStore(SQLBaseStore):
@defer.inlineCallbacks
def get_e2e_room_key(self, user_id, version, room_id, session_id):
"""Get the encrypted E2E room key for a given session from a given
@ -97,9 +96,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
)
@defer.inlineCallbacks
def get_e2e_room_keys(
self, user_id, version, room_id=None, session_id=None
):
def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
"""Bulk get the E2E room keys for a given backup, optionally filtered to a given
room, or a given session.
@ -123,10 +120,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
except ValueError:
defer.returnValue({'rooms': {}})
keyvalues = {
"user_id": user_id,
"version": version,
}
keyvalues = {"user_id": user_id, "version": version}
if room_id:
keyvalues['room_id'] = room_id
if session_id:
@ -160,9 +154,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
defer.returnValue(sessions)
@defer.inlineCallbacks
def delete_e2e_room_keys(
self, user_id, version, room_id=None, session_id=None
):
def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
"""Bulk delete the E2E room keys for a given backup, optionally filtered to a given
room or a given session.
@ -180,19 +172,14 @@ class EndToEndRoomKeyStore(SQLBaseStore):
A deferred of the deletion transaction
"""
keyvalues = {
"user_id": user_id,
"version": int(version),
}
keyvalues = {"user_id": user_id, "version": int(version)}
if room_id:
keyvalues['room_id'] = room_id
if session_id:
keyvalues['session_id'] = session_id
yield self._simple_delete(
table="e2e_room_keys",
keyvalues=keyvalues,
desc="delete_e2e_room_keys",
table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys"
)
@staticmethod
@ -200,7 +187,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
txn.execute(
"SELECT MAX(version) FROM e2e_room_keys_versions "
"WHERE user_id=? AND deleted=0",
(user_id,)
(user_id,),
)
row = txn.fetchone()
if not row:
@ -238,24 +225,15 @@ class EndToEndRoomKeyStore(SQLBaseStore):
result = self._simple_select_one_txn(
txn,
table="e2e_room_keys_versions",
keyvalues={
"user_id": user_id,
"version": this_version,
"deleted": 0,
},
retcols=(
"version",
"algorithm",
"auth_data",
),
keyvalues={"user_id": user_id, "version": this_version, "deleted": 0},
retcols=("version", "algorithm", "auth_data"),
)
result["auth_data"] = json.loads(result["auth_data"])
result["version"] = str(result["version"])
return result
return self.runInteraction(
"get_e2e_room_keys_version_info",
_get_e2e_room_keys_version_info_txn
"get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn
)
def create_e2e_room_keys_version(self, user_id, info):
@ -273,7 +251,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
def _create_e2e_room_keys_version_txn(txn):
txn.execute(
"SELECT MAX(version) FROM e2e_room_keys_versions WHERE user_id=?",
(user_id,)
(user_id,),
)
current_version = txn.fetchone()[0]
if current_version is None:
@ -309,14 +287,9 @@ class EndToEndRoomKeyStore(SQLBaseStore):
return self._simple_update(
table="e2e_room_keys_versions",
keyvalues={
"user_id": user_id,
"version": version,
},
updatevalues={
"auth_data": json.dumps(info["auth_data"]),
},
desc="update_e2e_room_keys_version"
keyvalues={"user_id": user_id, "version": version},
updatevalues={"auth_data": json.dumps(info["auth_data"])},
desc="update_e2e_room_keys_version",
)
def delete_e2e_room_keys_version(self, user_id, version=None):
@ -341,16 +314,10 @@ class EndToEndRoomKeyStore(SQLBaseStore):
return self._simple_update_one_txn(
txn,
table="e2e_room_keys_versions",
keyvalues={
"user_id": user_id,
"version": this_version,
},
updatevalues={
"deleted": 1,
}
keyvalues={"user_id": user_id, "version": this_version},
updatevalues={"deleted": 1},
)
return self.runInteraction(
"delete_e2e_room_keys_version",
_delete_e2e_room_keys_version_txn
"delete_e2e_room_keys_version", _delete_e2e_room_keys_version_txn
)

View file

@ -26,8 +26,7 @@ from ._base import SQLBaseStore, db_to_json
class EndToEndKeyWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_e2e_device_keys(
self, query_list, include_all_devices=False,
include_deleted_devices=False,
self, query_list, include_all_devices=False, include_deleted_devices=False
):
"""Fetch a list of device keys.
Args:
@ -45,8 +44,11 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
defer.returnValue({})
results = yield self.runInteraction(
"get_e2e_device_keys", self._get_e2e_device_keys_txn,
query_list, include_all_devices, include_deleted_devices,
"get_e2e_device_keys",
self._get_e2e_device_keys_txn,
query_list,
include_all_devices,
include_deleted_devices,
)
for user_id, device_keys in iteritems(results):
@ -56,8 +58,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
defer.returnValue(results)
def _get_e2e_device_keys_txn(
self, txn, query_list, include_all_devices=False,
include_deleted_devices=False,
self, txn, query_list, include_all_devices=False, include_deleted_devices=False
):
query_clauses = []
query_params = []
@ -87,7 +88,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
" WHERE %s"
) % (
"LEFT" if include_all_devices else "INNER",
" OR ".join("(" + q + ")" for q in query_clauses)
" OR ".join("(" + q + ")" for q in query_clauses),
)
txn.execute(sql, query_params)
@ -124,17 +125,14 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
table="e2e_one_time_keys_json",
column="key_id",
iterable=key_ids,
retcols=("algorithm", "key_id", "key_json",),
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
retcols=("algorithm", "key_id", "key_json"),
keyvalues={"user_id": user_id, "device_id": device_id},
desc="add_e2e_one_time_keys_check",
)
defer.returnValue({
(row["algorithm"], row["key_id"]): row["key_json"] for row in rows
})
defer.returnValue(
{(row["algorithm"], row["key_id"]): row["key_json"] for row in rows}
)
@defer.inlineCallbacks
def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys):
@ -155,7 +153,8 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
# `add_e2e_one_time_keys` then they'll conflict and we will only
# insert one set.
self._simple_insert_many_txn(
txn, table="e2e_one_time_keys_json",
txn,
table="e2e_one_time_keys_json",
values=[
{
"user_id": user_id,
@ -169,8 +168,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
],
)
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id,)
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
yield self.runInteraction(
"add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
)
@ -181,6 +181,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
Returns:
Dict mapping from algorithm to number of keys for that algorithm.
"""
def _count_e2e_one_time_keys(txn):
sql = (
"SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json"
@ -192,9 +193,8 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
for algorithm, key_count in txn:
result[algorithm] = key_count
return result
return self.runInteraction(
"count_e2e_one_time_keys", _count_e2e_one_time_keys
)
return self.runInteraction("count_e2e_one_time_keys", _count_e2e_one_time_keys)
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
@ -202,14 +202,12 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
"""Stores device keys for a device. Returns whether there was a change
or the keys were already in the database.
"""
def _set_e2e_device_keys_txn(txn):
old_key_json = self._simple_select_one_onecol_txn(
txn,
table="e2e_device_keys_json",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
keyvalues={"user_id": user_id, "device_id": device_id},
retcol="key_json",
allow_none=True,
)
@ -224,24 +222,17 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
self._simple_upsert_txn(
txn,
table="e2e_device_keys_json",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
values={
"ts_added_ms": time_now,
"key_json": new_key_json,
}
keyvalues={"user_id": user_id, "device_id": device_id},
values={"ts_added_ms": time_now, "key_json": new_key_json},
)
return True
return self.runInteraction(
"set_e2e_device_keys", _set_e2e_device_keys_txn
)
return self.runInteraction("set_e2e_device_keys", _set_e2e_device_keys_txn)
def claim_e2e_one_time_keys(self, query_list):
"""Take a list of one time keys out of the database"""
def _claim_e2e_one_time_keys(txn):
sql = (
"SELECT key_id, key_json FROM e2e_one_time_keys_json"
@ -265,12 +256,11 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
for user_id, device_id, algorithm, key_id in delete:
txn.execute(sql, (user_id, device_id, algorithm, key_id))
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id,)
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
return result
return self.runInteraction(
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys
)
return self.runInteraction("claim_e2e_one_time_keys", _claim_e2e_one_time_keys)
def delete_e2e_keys_by_device(self, user_id, device_id):
def delete_e2e_keys_by_device_txn(txn):
@ -285,8 +275,9 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
keyvalues={"user_id": user_id, "device_id": device_id},
)
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id,)
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
return self.runInteraction(
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
)

View file

@ -20,10 +20,7 @@ from ._base import IncorrectDatabaseSetup
from .postgres import PostgresEngine
from .sqlite import Sqlite3Engine
SUPPORTED_MODULE = {
"sqlite3": Sqlite3Engine,
"psycopg2": PostgresEngine,
}
SUPPORTED_MODULE = {"sqlite3": Sqlite3Engine, "psycopg2": PostgresEngine}
def create_engine(database_config):
@ -32,15 +29,12 @@ def create_engine(database_config):
if engine_class:
# pypy requires psycopg2cffi rather than psycopg2
if (name == "psycopg2" and
platform.python_implementation() == "PyPy"):
if name == "psycopg2" and platform.python_implementation() == "PyPy":
name = "psycopg2cffi"
module = importlib.import_module(name)
return engine_class(module, database_config)
raise RuntimeError(
"Unsupported database engine '%s'" % (name,)
)
raise RuntimeError("Unsupported database engine '%s'" % (name,))
__all__ = ["create_engine", "IncorrectDatabaseSetup"]

View file

@ -23,7 +23,7 @@ class PostgresEngine(object):
self.module = database_module
self.module.extensions.register_type(self.module.extensions.UNICODE)
self.synchronous_commit = database_config.get("synchronous_commit", True)
self._version = None # unknown as yet
self._version = None # unknown as yet
def check_database(self, txn):
txn.execute("SHOW SERVER_ENCODING")
@ -31,8 +31,7 @@ class PostgresEngine(object):
if rows and rows[0][0] != "UTF8":
raise IncorrectDatabaseSetup(
"Database has incorrect encoding: '%s' instead of 'UTF8'\n"
"See docs/postgres.rst for more information."
% (rows[0][0],)
"See docs/postgres.rst for more information." % (rows[0][0],)
)
def convert_param_style(self, sql):
@ -103,12 +102,6 @@ class PostgresEngine(object):
# https://www.postgresql.org/docs/current/libpq-status.html#LIBPQ-PQSERVERVERSION
if numver >= 100000:
return "%i.%i" % (
numver / 10000, numver % 10000,
)
return "%i.%i" % (numver / 10000, numver % 10000)
else:
return "%i.%i.%i" % (
numver / 10000,
(numver % 10000) / 100,
numver % 100,
)
return "%i.%i.%i" % (numver / 10000, (numver % 10000) / 100, numver % 100)

View file

@ -82,9 +82,10 @@ class Sqlite3Engine(object):
# Following functions taken from: https://github.com/coleifer/peewee
def _parse_match_info(buf):
bufsize = len(buf)
return [struct.unpack('@I', buf[i:i + 4])[0] for i in range(0, bufsize, 4)]
return [struct.unpack('@I', buf[i : i + 4])[0] for i in range(0, bufsize, 4)]
def _rank(raw_match_info):
@ -98,7 +99,7 @@ def _rank(raw_match_info):
phrase_info_idx = 2 + (phrase_num * c * 3)
for col_num in range(c):
col_idx = phrase_info_idx + (col_num * 3)
x1, x2 = match_info[col_idx:col_idx + 2]
x1, x2 = match_info[col_idx : col_idx + 2]
if x1 > 0:
score += float(x1) / x2
return score

View file

@ -32,8 +32,7 @@ from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
SQLBaseStore):
class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
def get_auth_chain(self, event_ids, include_given=False):
"""Get auth events for given event_ids. The events *must* be state events.
@ -45,7 +44,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
list of events
"""
return self.get_auth_chain_ids(
event_ids, include_given=include_given,
event_ids, include_given=include_given
).addCallback(self._get_events)
def get_auth_chain_ids(self, event_ids, include_given=False):
@ -59,9 +58,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
list of event_ids
"""
return self.runInteraction(
"get_auth_chain_ids",
self._get_auth_chain_ids_txn,
event_ids, include_given
"get_auth_chain_ids", self._get_auth_chain_ids_txn, event_ids, include_given
)
def _get_auth_chain_ids_txn(self, txn, event_ids, include_given):
@ -70,23 +67,15 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
else:
results = set()
base_sql = (
"SELECT auth_id FROM event_auth WHERE event_id IN (%s)"
)
base_sql = "SELECT auth_id FROM event_auth WHERE event_id IN (%s)"
front = set(event_ids)
while front:
new_front = set()
front_list = list(front)
chunks = [
front_list[x:x + 100]
for x in range(0, len(front), 100)
]
chunks = [front_list[x : x + 100] for x in range(0, len(front), 100)]
for chunk in chunks:
txn.execute(
base_sql % (",".join(["?"] * len(chunk)),),
chunk
)
txn.execute(base_sql % (",".join(["?"] * len(chunk)),), chunk)
new_front.update([r[0] for r in txn])
new_front -= results
@ -98,9 +87,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
def get_oldest_events_in_room(self, room_id):
return self.runInteraction(
"get_oldest_events_in_room",
self._get_oldest_events_in_room_txn,
room_id,
"get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id
)
def get_oldest_events_with_depth_in_room(self, room_id):
@ -121,7 +108,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
" GROUP BY b.event_id"
)
txn.execute(sql, (room_id, False,))
txn.execute(sql, (room_id, False))
return dict(txn)
@ -152,9 +139,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
return self._simple_select_onecol_txn(
txn,
table="event_backward_extremities",
keyvalues={
"room_id": room_id,
},
keyvalues={"room_id": room_id},
retcol="event_id",
)
@ -209,9 +194,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
def get_latest_event_ids_in_room(self, room_id):
return self._simple_select_onecol(
table="event_forward_extremities",
keyvalues={
"room_id": room_id,
},
keyvalues={"room_id": room_id},
retcol="event_id",
desc="get_latest_event_ids_in_room",
)
@ -225,14 +208,13 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
"WHERE f.room_id = ?"
)
txn.execute(sql, (room_id, ))
txn.execute(sql, (room_id,))
results = []
for event_id, depth in txn.fetchall():
hashes = self._get_event_reference_hashes_txn(txn, event_id)
prev_hashes = {
k: encode_base64(v) for k, v in hashes.items()
if k == "sha256"
k: encode_base64(v) for k, v in hashes.items() if k == "sha256"
}
results.append((event_id, prev_hashes, depth))
@ -242,9 +224,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
""" For hte given room, get the minimum depth we have seen for it.
"""
return self.runInteraction(
"get_min_depth",
self._get_min_depth_interaction,
room_id,
"get_min_depth", self._get_min_depth_interaction, room_id
)
def _get_min_depth_interaction(self, txn, room_id):
@ -300,7 +280,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
if stream_ordering <= self.stream_ordering_month_ago:
raise StoreError(400, "stream_ordering too old")
sql = ("""
sql = """
SELECT event_id FROM stream_ordering_to_exterm
INNER JOIN (
SELECT room_id, MAX(stream_ordering) AS stream_ordering
@ -308,15 +288,14 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
WHERE stream_ordering <= ? GROUP BY room_id
) AS rms USING (room_id, stream_ordering)
WHERE room_id = ?
""")
"""
def get_forward_extremeties_for_room_txn(txn):
txn.execute(sql, (stream_ordering, room_id))
return [event_id for event_id, in txn]
return self.runInteraction(
"get_forward_extremeties_for_room",
get_forward_extremeties_for_room_txn
"get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
)
def get_backfill_events(self, room_id, event_list, limit):
@ -329,19 +308,21 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
event_list (list)
limit (int)
"""
return self.runInteraction(
"get_backfill_events",
self._get_backfill_events, room_id, event_list, limit
).addCallback(
self._get_events
).addCallback(
lambda l: sorted(l, key=lambda e: -e.depth)
return (
self.runInteraction(
"get_backfill_events",
self._get_backfill_events,
room_id,
event_list,
limit,
)
.addCallback(self._get_events)
.addCallback(lambda l: sorted(l, key=lambda e: -e.depth))
)
def _get_backfill_events(self, txn, room_id, event_list, limit):
logger.debug(
"_get_backfill_events: %s, %s, %s",
room_id, repr(event_list), limit
"_get_backfill_events: %s, %s, %s", room_id, repr(event_list), limit
)
event_results = set()
@ -364,10 +345,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
depth = self._simple_select_one_onecol_txn(
txn,
table="events",
keyvalues={
"event_id": event_id,
"room_id": room_id,
},
keyvalues={"event_id": event_id, "room_id": room_id},
retcol="depth",
allow_none=True,
)
@ -386,10 +364,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
event_results.add(event_id)
txn.execute(
query,
(event_id, False, limit - len(event_results))
)
txn.execute(query, (event_id, False, limit - len(event_results)))
for row in txn:
if row[1] not in event_results:
@ -398,18 +373,19 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
return event_results
@defer.inlineCallbacks
def get_missing_events(self, room_id, earliest_events, latest_events,
limit):
def get_missing_events(self, room_id, earliest_events, latest_events, limit):
ids = yield self.runInteraction(
"get_missing_events",
self._get_missing_events,
room_id, earliest_events, latest_events, limit,
room_id,
earliest_events,
latest_events,
limit,
)
events = yield self._get_events(ids)
defer.returnValue(events)
def _get_missing_events(self, txn, room_id, earliest_events, latest_events,
limit):
def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):
seen_events = set(earliest_events)
front = set(latest_events) - seen_events
@ -425,8 +401,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
new_front = set()
for event_id in front:
txn.execute(
query,
(room_id, event_id, False, limit - len(event_results))
query, (room_id, event_id, False, limit - len(event_results))
)
new_results = set(t[0] for t in txn) - seen_events
@ -457,12 +432,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
column="prev_event_id",
iterable=event_ids,
retcols=("event_id",),
desc="get_successor_events"
desc="get_successor_events",
)
defer.returnValue([
row["event_id"] for row in rows
])
defer.returnValue([row["event_id"] for row in rows])
class EventFederationStore(EventFederationWorkerStore):
@ -481,12 +454,11 @@ class EventFederationStore(EventFederationWorkerStore):
super(EventFederationStore, self).__init__(db_conn, hs)
self.register_background_update_handler(
self.EVENT_AUTH_STATE_ONLY,
self._background_delete_non_state_event_auth,
self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth
)
hs.get_clock().looping_call(
self._delete_old_forward_extrem_cache, 60 * 60 * 1000,
self._delete_old_forward_extrem_cache, 60 * 60 * 1000
)
def _update_min_depth_for_room_txn(self, txn, room_id, depth):
@ -498,12 +470,8 @@ class EventFederationStore(EventFederationWorkerStore):
self._simple_upsert_txn(
txn,
table="room_depth",
keyvalues={
"room_id": room_id,
},
values={
"min_depth": depth,
},
keyvalues={"room_id": room_id},
values={"min_depth": depth},
)
def _handle_mult_prev_events(self, txn, events):
@ -553,11 +521,15 @@ class EventFederationStore(EventFederationWorkerStore):
" )"
)
txn.executemany(query, [
(e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
for ev in events for e_id in ev.prev_event_ids()
if not ev.internal_metadata.is_outlier()
])
txn.executemany(
query,
[
(e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
for ev in events
for e_id in ev.prev_event_ids()
if not ev.internal_metadata.is_outlier()
],
)
query = (
"DELETE FROM event_backward_extremities"
@ -566,16 +538,17 @@ class EventFederationStore(EventFederationWorkerStore):
txn.executemany(
query,
[
(ev.event_id, ev.room_id) for ev in events
(ev.event_id, ev.room_id)
for ev in events
if not ev.internal_metadata.is_outlier()
]
],
)
def _delete_old_forward_extrem_cache(self):
def _delete_old_forward_extrem_cache_txn(txn):
# Delete entries older than a month, while making sure we don't delete
# the only entries for a room.
sql = ("""
sql = """
DELETE FROM stream_ordering_to_exterm
WHERE
room_id IN (
@ -583,11 +556,11 @@ class EventFederationStore(EventFederationWorkerStore):
FROM stream_ordering_to_exterm
WHERE stream_ordering > ?
) AND stream_ordering < ?
""")
"""
txn.execute(
sql,
(self.stream_ordering_month_ago, self.stream_ordering_month_ago,)
sql, (self.stream_ordering_month_ago, self.stream_ordering_month_ago)
)
return run_as_background_process(
"delete_old_forward_extrem_cache",
self.runInteraction,
@ -597,9 +570,7 @@ class EventFederationStore(EventFederationWorkerStore):
def clean_room_for_join(self, room_id):
return self.runInteraction(
"clean_room_for_join",
self._clean_room_for_join_txn,
room_id,
"clean_room_for_join", self._clean_room_for_join_txn, room_id
)
def _clean_room_for_join_txn(self, txn, room_id):
@ -635,7 +606,7 @@ class EventFederationStore(EventFederationWorkerStore):
)
"""
txn.execute(sql, (min_stream_id, max_stream_id,))
txn.execute(sql, (min_stream_id, max_stream_id))
new_progress = {
"target_min_stream_id_inclusive": target_min_stream_id,

View file

@ -31,7 +31,9 @@ logger = logging.getLogger(__name__)
DEFAULT_NOTIF_ACTION = ["notify", {"set_tweak": "highlight", "value": False}]
DEFAULT_HIGHLIGHT_ACTION = [
"notify", {"set_tweak": "sound", "value": "default"}, {"set_tweak": "highlight"}
"notify",
{"set_tweak": "sound", "value": "default"},
{"set_tweak": "highlight"},
]
@ -91,25 +93,26 @@ class EventPushActionsWorkerStore(SQLBaseStore):
@cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
def get_unread_event_push_actions_by_room_for_user(
self, room_id, user_id, last_read_event_id
self, room_id, user_id, last_read_event_id
):
ret = yield self.runInteraction(
"get_unread_event_push_actions_by_room",
self._get_unread_counts_by_receipt_txn,
room_id, user_id, last_read_event_id
room_id,
user_id,
last_read_event_id,
)
defer.returnValue(ret)
def _get_unread_counts_by_receipt_txn(self, txn, room_id, user_id,
last_read_event_id):
def _get_unread_counts_by_receipt_txn(
self, txn, room_id, user_id, last_read_event_id
):
sql = (
"SELECT stream_ordering"
" FROM events"
" WHERE room_id = ? AND event_id = ?"
)
txn.execute(
sql, (room_id, last_read_event_id)
)
txn.execute(sql, (room_id, last_read_event_id))
results = txn.fetchall()
if len(results) == 0:
return {"notify_count": 0, "highlight_count": 0}
@ -138,10 +141,13 @@ class EventPushActionsWorkerStore(SQLBaseStore):
row = txn.fetchone()
notify_count = row[0] if row else 0
txn.execute("""
txn.execute(
"""
SELECT notif_count FROM event_push_summary
WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
""", (room_id, user_id, stream_ordering,))
""",
(room_id, user_id, stream_ordering),
)
rows = txn.fetchall()
if rows:
notify_count += rows[0][0]
@ -161,10 +167,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
row = txn.fetchone()
highlight_count = row[0] if row else 0
return {
"notify_count": notify_count,
"highlight_count": highlight_count,
}
return {"notify_count": notify_count, "highlight_count": highlight_count}
@defer.inlineCallbacks
def get_push_action_users_in_range(self, min_stream_ordering, max_stream_ordering):
@ -175,6 +178,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
)
txn.execute(sql, (min_stream_ordering, max_stream_ordering))
return [r[0] for r in txn]
ret = yield self.runInteraction("get_push_action_users_in_range", f)
defer.returnValue(ret)
@ -223,12 +227,10 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" AND ep.stream_ordering <= ?"
" ORDER BY ep.stream_ordering ASC LIMIT ?"
)
args = [
user_id, user_id,
min_stream_ordering, max_stream_ordering, limit,
]
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
txn.execute(sql, args)
return txn.fetchall()
after_read_receipt = yield self.runInteraction(
"get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
)
@ -253,12 +255,10 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" AND ep.stream_ordering <= ?"
" ORDER BY ep.stream_ordering ASC LIMIT ?"
)
args = [
user_id, user_id,
min_stream_ordering, max_stream_ordering, limit,
]
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
txn.execute(sql, args)
return txn.fetchall()
no_read_receipt = yield self.runInteraction(
"get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
)
@ -269,7 +269,8 @@ class EventPushActionsWorkerStore(SQLBaseStore):
"room_id": row[1],
"stream_ordering": row[2],
"actions": _deserialize_action(row[3], row[4]),
} for row in after_read_receipt + no_read_receipt
}
for row in after_read_receipt + no_read_receipt
]
# Now sort it so it's ordered correctly, since currently it will
@ -326,12 +327,10 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" AND ep.stream_ordering <= ?"
" ORDER BY ep.stream_ordering DESC LIMIT ?"
)
args = [
user_id, user_id,
min_stream_ordering, max_stream_ordering, limit,
]
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
txn.execute(sql, args)
return txn.fetchall()
after_read_receipt = yield self.runInteraction(
"get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
)
@ -356,12 +355,10 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" AND ep.stream_ordering <= ?"
" ORDER BY ep.stream_ordering DESC LIMIT ?"
)
args = [
user_id, user_id,
min_stream_ordering, max_stream_ordering, limit,
]
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
txn.execute(sql, args)
return txn.fetchall()
no_read_receipt = yield self.runInteraction(
"get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
)
@ -374,7 +371,8 @@ class EventPushActionsWorkerStore(SQLBaseStore):
"stream_ordering": row[2],
"actions": _deserialize_action(row[3], row[4]),
"received_ts": row[5],
} for row in after_read_receipt + no_read_receipt
}
for row in after_read_receipt + no_read_receipt
]
# Now sort it so it's ordered correctly, since currently it will
@ -408,7 +406,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
LIMIT 1
"""
txn.execute(sql, (user_id, min_stream_ordering,))
txn.execute(sql, (user_id, min_stream_ordering))
return bool(txn.fetchone())
return self.runInteraction(
@ -454,10 +452,13 @@ class EventPushActionsWorkerStore(SQLBaseStore):
VALUES (?, ?, ?, ?, ?)
"""
txn.executemany(sql, (
_gen_entry(user_id, actions)
for user_id, actions in iteritems(user_id_actions)
))
txn.executemany(
sql,
(
_gen_entry(user_id, actions)
for user_id, actions in iteritems(user_id_actions)
),
)
return self.runInteraction(
"add_push_actions_to_staging", _add_push_actions_to_staging_txn
@ -475,9 +476,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
try:
res = yield self._simple_delete(
table="event_push_actions_staging",
keyvalues={
"event_id": event_id,
},
keyvalues={"event_id": event_id},
desc="remove_push_actions_from_staging",
)
defer.returnValue(res)
@ -486,7 +485,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# another exception here really isn't helpful - there's nothing
# the caller can do about it. Just log the exception and move on.
logger.exception(
"Error removing push actions after event persistence failure",
"Error removing push actions after event persistence failure"
)
def _find_stream_orderings_for_times(self):
@ -503,16 +502,14 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000
)
logger.info(
"Found stream ordering 1 month ago: it's %d",
self.stream_ordering_month_ago
"Found stream ordering 1 month ago: it's %d", self.stream_ordering_month_ago
)
logger.info("Searching for stream ordering 1 day ago")
self.stream_ordering_day_ago = self._find_first_stream_ordering_after_ts_txn(
txn, self._clock.time_msec() - 24 * 60 * 60 * 1000
)
logger.info(
"Found stream ordering 1 day ago: it's %d",
self.stream_ordering_day_ago
"Found stream ordering 1 day ago: it's %d", self.stream_ordering_day_ago
)
def find_first_stream_ordering_after_ts(self, ts):
@ -631,16 +628,17 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
index_name="event_push_actions_highlights_index",
table="event_push_actions",
columns=["user_id", "room_id", "topological_ordering", "stream_ordering"],
where_clause="highlight=1"
where_clause="highlight=1",
)
self._doing_notif_rotation = False
self._rotate_notif_loop = self._clock.looping_call(
self._start_rotate_notifs, 30 * 60 * 1000,
self._start_rotate_notifs, 30 * 60 * 1000
)
def _set_push_actions_for_event_and_users_txn(self, txn, events_and_contexts,
all_events_and_contexts):
def _set_push_actions_for_event_and_users_txn(
self, txn, events_and_contexts, all_events_and_contexts
):
"""Handles moving push actions from staging table to main
event_push_actions table for all events in `events_and_contexts`.
@ -667,43 +665,44 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
"""
if events_and_contexts:
txn.executemany(sql, (
txn.executemany(
sql,
(
event.room_id, event.internal_metadata.stream_ordering,
event.depth, event.event_id,
)
for event, _ in events_and_contexts
))
(
event.room_id,
event.internal_metadata.stream_ordering,
event.depth,
event.event_id,
)
for event, _ in events_and_contexts
),
)
for event, _ in events_and_contexts:
user_ids = self._simple_select_onecol_txn(
txn,
table="event_push_actions_staging",
keyvalues={
"event_id": event.event_id,
},
keyvalues={"event_id": event.event_id},
retcol="user_id",
)
for uid in user_ids:
txn.call_after(
self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
(event.room_id, uid,)
(event.room_id, uid),
)
# Now we delete the staging area for *all* events that were being
# persisted.
txn.executemany(
"DELETE FROM event_push_actions_staging WHERE event_id = ?",
(
(event.event_id,)
for event, _ in all_events_and_contexts
)
((event.event_id,) for event, _ in all_events_and_contexts),
)
@defer.inlineCallbacks
def get_push_actions_for_user(self, user_id, before=None, limit=50,
only_highlight=False):
def get_push_actions_for_user(
self, user_id, before=None, limit=50, only_highlight=False
):
def f(txn):
before_clause = ""
if before:
@ -727,15 +726,12 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
" WHERE epa.event_id = e.event_id"
" AND epa.user_id = ? %s"
" ORDER BY epa.stream_ordering DESC"
" LIMIT ?"
% (before_clause,)
" LIMIT ?" % (before_clause,)
)
txn.execute(sql, args)
return self.cursor_to_dict(txn)
push_actions = yield self.runInteraction(
"get_push_actions_for_user", f
)
push_actions = yield self.runInteraction("get_push_actions_for_user", f)
for pa in push_actions:
pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
defer.returnValue(push_actions)
@ -753,6 +749,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
)
txn.execute(sql, (stream_ordering,))
return txn.fetchone()
result = yield self.runInteraction("get_time_of_last_push_action_before", f)
defer.returnValue(result[0] if result else None)
@ -761,24 +758,24 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
def f(txn):
txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
return txn.fetchone()
result = yield self.runInteraction(
"get_latest_push_action_stream_ordering", f
)
result = yield self.runInteraction("get_latest_push_action_stream_ordering", f)
defer.returnValue(result[0] or 0)
def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
# Sad that we have to blow away the cache for the whole room here
txn.call_after(
self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
(room_id,)
(room_id,),
)
txn.execute(
"DELETE FROM event_push_actions WHERE room_id = ? AND event_id = ?",
(room_id, event_id)
(room_id, event_id),
)
def _remove_old_push_actions_before_txn(self, txn, room_id, user_id,
stream_ordering):
def _remove_old_push_actions_before_txn(
self, txn, room_id, user_id, stream_ordering
):
"""
Purges old push actions for a user and room before a given
stream_ordering.
@ -795,7 +792,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
"""
txn.call_after(
self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
(room_id, user_id, )
(room_id, user_id),
)
# We need to join on the events table to get the received_ts for
@ -811,13 +808,16 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
" WHERE user_id = ? AND room_id = ? AND "
" stream_ordering <= ?"
" AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
(user_id, room_id, stream_ordering, self.stream_ordering_month_ago)
(user_id, room_id, stream_ordering, self.stream_ordering_month_ago),
)
txn.execute("""
txn.execute(
"""
DELETE FROM event_push_summary
WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
""", (room_id, user_id, stream_ordering))
""",
(room_id, user_id, stream_ordering),
)
def _start_rotate_notifs(self):
return run_as_background_process("rotate_notifs", self._rotate_notifs)
@ -833,8 +833,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
logger.info("Rotating notifications")
caught_up = yield self.runInteraction(
"_rotate_notifs",
self._rotate_notifs_txn
"_rotate_notifs", self._rotate_notifs_txn
)
if caught_up:
break
@ -856,11 +855,14 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
# We don't to try and rotate millions of rows at once, so we cap the
# maximum stream ordering we'll rotate before.
txn.execute("""
txn.execute(
"""
SELECT stream_ordering FROM event_push_actions
WHERE stream_ordering > ?
ORDER BY stream_ordering ASC LIMIT 1 OFFSET ?
""", (old_rotate_stream_ordering, self._rotate_count))
""",
(old_rotate_stream_ordering, self._rotate_count),
)
stream_row = txn.fetchone()
if stream_row:
offset_stream_ordering, = stream_row
@ -904,7 +906,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
LEFT JOIN event_push_summary AS old USING (user_id, room_id)
"""
txn.execute(sql, (old_rotate_stream_ordering, rotate_to_stream_ordering,))
txn.execute(sql, (old_rotate_stream_ordering, rotate_to_stream_ordering))
rows = txn.fetchall()
logger.info("Rotating notifications, handling %d rows", len(rows))
@ -922,8 +924,9 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
"notif_count": row[2],
"stream_ordering": row[3],
}
for row in rows if row[4] is None
]
for row in rows
if row[4] is None
],
)
txn.executemany(
@ -931,20 +934,20 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
UPDATE event_push_summary SET notif_count = ?, stream_ordering = ?
WHERE user_id = ? AND room_id = ?
""",
((row[2], row[3], row[0], row[1],) for row in rows if row[4] is not None)
((row[2], row[3], row[0], row[1]) for row in rows if row[4] is not None),
)
txn.execute(
"DELETE FROM event_push_actions"
" WHERE ? <= stream_ordering AND stream_ordering < ? AND highlight = 0",
(old_rotate_stream_ordering, rotate_to_stream_ordering,)
(old_rotate_stream_ordering, rotate_to_stream_ordering),
)
logger.info("Rotating notifications, deleted %s push actions", txn.rowcount)
txn.execute(
"UPDATE event_push_summary_stream_ordering SET stream_ordering = ?",
(rotate_to_stream_ordering,)
(rotate_to_stream_ordering,),
)

View file

@ -71,17 +71,21 @@ class EventsWorkerStore(SQLBaseStore):
"""
return self._simple_select_one_onecol(
table="events",
keyvalues={
"event_id": event_id,
},
keyvalues={"event_id": event_id},
retcol="received_ts",
desc="get_received_ts",
)
@defer.inlineCallbacks
def get_event(self, event_id, check_redacted=True,
get_prev_content=False, allow_rejected=False,
allow_none=False, check_room_id=None):
def get_event(
self,
event_id,
check_redacted=True,
get_prev_content=False,
allow_rejected=False,
allow_none=False,
check_room_id=None,
):
"""Get an event from the database by event_id.
Args:
@ -118,8 +122,13 @@ class EventsWorkerStore(SQLBaseStore):
defer.returnValue(event)
@defer.inlineCallbacks
def get_events(self, event_ids, check_redacted=True,
get_prev_content=False, allow_rejected=False):
def get_events(
self,
event_ids,
check_redacted=True,
get_prev_content=False,
allow_rejected=False,
):
"""Get events from the database
Args:
@ -143,8 +152,13 @@ class EventsWorkerStore(SQLBaseStore):
defer.returnValue({e.event_id: e for e in events})
@defer.inlineCallbacks
def _get_events(self, event_ids, check_redacted=True,
get_prev_content=False, allow_rejected=False):
def _get_events(
self,
event_ids,
check_redacted=True,
get_prev_content=False,
allow_rejected=False,
):
if not event_ids:
defer.returnValue([])
@ -152,8 +166,7 @@ class EventsWorkerStore(SQLBaseStore):
event_ids = set(event_ids)
event_entry_map = self._get_events_from_cache(
event_ids,
allow_rejected=allow_rejected,
event_ids, allow_rejected=allow_rejected
)
missing_events_ids = [e for e in event_ids if e not in event_entry_map]
@ -169,8 +182,7 @@ class EventsWorkerStore(SQLBaseStore):
#
# _enqueue_events is a bit of a rubbish name but naming is hard.
missing_events = yield self._enqueue_events(
missing_events_ids,
allow_rejected=allow_rejected,
missing_events_ids, allow_rejected=allow_rejected
)
event_entry_map.update(missing_events)
@ -214,7 +226,10 @@ class EventsWorkerStore(SQLBaseStore):
)
expected_domain = get_domain_from_id(entry.event.sender)
if orig_sender and get_domain_from_id(orig_sender) == expected_domain:
if (
orig_sender
and get_domain_from_id(orig_sender) == expected_domain
):
# This redaction event is allowed. Mark as not needing a
# recheck.
entry.event.internal_metadata.recheck_redaction = False
@ -267,8 +282,7 @@ class EventsWorkerStore(SQLBaseStore):
for event_id in events:
ret = self._get_event_cache.get(
(event_id,), None,
update_metrics=update_metrics,
(event_id,), None, update_metrics=update_metrics
)
if not ret:
continue
@ -318,19 +332,13 @@ class EventsWorkerStore(SQLBaseStore):
with Measure(self._clock, "_fetch_event_list"):
try:
event_id_lists = list(zip(*event_list))[0]
event_ids = [
item for sublist in event_id_lists for item in sublist
]
event_ids = [item for sublist in event_id_lists for item in sublist]
rows = self._new_transaction(
conn, "do_fetch", [], [],
self._fetch_event_rows, event_ids,
conn, "do_fetch", [], [], self._fetch_event_rows, event_ids
)
row_dict = {
r["event_id"]: r
for r in rows
}
row_dict = {r["event_id"]: r for r in rows}
# We only want to resolve deferreds from the main thread
def fire(lst, res):
@ -338,13 +346,10 @@ class EventsWorkerStore(SQLBaseStore):
if not d.called:
try:
with PreserveLoggingContext():
d.callback([
res[i]
for i in ids
if i in res
])
d.callback([res[i] for i in ids if i in res])
except Exception:
logger.exception("Failed to callback")
with PreserveLoggingContext():
self.hs.get_reactor().callFromThread(fire, event_list, row_dict)
except Exception as e:
@ -371,9 +376,7 @@ class EventsWorkerStore(SQLBaseStore):
events_d = defer.Deferred()
with self._event_fetch_lock:
self._event_fetch_list.append(
(events, events_d)
)
self._event_fetch_list.append((events, events_d))
self._event_fetch_lock.notify()
@ -385,9 +388,7 @@ class EventsWorkerStore(SQLBaseStore):
if should_start:
run_as_background_process(
"fetch_events",
self.runWithConnection,
self._do_fetch,
"fetch_events", self.runWithConnection, self._do_fetch
)
logger.debug("Loading %d events", len(events))
@ -398,29 +399,30 @@ class EventsWorkerStore(SQLBaseStore):
if not allow_rejected:
rows[:] = [r for r in rows if not r["rejects"]]
res = yield make_deferred_yieldable(defer.gatherResults(
[
run_in_background(
self._get_event_from_row,
row["internal_metadata"], row["json"], row["redacts"],
rejected_reason=row["rejects"],
format_version=row["format_version"],
)
for row in rows
],
consumeErrors=True
))
res = yield make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
self._get_event_from_row,
row["internal_metadata"],
row["json"],
row["redacts"],
rejected_reason=row["rejects"],
format_version=row["format_version"],
)
for row in rows
],
consumeErrors=True,
)
)
defer.returnValue({
e.event.event_id: e
for e in res if e
})
defer.returnValue({e.event.event_id: e for e in res if e})
def _fetch_event_rows(self, txn, events):
rows = []
N = 200
for i in range(1 + len(events) // N):
evs = events[i * N:(i + 1) * N]
evs = events[i * N : (i + 1) * N]
if not evs:
break
@ -444,8 +446,9 @@ class EventsWorkerStore(SQLBaseStore):
return rows
@defer.inlineCallbacks
def _get_event_from_row(self, internal_metadata, js, redacted,
format_version, rejected_reason=None):
def _get_event_from_row(
self, internal_metadata, js, redacted, format_version, rejected_reason=None
):
with Measure(self._clock, "_get_event_from_row"):
d = json.loads(js)
internal_metadata = json.loads(internal_metadata)
@ -484,9 +487,7 @@ class EventsWorkerStore(SQLBaseStore):
# Get the redaction event.
because = yield self.get_event(
redaction_id,
check_redacted=False,
allow_none=True,
redaction_id, check_redacted=False, allow_none=True
)
if because:
@ -508,8 +509,7 @@ class EventsWorkerStore(SQLBaseStore):
redacted_event = None
cache_entry = _EventCacheEntry(
event=original_ev,
redacted_event=redacted_event,
event=original_ev, redacted_event=redacted_event
)
self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
@ -545,23 +545,17 @@ class EventsWorkerStore(SQLBaseStore):
results = set()
def have_seen_events_txn(txn, chunk):
sql = (
"SELECT event_id FROM events as e WHERE e.event_id IN (%s)"
% (",".join("?" * len(chunk)), )
sql = "SELECT event_id FROM events as e WHERE e.event_id IN (%s)" % (
",".join("?" * len(chunk)),
)
txn.execute(sql, chunk)
for (event_id, ) in txn:
for (event_id,) in txn:
results.add(event_id)
# break the input up into chunks of 100
input_iterator = iter(event_ids)
for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)),
[]):
yield self.runInteraction(
"have_seen_events",
have_seen_events_txn,
chunk,
)
for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
yield self.runInteraction("have_seen_events", have_seen_events_txn, chunk)
defer.returnValue(results)
def get_seen_events_with_rejections(self, event_ids):

View file

@ -35,10 +35,7 @@ class FilteringStore(SQLBaseStore):
def_json = yield self._simple_select_one_onecol(
table="user_filters",
keyvalues={
"user_id": user_localpart,
"filter_id": filter_id,
},
keyvalues={"user_id": user_localpart, "filter_id": filter_id},
retcol="filter_json",
allow_none=False,
desc="get_user_filter",
@ -61,10 +58,7 @@ class FilteringStore(SQLBaseStore):
if filter_id_response is not None:
return filter_id_response[0]
sql = (
"SELECT MAX(filter_id) FROM user_filters "
"WHERE user_id = ?"
)
sql = "SELECT MAX(filter_id) FROM user_filters " "WHERE user_id = ?"
txn.execute(sql, (user_localpart,))
max_id = txn.fetchone()[0]
if max_id is None:

File diff suppressed because it is too large Load diff

View file

@ -56,12 +56,13 @@ class KeyStore(SQLBaseStore):
desc="get_server_certificate",
)
tls_certificate = OpenSSL.crypto.load_certificate(
OpenSSL.crypto.FILETYPE_ASN1, tls_certificate_bytes,
OpenSSL.crypto.FILETYPE_ASN1, tls_certificate_bytes
)
defer.returnValue(tls_certificate)
def store_server_certificate(self, server_name, from_server, time_now_ms,
tls_certificate):
def store_server_certificate(
self, server_name, from_server, time_now_ms, tls_certificate
):
"""Stores the TLS X.509 certificate for the given server
Args:
server_name (str): The name of the server.
@ -75,10 +76,7 @@ class KeyStore(SQLBaseStore):
fingerprint = hashlib.sha256(tls_certificate_bytes).hexdigest()
return self._simple_upsert(
table="server_tls_certificates",
keyvalues={
"server_name": server_name,
"fingerprint": fingerprint,
},
keyvalues={"server_name": server_name, "fingerprint": fingerprint},
values={
"from_server": from_server,
"ts_added_ms": time_now_ms,
@ -91,19 +89,14 @@ class KeyStore(SQLBaseStore):
def _get_server_verify_key(self, server_name, key_id):
verify_key_bytes = yield self._simple_select_one_onecol(
table="server_signature_keys",
keyvalues={
"server_name": server_name,
"key_id": key_id,
},
keyvalues={"server_name": server_name, "key_id": key_id},
retcol="verify_key",
desc="_get_server_verify_key",
allow_none=True,
)
if verify_key_bytes:
defer.returnValue(decode_verify_key_bytes(
key_id, bytes(verify_key_bytes)
))
defer.returnValue(decode_verify_key_bytes(key_id, bytes(verify_key_bytes)))
@defer.inlineCallbacks
def get_server_verify_keys(self, server_name, key_ids):
@ -123,8 +116,9 @@ class KeyStore(SQLBaseStore):
keys[key_id] = key
defer.returnValue(keys)
def store_server_verify_key(self, server_name, from_server, time_now_ms,
verify_key):
def store_server_verify_key(
self, server_name, from_server, time_now_ms, verify_key
):
"""Stores a NACL verification key for the given server.
Args:
server_name (str): The name of the server.
@ -139,10 +133,7 @@ class KeyStore(SQLBaseStore):
self._simple_upsert_txn(
txn,
table="server_signature_keys",
keyvalues={
"server_name": server_name,
"key_id": key_id,
},
keyvalues={"server_name": server_name, "key_id": key_id},
values={
"from_server": from_server,
"ts_added_ms": time_now_ms,
@ -150,14 +141,14 @@ class KeyStore(SQLBaseStore):
},
)
txn.call_after(
self._get_server_verify_key.invalidate,
(server_name, key_id)
self._get_server_verify_key.invalidate, (server_name, key_id)
)
return self.runInteraction("store_server_verify_key", _txn)
def store_server_keys_json(self, server_name, key_id, from_server,
ts_now_ms, ts_expires_ms, key_json_bytes):
def store_server_keys_json(
self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes
):
"""Stores the JSON bytes for a set of keys from a server
The JSON should be signed by the originating server, the intermediate
server, and by this server. Updates the value for the
@ -200,6 +191,7 @@ class KeyStore(SQLBaseStore):
Dict mapping (server_name, key_id, source) triplets to dicts with
"ts_valid_until_ms" and "key_json" keys.
"""
def _get_server_keys_json_txn(txn):
results = {}
for server_name, key_id, from_server in server_keys:
@ -222,6 +214,5 @@ class KeyStore(SQLBaseStore):
)
results[(server_name, key_id, from_server)] = rows
return results
return self.runInteraction(
"get_server_keys_json", _get_server_keys_json_txn
)
return self.runInteraction("get_server_keys_json", _get_server_keys_json_txn)

View file

@ -38,15 +38,27 @@ class MediaRepositoryStore(BackgroundUpdateStore):
"local_media_repository",
{"media_id": media_id},
(
"media_type", "media_length", "upload_name", "created_ts",
"quarantined_by", "url_cache",
"media_type",
"media_length",
"upload_name",
"created_ts",
"quarantined_by",
"url_cache",
),
allow_none=True,
desc="get_local_media",
)
def store_local_media(self, media_id, media_type, time_now_ms, upload_name,
media_length, user_id, url_cache=None):
def store_local_media(
self,
media_id,
media_type,
time_now_ms,
upload_name,
media_length,
user_id,
url_cache=None,
):
return self._simple_insert(
"local_media_repository",
{
@ -66,6 +78,7 @@ class MediaRepositoryStore(BackgroundUpdateStore):
Returns:
None if the URL isn't cached.
"""
def get_url_cache_txn(txn):
# get the most recently cached result (relative to the given ts)
sql = (
@ -92,16 +105,25 @@ class MediaRepositoryStore(BackgroundUpdateStore):
if not row:
return None
return dict(zip((
'response_code', 'etag', 'expires_ts', 'og', 'media_id', 'download_ts'
), row))
return dict(
zip(
(
'response_code',
'etag',
'expires_ts',
'og',
'media_id',
'download_ts',
),
row,
)
)
return self.runInteraction(
"get_url_cache", get_url_cache_txn
)
return self.runInteraction("get_url_cache", get_url_cache_txn)
def store_url_cache(self, url, response_code, etag, expires_ts, og, media_id,
download_ts):
def store_url_cache(
self, url, response_code, etag, expires_ts, og, media_id, download_ts
):
return self._simple_insert(
"local_media_repository_url_cache",
{
@ -121,15 +143,24 @@ class MediaRepositoryStore(BackgroundUpdateStore):
"local_media_repository_thumbnails",
{"media_id": media_id},
(
"thumbnail_width", "thumbnail_height", "thumbnail_method",
"thumbnail_type", "thumbnail_length",
"thumbnail_width",
"thumbnail_height",
"thumbnail_method",
"thumbnail_type",
"thumbnail_length",
),
desc="get_local_media_thumbnails",
)
def store_local_thumbnail(self, media_id, thumbnail_width,
thumbnail_height, thumbnail_type,
thumbnail_method, thumbnail_length):
def store_local_thumbnail(
self,
media_id,
thumbnail_width,
thumbnail_height,
thumbnail_type,
thumbnail_method,
thumbnail_length,
):
return self._simple_insert(
"local_media_repository_thumbnails",
{
@ -148,16 +179,27 @@ class MediaRepositoryStore(BackgroundUpdateStore):
"remote_media_cache",
{"media_origin": origin, "media_id": media_id},
(
"media_type", "media_length", "upload_name", "created_ts",
"filesystem_id", "quarantined_by",
"media_type",
"media_length",
"upload_name",
"created_ts",
"filesystem_id",
"quarantined_by",
),
allow_none=True,
desc="get_cached_remote_media",
)
def store_cached_remote_media(self, origin, media_id, media_type,
media_length, time_now_ms, upload_name,
filesystem_id):
def store_cached_remote_media(
self,
origin,
media_id,
media_type,
media_length,
time_now_ms,
upload_name,
filesystem_id,
):
return self._simple_insert(
"remote_media_cache",
{
@ -181,26 +223,27 @@ class MediaRepositoryStore(BackgroundUpdateStore):
remote_media (iterable[(str, str)]): Set of (server_name, media_id)
time_ms: Current time in milliseconds
"""
def update_cache_txn(txn):
sql = (
"UPDATE remote_media_cache SET last_access_ts = ?"
" WHERE media_origin = ? AND media_id = ?"
)
txn.executemany(sql, (
(time_ms, media_origin, media_id)
for media_origin, media_id in remote_media
))
txn.executemany(
sql,
(
(time_ms, media_origin, media_id)
for media_origin, media_id in remote_media
),
)
sql = (
"UPDATE local_media_repository SET last_access_ts = ?"
" WHERE media_id = ?"
)
txn.executemany(sql, (
(time_ms, media_id)
for media_id in local_media
))
txn.executemany(sql, ((time_ms, media_id) for media_id in local_media))
return self.runInteraction("update_cached_last_access_time", update_cache_txn)
@ -209,16 +252,27 @@ class MediaRepositoryStore(BackgroundUpdateStore):
"remote_media_cache_thumbnails",
{"media_origin": origin, "media_id": media_id},
(
"thumbnail_width", "thumbnail_height", "thumbnail_method",
"thumbnail_type", "thumbnail_length", "filesystem_id",
"thumbnail_width",
"thumbnail_height",
"thumbnail_method",
"thumbnail_type",
"thumbnail_length",
"filesystem_id",
),
desc="get_remote_media_thumbnails",
)
def store_remote_media_thumbnail(self, origin, media_id, filesystem_id,
thumbnail_width, thumbnail_height,
thumbnail_type, thumbnail_method,
thumbnail_length):
def store_remote_media_thumbnail(
self,
origin,
media_id,
filesystem_id,
thumbnail_width,
thumbnail_height,
thumbnail_type,
thumbnail_method,
thumbnail_length,
):
return self._simple_insert(
"remote_media_cache_thumbnails",
{
@ -250,17 +304,14 @@ class MediaRepositoryStore(BackgroundUpdateStore):
self._simple_delete_txn(
txn,
"remote_media_cache",
keyvalues={
"media_origin": media_origin, "media_id": media_id
},
keyvalues={"media_origin": media_origin, "media_id": media_id},
)
self._simple_delete_txn(
txn,
"remote_media_cache_thumbnails",
keyvalues={
"media_origin": media_origin, "media_id": media_id
},
keyvalues={"media_origin": media_origin, "media_id": media_id},
)
return self.runInteraction("delete_remote_media", delete_remote_media_txn)
def get_expired_url_cache(self, now_ts):
@ -281,10 +332,7 @@ class MediaRepositoryStore(BackgroundUpdateStore):
if len(media_ids) == 0:
return
sql = (
"DELETE FROM local_media_repository_url_cache"
" WHERE media_id = ?"
)
sql = "DELETE FROM local_media_repository_url_cache" " WHERE media_id = ?"
def _delete_url_cache_txn(txn):
txn.executemany(sql, [(media_id,) for media_id in media_ids])
@ -304,7 +352,7 @@ class MediaRepositoryStore(BackgroundUpdateStore):
return [row[0] for row in txn]
return self.runInteraction(
"get_url_cache_media_before", _get_url_cache_media_before_txn,
"get_url_cache_media_before", _get_url_cache_media_before_txn
)
def delete_url_cache_media(self, media_ids):
@ -312,20 +360,14 @@ class MediaRepositoryStore(BackgroundUpdateStore):
return
def _delete_url_cache_media_txn(txn):
sql = (
"DELETE FROM local_media_repository"
" WHERE media_id = ?"
)
sql = "DELETE FROM local_media_repository" " WHERE media_id = ?"
txn.executemany(sql, [(media_id,) for media_id in media_ids])
sql = (
"DELETE FROM local_media_repository_thumbnails"
" WHERE media_id = ?"
)
sql = "DELETE FROM local_media_repository_thumbnails" " WHERE media_id = ?"
txn.executemany(sql, [(media_id,) for media_id in media_ids])
return self.runInteraction(
"delete_url_cache_media", _delete_url_cache_media_txn,
"delete_url_cache_media", _delete_url_cache_media_txn
)

View file

@ -35,9 +35,12 @@ class MonthlyActiveUsersStore(SQLBaseStore):
self.reserved_users = ()
# Do not add more reserved users than the total allowable number
self._new_transaction(
dbconn, "initialise_mau_threepids", [], [],
dbconn,
"initialise_mau_threepids",
[],
[],
self._initialise_reserved_users,
hs.config.mau_limits_reserved_threepids[:self.hs.config.max_mau_value],
hs.config.mau_limits_reserved_threepids[: self.hs.config.max_mau_value],
)
def _initialise_reserved_users(self, txn, threepids):
@ -51,10 +54,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
reserved_user_list = []
for tp in threepids:
user_id = self.get_user_id_by_threepid_txn(
txn,
tp["medium"], tp["address"]
)
user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"])
if user_id:
is_support = self.is_support_user_txn(txn, user_id)
@ -62,9 +62,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
self.upsert_monthly_active_user_txn(txn, user_id)
reserved_user_list.append(user_id)
else:
logger.warning(
"mau limit reserved threepid %s not found in db" % tp
)
logger.warning("mau limit reserved threepid %s not found in db" % tp)
self.reserved_users = tuple(reserved_user_list)
@defer.inlineCallbacks
@ -75,12 +73,11 @@ class MonthlyActiveUsersStore(SQLBaseStore):
Returns:
Deferred[]
"""
def _reap_users(txn):
# Purge stale users
thirty_days_ago = (
int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
)
thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
query_args = [thirty_days_ago]
base_sql = "DELETE FROM monthly_active_users WHERE timestamp < ?"
@ -158,6 +155,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
txn.execute(sql)
count, = txn.fetchone()
return count
return self.runInteraction("count_users", _count_users)
@defer.inlineCallbacks
@ -198,14 +196,11 @@ class MonthlyActiveUsersStore(SQLBaseStore):
return
yield self.runInteraction(
"upsert_monthly_active_user", self.upsert_monthly_active_user_txn,
user_id
"upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id
)
user_in_mau = self.user_last_seen_monthly_active.cache.get(
(user_id,),
None,
update_metrics=False
(user_id,), None, update_metrics=False
)
if user_in_mau is None:
self.get_monthly_active_count.invalidate(())
@ -247,12 +242,8 @@ class MonthlyActiveUsersStore(SQLBaseStore):
is_insert = self._simple_upsert_txn(
txn,
table="monthly_active_users",
keyvalues={
"user_id": user_id,
},
values={
"timestamp": int(self._clock.time_msec()),
},
keyvalues={"user_id": user_id},
values={"timestamp": int(self._clock.time_msec())},
)
return is_insert
@ -268,15 +259,13 @@ class MonthlyActiveUsersStore(SQLBaseStore):
"""
return(self._simple_select_one_onecol(
return self._simple_select_one_onecol(
table="monthly_active_users",
keyvalues={
"user_id": user_id,
},
keyvalues={"user_id": user_id},
retcol="timestamp",
allow_none=True,
desc="user_last_seen_monthly_active",
))
)
@defer.inlineCallbacks
def populate_monthly_active_users(self, user_id):

View file

@ -10,7 +10,7 @@ class OpenIdStore(SQLBaseStore):
"ts_valid_until_ms": ts_valid_until_ms,
"user_id": user_id,
},
desc="insert_open_id_token"
desc="insert_open_id_token",
)
def get_user_id_for_open_id_token(self, token, ts_now_ms):
@ -27,6 +27,5 @@ class OpenIdStore(SQLBaseStore):
return None
else:
return rows[0][0]
return self.runInteraction(
"get_user_id_for_token", get_user_id_for_token_txn
)
return self.runInteraction("get_user_id_for_token", get_user_id_for_token_txn)

View file

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 53
SCHEMA_VERSION = 54
dir_path = os.path.abspath(os.path.dirname(__file__))
@ -143,10 +143,9 @@ def _setup_new_database(cur, database_engine):
cur.execute(
database_engine.convert_param_style(
"INSERT INTO schema_version (version, upgraded)"
" VALUES (?,?)"
"INSERT INTO schema_version (version, upgraded)" " VALUES (?,?)"
),
(max_current_ver, False,)
(max_current_ver, False),
)
_upgrade_existing_database(
@ -160,8 +159,15 @@ def _setup_new_database(cur, database_engine):
)
def _upgrade_existing_database(cur, current_version, applied_delta_files,
upgraded, database_engine, config, is_empty=False):
def _upgrade_existing_database(
cur,
current_version,
applied_delta_files,
upgraded,
database_engine,
config,
is_empty=False,
):
"""Upgrades an existing database.
Delta files can either be SQL stored in *.sql files, or python modules
@ -209,8 +215,8 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
if current_version > SCHEMA_VERSION:
raise ValueError(
"Cannot use this database as it is too " +
"new for the server to understand"
"Cannot use this database as it is too "
+ "new for the server to understand"
)
start_ver = current_version
@ -239,20 +245,14 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
if relative_path in applied_delta_files:
continue
absolute_path = os.path.join(
dir_path, "schema", "delta", relative_path,
)
absolute_path = os.path.join(dir_path, "schema", "delta", relative_path)
root_name, ext = os.path.splitext(file_name)
if ext == ".py":
# This is a python upgrade module. We need to import into some
# package and then execute its `run_upgrade` function.
module_name = "synapse.storage.v%d_%s" % (
v, root_name
)
module_name = "synapse.storage.v%d_%s" % (v, root_name)
with open(absolute_path) as python_file:
module = imp.load_source(
module_name, absolute_path, python_file
)
module = imp.load_source(module_name, absolute_path, python_file)
logger.info("Running script %s", relative_path)
module.run_create(cur, database_engine)
if not is_empty:
@ -269,8 +269,7 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
else:
# Not a valid delta file.
logger.warn(
"Found directory entry that did not end in .py or"
" .sql: %s",
"Found directory entry that did not end in .py or" " .sql: %s",
relative_path,
)
continue
@ -278,19 +277,17 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
# Mark as done.
cur.execute(
database_engine.convert_param_style(
"INSERT INTO applied_schema_deltas (version, file)"
" VALUES (?,?)",
"INSERT INTO applied_schema_deltas (version, file)" " VALUES (?,?)"
),
(v, relative_path)
(v, relative_path),
)
cur.execute("DELETE FROM schema_version")
cur.execute(
database_engine.convert_param_style(
"INSERT INTO schema_version (version, upgraded)"
" VALUES (?,?)",
"INSERT INTO schema_version (version, upgraded)" " VALUES (?,?)"
),
(v, True)
(v, True),
)
@ -308,7 +305,7 @@ def _apply_module_schemas(txn, database_engine, config):
continue
modname = ".".join((mod.__module__, mod.__name__))
_apply_module_schema_files(
txn, database_engine, modname, mod.get_db_schema_files(),
txn, database_engine, modname, mod.get_db_schema_files()
)
@ -326,7 +323,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
database_engine.convert_param_style(
"SELECT file FROM applied_module_schemas WHERE module_name = ?"
),
(modname,)
(modname,),
)
applied_deltas = set(d for d, in cur)
for (name, stream) in names_and_streams:
@ -336,7 +333,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
root_name, ext = os.path.splitext(name)
if ext != '.sql':
raise PrepareDatabaseException(
"only .sql files are currently supported for module schemas",
"only .sql files are currently supported for module schemas"
)
logger.info("applying schema %s for %s", name, modname)
@ -346,10 +343,9 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
# Mark as done.
cur.execute(
database_engine.convert_param_style(
"INSERT INTO applied_module_schemas (module_name, file)"
" VALUES (?,?)",
"INSERT INTO applied_module_schemas (module_name, file)" " VALUES (?,?)"
),
(modname, name)
(modname, name),
)
@ -386,10 +382,7 @@ def get_statements(f):
statements = line.split(";")
# We must prepend statement_buffer to the first statement
first_statement = "%s %s" % (
statement_buffer.strip(),
statements[0].strip()
)
first_statement = "%s %s" % (statement_buffer.strip(), statements[0].strip())
statements[0] = first_statement
# Every entry, except the last, is a full statement
@ -409,9 +402,7 @@ def executescript(txn, schema_path):
def _get_or_create_schema_state(txn, database_engine):
# Bluntly try creating the schema_version tables.
schema_path = os.path.join(
dir_path, "schema", "schema_version.sql",
)
schema_path = os.path.join(dir_path, "schema", "schema_version.sql")
executescript(txn, schema_path)
txn.execute("SELECT version, upgraded FROM schema_version")
@ -424,7 +415,7 @@ def _get_or_create_schema_state(txn, database_engine):
database_engine.convert_param_style(
"SELECT file FROM applied_schema_deltas WHERE version >= ?"
),
(current_version,)
(current_version,),
)
applied_deltas = [d for d, in txn]
return current_version, applied_deltas, upgraded

View file

@ -19,15 +19,25 @@ from twisted.internet import defer
from synapse.api.constants import PresenceState
from synapse.util import batch_iter
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
from synapse.util.caches.descriptors import cached, cachedList
from ._base import SQLBaseStore
class UserPresenceState(namedtuple("UserPresenceState",
("user_id", "state", "last_active_ts",
"last_federation_update_ts", "last_user_sync_ts",
"status_msg", "currently_active"))):
class UserPresenceState(
namedtuple(
"UserPresenceState",
(
"user_id",
"state",
"last_active_ts",
"last_federation_update_ts",
"last_user_sync_ts",
"status_msg",
"currently_active",
),
)
):
"""Represents the current presence state of the user.
user_id (str)
@ -75,22 +85,21 @@ class PresenceStore(SQLBaseStore):
with stream_ordering_manager as stream_orderings:
yield self.runInteraction(
"update_presence",
self._update_presence_txn, stream_orderings, presence_states,
self._update_presence_txn,
stream_orderings,
presence_states,
)
defer.returnValue((
stream_orderings[-1], self._presence_id_gen.get_current_token()
))
defer.returnValue(
(stream_orderings[-1], self._presence_id_gen.get_current_token())
)
def _update_presence_txn(self, txn, stream_orderings, presence_states):
for stream_id, state in zip(stream_orderings, presence_states):
txn.call_after(
self.presence_stream_cache.entity_has_changed,
state.user_id, stream_id,
)
txn.call_after(
self._get_presence_for_user.invalidate, (state.user_id,)
self.presence_stream_cache.entity_has_changed, state.user_id, stream_id
)
txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,))
# Actually insert new rows
self._simple_insert_many_txn(
@ -113,18 +122,13 @@ class PresenceStore(SQLBaseStore):
# Delete old rows to stop database from getting really big
sql = (
"DELETE FROM presence_stream WHERE"
" stream_id < ?"
" AND user_id IN (%s)"
"DELETE FROM presence_stream WHERE" " stream_id < ?" " AND user_id IN (%s)"
)
for states in batch_iter(presence_states, 50):
args = [stream_id]
args.extend(s.user_id for s in states)
txn.execute(
sql % (",".join("?" for _ in states),),
args
)
txn.execute(sql % (",".join("?" for _ in states),), args)
def get_all_presence_updates(self, last_id, current_id):
if last_id == current_id:
@ -149,8 +153,12 @@ class PresenceStore(SQLBaseStore):
def _get_presence_for_user(self, user_id):
raise NotImplementedError()
@cachedList(cached_method_name="_get_presence_for_user", list_name="user_ids",
num_args=1, inlineCallbacks=True)
@cachedList(
cached_method_name="_get_presence_for_user",
list_name="user_ids",
num_args=1,
inlineCallbacks=True,
)
def get_presence_for_users(self, user_ids):
rows = yield self._simple_select_many_batch(
table="presence_stream",
@ -180,8 +188,10 @@ class PresenceStore(SQLBaseStore):
def allow_presence_visible(self, observed_localpart, observer_userid):
return self._simple_insert(
table="presence_allow_inbound",
values={"observed_user_id": observed_localpart,
"observer_user_id": observer_userid},
values={
"observed_user_id": observed_localpart,
"observer_user_id": observer_userid,
},
desc="allow_presence_visible",
or_ignore=True,
)
@ -189,89 +199,9 @@ class PresenceStore(SQLBaseStore):
def disallow_presence_visible(self, observed_localpart, observer_userid):
return self._simple_delete_one(
table="presence_allow_inbound",
keyvalues={"observed_user_id": observed_localpart,
"observer_user_id": observer_userid},
keyvalues={
"observed_user_id": observed_localpart,
"observer_user_id": observer_userid,
},
desc="disallow_presence_visible",
)
def add_presence_list_pending(self, observer_localpart, observed_userid):
return self._simple_insert(
table="presence_list",
values={"user_id": observer_localpart,
"observed_user_id": observed_userid,
"accepted": False},
desc="add_presence_list_pending",
)
def set_presence_list_accepted(self, observer_localpart, observed_userid):
def update_presence_list_txn(txn):
result = self._simple_update_one_txn(
txn,
table="presence_list",
keyvalues={
"user_id": observer_localpart,
"observed_user_id": observed_userid
},
updatevalues={"accepted": True},
)
self._invalidate_cache_and_stream(
txn, self.get_presence_list_accepted, (observer_localpart,)
)
self._invalidate_cache_and_stream(
txn, self.get_presence_list_observers_accepted, (observed_userid,)
)
return result
return self.runInteraction(
"set_presence_list_accepted", update_presence_list_txn,
)
def get_presence_list(self, observer_localpart, accepted=None):
if accepted:
return self.get_presence_list_accepted(observer_localpart)
else:
keyvalues = {"user_id": observer_localpart}
if accepted is not None:
keyvalues["accepted"] = accepted
return self._simple_select_list(
table="presence_list",
keyvalues=keyvalues,
retcols=["observed_user_id", "accepted"],
desc="get_presence_list",
)
@cached()
def get_presence_list_accepted(self, observer_localpart):
return self._simple_select_list(
table="presence_list",
keyvalues={"user_id": observer_localpart, "accepted": True},
retcols=["observed_user_id", "accepted"],
desc="get_presence_list_accepted",
)
@cachedInlineCallbacks()
def get_presence_list_observers_accepted(self, observed_userid):
user_localparts = yield self._simple_select_onecol(
table="presence_list",
keyvalues={"observed_user_id": observed_userid, "accepted": True},
retcol="user_id",
desc="get_presence_list_accepted",
)
defer.returnValue([
"@%s:%s" % (u, self.hs.hostname,) for u in user_localparts
])
@defer.inlineCallbacks
def del_presence_list(self, observer_localpart, observed_userid):
yield self._simple_delete_one(
table="presence_list",
keyvalues={"user_id": observer_localpart,
"observed_user_id": observed_userid},
desc="del_presence_list",
)
self.get_presence_list_accepted.invalidate((observer_localpart,))
self.get_presence_list_observers_accepted.invalidate((observed_userid,))

View file

@ -41,8 +41,7 @@ class ProfileWorkerStore(SQLBaseStore):
defer.returnValue(
ProfileInfo(
avatar_url=profile['avatar_url'],
display_name=profile['displayname'],
avatar_url=profile['avatar_url'], display_name=profile['displayname']
)
)
@ -66,16 +65,14 @@ class ProfileWorkerStore(SQLBaseStore):
return self._simple_select_one(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
retcols=("displayname", "avatar_url",),
retcols=("displayname", "avatar_url"),
allow_none=True,
desc="get_from_remote_profile_cache",
)
def create_profile(self, user_localpart):
return self._simple_insert(
table="profiles",
values={"user_id": user_localpart},
desc="create_profile",
table="profiles", values={"user_id": user_localpart}, desc="create_profile"
)
def set_profile_displayname(self, user_localpart, new_displayname):
@ -141,6 +138,7 @@ class ProfileStore(ProfileWorkerStore):
def get_remote_profile_cache_entries_that_expire(self, last_checked):
"""Get all users who haven't been checked since `last_checked`
"""
def _get_remote_profile_cache_entries_that_expire_txn(txn):
sql = """
SELECT user_id, displayname, avatar_url

View file

@ -57,11 +57,13 @@ def _load_rules(rawrules, enabled_map):
return rules
class PushRulesWorkerStore(ApplicationServiceWorkerStore,
ReceiptsWorkerStore,
PusherWorkerStore,
RoomMemberWorkerStore,
SQLBaseStore):
class PushRulesWorkerStore(
ApplicationServiceWorkerStore,
ReceiptsWorkerStore,
PusherWorkerStore,
RoomMemberWorkerStore,
SQLBaseStore,
):
"""This is an abstract base class where subclasses must implement
`get_max_push_rules_stream_id` which can be called in the initializer.
"""
@ -74,14 +76,16 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
super(PushRulesWorkerStore, self).__init__(db_conn, hs)
push_rules_prefill, push_rules_id = self._get_cache_dict(
db_conn, "push_rules_stream",
db_conn,
"push_rules_stream",
entity_column="user_id",
stream_column="stream_id",
max_value=self.get_max_push_rules_stream_id(),
)
self.push_rules_stream_cache = StreamChangeCache(
"PushRulesStreamChangeCache", push_rules_id,
"PushRulesStreamChangeCache",
push_rules_id,
prefilled_cache=push_rules_prefill,
)
@ -98,19 +102,19 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
def get_push_rules_for_user(self, user_id):
rows = yield self._simple_select_list(
table="push_rules",
keyvalues={
"user_name": user_id,
},
keyvalues={"user_name": user_id},
retcols=(
"user_name", "rule_id", "priority_class", "priority",
"conditions", "actions",
"user_name",
"rule_id",
"priority_class",
"priority",
"conditions",
"actions",
),
desc="get_push_rules_enabled_for_user",
)
rows.sort(
key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))
)
rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
enabled_map = yield self.get_push_rules_enabled_for_user(user_id)
@ -122,22 +126,19 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
def get_push_rules_enabled_for_user(self, user_id):
results = yield self._simple_select_list(
table="push_rules_enable",
keyvalues={
'user_name': user_id
},
retcols=(
"user_name", "rule_id", "enabled",
),
keyvalues={'user_name': user_id},
retcols=("user_name", "rule_id", "enabled"),
desc="get_push_rules_enabled_for_user",
)
defer.returnValue({
r['rule_id']: False if r['enabled'] == 0 else True for r in results
})
defer.returnValue(
{r['rule_id']: False if r['enabled'] == 0 else True for r in results}
)
def have_push_rules_changed_for_user(self, user_id, last_id):
if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
return defer.succeed(False)
else:
def have_push_rules_changed_txn(txn):
sql = (
"SELECT COUNT(stream_id) FROM push_rules_stream"
@ -146,20 +147,22 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
txn.execute(sql, (user_id, last_id))
count, = txn.fetchone()
return bool(count)
return self.runInteraction(
"have_push_rules_changed", have_push_rules_changed_txn
)
@cachedList(cached_method_name="get_push_rules_for_user",
list_name="user_ids", num_args=1, inlineCallbacks=True)
@cachedList(
cached_method_name="get_push_rules_for_user",
list_name="user_ids",
num_args=1,
inlineCallbacks=True,
)
def bulk_get_push_rules(self, user_ids):
if not user_ids:
defer.returnValue({})
results = {
user_id: []
for user_id in user_ids
}
results = {user_id: [] for user_id in user_ids}
rows = yield self._simple_select_many_batch(
table="push_rules",
@ -169,9 +172,7 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
desc="bulk_get_push_rules",
)
rows.sort(
key=lambda row: (-int(row["priority_class"]), -int(row["priority"]))
)
rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
for row in rows:
results.setdefault(row['user_name'], []).append(row)
@ -179,16 +180,12 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids)
for user_id, rules in results.items():
results[user_id] = _load_rules(
rules, enabled_map_by_user.get(user_id, {})
)
results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {}))
defer.returnValue(results)
@defer.inlineCallbacks
def move_push_rule_from_room_to_room(
self, new_room_id, user_id, rule,
):
def move_push_rule_from_room_to_room(self, new_room_id, user_id, rule):
"""Move a single push rule from one room to another for a specific user.
Args:
@ -219,7 +216,7 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
@defer.inlineCallbacks
def move_push_rules_from_room_to_room_for_user(
self, old_room_id, new_room_id, user_id,
self, old_room_id, new_room_id, user_id
):
"""Move all of the push rules from one room to another for a specific
user.
@ -236,11 +233,11 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
# delete them from the old room
for rule in user_push_rules:
conditions = rule.get("conditions", [])
if any((c.get("key") == "room_id" and
c.get("pattern") == old_room_id) for c in conditions):
self.move_push_rule_from_room_to_room(
new_room_id, user_id, rule,
)
if any(
(c.get("key") == "room_id" and c.get("pattern") == old_room_id)
for c in conditions
):
self.move_push_rule_from_room_to_room(new_room_id, user_id, rule)
@defer.inlineCallbacks
def bulk_get_push_rules_for_room(self, event, context):
@ -259,8 +256,9 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
defer.returnValue(result)
@cachedInlineCallbacks(num_args=2, cache_context=True)
def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state_ids,
cache_context, event=None):
def _bulk_get_push_rules_for_room(
self, room_id, state_group, current_state_ids, cache_context, event=None
):
# We don't use `state_group`, its there so that we can cache based
# on it. However, its important that its never None, since two current_state's
# with a state_group of None are likely to be different.
@ -273,7 +271,9 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
# sent a read receipt into the room.
users_in_room = yield self._get_joined_users_from_context(
room_id, state_group, current_state_ids,
room_id,
state_group,
current_state_ids,
on_invalidate=cache_context.invalidate,
event=event,
)
@ -282,7 +282,8 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
# up the `get_if_users_have_pushers` cache with AS entries that we
# know don't have pushers, nor even read receipts.
local_users_in_room = set(
u for u in users_in_room
u
for u in users_in_room
if self.hs.is_mine_id(u)
and not self.get_if_app_services_interested_in_user(u)
)
@ -290,15 +291,14 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
# users in the room who have pushers need to get push rules run because
# that's how their pushers work
if_users_with_pushers = yield self.get_if_users_have_pushers(
local_users_in_room,
on_invalidate=cache_context.invalidate,
local_users_in_room, on_invalidate=cache_context.invalidate
)
user_ids = set(
uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
)
users_with_receipts = yield self.get_users_with_read_receipts_in_room(
room_id, on_invalidate=cache_context.invalidate,
room_id, on_invalidate=cache_context.invalidate
)
# any users with pushers must be ours: they have pushers
@ -307,29 +307,30 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
user_ids.add(uid)
rules_by_user = yield self.bulk_get_push_rules(
user_ids, on_invalidate=cache_context.invalidate,
user_ids, on_invalidate=cache_context.invalidate
)
rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
defer.returnValue(rules_by_user)
@cachedList(cached_method_name="get_push_rules_enabled_for_user",
list_name="user_ids", num_args=1, inlineCallbacks=True)
@cachedList(
cached_method_name="get_push_rules_enabled_for_user",
list_name="user_ids",
num_args=1,
inlineCallbacks=True,
)
def bulk_get_push_rules_enabled(self, user_ids):
if not user_ids:
defer.returnValue({})
results = {
user_id: {}
for user_id in user_ids
}
results = {user_id: {} for user_id in user_ids}
rows = yield self._simple_select_many_batch(
table="push_rules_enable",
column="user_name",
iterable=user_ids,
retcols=("user_name", "rule_id", "enabled",),
retcols=("user_name", "rule_id", "enabled"),
desc="bulk_get_push_rules_enabled",
)
for row in rows:
@ -341,8 +342,14 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
class PushRuleStore(PushRulesWorkerStore):
@defer.inlineCallbacks
def add_push_rule(
self, user_id, rule_id, priority_class, conditions, actions,
before=None, after=None
self,
user_id,
rule_id,
priority_class,
conditions,
actions,
before=None,
after=None,
):
conditions_json = json.dumps(conditions)
actions_json = json.dumps(actions)
@ -352,20 +359,41 @@ class PushRuleStore(PushRulesWorkerStore):
yield self.runInteraction(
"_add_push_rule_relative_txn",
self._add_push_rule_relative_txn,
stream_id, event_stream_ordering, user_id, rule_id, priority_class,
conditions_json, actions_json, before, after,
stream_id,
event_stream_ordering,
user_id,
rule_id,
priority_class,
conditions_json,
actions_json,
before,
after,
)
else:
yield self.runInteraction(
"_add_push_rule_highest_priority_txn",
self._add_push_rule_highest_priority_txn,
stream_id, event_stream_ordering, user_id, rule_id, priority_class,
conditions_json, actions_json,
stream_id,
event_stream_ordering,
user_id,
rule_id,
priority_class,
conditions_json,
actions_json,
)
def _add_push_rule_relative_txn(
self, txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class,
conditions_json, actions_json, before, after
self,
txn,
stream_id,
event_stream_ordering,
user_id,
rule_id,
priority_class,
conditions_json,
actions_json,
before,
after,
):
# Lock the table since otherwise we'll have annoying races between the
# SELECT here and the UPSERT below.
@ -376,10 +404,7 @@ class PushRuleStore(PushRulesWorkerStore):
res = self._simple_select_one_txn(
txn,
table="push_rules",
keyvalues={
"user_name": user_id,
"rule_id": relative_to_rule,
},
keyvalues={"user_name": user_id, "rule_id": relative_to_rule},
retcols=["priority_class", "priority"],
allow_none=True,
)
@ -416,13 +441,27 @@ class PushRuleStore(PushRulesWorkerStore):
txn.execute(sql, (user_id, priority_class, new_rule_priority))
self._upsert_push_rule_txn(
txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class,
new_rule_priority, conditions_json, actions_json,
txn,
stream_id,
event_stream_ordering,
user_id,
rule_id,
priority_class,
new_rule_priority,
conditions_json,
actions_json,
)
def _add_push_rule_highest_priority_txn(
self, txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class,
conditions_json, actions_json
self,
txn,
stream_id,
event_stream_ordering,
user_id,
rule_id,
priority_class,
conditions_json,
actions_json,
):
# Lock the table since otherwise we'll have annoying races between the
# SELECT here and the UPSERT below.
@ -443,13 +482,28 @@ class PushRuleStore(PushRulesWorkerStore):
self._upsert_push_rule_txn(
txn,
stream_id, event_stream_ordering, user_id, rule_id, priority_class, new_prio,
conditions_json, actions_json,
stream_id,
event_stream_ordering,
user_id,
rule_id,
priority_class,
new_prio,
conditions_json,
actions_json,
)
def _upsert_push_rule_txn(
self, txn, stream_id, event_stream_ordering, user_id, rule_id, priority_class,
priority, conditions_json, actions_json, update_stream=True
self,
txn,
stream_id,
event_stream_ordering,
user_id,
rule_id,
priority_class,
priority,
conditions_json,
actions_json,
update_stream=True,
):
"""Specialised version of _simple_upsert_txn that picks a push_rule_id
using the _push_rule_id_gen if it needs to insert the rule. It assumes
@ -461,10 +515,10 @@ class PushRuleStore(PushRulesWorkerStore):
" WHERE user_name = ? AND rule_id = ?"
)
txn.execute(sql, (
priority_class, priority, conditions_json, actions_json,
user_id, rule_id,
))
txn.execute(
sql,
(priority_class, priority, conditions_json, actions_json, user_id, rule_id),
)
if txn.rowcount == 0:
# We didn't update a row with the given rule_id so insert one
@ -486,14 +540,18 @@ class PushRuleStore(PushRulesWorkerStore):
if update_stream:
self._insert_push_rules_update_txn(
txn, stream_id, event_stream_ordering, user_id, rule_id,
txn,
stream_id,
event_stream_ordering,
user_id,
rule_id,
op="ADD",
data={
"priority_class": priority_class,
"priority": priority,
"conditions": conditions_json,
"actions": actions_json,
}
},
)
@defer.inlineCallbacks
@ -507,22 +565,23 @@ class PushRuleStore(PushRulesWorkerStore):
user_id (str): The matrix ID of the push rule owner
rule_id (str): The rule_id of the rule to be deleted
"""
def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
self._simple_delete_one_txn(
txn,
"push_rules",
{'user_name': user_id, 'rule_id': rule_id},
txn, "push_rules", {'user_name': user_id, 'rule_id': rule_id}
)
self._insert_push_rules_update_txn(
txn, stream_id, event_stream_ordering, user_id, rule_id,
op="DELETE"
txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
)
with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids
yield self.runInteraction(
"delete_push_rule", delete_push_rule_txn, stream_id, event_stream_ordering
"delete_push_rule",
delete_push_rule_txn,
stream_id,
event_stream_ordering,
)
@defer.inlineCallbacks
@ -532,7 +591,11 @@ class PushRuleStore(PushRulesWorkerStore):
yield self.runInteraction(
"_set_push_rule_enabled_txn",
self._set_push_rule_enabled_txn,
stream_id, event_stream_ordering, user_id, rule_id, enabled
stream_id,
event_stream_ordering,
user_id,
rule_id,
enabled,
)
def _set_push_rule_enabled_txn(
@ -548,8 +611,12 @@ class PushRuleStore(PushRulesWorkerStore):
)
self._insert_push_rules_update_txn(
txn, stream_id, event_stream_ordering, user_id, rule_id,
op="ENABLE" if enabled else "DISABLE"
txn,
stream_id,
event_stream_ordering,
user_id,
rule_id,
op="ENABLE" if enabled else "DISABLE",
)
@defer.inlineCallbacks
@ -563,9 +630,16 @@ class PushRuleStore(PushRulesWorkerStore):
priority_class = -1
priority = 1
self._upsert_push_rule_txn(
txn, stream_id, event_stream_ordering, user_id, rule_id,
priority_class, priority, "[]", actions_json,
update_stream=False
txn,
stream_id,
event_stream_ordering,
user_id,
rule_id,
priority_class,
priority,
"[]",
actions_json,
update_stream=False,
)
else:
self._simple_update_one_txn(
@ -576,15 +650,22 @@ class PushRuleStore(PushRulesWorkerStore):
)
self._insert_push_rules_update_txn(
txn, stream_id, event_stream_ordering, user_id, rule_id,
op="ACTIONS", data={"actions": actions_json}
txn,
stream_id,
event_stream_ordering,
user_id,
rule_id,
op="ACTIONS",
data={"actions": actions_json},
)
with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids
yield self.runInteraction(
"set_push_rule_actions", set_push_rule_actions_txn,
stream_id, event_stream_ordering
"set_push_rule_actions",
set_push_rule_actions_txn,
stream_id,
event_stream_ordering,
)
def _insert_push_rules_update_txn(
@ -602,12 +683,8 @@ class PushRuleStore(PushRulesWorkerStore):
self._simple_insert_txn(txn, "push_rules_stream", values=values)
txn.call_after(
self.get_push_rules_for_user.invalidate, (user_id,)
)
txn.call_after(
self.get_push_rules_enabled_for_user.invalidate, (user_id,)
)
txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,))
txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,))
txn.call_after(
self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
)
@ -627,6 +704,7 @@ class PushRuleStore(PushRulesWorkerStore):
)
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()
return self.runInteraction(
"get_all_push_rule_updates", get_all_push_rule_updates_txn
)

View file

@ -47,7 +47,9 @@ class PusherWorkerStore(SQLBaseStore):
except Exception as e:
logger.warn(
"Invalid JSON in data for pusher %d: %s, %s",
r['id'], dataJson, e.args[0],
r['id'],
dataJson,
e.args[0],
)
pass
@ -64,20 +66,16 @@ class PusherWorkerStore(SQLBaseStore):
defer.returnValue(ret is not None)
def get_pushers_by_app_id_and_pushkey(self, app_id, pushkey):
return self.get_pushers_by({
"app_id": app_id,
"pushkey": pushkey,
})
return self.get_pushers_by({"app_id": app_id, "pushkey": pushkey})
def get_pushers_by_user_id(self, user_id):
return self.get_pushers_by({
"user_name": user_id,
})
return self.get_pushers_by({"user_name": user_id})
@defer.inlineCallbacks
def get_pushers_by(self, keyvalues):
ret = yield self._simple_select_list(
"pushers", keyvalues,
"pushers",
keyvalues,
[
"id",
"user_name",
@ -94,7 +92,8 @@ class PusherWorkerStore(SQLBaseStore):
"last_stream_ordering",
"last_success",
"failing_since",
], desc="get_pushers_by"
],
desc="get_pushers_by",
)
defer.returnValue(self._decode_pushers_rows(ret))
@ -135,6 +134,7 @@ class PusherWorkerStore(SQLBaseStore):
deleted = txn.fetchall()
return (updated, deleted)
return self.runInteraction(
"get_all_updated_pushers", get_all_updated_pushers_txn
)
@ -177,6 +177,7 @@ class PusherWorkerStore(SQLBaseStore):
results.sort() # Sort so that they're ordered by stream id
return results
return self.runInteraction(
"get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn
)
@ -186,15 +187,19 @@ class PusherWorkerStore(SQLBaseStore):
# This only exists for the cachedList decorator
raise NotImplementedError()
@cachedList(cached_method_name="get_if_user_has_pusher",
list_name="user_ids", num_args=1, inlineCallbacks=True)
@cachedList(
cached_method_name="get_if_user_has_pusher",
list_name="user_ids",
num_args=1,
inlineCallbacks=True,
)
def get_if_users_have_pushers(self, user_ids):
rows = yield self._simple_select_many_batch(
table='pushers',
column='user_name',
iterable=user_ids,
retcols=['user_name'],
desc='get_if_users_have_pushers'
desc='get_if_users_have_pushers',
)
result = {user_id: False for user_id in user_ids}
@ -208,20 +213,27 @@ class PusherStore(PusherWorkerStore):
return self._pushers_id_gen.get_current_token()
@defer.inlineCallbacks
def add_pusher(self, user_id, access_token, kind, app_id,
app_display_name, device_display_name,
pushkey, pushkey_ts, lang, data, last_stream_ordering,
profile_tag=""):
def add_pusher(
self,
user_id,
access_token,
kind,
app_id,
app_display_name,
device_display_name,
pushkey,
pushkey_ts,
lang,
data,
last_stream_ordering,
profile_tag="",
):
with self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on
# (app_id, pushkey, user_name) so _simple_upsert will retry
yield self._simple_upsert(
table="pushers",
keyvalues={
"app_id": app_id,
"pushkey": pushkey,
"user_name": user_id,
},
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
values={
"access_token": access_token,
"kind": kind,
@ -247,7 +259,8 @@ class PusherStore(PusherWorkerStore):
yield self.runInteraction(
"add_pusher",
self._invalidate_cache_and_stream,
self.get_if_user_has_pusher, (user_id,)
self.get_if_user_has_pusher,
(user_id,),
)
@defer.inlineCallbacks
@ -260,7 +273,7 @@ class PusherStore(PusherWorkerStore):
self._simple_delete_one_txn(
txn,
"pushers",
{"app_id": app_id, "pushkey": pushkey, "user_name": user_id}
{"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
)
# it's possible for us to end up with duplicate rows for
@ -278,13 +291,12 @@ class PusherStore(PusherWorkerStore):
)
with self._pushers_id_gen.get_next() as stream_id:
yield self.runInteraction(
"delete_pusher", delete_pusher_txn, stream_id
)
yield self.runInteraction("delete_pusher", delete_pusher_txn, stream_id)
@defer.inlineCallbacks
def update_pusher_last_stream_ordering(self, app_id, pushkey, user_id,
last_stream_ordering):
def update_pusher_last_stream_ordering(
self, app_id, pushkey, user_id, last_stream_ordering
):
yield self._simple_update_one(
"pushers",
{'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id},
@ -293,23 +305,21 @@ class PusherStore(PusherWorkerStore):
)
@defer.inlineCallbacks
def update_pusher_last_stream_ordering_and_success(self, app_id, pushkey,
user_id,
last_stream_ordering,
last_success):
def update_pusher_last_stream_ordering_and_success(
self, app_id, pushkey, user_id, last_stream_ordering, last_success
):
yield self._simple_update_one(
"pushers",
{'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id},
{
'last_stream_ordering': last_stream_ordering,
'last_success': last_success
'last_success': last_success,
},
desc="update_pusher_last_stream_ordering_and_success",
)
@defer.inlineCallbacks
def update_pusher_failing_since(self, app_id, pushkey, user_id,
failing_since):
def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since):
yield self._simple_update_one(
"pushers",
{'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id},
@ -323,14 +333,14 @@ class PusherStore(PusherWorkerStore):
"pusher_throttle",
{"pusher": pusher_id},
["room_id", "last_sent_ts", "throttle_ms"],
desc="get_throttle_params_by_room"
desc="get_throttle_params_by_room",
)
params_by_room = {}
for row in res:
params_by_room[row["room_id"]] = {
"last_sent_ts": row["last_sent_ts"],
"throttle_ms": row["throttle_ms"]
"throttle_ms": row["throttle_ms"],
}
defer.returnValue(params_by_room)

View file

@ -64,10 +64,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
def get_receipts_for_room(self, room_id, receipt_type):
return self._simple_select_list(
table="receipts_linearized",
keyvalues={
"room_id": room_id,
"receipt_type": receipt_type,
},
keyvalues={"room_id": room_id, "receipt_type": receipt_type},
retcols=("user_id", "event_id"),
desc="get_receipts_for_room",
)
@ -79,7 +76,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
keyvalues={
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id
"user_id": user_id,
},
retcol="event_id",
desc="get_own_receipt_for_user",
@ -90,10 +87,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
def get_receipts_for_user(self, user_id, receipt_type):
rows = yield self._simple_select_list(
table="receipts_linearized",
keyvalues={
"user_id": user_id,
"receipt_type": receipt_type,
},
keyvalues={"user_id": user_id, "receipt_type": receipt_type},
retcols=("room_id", "event_id"),
desc="get_receipts_for_user",
)
@ -114,16 +108,18 @@ class ReceiptsWorkerStore(SQLBaseStore):
)
txn.execute(sql, (user_id,))
return txn.fetchall()
rows = yield self.runInteraction(
"get_receipts_for_user_with_orderings", f
rows = yield self.runInteraction("get_receipts_for_user_with_orderings", f)
defer.returnValue(
{
row[0]: {
"event_id": row[1],
"topological_ordering": row[2],
"stream_ordering": row[3],
}
for row in rows
}
)
defer.returnValue({
row[0]: {
"event_id": row[1],
"topological_ordering": row[2],
"stream_ordering": row[3],
} for row in rows
})
@defer.inlineCallbacks
def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
@ -177,6 +173,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
def _get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
"""See get_linearized_receipts_for_room
"""
def f(txn):
if from_key:
sql = (
@ -184,48 +181,40 @@ class ReceiptsWorkerStore(SQLBaseStore):
" room_id = ? AND stream_id > ? AND stream_id <= ?"
)
txn.execute(
sql,
(room_id, from_key, to_key)
)
txn.execute(sql, (room_id, from_key, to_key))
else:
sql = (
"SELECT * FROM receipts_linearized WHERE"
" room_id = ? AND stream_id <= ?"
)
txn.execute(
sql,
(room_id, to_key)
)
txn.execute(sql, (room_id, to_key))
rows = self.cursor_to_dict(txn)
return rows
rows = yield self.runInteraction(
"get_linearized_receipts_for_room", f
)
rows = yield self.runInteraction("get_linearized_receipts_for_room", f)
if not rows:
defer.returnValue([])
content = {}
for row in rows:
content.setdefault(
row["event_id"], {}
).setdefault(
row["receipt_type"], {}
)[row["user_id"]] = json.loads(row["data"])
content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[
row["user_id"]
] = json.loads(row["data"])
defer.returnValue([{
"type": "m.receipt",
"room_id": room_id,
"content": content,
}])
defer.returnValue(
[{"type": "m.receipt", "room_id": room_id, "content": content}]
)
@cachedList(cached_method_name="_get_linearized_receipts_for_room",
list_name="room_ids", num_args=3, inlineCallbacks=True)
@cachedList(
cached_method_name="_get_linearized_receipts_for_room",
list_name="room_ids",
num_args=3,
inlineCallbacks=True,
)
def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
if not room_ids:
defer.returnValue({})
@ -235,9 +224,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
sql = (
"SELECT * FROM receipts_linearized WHERE"
" room_id IN (%s) AND stream_id > ? AND stream_id <= ?"
) % (
",".join(["?"] * len(room_ids))
)
) % (",".join(["?"] * len(room_ids)))
args = list(room_ids)
args.extend([from_key, to_key])
@ -246,9 +233,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
sql = (
"SELECT * FROM receipts_linearized WHERE"
" room_id IN (%s) AND stream_id <= ?"
) % (
",".join(["?"] * len(room_ids))
)
) % (",".join(["?"] * len(room_ids)))
args = list(room_ids)
args.append(to_key)
@ -257,19 +242,16 @@ class ReceiptsWorkerStore(SQLBaseStore):
return self.cursor_to_dict(txn)
txn_results = yield self.runInteraction(
"_get_linearized_receipts_for_rooms", f
)
txn_results = yield self.runInteraction("_get_linearized_receipts_for_rooms", f)
results = {}
for row in txn_results:
# We want a single event per room, since we want to batch the
# receipts by room, event and type.
room_event = results.setdefault(row["room_id"], {
"type": "m.receipt",
"room_id": row["room_id"],
"content": {},
})
room_event = results.setdefault(
row["room_id"],
{"type": "m.receipt", "room_id": row["room_id"], "content": {}},
)
# The content is of the form:
# {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
@ -301,21 +283,21 @@ class ReceiptsWorkerStore(SQLBaseStore):
args.append(limit)
txn.execute(sql, args)
return (
r[0:5] + (json.loads(r[5]), ) for r in txn
)
return (r[0:5] + (json.loads(r[5]),) for r in txn)
return self.runInteraction(
"get_all_updated_receipts", get_all_updated_receipts_txn
)
def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type,
user_id):
def _invalidate_get_users_with_receipts_in_room(
self, room_id, receipt_type, user_id
):
if receipt_type != "m.read":
return
# Returns either an ObservableDeferred or the raw result
res = self.get_users_with_read_receipts_in_room.cache.get(
room_id, None, update_metrics=False,
room_id, None, update_metrics=False
)
# first handle the Deferred case
@ -346,8 +328,9 @@ class ReceiptsStore(ReceiptsWorkerStore):
def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_current_token()
def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
user_id, event_id, data, stream_id):
def insert_linearized_receipt_txn(
self, txn, room_id, receipt_type, user_id, event_id, data, stream_id
):
"""Inserts a read-receipt into the database if it's newer than the current RR
Returns: int|None
@ -360,7 +343,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
table="events",
retcols=["stream_ordering", "received_ts"],
keyvalues={"event_id": event_id},
allow_none=True
allow_none=True,
)
stream_ordering = int(res["stream_ordering"]) if res else None
@ -381,31 +364,31 @@ class ReceiptsStore(ReceiptsWorkerStore):
logger.debug(
"Ignoring new receipt for %s in favour of existing "
"one for later event %s",
event_id, eid,
event_id,
eid,
)
return None
txn.call_after(
self.get_receipts_for_room.invalidate, (room_id, receipt_type)
)
txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))
txn.call_after(
self._invalidate_get_users_with_receipts_in_room,
room_id, receipt_type, user_id,
)
txn.call_after(
self.get_receipts_for_user.invalidate, (user_id, receipt_type)
room_id,
receipt_type,
user_id,
)
txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type))
# FIXME: This shouldn't invalidate the whole cache
txn.call_after(self._get_linearized_receipts_for_room.invalidate_many, (room_id,))
txn.call_after(
self._get_linearized_receipts_for_room.invalidate_many, (room_id,)
)
txn.call_after(
self._receipts_stream_cache.entity_has_changed,
room_id, stream_id
self._receipts_stream_cache.entity_has_changed, room_id, stream_id
)
txn.call_after(
self.get_last_receipt_event_id_for_user.invalidate,
(user_id, room_id, receipt_type)
(user_id, room_id, receipt_type),
)
self._simple_delete_txn(
@ -415,7 +398,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
}
},
)
self._simple_insert_txn(
@ -428,15 +411,12 @@ class ReceiptsStore(ReceiptsWorkerStore):
"user_id": user_id,
"event_id": event_id,
"data": json.dumps(data),
}
},
)
if receipt_type == "m.read" and stream_ordering is not None:
self._remove_old_push_actions_before_txn(
txn,
room_id=room_id,
user_id=user_id,
stream_ordering=stream_ordering,
txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering
)
return rx_ts
@ -479,7 +459,10 @@ class ReceiptsStore(ReceiptsWorkerStore):
event_ts = yield self.runInteraction(
"insert_linearized_receipt",
self.insert_linearized_receipt_txn,
room_id, receipt_type, user_id, linearized_event_id,
room_id,
receipt_type,
user_id,
linearized_event_id,
data,
stream_id=stream_id,
)
@ -490,39 +473,43 @@ class ReceiptsStore(ReceiptsWorkerStore):
now = self._clock.time_msec()
logger.debug(
"RR for event %s in %s (%i ms old)",
linearized_event_id, room_id, now - event_ts,
linearized_event_id,
room_id,
now - event_ts,
)
yield self.insert_graph_receipt(
room_id, receipt_type, user_id, event_ids, data
)
yield self.insert_graph_receipt(room_id, receipt_type, user_id, event_ids, data)
max_persisted_id = self._receipts_id_gen.get_current_token()
defer.returnValue((stream_id, max_persisted_id))
def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids,
data):
def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data):
return self.runInteraction(
"insert_graph_receipt",
self.insert_graph_receipt_txn,
room_id, receipt_type, user_id, event_ids, data
room_id,
receipt_type,
user_id,
event_ids,
data,
)
def insert_graph_receipt_txn(self, txn, room_id, receipt_type,
user_id, event_ids, data):
txn.call_after(
self.get_receipts_for_room.invalidate, (room_id, receipt_type)
)
def insert_graph_receipt_txn(
self, txn, room_id, receipt_type, user_id, event_ids, data
):
txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type))
txn.call_after(
self._invalidate_get_users_with_receipts_in_room,
room_id, receipt_type, user_id,
)
txn.call_after(
self.get_receipts_for_user.invalidate, (user_id, receipt_type)
room_id,
receipt_type,
user_id,
)
txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type))
# FIXME: This shouldn't invalidate the whole cache
txn.call_after(self._get_linearized_receipts_for_room.invalidate_many, (room_id,))
txn.call_after(
self._get_linearized_receipts_for_room.invalidate_many, (room_id,)
)
self._simple_delete_txn(
txn,
@ -531,7 +518,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
}
},
)
self._simple_insert_txn(
txn,
@ -542,5 +529,5 @@ class ReceiptsStore(ReceiptsWorkerStore):
"user_id": user_id,
"event_ids": json.dumps(event_ids),
"data": json.dumps(data),
}
},
)

View file

@ -37,13 +37,15 @@ class RegistrationWorkerStore(SQLBaseStore):
def get_user_by_id(self, user_id):
return self._simple_select_one(
table="users",
keyvalues={
"name": user_id,
},
keyvalues={"name": user_id},
retcols=[
"name", "password_hash", "is_guest",
"consent_version", "consent_server_notice_sent",
"appservice_id", "creation_ts",
"name",
"password_hash",
"is_guest",
"consent_version",
"consent_server_notice_sent",
"appservice_id",
"creation_ts",
],
allow_none=True,
desc="get_user_by_id",
@ -81,9 +83,7 @@ class RegistrationWorkerStore(SQLBaseStore):
including the keys `name`, `is_guest`, `device_id`, `token_id`.
"""
return self.runInteraction(
"get_user_by_access_token",
self._query_for_auth,
token
"get_user_by_access_token", self._query_for_auth, token
)
@defer.inlineCallbacks
@ -143,10 +143,10 @@ class RegistrationWorkerStore(SQLBaseStore):
"""Gets users that match user_id case insensitively.
Returns a mapping of user_id -> password_hash.
"""
def f(txn):
sql = (
"SELECT name, password_hash FROM users"
" WHERE lower(name) = lower(?)"
"SELECT name, password_hash FROM users" " WHERE lower(name) = lower(?)"
)
txn.execute(sql, (user_id,))
return dict(txn)
@ -156,6 +156,7 @@ class RegistrationWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def count_all_users(self):
"""Counts all users registered on the homeserver."""
def _count_users(txn):
txn.execute("SELECT COUNT(*) AS users FROM users")
rows = self.cursor_to_dict(txn)
@ -173,6 +174,7 @@ class RegistrationWorkerStore(SQLBaseStore):
3) bridged users
who registered on the homeserver in the past 24 hours
"""
def _count_daily_user_type(txn):
yesterday = int(self._clock.time()) - (60 * 60 * 24)
@ -193,15 +195,18 @@ class RegistrationWorkerStore(SQLBaseStore):
for row in txn:
results[row[0]] = row[1]
return results
return self.runInteraction("count_daily_user_type", _count_daily_user_type)
@defer.inlineCallbacks
def count_nonbridged_users(self):
def _count_users(txn):
txn.execute("""
txn.execute(
"""
SELECT COALESCE(COUNT(*), 0) FROM users
WHERE appservice_id IS NULL
""")
"""
)
count, = txn.fetchone()
return count
@ -220,6 +225,7 @@ class RegistrationWorkerStore(SQLBaseStore):
avoid the case of ID 10000000 being pre-allocated, so us wasting the
first (and shortest) many generated user IDs.
"""
def _find_next_generated_user_id(txn):
txn.execute("SELECT name FROM users")
@ -227,7 +233,7 @@ class RegistrationWorkerStore(SQLBaseStore):
found = set()
for user_id, in txn:
for (user_id,) in txn:
match = regex.search(user_id)
if match:
found.add(int(match.group(1)))
@ -235,20 +241,22 @@ class RegistrationWorkerStore(SQLBaseStore):
if i not in found:
return i
defer.returnValue((yield self.runInteraction(
"find_next_generated_user_id",
_find_next_generated_user_id
)))
defer.returnValue(
(
yield self.runInteraction(
"find_next_generated_user_id", _find_next_generated_user_id
)
)
)
@defer.inlineCallbacks
def get_3pid_guest_access_token(self, medium, address):
ret = yield self._simple_select_one(
"threepid_guest_access_tokens",
{
"medium": medium,
"address": address
},
["guest_access_token"], True, 'get_3pid_guest_access_token'
{"medium": medium, "address": address},
["guest_access_token"],
True,
'get_3pid_guest_access_token',
)
if ret:
defer.returnValue(ret["guest_access_token"])
@ -266,8 +274,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Deferred[str|None]: user id or None if no user id/threepid mapping exists
"""
user_id = yield self.runInteraction(
"get_user_id_by_threepid", self.get_user_id_by_threepid_txn,
medium, address
"get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address
)
defer.returnValue(user_id)
@ -285,11 +292,9 @@ class RegistrationWorkerStore(SQLBaseStore):
ret = self._simple_select_one_txn(
txn,
"user_threepids",
{
"medium": medium,
"address": address
},
['user_id'], True
{"medium": medium, "address": address},
['user_id'],
True,
)
if ret:
return ret['user_id']
@ -297,41 +302,110 @@ class RegistrationWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
yield self._simple_upsert("user_threepids", {
"medium": medium,
"address": address,
}, {
"user_id": user_id,
"validated_at": validated_at,
"added_at": added_at,
})
yield self._simple_upsert(
"user_threepids",
{"medium": medium, "address": address},
{"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
)
@defer.inlineCallbacks
def user_get_threepids(self, user_id):
ret = yield self._simple_select_list(
"user_threepids", {
"user_id": user_id
},
"user_threepids",
{"user_id": user_id},
['medium', 'address', 'validated_at', 'added_at'],
'user_get_threepids'
'user_get_threepids',
)
defer.returnValue(ret)
def user_delete_threepid(self, user_id, medium, address):
return self._simple_delete(
"user_threepids",
keyvalues={"user_id": user_id, "medium": medium, "address": address},
desc="user_delete_threepids",
)
def add_user_bound_threepid(self, user_id, medium, address, id_server):
"""The server proxied a bind request to the given identity server on
behalf of the given user. We need to remember this in case the user
asks us to unbind the threepid.
Args:
user_id (str)
medium (str)
address (str)
id_server (str)
Returns:
Deferred
"""
# We need to use an upsert, in case they user had already bound the
# threepid
return self._simple_upsert(
table="user_threepid_id_server",
keyvalues={
"user_id": user_id,
"medium": medium,
"address": address,
"id_server": id_server,
},
values={},
insertion_values={},
desc="add_user_bound_threepid",
)
def remove_user_bound_threepid(self, user_id, medium, address, id_server):
"""The server proxied an unbind request to the given identity server on
behalf of the given user, so we remove the mapping of threepid to
identity server.
Args:
user_id (str)
medium (str)
address (str)
id_server (str)
Returns:
Deferred
"""
return self._simple_delete(
table="user_threepid_id_server",
keyvalues={
"user_id": user_id,
"medium": medium,
"address": address,
"id_server": id_server,
},
desc="remove_user_bound_threepid",
)
def get_id_servers_user_bound(self, user_id, medium, address):
"""Get the list of identity servers that the server proxied bind
requests to for given user and threepid
Args:
user_id (str)
medium (str)
address (str)
Returns:
Deferred[list[str]]: Resolves to a list of identity servers
"""
return self._simple_select_onecol(
table="user_threepid_id_server",
keyvalues={
"user_id": user_id,
"medium": medium,
"address": address,
},
desc="user_delete_threepids",
retcol="id_server",
desc="get_id_servers_user_bound",
)
class RegistrationStore(RegistrationWorkerStore,
background_updates.BackgroundUpdateStore):
class RegistrationStore(
RegistrationWorkerStore, background_updates.BackgroundUpdateStore
):
def __init__(self, db_conn, hs):
super(RegistrationStore, self).__init__(db_conn, hs)
@ -356,6 +430,10 @@ class RegistrationStore(RegistrationWorkerStore,
# clear the background update.
self.register_noop_background_update("refresh_tokens_device_index")
self.register_background_update_handler(
"user_threepids_grandfather", self._bg_user_threepids_grandfather,
)
@defer.inlineCallbacks
def add_access_token_to_user(self, user_id, token, device_id=None):
"""Adds an access token for the given user.
@ -372,18 +450,22 @@ class RegistrationStore(RegistrationWorkerStore,
yield self._simple_insert(
"access_tokens",
{
"id": next_id,
"user_id": user_id,
"token": token,
"device_id": device_id,
},
{"id": next_id, "user_id": user_id, "token": token, "device_id": device_id},
desc="add_access_token_to_user",
)
def register(self, user_id, token=None, password_hash=None,
was_guest=False, make_guest=False, appservice_id=None,
create_profile_with_displayname=None, admin=False, user_type=None):
def register(
self,
user_id,
token=None,
password_hash=None,
was_guest=False,
make_guest=False,
appservice_id=None,
create_profile_with_displayname=None,
admin=False,
user_type=None,
):
"""Attempts to register an account.
Args:
@ -417,7 +499,7 @@ class RegistrationStore(RegistrationWorkerStore,
appservice_id,
create_profile_with_displayname,
admin,
user_type
user_type,
)
def _register(
@ -447,10 +529,7 @@ class RegistrationStore(RegistrationWorkerStore,
self._simple_select_one_txn(
txn,
"users",
keyvalues={
"name": user_id,
"is_guest": 1,
},
keyvalues={"name": user_id, "is_guest": 1},
retcols=("name",),
allow_none=False,
)
@ -458,10 +537,7 @@ class RegistrationStore(RegistrationWorkerStore,
self._simple_update_one_txn(
txn,
"users",
keyvalues={
"name": user_id,
"is_guest": 1,
},
keyvalues={"name": user_id, "is_guest": 1},
updatevalues={
"password_hash": password_hash,
"upgrade_ts": now,
@ -469,7 +545,7 @@ class RegistrationStore(RegistrationWorkerStore,
"appservice_id": appservice_id,
"admin": 1 if admin else 0,
"user_type": user_type,
}
},
)
else:
self._simple_insert_txn(
@ -483,20 +559,17 @@ class RegistrationStore(RegistrationWorkerStore,
"appservice_id": appservice_id,
"admin": 1 if admin else 0,
"user_type": user_type,
}
},
)
except self.database_engine.module.IntegrityError:
raise StoreError(
400, "User ID already taken.", errcode=Codes.USER_IN_USE
)
raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE)
if token:
# it's possible for this to get a conflict, but only for a single user
# since tokens are namespaced based on their user ID
txn.execute(
"INSERT INTO access_tokens(id, user_id, token)"
" VALUES (?,?,?)",
(next_id, user_id, token,)
"INSERT INTO access_tokens(id, user_id, token)" " VALUES (?,?,?)",
(next_id, user_id, token),
)
if create_profile_with_displayname:
@ -507,12 +580,10 @@ class RegistrationStore(RegistrationWorkerStore,
# while everything else uses the full mxid.
txn.execute(
"INSERT INTO profiles(user_id, displayname) VALUES (?,?)",
(user_id_obj.localpart, create_profile_with_displayname)
(user_id_obj.localpart, create_profile_with_displayname),
)
self._invalidate_cache_and_stream(
txn, self.get_user_by_id, (user_id,)
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
txn.call_after(self.is_guest.invalidate, (user_id,))
def user_set_password_hash(self, user_id, password_hash):
@ -521,22 +592,14 @@ class RegistrationStore(RegistrationWorkerStore,
removes most of the entries subsequently anyway so it would be
pointless. Use flush_user separately.
"""
def user_set_password_hash_txn(txn):
self._simple_update_one_txn(
txn,
'users', {
'name': user_id
},
{
'password_hash': password_hash
}
txn, 'users', {'name': user_id}, {'password_hash': password_hash}
)
self._invalidate_cache_and_stream(
txn, self.get_user_by_id, (user_id,)
)
return self.runInteraction(
"user_set_password_hash", user_set_password_hash_txn
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
return self.runInteraction("user_set_password_hash", user_set_password_hash_txn)
def user_set_consent_version(self, user_id, consent_version):
"""Updates the user table to record privacy policy consent
@ -549,16 +612,16 @@ class RegistrationStore(RegistrationWorkerStore,
Raises:
StoreError(404) if user not found
"""
def f(txn):
self._simple_update_one_txn(
txn,
table='users',
keyvalues={'name': user_id, },
updatevalues={'consent_version': consent_version, },
)
self._invalidate_cache_and_stream(
txn, self.get_user_by_id, (user_id,)
keyvalues={'name': user_id},
updatevalues={'consent_version': consent_version},
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
return self.runInteraction("user_set_consent_version", f)
def user_set_consent_server_notice_sent(self, user_id, consent_version):
@ -573,20 +636,19 @@ class RegistrationStore(RegistrationWorkerStore,
Raises:
StoreError(404) if user not found
"""
def f(txn):
self._simple_update_one_txn(
txn,
table='users',
keyvalues={'name': user_id, },
updatevalues={'consent_server_notice_sent': consent_version, },
)
self._invalidate_cache_and_stream(
txn, self.get_user_by_id, (user_id,)
keyvalues={'name': user_id},
updatevalues={'consent_server_notice_sent': consent_version},
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
return self.runInteraction("user_set_consent_server_notice_sent", f)
def user_delete_access_tokens(self, user_id, except_token_id=None,
device_id=None):
def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None):
"""
Invalidate access tokens belonging to a user
@ -601,10 +663,9 @@ class RegistrationStore(RegistrationWorkerStore,
defer.Deferred[list[str, int, str|None, int]]: a list of
(token, token id, device id) for each of the deleted tokens
"""
def f(txn):
keyvalues = {
"user_id": user_id,
}
keyvalues = {"user_id": user_id}
if device_id is not None:
keyvalues["device_id"] = device_id
@ -616,8 +677,9 @@ class RegistrationStore(RegistrationWorkerStore,
values.append(except_token_id)
txn.execute(
"SELECT token, id, device_id FROM access_tokens WHERE %s" % where_clause,
values
"SELECT token, id, device_id FROM access_tokens WHERE %s"
% where_clause,
values,
)
tokens_and_devices = [(r[0], r[1], r[2]) for r in txn]
@ -626,25 +688,16 @@ class RegistrationStore(RegistrationWorkerStore,
txn, self.get_user_by_access_token, (token,)
)
txn.execute(
"DELETE FROM access_tokens WHERE %s" % where_clause,
values
)
txn.execute("DELETE FROM access_tokens WHERE %s" % where_clause, values)
return tokens_and_devices
return self.runInteraction(
"user_delete_access_tokens", f,
)
return self.runInteraction("user_delete_access_tokens", f)
def delete_access_token(self, access_token):
def f(txn):
self._simple_delete_one_txn(
txn,
table="access_tokens",
keyvalues={
"token": access_token
},
txn, table="access_tokens", keyvalues={"token": access_token}
)
self._invalidate_cache_and_stream(
@ -667,7 +720,7 @@ class RegistrationStore(RegistrationWorkerStore,
@defer.inlineCallbacks
def save_or_get_3pid_guest_access_token(
self, medium, address, access_token, inviter_user_id
self, medium, address, access_token, inviter_user_id
):
"""
Gets the 3pid's guest access token if exists, else saves access_token.
@ -683,12 +736,13 @@ class RegistrationStore(RegistrationWorkerStore,
deferred str: Whichever access token is persisted at the end
of this function call.
"""
def insert(txn):
txn.execute(
"INSERT INTO threepid_guest_access_tokens "
"(medium, address, guest_access_token, first_inviter) "
"VALUES (?, ?, ?, ?)",
(medium, address, access_token, inviter_user_id)
(medium, address, access_token, inviter_user_id),
)
try:
@ -705,9 +759,7 @@ class RegistrationStore(RegistrationWorkerStore,
"""
return self._simple_insert(
"users_pending_deactivation",
values={
"user_id": user_id,
},
values={"user_id": user_id},
desc="add_user_pending_deactivation",
)
@ -720,9 +772,7 @@ class RegistrationStore(RegistrationWorkerStore,
# the table, so somehow duplicate entries have ended up in it.
return self._simple_delete(
"users_pending_deactivation",
keyvalues={
"user_id": user_id,
},
keyvalues={"user_id": user_id},
desc="del_user_pending_deactivation",
)
@ -738,3 +788,34 @@ class RegistrationStore(RegistrationWorkerStore,
allow_none=True,
desc="get_users_pending_deactivation",
)
@defer.inlineCallbacks
def _bg_user_threepids_grandfather(self, progress, batch_size):
"""We now track which identity servers a user binds their 3PID to, so
we need to handle the case of existing bindings where we didn't track
this.
We do this by grandfathering in existing user threepids assuming that
they used one of the server configured trusted identity servers.
"""
id_servers = set(self.config.trusted_third_party_id_servers)
def _bg_user_threepids_grandfather_txn(txn):
sql = """
INSERT INTO user_threepid_id_server
(user_id, medium, address, id_server)
SELECT user_id, medium, address, ?
FROM user_threepids
"""
txn.executemany(sql, [(id_server,) for id_server in id_servers])
if id_servers:
yield self.runInteraction(
"_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn,
)
yield self._end_background_update("user_threepids_grandfather")
defer.returnValue(1)

View file

@ -36,9 +36,7 @@ class RejectionsStore(SQLBaseStore):
return self._simple_select_one_onecol(
table="rejections",
retcol="reason",
keyvalues={
"event_id": event_id,
},
keyvalues={"event_id": event_id},
allow_none=True,
desc="get_rejection_reason",
)

View file

@ -30,13 +30,11 @@ logger = logging.getLogger(__name__)
OpsLevel = collections.namedtuple(
"OpsLevel",
("ban_level", "kick_level", "redact_level",)
"OpsLevel", ("ban_level", "kick_level", "redact_level")
)
RatelimitOverride = collections.namedtuple(
"RatelimitOverride",
("messages_per_second", "burst_count",)
"RatelimitOverride", ("messages_per_second", "burst_count")
)
@ -60,9 +58,7 @@ class RoomWorkerStore(SQLBaseStore):
def get_public_room_ids(self):
return self._simple_select_onecol(
table="rooms",
keyvalues={
"is_public": True,
},
keyvalues={"is_public": True},
retcol="room_id",
desc="get_public_room_ids",
)
@ -79,11 +75,11 @@ class RoomWorkerStore(SQLBaseStore):
return self.runInteraction(
"get_public_room_ids_at_stream_id",
self.get_public_room_ids_at_stream_id_txn,
stream_id, network_tuple=network_tuple
stream_id,
network_tuple=network_tuple,
)
def get_public_room_ids_at_stream_id_txn(self, txn, stream_id,
network_tuple):
def get_public_room_ids_at_stream_id_txn(self, txn, stream_id, network_tuple):
return {
rm
for rm, vis in self.get_published_at_stream_id_txn(
@ -96,7 +92,7 @@ class RoomWorkerStore(SQLBaseStore):
if network_tuple:
# We want to get from a particular list. No aggregation required.
sql = ("""
sql = """
SELECT room_id, visibility FROM public_room_list_stream
INNER JOIN (
SELECT room_id, max(stream_id) AS stream_id
@ -104,25 +100,22 @@ class RoomWorkerStore(SQLBaseStore):
WHERE stream_id <= ? %s
GROUP BY room_id
) grouped USING (room_id, stream_id)
""")
"""
if network_tuple.appservice_id is not None:
txn.execute(
sql % ("AND appservice_id = ? AND network_id = ?",),
(stream_id, network_tuple.appservice_id, network_tuple.network_id,)
(stream_id, network_tuple.appservice_id, network_tuple.network_id),
)
else:
txn.execute(
sql % ("AND appservice_id IS NULL",),
(stream_id,)
)
txn.execute(sql % ("AND appservice_id IS NULL",), (stream_id,))
return dict(txn)
else:
# We want to get from all lists, so we need to aggregate the results
logger.info("Executing full list")
sql = ("""
sql = """
SELECT room_id, visibility
FROM public_room_list_stream
INNER JOIN (
@ -133,12 +126,9 @@ class RoomWorkerStore(SQLBaseStore):
WHERE stream_id <= ?
GROUP BY room_id, appservice_id, network_id
) grouped USING (room_id, stream_id)
""")
"""
txn.execute(
sql,
(stream_id,)
)
txn.execute(sql, (stream_id,))
results = {}
# A room is visible if its visible on any list.
@ -147,8 +137,7 @@ class RoomWorkerStore(SQLBaseStore):
return results
def get_public_room_changes(self, prev_stream_id, new_stream_id,
network_tuple):
def get_public_room_changes(self, prev_stream_id, new_stream_id, network_tuple):
def get_public_room_changes_txn(txn):
then_rooms = self.get_public_room_ids_at_stream_id_txn(
txn, prev_stream_id, network_tuple
@ -158,9 +147,7 @@ class RoomWorkerStore(SQLBaseStore):
txn, new_stream_id, network_tuple
)
now_rooms_visible = set(
rm for rm, vis in now_rooms_dict.items() if vis
)
now_rooms_visible = set(rm for rm, vis in now_rooms_dict.items() if vis)
now_rooms_not_visible = set(
rm for rm, vis in now_rooms_dict.items() if not vis
)
@ -178,9 +165,7 @@ class RoomWorkerStore(SQLBaseStore):
def is_room_blocked(self, room_id):
return self._simple_select_one_onecol(
table="blocked_rooms",
keyvalues={
"room_id": room_id,
},
keyvalues={"room_id": room_id},
retcol="1",
allow_none=True,
desc="is_room_blocked",
@ -208,16 +193,17 @@ class RoomWorkerStore(SQLBaseStore):
)
if row:
defer.returnValue(RatelimitOverride(
messages_per_second=row["messages_per_second"],
burst_count=row["burst_count"],
))
defer.returnValue(
RatelimitOverride(
messages_per_second=row["messages_per_second"],
burst_count=row["burst_count"],
)
)
else:
defer.returnValue(None)
class RoomStore(RoomWorkerStore, SearchStore):
@defer.inlineCallbacks
def store_room(self, room_id, room_creator_user_id, is_public):
"""Stores a room.
@ -231,6 +217,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
StoreError if the room could not be stored.
"""
try:
def store_room_txn(txn, next_id):
self._simple_insert_txn(
txn,
@ -249,13 +236,11 @@ class RoomStore(RoomWorkerStore, SearchStore):
"stream_id": next_id,
"room_id": room_id,
"visibility": is_public,
}
},
)
with self._public_room_id_gen.get_next() as next_id:
yield self.runInteraction(
"store_room_txn",
store_room_txn, next_id,
)
yield self.runInteraction("store_room_txn", store_room_txn, next_id)
except Exception as e:
logger.error("store_room with room_id=%s failed: %s", room_id, e)
raise StoreError(500, "Problem creating room.")
@ -297,19 +282,19 @@ class RoomStore(RoomWorkerStore, SearchStore):
"visibility": is_public,
"appservice_id": None,
"network_id": None,
}
},
)
with self._public_room_id_gen.get_next() as next_id:
yield self.runInteraction(
"set_room_is_public",
set_room_is_public_txn, next_id,
"set_room_is_public", set_room_is_public_txn, next_id
)
self.hs.get_notifier().on_new_replication_data()
@defer.inlineCallbacks
def set_room_is_public_appservice(self, room_id, appservice_id, network_id,
is_public):
def set_room_is_public_appservice(
self, room_id, appservice_id, network_id, is_public
):
"""Edit the appservice/network specific public room list.
Each appservice can have a number of published room lists associated
@ -324,6 +309,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
is_public (bool): Whether to publish or unpublish the room from the
list.
"""
def set_room_is_public_appservice_txn(txn, next_id):
if is_public:
try:
@ -333,7 +319,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
values={
"appservice_id": appservice_id,
"network_id": network_id,
"room_id": room_id
"room_id": room_id,
},
)
except self.database_engine.module.IntegrityError:
@ -346,7 +332,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
keyvalues={
"appservice_id": appservice_id,
"network_id": network_id,
"room_id": room_id
"room_id": room_id,
},
)
@ -377,13 +363,14 @@ class RoomStore(RoomWorkerStore, SearchStore):
"visibility": is_public,
"appservice_id": appservice_id,
"network_id": network_id,
}
},
)
with self._public_room_id_gen.get_next() as next_id:
yield self.runInteraction(
"set_room_is_public_appservice",
set_room_is_public_appservice_txn, next_id,
set_room_is_public_appservice_txn,
next_id,
)
self.hs.get_notifier().on_new_replication_data()
@ -397,9 +384,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
row = txn.fetchone()
return row[0] or 0
return self.runInteraction(
"get_rooms", f
)
return self.runInteraction("get_rooms", f)
def _store_room_topic_txn(self, txn, event):
if hasattr(event, "content") and "topic" in event.content:
@ -414,7 +399,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
)
self.store_event_search_txn(
txn, event, "content.topic", event.content["topic"],
txn, event, "content.topic", event.content["topic"]
)
def _store_room_name_txn(self, txn, event):
@ -426,17 +411,17 @@ class RoomStore(RoomWorkerStore, SearchStore):
"event_id": event.event_id,
"room_id": event.room_id,
"name": event.content["name"],
}
},
)
self.store_event_search_txn(
txn, event, "content.name", event.content["name"],
txn, event, "content.name", event.content["name"]
)
def _store_room_message_txn(self, txn, event):
if hasattr(event, "content") and "body" in event.content:
self.store_event_search_txn(
txn, event, "content.body", event.content["body"],
txn, event, "content.body", event.content["body"]
)
def _store_history_visibility_txn(self, txn, event):
@ -452,14 +437,11 @@ class RoomStore(RoomWorkerStore, SearchStore):
" (event_id, room_id, %(key)s)"
" VALUES (?, ?, ?)" % {"key": key}
)
txn.execute(sql, (
event.event_id,
event.room_id,
event.content[key]
))
txn.execute(sql, (event.event_id, event.room_id, event.content[key]))
def add_event_report(self, room_id, event_id, user_id, reason, content,
received_ts):
def add_event_report(
self, room_id, event_id, user_id, reason, content, received_ts
):
next_id = self._event_reports_id_gen.get_next()
return self._simple_insert(
table="event_reports",
@ -472,7 +454,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
"reason": reason,
"content": json.dumps(content),
},
desc="add_event_report"
desc="add_event_report",
)
def get_current_public_room_stream_id(self):
@ -480,23 +462,21 @@ class RoomStore(RoomWorkerStore, SearchStore):
def get_all_new_public_rooms(self, prev_id, current_id, limit):
def get_all_new_public_rooms(txn):
sql = ("""
sql = """
SELECT stream_id, room_id, visibility, appservice_id, network_id
FROM public_room_list_stream
WHERE stream_id > ? AND stream_id <= ?
ORDER BY stream_id ASC
LIMIT ?
""")
"""
txn.execute(sql, (prev_id, current_id, limit,))
txn.execute(sql, (prev_id, current_id, limit))
return txn.fetchall()
if prev_id == current_id:
return defer.succeed([])
return self.runInteraction(
"get_all_new_public_rooms", get_all_new_public_rooms
)
return self.runInteraction("get_all_new_public_rooms", get_all_new_public_rooms)
@defer.inlineCallbacks
def block_room(self, room_id, user_id):
@ -511,19 +491,16 @@ class RoomStore(RoomWorkerStore, SearchStore):
"""
yield self._simple_upsert(
table="blocked_rooms",
keyvalues={
"room_id": room_id,
},
keyvalues={"room_id": room_id},
values={},
insertion_values={
"user_id": user_id,
},
insertion_values={"user_id": user_id},
desc="block_room",
)
yield self.runInteraction(
"block_room_invalidation",
self._invalidate_cache_and_stream,
self.is_room_blocked, (room_id,),
self.is_room_blocked,
(room_id,),
)
def get_media_mxcs_in_room(self, room_id):
@ -536,6 +513,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
The local and remote media as a lists of tuples where the key is
the hostname and the value is the media ID.
"""
def _get_media_mxcs_in_room_txn(txn):
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
local_media_mxcs = []
@ -548,23 +526,28 @@ class RoomStore(RoomWorkerStore, SearchStore):
remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id))
return local_media_mxcs, remote_media_mxcs
return self.runInteraction("get_media_ids_in_room", _get_media_mxcs_in_room_txn)
def quarantine_media_ids_in_room(self, room_id, quarantined_by):
"""For a room loops through all events with media and quarantines
the associated media
"""
def _quarantine_media_in_room_txn(txn):
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
total_media_quarantined = 0
# Now update all the tables to set the quarantined_by flag
txn.executemany("""
txn.executemany(
"""
UPDATE local_media_repository
SET quarantined_by = ?
WHERE media_id = ?
""", ((quarantined_by, media_id) for media_id in local_mxcs))
""",
((quarantined_by, media_id) for media_id in local_mxcs),
)
txn.executemany(
"""
@ -575,7 +558,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
(
(quarantined_by, origin, media_id)
for origin, media_id in remote_mxcs
)
),
)
total_media_quarantined += len(local_mxcs)
@ -584,8 +567,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
return total_media_quarantined
return self.runInteraction(
"quarantine_media_in_room",
_quarantine_media_in_room_txn,
"quarantine_media_in_room", _quarantine_media_in_room_txn
)
def _get_media_mxcs_in_room_txn(self, txn, room_id):

View file

@ -0,0 +1,29 @@
/* Copyright 2019 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Tracks which identity server a user bound their threepid via.
CREATE TABLE user_threepid_id_server (
user_id TEXT NOT NULL,
medium TEXT NOT NULL,
address TEXT NOT NULL,
id_server TEXT NOT NULL
);
CREATE UNIQUE INDEX user_threepid_id_server_idx ON user_threepid_id_server(
user_id, medium, address, id_server
);
INSERT INTO background_updates (update_name, progress_json) VALUES
('user_threepids_grandfather', '{}');

View file

@ -0,0 +1,16 @@
/* Copyright 2019 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
DROP TABLE IF EXISTS presence_list;

View file

@ -28,13 +28,5 @@ CREATE TABLE IF NOT EXISTS presence_allow_inbound(
UNIQUE (observed_user_id, observer_user_id)
);
-- For each of /my/ users (watcher), which possibly-remote users are they
-- watching?
CREATE TABLE IF NOT EXISTS presence_list(
user_id TEXT NOT NULL,
observed_user_id TEXT NOT NULL, -- a UserID,
accepted BOOLEAN NOT NULL,
UNIQUE (user_id, observed_user_id)
);
CREATE INDEX presence_list_user_id ON presence_list (user_id);
-- We used to create a table called presence_list, but this is no longer used
-- and is removed in delta 54.

View file

@ -30,10 +30,10 @@ from .background_updates import BackgroundUpdateStore
logger = logging.getLogger(__name__)
SearchEntry = namedtuple('SearchEntry', [
'key', 'value', 'event_id', 'room_id', 'stream_ordering',
'origin_server_ts',
])
SearchEntry = namedtuple(
'SearchEntry',
['key', 'value', 'event_id', 'room_id', 'stream_ordering', 'origin_server_ts'],
)
class SearchStore(BackgroundUpdateStore):
@ -53,8 +53,7 @@ class SearchStore(BackgroundUpdateStore):
self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search
)
self.register_background_update_handler(
self.EVENT_SEARCH_ORDER_UPDATE_NAME,
self._background_reindex_search_order
self.EVENT_SEARCH_ORDER_UPDATE_NAME, self._background_reindex_search_order
)
# we used to have a background update to turn the GIN index into a
@ -62,13 +61,10 @@ class SearchStore(BackgroundUpdateStore):
# a GIN index. However, it's possible that some people might still have
# the background update queued, so we register a handler to clear the
# background update.
self.register_noop_background_update(
self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME,
)
self.register_noop_background_update(self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME)
self.register_background_update_handler(
self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME,
self._background_reindex_gin_search
self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, self._background_reindex_gin_search
)
@defer.inlineCallbacks
@ -138,21 +134,23 @@ class SearchStore(BackgroundUpdateStore):
# then skip over it
continue
event_search_rows.append(SearchEntry(
key=key,
value=value,
event_id=event_id,
room_id=room_id,
stream_ordering=stream_ordering,
origin_server_ts=origin_server_ts,
))
event_search_rows.append(
SearchEntry(
key=key,
value=value,
event_id=event_id,
room_id=room_id,
stream_ordering=stream_ordering,
origin_server_ts=origin_server_ts,
)
)
self.store_search_entries_txn(txn, event_search_rows)
progress = {
"target_min_stream_id_inclusive": target_min_stream_id,
"max_stream_id_exclusive": min_stream_id,
"rows_inserted": rows_inserted + len(event_search_rows)
"rows_inserted": rows_inserted + len(event_search_rows),
}
self._background_update_progress_txn(
@ -191,6 +189,7 @@ class SearchStore(BackgroundUpdateStore):
# doesn't support CREATE INDEX IF EXISTS so we just catch the
# exception and ignore it.
import psycopg2
try:
c.execute(
"CREATE INDEX CONCURRENTLY event_search_fts_idx"
@ -198,14 +197,11 @@ class SearchStore(BackgroundUpdateStore):
)
except psycopg2.ProgrammingError as e:
logger.warn(
"Ignoring error %r when trying to switch from GIST to GIN",
e
"Ignoring error %r when trying to switch from GIST to GIN", e
)
# we should now be able to delete the GIST index.
c.execute(
"DROP INDEX IF EXISTS event_search_fts_idx_gist"
)
c.execute("DROP INDEX IF EXISTS event_search_fts_idx_gist")
finally:
conn.set_session(autocommit=False)
@ -223,6 +219,7 @@ class SearchStore(BackgroundUpdateStore):
have_added_index = progress['have_added_indexes']
if not have_added_index:
def create_index(conn):
conn.rollback()
conn.set_session(autocommit=True)
@ -248,7 +245,8 @@ class SearchStore(BackgroundUpdateStore):
yield self.runInteraction(
self.EVENT_SEARCH_ORDER_UPDATE_NAME,
self._background_update_progress_txn,
self.EVENT_SEARCH_ORDER_UPDATE_NAME, pg,
self.EVENT_SEARCH_ORDER_UPDATE_NAME,
pg,
)
def reindex_search_txn(txn):
@ -302,14 +300,16 @@ class SearchStore(BackgroundUpdateStore):
"""
self.store_search_entries_txn(
txn,
(SearchEntry(
key=key,
value=value,
event_id=event.event_id,
room_id=event.room_id,
stream_ordering=event.internal_metadata.stream_ordering,
origin_server_ts=event.origin_server_ts,
),),
(
SearchEntry(
key=key,
value=value,
event_id=event.event_id,
room_id=event.room_id,
stream_ordering=event.internal_metadata.stream_ordering,
origin_server_ts=event.origin_server_ts,
),
),
)
def store_search_entries_txn(self, txn, entries):
@ -329,10 +329,17 @@ class SearchStore(BackgroundUpdateStore):
" VALUES (?,?,?,to_tsvector('english', ?),?,?)"
)
args = ((
entry.event_id, entry.room_id, entry.key, entry.value,
entry.stream_ordering, entry.origin_server_ts,
) for entry in entries)
args = (
(
entry.event_id,
entry.room_id,
entry.key,
entry.value,
entry.stream_ordering,
entry.origin_server_ts,
)
for entry in entries
)
# inserts to a GIN index are normally batched up into a pending
# list, and then all committed together once the list gets to a
@ -363,9 +370,10 @@ class SearchStore(BackgroundUpdateStore):
"INSERT INTO event_search (event_id, room_id, key, value)"
" VALUES (?,?,?,?)"
)
args = ((
entry.event_id, entry.room_id, entry.key, entry.value,
) for entry in entries)
args = (
(entry.event_id, entry.room_id, entry.key, entry.value)
for entry in entries
)
txn.executemany(sql, args)
else:
@ -394,9 +402,7 @@ class SearchStore(BackgroundUpdateStore):
# Make sure we don't explode because the person is in too many rooms.
# We filter the results below regardless.
if len(room_ids) < 500:
clauses.append(
"room_id IN (%s)" % (",".join(["?"] * len(room_ids)),)
)
clauses.append("room_id IN (%s)" % (",".join(["?"] * len(room_ids)),))
args.extend(room_ids)
local_clauses = []
@ -404,9 +410,7 @@ class SearchStore(BackgroundUpdateStore):
local_clauses.append("key = ?")
args.append(key)
clauses.append(
"(%s)" % (" OR ".join(local_clauses),)
)
clauses.append("(%s)" % (" OR ".join(local_clauses),))
count_args = args
count_clauses = clauses
@ -452,18 +456,13 @@ class SearchStore(BackgroundUpdateStore):
# entire table from the database.
sql += " ORDER BY rank DESC LIMIT 500"
results = yield self._execute(
"search_msgs", self.cursor_to_dict, sql, *args
)
results = yield self._execute("search_msgs", self.cursor_to_dict, sql, *args)
results = list(filter(lambda row: row["room_id"] in room_ids, results))
events = yield self._get_events([r["event_id"] for r in results])
event_map = {
ev.event_id: ev
for ev in events
}
event_map = {ev.event_id: ev for ev in events}
highlights = None
if isinstance(self.database_engine, PostgresEngine):
@ -477,18 +476,17 @@ class SearchStore(BackgroundUpdateStore):
count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
defer.returnValue({
"results": [
{
"event": event_map[r["event_id"]],
"rank": r["rank"],
}
for r in results
if r["event_id"] in event_map
],
"highlights": highlights,
"count": count,
})
defer.returnValue(
{
"results": [
{"event": event_map[r["event_id"]], "rank": r["rank"]}
for r in results
if r["event_id"] in event_map
],
"highlights": highlights,
"count": count,
}
)
@defer.inlineCallbacks
def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None):
@ -513,9 +511,7 @@ class SearchStore(BackgroundUpdateStore):
# Make sure we don't explode because the person is in too many rooms.
# We filter the results below regardless.
if len(room_ids) < 500:
clauses.append(
"room_id IN (%s)" % (",".join(["?"] * len(room_ids)),)
)
clauses.append("room_id IN (%s)" % (",".join(["?"] * len(room_ids)),))
args.extend(room_ids)
local_clauses = []
@ -523,9 +519,7 @@ class SearchStore(BackgroundUpdateStore):
local_clauses.append("key = ?")
args.append(key)
clauses.append(
"(%s)" % (" OR ".join(local_clauses),)
)
clauses.append("(%s)" % (" OR ".join(local_clauses),))
# take copies of the current args and clauses lists, before adding
# pagination clauses to main query.
@ -607,18 +601,13 @@ class SearchStore(BackgroundUpdateStore):
args.append(limit)
results = yield self._execute(
"search_rooms", self.cursor_to_dict, sql, *args
)
results = yield self._execute("search_rooms", self.cursor_to_dict, sql, *args)
results = list(filter(lambda row: row["room_id"] in room_ids, results))
events = yield self._get_events([r["event_id"] for r in results])
event_map = {
ev.event_id: ev
for ev in events
}
event_map = {ev.event_id: ev for ev in events}
highlights = None
if isinstance(self.database_engine, PostgresEngine):
@ -632,21 +621,22 @@ class SearchStore(BackgroundUpdateStore):
count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
defer.returnValue({
"results": [
{
"event": event_map[r["event_id"]],
"rank": r["rank"],
"pagination_token": "%s,%s" % (
r["origin_server_ts"], r["stream_ordering"]
),
}
for r in results
if r["event_id"] in event_map
],
"highlights": highlights,
"count": count,
})
defer.returnValue(
{
"results": [
{
"event": event_map[r["event_id"]],
"rank": r["rank"],
"pagination_token": "%s,%s"
% (r["origin_server_ts"], r["stream_ordering"]),
}
for r in results
if r["event_id"] in event_map
],
"highlights": highlights,
"count": count,
}
)
def _find_highlights_in_postgres(self, search_query, events):
"""Given a list of events and a search term, return a list of words
@ -662,6 +652,7 @@ class SearchStore(BackgroundUpdateStore):
Returns:
deferred : A set of strings.
"""
def f(txn):
highlight_words = set()
for event in events:
@ -689,13 +680,15 @@ class SearchStore(BackgroundUpdateStore):
stop_sel += ">"
query = "SELECT ts_headline(?, to_tsquery('english', ?), %s)" % (
_to_postgres_options({
"StartSel": start_sel,
"StopSel": stop_sel,
"MaxFragments": "50",
})
_to_postgres_options(
{
"StartSel": start_sel,
"StopSel": stop_sel,
"MaxFragments": "50",
}
)
)
txn.execute(query, (value, search_query,))
txn.execute(query, (value, search_query))
headline, = txn.fetchall()[0]
# Now we need to pick the possible highlights out of the haedline
@ -714,9 +707,7 @@ class SearchStore(BackgroundUpdateStore):
def _to_postgres_options(options_dict):
return "'%s'" % (
",".join("%s=%s" % (k, v) for k, v in options_dict.items()),
)
return "'%s'" % (",".join("%s=%s" % (k, v) for k, v in options_dict.items()),)
def _parse_query(database_engine, search_term):

View file

@ -39,8 +39,9 @@ class SignatureWorkerStore(SQLBaseStore):
# to use its cache
raise NotImplementedError()
@cachedList(cached_method_name="get_event_reference_hash",
list_name="event_ids", num_args=1)
@cachedList(
cached_method_name="get_event_reference_hash", list_name="event_ids", num_args=1
)
def get_event_reference_hashes(self, event_ids):
def f(txn):
return {
@ -48,21 +49,13 @@ class SignatureWorkerStore(SQLBaseStore):
for event_id in event_ids
}
return self.runInteraction(
"get_event_reference_hashes",
f
)
return self.runInteraction("get_event_reference_hashes", f)
@defer.inlineCallbacks
def add_event_hashes(self, event_ids):
hashes = yield self.get_event_reference_hashes(
event_ids
)
hashes = yield self.get_event_reference_hashes(event_ids)
hashes = {
e_id: {
k: encode_base64(v) for k, v in h.items()
if k == "sha256"
}
e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"}
for e_id, h in hashes.items()
}
@ -81,7 +74,7 @@ class SignatureWorkerStore(SQLBaseStore):
" FROM event_reference_hashes"
" WHERE event_id = ?"
)
txn.execute(query, (event_id, ))
txn.execute(query, (event_id,))
return {k: v for k, v in txn}
@ -98,14 +91,12 @@ class SignatureStore(SignatureWorkerStore):
vals = []
for event in events:
ref_alg, ref_hash_bytes = compute_event_reference_hash(event)
vals.append({
"event_id": event.event_id,
"algorithm": ref_alg,
"hash": db_binary_type(ref_hash_bytes),
})
vals.append(
{
"event_id": event.event_id,
"algorithm": ref_alg,
"hash": db_binary_type(ref_hash_bytes),
}
)
self._simple_insert_many_txn(
txn,
table="event_reference_hashes",
values=vals,
)
self._simple_insert_many_txn(txn, table="event_reference_hashes", values=vals)

View file

@ -40,10 +40,13 @@ logger = logging.getLogger(__name__)
MAX_STATE_DELTA_HOPS = 100
class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))):
class _GetStateGroupDelta(
namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))
):
"""Return type of get_state_group_delta that implements __len__, which lets
us use the itrable flag when caching
"""
__slots__ = []
def __len__(self):
@ -70,10 +73,7 @@ class StateFilter(object):
# If `include_others` is set we canonicalise the filter by removing
# wildcards from the types dictionary
if self.include_others:
self.types = {
k: v for k, v in iteritems(self.types)
if v is not None
}
self.types = {k: v for k, v in iteritems(self.types) if v is not None}
@staticmethod
def all():
@ -130,10 +130,7 @@ class StateFilter(object):
Returns:
StateFilter
"""
return StateFilter(
types={EventTypes.Member: set(members)},
include_others=True,
)
return StateFilter(types={EventTypes.Member: set(members)}, include_others=True)
def return_expanded(self):
"""Creates a new StateFilter where type wild cards have been removed
@ -243,9 +240,7 @@ class StateFilter(object):
if where_clause:
where_clause += " OR "
where_clause += "type NOT IN (%s)" % (
",".join(["?"] * len(self.types)),
)
where_clause += "type NOT IN (%s)" % (",".join(["?"] * len(self.types)),)
where_args.extend(self.types)
return where_clause, where_args
@ -305,12 +300,8 @@ class StateFilter(object):
bool
"""
return (
self.include_others
or any(
state_keys is None
for state_keys in itervalues(self.types)
)
return self.include_others or any(
state_keys is None for state_keys in itervalues(self.types)
)
def concrete_types(self):
@ -406,11 +397,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
self._state_group_cache = DictionaryCache(
"*stateGroupCache*",
# TODO: this hasn't been tuned yet
50000 * get_cache_factor_for("stateGroupCache")
50000 * get_cache_factor_for("stateGroupCache"),
)
self._state_group_members_cache = DictionaryCache(
"*stateGroupMembersCache*",
500000 * get_cache_factor_for("stateGroupMembersCache")
500000 * get_cache_factor_for("stateGroupMembersCache"),
)
@defer.inlineCallbacks
@ -488,22 +479,20 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
deferred: dict of (type, state_key) -> event_id
"""
def _get_current_state_ids_txn(txn):
txn.execute(
"""SELECT type, state_key, event_id FROM current_state_events
WHERE room_id = ?
""",
(room_id,)
(room_id,),
)
return {
(intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn
}
return self.runInteraction(
"get_current_state_ids",
_get_current_state_ids_txn,
)
return self.runInteraction("get_current_state_ids", _get_current_state_ids_txn)
# FIXME: how should this be cached?
def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()):
@ -544,8 +533,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return results
return self.runInteraction(
"get_filtered_current_state_ids",
_get_filtered_current_state_ids_txn,
"get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
)
@defer.inlineCallbacks
@ -559,9 +547,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Deferred[str|None]: The canonical alias, if any
"""
state = yield self.get_filtered_current_state_ids(room_id, StateFilter.from_types(
[(EventTypes.CanonicalAlias, "")]
))
state = yield self.get_filtered_current_state_ids(
room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")])
)
event_id = state.get((EventTypes.CanonicalAlias, ""))
if not event_id:
@ -581,13 +569,12 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
(prev_group, delta_ids), where both may be None.
"""
def _get_state_group_delta_txn(txn):
prev_group = self._simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={
"state_group": state_group,
},
keyvalues={"state_group": state_group},
retcol="prev_state_group",
allow_none=True,
)
@ -598,20 +585,16 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
delta_ids = self._simple_select_list_txn(
txn,
table="state_groups_state",
keyvalues={
"state_group": state_group,
},
retcols=("type", "state_key", "event_id",)
keyvalues={"state_group": state_group},
retcols=("type", "state_key", "event_id"),
)
return _GetStateGroupDelta(prev_group, {
(row["type"], row["state_key"]): row["event_id"]
for row in delta_ids
})
return self.runInteraction(
"get_state_group_delta",
_get_state_group_delta_txn,
)
return _GetStateGroupDelta(
prev_group,
{(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
)
return self.runInteraction("get_state_group_delta", _get_state_group_delta_txn)
@defer.inlineCallbacks
def get_state_groups_ids(self, _room_id, event_ids):
@ -628,9 +611,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
if not event_ids:
defer.returnValue({})
event_to_groups = yield self._get_state_group_for_events(
event_ids,
)
event_to_groups = yield self._get_state_group_for_events(event_ids)
groups = set(itervalues(event_to_groups))
group_to_state = yield self._get_state_for_groups(groups)
@ -666,19 +647,23 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
state_event_map = yield self.get_events(
[
ev_id for group_ids in itervalues(group_to_ids)
ev_id
for group_ids in itervalues(group_to_ids)
for ev_id in itervalues(group_ids)
],
get_prev_content=False
get_prev_content=False,
)
defer.returnValue({
group: [
state_event_map[v] for v in itervalues(event_id_map)
if v in state_event_map
]
for group, event_id_map in iteritems(group_to_ids)
})
defer.returnValue(
{
group: [
state_event_map[v]
for v in itervalues(event_id_map)
if v in state_event_map
]
for group, event_id_map in iteritems(group_to_ids)
}
)
@defer.inlineCallbacks
def _get_state_groups_from_groups(self, groups, state_filter):
@ -695,18 +680,20 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"""
results = {}
chunks = [groups[i:i + 100] for i in range(0, len(groups), 100)]
chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)]
for chunk in chunks:
res = yield self.runInteraction(
"_get_state_groups_from_groups",
self._get_state_groups_from_groups_txn, chunk, state_filter,
self._get_state_groups_from_groups_txn,
chunk,
state_filter,
)
results.update(res)
defer.returnValue(results)
def _get_state_groups_from_groups_txn(
self, txn, groups, state_filter=StateFilter.all(),
self, txn, groups, state_filter=StateFilter.all()
):
results = {group: {} for group in groups}
@ -776,7 +763,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
txn.execute(
"SELECT type, state_key, event_id FROM state_groups_state"
" WHERE state_group = ? " + where_clause,
args
args,
)
results[group].update(
((typ, state_key), event_id)
@ -791,8 +778,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
# wildcards (i.e. Nones) in which case we have to do an exhaustive
# search
if (
max_entries_returned is not None and
len(results[group]) == max_entries_returned
max_entries_returned is not None
and len(results[group]) == max_entries_returned
):
break
@ -819,16 +806,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
"""
event_to_groups = yield self._get_state_group_for_events(
event_ids,
)
event_to_groups = yield self._get_state_group_for_events(event_ids)
groups = set(itervalues(event_to_groups))
group_to_state = yield self._get_state_for_groups(groups, state_filter)
state_event_map = yield self.get_events(
[ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)],
get_prev_content=False
get_prev_content=False,
)
event_to_state = {
@ -856,9 +841,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
A deferred dict from event_id -> (type, state_key) -> event_id
"""
event_to_groups = yield self._get_state_group_for_events(
event_ids,
)
event_to_groups = yield self._get_state_group_for_events(event_ids)
groups = set(itervalues(event_to_groups))
group_to_state = yield self._get_state_for_groups(groups, state_filter)
@ -906,16 +889,18 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
def _get_state_group_for_event(self, event_id):
return self._simple_select_one_onecol(
table="event_to_state_groups",
keyvalues={
"event_id": event_id,
},
keyvalues={"event_id": event_id},
retcol="state_group",
allow_none=True,
desc="_get_state_group_for_event",
)
@cachedList(cached_method_name="_get_state_group_for_event",
list_name="event_ids", num_args=1, inlineCallbacks=True)
@cachedList(
cached_method_name="_get_state_group_for_event",
list_name="event_ids",
num_args=1,
inlineCallbacks=True,
)
def _get_state_group_for_events(self, event_ids):
"""Returns mapping event_id -> state_group
"""
@ -924,7 +909,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
column="event_id",
iterable=event_ids,
keyvalues={},
retcols=("event_id", "state_group",),
retcols=("event_id", "state_group"),
desc="_get_state_group_for_events",
)
@ -989,15 +974,13 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
# Now we look them up in the member and non-member caches
non_member_state, incomplete_groups_nm, = (
yield self._get_state_for_groups_using_cache(
groups, self._state_group_cache,
state_filter=non_member_filter,
groups, self._state_group_cache, state_filter=non_member_filter
)
)
member_state, incomplete_groups_m, = (
yield self._get_state_for_groups_using_cache(
groups, self._state_group_members_cache,
state_filter=member_filter,
groups, self._state_group_members_cache, state_filter=member_filter
)
)
@ -1019,8 +1002,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
db_state_filter = state_filter.return_expanded()
group_to_state_dict = yield self._get_state_groups_from_groups(
list(incomplete_groups),
state_filter=db_state_filter,
list(incomplete_groups), state_filter=db_state_filter
)
# Now lets update the caches
@ -1040,9 +1022,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
defer.returnValue(state)
def _get_state_for_groups_using_cache(
self, groups, cache, state_filter,
):
def _get_state_for_groups_using_cache(self, groups, cache, state_filter):
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key, querying from a specific cache.
@ -1074,8 +1054,13 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return results, incomplete_groups
def _insert_into_cache(self, group_to_state_dict, state_filter,
cache_seq_num_members, cache_seq_num_non_members):
def _insert_into_cache(
self,
group_to_state_dict,
state_filter,
cache_seq_num_members,
cache_seq_num_non_members,
):
"""Inserts results from querying the database into the relevant cache.
Args:
@ -1132,8 +1117,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
fetched_keys=non_member_types,
)
def store_state_group(self, event_id, room_id, prev_group, delta_ids,
current_state_ids):
def store_state_group(
self, event_id, room_id, prev_group, delta_ids, current_state_ids
):
"""Store a new set of state, returning a newly assigned state group.
Args:
@ -1149,6 +1135,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
Deferred[int]: The state group ID
"""
def _store_state_group_txn(txn):
if current_state_ids is None:
# AFAIK, this can never happen
@ -1159,11 +1146,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
self._simple_insert_txn(
txn,
table="state_groups",
values={
"id": state_group,
"room_id": room_id,
"event_id": event_id,
},
values={"id": state_group, "room_id": room_id, "event_id": event_id},
)
# We persist as a delta if we can, while also ensuring the chain
@ -1182,17 +1165,12 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
% (prev_group,)
)
potential_hops = self._count_state_group_hops_txn(
txn, prev_group
)
potential_hops = self._count_state_group_hops_txn(txn, prev_group)
if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
self._simple_insert_txn(
txn,
table="state_group_edges",
values={
"state_group": state_group,
"prev_state_group": prev_group,
},
values={"state_group": state_group, "prev_state_group": prev_group},
)
self._simple_insert_many_txn(
@ -1264,7 +1242,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
This is used to ensure the delta chains don't get too long.
"""
if isinstance(self.database_engine, PostgresEngine):
sql = ("""
sql = """
WITH RECURSIVE state(state_group) AS (
VALUES(?::bigint)
UNION ALL
@ -1272,7 +1250,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
WHERE s.state_group = e.state_group
)
SELECT count(*) FROM state;
""")
"""
txn.execute(sql, (state_group,))
row = txn.fetchone()
@ -1331,8 +1309,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
self._background_deduplicate_state,
)
self.register_background_update_handler(
self.STATE_GROUP_INDEX_UPDATE_NAME,
self._background_index_state,
self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state
)
self.register_background_index_update(
self.CURRENT_STATE_INDEX_UPDATE_NAME,
@ -1366,18 +1343,14 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
txn,
table="event_to_state_groups",
values=[
{
"state_group": state_group_id,
"event_id": event_id,
}
{"state_group": state_group_id, "event_id": event_id}
for event_id, state_group_id in iteritems(state_groups)
],
)
for event_id, state_group_id in iteritems(state_groups):
txn.call_after(
self._get_state_group_for_event.prefill,
(event_id,), state_group_id
self._get_state_group_for_event.prefill, (event_id,), state_group_id
)
@defer.inlineCallbacks
@ -1395,7 +1368,8 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
if max_group is None:
rows = yield self._execute(
"_background_deduplicate_state", None,
"_background_deduplicate_state",
None,
"SELECT coalesce(max(id), 0) FROM state_groups",
)
max_group = rows[0][0]
@ -1408,7 +1382,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
" WHERE ? < id AND id <= ?"
" ORDER BY id ASC"
" LIMIT 1",
(new_last_state_group, max_group,)
(new_last_state_group, max_group),
)
row = txn.fetchone()
if row:
@ -1420,7 +1394,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
txn.execute(
"SELECT state_group FROM state_group_edges"
" WHERE state_group = ?",
(state_group,)
(state_group,),
)
# If we reach a point where we've already started inserting
@ -1431,27 +1405,25 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
txn.execute(
"SELECT coalesce(max(id), 0) FROM state_groups"
" WHERE id < ? AND room_id = ?",
(state_group, room_id,)
(state_group, room_id),
)
prev_group, = txn.fetchone()
new_last_state_group = state_group
if prev_group:
potential_hops = self._count_state_group_hops_txn(
txn, prev_group
)
potential_hops = self._count_state_group_hops_txn(txn, prev_group)
if potential_hops >= MAX_STATE_DELTA_HOPS:
# We want to ensure chains are at most this long,#
# otherwise read performance degrades.
continue
prev_state = self._get_state_groups_from_groups_txn(
txn, [prev_group],
txn, [prev_group]
)
prev_state = prev_state[prev_group]
curr_state = self._get_state_groups_from_groups_txn(
txn, [state_group],
txn, [state_group]
)
curr_state = curr_state[state_group]
@ -1460,16 +1432,15 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
# of keys
delta_state = {
key: value for key, value in iteritems(curr_state)
key: value
for key, value in iteritems(curr_state)
if prev_state.get(key, None) != value
}
self._simple_delete_txn(
txn,
table="state_group_edges",
keyvalues={
"state_group": state_group,
}
keyvalues={"state_group": state_group},
)
self._simple_insert_txn(
@ -1478,15 +1449,13 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
values={
"state_group": state_group,
"prev_state_group": prev_group,
}
},
)
self._simple_delete_txn(
txn,
table="state_groups_state",
keyvalues={
"state_group": state_group,
}
keyvalues={"state_group": state_group},
)
self._simple_insert_many_txn(
@ -1521,7 +1490,9 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
)
if finished:
yield self._end_background_update(self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME)
yield self._end_background_update(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME
)
defer.returnValue(result * BATCH_SIZE_SCALE_FACTOR)
@ -1538,9 +1509,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
"CREATE INDEX CONCURRENTLY state_groups_state_type_idx"
" ON state_groups_state(state_group, type, state_key)"
)
txn.execute(
"DROP INDEX IF EXISTS state_groups_state_id"
)
txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
finally:
conn.set_session(autocommit=False)
else:
@ -1549,9 +1518,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
"CREATE INDEX state_groups_state_type_idx"
" ON state_groups_state(state_group, type, state_key)"
)
txn.execute(
"DROP INDEX IF EXISTS state_groups_state_id"
)
txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
yield self.runWithConnection(reindex_txn)

View file

@ -21,10 +21,11 @@ logger = logging.getLogger(__name__)
class StateDeltasStore(SQLBaseStore):
def get_current_state_deltas(self, prev_stream_id):
prev_stream_id = int(prev_stream_id)
if not self._curr_state_delta_stream_cache.has_any_entity_changed(prev_stream_id):
if not self._curr_state_delta_stream_cache.has_any_entity_changed(
prev_stream_id
):
return []
def get_current_state_deltas_txn(txn):
@ -58,7 +59,7 @@ class StateDeltasStore(SQLBaseStore):
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
"""
txn.execute(sql, (prev_stream_id, max_stream_id,))
txn.execute(sql, (prev_stream_id, max_stream_id))
return self.cursor_to_dict(txn)
return self.runInteraction(

View file

@ -59,9 +59,9 @@ _TOPOLOGICAL_TOKEN = "topological"
# Used as return values for pagination APIs
_EventDictReturn = namedtuple("_EventDictReturn", (
"event_id", "topological_ordering", "stream_ordering",
))
_EventDictReturn = namedtuple(
"_EventDictReturn", ("event_id", "topological_ordering", "stream_ordering")
)
def lower_bound(token, engine, inclusive=False):
@ -74,13 +74,20 @@ def lower_bound(token, engine, inclusive=False):
# as it optimises ``(x,y) < (a,b)`` on multicolumn indexes. So we
# use the later form when running against postgres.
return "((%d,%d) <%s (%s,%s))" % (
token.topological, token.stream, inclusive,
"topological_ordering", "stream_ordering",
token.topological,
token.stream,
inclusive,
"topological_ordering",
"stream_ordering",
)
return "(%d < %s OR (%d = %s AND %d <%s %s))" % (
token.topological, "topological_ordering",
token.topological, "topological_ordering",
token.stream, inclusive, "stream_ordering",
token.topological,
"topological_ordering",
token.topological,
"topological_ordering",
token.stream,
inclusive,
"stream_ordering",
)
@ -94,13 +101,20 @@ def upper_bound(token, engine, inclusive=True):
# as it optimises ``(x,y) > (a,b)`` on multicolumn indexes. So we
# use the later form when running against postgres.
return "((%d,%d) >%s (%s,%s))" % (
token.topological, token.stream, inclusive,
"topological_ordering", "stream_ordering",
token.topological,
token.stream,
inclusive,
"topological_ordering",
"stream_ordering",
)
return "(%d > %s OR (%d = %s AND %d >%s %s))" % (
token.topological, "topological_ordering",
token.topological, "topological_ordering",
token.stream, inclusive, "stream_ordering",
token.topological,
"topological_ordering",
token.topological,
"topological_ordering",
token.stream,
inclusive,
"stream_ordering",
)
@ -116,9 +130,7 @@ def filter_to_clause(event_filter):
args = []
if event_filter.types:
clauses.append(
"(%s)" % " OR ".join("type = ?" for _ in event_filter.types)
)
clauses.append("(%s)" % " OR ".join("type = ?" for _ in event_filter.types))
args.extend(event_filter.types)
for typ in event_filter.not_types:
@ -126,9 +138,7 @@ def filter_to_clause(event_filter):
args.append(typ)
if event_filter.senders:
clauses.append(
"(%s)" % " OR ".join("sender = ?" for _ in event_filter.senders)
)
clauses.append("(%s)" % " OR ".join("sender = ?" for _ in event_filter.senders))
args.extend(event_filter.senders)
for sender in event_filter.not_senders:
@ -136,9 +146,7 @@ def filter_to_clause(event_filter):
args.append(sender)
if event_filter.rooms:
clauses.append(
"(%s)" % " OR ".join("room_id = ?" for _ in event_filter.rooms)
)
clauses.append("(%s)" % " OR ".join("room_id = ?" for _ in event_filter.rooms))
args.extend(event_filter.rooms)
for room_id in event_filter.not_rooms:
@ -165,17 +173,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
events_max = self.get_room_max_stream_ordering()
event_cache_prefill, min_event_val = self._get_cache_dict(
db_conn, "events",
db_conn,
"events",
entity_column="room_id",
stream_column="stream_ordering",
max_value=events_max,
)
self._events_stream_cache = StreamChangeCache(
"EventsRoomStreamChangeCache", min_event_val,
"EventsRoomStreamChangeCache",
min_event_val,
prefilled_cache=event_cache_prefill,
)
self._membership_stream_cache = StreamChangeCache(
"MembershipStreamChangeCache", events_max,
"MembershipStreamChangeCache", events_max
)
self._stream_order_on_start = self.get_room_max_stream_ordering()
@ -189,8 +199,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
raise NotImplementedError()
@defer.inlineCallbacks
def get_room_events_stream_for_rooms(self, room_ids, from_key, to_key, limit=0,
order='DESC'):
def get_room_events_stream_for_rooms(
self, room_ids, from_key, to_key, limit=0, order='DESC'
):
"""Get new room events in stream ordering since `from_key`.
Args:
@ -221,14 +232,23 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
results = {}
room_ids = list(room_ids)
for rm_ids in (room_ids[i:i + 20] for i in range(0, len(room_ids), 20)):
res = yield make_deferred_yieldable(defer.gatherResults([
run_in_background(
self.get_room_events_stream_for_room,
room_id, from_key, to_key, limit, order=order,
for rm_ids in (room_ids[i : i + 20] for i in range(0, len(room_ids), 20)):
res = yield make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
self.get_room_events_stream_for_room,
room_id,
from_key,
to_key,
limit,
order=order,
)
for room_id in rm_ids
],
consumeErrors=True,
)
for room_id in rm_ids
], consumeErrors=True))
)
results.update(dict(zip(rm_ids, res)))
defer.returnValue(results)
@ -243,13 +263,15 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"""
from_key = RoomStreamToken.parse_stream_token(from_key).stream
return set(
room_id for room_id in room_ids
room_id
for room_id in room_ids
if self._events_stream_cache.has_entity_changed(room_id, from_key)
)
@defer.inlineCallbacks
def get_room_events_stream_for_room(self, room_id, from_key, to_key, limit=0,
order='DESC'):
def get_room_events_stream_for_room(
self, room_id, from_key, to_key, limit=0, order='DESC'
):
"""Get new room events in stream ordering since `from_key`.
@ -297,10 +319,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
rows = yield self.runInteraction("get_room_events_stream_for_room", f)
ret = yield self._get_events(
[r.event_id for r in rows],
get_prev_content=True
)
ret = yield self._get_events([r.event_id for r in rows], get_prev_content=True)
self._set_before_and_after(ret, rows, topo_order=from_id is None)
@ -340,7 +359,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
" AND e.stream_ordering > ? AND e.stream_ordering <= ?"
" ORDER BY e.stream_ordering ASC"
)
txn.execute(sql, (user_id, from_id, to_id,))
txn.execute(sql, (user_id, from_id, to_id))
rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
@ -348,10 +367,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
rows = yield self.runInteraction("get_membership_changes_for_user", f)
ret = yield self._get_events(
[r.event_id for r in rows],
get_prev_content=True
)
ret = yield self._get_events([r.event_id for r in rows], get_prev_content=True)
self._set_before_and_after(ret, rows, topo_order=False)
@ -374,13 +390,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"""
rows, token = yield self.get_recent_event_ids_for_room(
room_id, limit, end_token,
room_id, limit, end_token
)
logger.debug("stream before")
events = yield self._get_events(
[r.event_id for r in rows],
get_prev_content=True
[r.event_id for r in rows], get_prev_content=True
)
logger.debug("stream after")
@ -410,8 +425,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
end_token = RoomStreamToken.parse(end_token)
rows, token = yield self.runInteraction(
"get_recent_event_ids_for_room", self._paginate_room_events_txn,
room_id, from_token=end_token, limit=limit,
"get_recent_event_ids_for_room",
self._paginate_room_events_txn,
room_id,
from_token=end_token,
limit=limit,
)
# We want to return the results in ascending order.
@ -430,6 +448,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
Deferred[(int, int, str)]:
(stream ordering, topological ordering, event_id)
"""
def _f(txn):
sql = (
"SELECT stream_ordering, topological_ordering, event_id"
@ -439,12 +458,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
" ORDER BY stream_ordering"
" LIMIT 1"
)
txn.execute(sql, (room_id, stream_ordering, ))
txn.execute(sql, (room_id, stream_ordering))
return txn.fetchone()
return self.runInteraction(
"get_room_event_after_stream_ordering", _f,
)
return self.runInteraction("get_room_event_after_stream_ordering", _f)
@defer.inlineCallbacks
def get_room_events_max_id(self, room_id=None):
@ -459,8 +476,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
defer.returnValue("s%d" % (token,))
else:
topo = yield self.runInteraction(
"_get_max_topological_txn", self._get_max_topological_txn,
room_id,
"_get_max_topological_txn", self._get_max_topological_txn, room_id
)
defer.returnValue("t%d-%d" % (topo, token))
@ -474,9 +490,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
A deferred "s%d" stream token.
"""
return self._simple_select_one_onecol(
table="events",
keyvalues={"event_id": event_id},
retcol="stream_ordering",
table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering"
).addCallback(lambda row: "s%d" % (row,))
def get_topological_token_for_event(self, event_id):
@ -493,8 +507,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
keyvalues={"event_id": event_id},
retcols=("stream_ordering", "topological_ordering"),
desc="get_topological_token_for_event",
).addCallback(lambda row: "t%d-%d" % (
row["topological_ordering"], row["stream_ordering"],)
).addCallback(
lambda row: "t%d-%d" % (row["topological_ordering"], row["stream_ordering"])
)
def get_max_topological_token(self, room_id, stream_key):
@ -503,17 +517,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
" WHERE room_id = ? AND stream_ordering < ?"
)
return self._execute(
"get_max_topological_token", None,
sql, room_id, stream_key,
).addCallback(
lambda r: r[0][0] if r else 0
)
"get_max_topological_token", None, sql, room_id, stream_key
).addCallback(lambda r: r[0][0] if r else 0)
def _get_max_topological_txn(self, txn, room_id):
txn.execute(
"SELECT MAX(topological_ordering) FROM events"
" WHERE room_id = ?",
(room_id,)
"SELECT MAX(topological_ordering) FROM events" " WHERE room_id = ?",
(room_id,),
)
rows = txn.fetchall()
@ -540,14 +550,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
internal = event.internal_metadata
internal.before = str(RoomStreamToken(topo, stream - 1))
internal.after = str(RoomStreamToken(topo, stream))
internal.order = (
int(topo) if topo else 0,
int(stream),
)
internal.order = (int(topo) if topo else 0, int(stream))
@defer.inlineCallbacks
def get_events_around(
self, room_id, event_id, before_limit, after_limit, event_filter=None,
self, room_id, event_id, before_limit, after_limit, event_filter=None
):
"""Retrieve events and pagination tokens around a given event in a
room.
@ -564,29 +571,34 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"""
results = yield self.runInteraction(
"get_events_around", self._get_events_around_txn,
room_id, event_id, before_limit, after_limit, event_filter,
"get_events_around",
self._get_events_around_txn,
room_id,
event_id,
before_limit,
after_limit,
event_filter,
)
events_before = yield self._get_events(
[e for e in results["before"]["event_ids"]],
get_prev_content=True
[e for e in results["before"]["event_ids"]], get_prev_content=True
)
events_after = yield self._get_events(
[e for e in results["after"]["event_ids"]],
get_prev_content=True
[e for e in results["after"]["event_ids"]], get_prev_content=True
)
defer.returnValue({
"events_before": events_before,
"events_after": events_after,
"start": results["before"]["token"],
"end": results["after"]["token"],
})
defer.returnValue(
{
"events_before": events_before,
"events_after": events_after,
"start": results["before"]["token"],
"end": results["after"]["token"],
}
)
def _get_events_around_txn(
self, txn, room_id, event_id, before_limit, after_limit, event_filter,
self, txn, room_id, event_id, before_limit, after_limit, event_filter
):
"""Retrieves event_ids and pagination tokens around a given event in a
room.
@ -605,46 +617,43 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
results = self._simple_select_one_txn(
txn,
"events",
keyvalues={
"event_id": event_id,
"room_id": room_id,
},
keyvalues={"event_id": event_id, "room_id": room_id},
retcols=["stream_ordering", "topological_ordering"],
)
# Paginating backwards includes the event at the token, but paginating
# forward doesn't.
before_token = RoomStreamToken(
results["topological_ordering"] - 1,
results["stream_ordering"],
results["topological_ordering"] - 1, results["stream_ordering"]
)
after_token = RoomStreamToken(
results["topological_ordering"],
results["stream_ordering"],
results["topological_ordering"], results["stream_ordering"]
)
rows, start_token = self._paginate_room_events_txn(
txn, room_id, before_token, direction='b', limit=before_limit,
txn,
room_id,
before_token,
direction='b',
limit=before_limit,
event_filter=event_filter,
)
events_before = [r.event_id for r in rows]
rows, end_token = self._paginate_room_events_txn(
txn, room_id, after_token, direction='f', limit=after_limit,
txn,
room_id,
after_token,
direction='f',
limit=after_limit,
event_filter=event_filter,
)
events_after = [r.event_id for r in rows]
return {
"before": {
"event_ids": events_before,
"token": start_token,
},
"after": {
"event_ids": events_after,
"token": end_token,
},
"before": {"event_ids": events_before, "token": start_token},
"after": {"event_ids": events_after, "token": end_token},
}
@defer.inlineCallbacks
@ -685,7 +694,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return upper_bound, [row[1] for row in rows]
upper_bound, event_ids = yield self.runInteraction(
"get_all_new_events_stream", get_all_new_events_stream_txn,
"get_all_new_events_stream", get_all_new_events_stream_txn
)
events = yield self._get_events(event_ids)
@ -697,7 +706,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
table="federation_stream_position",
retcol="stream_id",
keyvalues={"type": typ},
desc="get_federation_out_pos"
desc="get_federation_out_pos",
)
def update_federation_out_pos(self, typ, stream_id):
@ -711,8 +720,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
def has_room_changed_since(self, room_id, stream_id):
return self._events_stream_cache.has_entity_changed(room_id, stream_id)
def _paginate_room_events_txn(self, txn, room_id, from_token, to_token=None,
direction='b', limit=-1, event_filter=None):
def _paginate_room_events_txn(
self,
txn,
room_id,
from_token,
to_token=None,
direction='b',
limit=-1,
event_filter=None,
):
"""Returns list of events before or after a given token.
Args:
@ -741,22 +758,20 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
args = [False, room_id]
if direction == 'b':
order = "DESC"
bounds = upper_bound(
from_token, self.database_engine
)
bounds = upper_bound(from_token, self.database_engine)
if to_token:
bounds = "%s AND %s" % (bounds, lower_bound(
to_token, self.database_engine
))
bounds = "%s AND %s" % (
bounds,
lower_bound(to_token, self.database_engine),
)
else:
order = "ASC"
bounds = lower_bound(
from_token, self.database_engine
)
bounds = lower_bound(from_token, self.database_engine)
if to_token:
bounds = "%s AND %s" % (bounds, upper_bound(
to_token, self.database_engine
))
bounds = "%s AND %s" % (
bounds,
upper_bound(to_token, self.database_engine),
)
filter_clause, filter_args = filter_to_clause(event_filter)
@ -772,10 +787,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
" WHERE outlier = ? AND room_id = ? AND %(bounds)s"
" ORDER BY topological_ordering %(order)s,"
" stream_ordering %(order)s LIMIT ?"
) % {
"bounds": bounds,
"order": order,
}
) % {"bounds": bounds, "order": order}
txn.execute(sql, args)
@ -796,11 +808,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
# TODO (erikj): We should work out what to do here instead.
next_token = to_token if to_token else from_token
return rows, str(next_token),
return rows, str(next_token)
@defer.inlineCallbacks
def paginate_room_events(self, room_id, from_key, to_key=None,
direction='b', limit=-1, event_filter=None):
def paginate_room_events(
self, room_id, from_key, to_key=None, direction='b', limit=-1, event_filter=None
):
"""Returns list of events before or after a given token.
Args:
@ -826,13 +839,18 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
to_key = RoomStreamToken.parse(to_key)
rows, token = yield self.runInteraction(
"paginate_room_events", self._paginate_room_events_txn,
room_id, from_key, to_key, direction, limit, event_filter,
"paginate_room_events",
self._paginate_room_events_txn,
room_id,
from_key,
to_key,
direction,
limit,
event_filter,
)
events = yield self._get_events(
[r.event_id for r in rows],
get_prev_content=True
[r.event_id for r in rows], get_prev_content=True
)
self._set_before_and_after(events, rows)

View file

@ -84,9 +84,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
def get_tag_content(txn, tag_ids):
sql = (
"SELECT tag, content"
" FROM room_tags"
" WHERE user_id=? AND room_id=?"
"SELECT tag, content" " FROM room_tags" " WHERE user_id=? AND room_id=?"
)
results = []
for stream_id, user_id, room_id in tag_ids:
@ -105,7 +103,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
tags = yield self.runInteraction(
"get_all_updated_tag_content",
get_tag_content,
tag_ids[i:i + batch_size],
tag_ids[i : i + batch_size],
)
results.extend(tags)
@ -123,6 +121,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
A deferred dict mapping from room_id strings to lists of tag
strings for all the rooms that changed since the stream_id token.
"""
def get_updated_tags_txn(txn):
sql = (
"SELECT room_id from room_tags_revisions"
@ -138,9 +137,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
if not changed:
defer.returnValue({})
room_ids = yield self.runInteraction(
"get_updated_tags", get_updated_tags_txn
)
room_ids = yield self.runInteraction("get_updated_tags", get_updated_tags_txn)
results = {}
if room_ids:
@ -163,9 +160,9 @@ class TagsWorkerStore(AccountDataWorkerStore):
keyvalues={"user_id": user_id, "room_id": room_id},
retcols=("tag", "content"),
desc="get_tags_for_room",
).addCallback(lambda rows: {
row["tag"]: json.loads(row["content"]) for row in rows
})
).addCallback(
lambda rows: {row["tag"]: json.loads(row["content"]) for row in rows}
)
class TagsStore(TagsWorkerStore):
@ -186,14 +183,8 @@ class TagsStore(TagsWorkerStore):
self._simple_upsert_txn(
txn,
table="room_tags",
keyvalues={
"user_id": user_id,
"room_id": room_id,
"tag": tag,
},
values={
"content": content_json,
}
keyvalues={"user_id": user_id, "room_id": room_id, "tag": tag},
values={"content": content_json},
)
self._update_revision_txn(txn, user_id, room_id, next_id)
@ -211,6 +202,7 @@ class TagsStore(TagsWorkerStore):
Returns:
A deferred that completes once the tag has been removed
"""
def remove_tag_txn(txn, next_id):
sql = (
"DELETE FROM room_tags "
@ -238,8 +230,7 @@ class TagsStore(TagsWorkerStore):
"""
txn.call_after(
self._account_data_stream_cache.entity_has_changed,
user_id, next_id
self._account_data_stream_cache.entity_has_changed, user_id, next_id
)
update_max_id_sql = (

View file

@ -38,16 +38,12 @@ logger = logging.getLogger(__name__)
_TransactionRow = namedtuple(
"_TransactionRow", (
"id", "transaction_id", "destination", "ts", "response_code",
"response_json",
)
"_TransactionRow",
("id", "transaction_id", "destination", "ts", "response_code", "response_json"),
)
_UpdateTransactionRow = namedtuple(
"_TransactionRow", (
"response_code", "response_json",
)
"_TransactionRow", ("response_code", "response_json")
)
SENTINEL = object()
@ -84,19 +80,22 @@ class TransactionStore(SQLBaseStore):
return self.runInteraction(
"get_received_txn_response",
self._get_received_txn_response, transaction_id, origin
self._get_received_txn_response,
transaction_id,
origin,
)
def _get_received_txn_response(self, txn, transaction_id, origin):
result = self._simple_select_one_txn(
txn,
table="received_transactions",
keyvalues={
"transaction_id": transaction_id,
"origin": origin,
},
keyvalues={"transaction_id": transaction_id, "origin": origin},
retcols=(
"transaction_id", "origin", "ts", "response_code", "response_json",
"transaction_id",
"origin",
"ts",
"response_code",
"response_json",
"has_been_referenced",
),
allow_none=True,
@ -108,8 +107,7 @@ class TransactionStore(SQLBaseStore):
else:
return None
def set_received_txn_response(self, transaction_id, origin, code,
response_dict):
def set_received_txn_response(self, transaction_id, origin, code, response_dict):
"""Persist the response we returened for an incoming transaction, and
should return for subsequent transactions with the same transaction_id
and origin.
@ -135,8 +133,7 @@ class TransactionStore(SQLBaseStore):
desc="set_received_txn_response",
)
def prep_send_transaction(self, transaction_id, destination,
origin_server_ts):
def prep_send_transaction(self, transaction_id, destination, origin_server_ts):
"""Persists an outgoing transaction and calculates the values for the
previous transaction id list.
@ -182,7 +179,9 @@ class TransactionStore(SQLBaseStore):
result = yield self.runInteraction(
"get_destination_retry_timings",
self._get_destination_retry_timings, destination)
self._get_destination_retry_timings,
destination,
)
# We don't hugely care about race conditions between getting and
# invalidating the cache, since we time out fairly quickly anyway.
@ -193,9 +192,7 @@ class TransactionStore(SQLBaseStore):
result = self._simple_select_one_txn(
txn,
table="destinations",
keyvalues={
"destination": destination,
},
keyvalues={"destination": destination},
retcols=("destination", "retry_last_ts", "retry_interval"),
allow_none=True,
)
@ -205,8 +202,7 @@ class TransactionStore(SQLBaseStore):
else:
return None
def set_destination_retry_timings(self, destination,
retry_last_ts, retry_interval):
def set_destination_retry_timings(self, destination, retry_last_ts, retry_interval):
"""Sets the current retry timings for a given destination.
Both timings should be zero if retrying is no longer occuring.
@ -225,8 +221,9 @@ class TransactionStore(SQLBaseStore):
retry_interval,
)
def _set_destination_retry_timings(self, txn, destination,
retry_last_ts, retry_interval):
def _set_destination_retry_timings(
self, txn, destination, retry_last_ts, retry_interval
):
self.database_engine.lock_table(txn, "destinations")
# We need to be careful here as the data may have changed from under us
@ -235,9 +232,7 @@ class TransactionStore(SQLBaseStore):
prev_row = self._simple_select_one_txn(
txn,
table="destinations",
keyvalues={
"destination": destination,
},
keyvalues={"destination": destination},
retcols=("retry_last_ts", "retry_interval"),
allow_none=True,
)
@ -250,15 +245,13 @@ class TransactionStore(SQLBaseStore):
"destination": destination,
"retry_last_ts": retry_last_ts,
"retry_interval": retry_interval,
}
},
)
elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval:
self._simple_update_one_txn(
txn,
"destinations",
keyvalues={
"destination": destination,
},
keyvalues={"destination": destination},
updatevalues={
"retry_last_ts": retry_last_ts,
"retry_interval": retry_interval,
@ -273,8 +266,7 @@ class TransactionStore(SQLBaseStore):
"""
return self.runInteraction(
"get_destinations_needing_retry",
self._get_destinations_needing_retry
"get_destinations_needing_retry", self._get_destinations_needing_retry
)
def _get_destinations_needing_retry(self, txn):
@ -288,7 +280,7 @@ class TransactionStore(SQLBaseStore):
def _start_cleanup_transactions(self):
return run_as_background_process(
"cleanup_transactions", self._cleanup_transactions,
"cleanup_transactions", self._cleanup_transactions
)
def _cleanup_transactions(self):

View file

@ -194,7 +194,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
room_id
)
users_with_profile = yield state.get_current_user_in_room(room_id)
users_with_profile = yield state.get_current_users_in_room(room_id)
user_ids = set(users_with_profile)
# Update each user in the user directory.

View file

@ -40,9 +40,7 @@ class UserErasureWorkerStore(SQLBaseStore):
).addCallback(operator.truth)
@cachedList(
cached_method_name="is_user_erased",
list_name="user_ids",
inlineCallbacks=True,
cached_method_name="is_user_erased", list_name="user_ids", inlineCallbacks=True
)
def are_users_erased(self, user_ids):
"""
@ -61,16 +59,13 @@ class UserErasureWorkerStore(SQLBaseStore):
def _get_erased_users(txn):
txn.execute(
"SELECT user_id FROM erased_users WHERE user_id IN (%s)" % (
",".join("?" * len(user_ids))
),
"SELECT user_id FROM erased_users WHERE user_id IN (%s)"
% (",".join("?" * len(user_ids))),
user_ids,
)
return set(r[0] for r in txn)
erased_users = yield self.runInteraction(
"are_users_erased", _get_erased_users,
)
erased_users = yield self.runInteraction("are_users_erased", _get_erased_users)
res = dict((u, u in erased_users) for u in user_ids)
defer.returnValue(res)
@ -82,22 +77,16 @@ class UserErasureStore(UserErasureWorkerStore):
Args:
user_id (str): full user_id to be erased
"""
def f(txn):
# first check if they are already in the list
txn.execute(
"SELECT 1 FROM erased_users WHERE user_id = ?",
(user_id, )
)
txn.execute("SELECT 1 FROM erased_users WHERE user_id = ?", (user_id,))
if txn.fetchone():
return
# they are not already there: do the insert.
txn.execute(
"INSERT INTO erased_users (user_id) VALUES (?)",
(user_id, )
)
txn.execute("INSERT INTO erased_users (user_id) VALUES (?)", (user_id,))
self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
self._invalidate_cache_and_stream(
txn, self.is_user_erased, (user_id,)
)
return self.runInteraction("mark_user_erased", f)

View file

@ -43,9 +43,9 @@ def _load_current_id(db_conn, table, column, step=1):
"""
cur = db_conn.cursor()
if step == 1:
cur.execute("SELECT MAX(%s) FROM %s" % (column, table,))
cur.execute("SELECT MAX(%s) FROM %s" % (column, table))
else:
cur.execute("SELECT MIN(%s) FROM %s" % (column, table,))
cur.execute("SELECT MIN(%s) FROM %s" % (column, table))
val, = cur.fetchone()
cur.close()
current_id = int(val) if val else step
@ -77,6 +77,7 @@ class StreamIdGenerator(object):
with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
def __init__(self, db_conn, table, column, extra_tables=[], step=1):
assert step != 0
self._lock = threading.Lock()
@ -84,8 +85,7 @@ class StreamIdGenerator(object):
self._current = _load_current_id(db_conn, table, column, step)
for table, column in extra_tables:
self._current = (max if step > 0 else min)(
self._current,
_load_current_id(db_conn, table, column, step)
self._current, _load_current_id(db_conn, table, column, step)
)
self._unfinished_ids = deque()
@ -121,7 +121,7 @@ class StreamIdGenerator(object):
next_ids = range(
self._current + self._step,
self._current + self._step * (n + 1),
self._step
self._step,
)
self._current += n * self._step

View file

@ -19,14 +19,14 @@ from mock import Mock
import signedjson.key
import signedjson.sign
from twisted.internet import defer, reactor
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.crypto import keyring
from synapse.util import Clock, logcontext
from synapse.util import logcontext
from synapse.util.logcontext import LoggingContext
from tests import unittest, utils
from tests import unittest
class MockPerspectiveServer(object):
@ -52,75 +52,50 @@ class MockPerspectiveServer(object):
return res
class KeyringTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
class KeyringTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.mock_perspective_server = MockPerspectiveServer()
self.http_client = Mock()
self.hs = yield utils.setup_test_homeserver(
self.addCleanup, handlers=None, http_client=self.http_client
)
hs = self.setup_test_homeserver(handlers=None, http_client=self.http_client)
keys = self.mock_perspective_server.get_verify_keys()
self.hs.config.perspectives = {self.mock_perspective_server.server_name: keys}
def assert_sentinel_context(self):
if LoggingContext.current_context() != LoggingContext.sentinel:
self.fail(
"Expected sentinel context but got %s" % (
LoggingContext.current_context(),
)
)
hs.config.perspectives = {self.mock_perspective_server.server_name: keys}
return hs
def check_context(self, _, expected):
self.assertEquals(
getattr(LoggingContext.current_context(), "request", None), expected
)
@defer.inlineCallbacks
def test_wait_for_previous_lookups(self):
kr = keyring.Keyring(self.hs)
lookup_1_deferred = defer.Deferred()
lookup_2_deferred = defer.Deferred()
with LoggingContext("one") as context_one:
context_one.request = "one"
# we run the lookup in a logcontext so that the patched inlineCallbacks can check
# it is doing the right thing with logcontexts.
wait_1_deferred = run_in_context(
kr.wait_for_previous_lookups, ["server1"], {"server1": lookup_1_deferred}
)
wait_1_deferred = kr.wait_for_previous_lookups(
["server1"], {"server1": lookup_1_deferred}
)
# there were no previous lookups, so the deferred should be ready
self.successResultOf(wait_1_deferred)
# there were no previous lookups, so the deferred should be ready
self.assertTrue(wait_1_deferred.called)
# ... so we should have preserved the LoggingContext.
self.assertIs(LoggingContext.current_context(), context_one)
wait_1_deferred.addBoth(self.check_context, "one")
# set off another wait. It should block because the first lookup
# hasn't yet completed.
wait_2_deferred = run_in_context(
kr.wait_for_previous_lookups, ["server1"], {"server1": lookup_2_deferred}
)
with LoggingContext("two") as context_two:
context_two.request = "two"
self.assertFalse(wait_2_deferred.called)
# set off another wait. It should block because the first lookup
# hasn't yet completed.
wait_2_deferred = kr.wait_for_previous_lookups(
["server1"], {"server1": lookup_2_deferred}
)
self.assertFalse(wait_2_deferred.called)
# let the first lookup complete (in the sentinel context)
lookup_1_deferred.callback(None)
# ... so we should have reset the LoggingContext.
self.assert_sentinel_context()
# now the second wait should complete.
self.successResultOf(wait_2_deferred)
wait_2_deferred.addBoth(self.check_context, "two")
# let the first lookup complete (in the sentinel context)
lookup_1_deferred.callback(None)
# now the second wait should complete and restore our
# loggingcontext.
yield wait_2_deferred
@defer.inlineCallbacks
def test_verify_json_objects_for_server_awaits_previous_requests(self):
clock = Clock(reactor)
key1 = signedjson.key.generate_signing_key(1)
kr = keyring.Keyring(self.hs)
@ -145,81 +120,103 @@ class KeyringTestCase(unittest.TestCase):
self.http_client.post_json.side_effect = get_perspectives
with LoggingContext("11") as context_11:
context_11.request = "11"
# start off a first set of lookups
@defer.inlineCallbacks
def first_lookup():
with LoggingContext("11") as context_11:
context_11.request = "11"
# start off a first set of lookups
res_deferreds = kr.verify_json_objects_for_server(
[("server10", json1), ("server11", {})]
)
res_deferreds = kr.verify_json_objects_for_server(
[("server10", json1), ("server11", {})]
)
# the unsigned json should be rejected pretty quickly
self.assertTrue(res_deferreds[1].called)
try:
yield res_deferreds[1]
self.assertFalse("unsigned json didn't cause a failure")
except SynapseError:
pass
# the unsigned json should be rejected pretty quickly
self.assertTrue(res_deferreds[1].called)
try:
yield res_deferreds[1]
self.assertFalse("unsigned json didn't cause a failure")
except SynapseError:
pass
self.assertFalse(res_deferreds[0].called)
res_deferreds[0].addBoth(self.check_context, None)
self.assertFalse(res_deferreds[0].called)
res_deferreds[0].addBoth(self.check_context, None)
# wait a tick for it to send the request to the perspectives server
# (it first tries the datastore)
yield clock.sleep(1) # XXX find out why this takes so long!
self.http_client.post_json.assert_called_once()
yield logcontext.make_deferred_yieldable(res_deferreds[0])
self.assertIs(LoggingContext.current_context(), context_11)
# let verify_json_objects_for_server finish its work before we kill the
# logcontext
yield self.clock.sleep(0)
context_12 = LoggingContext("12")
context_12.request = "12"
with logcontext.PreserveLoggingContext(context_12):
# a second request for a server with outstanding requests
# should block rather than start a second call
d0 = first_lookup()
# wait a tick for it to send the request to the perspectives server
# (it first tries the datastore)
self.pump()
self.http_client.post_json.assert_called_once()
# a second request for a server with outstanding requests
# should block rather than start a second call
@defer.inlineCallbacks
def second_lookup():
with LoggingContext("12") as context_12:
context_12.request = "12"
self.http_client.post_json.reset_mock()
self.http_client.post_json.return_value = defer.Deferred()
res_deferreds_2 = kr.verify_json_objects_for_server(
[("server10", json1)]
[("server10", json1, )]
)
yield clock.sleep(1)
self.http_client.post_json.assert_not_called()
res_deferreds_2[0].addBoth(self.check_context, None)
yield logcontext.make_deferred_yieldable(res_deferreds_2[0])
# complete the first request
with logcontext.PreserveLoggingContext():
persp_deferred.callback(persp_resp)
self.assertIs(LoggingContext.current_context(), context_11)
# let verify_json_objects_for_server finish its work before we kill the
# logcontext
yield self.clock.sleep(0)
with logcontext.PreserveLoggingContext():
yield res_deferreds[0]
yield res_deferreds_2[0]
d2 = second_lookup()
self.pump()
self.http_client.post_json.assert_not_called()
# complete the first request
persp_deferred.callback(persp_resp)
self.get_success(d0)
self.get_success(d2)
@defer.inlineCallbacks
def test_verify_json_for_server(self):
kr = keyring.Keyring(self.hs)
key1 = signedjson.key.generate_signing_key(1)
yield self.hs.datastore.store_server_verify_key(
r = self.hs.datastore.store_server_verify_key(
"server9", "", time.time() * 1000, signedjson.key.get_verify_key(key1)
)
self.get_success(r)
json1 = {}
signedjson.sign.sign_json(json1, "server9", key1)
with LoggingContext("one") as context_one:
context_one.request = "one"
# should fail immediately on an unsigned object
d = _verify_json_for_server(kr, "server9", {})
self.failureResultOf(d, SynapseError)
defer = kr.verify_json_for_server("server9", {})
try:
yield defer
self.fail("should fail on unsigned json")
except SynapseError:
pass
self.assertIs(LoggingContext.current_context(), context_one)
d = _verify_json_for_server(kr, "server9", json1)
self.assertFalse(d.called)
self.get_success(d)
defer = kr.verify_json_for_server("server9", json1)
self.assertFalse(defer.called)
self.assert_sentinel_context()
yield defer
self.assertIs(LoggingContext.current_context(), context_one)
@defer.inlineCallbacks
def run_in_context(f, *args, **kwargs):
with LoggingContext("testctx"):
rv = yield f(*args, **kwargs)
defer.returnValue(rv)
def _verify_json_for_server(keyring, server_name, json_object):
"""thin wrapper around verify_json_for_server which makes sure it is wrapped
with the patched defer.inlineCallbacks.
"""
@defer.inlineCallbacks
def v():
rv1 = yield keyring.verify_json_for_server(server_name, json_object)
defer.returnValue(rv1)
return run_in_context(v)

View file

@ -121,9 +121,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
def get_current_user_in_room(room_id):
def get_current_users_in_room(room_id):
return set(str(u) for u in self.room_members)
hs.get_state_handler().get_current_user_in_room = get_current_user_in_room
hs.get_state_handler().get_current_users_in_room = get_current_users_in_room
self.datastore.get_user_directory_stream_pos.return_value = (
# we deliberately return a non-None stream pos to avoid doing an initial_spam

View file

@ -21,6 +21,7 @@ from mock import Mock
from synapse.api.constants import UserTypes
from synapse.rest.client.v1 import admin, events, login, room
from synapse.rest.client.v2_alpha import groups
from tests import unittest
@ -490,3 +491,126 @@ class ShutdownRoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(
expect_code, int(channel.result["code"]), msg=channel.result["body"],
)
class DeleteGroupTestCase(unittest.HomeserverTestCase):
servlets = [
admin.register_servlets,
login.register_servlets,
groups.register_servlets,
]
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
self.other_user = self.register_user("user", "pass")
self.other_user_token = self.login("user", "pass")
def test_delete_group(self):
# Create a new group
request, channel = self.make_request(
"POST",
"/create_group".encode('ascii'),
access_token=self.admin_user_tok,
content={
"localpart": "test",
}
)
self.render(request)
self.assertEqual(
200, int(channel.result["code"]), msg=channel.result["body"],
)
group_id = channel.json_body["group_id"]
self._check_group(group_id, expect_code=200)
# Invite/join another user
url = "/groups/%s/admin/users/invite/%s" % (group_id, self.other_user)
request, channel = self.make_request(
"PUT",
url.encode('ascii'),
access_token=self.admin_user_tok,
content={}
)
self.render(request)
self.assertEqual(
200, int(channel.result["code"]), msg=channel.result["body"],
)
url = "/groups/%s/self/accept_invite" % (group_id,)
request, channel = self.make_request(
"PUT",
url.encode('ascii'),
access_token=self.other_user_token,
content={}
)
self.render(request)
self.assertEqual(
200, int(channel.result["code"]), msg=channel.result["body"],
)
# Check other user knows they're in the group
self.assertIn(group_id, self._get_groups_user_is_in(self.admin_user_tok))
self.assertIn(group_id, self._get_groups_user_is_in(self.other_user_token))
# Now delete the group
url = "/admin/delete_group/" + group_id
request, channel = self.make_request(
"POST",
url.encode('ascii'),
access_token=self.admin_user_tok,
content={
"localpart": "test",
}
)
self.render(request)
self.assertEqual(
200, int(channel.result["code"]), msg=channel.result["body"],
)
# Check group returns 404
self._check_group(group_id, expect_code=404)
# Check users don't think they're in the group
self.assertNotIn(group_id, self._get_groups_user_is_in(self.admin_user_tok))
self.assertNotIn(group_id, self._get_groups_user_is_in(self.other_user_token))
def _check_group(self, group_id, expect_code):
"""Assert that trying to fetch the given group results in the given
HTTP status code
"""
url = "/groups/%s/profile" % (group_id,)
request, channel = self.make_request(
"GET",
url.encode('ascii'),
access_token=self.admin_user_tok,
)
self.render(request)
self.assertEqual(
expect_code, int(channel.result["code"]), msg=channel.result["body"],
)
def _get_groups_user_is_in(self, access_token):
"""Returns the list of groups the user is in (given their access token)
"""
request, channel = self.make_request(
"GET",
"/joined_groups".encode('ascii'),
access_token=access_token,
)
self.render(request)
self.assertEqual(
200, int(channel.result["code"]), msg=channel.result["body"],
)
return channel.json_body["groups"]

View file

@ -1,118 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.types import UserID
from tests import unittest
from tests.utils import setup_test_homeserver
class PresenceStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
hs = yield setup_test_homeserver(self.addCleanup)
self.store = hs.get_datastore()
self.u_apple = UserID.from_string("@apple:test")
self.u_banana = UserID.from_string("@banana:test")
@defer.inlineCallbacks
def test_presence_list(self):
self.assertEquals(
[],
(
yield self.store.get_presence_list(
observer_localpart=self.u_apple.localpart
)
),
)
self.assertEquals(
[],
(
yield self.store.get_presence_list(
observer_localpart=self.u_apple.localpart, accepted=True
)
),
)
yield self.store.add_presence_list_pending(
observer_localpart=self.u_apple.localpart,
observed_userid=self.u_banana.to_string(),
)
self.assertEquals(
[{"observed_user_id": "@banana:test", "accepted": 0}],
(
yield self.store.get_presence_list(
observer_localpart=self.u_apple.localpart
)
),
)
self.assertEquals(
[],
(
yield self.store.get_presence_list(
observer_localpart=self.u_apple.localpart, accepted=True
)
),
)
yield self.store.set_presence_list_accepted(
observer_localpart=self.u_apple.localpart,
observed_userid=self.u_banana.to_string(),
)
self.assertEquals(
[{"observed_user_id": "@banana:test", "accepted": 1}],
(
yield self.store.get_presence_list(
observer_localpart=self.u_apple.localpart
)
),
)
self.assertEquals(
[{"observed_user_id": "@banana:test", "accepted": 1}],
(
yield self.store.get_presence_list(
observer_localpart=self.u_apple.localpart, accepted=True
)
),
)
yield self.store.del_presence_list(
observer_localpart=self.u_apple.localpart,
observed_userid=self.u_banana.to_string(),
)
self.assertEquals(
[],
(
yield self.store.get_presence_list(
observer_localpart=self.u_apple.localpart
)
),
)
self.assertEquals(
[],
(
yield self.store.get_presence_list(
observer_localpart=self.u_apple.localpart, accepted=True
)
),
)