Merge branch 'develop' into babolivier/mark_unread

This commit is contained in:
Brendan Abolivier 2020-06-15 16:37:52 +01:00
commit 6efb2b0ad4
No known key found for this signature in database
GPG key ID: 1E015C145F1916CD
57 changed files with 363 additions and 292 deletions

1
changelog.d/7648.bugfix Normal file
View file

@ -0,0 +1 @@
In working mode, ensure that replicated data has not already been received.

1
changelog.d/7688.bugfix Normal file
View file

@ -0,0 +1 @@
Fix "Starting db txn 'get_completed_ui_auth_stages' from sentinel context" warning. The bug was introduced in 1.13.0rc1.

1
changelog.d/7689.bugfix Normal file
View file

@ -0,0 +1 @@
Compare the URI and method during user interactive authentication (instead of the URI twice). Bug introduced in 1.13.0rc1.

1
changelog.d/7692.misc Normal file
View file

@ -0,0 +1 @@
Replace uses of `six.iterkeys`/`iteritems`/`itervalues` with `keys()`/`items()`/`values()`.

View file

@ -16,8 +16,6 @@
import logging import logging
from typing import Optional from typing import Optional
from six import itervalues
import pymacaroons import pymacaroons
from netaddr import IPAddress from netaddr import IPAddress
@ -90,7 +88,7 @@ class Auth(object):
event, prev_state_ids, for_verification=True event, prev_state_ids, for_verification=True
) )
auth_events = yield self.store.get_events(auth_events_ids) auth_events = yield self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in itervalues(auth_events)} auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
room_version_obj = KNOWN_ROOM_VERSIONS[room_version] room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
event_auth.check( event_auth.check(

View file

@ -19,7 +19,6 @@
import logging import logging
from typing import Dict, List from typing import Dict, List
from six import iteritems
from six.moves import http_client from six.moves import http_client
from canonicaljson import json from canonicaljson import json
@ -497,7 +496,7 @@ def cs_error(msg, code=Codes.UNKNOWN, **kwargs):
A dict representing the error response JSON. A dict representing the error response JSON.
""" """
err = {"error": msg, "errcode": code} err = {"error": msg, "errcode": code}
for key, value in iteritems(kwargs): for key, value in kwargs.items():
err[key] = value err[key] = value
return err return err

View file

@ -738,6 +738,11 @@ class GenericWorkerReplicationHandler(ReplicationDataHandler):
except Exception: except Exception:
logger.exception("Error processing replication") logger.exception("Error processing replication")
async def on_position(self, stream_name: str, instance_name: str, token: int):
await super().on_position(stream_name, instance_name, token)
# Also call on_rdata to ensure that stream positions are properly reset.
await self.on_rdata(stream_name, instance_name, token, [])
def stop_pusher(self, user_id, app_id, pushkey): def stop_pusher(self, user_id, app_id, pushkey):
if not self.notify_pushers: if not self.notify_pushers:
return return

View file

@ -24,8 +24,6 @@ import os
import resource import resource
import sys import sys
from six import iteritems
from prometheus_client import Gauge from prometheus_client import Gauge
from twisted.application import service from twisted.application import service
@ -525,7 +523,7 @@ def phone_stats_home(hs, stats, stats_process=_stats_process):
stats["total_nonbridged_users"] = total_nonbridged_users stats["total_nonbridged_users"] = total_nonbridged_users
daily_user_type_results = yield hs.get_datastore().count_daily_user_type() daily_user_type_results = yield hs.get_datastore().count_daily_user_type()
for name, count in iteritems(daily_user_type_results): for name, count in daily_user_type_results.items():
stats["daily_user_type_" + name] = count stats["daily_user_type_" + name] = count
room_count = yield hs.get_datastore().get_room_count() room_count = yield hs.get_datastore().get_room_count()
@ -537,7 +535,7 @@ def phone_stats_home(hs, stats, stats_process=_stats_process):
stats["daily_messages"] = yield hs.get_datastore().count_daily_messages() stats["daily_messages"] = yield hs.get_datastore().count_daily_messages()
r30_results = yield hs.get_datastore().count_r30_users() r30_results = yield hs.get_datastore().count_r30_users()
for name, count in iteritems(r30_results): for name, count in r30_results.items():
stats["r30_users_" + name] = count stats["r30_users_" + name] = count
daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages() daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages()

View file

@ -20,8 +20,6 @@ import os
from distutils.util import strtobool from distutils.util import strtobool
from typing import Dict, Optional, Type from typing import Dict, Optional, Type
import six
from unpaddedbase64 import encode_base64 from unpaddedbase64 import encode_base64
from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions
@ -290,7 +288,7 @@ class EventBase(metaclass=abc.ABCMeta):
return list(self._dict.items()) return list(self._dict.items())
def keys(self): def keys(self):
return six.iterkeys(self._dict) return self._dict.keys()
def prev_event_ids(self): def prev_event_ids(self):
"""Returns the list of prev event IDs. The order matches the order """Returns the list of prev event IDs. The order matches the order

View file

@ -14,8 +14,6 @@
# limitations under the License. # limitations under the License.
from typing import Optional, Union from typing import Optional, Union
from six import iteritems
import attr import attr
from frozendict import frozendict from frozendict import frozendict
@ -341,7 +339,7 @@ def _encode_state_dict(state_dict):
if state_dict is None: if state_dict is None:
return None return None
return [(etype, state_key, v) for (etype, state_key), v in iteritems(state_dict)] return [(etype, state_key, v) for (etype, state_key), v in state_dict.items()]
def _decode_state_dict(input): def _decode_state_dict(input):

View file

@ -93,8 +93,8 @@ class FederationBase(object):
# *actual* redacted copy to be on the safe side.) # *actual* redacted copy to be on the safe side.)
redacted_event = prune_event(pdu) redacted_event = prune_event(pdu)
if set(redacted_event.keys()) == set(pdu.keys()) and set( if set(redacted_event.keys()) == set(pdu.keys()) and set(
six.iterkeys(redacted_event.content) redacted_event.content.keys()
) == set(six.iterkeys(pdu.content)): ) == set(pdu.content.keys()):
logger.info( logger.info(
"Event %s seems to have been redacted; using our redacted " "Event %s seems to have been redacted; using our redacted "
"copy", "copy",

View file

@ -18,7 +18,6 @@ import logging
from typing import Any, Callable, Dict, List, Match, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Match, Optional, Tuple, Union
import six import six
from six import iteritems
from canonicaljson import json from canonicaljson import json
from prometheus_client import Counter from prometheus_client import Counter
@ -534,9 +533,9 @@ class FederationServer(FederationBase):
",".join( ",".join(
( (
"%s for %s:%s" % (key_id, user_id, device_id) "%s for %s:%s" % (key_id, user_id, device_id)
for user_id, user_keys in iteritems(json_result) for user_id, user_keys in json_result.items()
for device_id, device_keys in iteritems(user_keys) for device_id, device_keys in user_keys.items()
for key_id, _ in iteritems(device_keys) for key_id, _ in device_keys.items()
) )
), ),
) )

View file

@ -33,8 +33,6 @@ import logging
from collections import namedtuple from collections import namedtuple
from typing import Dict, List, Tuple, Type from typing import Dict, List, Tuple, Type
from six import iteritems
from sortedcontainers import SortedDict from sortedcontainers import SortedDict
from twisted.internet import defer from twisted.internet import defer
@ -327,7 +325,7 @@ class FederationRemoteSendQueue(object):
# stream position. # stream position.
keyed_edus = {v: k for k, v in self.keyed_edu_changed.items()[i:j]} keyed_edus = {v: k for k, v in self.keyed_edu_changed.items()[i:j]}
for ((destination, edu_key), pos) in iteritems(keyed_edus): for ((destination, edu_key), pos) in keyed_edus.items():
rows.append( rows.append(
( (
pos, pos,
@ -530,10 +528,10 @@ def process_rows_for_federation(transaction_queue, rows):
states=[state], destinations=destinations states=[state], destinations=destinations
) )
for destination, edu_map in iteritems(buff.keyed_edus): for destination, edu_map in buff.keyed_edus.items():
for key, edu in edu_map.items(): for key, edu in edu_map.items():
transaction_queue.send_edu(edu, key) transaction_queue.send_edu(edu, key)
for destination, edu_list in iteritems(buff.edus): for destination, edu_list in buff.edus.items():
for edu in edu_list: for edu in edu_list:
transaction_queue.send_edu(edu, None) transaction_queue.send_edu(edu, None)

View file

@ -16,8 +16,6 @@
import logging import logging
from typing import Dict, Hashable, Iterable, List, Optional, Set, Tuple from typing import Dict, Hashable, Iterable, List, Optional, Set, Tuple
from six import itervalues
from prometheus_client import Counter from prometheus_client import Counter
from twisted.internet import defer from twisted.internet import defer
@ -218,7 +216,7 @@ class FederationSender(object):
defer.gatherResults( defer.gatherResults(
[ [
run_in_background(handle_room_events, evs) run_in_background(handle_room_events, evs)
for evs in itervalues(events_by_room) for evs in events_by_room.values()
], ],
consumeErrors=True, consumeErrors=True,
) )

View file

@ -15,8 +15,6 @@
import logging import logging
from six import itervalues
from prometheus_client import Counter from prometheus_client import Counter
from twisted.internet import defer from twisted.internet import defer
@ -125,7 +123,7 @@ class ApplicationServicesHandler(object):
defer.gatherResults( defer.gatherResults(
[ [
run_in_background(handle_room_events, evs) run_in_background(handle_room_events, evs)
for evs in itervalues(events_by_room) for evs in events_by_room.values()
], ],
consumeErrors=True, consumeErrors=True,
) )

View file

@ -297,7 +297,7 @@ class AuthHandler(BaseHandler):
# Convert the URI and method to strings. # Convert the URI and method to strings.
uri = request.uri.decode("utf-8") uri = request.uri.decode("utf-8")
method = request.uri.decode("utf-8") method = request.method.decode("utf-8")
# If there's no session ID, create a new session. # If there's no session ID, create a new session.
if not sid: if not sid:

View file

@ -17,8 +17,6 @@
import logging import logging
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from six import iteritems, itervalues
from twisted.internet import defer from twisted.internet import defer
from synapse.api import errors from synapse.api import errors
@ -159,7 +157,7 @@ class DeviceWorkerHandler(BaseHandler):
# The user may have left the room # The user may have left the room
# TODO: Check if they actually did or if we were just invited. # TODO: Check if they actually did or if we were just invited.
if room_id not in room_ids: if room_id not in room_ids:
for key, event_id in iteritems(current_state_ids): for key, event_id in current_state_ids.items():
etype, state_key = key etype, state_key = key
if etype != EventTypes.Member: if etype != EventTypes.Member:
continue continue
@ -182,7 +180,7 @@ class DeviceWorkerHandler(BaseHandler):
log_kv( log_kv(
{"event": "encountered empty previous state", "room_id": room_id} {"event": "encountered empty previous state", "room_id": room_id}
) )
for key, event_id in iteritems(current_state_ids): for key, event_id in current_state_ids.items():
etype, state_key = key etype, state_key = key
if etype != EventTypes.Member: if etype != EventTypes.Member:
continue continue
@ -198,10 +196,10 @@ class DeviceWorkerHandler(BaseHandler):
# Check if we've joined the room? If so we just blindly add all the users to # Check if we've joined the room? If so we just blindly add all the users to
# the "possibly changed" users. # the "possibly changed" users.
for state_dict in itervalues(prev_state_ids): for state_dict in prev_state_ids.values():
member_event = state_dict.get((EventTypes.Member, user_id), None) member_event = state_dict.get((EventTypes.Member, user_id), None)
if not member_event or member_event != current_member_id: if not member_event or member_event != current_member_id:
for key, event_id in iteritems(current_state_ids): for key, event_id in current_state_ids.items():
etype, state_key = key etype, state_key = key
if etype != EventTypes.Member: if etype != EventTypes.Member:
continue continue
@ -211,14 +209,14 @@ class DeviceWorkerHandler(BaseHandler):
# If there has been any change in membership, include them in the # If there has been any change in membership, include them in the
# possibly changed list. We'll check if they are joined below, # possibly changed list. We'll check if they are joined below,
# and we're not toooo worried about spuriously adding users. # and we're not toooo worried about spuriously adding users.
for key, event_id in iteritems(current_state_ids): for key, event_id in current_state_ids.items():
etype, state_key = key etype, state_key = key
if etype != EventTypes.Member: if etype != EventTypes.Member:
continue continue
# check if this member has changed since any of the extremities # check if this member has changed since any of the extremities
# at the stream_ordering, and add them to the list if so. # at the stream_ordering, and add them to the list if so.
for state_dict in itervalues(prev_state_ids): for state_dict in prev_state_ids.values():
prev_event_id = state_dict.get(key, None) prev_event_id = state_dict.get(key, None)
if not prev_event_id or prev_event_id != event_id: if not prev_event_id or prev_event_id != event_id:
if state_key != user_id: if state_key != user_id:

View file

@ -17,8 +17,6 @@
import logging import logging
from six import iteritems
import attr import attr
from canonicaljson import encode_canonical_json, json from canonicaljson import encode_canonical_json, json
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
@ -135,7 +133,7 @@ class E2eKeysHandler(object):
remote_queries_not_in_cache = {} remote_queries_not_in_cache = {}
if remote_queries: if remote_queries:
query_list = [] query_list = []
for user_id, device_ids in iteritems(remote_queries): for user_id, device_ids in remote_queries.items():
if device_ids: if device_ids:
query_list.extend((user_id, device_id) for device_id in device_ids) query_list.extend((user_id, device_id) for device_id in device_ids)
else: else:
@ -145,9 +143,9 @@ class E2eKeysHandler(object):
user_ids_not_in_cache, user_ids_not_in_cache,
remote_results, remote_results,
) = yield self.store.get_user_devices_from_cache(query_list) ) = yield self.store.get_user_devices_from_cache(query_list)
for user_id, devices in iteritems(remote_results): for user_id, devices in remote_results.items():
user_devices = results.setdefault(user_id, {}) user_devices = results.setdefault(user_id, {})
for device_id, device in iteritems(devices): for device_id, device in devices.items():
keys = device.get("keys", None) keys = device.get("keys", None)
device_display_name = device.get("device_display_name", None) device_display_name = device.get("device_display_name", None)
if keys: if keys:
@ -446,9 +444,9 @@ class E2eKeysHandler(object):
",".join( ",".join(
( (
"%s for %s:%s" % (key_id, user_id, device_id) "%s for %s:%s" % (key_id, user_id, device_id)
for user_id, user_keys in iteritems(json_result) for user_id, user_keys in json_result.items()
for device_id, device_keys in iteritems(user_keys) for device_id, device_keys in user_keys.items()
for key_id, _ in iteritems(device_keys) for key_id, _ in device_keys.items()
) )
), ),
) )

View file

@ -16,8 +16,6 @@
import logging import logging
from six import iteritems
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import ( from synapse.api.errors import (
@ -205,8 +203,8 @@ class E2eRoomKeysHandler(object):
) )
to_insert = [] # batch the inserts together to_insert = [] # batch the inserts together
changed = False # if anything has changed, we need to update the etag changed = False # if anything has changed, we need to update the etag
for room_id, room in iteritems(room_keys["rooms"]): for room_id, room in room_keys["rooms"].items():
for session_id, room_key in iteritems(room["sessions"]): for session_id, room_key in room["sessions"].items():
if not isinstance(room_key["is_verified"], bool): if not isinstance(room_key["is_verified"], bool):
msg = ( msg = (
"is_verified must be a boolean in keys for session %s in" "is_verified must be a boolean in keys for session %s in"

View file

@ -21,8 +21,6 @@ import itertools
import logging import logging
from typing import Dict, Iterable, List, Optional, Sequence, Tuple from typing import Dict, Iterable, List, Optional, Sequence, Tuple
import six
from six import iteritems, itervalues
from six.moves import http_client, zip from six.moves import http_client, zip
import attr import attr
@ -398,7 +396,7 @@ class FederationHandler(BaseHandler):
) )
event_map.update(evs) event_map.update(evs)
state = [event_map[e] for e in six.itervalues(state_map)] state = [event_map[e] for e in state_map.values()]
except Exception: except Exception:
logger.warning( logger.warning(
"[%s %s] Error attempting to resolve state at missing " "[%s %s] Error attempting to resolve state at missing "
@ -1009,7 +1007,7 @@ class FederationHandler(BaseHandler):
""" """
joined_users = [ joined_users = [
(state_key, int(event.depth)) (state_key, int(event.depth))
for (e_type, state_key), event in iteritems(state) for (e_type, state_key), event in state.items()
if e_type == EventTypes.Member and event.membership == Membership.JOIN if e_type == EventTypes.Member and event.membership == Membership.JOIN
] ]
@ -1099,16 +1097,16 @@ class FederationHandler(BaseHandler):
states = dict(zip(event_ids, [s.state for s in states])) states = dict(zip(event_ids, [s.state for s in states]))
state_map = await self.store.get_events( state_map = await self.store.get_events(
[e_id for ids in itervalues(states) for e_id in itervalues(ids)], [e_id for ids in states.values() for e_id in ids.values()],
get_prev_content=False, get_prev_content=False,
) )
states = { states = {
key: { key: {
k: state_map[e_id] k: state_map[e_id]
for k, e_id in iteritems(state_dict) for k, e_id in state_dict.items()
if e_id in state_map if e_id in state_map
} }
for key, state_dict in iteritems(states) for key, state_dict in states.items()
} }
for e_id, _ in sorted_extremeties_tuple: for e_id, _ in sorted_extremeties_tuple:
@ -1733,7 +1731,7 @@ class FederationHandler(BaseHandler):
state_groups = await self.state_store.get_state_groups(room_id, [event_id]) state_groups = await self.state_store.get_state_groups(room_id, [event_id])
if state_groups: if state_groups:
_, state = list(iteritems(state_groups)).pop() _, state = list(state_groups.items()).pop()
results = {(e.type, e.state_key): e for e in state} results = {(e.type, e.state_key): e for e in state}
if event.is_state(): if event.is_state():
@ -2096,7 +2094,7 @@ class FederationHandler(BaseHandler):
room_version, state_sets, event room_version, state_sets, event
) )
current_state_ids = { current_state_ids = {
k: e.event_id for k, e in iteritems(current_state_ids) k: e.event_id for k, e in current_state_ids.items()
} }
else: else:
current_state_ids = await self.state_handler.get_current_state_ids( current_state_ids = await self.state_handler.get_current_state_ids(
@ -2112,7 +2110,7 @@ class FederationHandler(BaseHandler):
# Now check if event pass auth against said current state # Now check if event pass auth against said current state
auth_types = auth_types_for_event(event) auth_types = auth_types_for_event(event)
current_state_ids = [ current_state_ids = [
e for k, e in iteritems(current_state_ids) if k in auth_types e for k, e in current_state_ids.items() if k in auth_types
] ]
current_auth_events = await self.store.get_events(current_state_ids) current_auth_events = await self.store.get_events(current_state_ids)
@ -2428,7 +2426,7 @@ class FederationHandler(BaseHandler):
else: else:
event_key = None event_key = None
state_updates = { state_updates = {
k: a.event_id for k, a in iteritems(auth_events) if k != event_key k: a.event_id for k, a in auth_events.items() if k != event_key
} }
current_state_ids = await context.get_current_state_ids() current_state_ids = await context.get_current_state_ids()
@ -2439,7 +2437,7 @@ class FederationHandler(BaseHandler):
prev_state_ids = await context.get_prev_state_ids() prev_state_ids = await context.get_prev_state_ids()
prev_state_ids = dict(prev_state_ids) prev_state_ids = dict(prev_state_ids)
prev_state_ids.update({k: a.event_id for k, a in iteritems(auth_events)}) prev_state_ids.update({k: a.event_id for k, a in auth_events.items()})
# create a new state group as a delta from the existing one. # create a new state group as a delta from the existing one.
prev_group = context.state_group prev_group = context.state_group

View file

@ -16,8 +16,6 @@
import logging import logging
from six import iteritems
from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
from synapse.types import get_domain_from_id from synapse.types import get_domain_from_id
@ -227,7 +225,7 @@ class GroupsLocalWorkerHandler(object):
results = {} results = {}
failed_results = [] failed_results = []
for destination, dest_user_ids in iteritems(destinations): for destination, dest_user_ids in destinations.items():
try: try:
r = await self.transport_client.bulk_get_publicised_groups( r = await self.transport_client.bulk_get_publicised_groups(
destination, list(dest_user_ids) destination, list(dest_user_ids)

View file

@ -17,7 +17,7 @@
import logging import logging
from typing import Optional, Tuple from typing import Optional, Tuple
from six import iteritems, itervalues, string_types from six import string_types
from canonicaljson import encode_canonical_json, json from canonicaljson import encode_canonical_json, json
@ -246,7 +246,7 @@ class MessageHandler(object):
"avatar_url": profile.avatar_url, "avatar_url": profile.avatar_url,
"display_name": profile.display_name, "display_name": profile.display_name,
} }
for user_id, profile in iteritems(users_with_profile) for user_id, profile in users_with_profile.items()
} }
def maybe_schedule_expiry(self, event): def maybe_schedule_expiry(self, event):
@ -988,7 +988,7 @@ class EventCreationHandler(object):
state_to_include_ids = [ state_to_include_ids = [
e_id e_id
for k, e_id in iteritems(current_state_ids) for k, e_id in current_state_ids.items()
if k[0] in self.room_invite_state_types if k[0] in self.room_invite_state_types
or k == (EventTypes.Member, event.sender) or k == (EventTypes.Member, event.sender)
] ]
@ -1002,7 +1002,7 @@ class EventCreationHandler(object):
"content": e.content, "content": e.content,
"sender": e.sender, "sender": e.sender,
} }
for e in itervalues(state_to_include) for e in state_to_include.values()
] ]
invitee = UserID.from_string(event.state_key) invitee = UserID.from_string(event.state_key)

View file

@ -15,8 +15,6 @@
# limitations under the License. # limitations under the License.
import logging import logging
from six import iteritems
from twisted.internet import defer from twisted.internet import defer
from twisted.python.failure import Failure from twisted.python.failure import Failure
@ -145,7 +143,7 @@ class PaginationHandler(object):
logger.debug("[purge] Rooms to purge: %s", rooms) logger.debug("[purge] Rooms to purge: %s", rooms)
for room_id, retention_policy in iteritems(rooms): for room_id, retention_policy in rooms.items():
logger.info("[purge] Attempting to purge messages in room %s", room_id) logger.info("[purge] Attempting to purge messages in room %s", room_id)
if room_id in self._purges_in_progress_by_room: if room_id in self._purges_in_progress_by_room:

View file

@ -27,8 +27,6 @@ import logging
from contextlib import contextmanager from contextlib import contextmanager
from typing import Dict, Iterable, List, Set from typing import Dict, Iterable, List, Set
from six import iteritems, itervalues
from prometheus_client import Counter from prometheus_client import Counter
from typing_extensions import ContextManager from typing_extensions import ContextManager
@ -170,14 +168,14 @@ class BasePresenceHandler(abc.ABC):
for user_id in user_ids for user_id in user_ids
} }
missing = [user_id for user_id, state in iteritems(states) if not state] missing = [user_id for user_id, state in states.items() if not state]
if missing: if missing:
# There are things not in our in memory cache. Lets pull them out of # There are things not in our in memory cache. Lets pull them out of
# the database. # the database.
res = await self.store.get_presence_for_users(missing) res = await self.store.get_presence_for_users(missing)
states.update(res) states.update(res)
missing = [user_id for user_id, state in iteritems(states) if not state] missing = [user_id for user_id, state in states.items() if not state]
if missing: if missing:
new = { new = {
user_id: UserPresenceState.default(user_id) for user_id in missing user_id: UserPresenceState.default(user_id) for user_id in missing
@ -632,7 +630,7 @@ class PresenceHandler(BasePresenceHandler):
await self._update_states( await self._update_states(
[ [
prev_state.copy_and_replace(last_user_sync_ts=time_now_ms) prev_state.copy_and_replace(last_user_sync_ts=time_now_ms)
for prev_state in itervalues(prev_states) for prev_state in prev_states.values()
] ]
) )
self.external_process_last_updated_ms.pop(process_id, None) self.external_process_last_updated_ms.pop(process_id, None)
@ -1087,7 +1085,7 @@ class PresenceEventSource(object):
return (list(updates.values()), max_token) return (list(updates.values()), max_token)
else: else:
return ( return (
[s for s in itervalues(updates) if s.state != PresenceState.OFFLINE], [s for s in updates.values() if s.state != PresenceState.OFFLINE],
max_token, max_token,
) )
@ -1323,11 +1321,11 @@ def get_interested_remotes(store, states, state_handler):
# hosts in those rooms. # hosts in those rooms.
room_ids_to_states, users_to_states = yield get_interested_parties(store, states) room_ids_to_states, users_to_states = yield get_interested_parties(store, states)
for room_id, states in iteritems(room_ids_to_states): for room_id, states in room_ids_to_states.items():
hosts = yield state_handler.get_current_hosts_in_room(room_id) hosts = yield state_handler.get_current_hosts_in_room(room_id)
hosts_and_states.append((hosts, states)) hosts_and_states.append((hosts, states))
for user_id, states in iteritems(users_to_states): for user_id, states in users_to_states.items():
host = get_domain_from_id(user_id) host = get_domain_from_id(user_id)
hosts_and_states.append(([host], states)) hosts_and_states.append(([host], states))

View file

@ -24,7 +24,7 @@ import string
from collections import OrderedDict from collections import OrderedDict
from typing import Tuple from typing import Tuple
from six import iteritems, string_types from six import string_types
from synapse.api.constants import ( from synapse.api.constants import (
EventTypes, EventTypes,
@ -377,7 +377,7 @@ class RoomCreationHandler(BaseHandler):
# map from event_id to BaseEvent # map from event_id to BaseEvent
old_room_state_events = await self.store.get_events(old_room_state_ids.values()) old_room_state_events = await self.store.get_events(old_room_state_ids.values())
for k, old_event_id in iteritems(old_room_state_ids): for k, old_event_id in old_room_state_ids.items():
old_event = old_room_state_events.get(old_event_id) old_event = old_room_state_events.get(old_event_id)
if old_event: if old_event:
initial_state[k] = old_event.content initial_state[k] = old_event.content
@ -430,7 +430,7 @@ class RoomCreationHandler(BaseHandler):
old_room_member_state_events = await self.store.get_events( old_room_member_state_events = await self.store.get_events(
old_room_member_state_ids.values() old_room_member_state_ids.values()
) )
for k, old_event in iteritems(old_room_member_state_events): for k, old_event in old_room_member_state_events.items():
# Only transfer ban events # Only transfer ban events
if ( if (
"membership" in old_event.content "membership" in old_event.content

View file

@ -17,8 +17,6 @@ import logging
from collections import namedtuple from collections import namedtuple
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from six import iteritems
import msgpack import msgpack
from unpaddedbase64 import decode_base64, encode_base64 from unpaddedbase64 import decode_base64, encode_base64
@ -271,7 +269,7 @@ class RoomListHandler(BaseHandler):
event_map = yield self.store.get_events( event_map = yield self.store.get_events(
[ [
event_id event_id
for key, event_id in iteritems(current_state_ids) for key, event_id in current_state_ids.items()
if key[0] if key[0]
in ( in (
EventTypes.Create, EventTypes.Create,

View file

@ -18,8 +18,6 @@ import itertools
import logging import logging
from typing import Any, Dict, FrozenSet, List, Optional, Set, Tuple from typing import Any, Dict, FrozenSet, List, Optional, Set, Tuple
from six import iteritems, itervalues
import attr import attr
from prometheus_client import Counter from prometheus_client import Counter
@ -390,7 +388,7 @@ class SyncHandler(object):
# result returned by the event source is poor form (it might cache # result returned by the event source is poor form (it might cache
# the object) # the object)
room_id = event["room_id"] room_id = event["room_id"]
event_copy = {k: v for (k, v) in iteritems(event) if k != "room_id"} event_copy = {k: v for (k, v) in event.items() if k != "room_id"}
ephemeral_by_room.setdefault(room_id, []).append(event_copy) ephemeral_by_room.setdefault(room_id, []).append(event_copy)
receipt_key = since_token.receipt_key if since_token else "0" receipt_key = since_token.receipt_key if since_token else "0"
@ -408,7 +406,7 @@ class SyncHandler(object):
for event in receipts: for event in receipts:
room_id = event["room_id"] room_id = event["room_id"]
# exclude room id, as above # exclude room id, as above
event_copy = {k: v for (k, v) in iteritems(event) if k != "room_id"} event_copy = {k: v for (k, v) in event.items() if k != "room_id"}
ephemeral_by_room.setdefault(room_id, []).append(event_copy) ephemeral_by_room.setdefault(room_id, []).append(event_copy)
return now_token, ephemeral_by_room return now_token, ephemeral_by_room
@ -454,7 +452,7 @@ class SyncHandler(object):
current_state_ids_map = await self.state.get_current_state_ids( current_state_ids_map = await self.state.get_current_state_ids(
room_id room_id
) )
current_state_ids = frozenset(itervalues(current_state_ids_map)) current_state_ids = frozenset(current_state_ids_map.values())
recents = await filter_events_for_client( recents = await filter_events_for_client(
self.storage, self.storage,
@ -509,7 +507,7 @@ class SyncHandler(object):
current_state_ids_map = await self.state.get_current_state_ids( current_state_ids_map = await self.state.get_current_state_ids(
room_id room_id
) )
current_state_ids = frozenset(itervalues(current_state_ids_map)) current_state_ids = frozenset(current_state_ids_map.values())
loaded_recents = await filter_events_for_client( loaded_recents = await filter_events_for_client(
self.storage, self.storage,
@ -909,7 +907,7 @@ class SyncHandler(object):
logger.debug("filtering state from %r...", state_ids) logger.debug("filtering state from %r...", state_ids)
state_ids = { state_ids = {
t: event_id t: event_id
for t, event_id in iteritems(state_ids) for t, event_id in state_ids.items()
if cache.get(t[1]) != event_id if cache.get(t[1]) != event_id
} }
logger.debug("...to %r", state_ids) logger.debug("...to %r", state_ids)
@ -1430,7 +1428,7 @@ class SyncHandler(object):
if since_token: if since_token:
for joined_sync in sync_result_builder.joined: for joined_sync in sync_result_builder.joined:
it = itertools.chain( it = itertools.chain(
joined_sync.timeline.events, itervalues(joined_sync.state) joined_sync.timeline.events, joined_sync.state.values()
) )
for event in it: for event in it:
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
@ -1505,7 +1503,7 @@ class SyncHandler(object):
newly_left_rooms = [] newly_left_rooms = []
room_entries = [] room_entries = []
invited = [] invited = []
for room_id, events in iteritems(mem_change_events_by_room_id): for room_id, events in mem_change_events_by_room_id.items():
logger.debug( logger.debug(
"Membership changes in %s: [%s]", "Membership changes in %s: [%s]",
room_id, room_id,
@ -1996,17 +1994,17 @@ def _calculate_state(
event_id_to_key = { event_id_to_key = {
e: key e: key
for key, e in itertools.chain( for key, e in itertools.chain(
iteritems(timeline_contains), timeline_contains.items(),
iteritems(previous), previous.items(),
iteritems(timeline_start), timeline_start.items(),
iteritems(current), current.items(),
) )
} }
c_ids = set(itervalues(current)) c_ids = set(current.values())
ts_ids = set(itervalues(timeline_start)) ts_ids = set(timeline_start.values())
p_ids = set(itervalues(previous)) p_ids = set(previous.values())
tc_ids = set(itervalues(timeline_contains)) tc_ids = set(timeline_contains.values())
# If we are lazyloading room members, we explicitly add the membership events # If we are lazyloading room members, we explicitly add the membership events
# for the senders in the timeline into the state block returned by /sync, # for the senders in the timeline into the state block returned by /sync,
@ -2020,7 +2018,7 @@ def _calculate_state(
if lazy_load_members: if lazy_load_members:
p_ids.difference_update( p_ids.difference_update(
e for t, e in iteritems(timeline_start) if t[0] == EventTypes.Member e for t, e in timeline_start.items() if t[0] == EventTypes.Member
) )
state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids

View file

@ -15,8 +15,6 @@
import logging import logging
from six import iteritems, iterkeys
import synapse.metrics import synapse.metrics
from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.handlers.state_deltas import StateDeltasHandler from synapse.handlers.state_deltas import StateDeltasHandler
@ -289,7 +287,7 @@ class UserDirectoryHandler(StateDeltasHandler):
users_with_profile = await self.state.get_current_users_in_room(room_id) users_with_profile = await self.state.get_current_users_in_room(room_id)
# Remove every user from the sharing tables for that room. # Remove every user from the sharing tables for that room.
for user_id in iterkeys(users_with_profile): for user_id in users_with_profile.keys():
await self.store.remove_user_who_share_room(user_id, room_id) await self.store.remove_user_who_share_room(user_id, room_id)
# Then, re-add them to the tables. # Then, re-add them to the tables.
@ -298,7 +296,7 @@ class UserDirectoryHandler(StateDeltasHandler):
# which when ran over an entire room, will result in the same values # which when ran over an entire room, will result in the same values
# being added multiple times. The batching upserts shouldn't make this # being added multiple times. The batching upserts shouldn't make this
# too bad, though. # too bad, though.
for user_id, profile in iteritems(users_with_profile): for user_id, profile in users_with_profile.items():
await self._handle_new_user(room_id, user_id, profile) await self._handle_new_user(room_id, user_id, profile)
async def _handle_new_user(self, room_id, user_id, profile): async def _handle_new_user(self, room_id, user_id, profile):

View file

@ -22,8 +22,6 @@ import threading
import time import time
from typing import Callable, Dict, Iterable, Optional, Tuple, Union from typing import Callable, Dict, Iterable, Optional, Tuple, Union
import six
import attr import attr
from prometheus_client import Counter, Gauge, Histogram from prometheus_client import Counter, Gauge, Histogram
from prometheus_client.core import ( from prometheus_client.core import (
@ -83,7 +81,7 @@ class LaterGauge(object):
return return
if isinstance(calls, dict): if isinstance(calls, dict):
for k, v in six.iteritems(calls): for k, v in calls.items():
g.add_metric(k, v) g.add_metric(k, v)
else: else:
g.add_metric([], calls) g.add_metric([], calls)
@ -194,7 +192,7 @@ class InFlightGauge(object):
gauge = GaugeMetricFamily( gauge = GaugeMetricFamily(
"_".join([self.name, name]), "", labels=self.labels "_".join([self.name, name]), "", labels=self.labels
) )
for key, metrics in six.iteritems(metrics_by_key): for key, metrics in metrics_by_key.items():
gauge.add_metric(key, getattr(metrics, name)) gauge.add_metric(key, getattr(metrics, name))
yield gauge yield gauge

View file

@ -17,8 +17,6 @@
import logging import logging
from collections import namedtuple from collections import namedtuple
from six import iteritems, itervalues
from prometheus_client import Counter from prometheus_client import Counter
from twisted.internet import defer from twisted.internet import defer
@ -130,7 +128,7 @@ class BulkPushRuleEvaluator(object):
event, prev_state_ids, for_verification=False event, prev_state_ids, for_verification=False
) )
auth_events = yield self.store.get_events(auth_events_ids) auth_events = yield self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in itervalues(auth_events)} auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
sender_level = get_user_power_level(event.sender, auth_events) sender_level = get_user_power_level(event.sender, auth_events)
@ -162,7 +160,7 @@ class BulkPushRuleEvaluator(object):
condition_cache = {} condition_cache = {}
for uid, rules in iteritems(rules_by_user): for uid, rules in rules_by_user.items():
if event.sender == uid: if event.sender == uid:
continue continue
@ -398,7 +396,7 @@ class RulesForRoom(object):
# If the event is a join event then it will be in current state evnts # If the event is a join event then it will be in current state evnts
# map but not in the DB, so we have to explicitly insert it. # map but not in the DB, so we have to explicitly insert it.
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
for event_id in itervalues(member_event_ids): for event_id in member_event_ids.values():
if event_id == event.event_id: if event_id == event.event_id:
members[event_id] = (event.state_key, event.membership) members[event_id] = (event.state_key, event.membership)
@ -407,7 +405,7 @@ class RulesForRoom(object):
interested_in_user_ids = { interested_in_user_ids = {
user_id user_id
for user_id, membership in itervalues(members) for user_id, membership in members.values()
if membership == Membership.JOIN if membership == Membership.JOIN
} }
@ -418,7 +416,7 @@ class RulesForRoom(object):
) )
user_ids = { user_ids = {
uid for uid, have_pusher in iteritems(if_users_with_pushers) if have_pusher uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
} }
logger.debug("With pushers: %r", user_ids) logger.debug("With pushers: %r", user_ids)
@ -439,7 +437,7 @@ class RulesForRoom(object):
) )
ret_rules_by_user.update( ret_rules_by_user.update(
item for item in iteritems(rules_by_user) if item[0] is not None item for item in rules_by_user.items() if item[0] is not None
) )
self.update_cache(sequence, members, ret_rules_by_user, state_group) self.update_cache(sequence, members, ret_rules_by_user, state_group)

View file

@ -149,7 +149,7 @@ class RdataCommand(Command):
class PositionCommand(Command): class PositionCommand(Command):
"""Sent by the server to tell the client the stream postition without """Sent by the server to tell the client the stream position without
needing to send an RDATA. needing to send an RDATA.
Format:: Format::
@ -188,7 +188,7 @@ class ErrorCommand(_SimpleCommand):
class PingCommand(_SimpleCommand): class PingCommand(_SimpleCommand):
"""Sent by either side as a keep alive. The data is arbitary (often timestamp) """Sent by either side as a keep alive. The data is arbitrary (often timestamp)
""" """
NAME = "PING" NAME = "PING"

View file

@ -112,8 +112,8 @@ class ReplicationCommandHandler:
"replication_position", clock=self._clock "replication_position", clock=self._clock
) )
# Map of stream to batched updates. See RdataCommand for info on how # Map of stream name to batched updates. See RdataCommand for info on
# batching works. # how batching works.
self._pending_batches = {} # type: Dict[str, List[Any]] self._pending_batches = {} # type: Dict[str, List[Any]]
# The factory used to create connections. # The factory used to create connections.
@ -123,7 +123,8 @@ class ReplicationCommandHandler:
# outgoing replication commands to.) # outgoing replication commands to.)
self._connections = [] # type: List[AbstractConnection] self._connections = [] # type: List[AbstractConnection]
# For each connection, the incoming streams that are coming from that connection # For each connection, the incoming stream names that are coming from
# that connection.
self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]] self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]]
LaterGauge( LaterGauge(
@ -310,6 +311,27 @@ class ReplicationCommandHandler:
# Check if this is the last of a batch of updates # Check if this is the last of a batch of updates
rows = self._pending_batches.pop(stream_name, []) rows = self._pending_batches.pop(stream_name, [])
rows.append(row) rows.append(row)
stream = self._streams.get(stream_name)
if not stream:
logger.error("Got RDATA for unknown stream: %s", stream_name)
return
# Find where we previously streamed up to.
current_token = stream.current_token(cmd.instance_name)
# Discard this data if this token is earlier than the current
# position. Note that streams can be reset (in which case you
# expect an earlier token), but that must be preceded by a
# POSITION command.
if cmd.token <= current_token:
logger.debug(
"Discarding RDATA from stream %s at position %s before previous position %s",
stream_name,
cmd.token,
current_token,
)
else:
await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows) await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)
async def on_rdata( async def on_rdata(

View file

@ -20,8 +20,6 @@ import os
import shutil import shutil
from typing import Dict, Tuple from typing import Dict, Tuple
from six import iteritems
import twisted.internet.error import twisted.internet.error
import twisted.web.http import twisted.web.http
from twisted.web.resource import Resource from twisted.web.resource import Resource
@ -606,7 +604,7 @@ class MediaRepository(object):
thumbnails[(t_width, t_height, r_type)] = r_method thumbnails[(t_width, t_height, r_type)] = r_method
# Now we generate the thumbnails for each dimension, store it # Now we generate the thumbnails for each dimension, store it
for (t_width, t_height, t_type), t_method in iteritems(thumbnails): for (t_width, t_height, t_type), t_method in thumbnails.items():
# Generate the thumbnail # Generate the thumbnail
if t_method == "crop": if t_method == "crop":
t_byte_source = await defer_to_thread( t_byte_source = await defer_to_thread(

View file

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from six import iteritems, string_types from six import string_types
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.api.urls import ConsentURIBuilder from synapse.api.urls import ConsentURIBuilder
@ -121,7 +121,7 @@ def copy_with_str_subst(x, substitutions):
if isinstance(x, string_types): if isinstance(x, string_types):
return x % substitutions return x % substitutions
if isinstance(x, dict): if isinstance(x, dict):
return {k: copy_with_str_subst(v, substitutions) for (k, v) in iteritems(x)} return {k: copy_with_str_subst(v, substitutions) for (k, v) in x.items()}
if isinstance(x, (list, tuple)): if isinstance(x, (list, tuple)):
return [copy_with_str_subst(y) for y in x] return [copy_with_str_subst(y) for y in x]

View file

@ -14,8 +14,6 @@
# limitations under the License. # limitations under the License.
import logging import logging
from six import iteritems
from synapse.api.constants import ( from synapse.api.constants import (
EventTypes, EventTypes,
LimitBlockingTypes, LimitBlockingTypes,
@ -214,7 +212,7 @@ class ResourceLimitsServerNotices(object):
referenced_events = list(pinned_state_event.content.get("pinned", [])) referenced_events = list(pinned_state_event.content.get("pinned", []))
events = await self._store.get_events(referenced_events) events = await self._store.get_events(referenced_events)
for event_id, event in iteritems(events): for event_id, event in events.items():
if event.type != EventTypes.Message: if event.type != EventTypes.Message:
continue continue
if event.content.get("msgtype") == ServerNoticeMsgType: if event.content.get("msgtype") == ServerNoticeMsgType:

View file

@ -18,8 +18,6 @@ import logging
from collections import namedtuple from collections import namedtuple
from typing import Dict, Iterable, List, Optional, Set from typing import Dict, Iterable, List, Optional, Set
from six import iteritems, itervalues
import attr import attr
from frozendict import frozendict from frozendict import frozendict
from prometheus_client import Histogram from prometheus_client import Histogram
@ -144,7 +142,7 @@ class StateHandler(object):
list(state.values()), get_prev_content=False list(state.values()), get_prev_content=False
) )
state = { state = {
key: state_map[e_id] for key, e_id in iteritems(state) if e_id in state_map key: state_map[e_id] for key, e_id in state.items() if e_id in state_map
} }
return state return state
@ -423,7 +421,7 @@ class StateHandler(object):
state_res_store=StateResolutionStore(self.store), state_res_store=StateResolutionStore(self.store),
) )
new_state = {key: state_map[ev_id] for key, ev_id in iteritems(new_state)} new_state = {key: state_map[ev_id] for key, ev_id in new_state.items()}
return new_state return new_state
@ -505,8 +503,8 @@ class StateResolutionHandler(object):
# resolve_events_with_store do it? # resolve_events_with_store do it?
new_state = {} new_state = {}
conflicted_state = False conflicted_state = False
for st in itervalues(state_groups_ids): for st in state_groups_ids.values():
for key, e_id in iteritems(st): for key, e_id in st.items():
if key in new_state: if key in new_state:
conflicted_state = True conflicted_state = True
break break
@ -520,7 +518,7 @@ class StateResolutionHandler(object):
new_state = yield resolve_events_with_store( new_state = yield resolve_events_with_store(
room_id, room_id,
room_version, room_version,
list(itervalues(state_groups_ids)), list(state_groups_ids.values()),
event_map=event_map, event_map=event_map,
state_res_store=state_res_store, state_res_store=state_res_store,
) )
@ -561,12 +559,12 @@ def _make_state_cache_entry(new_state, state_groups_ids):
# not get persisted. # not get persisted.
# first look for exact matches # first look for exact matches
new_state_event_ids = set(itervalues(new_state)) new_state_event_ids = set(new_state.values())
for sg, state in iteritems(state_groups_ids): for sg, state in state_groups_ids.items():
if len(new_state_event_ids) != len(state): if len(new_state_event_ids) != len(state):
continue continue
old_state_event_ids = set(itervalues(state)) old_state_event_ids = set(state.values())
if new_state_event_ids == old_state_event_ids: if new_state_event_ids == old_state_event_ids:
# got an exact match. # got an exact match.
return _StateCacheEntry(state=new_state, state_group=sg) return _StateCacheEntry(state=new_state, state_group=sg)
@ -579,8 +577,8 @@ def _make_state_cache_entry(new_state, state_groups_ids):
prev_group = None prev_group = None
delta_ids = None delta_ids = None
for old_group, old_state in iteritems(state_groups_ids): for old_group, old_state in state_groups_ids.items():
n_delta_ids = {k: v for k, v in iteritems(new_state) if old_state.get(k) != v} n_delta_ids = {k: v for k, v in new_state.items() if old_state.get(k) != v}
if not delta_ids or len(n_delta_ids) < len(delta_ids): if not delta_ids or len(n_delta_ids) < len(delta_ids):
prev_group = old_group prev_group = old_group
delta_ids = n_delta_ids delta_ids = n_delta_ids

View file

@ -17,8 +17,6 @@ import hashlib
import logging import logging
from typing import Callable, Dict, List, Optional from typing import Callable, Dict, List, Optional
from six import iteritems, iterkeys, itervalues
from twisted.internet import defer from twisted.internet import defer
from synapse import event_auth from synapse import event_auth
@ -70,11 +68,11 @@ def resolve_events_with_store(
unconflicted_state, conflicted_state = _seperate(state_sets) unconflicted_state, conflicted_state = _seperate(state_sets)
needed_events = { needed_events = {
event_id for event_ids in itervalues(conflicted_state) for event_id in event_ids event_id for event_ids in conflicted_state.values() for event_id in event_ids
} }
needed_event_count = len(needed_events) needed_event_count = len(needed_events)
if event_map is not None: if event_map is not None:
needed_events -= set(iterkeys(event_map)) needed_events -= set(event_map.keys())
logger.info( logger.info(
"Asking for %d/%d conflicted events", len(needed_events), needed_event_count "Asking for %d/%d conflicted events", len(needed_events), needed_event_count
@ -102,11 +100,11 @@ def resolve_events_with_store(
unconflicted_state, conflicted_state, state_map unconflicted_state, conflicted_state, state_map
) )
new_needed_events = set(itervalues(auth_events)) new_needed_events = set(auth_events.values())
new_needed_event_count = len(new_needed_events) new_needed_event_count = len(new_needed_events)
new_needed_events -= needed_events new_needed_events -= needed_events
if event_map is not None: if event_map is not None:
new_needed_events -= set(iterkeys(event_map)) new_needed_events -= set(event_map.keys())
logger.info( logger.info(
"Asking for %d/%d auth events", len(new_needed_events), new_needed_event_count "Asking for %d/%d auth events", len(new_needed_events), new_needed_event_count
@ -152,7 +150,7 @@ def _seperate(state_sets):
conflicted_state = {} conflicted_state = {}
for state_set in state_set_iterator: for state_set in state_set_iterator:
for key, value in iteritems(state_set): for key, value in state_set.items():
# Check if there is an unconflicted entry for the state key. # Check if there is an unconflicted entry for the state key.
unconflicted_value = unconflicted_state.get(key) unconflicted_value = unconflicted_state.get(key)
if unconflicted_value is None: if unconflicted_value is None:
@ -178,7 +176,7 @@ def _seperate(state_sets):
def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map): def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map):
auth_events = {} auth_events = {}
for event_ids in itervalues(conflicted_state): for event_ids in conflicted_state.values():
for event_id in event_ids: for event_id in event_ids:
if event_id in state_map: if event_id in state_map:
keys = event_auth.auth_types_for_event(state_map[event_id]) keys = event_auth.auth_types_for_event(state_map[event_id])
@ -194,7 +192,7 @@ def _resolve_with_state(
unconflicted_state_ids, conflicted_state_ids, auth_event_ids, state_map unconflicted_state_ids, conflicted_state_ids, auth_event_ids, state_map
): ):
conflicted_state = {} conflicted_state = {}
for key, event_ids in iteritems(conflicted_state_ids): for key, event_ids in conflicted_state_ids.items():
events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map] events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map]
if len(events) > 1: if len(events) > 1:
conflicted_state[key] = events conflicted_state[key] = events
@ -203,7 +201,7 @@ def _resolve_with_state(
auth_events = { auth_events = {
key: state_map[ev_id] key: state_map[ev_id]
for key, ev_id in iteritems(auth_event_ids) for key, ev_id in auth_event_ids.items()
if ev_id in state_map if ev_id in state_map
} }
@ -214,7 +212,7 @@ def _resolve_with_state(
raise raise
new_state = unconflicted_state_ids new_state = unconflicted_state_ids
for key, event in iteritems(resolved_state): for key, event in resolved_state.items():
new_state[key] = event.event_id new_state[key] = event.event_id
return new_state return new_state
@ -238,21 +236,21 @@ def _resolve_state_events(conflicted_state, auth_events):
auth_events.update(resolved_state) auth_events.update(resolved_state)
for key, events in iteritems(conflicted_state): for key, events in conflicted_state.items():
if key[0] == EventTypes.JoinRules: if key[0] == EventTypes.JoinRules:
logger.debug("Resolving conflicted join rules %r", events) logger.debug("Resolving conflicted join rules %r", events)
resolved_state[key] = _resolve_auth_events(events, auth_events) resolved_state[key] = _resolve_auth_events(events, auth_events)
auth_events.update(resolved_state) auth_events.update(resolved_state)
for key, events in iteritems(conflicted_state): for key, events in conflicted_state.items():
if key[0] == EventTypes.Member: if key[0] == EventTypes.Member:
logger.debug("Resolving conflicted member lists %r", events) logger.debug("Resolving conflicted member lists %r", events)
resolved_state[key] = _resolve_auth_events(events, auth_events) resolved_state[key] = _resolve_auth_events(events, auth_events)
auth_events.update(resolved_state) auth_events.update(resolved_state)
for key, events in iteritems(conflicted_state): for key, events in conflicted_state.items():
if key not in resolved_state: if key not in resolved_state:
logger.debug("Resolving conflicted state %r:%r", key, events) logger.debug("Resolving conflicted state %r:%r", key, events)
resolved_state[key] = _resolve_normal_events(events, auth_events) resolved_state[key] = _resolve_normal_events(events, auth_events)

View file

@ -18,8 +18,6 @@ import itertools
import logging import logging
from typing import Dict, List, Optional from typing import Dict, List, Optional
from six import iteritems, itervalues
from twisted.internet import defer from twisted.internet import defer
import synapse.state import synapse.state
@ -87,7 +85,7 @@ def resolve_events_with_store(
full_conflicted_set = set( full_conflicted_set = set(
itertools.chain( itertools.chain(
itertools.chain.from_iterable(itervalues(conflicted_state)), auth_diff itertools.chain.from_iterable(conflicted_state.values()), auth_diff
) )
) )
@ -572,7 +570,7 @@ def lexicographical_topological_sort(graph, key):
# `(key(node), node)` so that sorting does the right thing # `(key(node), node)` so that sorting does the right thing
zero_outdegree = [] zero_outdegree = []
for node, edges in iteritems(graph): for node, edges in graph.items():
if len(edges) == 0: if len(edges) == 0:
zero_outdegree.append((key(node), node)) zero_outdegree.append((key(node), node))

View file

@ -15,8 +15,6 @@
import logging import logging
from six import iteritems
from twisted.internet import defer from twisted.internet import defer
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
@ -421,7 +419,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
): ):
self.database_engine.lock_table(txn, "user_ips") self.database_engine.lock_table(txn, "user_ips")
for entry in iteritems(to_update): for entry in to_update.items():
(user_id, access_token, ip), (user_agent, device_id, last_seen) = entry (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
try: try:
@ -530,7 +528,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
"user_agent": user_agent, "user_agent": user_agent,
"last_seen": last_seen, "last_seen": last_seen,
} }
for (access_token, ip), (user_agent, last_seen) in iteritems(results) for (access_token, ip), (user_agent, last_seen) in results.items()
] ]
@wrap_as_background_process("prune_old_user_ips") @wrap_as_background_process("prune_old_user_ips")

View file

@ -17,8 +17,6 @@
import logging import logging
from typing import List, Optional, Set, Tuple from typing import List, Optional, Set, Tuple
from six import iteritems
from canonicaljson import json from canonicaljson import json
from twisted.internet import defer from twisted.internet import defer
@ -208,7 +206,7 @@ class DeviceWorkerStore(SQLBaseStore):
) )
# add the updated cross-signing keys to the results list # add the updated cross-signing keys to the results list
for user_id, result in iteritems(cross_signing_keys_by_user): for user_id, result in cross_signing_keys_by_user.items():
result["user_id"] = user_id result["user_id"] = user_id
# FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
results.append(("org.matrix.signing_key_update", result)) results.append(("org.matrix.signing_key_update", result))
@ -269,7 +267,7 @@ class DeviceWorkerStore(SQLBaseStore):
) )
results = [] results = []
for user_id, user_devices in iteritems(devices): for user_id, user_devices in devices.items():
# The prev_id for the first row is always the last row before # The prev_id for the first row is always the last row before
# `from_stream_id` # `from_stream_id`
prev_id = yield self._get_last_device_update_for_remote_user( prev_id = yield self._get_last_device_update_for_remote_user(
@ -493,7 +491,7 @@ class DeviceWorkerStore(SQLBaseStore):
if devices: if devices:
user_devices = devices[user_id] user_devices = devices[user_id]
results = [] results = []
for device_id, device in iteritems(user_devices): for device_id, device in user_devices.items():
result = {"device_id": device_id} result = {"device_id": device_id}
key_json = device.get("key_json", None) key_json = device.get("key_json", None)

View file

@ -16,8 +16,6 @@
# limitations under the License. # limitations under the License.
from typing import Dict, List from typing import Dict, List
from six import iteritems
from canonicaljson import encode_canonical_json, json from canonicaljson import encode_canonical_json, json
from twisted.enterprise.adbapi import Connection from twisted.enterprise.adbapi import Connection
@ -64,9 +62,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
# Build the result structure, un-jsonify the results, and add the # Build the result structure, un-jsonify the results, and add the
# "unsigned" section # "unsigned" section
rv = {} rv = {}
for user_id, device_keys in iteritems(results): for user_id, device_keys in results.items():
rv[user_id] = {} rv[user_id] = {}
for device_id, device_info in iteritems(device_keys): for device_id, device_info in device_keys.items():
r = db_to_json(device_info.pop("key_json")) r = db_to_json(device_info.pop("key_json"))
r["unsigned"] = {} r["unsigned"] = {}
display_name = device_info["device_display_name"] display_name = device_info["device_display_name"]

View file

@ -16,8 +16,6 @@
import logging import logging
from typing import Dict, Tuple from typing import Dict, Tuple
from six import iteritems
import attr import attr
from canonicaljson import json from canonicaljson import json
@ -493,7 +491,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
sql, sql,
( (
_gen_entry(user_id, actions) _gen_entry(user_id, actions)
for user_id, actions in iteritems(user_id_actions) for user_id, actions in user_id_actions.items()
), ),
) )

View file

@ -21,7 +21,7 @@ from collections import OrderedDict, namedtuple
from functools import wraps from functools import wraps
from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
from six import integer_types, iteritems, text_type from six import integer_types, text_type
from six.moves import range from six.moves import range
import attr import attr
@ -232,10 +232,10 @@ class PersistEventsStore:
event_counter.labels(event.type, origin_type, origin_entity).inc() event_counter.labels(event.type, origin_type, origin_entity).inc()
for room_id, new_state in iteritems(current_state_for_room): for room_id, new_state in current_state_for_room.items():
self.store.get_current_state_ids.prefill((room_id,), new_state) self.store.get_current_state_ids.prefill((room_id,), new_state)
for room_id, latest_event_ids in iteritems(new_forward_extremeties): for room_id, latest_event_ids in new_forward_extremeties.items():
self.store.get_latest_event_ids_in_room.prefill( self.store.get_latest_event_ids_in_room.prefill(
(room_id,), list(latest_event_ids) (room_id,), list(latest_event_ids)
) )
@ -461,7 +461,7 @@ class PersistEventsStore:
state_delta_by_room: Dict[str, DeltaState], state_delta_by_room: Dict[str, DeltaState],
stream_id: int, stream_id: int,
): ):
for room_id, delta_state in iteritems(state_delta_by_room): for room_id, delta_state in state_delta_by_room.items():
to_delete = delta_state.to_delete to_delete = delta_state.to_delete
to_insert = delta_state.to_insert to_insert = delta_state.to_insert
@ -545,7 +545,7 @@ class PersistEventsStore:
""", """,
[ [
(room_id, key[0], key[1], ev_id, ev_id) (room_id, key[0], key[1], ev_id, ev_id)
for key, ev_id in iteritems(to_insert) for key, ev_id in to_insert.items()
], ],
) )
@ -642,7 +642,7 @@ class PersistEventsStore:
def _update_forward_extremities_txn( def _update_forward_extremities_txn(
self, txn, new_forward_extremities, max_stream_order self, txn, new_forward_extremities, max_stream_order
): ):
for room_id, new_extrem in iteritems(new_forward_extremities): for room_id, new_extrem in new_forward_extremities.items():
self.db.simple_delete_txn( self.db.simple_delete_txn(
txn, table="event_forward_extremities", keyvalues={"room_id": room_id} txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
) )
@ -655,7 +655,7 @@ class PersistEventsStore:
table="event_forward_extremities", table="event_forward_extremities",
values=[ values=[
{"event_id": ev_id, "room_id": room_id} {"event_id": ev_id, "room_id": room_id}
for room_id, new_extrem in iteritems(new_forward_extremities) for room_id, new_extrem in new_forward_extremities.items()
for ev_id in new_extrem for ev_id in new_extrem
], ],
) )
@ -672,7 +672,7 @@ class PersistEventsStore:
"event_id": event_id, "event_id": event_id,
"stream_ordering": max_stream_order, "stream_ordering": max_stream_order,
} }
for room_id, new_extrem in iteritems(new_forward_extremities) for room_id, new_extrem in new_forward_extremities.items()
for event_id in new_extrem for event_id in new_extrem
], ],
) )
@ -727,7 +727,7 @@ class PersistEventsStore:
event.depth, depth_updates.get(event.room_id, event.depth) event.depth, depth_updates.get(event.room_id, event.depth)
) )
for room_id, depth in iteritems(depth_updates): for room_id, depth in depth_updates.items():
self._update_min_depth_for_room_txn(txn, room_id, depth) self._update_min_depth_for_room_txn(txn, room_id, depth)
def _update_outliers_txn(self, txn, events_and_contexts): def _update_outliers_txn(self, txn, events_and_contexts):
@ -1497,11 +1497,11 @@ class PersistEventsStore:
table="event_to_state_groups", table="event_to_state_groups",
values=[ values=[
{"state_group": state_group_id, "event_id": event_id} {"state_group": state_group_id, "event_id": event_id}
for event_id, state_group_id in iteritems(state_groups) for event_id, state_group_id in state_groups.items()
], ],
) )
for event_id, state_group_id in iteritems(state_groups): for event_id, state_group_id in state_groups.items():
txn.call_after( txn.call_after(
self.store._get_state_group_for_event.prefill, self.store._get_state_group_for_event.prefill,
(event_id,), (event_id,),

View file

@ -19,8 +19,6 @@ import logging
import re import re
from typing import Optional from typing import Optional
from six import iterkeys
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.defer import Deferred from twisted.internet.defer import Deferred
@ -753,7 +751,7 @@ class RegistrationWorkerStore(SQLBaseStore):
last_send_attempt, validated_at last_send_attempt, validated_at
FROM threepid_validation_session WHERE %s FROM threepid_validation_session WHERE %s
""" % ( """ % (
" AND ".join("%s = ?" % k for k in iterkeys(keyvalues)), " AND ".join("%s = ?" % k for k in keyvalues.keys()),
) )
if validated is not None: if validated is not None:

View file

@ -17,8 +17,6 @@
import logging import logging
from typing import Iterable, List, Set from typing import Iterable, List, Set
from six import iteritems, itervalues
from canonicaljson import json from canonicaljson import json
from twisted.internet import defer from twisted.internet import defer
@ -544,7 +542,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
users_in_room = {} users_in_room = {}
member_event_ids = [ member_event_ids = [
e_id e_id
for key, e_id in iteritems(current_state_ids) for key, e_id in current_state_ids.items()
if key[0] == EventTypes.Member if key[0] == EventTypes.Member
] ]
@ -561,7 +559,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
users_in_room = dict(prev_res) users_in_room = dict(prev_res)
member_event_ids = [ member_event_ids = [
e_id e_id
for key, e_id in iteritems(context.delta_ids) for key, e_id in context.delta_ids.items()
if key[0] == EventTypes.Member if key[0] == EventTypes.Member
] ]
for etype, state_key in context.delta_ids: for etype, state_key in context.delta_ids:
@ -1101,7 +1099,7 @@ class _JoinedHostsCache(object):
if state_entry.state_group == self.state_group: if state_entry.state_group == self.state_group:
pass pass
elif state_entry.prev_group == self.state_group: elif state_entry.prev_group == self.state_group:
for (typ, state_key), event_id in iteritems(state_entry.delta_ids): for (typ, state_key), event_id in state_entry.delta_ids.items():
if typ != EventTypes.Member: if typ != EventTypes.Member:
continue continue
@ -1131,7 +1129,7 @@ class _JoinedHostsCache(object):
self.state_group = state_entry.state_group self.state_group = state_entry.state_group
else: else:
self.state_group = object() self.state_group = object()
self._len = sum(len(v) for v in itervalues(self.hosts_to_joined_users)) self._len = sum(len(v) for v in self.hosts_to_joined_users.values())
return frozenset(self.hosts_to_joined_users) return frozenset(self.hosts_to_joined_users)
def __len__(self): def __len__(self):

View file

@ -186,7 +186,7 @@ class UIAuthWorkerStore(SQLBaseStore):
# The clientdict gets stored as JSON. # The clientdict gets stored as JSON.
clientdict_json = json.dumps(clientdict) clientdict_json = json.dumps(clientdict)
self.db.simple_update_one( await self.db.simple_update_one(
table="ui_auth_sessions", table="ui_auth_sessions",
keyvalues={"session_id": session_id}, keyvalues={"session_id": session_id},
updatevalues={"clientdict": clientdict_json}, updatevalues={"clientdict": clientdict_json},

View file

@ -15,8 +15,6 @@
import logging import logging
from six import iteritems
from twisted.internet import defer from twisted.internet import defer
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
@ -280,7 +278,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
delta_state = { delta_state = {
key: value key: value
for key, value in iteritems(curr_state) for key, value in curr_state.items()
if prev_state.get(key, None) != value if prev_state.get(key, None) != value
} }
@ -316,7 +314,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
"state_key": key[1], "state_key": key[1],
"event_id": state_id, "event_id": state_id,
} }
for key, state_id in iteritems(delta_state) for key, state_id in delta_state.items()
], ],
) )

View file

@ -17,7 +17,6 @@ import logging
from collections import namedtuple from collections import namedtuple
from typing import Dict, Iterable, List, Set, Tuple from typing import Dict, Iterable, List, Set, Tuple
from six import iteritems
from six.moves import range from six.moves import range
from twisted.internet import defer from twisted.internet import defer
@ -263,7 +262,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
# And finally update the result dict, by filtering out any extra # And finally update the result dict, by filtering out any extra
# stuff we pulled out of the database. # stuff we pulled out of the database.
for group, group_state_dict in iteritems(group_to_state_dict): for group, group_state_dict in group_to_state_dict.items():
# We just replace any existing entries, as we will have loaded # We just replace any existing entries, as we will have loaded
# everything we need from the database anyway. # everything we need from the database anyway.
state[group] = state_filter.filter_state(group_state_dict) state[group] = state_filter.filter_state(group_state_dict)
@ -341,11 +340,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
else: else:
non_member_types = non_member_filter.concrete_types() non_member_types = non_member_filter.concrete_types()
for group, group_state_dict in iteritems(group_to_state_dict): for group, group_state_dict in group_to_state_dict.items():
state_dict_members = {} state_dict_members = {}
state_dict_non_members = {} state_dict_non_members = {}
for k, v in iteritems(group_state_dict): for k, v in group_state_dict.items():
if k[0] == EventTypes.Member: if k[0] == EventTypes.Member:
state_dict_members[k] = v state_dict_members[k] = v
else: else:
@ -432,7 +431,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
"state_key": key[1], "state_key": key[1],
"event_id": state_id, "event_id": state_id,
} }
for key, state_id in iteritems(delta_ids) for key, state_id in delta_ids.items()
], ],
) )
else: else:
@ -447,7 +446,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
"state_key": key[1], "state_key": key[1],
"event_id": state_id, "event_id": state_id,
} }
for key, state_id in iteritems(current_state_ids) for key, state_id in current_state_ids.items()
], ],
) )
@ -458,7 +457,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
current_member_state_ids = { current_member_state_ids = {
s: ev s: ev
for (s, ev) in iteritems(current_state_ids) for (s, ev) in current_state_ids.items()
if s[0] == EventTypes.Member if s[0] == EventTypes.Member
} }
txn.call_after( txn.call_after(
@ -470,7 +469,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
current_non_member_state_ids = { current_non_member_state_ids = {
s: ev s: ev
for (s, ev) in iteritems(current_state_ids) for (s, ev) in current_state_ids.items()
if s[0] != EventTypes.Member if s[0] != EventTypes.Member
} }
txn.call_after( txn.call_after(
@ -555,7 +554,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
"state_key": key[1], "state_key": key[1],
"event_id": state_id, "event_id": state_id,
} }
for key, state_id in iteritems(curr_state) for key, state_id in curr_state.items()
], ],
) )

View file

@ -29,7 +29,6 @@ from typing import (
TypeVar, TypeVar,
) )
from six import iteritems, iterkeys, itervalues
from six.moves import intern, range from six.moves import intern, range
from prometheus_client import Histogram from prometheus_client import Histogram
@ -259,7 +258,7 @@ class PerformanceCounters(object):
def interval(self, interval_duration_secs, limit=3): def interval(self, interval_duration_secs, limit=3):
counters = [] counters = []
for name, (count, cum_time) in iteritems(self.current_counters): for name, (count, cum_time) in self.current_counters.items():
prev_count, prev_time = self.previous_counters.get(name, (0, 0)) prev_count, prev_time = self.previous_counters.get(name, (0, 0))
counters.append( counters.append(
( (
@ -1053,7 +1052,7 @@ class Database(object):
sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table} sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
if keyvalues: if keyvalues:
sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)) sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
txn.execute(sql, list(keyvalues.values())) txn.execute(sql, list(keyvalues.values()))
else: else:
txn.execute(sql) txn.execute(sql)
@ -1191,7 +1190,7 @@ class Database(object):
clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable) clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
clauses = [clause] clauses = [clause]
for key, value in iteritems(keyvalues): for key, value in keyvalues.items():
clauses.append("%s = ?" % (key,)) clauses.append("%s = ?" % (key,))
values.append(value) values.append(value)
@ -1212,7 +1211,7 @@ class Database(object):
@staticmethod @staticmethod
def simple_update_txn(txn, table, keyvalues, updatevalues): def simple_update_txn(txn, table, keyvalues, updatevalues):
if keyvalues: if keyvalues:
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)) where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
else: else:
where = "" where = ""
@ -1351,7 +1350,7 @@ class Database(object):
clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable) clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
clauses = [clause] clauses = [clause]
for key, value in iteritems(keyvalues): for key, value in keyvalues.items():
clauses.append("%s = ?" % (key,)) clauses.append("%s = ?" % (key,))
values.append(value) values.append(value)
@ -1388,7 +1387,7 @@ class Database(object):
txn.close() txn.close()
if cache: if cache:
min_val = min(itervalues(cache)) min_val = min(cache.values())
else: else:
min_val = max_value min_val = max_value

View file

@ -20,7 +20,6 @@ import logging
from collections import deque, namedtuple from collections import deque, namedtuple
from typing import Iterable, List, Optional, Set, Tuple from typing import Iterable, List, Optional, Set, Tuple
from six import iteritems
from six.moves import range from six.moves import range
from prometheus_client import Counter, Histogram from prometheus_client import Counter, Histogram
@ -218,7 +217,7 @@ class EventsPersistenceStorage(object):
partitioned.setdefault(event.room_id, []).append((event, ctx)) partitioned.setdefault(event.room_id, []).append((event, ctx))
deferreds = [] deferreds = []
for room_id, evs_ctxs in iteritems(partitioned): for room_id, evs_ctxs in partitioned.items():
d = self._event_persist_queue.add_to_queue( d = self._event_persist_queue.add_to_queue(
room_id, evs_ctxs, backfilled=backfilled room_id, evs_ctxs, backfilled=backfilled
) )
@ -319,7 +318,7 @@ class EventsPersistenceStorage(object):
(event, context) (event, context)
) )
for room_id, ev_ctx_rm in iteritems(events_by_room): for room_id, ev_ctx_rm in events_by_room.items():
latest_event_ids = await self.main_store.get_latest_event_ids_in_room( latest_event_ids = await self.main_store.get_latest_event_ids_in_room(
room_id room_id
) )
@ -674,7 +673,7 @@ class EventsPersistenceStorage(object):
to_insert = { to_insert = {
key: ev_id key: ev_id
for key, ev_id in iteritems(current_state) for key, ev_id in current_state.items()
if ev_id != existing_state.get(key) if ev_id != existing_state.get(key)
} }

View file

@ -16,8 +16,6 @@
import logging import logging
from typing import Iterable, List, TypeVar from typing import Iterable, List, TypeVar
from six import iteritems, itervalues
import attr import attr
from twisted.internet import defer from twisted.internet import defer
@ -51,7 +49,7 @@ class StateFilter(object):
# If `include_others` is set we canonicalise the filter by removing # If `include_others` is set we canonicalise the filter by removing
# wildcards from the types dictionary # wildcards from the types dictionary
if self.include_others: if self.include_others:
self.types = {k: v for k, v in iteritems(self.types) if v is not None} self.types = {k: v for k, v in self.types.items() if v is not None}
@staticmethod @staticmethod
def all(): def all():
@ -150,7 +148,7 @@ class StateFilter(object):
has_non_member_wildcard = self.include_others or any( has_non_member_wildcard = self.include_others or any(
state_keys is None state_keys is None
for t, state_keys in iteritems(self.types) for t, state_keys in self.types.items()
if t != EventTypes.Member if t != EventTypes.Member
) )
@ -199,7 +197,7 @@ class StateFilter(object):
# First we build up a lost of clauses for each type/state_key combo # First we build up a lost of clauses for each type/state_key combo
clauses = [] clauses = []
for etype, state_keys in iteritems(self.types): for etype, state_keys in self.types.items():
if state_keys is None: if state_keys is None:
clauses.append("(type = ?)") clauses.append("(type = ?)")
where_args.append(etype) where_args.append(etype)
@ -251,7 +249,7 @@ class StateFilter(object):
return dict(state_dict) return dict(state_dict)
filtered_state = {} filtered_state = {}
for k, v in iteritems(state_dict): for k, v in state_dict.items():
typ, state_key = k typ, state_key = k
if typ in self.types: if typ in self.types:
state_keys = self.types[typ] state_keys = self.types[typ]
@ -279,7 +277,7 @@ class StateFilter(object):
""" """
return self.include_others or any( return self.include_others or any(
state_keys is None for state_keys in itervalues(self.types) state_keys is None for state_keys in self.types.values()
) )
def concrete_types(self): def concrete_types(self):
@ -292,7 +290,7 @@ class StateFilter(object):
""" """
return [ return [
(t, s) (t, s)
for t, state_keys in iteritems(self.types) for t, state_keys in self.types.items()
if state_keys is not None if state_keys is not None
for s in state_keys for s in state_keys
] ]
@ -324,7 +322,7 @@ class StateFilter(object):
member_filter = StateFilter.none() member_filter = StateFilter.none()
non_member_filter = StateFilter( non_member_filter = StateFilter(
types={k: v for k, v in iteritems(self.types) if k != EventTypes.Member}, types={k: v for k, v in self.types.items() if k != EventTypes.Member},
include_others=self.include_others, include_others=self.include_others,
) )
@ -366,7 +364,7 @@ class StateGroupStorage(object):
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
groups = set(itervalues(event_to_groups)) groups = set(event_to_groups.values())
group_to_state = yield self.stores.state._get_state_for_groups(groups) group_to_state = yield self.stores.state._get_state_for_groups(groups)
return group_to_state return group_to_state
@ -400,8 +398,8 @@ class StateGroupStorage(object):
state_event_map = yield self.stores.main.get_events( state_event_map = yield self.stores.main.get_events(
[ [
ev_id ev_id
for group_ids in itervalues(group_to_ids) for group_ids in group_to_ids.values()
for ev_id in itervalues(group_ids) for ev_id in group_ids.values()
], ],
get_prev_content=False, get_prev_content=False,
) )
@ -409,10 +407,10 @@ class StateGroupStorage(object):
return { return {
group: [ group: [
state_event_map[v] state_event_map[v]
for v in itervalues(event_id_map) for v in event_id_map.values()
if v in state_event_map if v in state_event_map
] ]
for group, event_id_map in iteritems(group_to_ids) for group, event_id_map in group_to_ids.items()
} }
def _get_state_groups_from_groups( def _get_state_groups_from_groups(
@ -444,23 +442,23 @@ class StateGroupStorage(object):
""" """
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
groups = set(itervalues(event_to_groups)) groups = set(event_to_groups.values())
group_to_state = yield self.stores.state._get_state_for_groups( group_to_state = yield self.stores.state._get_state_for_groups(
groups, state_filter groups, state_filter
) )
state_event_map = yield self.stores.main.get_events( state_event_map = yield self.stores.main.get_events(
[ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)], [ev_id for sd in group_to_state.values() for ev_id in sd.values()],
get_prev_content=False, get_prev_content=False,
) )
event_to_state = { event_to_state = {
event_id: { event_id: {
k: state_event_map[v] k: state_event_map[v]
for k, v in iteritems(group_to_state[group]) for k, v in group_to_state[group].items()
if v in state_event_map if v in state_event_map
} }
for event_id, group in iteritems(event_to_groups) for event_id, group in event_to_groups.items()
} }
return {event: event_to_state[event] for event in event_ids} return {event: event_to_state[event] for event in event_ids}
@ -481,14 +479,14 @@ class StateGroupStorage(object):
""" """
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids) event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
groups = set(itervalues(event_to_groups)) groups = set(event_to_groups.values())
group_to_state = yield self.stores.state._get_state_for_groups( group_to_state = yield self.stores.state._get_state_for_groups(
groups, state_filter groups, state_filter
) )
event_to_state = { event_to_state = {
event_id: group_to_state[group] event_id: group_to_state[group]
for event_id, group in iteritems(event_to_groups) for event_id, group in event_to_groups.items()
} }
return {event: event_to_state[event] for event in event_ids} return {event: event_to_state[event] for event in event_ids}

View file

@ -21,8 +21,6 @@ import threading
from typing import Any, Tuple, Union, cast from typing import Any, Tuple, Union, cast
from weakref import WeakValueDictionary from weakref import WeakValueDictionary
from six import itervalues
from prometheus_client import Gauge from prometheus_client import Gauge
from typing_extensions import Protocol from typing_extensions import Protocol
@ -281,7 +279,7 @@ class Cache(object):
def invalidate_all(self): def invalidate_all(self):
self.check_thread() self.check_thread()
self.cache.clear() self.cache.clear()
for entry in itervalues(self._pending_deferred_cache): for entry in self._pending_deferred_cache.values():
entry.invalidate() entry.invalidate()
self._pending_deferred_cache.clear() self._pending_deferred_cache.clear()

View file

@ -16,8 +16,6 @@
import logging import logging
from collections import OrderedDict from collections import OrderedDict
from six import iteritems, itervalues
from synapse.config import cache as cache_config from synapse.config import cache as cache_config
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.caches import register_cache from synapse.util.caches import register_cache
@ -150,7 +148,7 @@ class ExpiringCache(object):
keys_to_delete = set() keys_to_delete = set()
for key, cache_entry in iteritems(self._cache): for key, cache_entry in self._cache.items():
if now - cache_entry.time > self._expiry_ms: if now - cache_entry.time > self._expiry_ms:
keys_to_delete.add(key) keys_to_delete.add(key)
@ -170,7 +168,7 @@ class ExpiringCache(object):
def __len__(self): def __len__(self):
if self.iterable: if self.iterable:
return sum(len(entry.value) for entry in itervalues(self._cache)) return sum(len(entry.value) for entry in self._cache.values())
else: else:
return len(self._cache) return len(self._cache)

View file

@ -1,7 +1,5 @@
from typing import Dict from typing import Dict
from six import itervalues
SENTINEL = object() SENTINEL = object()
@ -81,7 +79,7 @@ def iterate_tree_cache_entry(d):
can contain dicts. can contain dicts.
""" """
if isinstance(d, dict): if isinstance(d, dict):
for value_d in itervalues(d): for value_d in d.values():
for value in iterate_tree_cache_entry(value_d): for value in iterate_tree_cache_entry(value_d):
yield value yield value
else: else:

View file

@ -16,7 +16,6 @@
import logging import logging
import operator import operator
from six import iteritems, itervalues
from six.moves import map from six.moves import map
from twisted.internet import defer from twisted.internet import defer
@ -298,7 +297,7 @@ def filter_events_for_server(
# membership states for the requesting server to determine # membership states for the requesting server to determine
# if the server is either in the room or has been invited # if the server is either in the room or has been invited
# into the room. # into the room.
for ev in itervalues(state): for ev in state.values():
if ev.type != EventTypes.Member: if ev.type != EventTypes.Member:
continue continue
try: try:
@ -332,7 +331,7 @@ def filter_events_for_server(
) )
visibility_ids = set() visibility_ids = set()
for sids in itervalues(event_to_state_ids): for sids in event_to_state_ids.values():
hist = sids.get((EventTypes.RoomHistoryVisibility, "")) hist = sids.get((EventTypes.RoomHistoryVisibility, ""))
if hist: if hist:
visibility_ids.add(hist) visibility_ids.add(hist)
@ -345,7 +344,7 @@ def filter_events_for_server(
event_map = yield storage.main.get_events(visibility_ids) event_map = yield storage.main.get_events(visibility_ids)
all_open = all( all_open = all(
e.content.get("history_visibility") in (None, "shared", "world_readable") e.content.get("history_visibility") in (None, "shared", "world_readable")
for e in itervalues(event_map) for e in event_map.values()
) )
if not check_history_visibility_only: if not check_history_visibility_only:
@ -394,8 +393,8 @@ def filter_events_for_server(
# #
event_id_to_state_key = { event_id_to_state_key = {
event_id: key event_id: key
for key_to_eid in itervalues(event_to_state_ids) for key_to_eid in event_to_state_ids.values()
for key, event_id in iteritems(key_to_eid) for key, event_id in key_to_eid.items()
} }
def include(typ, state_key): def include(typ, state_key):
@ -409,20 +408,16 @@ def filter_events_for_server(
return state_key[idx + 1 :] == server_name return state_key[idx + 1 :] == server_name
event_map = yield storage.main.get_events( event_map = yield storage.main.get_events(
[ [e_id for e_id, key in event_id_to_state_key.items() if include(key[0], key[1])]
e_id
for e_id, key in iteritems(event_id_to_state_key)
if include(key[0], key[1])
]
) )
event_to_state = { event_to_state = {
e_id: { e_id: {
key: event_map[inner_e_id] key: event_map[inner_e_id]
for key, inner_e_id in iteritems(key_to_eid) for key, inner_e_id in key_to_eid.items()
if inner_e_id in event_map if inner_e_id in event_map
} }
for e_id, key_to_eid in iteritems(event_to_state_ids) for e_id, key_to_eid in event_to_state_ids.items()
} }
to_return = [] to_return = []

View file

@ -17,6 +17,7 @@ from typing import List, Optional
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase from synapse.events import EventBase
from synapse.replication.tcp.commands import RdataCommand
from synapse.replication.tcp.streams._base import _STREAM_UPDATE_TARGET_ROW_COUNT from synapse.replication.tcp.streams._base import _STREAM_UPDATE_TARGET_ROW_COUNT
from synapse.replication.tcp.streams.events import ( from synapse.replication.tcp.streams.events import (
EventsStreamCurrentStateRow, EventsStreamCurrentStateRow,
@ -66,11 +67,6 @@ class EventsStreamTestCase(BaseStreamTestCase):
# also one state event # also one state event
state_event = self._inject_state_event() state_event = self._inject_state_event()
# tell the notifier to catch up to avoid duplicate rows.
# workaround for https://github.com/matrix-org/synapse/issues/7360
# FIXME remove this when the above is fixed
self.replicate()
# check we're testing what we think we are: no rows should yet have been # check we're testing what we think we are: no rows should yet have been
# received # received
self.assertEqual([], self.test_handler.received_rdata_rows) self.assertEqual([], self.test_handler.received_rdata_rows)
@ -174,11 +170,6 @@ class EventsStreamTestCase(BaseStreamTestCase):
# one more bit of state that doesn't get rolled back # one more bit of state that doesn't get rolled back
state2 = self._inject_state_event() state2 = self._inject_state_event()
# tell the notifier to catch up to avoid duplicate rows.
# workaround for https://github.com/matrix-org/synapse/issues/7360
# FIXME remove this when the above is fixed
self.replicate()
# check we're testing what we think we are: no rows should yet have been # check we're testing what we think we are: no rows should yet have been
# received # received
self.assertEqual([], self.test_handler.received_rdata_rows) self.assertEqual([], self.test_handler.received_rdata_rows)
@ -327,11 +318,6 @@ class EventsStreamTestCase(BaseStreamTestCase):
prev_events = [e.event_id] prev_events = [e.event_id]
pl_events.append(e) pl_events.append(e)
# tell the notifier to catch up to avoid duplicate rows.
# workaround for https://github.com/matrix-org/synapse/issues/7360
# FIXME remove this when the above is fixed
self.replicate()
# check we're testing what we think we are: no rows should yet have been # check we're testing what we think we are: no rows should yet have been
# received # received
self.assertEqual([], self.test_handler.received_rdata_rows) self.assertEqual([], self.test_handler.received_rdata_rows)
@ -378,6 +364,64 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.assertEqual([], received_rows) self.assertEqual([], received_rows)
def test_backwards_stream_id(self):
"""
Test that RDATA that comes after the current position should be discarded.
"""
# disconnect, so that we can stack up some changes
self.disconnect()
# Generate an events. We inject them using inject_event so that they are
# not send out over replication until we call self.replicate().
event = self._inject_test_event()
# check we're testing what we think we are: no rows should yet have been
# received
self.assertEqual([], self.test_handler.received_rdata_rows)
# now reconnect to pull the updates
self.reconnect()
self.replicate()
# We should have received the expected single row (as well as various
# cache invalidation updates which we ignore).
received_rows = [
row for row in self.test_handler.received_rdata_rows if row[0] == "events"
]
# There should be a single received row.
self.assertEqual(len(received_rows), 1)
stream_name, token, row = received_rows[0]
self.assertEqual("events", stream_name)
self.assertIsInstance(row, EventsStreamRow)
self.assertEqual(row.type, "ev")
self.assertIsInstance(row.data, EventsStreamEventRow)
self.assertEqual(row.data.event_id, event.event_id)
# Reset the data.
self.test_handler.received_rdata_rows = []
# Save the current token for later.
worker_events_stream = self.worker_hs.get_replication_streams()["events"]
prev_token = worker_events_stream.current_token("master")
# Manually send an old RDATA command, which should get dropped. This
# re-uses the row from above, but with an earlier stream token.
self.hs.get_tcp_replication().send_command(
RdataCommand("events", "master", 1, row)
)
# No updates have been received (because it was discard as old).
received_rows = [
row for row in self.test_handler.received_rdata_rows if row[0] == "events"
]
self.assertEqual(len(received_rows), 0)
# Ensure the stream has not gone backwards.
current_token = worker_events_stream.current_token("master")
self.assertGreaterEqual(current_token, prev_token)
event_count = 0 event_count = 0
def _inject_test_event( def _inject_test_event(

View file

@ -16,10 +16,15 @@ from mock import Mock
from synapse.handlers.typing import RoomMember from synapse.handlers.typing import RoomMember
from synapse.replication.tcp.streams import TypingStream from synapse.replication.tcp.streams import TypingStream
from synapse.util.caches.stream_change_cache import StreamChangeCache
from tests.replication._base import BaseStreamTestCase from tests.replication._base import BaseStreamTestCase
USER_ID = "@feeling:blue" USER_ID = "@feeling:blue"
USER_ID_2 = "@da-ba-dee:blue"
ROOM_ID = "!bar:blue"
ROOM_ID_2 = "!foo:blue"
class TypingStreamTestCase(BaseStreamTestCase): class TypingStreamTestCase(BaseStreamTestCase):
@ -29,11 +34,9 @@ class TypingStreamTestCase(BaseStreamTestCase):
def test_typing(self): def test_typing(self):
typing = self.hs.get_typing_handler() typing = self.hs.get_typing_handler()
room_id = "!bar:blue"
self.reconnect() self.reconnect()
typing._push_update(member=RoomMember(room_id, USER_ID), typing=True) typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=True)
self.reactor.advance(0) self.reactor.advance(0)
@ -46,7 +49,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
self.assertEqual(stream_name, "typing") self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows)) self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0] # type: TypingStream.TypingStreamRow row = rdata_rows[0] # type: TypingStream.TypingStreamRow
self.assertEqual(room_id, row.room_id) self.assertEqual(ROOM_ID, row.room_id)
self.assertEqual([USER_ID], row.user_ids) self.assertEqual([USER_ID], row.user_ids)
# Now let's disconnect and insert some data. # Now let's disconnect and insert some data.
@ -54,7 +57,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
self.test_handler.on_rdata.reset_mock() self.test_handler.on_rdata.reset_mock()
typing._push_update(member=RoomMember(room_id, USER_ID), typing=False) typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=False)
self.test_handler.on_rdata.assert_not_called() self.test_handler.on_rdata.assert_not_called()
@ -73,5 +76,78 @@ class TypingStreamTestCase(BaseStreamTestCase):
self.assertEqual(stream_name, "typing") self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows)) self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0] row = rdata_rows[0]
self.assertEqual(room_id, row.room_id) self.assertEqual(ROOM_ID, row.room_id)
self.assertEqual([], row.user_ids) self.assertEqual([], row.user_ids)
def test_reset(self):
"""
Test what happens when a typing stream resets.
This is emulated by jumping the stream ahead, then reconnecting (which
sends the proper position and RDATA).
"""
typing = self.hs.get_typing_handler()
self.reconnect()
typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=True)
self.reactor.advance(0)
# We should now see an attempt to connect to the master
request = self.handle_http_replication_attempt()
self.assert_request_is_get_repl_stream_updates(request, "typing")
self.test_handler.on_rdata.assert_called_once()
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0] # type: TypingStream.TypingStreamRow
self.assertEqual(ROOM_ID, row.room_id)
self.assertEqual([USER_ID], row.user_ids)
# Push the stream forward a bunch so it can be reset.
for i in range(100):
typing._push_update(
member=RoomMember(ROOM_ID, "@test%s:blue" % i), typing=True
)
self.reactor.advance(0)
# Disconnect.
self.disconnect()
# Reset the typing handler
self.hs.get_replication_streams()["typing"].last_token = 0
self.hs.get_tcp_replication()._streams["typing"].last_token = 0
typing._latest_room_serial = 0
typing._typing_stream_change_cache = StreamChangeCache(
"TypingStreamChangeCache", typing._latest_room_serial
)
typing._reset()
# Reconnect.
self.reconnect()
self.pump(0.1)
# We should now see an attempt to connect to the master
request = self.handle_http_replication_attempt()
self.assert_request_is_get_repl_stream_updates(request, "typing")
# Reset the test code.
self.test_handler.on_rdata.reset_mock()
self.test_handler.on_rdata.assert_not_called()
# Push additional data.
typing._push_update(member=RoomMember(ROOM_ID_2, USER_ID_2), typing=False)
self.reactor.advance(0)
self.test_handler.on_rdata.assert_called_once()
stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0]
self.assertEqual(ROOM_ID_2, row.room_id)
self.assertEqual([], row.user_ids)
# The token should have been reset.
self.assertEqual(token, 1)