mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-30 17:44:07 +01:00
Support enabling/disabling pushers (from MSC3881) (#13799)
Partial implementation of MSC3881
This commit is contained in:
parent
6bd8763804
commit
8ae42ab8fa
15 changed files with 294 additions and 71 deletions
1
changelog.d/13799.feature
Normal file
1
changelog.d/13799.feature
Normal file
|
@ -0,0 +1 @@
|
|||
Add experimental support for [MSC3881: Remotely toggle push notifications for another client](https://github.com/matrix-org/matrix-spec-proposals/pull/3881).
|
|
@ -111,6 +111,7 @@ BOOLEAN_COLUMNS = {
|
|||
"e2e_fallback_keys_json": ["used"],
|
||||
"access_tokens": ["used"],
|
||||
"device_lists_changes_in_room": ["converted_to_destinations"],
|
||||
"pushers": ["enabled"],
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -93,3 +93,6 @@ class ExperimentalConfig(Config):
|
|||
|
||||
# MSC3852: Expose last seen user agent field on /_matrix/client/v3/devices.
|
||||
self.msc3852_enabled: bool = experimental.get("msc3852_enabled", False)
|
||||
|
||||
# MSC3881: Remotely toggle push notifications for another client
|
||||
self.msc3881_enabled: bool = experimental.get("msc3881_enabled", False)
|
||||
|
|
|
@ -997,7 +997,7 @@ class RegistrationHandler:
|
|||
assert user_tuple
|
||||
token_id = user_tuple.token_id
|
||||
|
||||
await self.pusher_pool.add_pusher(
|
||||
await self.pusher_pool.add_or_update_pusher(
|
||||
user_id=user_id,
|
||||
access_token=token_id,
|
||||
kind="email",
|
||||
|
@ -1005,7 +1005,7 @@ class RegistrationHandler:
|
|||
app_display_name="Email Notifications",
|
||||
device_display_name=threepid["address"],
|
||||
pushkey=threepid["address"],
|
||||
lang=None, # We don't know a user's language here
|
||||
lang=None,
|
||||
data={},
|
||||
)
|
||||
|
||||
|
|
|
@ -116,6 +116,7 @@ class PusherConfig:
|
|||
last_stream_ordering: int
|
||||
last_success: Optional[int]
|
||||
failing_since: Optional[int]
|
||||
enabled: bool
|
||||
|
||||
def as_dict(self) -> Dict[str, Any]:
|
||||
"""Information that can be retrieved about a pusher after creation."""
|
||||
|
@ -128,6 +129,7 @@ class PusherConfig:
|
|||
"lang": self.lang,
|
||||
"profile_tag": self.profile_tag,
|
||||
"pushkey": self.pushkey,
|
||||
"enabled": self.enabled,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -94,7 +94,7 @@ class PusherPool:
|
|||
return
|
||||
run_as_background_process("start_pushers", self._start_pushers)
|
||||
|
||||
async def add_pusher(
|
||||
async def add_or_update_pusher(
|
||||
self,
|
||||
user_id: str,
|
||||
access_token: Optional[int],
|
||||
|
@ -106,6 +106,7 @@ class PusherPool:
|
|||
lang: Optional[str],
|
||||
data: JsonDict,
|
||||
profile_tag: str = "",
|
||||
enabled: bool = True,
|
||||
) -> Optional[Pusher]:
|
||||
"""Creates a new pusher and adds it to the pool
|
||||
|
||||
|
@ -147,9 +148,20 @@ class PusherPool:
|
|||
last_stream_ordering=last_stream_ordering,
|
||||
last_success=None,
|
||||
failing_since=None,
|
||||
enabled=enabled,
|
||||
)
|
||||
)
|
||||
|
||||
# Before we actually persist the pusher, we check if the user already has one
|
||||
# for this app ID and pushkey. If so, we want to keep the access token in place,
|
||||
# since this could be one device modifying (e.g. enabling/disabling) another
|
||||
# device's pusher.
|
||||
existing_config = await self._get_pusher_config_for_user_by_app_id_and_pushkey(
|
||||
user_id, app_id, pushkey
|
||||
)
|
||||
if existing_config:
|
||||
access_token = existing_config.access_token
|
||||
|
||||
await self.store.add_pusher(
|
||||
user_id=user_id,
|
||||
access_token=access_token,
|
||||
|
@ -163,8 +175,9 @@ class PusherPool:
|
|||
data=data,
|
||||
last_stream_ordering=last_stream_ordering,
|
||||
profile_tag=profile_tag,
|
||||
enabled=enabled,
|
||||
)
|
||||
pusher = await self.start_pusher_by_id(app_id, pushkey, user_id)
|
||||
pusher = await self.process_pusher_change_by_id(app_id, pushkey, user_id)
|
||||
|
||||
return pusher
|
||||
|
||||
|
@ -276,10 +289,25 @@ class PusherPool:
|
|||
except Exception:
|
||||
logger.exception("Exception in pusher on_new_receipts")
|
||||
|
||||
async def start_pusher_by_id(
|
||||
async def _get_pusher_config_for_user_by_app_id_and_pushkey(
|
||||
self, user_id: str, app_id: str, pushkey: str
|
||||
) -> Optional[PusherConfig]:
|
||||
resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
|
||||
|
||||
pusher_config = None
|
||||
for r in resultlist:
|
||||
if r.user_name == user_id:
|
||||
pusher_config = r
|
||||
|
||||
return pusher_config
|
||||
|
||||
async def process_pusher_change_by_id(
|
||||
self, app_id: str, pushkey: str, user_id: str
|
||||
) -> Optional[Pusher]:
|
||||
"""Look up the details for the given pusher, and start it
|
||||
"""Look up the details for the given pusher, and either start it if its
|
||||
"enabled" flag is True, or try to stop it otherwise.
|
||||
|
||||
If the pusher is new and its "enabled" flag is False, the stop is a noop.
|
||||
|
||||
Returns:
|
||||
The pusher started, if any
|
||||
|
@ -290,12 +318,13 @@ class PusherPool:
|
|||
if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
|
||||
return None
|
||||
|
||||
resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
|
||||
pusher_config = await self._get_pusher_config_for_user_by_app_id_and_pushkey(
|
||||
user_id, app_id, pushkey
|
||||
)
|
||||
|
||||
pusher_config = None
|
||||
for r in resultlist:
|
||||
if r.user_name == user_id:
|
||||
pusher_config = r
|
||||
if pusher_config and not pusher_config.enabled:
|
||||
self.maybe_stop_pusher(app_id, pushkey, user_id)
|
||||
return None
|
||||
|
||||
pusher = None
|
||||
if pusher_config:
|
||||
|
@ -305,7 +334,7 @@ class PusherPool:
|
|||
|
||||
async def _start_pushers(self) -> None:
|
||||
"""Start all the pushers"""
|
||||
pushers = await self.store.get_all_pushers()
|
||||
pushers = await self.store.get_enabled_pushers()
|
||||
|
||||
# Stagger starting up the pushers so we don't completely drown the
|
||||
# process on start up.
|
||||
|
@ -363,6 +392,8 @@ class PusherPool:
|
|||
|
||||
synapse_pushers.labels(type(pusher).__name__, pusher.app_id).inc()
|
||||
|
||||
logger.info("Starting pusher %s / %s", pusher.user_id, appid_pushkey)
|
||||
|
||||
# Check if there *may* be push to process. We do this as this check is a
|
||||
# lot cheaper to do than actually fetching the exact rows we need to
|
||||
# push.
|
||||
|
@ -382,16 +413,7 @@ class PusherPool:
|
|||
return pusher
|
||||
|
||||
async def remove_pusher(self, app_id: str, pushkey: str, user_id: str) -> None:
|
||||
appid_pushkey = "%s:%s" % (app_id, pushkey)
|
||||
|
||||
byuser = self.pushers.get(user_id, {})
|
||||
|
||||
if appid_pushkey in byuser:
|
||||
logger.info("Stopping pusher %s / %s", user_id, appid_pushkey)
|
||||
pusher = byuser.pop(appid_pushkey)
|
||||
pusher.on_stop()
|
||||
|
||||
synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec()
|
||||
self.maybe_stop_pusher(app_id, pushkey, user_id)
|
||||
|
||||
# We can only delete pushers on master.
|
||||
if self._remove_pusher_client:
|
||||
|
@ -402,3 +424,22 @@ class PusherPool:
|
|||
await self.store.delete_pusher_by_app_id_pushkey_user_id(
|
||||
app_id, pushkey, user_id
|
||||
)
|
||||
|
||||
def maybe_stop_pusher(self, app_id: str, pushkey: str, user_id: str) -> None:
|
||||
"""Stops a pusher with the given app ID and push key if one is running.
|
||||
|
||||
Args:
|
||||
app_id: the pusher's app ID.
|
||||
pushkey: the pusher's push key.
|
||||
user_id: the user the pusher belongs to. Only used for logging.
|
||||
"""
|
||||
appid_pushkey = "%s:%s" % (app_id, pushkey)
|
||||
|
||||
byuser = self.pushers.get(user_id, {})
|
||||
|
||||
if appid_pushkey in byuser:
|
||||
logger.info("Stopping pusher %s / %s", user_id, appid_pushkey)
|
||||
pusher = byuser.pop(appid_pushkey)
|
||||
pusher.on_stop()
|
||||
|
||||
synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec()
|
||||
|
|
|
@ -189,7 +189,9 @@ class ReplicationDataHandler:
|
|||
if row.deleted:
|
||||
self.stop_pusher(row.user_id, row.app_id, row.pushkey)
|
||||
else:
|
||||
await self.start_pusher(row.user_id, row.app_id, row.pushkey)
|
||||
await self.process_pusher_change(
|
||||
row.user_id, row.app_id, row.pushkey
|
||||
)
|
||||
elif stream_name == EventsStream.NAME:
|
||||
# We shouldn't get multiple rows per token for events stream, so
|
||||
# we don't need to optimise this for multiple rows.
|
||||
|
@ -334,13 +336,15 @@ class ReplicationDataHandler:
|
|||
logger.info("Stopping pusher %r / %r", user_id, key)
|
||||
pusher.on_stop()
|
||||
|
||||
async def start_pusher(self, user_id: str, app_id: str, pushkey: str) -> None:
|
||||
async def process_pusher_change(
|
||||
self, user_id: str, app_id: str, pushkey: str
|
||||
) -> None:
|
||||
if not self._notify_pushers:
|
||||
return
|
||||
|
||||
key = "%s:%s" % (app_id, pushkey)
|
||||
logger.info("Starting pusher %r / %r", user_id, key)
|
||||
await self._pusher_pool.start_pusher_by_id(app_id, pushkey, user_id)
|
||||
await self._pusher_pool.process_pusher_change_by_id(app_id, pushkey, user_id)
|
||||
|
||||
|
||||
class FederationSenderHandler:
|
||||
|
|
|
@ -375,7 +375,7 @@ class UserRestServletV2(RestServlet):
|
|||
and self.hs.config.email.email_notif_for_new_users
|
||||
and medium == "email"
|
||||
):
|
||||
await self.pusher_pool.add_pusher(
|
||||
await self.pusher_pool.add_or_update_pusher(
|
||||
user_id=user_id,
|
||||
access_token=None,
|
||||
kind="email",
|
||||
|
@ -383,7 +383,7 @@ class UserRestServletV2(RestServlet):
|
|||
app_display_name="Email Notifications",
|
||||
device_display_name=address,
|
||||
pushkey=address,
|
||||
lang=None, # We don't know a user's language here
|
||||
lang=None,
|
||||
data={},
|
||||
)
|
||||
|
||||
|
|
|
@ -42,6 +42,7 @@ class PushersRestServlet(RestServlet):
|
|||
super().__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self._msc3881_enabled = self.hs.config.experimental.msc3881_enabled
|
||||
|
||||
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
|
@ -51,9 +52,14 @@ class PushersRestServlet(RestServlet):
|
|||
user.to_string()
|
||||
)
|
||||
|
||||
filtered_pushers = [p.as_dict() for p in pushers]
|
||||
pusher_dicts = [p.as_dict() for p in pushers]
|
||||
|
||||
return 200, {"pushers": filtered_pushers}
|
||||
for pusher in pusher_dicts:
|
||||
if self._msc3881_enabled:
|
||||
pusher["org.matrix.msc3881.enabled"] = pusher["enabled"]
|
||||
del pusher["enabled"]
|
||||
|
||||
return 200, {"pushers": pusher_dicts}
|
||||
|
||||
|
||||
class PushersSetRestServlet(RestServlet):
|
||||
|
@ -65,6 +71,7 @@ class PushersSetRestServlet(RestServlet):
|
|||
self.auth = hs.get_auth()
|
||||
self.notifier = hs.get_notifier()
|
||||
self.pusher_pool = self.hs.get_pusherpool()
|
||||
self._msc3881_enabled = self.hs.config.experimental.msc3881_enabled
|
||||
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
|
@ -103,6 +110,10 @@ class PushersSetRestServlet(RestServlet):
|
|||
if "append" in content:
|
||||
append = content["append"]
|
||||
|
||||
enabled = True
|
||||
if self._msc3881_enabled and "org.matrix.msc3881.enabled" in content:
|
||||
enabled = content["org.matrix.msc3881.enabled"]
|
||||
|
||||
if not append:
|
||||
await self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user(
|
||||
app_id=content["app_id"],
|
||||
|
@ -111,7 +122,7 @@ class PushersSetRestServlet(RestServlet):
|
|||
)
|
||||
|
||||
try:
|
||||
await self.pusher_pool.add_pusher(
|
||||
await self.pusher_pool.add_or_update_pusher(
|
||||
user_id=user.to_string(),
|
||||
access_token=requester.access_token_id,
|
||||
kind=content["kind"],
|
||||
|
@ -122,6 +133,7 @@ class PushersSetRestServlet(RestServlet):
|
|||
lang=content["lang"],
|
||||
data=content["data"],
|
||||
profile_tag=content.get("profile_tag", ""),
|
||||
enabled=enabled,
|
||||
)
|
||||
except PusherConfigException as pce:
|
||||
raise SynapseError(
|
||||
|
|
|
@ -89,6 +89,11 @@ class PusherWorkerStore(SQLBaseStore):
|
|||
)
|
||||
continue
|
||||
|
||||
# If we're using SQLite, then boolean values are integers. This is
|
||||
# troublesome since some code using the return value of this method might
|
||||
# expect it to be a boolean, or will expose it to clients (in responses).
|
||||
r["enabled"] = bool(r["enabled"])
|
||||
|
||||
yield PusherConfig(**r)
|
||||
|
||||
async def get_pushers_by_app_id_and_pushkey(
|
||||
|
@ -100,38 +105,52 @@ class PusherWorkerStore(SQLBaseStore):
|
|||
return await self.get_pushers_by({"user_name": user_id})
|
||||
|
||||
async def get_pushers_by(self, keyvalues: Dict[str, Any]) -> Iterator[PusherConfig]:
|
||||
ret = await self.db_pool.simple_select_list(
|
||||
"pushers",
|
||||
keyvalues,
|
||||
[
|
||||
"id",
|
||||
"user_name",
|
||||
"access_token",
|
||||
"profile_tag",
|
||||
"kind",
|
||||
"app_id",
|
||||
"app_display_name",
|
||||
"device_display_name",
|
||||
"pushkey",
|
||||
"ts",
|
||||
"lang",
|
||||
"data",
|
||||
"last_stream_ordering",
|
||||
"last_success",
|
||||
"failing_since",
|
||||
],
|
||||
"""Retrieve pushers that match the given criteria.
|
||||
|
||||
Args:
|
||||
keyvalues: A {column: value} dictionary.
|
||||
|
||||
Returns:
|
||||
The pushers for which the given columns have the given values.
|
||||
"""
|
||||
|
||||
def get_pushers_by_txn(txn: LoggingTransaction) -> List[Dict[str, Any]]:
|
||||
# We could technically use simple_select_list here, but we need to call
|
||||
# COALESCE on the 'enabled' column. While it is technically possible to give
|
||||
# simple_select_list the whole `COALESCE(...) AS ...` as a column name, it
|
||||
# feels a bit hacky, so it's probably better to just inline the query.
|
||||
sql = """
|
||||
SELECT
|
||||
id, user_name, access_token, profile_tag, kind, app_id,
|
||||
app_display_name, device_display_name, pushkey, ts, lang, data,
|
||||
last_stream_ordering, last_success, failing_since,
|
||||
COALESCE(enabled, TRUE) AS enabled
|
||||
FROM pushers
|
||||
"""
|
||||
|
||||
sql += "WHERE %s" % (" AND ".join("%s = ?" % (k,) for k in keyvalues),)
|
||||
|
||||
txn.execute(sql, list(keyvalues.values()))
|
||||
|
||||
return self.db_pool.cursor_to_dict(txn)
|
||||
|
||||
ret = await self.db_pool.runInteraction(
|
||||
desc="get_pushers_by",
|
||||
func=get_pushers_by_txn,
|
||||
)
|
||||
|
||||
return self._decode_pushers_rows(ret)
|
||||
|
||||
async def get_all_pushers(self) -> Iterator[PusherConfig]:
|
||||
def get_pushers(txn: LoggingTransaction) -> Iterator[PusherConfig]:
|
||||
txn.execute("SELECT * FROM pushers")
|
||||
async def get_enabled_pushers(self) -> Iterator[PusherConfig]:
|
||||
def get_enabled_pushers_txn(txn: LoggingTransaction) -> Iterator[PusherConfig]:
|
||||
txn.execute("SELECT * FROM pushers WHERE COALESCE(enabled, TRUE)")
|
||||
rows = self.db_pool.cursor_to_dict(txn)
|
||||
|
||||
return self._decode_pushers_rows(rows)
|
||||
|
||||
return await self.db_pool.runInteraction("get_all_pushers", get_pushers)
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_enabled_pushers", get_enabled_pushers_txn
|
||||
)
|
||||
|
||||
async def get_all_updated_pushers_rows(
|
||||
self, instance_name: str, last_id: int, current_id: int, limit: int
|
||||
|
@ -476,6 +495,7 @@ class PusherStore(PusherWorkerStore):
|
|||
data: Optional[JsonDict],
|
||||
last_stream_ordering: int,
|
||||
profile_tag: str = "",
|
||||
enabled: bool = True,
|
||||
) -> None:
|
||||
async with self._pushers_id_gen.get_next() as stream_id:
|
||||
# no need to lock because `pushers` has a unique key on
|
||||
|
@ -494,6 +514,7 @@ class PusherStore(PusherWorkerStore):
|
|||
"last_stream_ordering": last_stream_ordering,
|
||||
"profile_tag": profile_tag,
|
||||
"id": stream_id,
|
||||
"enabled": enabled,
|
||||
},
|
||||
desc="add_pusher",
|
||||
lock=False,
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
/* Copyright 2022 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
ALTER TABLE pushers ADD COLUMN enabled BOOLEAN;
|
|
@ -114,7 +114,7 @@ class EmailPusherTests(HomeserverTestCase):
|
|||
)
|
||||
|
||||
self.pusher = self.get_success(
|
||||
self.hs.get_pusherpool().add_pusher(
|
||||
self.hs.get_pusherpool().add_or_update_pusher(
|
||||
user_id=self.user_id,
|
||||
access_token=self.token_id,
|
||||
kind="email",
|
||||
|
@ -136,7 +136,7 @@ class EmailPusherTests(HomeserverTestCase):
|
|||
"""
|
||||
with self.assertRaises(SynapseError) as cm:
|
||||
self.get_success_or_raise(
|
||||
self.hs.get_pusherpool().add_pusher(
|
||||
self.hs.get_pusherpool().add_or_update_pusher(
|
||||
user_id=self.user_id,
|
||||
access_token=self.token_id,
|
||||
kind="email",
|
||||
|
|
|
@ -19,8 +19,8 @@ from twisted.test.proto_helpers import MemoryReactor
|
|||
|
||||
import synapse.rest.admin
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.push import PusherConfigException
|
||||
from synapse.rest.client import login, push_rule, receipts, room
|
||||
from synapse.push import PusherConfig, PusherConfigException
|
||||
from synapse.rest.client import login, push_rule, pusher, receipts, room
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import Clock
|
||||
|
@ -35,6 +35,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
|||
login.register_servlets,
|
||||
receipts.register_servlets,
|
||||
push_rule.register_servlets,
|
||||
pusher.register_servlets,
|
||||
]
|
||||
user_id = True
|
||||
hijack_auth = False
|
||||
|
@ -74,7 +75,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
|||
|
||||
def test_data(data: Optional[JsonDict]) -> None:
|
||||
self.get_failure(
|
||||
self.hs.get_pusherpool().add_pusher(
|
||||
self.hs.get_pusherpool().add_or_update_pusher(
|
||||
user_id=user_id,
|
||||
access_token=token_id,
|
||||
kind="http",
|
||||
|
@ -119,7 +120,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
|||
token_id = user_tuple.token_id
|
||||
|
||||
self.get_success(
|
||||
self.hs.get_pusherpool().add_pusher(
|
||||
self.hs.get_pusherpool().add_or_update_pusher(
|
||||
user_id=user_id,
|
||||
access_token=token_id,
|
||||
kind="http",
|
||||
|
@ -235,7 +236,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
|||
token_id = user_tuple.token_id
|
||||
|
||||
self.get_success(
|
||||
self.hs.get_pusherpool().add_pusher(
|
||||
self.hs.get_pusherpool().add_or_update_pusher(
|
||||
user_id=user_id,
|
||||
access_token=token_id,
|
||||
kind="http",
|
||||
|
@ -355,7 +356,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
|||
token_id = user_tuple.token_id
|
||||
|
||||
self.get_success(
|
||||
self.hs.get_pusherpool().add_pusher(
|
||||
self.hs.get_pusherpool().add_or_update_pusher(
|
||||
user_id=user_id,
|
||||
access_token=token_id,
|
||||
kind="http",
|
||||
|
@ -441,7 +442,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
|||
token_id = user_tuple.token_id
|
||||
|
||||
self.get_success(
|
||||
self.hs.get_pusherpool().add_pusher(
|
||||
self.hs.get_pusherpool().add_or_update_pusher(
|
||||
user_id=user_id,
|
||||
access_token=token_id,
|
||||
kind="http",
|
||||
|
@ -518,7 +519,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
|||
token_id = user_tuple.token_id
|
||||
|
||||
self.get_success(
|
||||
self.hs.get_pusherpool().add_pusher(
|
||||
self.hs.get_pusherpool().add_or_update_pusher(
|
||||
user_id=user_id,
|
||||
access_token=token_id,
|
||||
kind="http",
|
||||
|
@ -624,7 +625,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
|||
token_id = user_tuple.token_id
|
||||
|
||||
self.get_success(
|
||||
self.hs.get_pusherpool().add_pusher(
|
||||
self.hs.get_pusherpool().add_or_update_pusher(
|
||||
user_id=user_id,
|
||||
access_token=token_id,
|
||||
kind="http",
|
||||
|
@ -728,18 +729,38 @@ class HTTPPusherTests(HomeserverTestCase):
|
|||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
def _make_user_with_pusher(self, username: str) -> Tuple[str, str]:
|
||||
def _make_user_with_pusher(
|
||||
self, username: str, enabled: bool = True
|
||||
) -> Tuple[str, str]:
|
||||
"""Registers a user and creates a pusher for them.
|
||||
|
||||
Args:
|
||||
username: the localpart of the new user's Matrix ID.
|
||||
enabled: whether to create the pusher in an enabled or disabled state.
|
||||
"""
|
||||
user_id = self.register_user(username, "pass")
|
||||
access_token = self.login(username, "pass")
|
||||
|
||||
# Register the pusher
|
||||
self._set_pusher(user_id, access_token, enabled)
|
||||
|
||||
return user_id, access_token
|
||||
|
||||
def _set_pusher(self, user_id: str, access_token: str, enabled: bool) -> None:
|
||||
"""Creates or updates the pusher for the given user.
|
||||
|
||||
Args:
|
||||
user_id: the user's Matrix ID.
|
||||
access_token: the access token associated with the pusher.
|
||||
enabled: whether to enable or disable the pusher.
|
||||
"""
|
||||
user_tuple = self.get_success(
|
||||
self.hs.get_datastores().main.get_user_by_access_token(access_token)
|
||||
)
|
||||
token_id = user_tuple.token_id
|
||||
|
||||
self.get_success(
|
||||
self.hs.get_pusherpool().add_pusher(
|
||||
self.hs.get_pusherpool().add_or_update_pusher(
|
||||
user_id=user_id,
|
||||
access_token=token_id,
|
||||
kind="http",
|
||||
|
@ -749,11 +770,10 @@ class HTTPPusherTests(HomeserverTestCase):
|
|||
pushkey="a@example.com",
|
||||
lang=None,
|
||||
data={"url": "http://example.com/_matrix/push/v1/notify"},
|
||||
enabled=enabled,
|
||||
)
|
||||
)
|
||||
|
||||
return user_id, access_token
|
||||
|
||||
def test_dont_notify_rule_overrides_message(self) -> None:
|
||||
"""
|
||||
The override push rule will suppress notification
|
||||
|
@ -791,3 +811,105 @@ class HTTPPusherTests(HomeserverTestCase):
|
|||
# The user sends a message back (sends a notification)
|
||||
self.helper.send(room, body="Hello", tok=access_token)
|
||||
self.assertEqual(len(self.push_attempts), 1)
|
||||
|
||||
@override_config({"experimental_features": {"msc3881_enabled": True}})
|
||||
def test_disable(self) -> None:
|
||||
"""Tests that disabling a pusher means it's not pushed to anymore."""
|
||||
user_id, access_token = self._make_user_with_pusher("user")
|
||||
other_user_id, other_access_token = self._make_user_with_pusher("otheruser")
|
||||
|
||||
room = self.helper.create_room_as(user_id, tok=access_token)
|
||||
self.helper.join(room=room, user=other_user_id, tok=other_access_token)
|
||||
|
||||
# Send a message and check that it generated a push.
|
||||
self.helper.send(room, body="Hi!", tok=other_access_token)
|
||||
self.assertEqual(len(self.push_attempts), 1)
|
||||
|
||||
# Disable the pusher.
|
||||
self._set_pusher(user_id, access_token, enabled=False)
|
||||
|
||||
# Send another message and check that it did not generate a push.
|
||||
self.helper.send(room, body="Hi!", tok=other_access_token)
|
||||
self.assertEqual(len(self.push_attempts), 1)
|
||||
|
||||
# Get the pushers for the user and check that it is marked as disabled.
|
||||
channel = self.make_request("GET", "/pushers", access_token=access_token)
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(len(channel.json_body["pushers"]), 1)
|
||||
|
||||
enabled = channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"]
|
||||
self.assertFalse(enabled)
|
||||
self.assertTrue(isinstance(enabled, bool))
|
||||
|
||||
@override_config({"experimental_features": {"msc3881_enabled": True}})
|
||||
def test_enable(self) -> None:
|
||||
"""Tests that enabling a disabled pusher means it gets pushed to."""
|
||||
# Create the user with the pusher already disabled.
|
||||
user_id, access_token = self._make_user_with_pusher("user", enabled=False)
|
||||
other_user_id, other_access_token = self._make_user_with_pusher("otheruser")
|
||||
|
||||
room = self.helper.create_room_as(user_id, tok=access_token)
|
||||
self.helper.join(room=room, user=other_user_id, tok=other_access_token)
|
||||
|
||||
# Send a message and check that it did not generate a push.
|
||||
self.helper.send(room, body="Hi!", tok=other_access_token)
|
||||
self.assertEqual(len(self.push_attempts), 0)
|
||||
|
||||
# Enable the pusher.
|
||||
self._set_pusher(user_id, access_token, enabled=True)
|
||||
|
||||
# Send another message and check that it did generate a push.
|
||||
self.helper.send(room, body="Hi!", tok=other_access_token)
|
||||
self.assertEqual(len(self.push_attempts), 1)
|
||||
|
||||
# Get the pushers for the user and check that it is marked as enabled.
|
||||
channel = self.make_request("GET", "/pushers", access_token=access_token)
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(len(channel.json_body["pushers"]), 1)
|
||||
|
||||
enabled = channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"]
|
||||
self.assertTrue(enabled)
|
||||
self.assertTrue(isinstance(enabled, bool))
|
||||
|
||||
@override_config({"experimental_features": {"msc3881_enabled": True}})
|
||||
def test_null_enabled(self) -> None:
|
||||
"""Tests that a pusher that has an 'enabled' column set to NULL (eg pushers
|
||||
created before the column was introduced) is considered enabled.
|
||||
"""
|
||||
# We intentionally set 'enabled' to None so that it's stored as NULL in the
|
||||
# database.
|
||||
user_id, access_token = self._make_user_with_pusher("user", enabled=None) # type: ignore[arg-type]
|
||||
|
||||
channel = self.make_request("GET", "/pushers", access_token=access_token)
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(len(channel.json_body["pushers"]), 1)
|
||||
self.assertTrue(channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"])
|
||||
|
||||
def test_update_different_device_access_token(self) -> None:
|
||||
"""Tests that if we create a pusher from one device, the update it from another
|
||||
device, the access token associated with the pusher stays the same.
|
||||
"""
|
||||
# Create a user with a pusher.
|
||||
user_id, access_token = self._make_user_with_pusher("user")
|
||||
|
||||
# Get the token ID for the current access token, since that's what we store in
|
||||
# the pushers table.
|
||||
user_tuple = self.get_success(
|
||||
self.hs.get_datastores().main.get_user_by_access_token(access_token)
|
||||
)
|
||||
token_id = user_tuple.token_id
|
||||
|
||||
# Generate a new access token, and update the pusher with it.
|
||||
new_token = self.login("user", "pass")
|
||||
self._set_pusher(user_id, new_token, enabled=False)
|
||||
|
||||
# Get the current list of pushers for the user.
|
||||
ret = self.get_success(
|
||||
self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
|
||||
)
|
||||
pushers: List[PusherConfig] = list(ret)
|
||||
|
||||
# Check that we still have one pusher, and that the access token associated with
|
||||
# it didn't change.
|
||||
self.assertEqual(len(pushers), 1)
|
||||
self.assertEqual(pushers[0].access_token, token_id)
|
||||
|
|
|
@ -55,7 +55,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
|
|||
token_id = user_dict.token_id
|
||||
|
||||
self.get_success(
|
||||
self.hs.get_pusherpool().add_pusher(
|
||||
self.hs.get_pusherpool().add_or_update_pusher(
|
||||
user_id=user_id,
|
||||
access_token=token_id,
|
||||
kind="http",
|
||||
|
|
|
@ -2839,7 +2839,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
|
|||
token_id = user_tuple.token_id
|
||||
|
||||
self.get_success(
|
||||
self.hs.get_pusherpool().add_pusher(
|
||||
self.hs.get_pusherpool().add_or_update_pusher(
|
||||
user_id=self.other_user,
|
||||
access_token=token_id,
|
||||
kind="http",
|
||||
|
|
Loading…
Reference in a new issue