mirror of
https://mau.dev/maunium/synapse.git
synced 2025-01-22 07:10:06 +01:00
Merge branch 'develop' into markjh/bearer_token
This commit is contained in:
commit
4a32d25d4c
9 changed files with 107 additions and 44 deletions
|
@ -242,6 +242,9 @@ class SynchrotronTyping(object):
|
||||||
self._room_typing = {}
|
self._room_typing = {}
|
||||||
|
|
||||||
def stream_positions(self):
|
def stream_positions(self):
|
||||||
|
# We must update this typing token from the response of the previous
|
||||||
|
# sync. In particular, the stream id may "reset" back to zero/a low
|
||||||
|
# value which we *must* use for the next replication request.
|
||||||
return {"typing": self._latest_room_serial}
|
return {"typing": self._latest_room_serial}
|
||||||
|
|
||||||
def process_replication(self, result):
|
def process_replication(self, result):
|
||||||
|
|
|
@ -122,8 +122,12 @@ class FederationClient(FederationBase):
|
||||||
pdu.event_id
|
pdu.event_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def send_presence(self, destination, states):
|
||||||
|
if destination != self.server_name:
|
||||||
|
self._transaction_queue.enqueue_presence(destination, states)
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
def send_edu(self, destination, edu_type, content):
|
def send_edu(self, destination, edu_type, content, key=None):
|
||||||
edu = Edu(
|
edu = Edu(
|
||||||
origin=self.server_name,
|
origin=self.server_name,
|
||||||
destination=destination,
|
destination=destination,
|
||||||
|
@ -134,7 +138,7 @@ class FederationClient(FederationBase):
|
||||||
sent_edus_counter.inc()
|
sent_edus_counter.inc()
|
||||||
|
|
||||||
# TODO, add errback, etc.
|
# TODO, add errback, etc.
|
||||||
self._transaction_queue.enqueue_edu(edu)
|
self._transaction_queue.enqueue_edu(edu, key=key)
|
||||||
return defer.succeed(None)
|
return defer.succeed(None)
|
||||||
|
|
||||||
@log_function
|
@log_function
|
||||||
|
|
|
@ -26,6 +26,7 @@ from synapse.util.retryutils import (
|
||||||
get_retry_limiter, NotRetryingDestination,
|
get_retry_limiter, NotRetryingDestination,
|
||||||
)
|
)
|
||||||
from synapse.util.metrics import measure_func
|
from synapse.util.metrics import measure_func
|
||||||
|
from synapse.handlers.presence import format_user_presence_state
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
@ -69,13 +70,21 @@ class TransactionQueue(object):
|
||||||
# destination -> list of tuple(edu, deferred)
|
# destination -> list of tuple(edu, deferred)
|
||||||
self.pending_edus_by_dest = edus = {}
|
self.pending_edus_by_dest = edus = {}
|
||||||
|
|
||||||
|
# Presence needs to be separate as we send single aggragate EDUs
|
||||||
|
self.pending_presence_by_dest = presence = {}
|
||||||
|
self.pending_edus_keyed_by_dest = edus_keyed = {}
|
||||||
|
|
||||||
metrics.register_callback(
|
metrics.register_callback(
|
||||||
"pending_pdus",
|
"pending_pdus",
|
||||||
lambda: sum(map(len, pdus.values())),
|
lambda: sum(map(len, pdus.values())),
|
||||||
)
|
)
|
||||||
metrics.register_callback(
|
metrics.register_callback(
|
||||||
"pending_edus",
|
"pending_edus",
|
||||||
lambda: sum(map(len, edus.values())),
|
lambda: (
|
||||||
|
sum(map(len, edus.values()))
|
||||||
|
+ sum(map(len, presence.values()))
|
||||||
|
+ sum(map(len, edus_keyed.values()))
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# destination -> list of tuple(failure, deferred)
|
# destination -> list of tuple(failure, deferred)
|
||||||
|
@ -130,12 +139,26 @@ class TransactionQueue(object):
|
||||||
self._attempt_new_transaction, destination
|
self._attempt_new_transaction, destination
|
||||||
)
|
)
|
||||||
|
|
||||||
def enqueue_edu(self, edu):
|
def enqueue_presence(self, destination, states):
|
||||||
|
self.pending_presence_by_dest.setdefault(destination, {}).update({
|
||||||
|
state.user_id: state for state in states
|
||||||
|
})
|
||||||
|
|
||||||
|
preserve_context_over_fn(
|
||||||
|
self._attempt_new_transaction, destination
|
||||||
|
)
|
||||||
|
|
||||||
|
def enqueue_edu(self, edu, key=None):
|
||||||
destination = edu.destination
|
destination = edu.destination
|
||||||
|
|
||||||
if not self.can_send_to(destination):
|
if not self.can_send_to(destination):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if key:
|
||||||
|
self.pending_edus_keyed_by_dest.setdefault(
|
||||||
|
destination, {}
|
||||||
|
)[(edu.edu_type, key)] = edu
|
||||||
|
else:
|
||||||
self.pending_edus_by_dest.setdefault(destination, []).append(edu)
|
self.pending_edus_by_dest.setdefault(destination, []).append(edu)
|
||||||
|
|
||||||
preserve_context_over_fn(
|
preserve_context_over_fn(
|
||||||
|
@ -190,8 +213,13 @@ class TransactionQueue(object):
|
||||||
while True:
|
while True:
|
||||||
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
|
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
|
||||||
pending_edus = self.pending_edus_by_dest.pop(destination, [])
|
pending_edus = self.pending_edus_by_dest.pop(destination, [])
|
||||||
|
pending_presence = self.pending_presence_by_dest.pop(destination, {})
|
||||||
pending_failures = self.pending_failures_by_dest.pop(destination, [])
|
pending_failures = self.pending_failures_by_dest.pop(destination, [])
|
||||||
|
|
||||||
|
pending_edus.extend(
|
||||||
|
self.pending_edus_keyed_by_dest.pop(destination, {}).values()
|
||||||
|
)
|
||||||
|
|
||||||
limiter = yield get_retry_limiter(
|
limiter = yield get_retry_limiter(
|
||||||
destination,
|
destination,
|
||||||
self.clock,
|
self.clock,
|
||||||
|
@ -203,6 +231,22 @@ class TransactionQueue(object):
|
||||||
)
|
)
|
||||||
|
|
||||||
pending_edus.extend(device_message_edus)
|
pending_edus.extend(device_message_edus)
|
||||||
|
if pending_presence:
|
||||||
|
pending_edus.append(
|
||||||
|
Edu(
|
||||||
|
origin=self.server_name,
|
||||||
|
destination=destination,
|
||||||
|
edu_type="m.presence",
|
||||||
|
content={
|
||||||
|
"push": [
|
||||||
|
format_user_presence_state(
|
||||||
|
presence, self.clock.time_msec()
|
||||||
|
)
|
||||||
|
for presence in pending_presence.values()
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if pending_pdus:
|
if pending_pdus:
|
||||||
logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
|
logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
|
||||||
|
|
|
@ -625,18 +625,8 @@ class PresenceHandler(object):
|
||||||
Args:
|
Args:
|
||||||
hosts_to_states (dict): Mapping `server_name` -> `[UserPresenceState]`
|
hosts_to_states (dict): Mapping `server_name` -> `[UserPresenceState]`
|
||||||
"""
|
"""
|
||||||
now = self.clock.time_msec()
|
|
||||||
for host, states in hosts_to_states.items():
|
for host, states in hosts_to_states.items():
|
||||||
self.federation.send_edu(
|
self.federation.send_presence(host, states)
|
||||||
destination=host,
|
|
||||||
edu_type="m.presence",
|
|
||||||
content={
|
|
||||||
"push": [
|
|
||||||
_format_user_presence_state(state, now)
|
|
||||||
for state in states
|
|
||||||
]
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def incoming_presence(self, origin, content):
|
def incoming_presence(self, origin, content):
|
||||||
|
@ -723,13 +713,13 @@ class PresenceHandler(object):
|
||||||
defer.returnValue([
|
defer.returnValue([
|
||||||
{
|
{
|
||||||
"type": "m.presence",
|
"type": "m.presence",
|
||||||
"content": _format_user_presence_state(state, now),
|
"content": format_user_presence_state(state, now),
|
||||||
}
|
}
|
||||||
for state in updates
|
for state in updates
|
||||||
])
|
])
|
||||||
else:
|
else:
|
||||||
defer.returnValue([
|
defer.returnValue([
|
||||||
_format_user_presence_state(state, now) for state in updates
|
format_user_presence_state(state, now) for state in updates
|
||||||
])
|
])
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -988,7 +978,7 @@ def should_notify(old_state, new_state):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _format_user_presence_state(state, now):
|
def format_user_presence_state(state, now):
|
||||||
"""Convert UserPresenceState to a format that can be sent down to clients
|
"""Convert UserPresenceState to a format that can be sent down to clients
|
||||||
and to other servers.
|
and to other servers.
|
||||||
"""
|
"""
|
||||||
|
@ -1101,7 +1091,7 @@ class PresenceEventSource(object):
|
||||||
defer.returnValue(([
|
defer.returnValue(([
|
||||||
{
|
{
|
||||||
"type": "m.presence",
|
"type": "m.presence",
|
||||||
"content": _format_user_presence_state(s, now),
|
"content": format_user_presence_state(s, now),
|
||||||
}
|
}
|
||||||
for s in updates.values()
|
for s in updates.values()
|
||||||
if include_offline or s.state != PresenceState.OFFLINE
|
if include_offline or s.state != PresenceState.OFFLINE
|
||||||
|
|
|
@ -156,6 +156,7 @@ class ReceiptsHandler(BaseHandler):
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
key=(room_id, receipt_type, user_id),
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
|
@ -187,6 +187,7 @@ class TypingHandler(object):
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"typing": typing,
|
"typing": typing,
|
||||||
},
|
},
|
||||||
|
key=(room_id, user_id),
|
||||||
))
|
))
|
||||||
|
|
||||||
yield preserve_context_over_deferred(
|
yield preserve_context_over_deferred(
|
||||||
|
|
|
@ -274,11 +274,18 @@ class ReplicationResource(Resource):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def typing(self, writer, current_token, request_streams):
|
def typing(self, writer, current_token, request_streams):
|
||||||
current_position = current_token.presence
|
current_position = current_token.typing
|
||||||
|
|
||||||
request_typing = request_streams.get("typing")
|
request_typing = request_streams.get("typing")
|
||||||
|
|
||||||
if request_typing is not None:
|
if request_typing is not None:
|
||||||
|
# If they have a higher token than current max, we can assume that
|
||||||
|
# they had been talking to a previous instance of the master. Since
|
||||||
|
# we reset the token on restart, the best (but hacky) thing we can
|
||||||
|
# do is to simply resend down all the typing notifications.
|
||||||
|
if request_typing > current_position:
|
||||||
|
request_typing = 0
|
||||||
|
|
||||||
typing_rows = yield self.typing_handler.get_all_typing_updates(
|
typing_rows = yield self.typing_handler.get_all_typing_updates(
|
||||||
request_typing, current_position
|
request_typing, current_position
|
||||||
)
|
)
|
||||||
|
|
|
@ -318,7 +318,7 @@ class CasRedirectServlet(ClientV1RestServlet):
|
||||||
service_param = urllib.urlencode({
|
service_param = urllib.urlencode({
|
||||||
"service": "%s?%s" % (hs_redirect_url, client_redirect_url_param)
|
"service": "%s?%s" % (hs_redirect_url, client_redirect_url_param)
|
||||||
})
|
})
|
||||||
request.redirect("%s?%s" % (self.cas_server_url, service_param))
|
request.redirect("%s/login?%s" % (self.cas_server_url, service_param))
|
||||||
finish_request(request)
|
finish_request(request)
|
||||||
|
|
||||||
|
|
||||||
|
@ -385,7 +385,7 @@ class CasTicketServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
def parse_cas_response(self, cas_response_body):
|
def parse_cas_response(self, cas_response_body):
|
||||||
user = None
|
user = None
|
||||||
attributes = None
|
attributes = {}
|
||||||
try:
|
try:
|
||||||
root = ET.fromstring(cas_response_body)
|
root = ET.fromstring(cas_response_body)
|
||||||
if not root.tag.endswith("serviceResponse"):
|
if not root.tag.endswith("serviceResponse"):
|
||||||
|
@ -395,7 +395,6 @@ class CasTicketServlet(ClientV1RestServlet):
|
||||||
if child.tag.endswith("user"):
|
if child.tag.endswith("user"):
|
||||||
user = child.text
|
user = child.text
|
||||||
if child.tag.endswith("attributes"):
|
if child.tag.endswith("attributes"):
|
||||||
attributes = {}
|
|
||||||
for attribute in child:
|
for attribute in child:
|
||||||
# ElementTree library expands the namespace in
|
# ElementTree library expands the namespace in
|
||||||
# attribute tags to the full URL of the namespace.
|
# attribute tags to the full URL of the namespace.
|
||||||
|
@ -407,8 +406,6 @@ class CasTicketServlet(ClientV1RestServlet):
|
||||||
attributes[tag] = attribute.text
|
attributes[tag] = attribute.text
|
||||||
if user is None:
|
if user is None:
|
||||||
raise Exception("CAS response does not contain user")
|
raise Exception("CAS response does not contain user")
|
||||||
if attributes is None:
|
|
||||||
raise Exception("CAS response does not contain attributes")
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.error("Error parsing CAS response", exc_info=1)
|
logger.error("Error parsing CAS response", exc_info=1)
|
||||||
raise LoginError(401, "Invalid CAS response",
|
raise LoginError(401, "Invalid CAS response",
|
||||||
|
|
|
@ -306,13 +306,6 @@ class StateStore(SQLBaseStore):
|
||||||
defer.returnValue(results)
|
defer.returnValue(results)
|
||||||
|
|
||||||
def _get_state_groups_from_groups_txn(self, txn, groups, types=None):
|
def _get_state_groups_from_groups_txn(self, txn, groups, types=None):
|
||||||
if types is not None:
|
|
||||||
where_clause = "AND (%s)" % (
|
|
||||||
" OR ".join(["(type = ? AND state_key = ?)"] * len(types)),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
where_clause = ""
|
|
||||||
|
|
||||||
results = {group: {} for group in groups}
|
results = {group: {} for group in groups}
|
||||||
if isinstance(self.database_engine, PostgresEngine):
|
if isinstance(self.database_engine, PostgresEngine):
|
||||||
# Temporarily disable sequential scans in this transaction. This is
|
# Temporarily disable sequential scans in this transaction. This is
|
||||||
|
@ -342,20 +335,43 @@ class StateStore(SQLBaseStore):
|
||||||
WHERE state_group IN (
|
WHERE state_group IN (
|
||||||
SELECT state_group FROM state
|
SELECT state_group FROM state
|
||||||
)
|
)
|
||||||
%s;
|
%s
|
||||||
""") % (where_clause,)
|
""")
|
||||||
|
|
||||||
|
# Turns out that postgres doesn't like doing a list of OR's and
|
||||||
|
# is about 1000x slower, so we just issue a query for each specific
|
||||||
|
# type seperately.
|
||||||
|
if types:
|
||||||
|
clause_to_args = [
|
||||||
|
(
|
||||||
|
"AND type = ? AND state_key = ?",
|
||||||
|
(etype, state_key)
|
||||||
|
)
|
||||||
|
for etype, state_key in types
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
# If types is None we fetch all the state, and so just use an
|
||||||
|
# empty where clause with no extra args.
|
||||||
|
clause_to_args = [("", [])]
|
||||||
|
|
||||||
|
for where_clause, where_args in clause_to_args:
|
||||||
for group in groups:
|
for group in groups:
|
||||||
args = [group]
|
args = [group]
|
||||||
if types is not None:
|
args.extend(where_args)
|
||||||
args.extend([i for typ in types for i in typ])
|
|
||||||
|
|
||||||
txn.execute(sql, args)
|
txn.execute(sql % (where_clause,), args)
|
||||||
rows = self.cursor_to_dict(txn)
|
rows = self.cursor_to_dict(txn)
|
||||||
for row in rows:
|
for row in rows:
|
||||||
key = (row["type"], row["state_key"])
|
key = (row["type"], row["state_key"])
|
||||||
results[group][key] = row["event_id"]
|
results[group][key] = row["event_id"]
|
||||||
else:
|
else:
|
||||||
|
if types is not None:
|
||||||
|
where_clause = "AND (%s)" % (
|
||||||
|
" OR ".join(["(type = ? AND state_key = ?)"] * len(types)),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
where_clause = ""
|
||||||
|
|
||||||
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
|
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
|
||||||
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
|
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
|
||||||
for group in groups:
|
for group in groups:
|
||||||
|
|
Loading…
Add table
Reference in a new issue