mirror of
https://mau.dev/maunium/synapse.git
synced 2024-12-14 23:53:51 +01:00
Add appservice worker
This commit is contained in:
parent
9da84a9a1e
commit
07229bbdae
7 changed files with 364 additions and 118 deletions
211
synapse/app/appservice.py
Normal file
211
synapse/app/appservice.py
Normal file
|
@ -0,0 +1,211 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2016 OpenMarket Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import synapse
|
||||||
|
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
from synapse.config._base import ConfigError
|
||||||
|
from synapse.config.logger import setup_logging
|
||||||
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
|
from synapse.http.site import SynapseSite
|
||||||
|
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
|
||||||
|
from synapse.replication.slave.storage.directory import DirectoryStore
|
||||||
|
from synapse.replication.slave.storage.events import SlavedEventStore
|
||||||
|
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
|
||||||
|
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
||||||
|
from synapse.storage.engines import create_engine
|
||||||
|
from synapse.util.async import sleep
|
||||||
|
from synapse.util.httpresourcetree import create_resource_tree
|
||||||
|
from synapse.util.logcontext import LoggingContext
|
||||||
|
from synapse.util.manhole import manhole
|
||||||
|
from synapse.util.rlimit import change_resource_limit
|
||||||
|
from synapse.util.versionstring import get_version_string
|
||||||
|
|
||||||
|
from twisted.internet import reactor, defer
|
||||||
|
from twisted.web.resource import Resource
|
||||||
|
|
||||||
|
from daemonize import Daemonize
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
import gc
|
||||||
|
|
||||||
|
logger = logging.getLogger("synapse.app.appservice")
|
||||||
|
|
||||||
|
|
||||||
|
class AppserviceSlaveStore(
|
||||||
|
DirectoryStore, SlavedEventStore, SlavedApplicationServiceStore,
|
||||||
|
SlavedRegistrationStore,
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AppserviceServer(HomeServer):
|
||||||
|
def get_db_conn(self, run_new_connection=True):
|
||||||
|
# Any param beginning with cp_ is a parameter for adbapi, and should
|
||||||
|
# not be passed to the database engine.
|
||||||
|
db_params = {
|
||||||
|
k: v for k, v in self.db_config.get("args", {}).items()
|
||||||
|
if not k.startswith("cp_")
|
||||||
|
}
|
||||||
|
db_conn = self.database_engine.module.connect(**db_params)
|
||||||
|
|
||||||
|
if run_new_connection:
|
||||||
|
self.database_engine.on_new_connection(db_conn)
|
||||||
|
return db_conn
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
logger.info("Setting up.")
|
||||||
|
self.datastore = AppserviceSlaveStore(self.get_db_conn(), self)
|
||||||
|
logger.info("Finished setting up.")
|
||||||
|
|
||||||
|
def _listen_http(self, listener_config):
|
||||||
|
port = listener_config["port"]
|
||||||
|
bind_address = listener_config.get("bind_address", "")
|
||||||
|
site_tag = listener_config.get("tag", port)
|
||||||
|
resources = {}
|
||||||
|
for res in listener_config["resources"]:
|
||||||
|
for name in res["names"]:
|
||||||
|
if name == "metrics":
|
||||||
|
resources[METRICS_PREFIX] = MetricsResource(self)
|
||||||
|
|
||||||
|
root_resource = create_resource_tree(resources, Resource())
|
||||||
|
reactor.listenTCP(
|
||||||
|
port,
|
||||||
|
SynapseSite(
|
||||||
|
"synapse.access.http.%s" % (site_tag,),
|
||||||
|
site_tag,
|
||||||
|
listener_config,
|
||||||
|
root_resource,
|
||||||
|
),
|
||||||
|
interface=bind_address
|
||||||
|
)
|
||||||
|
logger.info("Synapse appservice now listening on port %d", port)
|
||||||
|
|
||||||
|
def start_listening(self, listeners):
|
||||||
|
for listener in listeners:
|
||||||
|
if listener["type"] == "http":
|
||||||
|
self._listen_http(listener)
|
||||||
|
elif listener["type"] == "manhole":
|
||||||
|
reactor.listenTCP(
|
||||||
|
listener["port"],
|
||||||
|
manhole(
|
||||||
|
username="matrix",
|
||||||
|
password="rabbithole",
|
||||||
|
globals={"hs": self},
|
||||||
|
),
|
||||||
|
interface=listener.get("bind_address", '127.0.0.1')
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warn("Unrecognized listener type: %s", listener["type"])
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def replicate(self):
|
||||||
|
http_client = self.get_simple_http_client()
|
||||||
|
store = self.get_datastore()
|
||||||
|
replication_url = self.config.worker_replication_url
|
||||||
|
appservice_handler = self.get_application_service_handler()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def replicate(results):
|
||||||
|
stream = results.get("events")
|
||||||
|
if stream:
|
||||||
|
max_stream_id = stream["position"]
|
||||||
|
yield appservice_handler.notify_interested_services(max_stream_id)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
logger.info("Hitting replication")
|
||||||
|
args = store.stream_positions()
|
||||||
|
args["timeout"] = 30000
|
||||||
|
result = yield http_client.get_json(replication_url, args=args)
|
||||||
|
logger.info("Got replication response")
|
||||||
|
yield store.process_replication(result)
|
||||||
|
replicate(result)
|
||||||
|
except:
|
||||||
|
logger.exception("Error replicating from %r", replication_url)
|
||||||
|
yield sleep(30)
|
||||||
|
|
||||||
|
|
||||||
|
def start(config_options):
|
||||||
|
try:
|
||||||
|
config = HomeServerConfig.load_config(
|
||||||
|
"Synapse appservice", config_options
|
||||||
|
)
|
||||||
|
except ConfigError as e:
|
||||||
|
sys.stderr.write("\n" + e.message + "\n")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
assert config.worker_app == "synapse.app.appservice"
|
||||||
|
|
||||||
|
setup_logging(config.worker_log_config, config.worker_log_file)
|
||||||
|
|
||||||
|
database_engine = create_engine(config.database_config)
|
||||||
|
|
||||||
|
if config.notify_appservices:
|
||||||
|
sys.stderr.write(
|
||||||
|
"\nThe appservices must be disabled in the main synapse process"
|
||||||
|
"\nbefore they can be run in a separate worker."
|
||||||
|
"\nPlease add ``notify_appservices: false`` to the main config"
|
||||||
|
"\n"
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Force the pushers to start since they will be disabled in the main config
|
||||||
|
config.notify_appservices = True
|
||||||
|
|
||||||
|
ps = AppserviceServer(
|
||||||
|
config.server_name,
|
||||||
|
db_config=config.database_config,
|
||||||
|
config=config,
|
||||||
|
version_string="Synapse/" + get_version_string(synapse),
|
||||||
|
database_engine=database_engine,
|
||||||
|
)
|
||||||
|
|
||||||
|
ps.setup()
|
||||||
|
ps.start_listening(config.worker_listeners)
|
||||||
|
|
||||||
|
def run():
|
||||||
|
with LoggingContext("run"):
|
||||||
|
logger.info("Running")
|
||||||
|
change_resource_limit(config.soft_file_limit)
|
||||||
|
if config.gc_thresholds:
|
||||||
|
gc.set_threshold(*config.gc_thresholds)
|
||||||
|
reactor.run()
|
||||||
|
|
||||||
|
def start():
|
||||||
|
ps.replicate()
|
||||||
|
ps.get_datastore().start_profiling()
|
||||||
|
|
||||||
|
reactor.callWhenRunning(start)
|
||||||
|
|
||||||
|
if config.worker_daemonize:
|
||||||
|
daemon = Daemonize(
|
||||||
|
app="synapse-appservice",
|
||||||
|
pid=config.worker_pid_file,
|
||||||
|
action=run,
|
||||||
|
auto_close_fds=False,
|
||||||
|
verbose=True,
|
||||||
|
logger=logger,
|
||||||
|
)
|
||||||
|
daemon.start()
|
||||||
|
else:
|
||||||
|
run()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
with LoggingContext("main"):
|
||||||
|
start(sys.argv[1:])
|
|
@ -28,6 +28,7 @@ class AppServiceConfig(Config):
|
||||||
|
|
||||||
def read_config(self, config):
|
def read_config(self, config):
|
||||||
self.app_service_config_files = config.get("app_service_config_files", [])
|
self.app_service_config_files = config.get("app_service_config_files", [])
|
||||||
|
self.notify_appservices = config.get("notify_appservices", True)
|
||||||
|
|
||||||
def default_config(cls, **kwargs):
|
def default_config(cls, **kwargs):
|
||||||
return """\
|
return """\
|
||||||
|
|
|
@ -44,6 +44,10 @@ class ApplicationServicesHandler(object):
|
||||||
self.scheduler = hs.get_application_service_scheduler()
|
self.scheduler = hs.get_application_service_scheduler()
|
||||||
self.started_scheduler = False
|
self.started_scheduler = False
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
|
self.notify_appservices = hs.config.notify_appservices
|
||||||
|
|
||||||
|
self.current_max = 0
|
||||||
|
self.is_processing = False
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def notify_interested_services(self, current_id):
|
def notify_interested_services(self, current_id):
|
||||||
|
@ -56,19 +60,23 @@ class ApplicationServicesHandler(object):
|
||||||
current_id(int): The current maximum ID.
|
current_id(int): The current maximum ID.
|
||||||
"""
|
"""
|
||||||
services = yield self.store.get_app_services()
|
services = yield self.store.get_app_services()
|
||||||
if not services:
|
if not services or not self.notify_appservices:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.current_max = max(self.current_max, current_id)
|
||||||
|
if self.is_processing:
|
||||||
return
|
return
|
||||||
|
|
||||||
with Measure(self.clock, "notify_interested_services"):
|
with Measure(self.clock, "notify_interested_services"):
|
||||||
upper_bound = current_id
|
self.is_processing = True
|
||||||
|
try:
|
||||||
|
upper_bound = self.current_max
|
||||||
limit = 100
|
limit = 100
|
||||||
while True:
|
while True:
|
||||||
upper_bound, events = yield self.store.get_new_events_for_appservice(
|
upper_bound, events = yield self.store.get_new_events_for_appservice(
|
||||||
upper_bound, limit
|
upper_bound, limit
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("Current_id: %r, upper_bound: %r", current_id, upper_bound)
|
|
||||||
|
|
||||||
if not events:
|
if not events:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@ -78,9 +86,10 @@ class ApplicationServicesHandler(object):
|
||||||
if len(services) == 0:
|
if len(services) == 0:
|
||||||
continue # no services need notifying
|
continue # no services need notifying
|
||||||
|
|
||||||
# Do we know this user exists? If not, poke the user query API for
|
# Do we know this user exists? If not, poke the user
|
||||||
# all services which match that user regex. This needs to block as
|
# query API for all services which match that user regex.
|
||||||
# these user queries need to be made BEFORE pushing the event.
|
# This needs to block as these user queries need to be
|
||||||
|
# made BEFORE pushing the event.
|
||||||
yield self._check_user_exists(event.sender)
|
yield self._check_user_exists(event.sender)
|
||||||
if event.type == EventTypes.Member:
|
if event.type == EventTypes.Member:
|
||||||
yield self._check_user_exists(event.state_key)
|
yield self._check_user_exists(event.state_key)
|
||||||
|
@ -91,12 +100,16 @@ class ApplicationServicesHandler(object):
|
||||||
|
|
||||||
# Fork off pushes to these services
|
# Fork off pushes to these services
|
||||||
for service in services:
|
for service in services:
|
||||||
preserve_fn(self.scheduler.submit_event_for_as)(service, event)
|
preserve_fn(self.scheduler.submit_event_for_as)(
|
||||||
|
service, event
|
||||||
|
)
|
||||||
|
|
||||||
yield self.store.set_appservice_last_pos(upper_bound)
|
yield self.store.set_appservice_last_pos(upper_bound)
|
||||||
|
|
||||||
if len(events) < limit:
|
if len(events) < limit:
|
||||||
break
|
break
|
||||||
|
finally:
|
||||||
|
self.is_processing = False
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def query_user_exists(self, user_id):
|
def query_user_exists(self, user_id):
|
||||||
|
|
|
@ -28,3 +28,13 @@ class SlavedApplicationServiceStore(BaseSlavedStore):
|
||||||
|
|
||||||
get_app_service_by_token = DataStore.get_app_service_by_token.__func__
|
get_app_service_by_token = DataStore.get_app_service_by_token.__func__
|
||||||
get_app_service_by_user_id = DataStore.get_app_service_by_user_id.__func__
|
get_app_service_by_user_id = DataStore.get_app_service_by_user_id.__func__
|
||||||
|
get_app_services = DataStore.get_app_services.__func__
|
||||||
|
get_new_events_for_appservice = DataStore.get_new_events_for_appservice.__func__
|
||||||
|
create_appservice_txn = DataStore.create_appservice_txn.__func__
|
||||||
|
get_appservices_by_state = DataStore.get_appservices_by_state.__func__
|
||||||
|
get_oldest_unsent_txn = DataStore.get_oldest_unsent_txn.__func__
|
||||||
|
_get_last_txn = DataStore._get_last_txn.__func__
|
||||||
|
complete_appservice_txn = DataStore.complete_appservice_txn.__func__
|
||||||
|
get_appservice_state = DataStore.get_appservice_state.__func__
|
||||||
|
set_appservice_last_pos = DataStore.set_appservice_last_pos.__func__
|
||||||
|
set_appservice_state = DataStore.set_appservice_state.__func__
|
||||||
|
|
|
@ -28,3 +28,6 @@ class SlavedRegistrationStore(BaseSlavedStore):
|
||||||
]
|
]
|
||||||
|
|
||||||
_query_for_auth = DataStore._query_for_auth.__func__
|
_query_for_auth = DataStore._query_for_auth.__func__
|
||||||
|
get_user_by_id = RegistrationStore.__dict__[
|
||||||
|
"get_user_by_id"
|
||||||
|
]
|
||||||
|
|
|
@ -218,13 +218,7 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
|
||||||
Returns:
|
Returns:
|
||||||
AppServiceTransaction: A new transaction.
|
AppServiceTransaction: A new transaction.
|
||||||
"""
|
"""
|
||||||
return self.runInteraction(
|
def _create_appservice_txn(txn):
|
||||||
"create_appservice_txn",
|
|
||||||
self._create_appservice_txn,
|
|
||||||
service, events
|
|
||||||
)
|
|
||||||
|
|
||||||
def _create_appservice_txn(self, txn, service, events):
|
|
||||||
# work out new txn id (highest txn id for this service += 1)
|
# work out new txn id (highest txn id for this service += 1)
|
||||||
# The highest id may be the last one sent (in which case it is last_txn)
|
# The highest id may be the last one sent (in which case it is last_txn)
|
||||||
# or it may be the highest in the txns list (which are waiting to be/are
|
# or it may be the highest in the txns list (which are waiting to be/are
|
||||||
|
@ -252,6 +246,11 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
|
||||||
service=service, id=new_txn_id, events=events
|
service=service, id=new_txn_id, events=events
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return self.runInteraction(
|
||||||
|
"create_appservice_txn",
|
||||||
|
_create_appservice_txn,
|
||||||
|
)
|
||||||
|
|
||||||
def complete_appservice_txn(self, txn_id, service):
|
def complete_appservice_txn(self, txn_id, service):
|
||||||
"""Completes an application service transaction.
|
"""Completes an application service transaction.
|
||||||
|
|
||||||
|
@ -263,15 +262,9 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
|
||||||
A Deferred which resolves if this transaction was stored
|
A Deferred which resolves if this transaction was stored
|
||||||
successfully.
|
successfully.
|
||||||
"""
|
"""
|
||||||
return self.runInteraction(
|
|
||||||
"complete_appservice_txn",
|
|
||||||
self._complete_appservice_txn,
|
|
||||||
txn_id, service
|
|
||||||
)
|
|
||||||
|
|
||||||
def _complete_appservice_txn(self, txn, txn_id, service):
|
|
||||||
txn_id = int(txn_id)
|
txn_id = int(txn_id)
|
||||||
|
|
||||||
|
def _complete_appservice_txn(txn):
|
||||||
# Debugging query: Make sure the txn being completed is EXACTLY +1 from
|
# Debugging query: Make sure the txn being completed is EXACTLY +1 from
|
||||||
# what was there before. If it isn't, we've got problems (e.g. the AS
|
# what was there before. If it isn't, we've got problems (e.g. the AS
|
||||||
# has probably missed some events), so whine loudly but still continue,
|
# has probably missed some events), so whine loudly but still continue,
|
||||||
|
@ -298,6 +291,11 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
|
||||||
dict(txn_id=txn_id, as_id=service.id)
|
dict(txn_id=txn_id, as_id=service.id)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return self.runInteraction(
|
||||||
|
"complete_appservice_txn",
|
||||||
|
_complete_appservice_txn,
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def get_oldest_unsent_txn(self, service):
|
def get_oldest_unsent_txn(self, service):
|
||||||
"""Get the oldest transaction which has not been sent for this
|
"""Get the oldest transaction which has not been sent for this
|
||||||
|
@ -309,24 +307,7 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
|
||||||
A Deferred which resolves to an AppServiceTransaction or
|
A Deferred which resolves to an AppServiceTransaction or
|
||||||
None.
|
None.
|
||||||
"""
|
"""
|
||||||
entry = yield self.runInteraction(
|
def _get_oldest_unsent_txn(txn):
|
||||||
"get_oldest_unsent_appservice_txn",
|
|
||||||
self._get_oldest_unsent_txn,
|
|
||||||
service
|
|
||||||
)
|
|
||||||
|
|
||||||
if not entry:
|
|
||||||
defer.returnValue(None)
|
|
||||||
|
|
||||||
event_ids = json.loads(entry["event_ids"])
|
|
||||||
|
|
||||||
events = yield self._get_events(event_ids)
|
|
||||||
|
|
||||||
defer.returnValue(AppServiceTransaction(
|
|
||||||
service=service, id=entry["txn_id"], events=events
|
|
||||||
))
|
|
||||||
|
|
||||||
def _get_oldest_unsent_txn(self, txn, service):
|
|
||||||
# Monotonically increasing txn ids, so just select the smallest
|
# Monotonically increasing txn ids, so just select the smallest
|
||||||
# one in the txns table (we delete them when they are sent)
|
# one in the txns table (we delete them when they are sent)
|
||||||
txn.execute(
|
txn.execute(
|
||||||
|
@ -342,6 +323,22 @@ class ApplicationServiceTransactionStore(SQLBaseStore):
|
||||||
|
|
||||||
return entry
|
return entry
|
||||||
|
|
||||||
|
entry = yield self.runInteraction(
|
||||||
|
"get_oldest_unsent_appservice_txn",
|
||||||
|
_get_oldest_unsent_txn,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not entry:
|
||||||
|
defer.returnValue(None)
|
||||||
|
|
||||||
|
event_ids = json.loads(entry["event_ids"])
|
||||||
|
|
||||||
|
events = yield self._get_events(event_ids)
|
||||||
|
|
||||||
|
defer.returnValue(AppServiceTransaction(
|
||||||
|
service=service, id=entry["txn_id"], events=events
|
||||||
|
))
|
||||||
|
|
||||||
def _get_last_txn(self, txn, service_id):
|
def _get_last_txn(self, txn, service_id):
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"SELECT last_txn FROM application_services_state WHERE as_id=?",
|
"SELECT last_txn FROM application_services_state WHERE as_id=?",
|
||||||
|
|
|
@ -93,7 +93,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
||||||
desc="add_refresh_token_to_user",
|
desc="add_refresh_token_to_user",
|
||||||
)
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def register(self, user_id, token=None, password_hash=None,
|
def register(self, user_id, token=None, password_hash=None,
|
||||||
was_guest=False, make_guest=False, appservice_id=None,
|
was_guest=False, make_guest=False, appservice_id=None,
|
||||||
create_profile_with_localpart=None, admin=False):
|
create_profile_with_localpart=None, admin=False):
|
||||||
|
@ -115,7 +114,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
||||||
Raises:
|
Raises:
|
||||||
StoreError if the user_id could not be registered.
|
StoreError if the user_id could not be registered.
|
||||||
"""
|
"""
|
||||||
yield self.runInteraction(
|
return self.runInteraction(
|
||||||
"register",
|
"register",
|
||||||
self._register,
|
self._register,
|
||||||
user_id,
|
user_id,
|
||||||
|
@ -127,8 +126,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
||||||
create_profile_with_localpart,
|
create_profile_with_localpart,
|
||||||
admin
|
admin
|
||||||
)
|
)
|
||||||
self.get_user_by_id.invalidate((user_id,))
|
|
||||||
self.is_guest.invalidate((user_id,))
|
|
||||||
|
|
||||||
def _register(
|
def _register(
|
||||||
self,
|
self,
|
||||||
|
@ -210,6 +207,11 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
||||||
(create_profile_with_localpart,)
|
(create_profile_with_localpart,)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._invalidate_cache_and_stream(
|
||||||
|
txn, self.get_user_by_id, (user_id,)
|
||||||
|
)
|
||||||
|
txn.call_after(self.is_guest.invalidate, (user_id,))
|
||||||
|
|
||||||
@cached()
|
@cached()
|
||||||
def get_user_by_id(self, user_id):
|
def get_user_by_id(self, user_id):
|
||||||
return self._simple_select_one(
|
return self._simple_select_one(
|
||||||
|
@ -236,19 +238,28 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
||||||
|
|
||||||
return self.runInteraction("get_users_by_id_case_insensitive", f)
|
return self.runInteraction("get_users_by_id_case_insensitive", f)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def user_set_password_hash(self, user_id, password_hash):
|
def user_set_password_hash(self, user_id, password_hash):
|
||||||
"""
|
"""
|
||||||
NB. This does *not* evict any cache because the one use for this
|
NB. This does *not* evict any cache because the one use for this
|
||||||
removes most of the entries subsequently anyway so it would be
|
removes most of the entries subsequently anyway so it would be
|
||||||
pointless. Use flush_user separately.
|
pointless. Use flush_user separately.
|
||||||
"""
|
"""
|
||||||
yield self._simple_update_one('users', {
|
def user_set_password_hash_txn(txn):
|
||||||
|
self._simple_update_one_txn(
|
||||||
|
txn,
|
||||||
|
'users', {
|
||||||
'name': user_id
|
'name': user_id
|
||||||
}, {
|
},
|
||||||
|
{
|
||||||
'password_hash': password_hash
|
'password_hash': password_hash
|
||||||
})
|
}
|
||||||
self.get_user_by_id.invalidate((user_id,))
|
)
|
||||||
|
self._invalidate_cache_and_stream(
|
||||||
|
txn, self.get_user_by_id, (user_id,)
|
||||||
|
)
|
||||||
|
return self.runInteraction(
|
||||||
|
"user_set_password_hash", user_set_password_hash_txn
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def user_delete_access_tokens(self, user_id, except_token_id=None,
|
def user_delete_access_tokens(self, user_id, except_token_id=None,
|
||||||
|
|
Loading…
Reference in a new issue