mirror of
https://mau.dev/maunium/synapse.git
synced 2025-01-20 06:12:03 +01:00
Support enabling/disabling pushers (from MSC3881) (#13799)
Partial implementation of MSC3881
This commit is contained in:
parent
6bd8763804
commit
8ae42ab8fa
15 changed files with 294 additions and 71 deletions
1
changelog.d/13799.feature
Normal file
1
changelog.d/13799.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add experimental support for [MSC3881: Remotely toggle push notifications for another client](https://github.com/matrix-org/matrix-spec-proposals/pull/3881).
|
|
@ -111,6 +111,7 @@ BOOLEAN_COLUMNS = {
|
||||||
"e2e_fallback_keys_json": ["used"],
|
"e2e_fallback_keys_json": ["used"],
|
||||||
"access_tokens": ["used"],
|
"access_tokens": ["used"],
|
||||||
"device_lists_changes_in_room": ["converted_to_destinations"],
|
"device_lists_changes_in_room": ["converted_to_destinations"],
|
||||||
|
"pushers": ["enabled"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -93,3 +93,6 @@ class ExperimentalConfig(Config):
|
||||||
|
|
||||||
# MSC3852: Expose last seen user agent field on /_matrix/client/v3/devices.
|
# MSC3852: Expose last seen user agent field on /_matrix/client/v3/devices.
|
||||||
self.msc3852_enabled: bool = experimental.get("msc3852_enabled", False)
|
self.msc3852_enabled: bool = experimental.get("msc3852_enabled", False)
|
||||||
|
|
||||||
|
# MSC3881: Remotely toggle push notifications for another client
|
||||||
|
self.msc3881_enabled: bool = experimental.get("msc3881_enabled", False)
|
||||||
|
|
|
@ -997,7 +997,7 @@ class RegistrationHandler:
|
||||||
assert user_tuple
|
assert user_tuple
|
||||||
token_id = user_tuple.token_id
|
token_id = user_tuple.token_id
|
||||||
|
|
||||||
await self.pusher_pool.add_pusher(
|
await self.pusher_pool.add_or_update_pusher(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
access_token=token_id,
|
access_token=token_id,
|
||||||
kind="email",
|
kind="email",
|
||||||
|
@ -1005,7 +1005,7 @@ class RegistrationHandler:
|
||||||
app_display_name="Email Notifications",
|
app_display_name="Email Notifications",
|
||||||
device_display_name=threepid["address"],
|
device_display_name=threepid["address"],
|
||||||
pushkey=threepid["address"],
|
pushkey=threepid["address"],
|
||||||
lang=None, # We don't know a user's language here
|
lang=None,
|
||||||
data={},
|
data={},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -116,6 +116,7 @@ class PusherConfig:
|
||||||
last_stream_ordering: int
|
last_stream_ordering: int
|
||||||
last_success: Optional[int]
|
last_success: Optional[int]
|
||||||
failing_since: Optional[int]
|
failing_since: Optional[int]
|
||||||
|
enabled: bool
|
||||||
|
|
||||||
def as_dict(self) -> Dict[str, Any]:
|
def as_dict(self) -> Dict[str, Any]:
|
||||||
"""Information that can be retrieved about a pusher after creation."""
|
"""Information that can be retrieved about a pusher after creation."""
|
||||||
|
@ -128,6 +129,7 @@ class PusherConfig:
|
||||||
"lang": self.lang,
|
"lang": self.lang,
|
||||||
"profile_tag": self.profile_tag,
|
"profile_tag": self.profile_tag,
|
||||||
"pushkey": self.pushkey,
|
"pushkey": self.pushkey,
|
||||||
|
"enabled": self.enabled,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -94,7 +94,7 @@ class PusherPool:
|
||||||
return
|
return
|
||||||
run_as_background_process("start_pushers", self._start_pushers)
|
run_as_background_process("start_pushers", self._start_pushers)
|
||||||
|
|
||||||
async def add_pusher(
|
async def add_or_update_pusher(
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
access_token: Optional[int],
|
access_token: Optional[int],
|
||||||
|
@ -106,6 +106,7 @@ class PusherPool:
|
||||||
lang: Optional[str],
|
lang: Optional[str],
|
||||||
data: JsonDict,
|
data: JsonDict,
|
||||||
profile_tag: str = "",
|
profile_tag: str = "",
|
||||||
|
enabled: bool = True,
|
||||||
) -> Optional[Pusher]:
|
) -> Optional[Pusher]:
|
||||||
"""Creates a new pusher and adds it to the pool
|
"""Creates a new pusher and adds it to the pool
|
||||||
|
|
||||||
|
@ -147,9 +148,20 @@ class PusherPool:
|
||||||
last_stream_ordering=last_stream_ordering,
|
last_stream_ordering=last_stream_ordering,
|
||||||
last_success=None,
|
last_success=None,
|
||||||
failing_since=None,
|
failing_since=None,
|
||||||
|
enabled=enabled,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Before we actually persist the pusher, we check if the user already has one
|
||||||
|
# for this app ID and pushkey. If so, we want to keep the access token in place,
|
||||||
|
# since this could be one device modifying (e.g. enabling/disabling) another
|
||||||
|
# device's pusher.
|
||||||
|
existing_config = await self._get_pusher_config_for_user_by_app_id_and_pushkey(
|
||||||
|
user_id, app_id, pushkey
|
||||||
|
)
|
||||||
|
if existing_config:
|
||||||
|
access_token = existing_config.access_token
|
||||||
|
|
||||||
await self.store.add_pusher(
|
await self.store.add_pusher(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
|
@ -163,8 +175,9 @@ class PusherPool:
|
||||||
data=data,
|
data=data,
|
||||||
last_stream_ordering=last_stream_ordering,
|
last_stream_ordering=last_stream_ordering,
|
||||||
profile_tag=profile_tag,
|
profile_tag=profile_tag,
|
||||||
|
enabled=enabled,
|
||||||
)
|
)
|
||||||
pusher = await self.start_pusher_by_id(app_id, pushkey, user_id)
|
pusher = await self.process_pusher_change_by_id(app_id, pushkey, user_id)
|
||||||
|
|
||||||
return pusher
|
return pusher
|
||||||
|
|
||||||
|
@ -276,10 +289,25 @@ class PusherPool:
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Exception in pusher on_new_receipts")
|
logger.exception("Exception in pusher on_new_receipts")
|
||||||
|
|
||||||
async def start_pusher_by_id(
|
async def _get_pusher_config_for_user_by_app_id_and_pushkey(
|
||||||
|
self, user_id: str, app_id: str, pushkey: str
|
||||||
|
) -> Optional[PusherConfig]:
|
||||||
|
resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
|
||||||
|
|
||||||
|
pusher_config = None
|
||||||
|
for r in resultlist:
|
||||||
|
if r.user_name == user_id:
|
||||||
|
pusher_config = r
|
||||||
|
|
||||||
|
return pusher_config
|
||||||
|
|
||||||
|
async def process_pusher_change_by_id(
|
||||||
self, app_id: str, pushkey: str, user_id: str
|
self, app_id: str, pushkey: str, user_id: str
|
||||||
) -> Optional[Pusher]:
|
) -> Optional[Pusher]:
|
||||||
"""Look up the details for the given pusher, and start it
|
"""Look up the details for the given pusher, and either start it if its
|
||||||
|
"enabled" flag is True, or try to stop it otherwise.
|
||||||
|
|
||||||
|
If the pusher is new and its "enabled" flag is False, the stop is a noop.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The pusher started, if any
|
The pusher started, if any
|
||||||
|
@ -290,12 +318,13 @@ class PusherPool:
|
||||||
if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
|
if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
|
pusher_config = await self._get_pusher_config_for_user_by_app_id_and_pushkey(
|
||||||
|
user_id, app_id, pushkey
|
||||||
|
)
|
||||||
|
|
||||||
pusher_config = None
|
if pusher_config and not pusher_config.enabled:
|
||||||
for r in resultlist:
|
self.maybe_stop_pusher(app_id, pushkey, user_id)
|
||||||
if r.user_name == user_id:
|
return None
|
||||||
pusher_config = r
|
|
||||||
|
|
||||||
pusher = None
|
pusher = None
|
||||||
if pusher_config:
|
if pusher_config:
|
||||||
|
@ -305,7 +334,7 @@ class PusherPool:
|
||||||
|
|
||||||
async def _start_pushers(self) -> None:
|
async def _start_pushers(self) -> None:
|
||||||
"""Start all the pushers"""
|
"""Start all the pushers"""
|
||||||
pushers = await self.store.get_all_pushers()
|
pushers = await self.store.get_enabled_pushers()
|
||||||
|
|
||||||
# Stagger starting up the pushers so we don't completely drown the
|
# Stagger starting up the pushers so we don't completely drown the
|
||||||
# process on start up.
|
# process on start up.
|
||||||
|
@ -363,6 +392,8 @@ class PusherPool:
|
||||||
|
|
||||||
synapse_pushers.labels(type(pusher).__name__, pusher.app_id).inc()
|
synapse_pushers.labels(type(pusher).__name__, pusher.app_id).inc()
|
||||||
|
|
||||||
|
logger.info("Starting pusher %s / %s", pusher.user_id, appid_pushkey)
|
||||||
|
|
||||||
# Check if there *may* be push to process. We do this as this check is a
|
# Check if there *may* be push to process. We do this as this check is a
|
||||||
# lot cheaper to do than actually fetching the exact rows we need to
|
# lot cheaper to do than actually fetching the exact rows we need to
|
||||||
# push.
|
# push.
|
||||||
|
@ -382,16 +413,7 @@ class PusherPool:
|
||||||
return pusher
|
return pusher
|
||||||
|
|
||||||
async def remove_pusher(self, app_id: str, pushkey: str, user_id: str) -> None:
|
async def remove_pusher(self, app_id: str, pushkey: str, user_id: str) -> None:
|
||||||
appid_pushkey = "%s:%s" % (app_id, pushkey)
|
self.maybe_stop_pusher(app_id, pushkey, user_id)
|
||||||
|
|
||||||
byuser = self.pushers.get(user_id, {})
|
|
||||||
|
|
||||||
if appid_pushkey in byuser:
|
|
||||||
logger.info("Stopping pusher %s / %s", user_id, appid_pushkey)
|
|
||||||
pusher = byuser.pop(appid_pushkey)
|
|
||||||
pusher.on_stop()
|
|
||||||
|
|
||||||
synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec()
|
|
||||||
|
|
||||||
# We can only delete pushers on master.
|
# We can only delete pushers on master.
|
||||||
if self._remove_pusher_client:
|
if self._remove_pusher_client:
|
||||||
|
@ -402,3 +424,22 @@ class PusherPool:
|
||||||
await self.store.delete_pusher_by_app_id_pushkey_user_id(
|
await self.store.delete_pusher_by_app_id_pushkey_user_id(
|
||||||
app_id, pushkey, user_id
|
app_id, pushkey, user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def maybe_stop_pusher(self, app_id: str, pushkey: str, user_id: str) -> None:
|
||||||
|
"""Stops a pusher with the given app ID and push key if one is running.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app_id: the pusher's app ID.
|
||||||
|
pushkey: the pusher's push key.
|
||||||
|
user_id: the user the pusher belongs to. Only used for logging.
|
||||||
|
"""
|
||||||
|
appid_pushkey = "%s:%s" % (app_id, pushkey)
|
||||||
|
|
||||||
|
byuser = self.pushers.get(user_id, {})
|
||||||
|
|
||||||
|
if appid_pushkey in byuser:
|
||||||
|
logger.info("Stopping pusher %s / %s", user_id, appid_pushkey)
|
||||||
|
pusher = byuser.pop(appid_pushkey)
|
||||||
|
pusher.on_stop()
|
||||||
|
|
||||||
|
synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec()
|
||||||
|
|
|
@ -189,7 +189,9 @@ class ReplicationDataHandler:
|
||||||
if row.deleted:
|
if row.deleted:
|
||||||
self.stop_pusher(row.user_id, row.app_id, row.pushkey)
|
self.stop_pusher(row.user_id, row.app_id, row.pushkey)
|
||||||
else:
|
else:
|
||||||
await self.start_pusher(row.user_id, row.app_id, row.pushkey)
|
await self.process_pusher_change(
|
||||||
|
row.user_id, row.app_id, row.pushkey
|
||||||
|
)
|
||||||
elif stream_name == EventsStream.NAME:
|
elif stream_name == EventsStream.NAME:
|
||||||
# We shouldn't get multiple rows per token for events stream, so
|
# We shouldn't get multiple rows per token for events stream, so
|
||||||
# we don't need to optimise this for multiple rows.
|
# we don't need to optimise this for multiple rows.
|
||||||
|
@ -334,13 +336,15 @@ class ReplicationDataHandler:
|
||||||
logger.info("Stopping pusher %r / %r", user_id, key)
|
logger.info("Stopping pusher %r / %r", user_id, key)
|
||||||
pusher.on_stop()
|
pusher.on_stop()
|
||||||
|
|
||||||
async def start_pusher(self, user_id: str, app_id: str, pushkey: str) -> None:
|
async def process_pusher_change(
|
||||||
|
self, user_id: str, app_id: str, pushkey: str
|
||||||
|
) -> None:
|
||||||
if not self._notify_pushers:
|
if not self._notify_pushers:
|
||||||
return
|
return
|
||||||
|
|
||||||
key = "%s:%s" % (app_id, pushkey)
|
key = "%s:%s" % (app_id, pushkey)
|
||||||
logger.info("Starting pusher %r / %r", user_id, key)
|
logger.info("Starting pusher %r / %r", user_id, key)
|
||||||
await self._pusher_pool.start_pusher_by_id(app_id, pushkey, user_id)
|
await self._pusher_pool.process_pusher_change_by_id(app_id, pushkey, user_id)
|
||||||
|
|
||||||
|
|
||||||
class FederationSenderHandler:
|
class FederationSenderHandler:
|
||||||
|
|
|
@ -375,7 +375,7 @@ class UserRestServletV2(RestServlet):
|
||||||
and self.hs.config.email.email_notif_for_new_users
|
and self.hs.config.email.email_notif_for_new_users
|
||||||
and medium == "email"
|
and medium == "email"
|
||||||
):
|
):
|
||||||
await self.pusher_pool.add_pusher(
|
await self.pusher_pool.add_or_update_pusher(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
access_token=None,
|
access_token=None,
|
||||||
kind="email",
|
kind="email",
|
||||||
|
@ -383,7 +383,7 @@ class UserRestServletV2(RestServlet):
|
||||||
app_display_name="Email Notifications",
|
app_display_name="Email Notifications",
|
||||||
device_display_name=address,
|
device_display_name=address,
|
||||||
pushkey=address,
|
pushkey=address,
|
||||||
lang=None, # We don't know a user's language here
|
lang=None,
|
||||||
data={},
|
data={},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -42,6 +42,7 @@ class PushersRestServlet(RestServlet):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
|
self._msc3881_enabled = self.hs.config.experimental.msc3881_enabled
|
||||||
|
|
||||||
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
|
@ -51,9 +52,14 @@ class PushersRestServlet(RestServlet):
|
||||||
user.to_string()
|
user.to_string()
|
||||||
)
|
)
|
||||||
|
|
||||||
filtered_pushers = [p.as_dict() for p in pushers]
|
pusher_dicts = [p.as_dict() for p in pushers]
|
||||||
|
|
||||||
return 200, {"pushers": filtered_pushers}
|
for pusher in pusher_dicts:
|
||||||
|
if self._msc3881_enabled:
|
||||||
|
pusher["org.matrix.msc3881.enabled"] = pusher["enabled"]
|
||||||
|
del pusher["enabled"]
|
||||||
|
|
||||||
|
return 200, {"pushers": pusher_dicts}
|
||||||
|
|
||||||
|
|
||||||
class PushersSetRestServlet(RestServlet):
|
class PushersSetRestServlet(RestServlet):
|
||||||
|
@ -65,6 +71,7 @@ class PushersSetRestServlet(RestServlet):
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.notifier = hs.get_notifier()
|
self.notifier = hs.get_notifier()
|
||||||
self.pusher_pool = self.hs.get_pusherpool()
|
self.pusher_pool = self.hs.get_pusherpool()
|
||||||
|
self._msc3881_enabled = self.hs.config.experimental.msc3881_enabled
|
||||||
|
|
||||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||||
requester = await self.auth.get_user_by_req(request)
|
requester = await self.auth.get_user_by_req(request)
|
||||||
|
@ -103,6 +110,10 @@ class PushersSetRestServlet(RestServlet):
|
||||||
if "append" in content:
|
if "append" in content:
|
||||||
append = content["append"]
|
append = content["append"]
|
||||||
|
|
||||||
|
enabled = True
|
||||||
|
if self._msc3881_enabled and "org.matrix.msc3881.enabled" in content:
|
||||||
|
enabled = content["org.matrix.msc3881.enabled"]
|
||||||
|
|
||||||
if not append:
|
if not append:
|
||||||
await self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user(
|
await self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user(
|
||||||
app_id=content["app_id"],
|
app_id=content["app_id"],
|
||||||
|
@ -111,7 +122,7 @@ class PushersSetRestServlet(RestServlet):
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.pusher_pool.add_pusher(
|
await self.pusher_pool.add_or_update_pusher(
|
||||||
user_id=user.to_string(),
|
user_id=user.to_string(),
|
||||||
access_token=requester.access_token_id,
|
access_token=requester.access_token_id,
|
||||||
kind=content["kind"],
|
kind=content["kind"],
|
||||||
|
@ -122,6 +133,7 @@ class PushersSetRestServlet(RestServlet):
|
||||||
lang=content["lang"],
|
lang=content["lang"],
|
||||||
data=content["data"],
|
data=content["data"],
|
||||||
profile_tag=content.get("profile_tag", ""),
|
profile_tag=content.get("profile_tag", ""),
|
||||||
|
enabled=enabled,
|
||||||
)
|
)
|
||||||
except PusherConfigException as pce:
|
except PusherConfigException as pce:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
|
|
|
@ -89,6 +89,11 @@ class PusherWorkerStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# If we're using SQLite, then boolean values are integers. This is
|
||||||
|
# troublesome since some code using the return value of this method might
|
||||||
|
# expect it to be a boolean, or will expose it to clients (in responses).
|
||||||
|
r["enabled"] = bool(r["enabled"])
|
||||||
|
|
||||||
yield PusherConfig(**r)
|
yield PusherConfig(**r)
|
||||||
|
|
||||||
async def get_pushers_by_app_id_and_pushkey(
|
async def get_pushers_by_app_id_and_pushkey(
|
||||||
|
@ -100,38 +105,52 @@ class PusherWorkerStore(SQLBaseStore):
|
||||||
return await self.get_pushers_by({"user_name": user_id})
|
return await self.get_pushers_by({"user_name": user_id})
|
||||||
|
|
||||||
async def get_pushers_by(self, keyvalues: Dict[str, Any]) -> Iterator[PusherConfig]:
|
async def get_pushers_by(self, keyvalues: Dict[str, Any]) -> Iterator[PusherConfig]:
|
||||||
ret = await self.db_pool.simple_select_list(
|
"""Retrieve pushers that match the given criteria.
|
||||||
"pushers",
|
|
||||||
keyvalues,
|
Args:
|
||||||
[
|
keyvalues: A {column: value} dictionary.
|
||||||
"id",
|
|
||||||
"user_name",
|
Returns:
|
||||||
"access_token",
|
The pushers for which the given columns have the given values.
|
||||||
"profile_tag",
|
"""
|
||||||
"kind",
|
|
||||||
"app_id",
|
def get_pushers_by_txn(txn: LoggingTransaction) -> List[Dict[str, Any]]:
|
||||||
"app_display_name",
|
# We could technically use simple_select_list here, but we need to call
|
||||||
"device_display_name",
|
# COALESCE on the 'enabled' column. While it is technically possible to give
|
||||||
"pushkey",
|
# simple_select_list the whole `COALESCE(...) AS ...` as a column name, it
|
||||||
"ts",
|
# feels a bit hacky, so it's probably better to just inline the query.
|
||||||
"lang",
|
sql = """
|
||||||
"data",
|
SELECT
|
||||||
"last_stream_ordering",
|
id, user_name, access_token, profile_tag, kind, app_id,
|
||||||
"last_success",
|
app_display_name, device_display_name, pushkey, ts, lang, data,
|
||||||
"failing_since",
|
last_stream_ordering, last_success, failing_since,
|
||||||
],
|
COALESCE(enabled, TRUE) AS enabled
|
||||||
|
FROM pushers
|
||||||
|
"""
|
||||||
|
|
||||||
|
sql += "WHERE %s" % (" AND ".join("%s = ?" % (k,) for k in keyvalues),)
|
||||||
|
|
||||||
|
txn.execute(sql, list(keyvalues.values()))
|
||||||
|
|
||||||
|
return self.db_pool.cursor_to_dict(txn)
|
||||||
|
|
||||||
|
ret = await self.db_pool.runInteraction(
|
||||||
desc="get_pushers_by",
|
desc="get_pushers_by",
|
||||||
|
func=get_pushers_by_txn,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self._decode_pushers_rows(ret)
|
return self._decode_pushers_rows(ret)
|
||||||
|
|
||||||
async def get_all_pushers(self) -> Iterator[PusherConfig]:
|
async def get_enabled_pushers(self) -> Iterator[PusherConfig]:
|
||||||
def get_pushers(txn: LoggingTransaction) -> Iterator[PusherConfig]:
|
def get_enabled_pushers_txn(txn: LoggingTransaction) -> Iterator[PusherConfig]:
|
||||||
txn.execute("SELECT * FROM pushers")
|
txn.execute("SELECT * FROM pushers WHERE COALESCE(enabled, TRUE)")
|
||||||
rows = self.db_pool.cursor_to_dict(txn)
|
rows = self.db_pool.cursor_to_dict(txn)
|
||||||
|
|
||||||
return self._decode_pushers_rows(rows)
|
return self._decode_pushers_rows(rows)
|
||||||
|
|
||||||
return await self.db_pool.runInteraction("get_all_pushers", get_pushers)
|
return await self.db_pool.runInteraction(
|
||||||
|
"get_enabled_pushers", get_enabled_pushers_txn
|
||||||
|
)
|
||||||
|
|
||||||
async def get_all_updated_pushers_rows(
|
async def get_all_updated_pushers_rows(
|
||||||
self, instance_name: str, last_id: int, current_id: int, limit: int
|
self, instance_name: str, last_id: int, current_id: int, limit: int
|
||||||
|
@ -476,6 +495,7 @@ class PusherStore(PusherWorkerStore):
|
||||||
data: Optional[JsonDict],
|
data: Optional[JsonDict],
|
||||||
last_stream_ordering: int,
|
last_stream_ordering: int,
|
||||||
profile_tag: str = "",
|
profile_tag: str = "",
|
||||||
|
enabled: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
async with self._pushers_id_gen.get_next() as stream_id:
|
async with self._pushers_id_gen.get_next() as stream_id:
|
||||||
# no need to lock because `pushers` has a unique key on
|
# no need to lock because `pushers` has a unique key on
|
||||||
|
@ -494,6 +514,7 @@ class PusherStore(PusherWorkerStore):
|
||||||
"last_stream_ordering": last_stream_ordering,
|
"last_stream_ordering": last_stream_ordering,
|
||||||
"profile_tag": profile_tag,
|
"profile_tag": profile_tag,
|
||||||
"id": stream_id,
|
"id": stream_id,
|
||||||
|
"enabled": enabled,
|
||||||
},
|
},
|
||||||
desc="add_pusher",
|
desc="add_pusher",
|
||||||
lock=False,
|
lock=False,
|
||||||
|
|
|
@ -0,0 +1,16 @@
|
||||||
|
/* Copyright 2022 The Matrix.org Foundation C.I.C
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
ALTER TABLE pushers ADD COLUMN enabled BOOLEAN;
|
|
@ -114,7 +114,7 @@ class EmailPusherTests(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.pusher = self.get_success(
|
self.pusher = self.get_success(
|
||||||
self.hs.get_pusherpool().add_pusher(
|
self.hs.get_pusherpool().add_or_update_pusher(
|
||||||
user_id=self.user_id,
|
user_id=self.user_id,
|
||||||
access_token=self.token_id,
|
access_token=self.token_id,
|
||||||
kind="email",
|
kind="email",
|
||||||
|
@ -136,7 +136,7 @@ class EmailPusherTests(HomeserverTestCase):
|
||||||
"""
|
"""
|
||||||
with self.assertRaises(SynapseError) as cm:
|
with self.assertRaises(SynapseError) as cm:
|
||||||
self.get_success_or_raise(
|
self.get_success_or_raise(
|
||||||
self.hs.get_pusherpool().add_pusher(
|
self.hs.get_pusherpool().add_or_update_pusher(
|
||||||
user_id=self.user_id,
|
user_id=self.user_id,
|
||||||
access_token=self.token_id,
|
access_token=self.token_id,
|
||||||
kind="email",
|
kind="email",
|
||||||
|
|
|
@ -19,8 +19,8 @@ from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
import synapse.rest.admin
|
import synapse.rest.admin
|
||||||
from synapse.logging.context import make_deferred_yieldable
|
from synapse.logging.context import make_deferred_yieldable
|
||||||
from synapse.push import PusherConfigException
|
from synapse.push import PusherConfig, PusherConfigException
|
||||||
from synapse.rest.client import login, push_rule, receipts, room
|
from synapse.rest.client import login, push_rule, pusher, receipts, room
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
@ -35,6 +35,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
login.register_servlets,
|
login.register_servlets,
|
||||||
receipts.register_servlets,
|
receipts.register_servlets,
|
||||||
push_rule.register_servlets,
|
push_rule.register_servlets,
|
||||||
|
pusher.register_servlets,
|
||||||
]
|
]
|
||||||
user_id = True
|
user_id = True
|
||||||
hijack_auth = False
|
hijack_auth = False
|
||||||
|
@ -74,7 +75,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
|
|
||||||
def test_data(data: Optional[JsonDict]) -> None:
|
def test_data(data: Optional[JsonDict]) -> None:
|
||||||
self.get_failure(
|
self.get_failure(
|
||||||
self.hs.get_pusherpool().add_pusher(
|
self.hs.get_pusherpool().add_or_update_pusher(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
access_token=token_id,
|
access_token=token_id,
|
||||||
kind="http",
|
kind="http",
|
||||||
|
@ -119,7 +120,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
token_id = user_tuple.token_id
|
token_id = user_tuple.token_id
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.hs.get_pusherpool().add_pusher(
|
self.hs.get_pusherpool().add_or_update_pusher(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
access_token=token_id,
|
access_token=token_id,
|
||||||
kind="http",
|
kind="http",
|
||||||
|
@ -235,7 +236,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
token_id = user_tuple.token_id
|
token_id = user_tuple.token_id
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.hs.get_pusherpool().add_pusher(
|
self.hs.get_pusherpool().add_or_update_pusher(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
access_token=token_id,
|
access_token=token_id,
|
||||||
kind="http",
|
kind="http",
|
||||||
|
@ -355,7 +356,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
token_id = user_tuple.token_id
|
token_id = user_tuple.token_id
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.hs.get_pusherpool().add_pusher(
|
self.hs.get_pusherpool().add_or_update_pusher(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
access_token=token_id,
|
access_token=token_id,
|
||||||
kind="http",
|
kind="http",
|
||||||
|
@ -441,7 +442,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
token_id = user_tuple.token_id
|
token_id = user_tuple.token_id
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.hs.get_pusherpool().add_pusher(
|
self.hs.get_pusherpool().add_or_update_pusher(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
access_token=token_id,
|
access_token=token_id,
|
||||||
kind="http",
|
kind="http",
|
||||||
|
@ -518,7 +519,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
token_id = user_tuple.token_id
|
token_id = user_tuple.token_id
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.hs.get_pusherpool().add_pusher(
|
self.hs.get_pusherpool().add_or_update_pusher(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
access_token=token_id,
|
access_token=token_id,
|
||||||
kind="http",
|
kind="http",
|
||||||
|
@ -624,7 +625,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
token_id = user_tuple.token_id
|
token_id = user_tuple.token_id
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.hs.get_pusherpool().add_pusher(
|
self.hs.get_pusherpool().add_or_update_pusher(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
access_token=token_id,
|
access_token=token_id,
|
||||||
kind="http",
|
kind="http",
|
||||||
|
@ -728,18 +729,38 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
self.assertEqual(channel.code, 200, channel.json_body)
|
self.assertEqual(channel.code, 200, channel.json_body)
|
||||||
|
|
||||||
def _make_user_with_pusher(self, username: str) -> Tuple[str, str]:
|
def _make_user_with_pusher(
|
||||||
|
self, username: str, enabled: bool = True
|
||||||
|
) -> Tuple[str, str]:
|
||||||
|
"""Registers a user and creates a pusher for them.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
username: the localpart of the new user's Matrix ID.
|
||||||
|
enabled: whether to create the pusher in an enabled or disabled state.
|
||||||
|
"""
|
||||||
user_id = self.register_user(username, "pass")
|
user_id = self.register_user(username, "pass")
|
||||||
access_token = self.login(username, "pass")
|
access_token = self.login(username, "pass")
|
||||||
|
|
||||||
# Register the pusher
|
# Register the pusher
|
||||||
|
self._set_pusher(user_id, access_token, enabled)
|
||||||
|
|
||||||
|
return user_id, access_token
|
||||||
|
|
||||||
|
def _set_pusher(self, user_id: str, access_token: str, enabled: bool) -> None:
|
||||||
|
"""Creates or updates the pusher for the given user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: the user's Matrix ID.
|
||||||
|
access_token: the access token associated with the pusher.
|
||||||
|
enabled: whether to enable or disable the pusher.
|
||||||
|
"""
|
||||||
user_tuple = self.get_success(
|
user_tuple = self.get_success(
|
||||||
self.hs.get_datastores().main.get_user_by_access_token(access_token)
|
self.hs.get_datastores().main.get_user_by_access_token(access_token)
|
||||||
)
|
)
|
||||||
token_id = user_tuple.token_id
|
token_id = user_tuple.token_id
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.hs.get_pusherpool().add_pusher(
|
self.hs.get_pusherpool().add_or_update_pusher(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
access_token=token_id,
|
access_token=token_id,
|
||||||
kind="http",
|
kind="http",
|
||||||
|
@ -749,11 +770,10 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
pushkey="a@example.com",
|
pushkey="a@example.com",
|
||||||
lang=None,
|
lang=None,
|
||||||
data={"url": "http://example.com/_matrix/push/v1/notify"},
|
data={"url": "http://example.com/_matrix/push/v1/notify"},
|
||||||
|
enabled=enabled,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return user_id, access_token
|
|
||||||
|
|
||||||
def test_dont_notify_rule_overrides_message(self) -> None:
|
def test_dont_notify_rule_overrides_message(self) -> None:
|
||||||
"""
|
"""
|
||||||
The override push rule will suppress notification
|
The override push rule will suppress notification
|
||||||
|
@ -791,3 +811,105 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||||
# The user sends a message back (sends a notification)
|
# The user sends a message back (sends a notification)
|
||||||
self.helper.send(room, body="Hello", tok=access_token)
|
self.helper.send(room, body="Hello", tok=access_token)
|
||||||
self.assertEqual(len(self.push_attempts), 1)
|
self.assertEqual(len(self.push_attempts), 1)
|
||||||
|
|
||||||
|
@override_config({"experimental_features": {"msc3881_enabled": True}})
|
||||||
|
def test_disable(self) -> None:
|
||||||
|
"""Tests that disabling a pusher means it's not pushed to anymore."""
|
||||||
|
user_id, access_token = self._make_user_with_pusher("user")
|
||||||
|
other_user_id, other_access_token = self._make_user_with_pusher("otheruser")
|
||||||
|
|
||||||
|
room = self.helper.create_room_as(user_id, tok=access_token)
|
||||||
|
self.helper.join(room=room, user=other_user_id, tok=other_access_token)
|
||||||
|
|
||||||
|
# Send a message and check that it generated a push.
|
||||||
|
self.helper.send(room, body="Hi!", tok=other_access_token)
|
||||||
|
self.assertEqual(len(self.push_attempts), 1)
|
||||||
|
|
||||||
|
# Disable the pusher.
|
||||||
|
self._set_pusher(user_id, access_token, enabled=False)
|
||||||
|
|
||||||
|
# Send another message and check that it did not generate a push.
|
||||||
|
self.helper.send(room, body="Hi!", tok=other_access_token)
|
||||||
|
self.assertEqual(len(self.push_attempts), 1)
|
||||||
|
|
||||||
|
# Get the pushers for the user and check that it is marked as disabled.
|
||||||
|
channel = self.make_request("GET", "/pushers", access_token=access_token)
|
||||||
|
self.assertEqual(channel.code, 200)
|
||||||
|
self.assertEqual(len(channel.json_body["pushers"]), 1)
|
||||||
|
|
||||||
|
enabled = channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"]
|
||||||
|
self.assertFalse(enabled)
|
||||||
|
self.assertTrue(isinstance(enabled, bool))
|
||||||
|
|
||||||
|
@override_config({"experimental_features": {"msc3881_enabled": True}})
|
||||||
|
def test_enable(self) -> None:
|
||||||
|
"""Tests that enabling a disabled pusher means it gets pushed to."""
|
||||||
|
# Create the user with the pusher already disabled.
|
||||||
|
user_id, access_token = self._make_user_with_pusher("user", enabled=False)
|
||||||
|
other_user_id, other_access_token = self._make_user_with_pusher("otheruser")
|
||||||
|
|
||||||
|
room = self.helper.create_room_as(user_id, tok=access_token)
|
||||||
|
self.helper.join(room=room, user=other_user_id, tok=other_access_token)
|
||||||
|
|
||||||
|
# Send a message and check that it did not generate a push.
|
||||||
|
self.helper.send(room, body="Hi!", tok=other_access_token)
|
||||||
|
self.assertEqual(len(self.push_attempts), 0)
|
||||||
|
|
||||||
|
# Enable the pusher.
|
||||||
|
self._set_pusher(user_id, access_token, enabled=True)
|
||||||
|
|
||||||
|
# Send another message and check that it did generate a push.
|
||||||
|
self.helper.send(room, body="Hi!", tok=other_access_token)
|
||||||
|
self.assertEqual(len(self.push_attempts), 1)
|
||||||
|
|
||||||
|
# Get the pushers for the user and check that it is marked as enabled.
|
||||||
|
channel = self.make_request("GET", "/pushers", access_token=access_token)
|
||||||
|
self.assertEqual(channel.code, 200)
|
||||||
|
self.assertEqual(len(channel.json_body["pushers"]), 1)
|
||||||
|
|
||||||
|
enabled = channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"]
|
||||||
|
self.assertTrue(enabled)
|
||||||
|
self.assertTrue(isinstance(enabled, bool))
|
||||||
|
|
||||||
|
@override_config({"experimental_features": {"msc3881_enabled": True}})
|
||||||
|
def test_null_enabled(self) -> None:
|
||||||
|
"""Tests that a pusher that has an 'enabled' column set to NULL (eg pushers
|
||||||
|
created before the column was introduced) is considered enabled.
|
||||||
|
"""
|
||||||
|
# We intentionally set 'enabled' to None so that it's stored as NULL in the
|
||||||
|
# database.
|
||||||
|
user_id, access_token = self._make_user_with_pusher("user", enabled=None) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
channel = self.make_request("GET", "/pushers", access_token=access_token)
|
||||||
|
self.assertEqual(channel.code, 200)
|
||||||
|
self.assertEqual(len(channel.json_body["pushers"]), 1)
|
||||||
|
self.assertTrue(channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"])
|
||||||
|
|
||||||
|
def test_update_different_device_access_token(self) -> None:
|
||||||
|
"""Tests that if we create a pusher from one device, the update it from another
|
||||||
|
device, the access token associated with the pusher stays the same.
|
||||||
|
"""
|
||||||
|
# Create a user with a pusher.
|
||||||
|
user_id, access_token = self._make_user_with_pusher("user")
|
||||||
|
|
||||||
|
# Get the token ID for the current access token, since that's what we store in
|
||||||
|
# the pushers table.
|
||||||
|
user_tuple = self.get_success(
|
||||||
|
self.hs.get_datastores().main.get_user_by_access_token(access_token)
|
||||||
|
)
|
||||||
|
token_id = user_tuple.token_id
|
||||||
|
|
||||||
|
# Generate a new access token, and update the pusher with it.
|
||||||
|
new_token = self.login("user", "pass")
|
||||||
|
self._set_pusher(user_id, new_token, enabled=False)
|
||||||
|
|
||||||
|
# Get the current list of pushers for the user.
|
||||||
|
ret = self.get_success(
|
||||||
|
self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
|
||||||
|
)
|
||||||
|
pushers: List[PusherConfig] = list(ret)
|
||||||
|
|
||||||
|
# Check that we still have one pusher, and that the access token associated with
|
||||||
|
# it didn't change.
|
||||||
|
self.assertEqual(len(pushers), 1)
|
||||||
|
self.assertEqual(pushers[0].access_token, token_id)
|
||||||
|
|
|
@ -55,7 +55,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
token_id = user_dict.token_id
|
token_id = user_dict.token_id
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.hs.get_pusherpool().add_pusher(
|
self.hs.get_pusherpool().add_or_update_pusher(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
access_token=token_id,
|
access_token=token_id,
|
||||||
kind="http",
|
kind="http",
|
||||||
|
|
|
@ -2839,7 +2839,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
|
||||||
token_id = user_tuple.token_id
|
token_id = user_tuple.token_id
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.hs.get_pusherpool().add_pusher(
|
self.hs.get_pusherpool().add_or_update_pusher(
|
||||||
user_id=self.other_user,
|
user_id=self.other_user,
|
||||||
access_token=token_id,
|
access_token=token_id,
|
||||||
kind="http",
|
kind="http",
|
||||||
|
|
Loading…
Add table
Reference in a new issue