mirror of
https://mau.dev/maunium/synapse.git
synced 2024-11-17 15:31:19 +01:00
Merge branch 'develop' of github.com:matrix-org/synapse into erikj/createroom_content
This commit is contained in:
commit
d8a6c734fa
73 changed files with 1778 additions and 822 deletions
|
@ -5,7 +5,8 @@ To use it, first install prometheus by following the instructions at
|
||||||
|
|
||||||
http://prometheus.io/
|
http://prometheus.io/
|
||||||
|
|
||||||
Then add a new job to the main prometheus.conf file:
|
### for Prometheus v1
|
||||||
|
Add a new job to the main prometheus.conf file:
|
||||||
|
|
||||||
job: {
|
job: {
|
||||||
name: "synapse"
|
name: "synapse"
|
||||||
|
@ -15,6 +16,22 @@ Then add a new job to the main prometheus.conf file:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
### for Prometheus v2
|
||||||
|
Add a new job to the main prometheus.yml file:
|
||||||
|
|
||||||
|
- job_name: "synapse"
|
||||||
|
metrics_path: "/_synapse/metrics"
|
||||||
|
# when endpoint uses https:
|
||||||
|
scheme: "https"
|
||||||
|
|
||||||
|
static_configs:
|
||||||
|
- targets: ['SERVER.LOCATION:PORT']
|
||||||
|
|
||||||
|
To use `synapse.rules` add
|
||||||
|
|
||||||
|
rule_files:
|
||||||
|
- "/PATH/TO/synapse-v2.rules"
|
||||||
|
|
||||||
Metrics are disabled by default when running synapse; they must be enabled
|
Metrics are disabled by default when running synapse; they must be enabled
|
||||||
with the 'enable-metrics' option, either in the synapse config file or as a
|
with the 'enable-metrics' option, either in the synapse config file or as a
|
||||||
command-line option.
|
command-line option.
|
||||||
|
|
60
contrib/prometheus/synapse-v2.rules
Normal file
60
contrib/prometheus/synapse-v2.rules
Normal file
|
@ -0,0 +1,60 @@
|
||||||
|
groups:
|
||||||
|
- name: synapse
|
||||||
|
rules:
|
||||||
|
- record: "synapse_federation_transaction_queue_pendingEdus:total"
|
||||||
|
expr: "sum(synapse_federation_transaction_queue_pendingEdus or absent(synapse_federation_transaction_queue_pendingEdus)*0)"
|
||||||
|
- record: "synapse_federation_transaction_queue_pendingPdus:total"
|
||||||
|
expr: "sum(synapse_federation_transaction_queue_pendingPdus or absent(synapse_federation_transaction_queue_pendingPdus)*0)"
|
||||||
|
- record: 'synapse_http_server_requests:method'
|
||||||
|
labels:
|
||||||
|
servlet: ""
|
||||||
|
expr: "sum(synapse_http_server_requests) by (method)"
|
||||||
|
- record: 'synapse_http_server_requests:servlet'
|
||||||
|
labels:
|
||||||
|
method: ""
|
||||||
|
expr: 'sum(synapse_http_server_requests) by (servlet)'
|
||||||
|
|
||||||
|
- record: 'synapse_http_server_requests:total'
|
||||||
|
labels:
|
||||||
|
servlet: ""
|
||||||
|
expr: 'sum(synapse_http_server_requests:by_method) by (servlet)'
|
||||||
|
|
||||||
|
- record: 'synapse_cache:hit_ratio_5m'
|
||||||
|
expr: 'rate(synapse_util_caches_cache:hits[5m]) / rate(synapse_util_caches_cache:total[5m])'
|
||||||
|
- record: 'synapse_cache:hit_ratio_30s'
|
||||||
|
expr: 'rate(synapse_util_caches_cache:hits[30s]) / rate(synapse_util_caches_cache:total[30s])'
|
||||||
|
|
||||||
|
- record: 'synapse_federation_client_sent'
|
||||||
|
labels:
|
||||||
|
type: "EDU"
|
||||||
|
expr: 'synapse_federation_client_sent_edus + 0'
|
||||||
|
- record: 'synapse_federation_client_sent'
|
||||||
|
labels:
|
||||||
|
type: "PDU"
|
||||||
|
expr: 'synapse_federation_client_sent_pdu_destinations:count + 0'
|
||||||
|
- record: 'synapse_federation_client_sent'
|
||||||
|
labels:
|
||||||
|
type: "Query"
|
||||||
|
expr: 'sum(synapse_federation_client_sent_queries) by (job)'
|
||||||
|
|
||||||
|
- record: 'synapse_federation_server_received'
|
||||||
|
labels:
|
||||||
|
type: "EDU"
|
||||||
|
expr: 'synapse_federation_server_received_edus + 0'
|
||||||
|
- record: 'synapse_federation_server_received'
|
||||||
|
labels:
|
||||||
|
type: "PDU"
|
||||||
|
expr: 'synapse_federation_server_received_pdus + 0'
|
||||||
|
- record: 'synapse_federation_server_received'
|
||||||
|
labels:
|
||||||
|
type: "Query"
|
||||||
|
expr: 'sum(synapse_federation_server_received_queries) by (job)'
|
||||||
|
|
||||||
|
- record: 'synapse_federation_transaction_queue_pending'
|
||||||
|
labels:
|
||||||
|
type: "EDU"
|
||||||
|
expr: 'synapse_federation_transaction_queue_pending_edus + 0'
|
||||||
|
- record: 'synapse_federation_transaction_queue_pending'
|
||||||
|
labels:
|
||||||
|
type: "PDU"
|
||||||
|
expr: 'synapse_federation_transaction_queue_pending_pdus + 0'
|
|
@ -298,10 +298,6 @@ It can be used like this:
|
||||||
# this will now be logged against the request context
|
# this will now be logged against the request context
|
||||||
logger.debug("Request handling complete")
|
logger.debug("Request handling complete")
|
||||||
|
|
||||||
XXX: I think ``preserve_context_over_fn`` is supposed to do the first option,
|
|
||||||
but the fact that it does ``preserve_context_over_deferred`` on its results
|
|
||||||
means that its use is fraught with difficulty.
|
|
||||||
|
|
||||||
Passing synapse deferreds into third-party functions
|
Passing synapse deferreds into third-party functions
|
||||||
----------------------------------------------------
|
----------------------------------------------------
|
||||||
|
|
||||||
|
|
157
docs/workers.rst
157
docs/workers.rst
|
@ -1,11 +1,15 @@
|
||||||
Scaling synapse via workers
|
Scaling synapse via workers
|
||||||
---------------------------
|
===========================
|
||||||
|
|
||||||
Synapse has experimental support for splitting out functionality into
|
Synapse has experimental support for splitting out functionality into
|
||||||
multiple separate python processes, helping greatly with scalability. These
|
multiple separate python processes, helping greatly with scalability. These
|
||||||
processes are called 'workers', and are (eventually) intended to scale
|
processes are called 'workers', and are (eventually) intended to scale
|
||||||
horizontally independently.
|
horizontally independently.
|
||||||
|
|
||||||
|
All of the below is highly experimental and subject to change as Synapse evolves,
|
||||||
|
but documenting it here to help folks needing highly scalable Synapses similar
|
||||||
|
to the one running matrix.org!
|
||||||
|
|
||||||
All processes continue to share the same database instance, and as such, workers
|
All processes continue to share the same database instance, and as such, workers
|
||||||
only work with postgres based synapse deployments (sharing a single sqlite
|
only work with postgres based synapse deployments (sharing a single sqlite
|
||||||
across multiple processes is a recipe for disaster, plus you should be using
|
across multiple processes is a recipe for disaster, plus you should be using
|
||||||
|
@ -16,6 +20,16 @@ TCP protocol called 'replication' - analogous to MySQL or Postgres style
|
||||||
database replication; feeding a stream of relevant data to the workers so they
|
database replication; feeding a stream of relevant data to the workers so they
|
||||||
can be kept in sync with the main synapse process and database state.
|
can be kept in sync with the main synapse process and database state.
|
||||||
|
|
||||||
|
Configuration
|
||||||
|
-------------
|
||||||
|
|
||||||
|
To make effective use of the workers, you will need to configure an HTTP
|
||||||
|
reverse-proxy such as nginx or haproxy, which will direct incoming requests to
|
||||||
|
the correct worker, or to the main synapse instance. Note that this includes
|
||||||
|
requests made to the federation port. The caveats regarding running a
|
||||||
|
reverse-proxy on the federation port still apply (see
|
||||||
|
https://github.com/matrix-org/synapse/blob/master/README.rst#reverse-proxying-the-federation-port).
|
||||||
|
|
||||||
To enable workers, you need to add a replication listener to the master synapse, e.g.::
|
To enable workers, you need to add a replication listener to the master synapse, e.g.::
|
||||||
|
|
||||||
listeners:
|
listeners:
|
||||||
|
@ -27,26 +41,19 @@ Under **no circumstances** should this replication API listener be exposed to th
|
||||||
public internet; it currently implements no authentication whatsoever and is
|
public internet; it currently implements no authentication whatsoever and is
|
||||||
unencrypted.
|
unencrypted.
|
||||||
|
|
||||||
You then create a set of configs for the various worker processes. These should be
|
You then create a set of configs for the various worker processes. These
|
||||||
worker configuration files should be stored in a dedicated subdirectory, to allow
|
should be worker configuration files, and should be stored in a dedicated
|
||||||
synctl to manipulate them.
|
subdirectory, to allow synctl to manipulate them.
|
||||||
|
|
||||||
The current available worker applications are:
|
|
||||||
* synapse.app.pusher - handles sending push notifications to sygnal and email
|
|
||||||
* synapse.app.synchrotron - handles /sync endpoints. can scales horizontally through multiple instances.
|
|
||||||
* synapse.app.appservice - handles output traffic to Application Services
|
|
||||||
* synapse.app.federation_reader - handles receiving federation traffic (including public_rooms API)
|
|
||||||
* synapse.app.media_repository - handles the media repository.
|
|
||||||
* synapse.app.client_reader - handles client API endpoints like /publicRooms
|
|
||||||
|
|
||||||
Each worker configuration file inherits the configuration of the main homeserver
|
Each worker configuration file inherits the configuration of the main homeserver
|
||||||
configuration file. You can then override configuration specific to that worker,
|
configuration file. You can then override configuration specific to that worker,
|
||||||
e.g. the HTTP listener that it provides (if any); logging configuration; etc.
|
e.g. the HTTP listener that it provides (if any); logging configuration; etc.
|
||||||
You should minimise the number of overrides though to maintain a usable config.
|
You should minimise the number of overrides though to maintain a usable config.
|
||||||
|
|
||||||
You must specify the type of worker application (worker_app) and the replication
|
You must specify the type of worker application (``worker_app``). The currently
|
||||||
endpoint that it's talking to on the main synapse process (worker_replication_host
|
available worker applications are listed below. You must also specify the
|
||||||
and worker_replication_port).
|
replication endpoint that it's talking to on the main synapse process
|
||||||
|
(``worker_replication_host`` and ``worker_replication_port``).
|
||||||
|
|
||||||
For instance::
|
For instance::
|
||||||
|
|
||||||
|
@ -68,11 +75,11 @@ For instance::
|
||||||
worker_log_config: /home/matrix/synapse/config/synchrotron_log_config.yaml
|
worker_log_config: /home/matrix/synapse/config/synchrotron_log_config.yaml
|
||||||
|
|
||||||
...is a full configuration for a synchrotron worker instance, which will expose a
|
...is a full configuration for a synchrotron worker instance, which will expose a
|
||||||
plain HTTP /sync endpoint on port 8083 separately from the /sync endpoint provided
|
plain HTTP ``/sync`` endpoint on port 8083 separately from the ``/sync`` endpoint provided
|
||||||
by the main synapse.
|
by the main synapse.
|
||||||
|
|
||||||
Obviously you should configure your loadbalancer to route the /sync endpoint to
|
Obviously you should configure your reverse-proxy to route the relevant
|
||||||
the synchrotron instance(s) in this instance.
|
endpoints to the worker (``localhost:8083`` in the above example).
|
||||||
|
|
||||||
Finally, to actually run your worker-based synapse, you must pass synctl the -a
|
Finally, to actually run your worker-based synapse, you must pass synctl the -a
|
||||||
commandline option to tell it to operate on all the worker configurations found
|
commandline option to tell it to operate on all the worker configurations found
|
||||||
|
@ -89,6 +96,114 @@ To manipulate a specific worker, you pass the -w option to synctl::
|
||||||
|
|
||||||
synctl -w $CONFIG/workers/synchrotron.yaml restart
|
synctl -w $CONFIG/workers/synchrotron.yaml restart
|
||||||
|
|
||||||
All of the above is highly experimental and subject to change as Synapse evolves,
|
|
||||||
but documenting it here to help folks needing highly scalable Synapses similar
|
Available worker applications
|
||||||
to the one running matrix.org!
|
-----------------------------
|
||||||
|
|
||||||
|
``synapse.app.pusher``
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
Handles sending push notifications to sygnal and email. Doesn't handle any
|
||||||
|
REST endpoints itself, but you should set ``start_pushers: False`` in the
|
||||||
|
shared configuration file to stop the main synapse sending these notifications.
|
||||||
|
|
||||||
|
Note this worker cannot be load-balanced: only one instance should be active.
|
||||||
|
|
||||||
|
``synapse.app.synchrotron``
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
The synchrotron handles ``sync`` requests from clients. In particular, it can
|
||||||
|
handle REST endpoints matching the following regular expressions::
|
||||||
|
|
||||||
|
^/_matrix/client/(v2_alpha|r0)/sync$
|
||||||
|
^/_matrix/client/(api/v1|v2_alpha|r0)/events$
|
||||||
|
^/_matrix/client/(api/v1|r0)/initialSync$
|
||||||
|
^/_matrix/client/(api/v1|r0)/rooms/[^/]+/initialSync$
|
||||||
|
|
||||||
|
The above endpoints should all be routed to the synchrotron worker by the
|
||||||
|
reverse-proxy configuration.
|
||||||
|
|
||||||
|
It is possible to run multiple instances of the synchrotron to scale
|
||||||
|
horizontally. In this case the reverse-proxy should be configured to
|
||||||
|
load-balance across the instances, though it will be more efficient if all
|
||||||
|
requests from a particular user are routed to a single instance. Extracting
|
||||||
|
a userid from the access token is currently left as an exercise for the reader.
|
||||||
|
|
||||||
|
``synapse.app.appservice``
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
Handles sending output traffic to Application Services. Doesn't handle any
|
||||||
|
REST endpoints itself, but you should set ``notify_appservices: False`` in the
|
||||||
|
shared configuration file to stop the main synapse sending these notifications.
|
||||||
|
|
||||||
|
Note this worker cannot be load-balanced: only one instance should be active.
|
||||||
|
|
||||||
|
``synapse.app.federation_reader``
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
Handles a subset of federation endpoints. In particular, it can handle REST
|
||||||
|
endpoints matching the following regular expressions::
|
||||||
|
|
||||||
|
^/_matrix/federation/v1/event/
|
||||||
|
^/_matrix/federation/v1/state/
|
||||||
|
^/_matrix/federation/v1/state_ids/
|
||||||
|
^/_matrix/federation/v1/backfill/
|
||||||
|
^/_matrix/federation/v1/get_missing_events/
|
||||||
|
^/_matrix/federation/v1/publicRooms
|
||||||
|
|
||||||
|
The above endpoints should all be routed to the federation_reader worker by the
|
||||||
|
reverse-proxy configuration.
|
||||||
|
|
||||||
|
``synapse.app.federation_sender``
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
Handles sending federation traffic to other servers. Doesn't handle any
|
||||||
|
REST endpoints itself, but you should set ``send_federation: False`` in the
|
||||||
|
shared configuration file to stop the main synapse sending this traffic.
|
||||||
|
|
||||||
|
Note this worker cannot be load-balanced: only one instance should be active.
|
||||||
|
|
||||||
|
``synapse.app.media_repository``
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
Handles the media repository. It can handle all endpoints starting with::
|
||||||
|
|
||||||
|
/_matrix/media/
|
||||||
|
|
||||||
|
You should also set ``enable_media_repo: False`` in the shared configuration
|
||||||
|
file to stop the main synapse running background jobs related to managing the
|
||||||
|
media repository.
|
||||||
|
|
||||||
|
Note this worker cannot be load-balanced: only one instance should be active.
|
||||||
|
|
||||||
|
``synapse.app.client_reader``
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
Handles client API endpoints. It can handle REST endpoints matching the
|
||||||
|
following regular expressions::
|
||||||
|
|
||||||
|
^/_matrix/client/(api/v1|r0|unstable)/publicRooms$
|
||||||
|
|
||||||
|
``synapse.app.user_dir``
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
Handles searches in the user directory. It can handle REST endpoints matching
|
||||||
|
the following regular expressions::
|
||||||
|
|
||||||
|
^/_matrix/client/(api/v1|r0|unstable)/user_directory/search$
|
||||||
|
|
||||||
|
``synapse.app.frontend_proxy``
|
||||||
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
Proxies some frequently-requested client endpoints to add caching and remove
|
||||||
|
load from the main synapse. It can handle REST endpoints matching the following
|
||||||
|
regular expressions::
|
||||||
|
|
||||||
|
^/_matrix/client/(api/v1|r0|unstable)/keys/upload
|
||||||
|
|
||||||
|
It will proxy any requests it cannot handle to the main synapse instance. It
|
||||||
|
must therefore be configured with the location of the main instance, via
|
||||||
|
the ``worker_main_http_uri`` setting in the frontend_proxy worker configuration
|
||||||
|
file. For example::
|
||||||
|
|
||||||
|
worker_main_http_uri: http://127.0.0.1:8008
|
||||||
|
|
|
@ -123,15 +123,25 @@ def lookup(destination, path):
|
||||||
except:
|
except:
|
||||||
return "https://%s:%d%s" % (destination, 8448, path)
|
return "https://%s:%d%s" % (destination, 8448, path)
|
||||||
|
|
||||||
def get_json(origin_name, origin_key, destination, path):
|
|
||||||
request_json = {
|
def request_json(method, origin_name, origin_key, destination, path, content):
|
||||||
"method": "GET",
|
if method is None:
|
||||||
|
if content is None:
|
||||||
|
method = "GET"
|
||||||
|
else:
|
||||||
|
method = "POST"
|
||||||
|
|
||||||
|
json_to_sign = {
|
||||||
|
"method": method,
|
||||||
"uri": path,
|
"uri": path,
|
||||||
"origin": origin_name,
|
"origin": origin_name,
|
||||||
"destination": destination,
|
"destination": destination,
|
||||||
}
|
}
|
||||||
|
|
||||||
signed_json = sign_json(request_json, origin_key, origin_name)
|
if content is not None:
|
||||||
|
json_to_sign["content"] = json.loads(content)
|
||||||
|
|
||||||
|
signed_json = sign_json(json_to_sign, origin_key, origin_name)
|
||||||
|
|
||||||
authorization_headers = []
|
authorization_headers = []
|
||||||
|
|
||||||
|
@ -145,10 +155,12 @@ def get_json(origin_name, origin_key, destination, path):
|
||||||
dest = lookup(destination, path)
|
dest = lookup(destination, path)
|
||||||
print ("Requesting %s" % dest, file=sys.stderr)
|
print ("Requesting %s" % dest, file=sys.stderr)
|
||||||
|
|
||||||
result = requests.get(
|
result = requests.request(
|
||||||
dest,
|
method=method,
|
||||||
|
url=dest,
|
||||||
headers={"Authorization": authorization_headers[0]},
|
headers={"Authorization": authorization_headers[0]},
|
||||||
verify=False,
|
verify=False,
|
||||||
|
data=content,
|
||||||
)
|
)
|
||||||
sys.stderr.write("Status Code: %d\n" % (result.status_code,))
|
sys.stderr.write("Status Code: %d\n" % (result.status_code,))
|
||||||
return result.json()
|
return result.json()
|
||||||
|
@ -186,6 +198,17 @@ def main():
|
||||||
"connect appropriately.",
|
"connect appropriately.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"-X", "--method",
|
||||||
|
help="HTTP method to use for the request. Defaults to GET if --data is"
|
||||||
|
"unspecified, POST if it is."
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--body",
|
||||||
|
help="Data to send as the body of the HTTP request"
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"path",
|
"path",
|
||||||
help="request path. We will add '/_matrix/federation/v1/' to this."
|
help="request path. We will add '/_matrix/federation/v1/' to this."
|
||||||
|
@ -199,8 +222,11 @@ def main():
|
||||||
with open(args.signing_key_path) as f:
|
with open(args.signing_key_path) as f:
|
||||||
key = read_signing_keys(f)[0]
|
key = read_signing_keys(f)[0]
|
||||||
|
|
||||||
result = get_json(
|
result = request_json(
|
||||||
args.server_name, key, args.destination, "/_matrix/federation/v1/" + args.path
|
args.method,
|
||||||
|
args.server_name, key, args.destination,
|
||||||
|
"/_matrix/federation/v1/" + args.path,
|
||||||
|
content=args.body,
|
||||||
)
|
)
|
||||||
|
|
||||||
json.dump(result, sys.stdout)
|
json.dump(result, sys.stdout)
|
||||||
|
|
45
scripts/sync_room_to_group.pl
Executable file
45
scripts/sync_room_to_group.pl
Executable file
|
@ -0,0 +1,45 @@
|
||||||
|
#!/usr/bin/env perl
|
||||||
|
|
||||||
|
use strict;
|
||||||
|
use warnings;
|
||||||
|
|
||||||
|
use JSON::XS;
|
||||||
|
use LWP::UserAgent;
|
||||||
|
use URI::Escape;
|
||||||
|
|
||||||
|
if (@ARGV < 4) {
|
||||||
|
die "usage: $0 <homeserver url> <access_token> <room_id|room_alias> <group_id>\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
my ($hs, $access_token, $room_id, $group_id) = @ARGV;
|
||||||
|
my $ua = LWP::UserAgent->new();
|
||||||
|
$ua->timeout(10);
|
||||||
|
|
||||||
|
if ($room_id =~ /^#/) {
|
||||||
|
$room_id = uri_escape($room_id);
|
||||||
|
$room_id = decode_json($ua->get("${hs}/_matrix/client/r0/directory/room/${room_id}?access_token=${access_token}")->decoded_content)->{room_id};
|
||||||
|
}
|
||||||
|
|
||||||
|
my $room_users = [ keys %{decode_json($ua->get("${hs}/_matrix/client/r0/rooms/${room_id}/joined_members?access_token=${access_token}")->decoded_content)->{joined}} ];
|
||||||
|
my $group_users = [
|
||||||
|
(map { $_->{user_id} } @{decode_json($ua->get("${hs}/_matrix/client/unstable/groups/${group_id}/users?access_token=${access_token}" )->decoded_content)->{chunk}}),
|
||||||
|
(map { $_->{user_id} } @{decode_json($ua->get("${hs}/_matrix/client/unstable/groups/${group_id}/invited_users?access_token=${access_token}" )->decoded_content)->{chunk}}),
|
||||||
|
];
|
||||||
|
|
||||||
|
die "refusing to sync from empty room" unless (@$room_users);
|
||||||
|
die "refusing to sync to empty group" unless (@$group_users);
|
||||||
|
|
||||||
|
my $diff = {};
|
||||||
|
foreach my $user (@$room_users) { $diff->{$user}++ }
|
||||||
|
foreach my $user (@$group_users) { $diff->{$user}-- }
|
||||||
|
|
||||||
|
foreach my $user (keys %$diff) {
|
||||||
|
if ($diff->{$user} == 1) {
|
||||||
|
warn "inviting $user";
|
||||||
|
print STDERR $ua->put("${hs}/_matrix/client/unstable/groups/${group_id}/admin/users/invite/${user}?access_token=${access_token}", Content=>'{}')->status_line."\n";
|
||||||
|
}
|
||||||
|
elsif ($diff->{$user} == -1) {
|
||||||
|
warn "removing $user";
|
||||||
|
print STDERR $ua->put("${hs}/_matrix/client/unstable/groups/${group_id}/admin/users/remove/${user}?access_token=${access_token}", Content=>'{}')->status_line."\n";
|
||||||
|
}
|
||||||
|
}
|
|
@ -270,7 +270,11 @@ class Auth(object):
|
||||||
rights (str): The operation being performed; the access token must
|
rights (str): The operation being performed; the access token must
|
||||||
allow this.
|
allow this.
|
||||||
Returns:
|
Returns:
|
||||||
dict : dict that includes the user and the ID of their access token.
|
Deferred[dict]: dict that includes:
|
||||||
|
`user` (UserID)
|
||||||
|
`is_guest` (bool)
|
||||||
|
`token_id` (int|None): access token id. May be None if guest
|
||||||
|
`device_id` (str|None): device corresponding to access token
|
||||||
Raises:
|
Raises:
|
||||||
AuthError if no user by that token exists or the token is invalid.
|
AuthError if no user by that token exists or the token is invalid.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -140,6 +140,22 @@ class RegistrationError(SynapseError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InteractiveAuthIncompleteError(Exception):
|
||||||
|
"""An error raised when UI auth is not yet complete
|
||||||
|
|
||||||
|
(This indicates we should return a 401 with 'result' as the body)
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
result (dict): the server response to the request, which should be
|
||||||
|
passed back to the client
|
||||||
|
"""
|
||||||
|
def __init__(self, result):
|
||||||
|
super(InteractiveAuthIncompleteError, self).__init__(
|
||||||
|
"Interactive auth not yet complete",
|
||||||
|
)
|
||||||
|
self.result = result
|
||||||
|
|
||||||
|
|
||||||
class UnrecognizedRequestError(SynapseError):
|
class UnrecognizedRequestError(SynapseError):
|
||||||
"""An error indicating we don't understand the request you're trying to make"""
|
"""An error indicating we don't understand the request you're trying to make"""
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
|
|
@ -43,7 +43,6 @@ from synapse.rest import ClientRestResource
|
||||||
from synapse.rest.key.v1.server_key_resource import LocalKey
|
from synapse.rest.key.v1.server_key_resource import LocalKey
|
||||||
from synapse.rest.key.v2 import KeyApiV2Resource
|
from synapse.rest.key.v2 import KeyApiV2Resource
|
||||||
from synapse.rest.media.v0.content_repository import ContentRepoResource
|
from synapse.rest.media.v0.content_repository import ContentRepoResource
|
||||||
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
|
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.storage import are_all_users_on_domain
|
from synapse.storage import are_all_users_on_domain
|
||||||
from synapse.storage.engines import IncorrectDatabaseSetup, create_engine
|
from synapse.storage.engines import IncorrectDatabaseSetup, create_engine
|
||||||
|
@ -195,14 +194,19 @@ class SynapseHomeServer(HomeServer):
|
||||||
})
|
})
|
||||||
|
|
||||||
if name in ["media", "federation", "client"]:
|
if name in ["media", "federation", "client"]:
|
||||||
media_repo = MediaRepositoryResource(self)
|
if self.get_config().enable_media_repo:
|
||||||
resources.update({
|
media_repo = self.get_media_repository_resource()
|
||||||
MEDIA_PREFIX: media_repo,
|
resources.update({
|
||||||
LEGACY_MEDIA_PREFIX: media_repo,
|
MEDIA_PREFIX: media_repo,
|
||||||
CONTENT_REPO_PREFIX: ContentRepoResource(
|
LEGACY_MEDIA_PREFIX: media_repo,
|
||||||
self, self.config.uploads_path
|
CONTENT_REPO_PREFIX: ContentRepoResource(
|
||||||
),
|
self, self.config.uploads_path
|
||||||
})
|
),
|
||||||
|
})
|
||||||
|
elif name == "media":
|
||||||
|
raise ConfigError(
|
||||||
|
"'media' resource conflicts with enable_media_repo=False",
|
||||||
|
)
|
||||||
|
|
||||||
if name in ["keys", "federation"]:
|
if name in ["keys", "federation"]:
|
||||||
resources.update({
|
resources.update({
|
||||||
|
|
|
@ -35,7 +35,6 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto
|
||||||
from synapse.replication.slave.storage.transactions import TransactionStore
|
from synapse.replication.slave.storage.transactions import TransactionStore
|
||||||
from synapse.replication.tcp.client import ReplicationClientHandler
|
from synapse.replication.tcp.client import ReplicationClientHandler
|
||||||
from synapse.rest.media.v0.content_repository import ContentRepoResource
|
from synapse.rest.media.v0.content_repository import ContentRepoResource
|
||||||
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
|
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.storage.engines import create_engine
|
from synapse.storage.engines import create_engine
|
||||||
from synapse.storage.media_repository import MediaRepositoryStore
|
from synapse.storage.media_repository import MediaRepositoryStore
|
||||||
|
@ -89,7 +88,7 @@ class MediaRepositoryServer(HomeServer):
|
||||||
if name == "metrics":
|
if name == "metrics":
|
||||||
resources[METRICS_PREFIX] = MetricsResource(self)
|
resources[METRICS_PREFIX] = MetricsResource(self)
|
||||||
elif name == "media":
|
elif name == "media":
|
||||||
media_repo = MediaRepositoryResource(self)
|
media_repo = self.get_media_repository_resource()
|
||||||
resources.update({
|
resources.update({
|
||||||
MEDIA_PREFIX: media_repo,
|
MEDIA_PREFIX: media_repo,
|
||||||
LEGACY_MEDIA_PREFIX: media_repo,
|
LEGACY_MEDIA_PREFIX: media_repo,
|
||||||
|
@ -151,6 +150,13 @@ def start(config_options):
|
||||||
|
|
||||||
assert config.worker_app == "synapse.app.media_repository"
|
assert config.worker_app == "synapse.app.media_repository"
|
||||||
|
|
||||||
|
if config.enable_media_repo:
|
||||||
|
_base.quit_with_error(
|
||||||
|
"enable_media_repo must be disabled in the main synapse process\n"
|
||||||
|
"before the media repo can be run in a separate worker.\n"
|
||||||
|
"Please add ``enable_media_repo: false`` to the main config\n"
|
||||||
|
)
|
||||||
|
|
||||||
setup_logging(config, use_worker_options=True)
|
setup_logging(config, use_worker_options=True)
|
||||||
|
|
||||||
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||||
|
|
|
@ -340,11 +340,10 @@ class SyncReplicationHandler(ReplicationClientHandler):
|
||||||
|
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.typing_handler = hs.get_typing_handler()
|
self.typing_handler = hs.get_typing_handler()
|
||||||
|
# NB this is a SynchrotronPresence, not a normal PresenceHandler
|
||||||
self.presence_handler = hs.get_presence_handler()
|
self.presence_handler = hs.get_presence_handler()
|
||||||
self.notifier = hs.get_notifier()
|
self.notifier = hs.get_notifier()
|
||||||
|
|
||||||
self.presence_handler.sync_callback = self.send_user_sync
|
|
||||||
|
|
||||||
def on_rdata(self, stream_name, token, rows):
|
def on_rdata(self, stream_name, token, rows):
|
||||||
super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows)
|
super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows)
|
||||||
|
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
||||||
|
from synapse.types import GroupID, get_domain_from_id
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
@ -81,12 +82,13 @@ class ApplicationService(object):
|
||||||
# values.
|
# values.
|
||||||
NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS]
|
NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS]
|
||||||
|
|
||||||
def __init__(self, token, url=None, namespaces=None, hs_token=None,
|
def __init__(self, token, hostname, url=None, namespaces=None, hs_token=None,
|
||||||
sender=None, id=None, protocols=None, rate_limited=True):
|
sender=None, id=None, protocols=None, rate_limited=True):
|
||||||
self.token = token
|
self.token = token
|
||||||
self.url = url
|
self.url = url
|
||||||
self.hs_token = hs_token
|
self.hs_token = hs_token
|
||||||
self.sender = sender
|
self.sender = sender
|
||||||
|
self.server_name = hostname
|
||||||
self.namespaces = self._check_namespaces(namespaces)
|
self.namespaces = self._check_namespaces(namespaces)
|
||||||
self.id = id
|
self.id = id
|
||||||
|
|
||||||
|
@ -125,6 +127,24 @@ class ApplicationService(object):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Expected bool for 'exclusive' in ns '%s'" % ns
|
"Expected bool for 'exclusive' in ns '%s'" % ns
|
||||||
)
|
)
|
||||||
|
group_id = regex_obj.get("group_id")
|
||||||
|
if group_id:
|
||||||
|
if not isinstance(group_id, str):
|
||||||
|
raise ValueError(
|
||||||
|
"Expected string for 'group_id' in ns '%s'" % ns
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
GroupID.from_string(group_id)
|
||||||
|
except Exception:
|
||||||
|
raise ValueError(
|
||||||
|
"Expected valid group ID for 'group_id' in ns '%s'" % ns
|
||||||
|
)
|
||||||
|
|
||||||
|
if get_domain_from_id(group_id) != self.server_name:
|
||||||
|
raise ValueError(
|
||||||
|
"Expected 'group_id' to be this host in ns '%s'" % ns
|
||||||
|
)
|
||||||
|
|
||||||
regex = regex_obj.get("regex")
|
regex = regex_obj.get("regex")
|
||||||
if isinstance(regex, basestring):
|
if isinstance(regex, basestring):
|
||||||
regex_obj["regex"] = re.compile(regex) # Pre-compile regex
|
regex_obj["regex"] = re.compile(regex) # Pre-compile regex
|
||||||
|
@ -251,6 +271,21 @@ class ApplicationService(object):
|
||||||
if regex_obj["exclusive"]
|
if regex_obj["exclusive"]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def get_groups_for_user(self, user_id):
|
||||||
|
"""Get the groups that this user is associated with by this AS
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str): The ID of the user.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
iterable[str]: an iterable that yields group_id strings.
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
regex_obj["group_id"]
|
||||||
|
for regex_obj in self.namespaces[ApplicationService.NS_USERS]
|
||||||
|
if "group_id" in regex_obj and regex_obj["regex"].match(user_id)
|
||||||
|
)
|
||||||
|
|
||||||
def is_rate_limited(self):
|
def is_rate_limited(self):
|
||||||
return self.rate_limited
|
return self.rate_limited
|
||||||
|
|
||||||
|
|
|
@ -154,6 +154,7 @@ def _load_appservice(hostname, as_info, config_filename):
|
||||||
)
|
)
|
||||||
return ApplicationService(
|
return ApplicationService(
|
||||||
token=as_info["as_token"],
|
token=as_info["as_token"],
|
||||||
|
hostname=hostname,
|
||||||
url=as_info["url"],
|
url=as_info["url"],
|
||||||
namespaces=as_info["namespaces"],
|
namespaces=as_info["namespaces"],
|
||||||
hs_token=as_info["hs_token"],
|
hs_token=as_info["hs_token"],
|
||||||
|
|
|
@ -36,6 +36,7 @@ from .workers import WorkerConfig
|
||||||
from .push import PushConfig
|
from .push import PushConfig
|
||||||
from .spam_checker import SpamCheckerConfig
|
from .spam_checker import SpamCheckerConfig
|
||||||
from .groups import GroupsConfig
|
from .groups import GroupsConfig
|
||||||
|
from .user_directory import UserDirectoryConfig
|
||||||
|
|
||||||
|
|
||||||
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
|
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
|
||||||
|
@ -44,7 +45,7 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
|
||||||
AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
|
AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
|
||||||
JWTConfig, PasswordConfig, EmailConfig,
|
JWTConfig, PasswordConfig, EmailConfig,
|
||||||
WorkerConfig, PasswordAuthProviderConfig, PushConfig,
|
WorkerConfig, PasswordAuthProviderConfig, PushConfig,
|
||||||
SpamCheckerConfig, GroupsConfig,):
|
SpamCheckerConfig, GroupsConfig, UserDirectoryConfig,):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2015, 2016 OpenMarket Ltd
|
# Copyright 2015, 2016 OpenMarket Ltd
|
||||||
|
# Copyright 2017 New Vector Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -18,28 +19,43 @@ from ._base import Config
|
||||||
|
|
||||||
class PushConfig(Config):
|
class PushConfig(Config):
|
||||||
def read_config(self, config):
|
def read_config(self, config):
|
||||||
self.push_redact_content = False
|
push_config = config.get("push", {})
|
||||||
|
self.push_include_content = push_config.get("include_content", True)
|
||||||
|
|
||||||
|
# There was a a 'redact_content' setting but mistakenly read from the
|
||||||
|
# 'email'section'. Check for the flag in the 'push' section, and log,
|
||||||
|
# but do not honour it to avoid nasty surprises when people upgrade.
|
||||||
|
if push_config.get("redact_content") is not None:
|
||||||
|
print(
|
||||||
|
"The push.redact_content content option has never worked. "
|
||||||
|
"Please set push.include_content if you want this behaviour"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now check for the one in the 'email' section and honour it,
|
||||||
|
# with a warning.
|
||||||
push_config = config.get("email", {})
|
push_config = config.get("email", {})
|
||||||
self.push_redact_content = push_config.get("redact_content", False)
|
redact_content = push_config.get("redact_content")
|
||||||
|
if redact_content is not None:
|
||||||
|
print(
|
||||||
|
"The 'email.redact_content' option is deprecated: "
|
||||||
|
"please set push.include_content instead"
|
||||||
|
)
|
||||||
|
self.push_include_content = not redact_content
|
||||||
|
|
||||||
def default_config(self, config_dir_path, server_name, **kwargs):
|
def default_config(self, config_dir_path, server_name, **kwargs):
|
||||||
return """
|
return """
|
||||||
# Control how push messages are sent to google/apple to notifications.
|
# Clients requesting push notifications can either have the body of
|
||||||
# Normally every message said in a room with one or more people using
|
# the message sent in the notification poke along with other details
|
||||||
# mobile devices will be posted to a push server hosted by matrix.org
|
# like the sender, or just the event ID and room ID (`event_id_only`).
|
||||||
# which is registered with google and apple in order to allow push
|
# If clients choose the former, this option controls whether the
|
||||||
# notifications to be sent to these mobile devices.
|
# notification request includes the content of the event (other details
|
||||||
#
|
# like the sender are still included). For `event_id_only` push, it
|
||||||
# Setting redact_content to true will make the push messages contain no
|
# has no effect.
|
||||||
# message content which will provide increased privacy. This is a
|
|
||||||
# temporary solution pending improvements to Android and iPhone apps
|
|
||||||
# to get content from the app rather than the notification.
|
|
||||||
#
|
|
||||||
# For modern android devices the notification content will still appear
|
# For modern android devices the notification content will still appear
|
||||||
# because it is loaded by the app. iPhone, however will send a
|
# because it is loaded by the app. iPhone, however will send a
|
||||||
# notification saying only that a message arrived and who it came from.
|
# notification saying only that a message arrived and who it came from.
|
||||||
#
|
#
|
||||||
#push:
|
#push:
|
||||||
# redact_content: false
|
# include_content: true
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -41,6 +41,12 @@ class ServerConfig(Config):
|
||||||
# false only if we are updating the user directory in a worker
|
# false only if we are updating the user directory in a worker
|
||||||
self.update_user_directory = config.get("update_user_directory", True)
|
self.update_user_directory = config.get("update_user_directory", True)
|
||||||
|
|
||||||
|
# whether to enable the media repository endpoints. This should be set
|
||||||
|
# to false if the media repository is running as a separate endpoint;
|
||||||
|
# doing so ensures that we will not run cache cleanup jobs on the
|
||||||
|
# master, potentially causing inconsistency.
|
||||||
|
self.enable_media_repo = config.get("enable_media_repo", True)
|
||||||
|
|
||||||
self.filter_timeline_limit = config.get("filter_timeline_limit", -1)
|
self.filter_timeline_limit = config.get("filter_timeline_limit", -1)
|
||||||
|
|
||||||
# Whether we should block invites sent to users on this server
|
# Whether we should block invites sent to users on this server
|
||||||
|
|
44
synapse/config/user_directory.py
Normal file
44
synapse/config/user_directory.py
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2017 New Vector Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from ._base import Config
|
||||||
|
|
||||||
|
|
||||||
|
class UserDirectoryConfig(Config):
|
||||||
|
"""User Directory Configuration
|
||||||
|
Configuration for the behaviour of the /user_directory API
|
||||||
|
"""
|
||||||
|
|
||||||
|
def read_config(self, config):
|
||||||
|
self.user_directory_search_all_users = False
|
||||||
|
user_directory_config = config.get("user_directory", None)
|
||||||
|
if user_directory_config:
|
||||||
|
self.user_directory_search_all_users = (
|
||||||
|
user_directory_config.get("search_all_users", False)
|
||||||
|
)
|
||||||
|
|
||||||
|
def default_config(self, config_dir_path, server_name, **kwargs):
|
||||||
|
return """
|
||||||
|
# User Directory configuration
|
||||||
|
#
|
||||||
|
# 'search_all_users' defines whether to search all users visible to your HS
|
||||||
|
# when searching the user directory, rather than limiting to users visible
|
||||||
|
# in public rooms. Defaults to false. If you set it True, you'll have to run
|
||||||
|
# UPDATE user_directory_stream_pos SET stream_id = NULL;
|
||||||
|
# on your database to tell it to rebuild the user_directory search indexes.
|
||||||
|
#
|
||||||
|
#user_directory:
|
||||||
|
# search_all_users: false
|
||||||
|
"""
|
|
@ -32,15 +32,22 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
|
||||||
"""Check whether the hash for this PDU matches the contents"""
|
"""Check whether the hash for this PDU matches the contents"""
|
||||||
name, expected_hash = compute_content_hash(event, hash_algorithm)
|
name, expected_hash = compute_content_hash(event, hash_algorithm)
|
||||||
logger.debug("Expecting hash: %s", encode_base64(expected_hash))
|
logger.debug("Expecting hash: %s", encode_base64(expected_hash))
|
||||||
if name not in event.hashes:
|
|
||||||
|
# some malformed events lack a 'hashes'. Protect against it being missing
|
||||||
|
# or a weird type by basically treating it the same as an unhashed event.
|
||||||
|
hashes = event.get("hashes")
|
||||||
|
if not isinstance(hashes, dict):
|
||||||
|
raise SynapseError(400, "Malformed 'hashes'", Codes.UNAUTHORIZED)
|
||||||
|
|
||||||
|
if name not in hashes:
|
||||||
raise SynapseError(
|
raise SynapseError(
|
||||||
400,
|
400,
|
||||||
"Algorithm %s not in hashes %s" % (
|
"Algorithm %s not in hashes %s" % (
|
||||||
name, list(event.hashes),
|
name, list(hashes),
|
||||||
),
|
),
|
||||||
Codes.UNAUTHORIZED,
|
Codes.UNAUTHORIZED,
|
||||||
)
|
)
|
||||||
message_hash_base64 = event.hashes[name]
|
message_hash_base64 = hashes[name]
|
||||||
try:
|
try:
|
||||||
message_hash_bytes = decode_base64(message_hash_base64)
|
message_hash_bytes = decode_base64(message_hash_base64)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
|
@ -25,7 +25,7 @@ from synapse.api.errors import (
|
||||||
from synapse.util import unwrapFirstError, logcontext
|
from synapse.util import unwrapFirstError, logcontext
|
||||||
from synapse.util.caches.expiringcache import ExpiringCache
|
from synapse.util.caches.expiringcache import ExpiringCache
|
||||||
from synapse.util.logutils import log_function
|
from synapse.util.logutils import log_function
|
||||||
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
|
||||||
from synapse.events import FrozenEvent, builder
|
from synapse.events import FrozenEvent, builder
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
|
|
||||||
|
@ -420,7 +420,7 @@ class FederationClient(FederationBase):
|
||||||
for e_id in batch
|
for e_id in batch
|
||||||
]
|
]
|
||||||
|
|
||||||
res = yield preserve_context_over_deferred(
|
res = yield make_deferred_yieldable(
|
||||||
defer.DeferredList(deferreds, consumeErrors=True)
|
defer.DeferredList(deferreds, consumeErrors=True)
|
||||||
)
|
)
|
||||||
for success, result in res:
|
for success, result in res:
|
||||||
|
|
|
@ -20,7 +20,7 @@ from .persistence import TransactionActions
|
||||||
from .units import Transaction, Edu
|
from .units import Transaction, Edu
|
||||||
|
|
||||||
from synapse.api.errors import HttpResponseException
|
from synapse.api.errors import HttpResponseException
|
||||||
from synapse.util import logcontext
|
from synapse.util import logcontext, PreserveLoggingContext
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor
|
||||||
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
|
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
|
||||||
from synapse.util.metrics import measure_func
|
from synapse.util.metrics import measure_func
|
||||||
|
@ -146,7 +146,6 @@ class TransactionQueue(object):
|
||||||
else:
|
else:
|
||||||
return not destination.startswith("localhost")
|
return not destination.startswith("localhost")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def notify_new_events(self, current_id):
|
def notify_new_events(self, current_id):
|
||||||
"""This gets called when we have some new events we might want to
|
"""This gets called when we have some new events we might want to
|
||||||
send out to other servers.
|
send out to other servers.
|
||||||
|
@ -156,6 +155,13 @@ class TransactionQueue(object):
|
||||||
if self._is_processing:
|
if self._is_processing:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# fire off a processing loop in the background. It's likely it will
|
||||||
|
# outlast the current request, so run it in the sentinel logcontext.
|
||||||
|
with PreserveLoggingContext():
|
||||||
|
self._process_event_queue_loop()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _process_event_queue_loop(self):
|
||||||
try:
|
try:
|
||||||
self._is_processing = True
|
self._is_processing = True
|
||||||
while True:
|
while True:
|
||||||
|
|
|
@ -17,7 +17,7 @@ from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
@ -159,7 +159,7 @@ class ApplicationServicesHandler(object):
|
||||||
def query_3pe(self, kind, protocol, fields):
|
def query_3pe(self, kind, protocol, fields):
|
||||||
services = yield self._get_services_for_3pn(protocol)
|
services = yield self._get_services_for_3pn(protocol)
|
||||||
|
|
||||||
results = yield preserve_context_over_deferred(defer.DeferredList([
|
results = yield make_deferred_yieldable(defer.DeferredList([
|
||||||
preserve_fn(self.appservice_api.query_3pe)(service, kind, protocol, fields)
|
preserve_fn(self.appservice_api.query_3pe)(service, kind, protocol, fields)
|
||||||
for service in services
|
for service in services
|
||||||
], consumeErrors=True))
|
], consumeErrors=True))
|
||||||
|
|
|
@ -17,7 +17,10 @@ from twisted.internet import defer
|
||||||
|
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
from synapse.api.constants import LoginType
|
from synapse.api.constants import LoginType
|
||||||
from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError
|
from synapse.api.errors import (
|
||||||
|
AuthError, Codes, InteractiveAuthIncompleteError, LoginError, StoreError,
|
||||||
|
SynapseError,
|
||||||
|
)
|
||||||
from synapse.module_api import ModuleApi
|
from synapse.module_api import ModuleApi
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor
|
||||||
|
@ -46,7 +49,6 @@ class AuthHandler(BaseHandler):
|
||||||
"""
|
"""
|
||||||
super(AuthHandler, self).__init__(hs)
|
super(AuthHandler, self).__init__(hs)
|
||||||
self.checkers = {
|
self.checkers = {
|
||||||
LoginType.PASSWORD: self._check_password_auth,
|
|
||||||
LoginType.RECAPTCHA: self._check_recaptcha,
|
LoginType.RECAPTCHA: self._check_recaptcha,
|
||||||
LoginType.EMAIL_IDENTITY: self._check_email_identity,
|
LoginType.EMAIL_IDENTITY: self._check_email_identity,
|
||||||
LoginType.MSISDN: self._check_msisdn,
|
LoginType.MSISDN: self._check_msisdn,
|
||||||
|
@ -75,15 +77,76 @@ class AuthHandler(BaseHandler):
|
||||||
self.macaroon_gen = hs.get_macaroon_generator()
|
self.macaroon_gen = hs.get_macaroon_generator()
|
||||||
self._password_enabled = hs.config.password_enabled
|
self._password_enabled = hs.config.password_enabled
|
||||||
|
|
||||||
login_types = set()
|
# we keep this as a list despite the O(N^2) implication so that we can
|
||||||
|
# keep PASSWORD first and avoid confusing clients which pick the first
|
||||||
|
# type in the list. (NB that the spec doesn't require us to do so and
|
||||||
|
# clients which favour types that they don't understand over those that
|
||||||
|
# they do are technically broken)
|
||||||
|
login_types = []
|
||||||
if self._password_enabled:
|
if self._password_enabled:
|
||||||
login_types.add(LoginType.PASSWORD)
|
login_types.append(LoginType.PASSWORD)
|
||||||
for provider in self.password_providers:
|
for provider in self.password_providers:
|
||||||
if hasattr(provider, "get_supported_login_types"):
|
if hasattr(provider, "get_supported_login_types"):
|
||||||
login_types.update(
|
for t in provider.get_supported_login_types().keys():
|
||||||
provider.get_supported_login_types().keys()
|
if t not in login_types:
|
||||||
)
|
login_types.append(t)
|
||||||
self._supported_login_types = frozenset(login_types)
|
self._supported_login_types = login_types
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def validate_user_via_ui_auth(self, requester, request_body, clientip):
|
||||||
|
"""
|
||||||
|
Checks that the user is who they claim to be, via a UI auth.
|
||||||
|
|
||||||
|
This is used for things like device deletion and password reset where
|
||||||
|
the user already has a valid access token, but we want to double-check
|
||||||
|
that it isn't stolen by re-authenticating them.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
requester (Requester): The user, as given by the access token
|
||||||
|
|
||||||
|
request_body (dict): The body of the request sent by the client
|
||||||
|
|
||||||
|
clientip (str): The IP address of the client.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
defer.Deferred[dict]: the parameters for this request (which may
|
||||||
|
have been given only in a previous call).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InteractiveAuthIncompleteError if the client has not yet completed
|
||||||
|
any of the permitted login flows
|
||||||
|
|
||||||
|
AuthError if the client has completed a login flow, and it gives
|
||||||
|
a different user to `requester`
|
||||||
|
"""
|
||||||
|
|
||||||
|
# build a list of supported flows
|
||||||
|
flows = [
|
||||||
|
[login_type] for login_type in self._supported_login_types
|
||||||
|
]
|
||||||
|
|
||||||
|
result, params, _ = yield self.check_auth(
|
||||||
|
flows, request_body, clientip,
|
||||||
|
)
|
||||||
|
|
||||||
|
# find the completed login type
|
||||||
|
for login_type in self._supported_login_types:
|
||||||
|
if login_type not in result:
|
||||||
|
continue
|
||||||
|
|
||||||
|
user_id = result[login_type]
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# this can't happen
|
||||||
|
raise Exception(
|
||||||
|
"check_auth returned True but no successful login type",
|
||||||
|
)
|
||||||
|
|
||||||
|
# check that the UI auth matched the access token
|
||||||
|
if user_id != requester.user.to_string():
|
||||||
|
raise AuthError(403, "Invalid auth")
|
||||||
|
|
||||||
|
defer.returnValue(params)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def check_auth(self, flows, clientdict, clientip):
|
def check_auth(self, flows, clientdict, clientip):
|
||||||
|
@ -95,26 +158,36 @@ class AuthHandler(BaseHandler):
|
||||||
session with a map, which maps each auth-type (str) to the relevant
|
session with a map, which maps each auth-type (str) to the relevant
|
||||||
identity authenticated by that auth-type (mostly str, but for captcha, bool).
|
identity authenticated by that auth-type (mostly str, but for captcha, bool).
|
||||||
|
|
||||||
|
If no auth flows have been completed successfully, raises an
|
||||||
|
InteractiveAuthIncompleteError. To handle this, you can use
|
||||||
|
synapse.rest.client.v2_alpha._base.interactive_auth_handler as a
|
||||||
|
decorator.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
flows (list): A list of login flows. Each flow is an ordered list of
|
flows (list): A list of login flows. Each flow is an ordered list of
|
||||||
strings representing auth-types. At least one full
|
strings representing auth-types. At least one full
|
||||||
flow must be completed in order for auth to be successful.
|
flow must be completed in order for auth to be successful.
|
||||||
|
|
||||||
clientdict: The dictionary from the client root level, not the
|
clientdict: The dictionary from the client root level, not the
|
||||||
'auth' key: this method prompts for auth if none is sent.
|
'auth' key: this method prompts for auth if none is sent.
|
||||||
|
|
||||||
clientip (str): The IP address of the client.
|
clientip (str): The IP address of the client.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of (authed, dict, dict, session_id) where authed is true if
|
defer.Deferred[dict, dict, str]: a deferred tuple of
|
||||||
the client has successfully completed an auth flow. If it is true
|
(creds, params, session_id).
|
||||||
the first dict contains the authenticated credentials of each stage.
|
|
||||||
|
|
||||||
If authed is false, the first dictionary is the server response to
|
'creds' contains the authenticated credentials of each stage.
|
||||||
the login request and should be passed back to the client.
|
|
||||||
|
|
||||||
In either case, the second dict contains the parameters for this
|
'params' contains the parameters for this request (which may
|
||||||
request (which may have been given only in a previous call).
|
have been given only in a previous call).
|
||||||
|
|
||||||
session_id is the ID of this session, either passed in by the client
|
'session_id' is the ID of this session, either passed in by the
|
||||||
or assigned by the call to check_auth
|
client or assigned by this call
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InteractiveAuthIncompleteError if the client has not yet completed
|
||||||
|
all the stages in any of the permitted flows.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
authdict = None
|
authdict = None
|
||||||
|
@ -142,11 +215,8 @@ class AuthHandler(BaseHandler):
|
||||||
clientdict = session['clientdict']
|
clientdict = session['clientdict']
|
||||||
|
|
||||||
if not authdict:
|
if not authdict:
|
||||||
defer.returnValue(
|
raise InteractiveAuthIncompleteError(
|
||||||
(
|
self._auth_dict_for_flows(flows, session),
|
||||||
False, self._auth_dict_for_flows(flows, session),
|
|
||||||
clientdict, session['id']
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if 'creds' not in session:
|
if 'creds' not in session:
|
||||||
|
@ -157,14 +227,12 @@ class AuthHandler(BaseHandler):
|
||||||
errordict = {}
|
errordict = {}
|
||||||
if 'type' in authdict:
|
if 'type' in authdict:
|
||||||
login_type = authdict['type']
|
login_type = authdict['type']
|
||||||
if login_type not in self.checkers:
|
|
||||||
raise LoginError(400, "", Codes.UNRECOGNIZED)
|
|
||||||
try:
|
try:
|
||||||
result = yield self.checkers[login_type](authdict, clientip)
|
result = yield self._check_auth_dict(authdict, clientip)
|
||||||
if result:
|
if result:
|
||||||
creds[login_type] = result
|
creds[login_type] = result
|
||||||
self._save_session(session)
|
self._save_session(session)
|
||||||
except LoginError, e:
|
except LoginError as e:
|
||||||
if login_type == LoginType.EMAIL_IDENTITY:
|
if login_type == LoginType.EMAIL_IDENTITY:
|
||||||
# riot used to have a bug where it would request a new
|
# riot used to have a bug where it would request a new
|
||||||
# validation token (thus sending a new email) each time it
|
# validation token (thus sending a new email) each time it
|
||||||
|
@ -173,7 +241,7 @@ class AuthHandler(BaseHandler):
|
||||||
#
|
#
|
||||||
# Grandfather in the old behaviour for now to avoid
|
# Grandfather in the old behaviour for now to avoid
|
||||||
# breaking old riot deployments.
|
# breaking old riot deployments.
|
||||||
raise e
|
raise
|
||||||
|
|
||||||
# this step failed. Merge the error dict into the response
|
# this step failed. Merge the error dict into the response
|
||||||
# so that the client can have another go.
|
# so that the client can have another go.
|
||||||
|
@ -190,12 +258,14 @@ class AuthHandler(BaseHandler):
|
||||||
"Auth completed with creds: %r. Client dict has keys: %r",
|
"Auth completed with creds: %r. Client dict has keys: %r",
|
||||||
creds, clientdict.keys()
|
creds, clientdict.keys()
|
||||||
)
|
)
|
||||||
defer.returnValue((True, creds, clientdict, session['id']))
|
defer.returnValue((creds, clientdict, session['id']))
|
||||||
|
|
||||||
ret = self._auth_dict_for_flows(flows, session)
|
ret = self._auth_dict_for_flows(flows, session)
|
||||||
ret['completed'] = creds.keys()
|
ret['completed'] = creds.keys()
|
||||||
ret.update(errordict)
|
ret.update(errordict)
|
||||||
defer.returnValue((False, ret, clientdict, session['id']))
|
raise InteractiveAuthIncompleteError(
|
||||||
|
ret,
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def add_oob_auth(self, stagetype, authdict, clientip):
|
def add_oob_auth(self, stagetype, authdict, clientip):
|
||||||
|
@ -268,17 +338,35 @@ class AuthHandler(BaseHandler):
|
||||||
return sess.setdefault('serverdict', {}).get(key, default)
|
return sess.setdefault('serverdict', {}).get(key, default)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _check_password_auth(self, authdict, _):
|
def _check_auth_dict(self, authdict, clientip):
|
||||||
if "user" not in authdict or "password" not in authdict:
|
"""Attempt to validate the auth dict provided by a client
|
||||||
raise LoginError(400, "", Codes.MISSING_PARAM)
|
|
||||||
|
|
||||||
user_id = authdict["user"]
|
Args:
|
||||||
password = authdict["password"]
|
authdict (object): auth dict provided by the client
|
||||||
|
clientip (str): IP address of the client
|
||||||
|
|
||||||
(canonical_id, callback) = yield self.validate_login(user_id, {
|
Returns:
|
||||||
"type": LoginType.PASSWORD,
|
Deferred: result of the stage verification.
|
||||||
"password": password,
|
|
||||||
})
|
Raises:
|
||||||
|
StoreError if there was a problem accessing the database
|
||||||
|
SynapseError if there was a problem with the request
|
||||||
|
LoginError if there was an authentication problem.
|
||||||
|
"""
|
||||||
|
login_type = authdict['type']
|
||||||
|
checker = self.checkers.get(login_type)
|
||||||
|
if checker is not None:
|
||||||
|
res = yield checker(authdict, clientip)
|
||||||
|
defer.returnValue(res)
|
||||||
|
|
||||||
|
# build a v1-login-style dict out of the authdict and fall back to the
|
||||||
|
# v1 code
|
||||||
|
user_id = authdict.get("user")
|
||||||
|
|
||||||
|
if user_id is None:
|
||||||
|
raise SynapseError(400, "", Codes.MISSING_PARAM)
|
||||||
|
|
||||||
|
(canonical_id, callback) = yield self.validate_login(user_id, authdict)
|
||||||
defer.returnValue(canonical_id)
|
defer.returnValue(canonical_id)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -649,41 +737,6 @@ class AuthHandler(BaseHandler):
|
||||||
except Exception:
|
except Exception:
|
||||||
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
|
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def set_password(self, user_id, newpassword, requester=None):
|
|
||||||
password_hash = self.hash(newpassword)
|
|
||||||
|
|
||||||
except_access_token_id = requester.access_token_id if requester else None
|
|
||||||
|
|
||||||
try:
|
|
||||||
yield self.store.user_set_password_hash(user_id, password_hash)
|
|
||||||
except StoreError as e:
|
|
||||||
if e.code == 404:
|
|
||||||
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
|
|
||||||
raise e
|
|
||||||
yield self.delete_access_tokens_for_user(
|
|
||||||
user_id, except_token_id=except_access_token_id,
|
|
||||||
)
|
|
||||||
yield self.hs.get_pusherpool().remove_pushers_by_user(
|
|
||||||
user_id, except_access_token_id
|
|
||||||
)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def deactivate_account(self, user_id):
|
|
||||||
"""Deactivate a user's account
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id (str): ID of user to be deactivated
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Deferred
|
|
||||||
"""
|
|
||||||
# FIXME: Theoretically there is a race here wherein user resets
|
|
||||||
# password using threepid.
|
|
||||||
yield self.delete_access_tokens_for_user(user_id)
|
|
||||||
yield self.store.user_delete_threepids(user_id)
|
|
||||||
yield self.store.user_set_password_hash(user_id, None)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def delete_access_token(self, access_token):
|
def delete_access_token(self, access_token):
|
||||||
"""Invalidate a single access token
|
"""Invalidate a single access token
|
||||||
|
@ -706,6 +759,12 @@ class AuthHandler(BaseHandler):
|
||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# delete pushers associated with this access token
|
||||||
|
if user_info["token_id"] is not None:
|
||||||
|
yield self.hs.get_pusherpool().remove_pushers_by_access_token(
|
||||||
|
str(user_info["user"]), (user_info["token_id"], )
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def delete_access_tokens_for_user(self, user_id, except_token_id=None,
|
def delete_access_tokens_for_user(self, user_id, except_token_id=None,
|
||||||
device_id=None):
|
device_id=None):
|
||||||
|
@ -728,13 +787,18 @@ class AuthHandler(BaseHandler):
|
||||||
# see if any of our auth providers want to know about this
|
# see if any of our auth providers want to know about this
|
||||||
for provider in self.password_providers:
|
for provider in self.password_providers:
|
||||||
if hasattr(provider, "on_logged_out"):
|
if hasattr(provider, "on_logged_out"):
|
||||||
for token, device_id in tokens_and_devices:
|
for token, token_id, device_id in tokens_and_devices:
|
||||||
yield provider.on_logged_out(
|
yield provider.on_logged_out(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
device_id=device_id,
|
device_id=device_id,
|
||||||
access_token=token,
|
access_token=token,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# delete pushers associated with the access tokens
|
||||||
|
yield self.hs.get_pusherpool().remove_pushers_by_access_token(
|
||||||
|
user_id, (token_id for _, token_id, _ in tokens_and_devices),
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def add_threepid(self, user_id, medium, address, validated_at):
|
def add_threepid(self, user_id, medium, address, validated_at):
|
||||||
# 'Canonicalise' email addresses down to lower case.
|
# 'Canonicalise' email addresses down to lower case.
|
||||||
|
|
52
synapse/handlers/deactivate_account.py
Normal file
52
synapse/handlers/deactivate_account.py
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2017 New Vector Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from ._base import BaseHandler
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DeactivateAccountHandler(BaseHandler):
|
||||||
|
"""Handler which deals with deactivating user accounts."""
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(DeactivateAccountHandler, self).__init__(hs)
|
||||||
|
self._auth_handler = hs.get_auth_handler()
|
||||||
|
self._device_handler = hs.get_device_handler()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def deactivate_account(self, user_id):
|
||||||
|
"""Deactivate a user's account
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str): ID of user to be deactivated
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred
|
||||||
|
"""
|
||||||
|
# FIXME: Theoretically there is a race here wherein user resets
|
||||||
|
# password using threepid.
|
||||||
|
|
||||||
|
# first delete any devices belonging to the user, which will also
|
||||||
|
# delete corresponding access tokens.
|
||||||
|
yield self._device_handler.delete_all_devices_for_user(user_id)
|
||||||
|
# then delete any remaining access tokens which weren't associated with
|
||||||
|
# a device.
|
||||||
|
yield self._auth_handler.delete_access_tokens_for_user(user_id)
|
||||||
|
|
||||||
|
yield self.store.user_delete_threepids(user_id)
|
||||||
|
yield self.store.user_set_password_hash(user_id, None)
|
|
@ -170,13 +170,31 @@ class DeviceHandler(BaseHandler):
|
||||||
|
|
||||||
yield self.notify_device_update(user_id, [device_id])
|
yield self.notify_device_update(user_id, [device_id])
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def delete_all_devices_for_user(self, user_id, except_device_id=None):
|
||||||
|
"""Delete all of the user's devices
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id (str):
|
||||||
|
except_device_id (str|None): optional device id which should not
|
||||||
|
be deleted
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
defer.Deferred:
|
||||||
|
"""
|
||||||
|
device_map = yield self.store.get_devices_by_user(user_id)
|
||||||
|
device_ids = device_map.keys()
|
||||||
|
if except_device_id is not None:
|
||||||
|
device_ids = [d for d in device_ids if d != except_device_id]
|
||||||
|
yield self.delete_devices(user_id, device_ids)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def delete_devices(self, user_id, device_ids):
|
def delete_devices(self, user_id, device_ids):
|
||||||
""" Delete several devices
|
""" Delete several devices
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id (str):
|
user_id (str):
|
||||||
device_ids (str): The list of device IDs to delete
|
device_ids (List[str]): The list of device IDs to delete
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
defer.Deferred:
|
defer.Deferred:
|
||||||
|
|
|
@ -375,6 +375,12 @@ class GroupsLocalHandler(object):
|
||||||
def get_publicised_groups_for_user(self, user_id):
|
def get_publicised_groups_for_user(self, user_id):
|
||||||
if self.hs.is_mine_id(user_id):
|
if self.hs.is_mine_id(user_id):
|
||||||
result = yield self.store.get_publicised_groups_for_user(user_id)
|
result = yield self.store.get_publicised_groups_for_user(user_id)
|
||||||
|
|
||||||
|
# Check AS associated groups for this user - this depends on the
|
||||||
|
# RegExps in the AS registration file (under `users`)
|
||||||
|
for app_service in self.store.get_app_services():
|
||||||
|
result.extend(app_service.get_groups_for_user(user_id))
|
||||||
|
|
||||||
defer.returnValue({"groups": result})
|
defer.returnValue({"groups": result})
|
||||||
else:
|
else:
|
||||||
result = yield self.transport_client.get_publicised_groups_for_user(
|
result = yield self.transport_client.get_publicised_groups_for_user(
|
||||||
|
@ -415,4 +421,9 @@ class GroupsLocalHandler(object):
|
||||||
uid
|
uid
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check AS associated groups for this user - this depends on the
|
||||||
|
# RegExps in the AS registration file (under `users`)
|
||||||
|
for app_service in self.store.get_app_services():
|
||||||
|
results[uid].extend(app_service.get_groups_for_user(uid))
|
||||||
|
|
||||||
defer.returnValue({"users": results})
|
defer.returnValue({"users": results})
|
||||||
|
|
|
@ -27,7 +27,7 @@ from synapse.types import (
|
||||||
from synapse.util import unwrapFirstError
|
from synapse.util import unwrapFirstError
|
||||||
from synapse.util.async import concurrently_execute
|
from synapse.util.async import concurrently_execute
|
||||||
from synapse.util.caches.snapshot_cache import SnapshotCache
|
from synapse.util.caches.snapshot_cache import SnapshotCache
|
||||||
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
|
||||||
from synapse.visibility import filter_events_for_client
|
from synapse.visibility import filter_events_for_client
|
||||||
|
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
|
@ -163,7 +163,7 @@ class InitialSyncHandler(BaseHandler):
|
||||||
lambda states: states[event.event_id]
|
lambda states: states[event.event_id]
|
||||||
)
|
)
|
||||||
|
|
||||||
(messages, token), current_state = yield preserve_context_over_deferred(
|
(messages, token), current_state = yield make_deferred_yieldable(
|
||||||
defer.gatherResults(
|
defer.gatherResults(
|
||||||
[
|
[
|
||||||
preserve_fn(self.store.get_recent_events_for_room)(
|
preserve_fn(self.store.get_recent_events_for_room)(
|
||||||
|
|
|
@ -1199,7 +1199,7 @@ def handle_timeout(state, is_mine, syncing_user_ids, now):
|
||||||
)
|
)
|
||||||
changed = True
|
changed = True
|
||||||
else:
|
else:
|
||||||
# We expect to be poked occaisonally by the other side.
|
# We expect to be poked occasionally by the other side.
|
||||||
# This is to protect against forgetful/buggy servers, so that
|
# This is to protect against forgetful/buggy servers, so that
|
||||||
# no one gets stuck online forever.
|
# no one gets stuck online forever.
|
||||||
if now - state.last_federation_update_ts > FEDERATION_TIMEOUT:
|
if now - state.last_federation_update_ts > FEDERATION_TIMEOUT:
|
||||||
|
|
|
@ -36,6 +36,8 @@ class ProfileHandler(BaseHandler):
|
||||||
"profile", self.on_profile_query
|
"profile", self.on_profile_query
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.user_directory_handler = hs.get_user_directory_handler()
|
||||||
|
|
||||||
self.clock.looping_call(self._update_remote_profile_cache, self.PROFILE_UPDATE_MS)
|
self.clock.looping_call(self._update_remote_profile_cache, self.PROFILE_UPDATE_MS)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -139,6 +141,12 @@ class ProfileHandler(BaseHandler):
|
||||||
target_user.localpart, new_displayname
|
target_user.localpart, new_displayname
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.hs.config.user_directory_search_all_users:
|
||||||
|
profile = yield self.store.get_profileinfo(target_user.localpart)
|
||||||
|
yield self.user_directory_handler.handle_local_profile_change(
|
||||||
|
target_user.to_string(), profile
|
||||||
|
)
|
||||||
|
|
||||||
yield self._update_join_states(requester, target_user)
|
yield self._update_join_states(requester, target_user)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
@ -183,6 +191,12 @@ class ProfileHandler(BaseHandler):
|
||||||
target_user.localpart, new_avatar_url
|
target_user.localpart, new_avatar_url
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.hs.config.user_directory_search_all_users:
|
||||||
|
profile = yield self.store.get_profileinfo(target_user.localpart)
|
||||||
|
yield self.user_directory_handler.handle_local_profile_change(
|
||||||
|
target_user.to_string(), profile
|
||||||
|
)
|
||||||
|
|
||||||
yield self._update_join_states(requester, target_user)
|
yield self._update_join_states(requester, target_user)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
|
@ -38,6 +38,7 @@ class RegistrationHandler(BaseHandler):
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self._auth_handler = hs.get_auth_handler()
|
self._auth_handler = hs.get_auth_handler()
|
||||||
self.profile_handler = hs.get_profile_handler()
|
self.profile_handler = hs.get_profile_handler()
|
||||||
|
self.user_directory_handler = hs.get_user_directory_handler()
|
||||||
self.captcha_client = CaptchaServerHttpClient(hs)
|
self.captcha_client = CaptchaServerHttpClient(hs)
|
||||||
|
|
||||||
self._next_generated_user_id = None
|
self._next_generated_user_id = None
|
||||||
|
@ -165,6 +166,13 @@ class RegistrationHandler(BaseHandler):
|
||||||
),
|
),
|
||||||
admin=admin,
|
admin=admin,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.hs.config.user_directory_search_all_users:
|
||||||
|
profile = yield self.store.get_profileinfo(localpart)
|
||||||
|
yield self.user_directory_handler.handle_local_profile_change(
|
||||||
|
user_id, profile
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# autogen a sequential user ID
|
# autogen a sequential user ID
|
||||||
attempts = 0
|
attempts = 0
|
||||||
|
|
|
@ -154,6 +154,8 @@ class RoomListHandler(BaseHandler):
|
||||||
# We want larger rooms to be first, hence negating num_joined_users
|
# We want larger rooms to be first, hence negating num_joined_users
|
||||||
rooms_to_order_value[room_id] = (-num_joined_users, room_id)
|
rooms_to_order_value[room_id] = (-num_joined_users, room_id)
|
||||||
|
|
||||||
|
logger.info("Getting ordering for %i rooms since %s",
|
||||||
|
len(room_ids), stream_token)
|
||||||
yield concurrently_execute(get_order_for_room, room_ids, 10)
|
yield concurrently_execute(get_order_for_room, room_ids, 10)
|
||||||
|
|
||||||
sorted_entries = sorted(rooms_to_order_value.items(), key=lambda e: e[1])
|
sorted_entries = sorted(rooms_to_order_value.items(), key=lambda e: e[1])
|
||||||
|
@ -181,34 +183,42 @@ class RoomListHandler(BaseHandler):
|
||||||
rooms_to_scan = rooms_to_scan[:since_token.current_limit]
|
rooms_to_scan = rooms_to_scan[:since_token.current_limit]
|
||||||
rooms_to_scan.reverse()
|
rooms_to_scan.reverse()
|
||||||
|
|
||||||
# Actually generate the entries. _append_room_entry_to_chunk will append to
|
logger.info("After sorting and filtering, %i rooms remain",
|
||||||
# chunk but will stop if len(chunk) > limit
|
len(rooms_to_scan))
|
||||||
chunk = []
|
|
||||||
if limit and not search_filter:
|
# _append_room_entry_to_chunk will append to chunk but will stop if
|
||||||
|
# len(chunk) > limit
|
||||||
|
#
|
||||||
|
# Normally we will generate enough results on the first iteration here,
|
||||||
|
# but if there is a search filter, _append_room_entry_to_chunk may
|
||||||
|
# filter some results out, in which case we loop again.
|
||||||
|
#
|
||||||
|
# We don't want to scan over the entire range either as that
|
||||||
|
# would potentially waste a lot of work.
|
||||||
|
#
|
||||||
|
# XXX if there is no limit, we may end up DoSing the server with
|
||||||
|
# calls to get_current_state_ids for every single room on the
|
||||||
|
# server. Surely we should cap this somehow?
|
||||||
|
#
|
||||||
|
if limit:
|
||||||
step = limit + 1
|
step = limit + 1
|
||||||
for i in xrange(0, len(rooms_to_scan), step):
|
|
||||||
# We iterate here because the vast majority of cases we'll stop
|
|
||||||
# at first iteration, but occaisonally _append_room_entry_to_chunk
|
|
||||||
# won't append to the chunk and so we need to loop again.
|
|
||||||
# We don't want to scan over the entire range either as that
|
|
||||||
# would potentially waste a lot of work.
|
|
||||||
yield concurrently_execute(
|
|
||||||
lambda r: self._append_room_entry_to_chunk(
|
|
||||||
r, rooms_to_num_joined[r],
|
|
||||||
chunk, limit, search_filter
|
|
||||||
),
|
|
||||||
rooms_to_scan[i:i + step], 10
|
|
||||||
)
|
|
||||||
if len(chunk) >= limit + 1:
|
|
||||||
break
|
|
||||||
else:
|
else:
|
||||||
|
step = len(rooms_to_scan)
|
||||||
|
|
||||||
|
chunk = []
|
||||||
|
for i in xrange(0, len(rooms_to_scan), step):
|
||||||
|
batch = rooms_to_scan[i:i + step]
|
||||||
|
logger.info("Processing %i rooms for result", len(batch))
|
||||||
yield concurrently_execute(
|
yield concurrently_execute(
|
||||||
lambda r: self._append_room_entry_to_chunk(
|
lambda r: self._append_room_entry_to_chunk(
|
||||||
r, rooms_to_num_joined[r],
|
r, rooms_to_num_joined[r],
|
||||||
chunk, limit, search_filter
|
chunk, limit, search_filter
|
||||||
),
|
),
|
||||||
rooms_to_scan, 5
|
batch, 5,
|
||||||
)
|
)
|
||||||
|
logger.info("Now %i rooms in result", len(chunk))
|
||||||
|
if len(chunk) >= limit + 1:
|
||||||
|
break
|
||||||
|
|
||||||
chunk.sort(key=lambda e: (-e["num_joined_members"], e["room_id"]))
|
chunk.sort(key=lambda e: (-e["num_joined_members"], e["room_id"]))
|
||||||
|
|
||||||
|
|
56
synapse/handlers/set_password.py
Normal file
56
synapse/handlers/set_password.py
Normal file
|
@ -0,0 +1,56 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2017 New Vector Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.api.errors import Codes, StoreError, SynapseError
|
||||||
|
from ._base import BaseHandler
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SetPasswordHandler(BaseHandler):
|
||||||
|
"""Handler which deals with changing user account passwords"""
|
||||||
|
def __init__(self, hs):
|
||||||
|
super(SetPasswordHandler, self).__init__(hs)
|
||||||
|
self._auth_handler = hs.get_auth_handler()
|
||||||
|
self._device_handler = hs.get_device_handler()
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def set_password(self, user_id, newpassword, requester=None):
|
||||||
|
password_hash = self._auth_handler.hash(newpassword)
|
||||||
|
|
||||||
|
except_device_id = requester.device_id if requester else None
|
||||||
|
except_access_token_id = requester.access_token_id if requester else None
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield self.store.user_set_password_hash(user_id, password_hash)
|
||||||
|
except StoreError as e:
|
||||||
|
if e.code == 404:
|
||||||
|
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
# we want to log out all of the user's other sessions. First delete
|
||||||
|
# all his other devices.
|
||||||
|
yield self._device_handler.delete_all_devices_for_user(
|
||||||
|
user_id, except_device_id=except_device_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# and now delete any access tokens which weren't associated with
|
||||||
|
# devices (or were associated with this device).
|
||||||
|
yield self._auth_handler.delete_access_tokens_for_user(
|
||||||
|
user_id, except_token_id=except_access_token_id,
|
||||||
|
)
|
|
@ -20,12 +20,13 @@ from synapse.api.constants import EventTypes, JoinRules, Membership
|
||||||
from synapse.storage.roommember import ProfileInfo
|
from synapse.storage.roommember import ProfileInfo
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
from synapse.util.async import sleep
|
from synapse.util.async import sleep
|
||||||
|
from synapse.types import get_localpart_from_id
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class UserDirectoyHandler(object):
|
class UserDirectoryHandler(object):
|
||||||
"""Handles querying of and keeping updated the user_directory.
|
"""Handles querying of and keeping updated the user_directory.
|
||||||
|
|
||||||
N.B.: ASSUMES IT IS THE ONLY THING THAT MODIFIES THE USER DIRECTORY
|
N.B.: ASSUMES IT IS THE ONLY THING THAT MODIFIES THE USER DIRECTORY
|
||||||
|
@ -41,9 +42,10 @@ class UserDirectoyHandler(object):
|
||||||
one public room.
|
one public room.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
INITIAL_SLEEP_MS = 50
|
INITIAL_ROOM_SLEEP_MS = 50
|
||||||
INITIAL_SLEEP_COUNT = 100
|
INITIAL_ROOM_SLEEP_COUNT = 100
|
||||||
INITIAL_BATCH_SIZE = 100
|
INITIAL_ROOM_BATCH_SIZE = 100
|
||||||
|
INITIAL_USER_SLEEP_MS = 10
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
|
@ -53,6 +55,7 @@ class UserDirectoyHandler(object):
|
||||||
self.notifier = hs.get_notifier()
|
self.notifier = hs.get_notifier()
|
||||||
self.is_mine_id = hs.is_mine_id
|
self.is_mine_id = hs.is_mine_id
|
||||||
self.update_user_directory = hs.config.update_user_directory
|
self.update_user_directory = hs.config.update_user_directory
|
||||||
|
self.search_all_users = hs.config.user_directory_search_all_users
|
||||||
|
|
||||||
# When start up for the first time we need to populate the user_directory.
|
# When start up for the first time we need to populate the user_directory.
|
||||||
# This is a set of user_id's we've inserted already
|
# This is a set of user_id's we've inserted already
|
||||||
|
@ -110,6 +113,15 @@ class UserDirectoyHandler(object):
|
||||||
finally:
|
finally:
|
||||||
self._is_processing = False
|
self._is_processing = False
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def handle_local_profile_change(self, user_id, profile):
|
||||||
|
"""Called to update index of our local user profiles when they change
|
||||||
|
irrespective of any rooms the user may be in.
|
||||||
|
"""
|
||||||
|
yield self.store.update_profile_in_user_dir(
|
||||||
|
user_id, profile.display_name, profile.avatar_url, None,
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _unsafe_process(self):
|
def _unsafe_process(self):
|
||||||
# If self.pos is None then means we haven't fetched it from DB
|
# If self.pos is None then means we haven't fetched it from DB
|
||||||
|
@ -148,16 +160,30 @@ class UserDirectoyHandler(object):
|
||||||
room_ids = yield self.store.get_all_rooms()
|
room_ids = yield self.store.get_all_rooms()
|
||||||
|
|
||||||
logger.info("Doing initial update of user directory. %d rooms", len(room_ids))
|
logger.info("Doing initial update of user directory. %d rooms", len(room_ids))
|
||||||
num_processed_rooms = 1
|
num_processed_rooms = 0
|
||||||
|
|
||||||
for room_id in room_ids:
|
for room_id in room_ids:
|
||||||
logger.info("Handling room %d/%d", num_processed_rooms, len(room_ids))
|
logger.info("Handling room %d/%d", num_processed_rooms + 1, len(room_ids))
|
||||||
yield self._handle_initial_room(room_id)
|
yield self._handle_initial_room(room_id)
|
||||||
num_processed_rooms += 1
|
num_processed_rooms += 1
|
||||||
yield sleep(self.INITIAL_SLEEP_MS / 1000.)
|
yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
|
||||||
|
|
||||||
logger.info("Processed all rooms.")
|
logger.info("Processed all rooms.")
|
||||||
|
|
||||||
|
if self.search_all_users:
|
||||||
|
num_processed_users = 0
|
||||||
|
user_ids = yield self.store.get_all_local_users()
|
||||||
|
logger.info("Doing initial update of user directory. %d users", len(user_ids))
|
||||||
|
for user_id in user_ids:
|
||||||
|
# We add profiles for all users even if they don't match the
|
||||||
|
# include pattern, just in case we want to change it in future
|
||||||
|
logger.info("Handling user %d/%d", num_processed_users + 1, len(user_ids))
|
||||||
|
yield self._handle_local_user(user_id)
|
||||||
|
num_processed_users += 1
|
||||||
|
yield sleep(self.INITIAL_USER_SLEEP_MS / 1000.)
|
||||||
|
|
||||||
|
logger.info("Processed all users")
|
||||||
|
|
||||||
self.initially_handled_users = None
|
self.initially_handled_users = None
|
||||||
self.initially_handled_users_in_public = None
|
self.initially_handled_users_in_public = None
|
||||||
self.initially_handled_users_share = None
|
self.initially_handled_users_share = None
|
||||||
|
@ -201,8 +227,8 @@ class UserDirectoyHandler(object):
|
||||||
to_update = set()
|
to_update = set()
|
||||||
count = 0
|
count = 0
|
||||||
for user_id in user_ids:
|
for user_id in user_ids:
|
||||||
if count % self.INITIAL_SLEEP_COUNT == 0:
|
if count % self.INITIAL_ROOM_SLEEP_COUNT == 0:
|
||||||
yield sleep(self.INITIAL_SLEEP_MS / 1000.)
|
yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
|
||||||
|
|
||||||
if not self.is_mine_id(user_id):
|
if not self.is_mine_id(user_id):
|
||||||
count += 1
|
count += 1
|
||||||
|
@ -216,8 +242,8 @@ class UserDirectoyHandler(object):
|
||||||
if user_id == other_user_id:
|
if user_id == other_user_id:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if count % self.INITIAL_SLEEP_COUNT == 0:
|
if count % self.INITIAL_ROOM_SLEEP_COUNT == 0:
|
||||||
yield sleep(self.INITIAL_SLEEP_MS / 1000.)
|
yield sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.)
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
user_set = (user_id, other_user_id)
|
user_set = (user_id, other_user_id)
|
||||||
|
@ -237,13 +263,13 @@ class UserDirectoyHandler(object):
|
||||||
else:
|
else:
|
||||||
self.initially_handled_users_share_private_room.add(user_set)
|
self.initially_handled_users_share_private_room.add(user_set)
|
||||||
|
|
||||||
if len(to_insert) > self.INITIAL_BATCH_SIZE:
|
if len(to_insert) > self.INITIAL_ROOM_BATCH_SIZE:
|
||||||
yield self.store.add_users_who_share_room(
|
yield self.store.add_users_who_share_room(
|
||||||
room_id, not is_public, to_insert,
|
room_id, not is_public, to_insert,
|
||||||
)
|
)
|
||||||
to_insert.clear()
|
to_insert.clear()
|
||||||
|
|
||||||
if len(to_update) > self.INITIAL_BATCH_SIZE:
|
if len(to_update) > self.INITIAL_ROOM_BATCH_SIZE:
|
||||||
yield self.store.update_users_who_share_room(
|
yield self.store.update_users_who_share_room(
|
||||||
room_id, not is_public, to_update,
|
room_id, not is_public, to_update,
|
||||||
)
|
)
|
||||||
|
@ -384,15 +410,29 @@ class UserDirectoyHandler(object):
|
||||||
for user_id in users:
|
for user_id in users:
|
||||||
yield self._handle_remove_user(room_id, user_id)
|
yield self._handle_remove_user(room_id, user_id)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def _handle_local_user(self, user_id):
|
||||||
|
"""Adds a new local roomless user into the user_directory_search table.
|
||||||
|
Used to populate up the user index when we have an
|
||||||
|
user_directory_search_all_users specified.
|
||||||
|
"""
|
||||||
|
logger.debug("Adding new local user to dir, %r", user_id)
|
||||||
|
|
||||||
|
profile = yield self.store.get_profileinfo(get_localpart_from_id(user_id))
|
||||||
|
|
||||||
|
row = yield self.store.get_user_in_directory(user_id)
|
||||||
|
if not row:
|
||||||
|
yield self.store.add_profiles_to_user_dir(None, {user_id: profile})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _handle_new_user(self, room_id, user_id, profile):
|
def _handle_new_user(self, room_id, user_id, profile):
|
||||||
"""Called when we might need to add user to directory
|
"""Called when we might need to add user to directory
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
room_id (str): room_id that user joined or started being public that
|
room_id (str): room_id that user joined or started being public
|
||||||
user_id (str)
|
user_id (str)
|
||||||
"""
|
"""
|
||||||
logger.debug("Adding user to dir, %r", user_id)
|
logger.debug("Adding new user to dir, %r", user_id)
|
||||||
|
|
||||||
row = yield self.store.get_user_in_directory(user_id)
|
row = yield self.store.get_user_in_directory(user_id)
|
||||||
if not row:
|
if not row:
|
||||||
|
@ -407,7 +447,7 @@ class UserDirectoyHandler(object):
|
||||||
if not row:
|
if not row:
|
||||||
yield self.store.add_users_to_public_room(room_id, [user_id])
|
yield self.store.add_users_to_public_room(room_id, [user_id])
|
||||||
else:
|
else:
|
||||||
logger.debug("Not adding user to public dir, %r", user_id)
|
logger.debug("Not adding new user to public dir, %r", user_id)
|
||||||
|
|
||||||
# Now we update users who share rooms with users. We do this by getting
|
# Now we update users who share rooms with users. We do this by getting
|
||||||
# all the current users in the room and seeing which aren't already
|
# all the current users in the room and seeing which aren't already
|
||||||
|
|
|
@ -362,8 +362,10 @@ def _get_hosts_for_srv_record(dns_client, host):
|
||||||
return res
|
return res
|
||||||
|
|
||||||
# no logcontexts here, so we can safely fire these off and gatherResults
|
# no logcontexts here, so we can safely fire these off and gatherResults
|
||||||
d1 = dns_client.lookupAddress(host).addCallbacks(cb, eb)
|
d1 = dns_client.lookupAddress(host).addCallbacks(
|
||||||
d2 = dns_client.lookupIPV6Address(host).addCallbacks(cb, eb)
|
cb, eb, errbackArgs=("A", ))
|
||||||
|
d2 = dns_client.lookupIPV6Address(host).addCallbacks(
|
||||||
|
cb, eb, errbackArgs=("AAAA", ))
|
||||||
results = yield defer.DeferredList(
|
results = yield defer.DeferredList(
|
||||||
[d1, d2], consumeErrors=True)
|
[d1, d2], consumeErrors=True)
|
||||||
|
|
||||||
|
|
|
@ -28,6 +28,7 @@ from canonicaljson import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
from twisted.python import failure
|
||||||
from twisted.web import server, resource
|
from twisted.web import server, resource
|
||||||
from twisted.web.server import NOT_DONE_YET
|
from twisted.web.server import NOT_DONE_YET
|
||||||
from twisted.web.util import redirectTo
|
from twisted.web.util import redirectTo
|
||||||
|
@ -131,12 +132,17 @@ def wrap_request_handler(request_handler, include_metrics=False):
|
||||||
version_string=self.version_string,
|
version_string=self.version_string,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(
|
# failure.Failure() fishes the original Failure out
|
||||||
"Failed handle request %s.%s on %r: %r",
|
# of our stack, and thus gives us a sensible stack
|
||||||
|
# trace.
|
||||||
|
f = failure.Failure()
|
||||||
|
logger.error(
|
||||||
|
"Failed handle request %s.%s on %r: %r: %s",
|
||||||
request_handler.__module__,
|
request_handler.__module__,
|
||||||
request_handler.__name__,
|
request_handler.__name__,
|
||||||
self,
|
self,
|
||||||
request
|
request,
|
||||||
|
f.getTraceback().rstrip(),
|
||||||
)
|
)
|
||||||
respond_with_json(
|
respond_with_json(
|
||||||
request,
|
request,
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
|
||||||
|
@ -81,6 +82,7 @@ class ModuleApi(object):
|
||||||
reg = self.hs.get_handlers().registration_handler
|
reg = self.hs.get_handlers().registration_handler
|
||||||
return reg.register(localpart=localpart)
|
return reg.register(localpart=localpart)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def invalidate_access_token(self, access_token):
|
def invalidate_access_token(self, access_token):
|
||||||
"""Invalidate an access token for a user
|
"""Invalidate an access token for a user
|
||||||
|
|
||||||
|
@ -94,8 +96,16 @@ class ModuleApi(object):
|
||||||
Raises:
|
Raises:
|
||||||
synapse.api.errors.AuthError: the access token is invalid
|
synapse.api.errors.AuthError: the access token is invalid
|
||||||
"""
|
"""
|
||||||
|
# see if the access token corresponds to a device
|
||||||
return self._auth_handler.delete_access_token(access_token)
|
user_info = yield self._auth.get_user_by_access_token(access_token)
|
||||||
|
device_id = user_info.get("device_id")
|
||||||
|
user_id = user_info["user"].to_string()
|
||||||
|
if device_id:
|
||||||
|
# delete the device, which will also delete its access tokens
|
||||||
|
yield self.hs.get_device_handler().delete_device(user_id, device_id)
|
||||||
|
else:
|
||||||
|
# no associated device. Just delete the access token.
|
||||||
|
yield self._auth_handler.delete_access_token(access_token)
|
||||||
|
|
||||||
def run_db_interaction(self, desc, func, *args, **kwargs):
|
def run_db_interaction(self, desc, func, *args, **kwargs):
|
||||||
"""Run a function with a database connection
|
"""Run a function with a database connection
|
||||||
|
|
|
@ -255,9 +255,7 @@ class Notifier(object):
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.federation_sender:
|
if self.federation_sender:
|
||||||
preserve_fn(self.federation_sender.notify_new_events)(
|
self.federation_sender.notify_new_events(room_stream_id)
|
||||||
room_stream_id
|
|
||||||
)
|
|
||||||
|
|
||||||
if event.type == EventTypes.Member and event.membership == Membership.JOIN:
|
if event.type == EventTypes.Member and event.membership == Membership.JOIN:
|
||||||
self._user_joined_room(event.state_key, event.room_id)
|
self._user_joined_room(event.state_key, event.room_id)
|
||||||
|
@ -297,8 +295,7 @@ class Notifier(object):
|
||||||
def on_new_replication_data(self):
|
def on_new_replication_data(self):
|
||||||
"""Used to inform replication listeners that something has happend
|
"""Used to inform replication listeners that something has happend
|
||||||
without waking up any of the normal user event streams"""
|
without waking up any of the normal user event streams"""
|
||||||
with PreserveLoggingContext():
|
self.notify_replication()
|
||||||
self.notify_replication()
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def wait_for_events(self, user_id, timeout, callback, room_ids=None,
|
def wait_for_events(self, user_id, timeout, callback, room_ids=None,
|
||||||
|
@ -516,8 +513,14 @@ class Notifier(object):
|
||||||
self.replication_deferred = ObservableDeferred(defer.Deferred())
|
self.replication_deferred = ObservableDeferred(defer.Deferred())
|
||||||
deferred.callback(None)
|
deferred.callback(None)
|
||||||
|
|
||||||
for cb in self.replication_callbacks:
|
# the callbacks may well outlast the current request, so we run
|
||||||
preserve_fn(cb)()
|
# them in the sentinel logcontext.
|
||||||
|
#
|
||||||
|
# (ideally it would be up to the callbacks to know if they were
|
||||||
|
# starting off background processes and drop the logcontext
|
||||||
|
# accordingly, but that requires more changes)
|
||||||
|
for cb in self.replication_callbacks:
|
||||||
|
cb()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def wait_for_replication(self, callback, timeout):
|
def wait_for_replication(self, callback, timeout):
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
# Copyright 2015, 2016 OpenMarket Ltd
|
# Copyright 2015, 2016 OpenMarket Ltd
|
||||||
|
# Copyright 2017 New Vector Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -295,7 +296,7 @@ class HttpPusher(object):
|
||||||
if event.type == 'm.room.member':
|
if event.type == 'm.room.member':
|
||||||
d['notification']['membership'] = event.content['membership']
|
d['notification']['membership'] = event.content['membership']
|
||||||
d['notification']['user_is_target'] = event.state_key == self.user_id
|
d['notification']['user_is_target'] = event.state_key == self.user_id
|
||||||
if not self.hs.config.push_redact_content and 'content' in event:
|
if self.hs.config.push_include_content and 'content' in event:
|
||||||
d['notification']['content'] = event.content
|
d['notification']['content'] = event.content
|
||||||
|
|
||||||
# We no longer send aliases separately, instead, we send the human
|
# We no longer send aliases separately, instead, we send the human
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from .pusher import PusherFactory
|
from .pusher import PusherFactory
|
||||||
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
@ -103,19 +103,25 @@ class PusherPool:
|
||||||
yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
|
yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def remove_pushers_by_user(self, user_id, except_access_token_id=None):
|
def remove_pushers_by_access_token(self, user_id, access_tokens):
|
||||||
all = yield self.store.get_all_pushers()
|
"""Remove the pushers for a given user corresponding to a set of
|
||||||
logger.info(
|
access_tokens.
|
||||||
"Removing all pushers for user %s except access tokens id %r",
|
|
||||||
user_id, except_access_token_id
|
Args:
|
||||||
)
|
user_id (str): user to remove pushers for
|
||||||
for p in all:
|
access_tokens (Iterable[int]): access token *ids* to remove pushers
|
||||||
if p['user_name'] == user_id and p['access_token'] != except_access_token_id:
|
for
|
||||||
|
"""
|
||||||
|
tokens = set(access_tokens)
|
||||||
|
for p in (yield self.store.get_pushers_by_user_id(user_id)):
|
||||||
|
if p['access_token'] in tokens:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Removing pusher for app id %s, pushkey %s, user %s",
|
"Removing pusher for app id %s, pushkey %s, user %s",
|
||||||
p['app_id'], p['pushkey'], p['user_name']
|
p['app_id'], p['pushkey'], p['user_name']
|
||||||
)
|
)
|
||||||
yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
|
yield self.remove_pusher(
|
||||||
|
p['app_id'], p['pushkey'], p['user_name'],
|
||||||
|
)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_new_notifications(self, min_stream_id, max_stream_id):
|
def on_new_notifications(self, min_stream_id, max_stream_id):
|
||||||
|
@ -136,7 +142,7 @@ class PusherPool:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
yield preserve_context_over_deferred(defer.gatherResults(deferreds))
|
yield make_deferred_yieldable(defer.gatherResults(deferreds))
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Exception in pusher on_new_notifications")
|
logger.exception("Exception in pusher on_new_notifications")
|
||||||
|
|
||||||
|
@ -161,7 +167,7 @@ class PusherPool:
|
||||||
preserve_fn(p.on_new_receipts)(min_stream_id, max_stream_id)
|
preserve_fn(p.on_new_receipts)(min_stream_id, max_stream_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
yield preserve_context_over_deferred(defer.gatherResults(deferreds))
|
yield make_deferred_yieldable(defer.gatherResults(deferreds))
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Exception in pusher on_new_receipts")
|
logger.exception("Exception in pusher on_new_receipts")
|
||||||
|
|
||||||
|
|
|
@ -12,20 +12,18 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from ._base import BaseSlavedStore
|
import logging
|
||||||
from ._slaved_id_tracker import SlavedIdTracker
|
|
||||||
|
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.storage import DataStore
|
from synapse.storage import DataStore
|
||||||
from synapse.storage.roommember import RoomMemberStore
|
|
||||||
from synapse.storage.event_federation import EventFederationStore
|
from synapse.storage.event_federation import EventFederationStore
|
||||||
from synapse.storage.event_push_actions import EventPushActionsStore
|
from synapse.storage.event_push_actions import EventPushActionsStore
|
||||||
from synapse.storage.state import StateStore
|
from synapse.storage.roommember import RoomMemberStore
|
||||||
|
from synapse.storage.state import StateGroupReadStore
|
||||||
from synapse.storage.stream import StreamStore
|
from synapse.storage.stream import StreamStore
|
||||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||||
|
from ._base import BaseSlavedStore
|
||||||
import logging
|
from ._slaved_id_tracker import SlavedIdTracker
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -39,7 +37,7 @@ logger = logging.getLogger(__name__)
|
||||||
# the method descriptor on the DataStore and chuck them into our class.
|
# the method descriptor on the DataStore and chuck them into our class.
|
||||||
|
|
||||||
|
|
||||||
class SlavedEventStore(BaseSlavedStore):
|
class SlavedEventStore(StateGroupReadStore, BaseSlavedStore):
|
||||||
|
|
||||||
def __init__(self, db_conn, hs):
|
def __init__(self, db_conn, hs):
|
||||||
super(SlavedEventStore, self).__init__(db_conn, hs)
|
super(SlavedEventStore, self).__init__(db_conn, hs)
|
||||||
|
@ -90,25 +88,9 @@ class SlavedEventStore(BaseSlavedStore):
|
||||||
_get_unread_counts_by_pos_txn = (
|
_get_unread_counts_by_pos_txn = (
|
||||||
DataStore._get_unread_counts_by_pos_txn.__func__
|
DataStore._get_unread_counts_by_pos_txn.__func__
|
||||||
)
|
)
|
||||||
_get_state_group_for_events = (
|
|
||||||
StateStore.__dict__["_get_state_group_for_events"]
|
|
||||||
)
|
|
||||||
_get_state_group_for_event = (
|
|
||||||
StateStore.__dict__["_get_state_group_for_event"]
|
|
||||||
)
|
|
||||||
_get_state_groups_from_groups = (
|
|
||||||
StateStore.__dict__["_get_state_groups_from_groups"]
|
|
||||||
)
|
|
||||||
_get_state_groups_from_groups_txn = (
|
|
||||||
DataStore._get_state_groups_from_groups_txn.__func__
|
|
||||||
)
|
|
||||||
get_recent_event_ids_for_room = (
|
get_recent_event_ids_for_room = (
|
||||||
StreamStore.__dict__["get_recent_event_ids_for_room"]
|
StreamStore.__dict__["get_recent_event_ids_for_room"]
|
||||||
)
|
)
|
||||||
get_current_state_ids = (
|
|
||||||
StateStore.__dict__["get_current_state_ids"]
|
|
||||||
)
|
|
||||||
get_state_group_delta = StateStore.__dict__["get_state_group_delta"]
|
|
||||||
_get_joined_hosts_cache = RoomMemberStore.__dict__["_get_joined_hosts_cache"]
|
_get_joined_hosts_cache = RoomMemberStore.__dict__["_get_joined_hosts_cache"]
|
||||||
has_room_changed_since = DataStore.has_room_changed_since.__func__
|
has_room_changed_since = DataStore.has_room_changed_since.__func__
|
||||||
|
|
||||||
|
@ -134,12 +116,6 @@ class SlavedEventStore(BaseSlavedStore):
|
||||||
DataStore.get_room_events_stream_for_room.__func__
|
DataStore.get_room_events_stream_for_room.__func__
|
||||||
)
|
)
|
||||||
get_events_around = DataStore.get_events_around.__func__
|
get_events_around = DataStore.get_events_around.__func__
|
||||||
get_state_for_event = DataStore.get_state_for_event.__func__
|
|
||||||
get_state_for_events = DataStore.get_state_for_events.__func__
|
|
||||||
get_state_groups = DataStore.get_state_groups.__func__
|
|
||||||
get_state_groups_ids = DataStore.get_state_groups_ids.__func__
|
|
||||||
get_state_ids_for_event = DataStore.get_state_ids_for_event.__func__
|
|
||||||
get_state_ids_for_events = DataStore.get_state_ids_for_events.__func__
|
|
||||||
get_joined_users_from_state = DataStore.get_joined_users_from_state.__func__
|
get_joined_users_from_state = DataStore.get_joined_users_from_state.__func__
|
||||||
get_joined_users_from_context = DataStore.get_joined_users_from_context.__func__
|
get_joined_users_from_context = DataStore.get_joined_users_from_context.__func__
|
||||||
_get_joined_users_from_context = (
|
_get_joined_users_from_context = (
|
||||||
|
@ -169,10 +145,7 @@ class SlavedEventStore(BaseSlavedStore):
|
||||||
_get_rooms_for_user_where_membership_is_txn = (
|
_get_rooms_for_user_where_membership_is_txn = (
|
||||||
DataStore._get_rooms_for_user_where_membership_is_txn.__func__
|
DataStore._get_rooms_for_user_where_membership_is_txn.__func__
|
||||||
)
|
)
|
||||||
_get_state_for_groups = DataStore._get_state_for_groups.__func__
|
|
||||||
_get_all_state_from_cache = DataStore._get_all_state_from_cache.__func__
|
|
||||||
_get_events_around_txn = DataStore._get_events_around_txn.__func__
|
_get_events_around_txn = DataStore._get_events_around_txn.__func__
|
||||||
_get_some_state_from_cache = DataStore._get_some_state_from_cache.__func__
|
|
||||||
|
|
||||||
get_backfill_events = DataStore.get_backfill_events.__func__
|
get_backfill_events = DataStore.get_backfill_events.__func__
|
||||||
_get_backfill_events = DataStore._get_backfill_events.__func__
|
_get_backfill_events = DataStore._get_backfill_events.__func__
|
||||||
|
|
|
@ -216,11 +216,12 @@ class ReplicationStreamer(object):
|
||||||
self.federation_sender.federation_ack(token)
|
self.federation_sender.federation_ack(token)
|
||||||
|
|
||||||
@measure_func("repl.on_user_sync")
|
@measure_func("repl.on_user_sync")
|
||||||
|
@defer.inlineCallbacks
|
||||||
def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms):
|
def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms):
|
||||||
"""A client has started/stopped syncing on a worker.
|
"""A client has started/stopped syncing on a worker.
|
||||||
"""
|
"""
|
||||||
user_sync_counter.inc()
|
user_sync_counter.inc()
|
||||||
self.presence_handler.update_external_syncs_row(
|
yield self.presence_handler.update_external_syncs_row(
|
||||||
conn_id, user_id, is_syncing, last_sync_ms,
|
conn_id, user_id, is_syncing, last_sync_ms,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -244,11 +245,12 @@ class ReplicationStreamer(object):
|
||||||
getattr(self.store, cache_func).invalidate(tuple(keys))
|
getattr(self.store, cache_func).invalidate(tuple(keys))
|
||||||
|
|
||||||
@measure_func("repl.on_user_ip")
|
@measure_func("repl.on_user_ip")
|
||||||
|
@defer.inlineCallbacks
|
||||||
def on_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen):
|
def on_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen):
|
||||||
"""The client saw a user request
|
"""The client saw a user request
|
||||||
"""
|
"""
|
||||||
user_ip_cache_counter.inc()
|
user_ip_cache_counter.inc()
|
||||||
self.store.insert_client_ip(
|
yield self.store.insert_client_ip(
|
||||||
user_id, access_token, ip, user_agent, device_id, last_seen,
|
user_id, access_token, ip, user_agent, device_id, last_seen,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -137,8 +137,8 @@ class DeactivateAccountRestServlet(ClientV1RestServlet):
|
||||||
PATTERNS = client_path_patterns("/admin/deactivate/(?P<target_user_id>[^/]*)")
|
PATTERNS = client_path_patterns("/admin/deactivate/(?P<target_user_id>[^/]*)")
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
self._auth_handler = hs.get_auth_handler()
|
|
||||||
super(DeactivateAccountRestServlet, self).__init__(hs)
|
super(DeactivateAccountRestServlet, self).__init__(hs)
|
||||||
|
self._deactivate_account_handler = hs.get_deactivate_account_handler()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request, target_user_id):
|
def on_POST(self, request, target_user_id):
|
||||||
|
@ -149,7 +149,7 @@ class DeactivateAccountRestServlet(ClientV1RestServlet):
|
||||||
if not is_admin:
|
if not is_admin:
|
||||||
raise AuthError(403, "You are not a server admin")
|
raise AuthError(403, "You are not a server admin")
|
||||||
|
|
||||||
yield self._auth_handler.deactivate_account(target_user_id)
|
yield self._deactivate_account_handler.deactivate_account(target_user_id)
|
||||||
defer.returnValue((200, {}))
|
defer.returnValue((200, {}))
|
||||||
|
|
||||||
|
|
||||||
|
@ -309,7 +309,7 @@ class ResetPasswordRestServlet(ClientV1RestServlet):
|
||||||
super(ResetPasswordRestServlet, self).__init__(hs)
|
super(ResetPasswordRestServlet, self).__init__(hs)
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.auth_handler = hs.get_auth_handler()
|
self._set_password_handler = hs.get_set_password_handler()
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request, target_user_id):
|
def on_POST(self, request, target_user_id):
|
||||||
|
@ -330,7 +330,7 @@ class ResetPasswordRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
logger.info("new_password: %r", new_password)
|
logger.info("new_password: %r", new_password)
|
||||||
|
|
||||||
yield self.auth_handler.set_password(
|
yield self._set_password_handler.set_password(
|
||||||
target_user_id, new_password, requester
|
target_user_id, new_password, requester
|
||||||
)
|
)
|
||||||
defer.returnValue((200, {}))
|
defer.returnValue((200, {}))
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.auth import get_access_token_from_request
|
from synapse.api.auth import get_access_token_from_request
|
||||||
|
from synapse.api.errors import AuthError
|
||||||
|
|
||||||
from .base import ClientV1RestServlet, client_path_patterns
|
from .base import ClientV1RestServlet, client_path_patterns
|
||||||
|
|
||||||
|
@ -30,15 +31,30 @@ class LogoutRestServlet(ClientV1RestServlet):
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
super(LogoutRestServlet, self).__init__(hs)
|
super(LogoutRestServlet, self).__init__(hs)
|
||||||
|
self._auth = hs.get_auth()
|
||||||
self._auth_handler = hs.get_auth_handler()
|
self._auth_handler = hs.get_auth_handler()
|
||||||
|
self._device_handler = hs.get_device_handler()
|
||||||
|
|
||||||
def on_OPTIONS(self, request):
|
def on_OPTIONS(self, request):
|
||||||
return (200, {})
|
return (200, {})
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
access_token = get_access_token_from_request(request)
|
try:
|
||||||
yield self._auth_handler.delete_access_token(access_token)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
except AuthError:
|
||||||
|
# this implies the access token has already been deleted.
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
if requester.device_id is None:
|
||||||
|
# the acccess token wasn't associated with a device.
|
||||||
|
# Just delete the access token
|
||||||
|
access_token = get_access_token_from_request(request)
|
||||||
|
yield self._auth_handler.delete_access_token(access_token)
|
||||||
|
else:
|
||||||
|
yield self._device_handler.delete_device(
|
||||||
|
requester.user.to_string(), requester.device_id)
|
||||||
|
|
||||||
defer.returnValue((200, {}))
|
defer.returnValue((200, {}))
|
||||||
|
|
||||||
|
|
||||||
|
@ -49,6 +65,7 @@ class LogoutAllRestServlet(ClientV1RestServlet):
|
||||||
super(LogoutAllRestServlet, self).__init__(hs)
|
super(LogoutAllRestServlet, self).__init__(hs)
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self._auth_handler = hs.get_auth_handler()
|
self._auth_handler = hs.get_auth_handler()
|
||||||
|
self._device_handler = hs.get_device_handler()
|
||||||
|
|
||||||
def on_OPTIONS(self, request):
|
def on_OPTIONS(self, request):
|
||||||
return (200, {})
|
return (200, {})
|
||||||
|
@ -57,6 +74,12 @@ class LogoutAllRestServlet(ClientV1RestServlet):
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
user_id = requester.user.to_string()
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
|
# first delete all of the user's devices
|
||||||
|
yield self._device_handler.delete_all_devices_for_user(user_id)
|
||||||
|
|
||||||
|
# .. and then delete any access tokens which weren't associated with
|
||||||
|
# devices.
|
||||||
yield self._auth_handler.delete_access_tokens_for_user(user_id)
|
yield self._auth_handler.delete_access_tokens_for_user(user_id)
|
||||||
defer.returnValue((200, {}))
|
defer.returnValue((200, {}))
|
||||||
|
|
||||||
|
|
|
@ -15,12 +15,13 @@
|
||||||
|
|
||||||
"""This module contains base REST classes for constructing client v1 servlets.
|
"""This module contains base REST classes for constructing client v1 servlets.
|
||||||
"""
|
"""
|
||||||
|
import logging
|
||||||
from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import logging
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.api.errors import InteractiveAuthIncompleteError
|
||||||
|
from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -57,3 +58,37 @@ def set_timeline_upper_limit(filter_json, filter_timeline_limit):
|
||||||
filter_json['room']['timeline']["limit"] = min(
|
filter_json['room']['timeline']["limit"] = min(
|
||||||
filter_json['room']['timeline']['limit'],
|
filter_json['room']['timeline']['limit'],
|
||||||
filter_timeline_limit)
|
filter_timeline_limit)
|
||||||
|
|
||||||
|
|
||||||
|
def interactive_auth_handler(orig):
|
||||||
|
"""Wraps an on_POST method to handle InteractiveAuthIncompleteErrors
|
||||||
|
|
||||||
|
Takes a on_POST method which returns a deferred (errcode, body) response
|
||||||
|
and adds exception handling to turn a InteractiveAuthIncompleteError into
|
||||||
|
a 401 response.
|
||||||
|
|
||||||
|
Normal usage is:
|
||||||
|
|
||||||
|
@interactive_auth_handler
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def on_POST(self, request):
|
||||||
|
# ...
|
||||||
|
yield self.auth_handler.check_auth
|
||||||
|
"""
|
||||||
|
def wrapped(*args, **kwargs):
|
||||||
|
res = defer.maybeDeferred(orig, *args, **kwargs)
|
||||||
|
res.addErrback(_catch_incomplete_interactive_auth)
|
||||||
|
return res
|
||||||
|
return wrapped
|
||||||
|
|
||||||
|
|
||||||
|
def _catch_incomplete_interactive_auth(f):
|
||||||
|
"""helper for interactive_auth_handler
|
||||||
|
|
||||||
|
Catches InteractiveAuthIncompleteErrors and turns them into 401 responses
|
||||||
|
|
||||||
|
Args:
|
||||||
|
f (failure.Failure):
|
||||||
|
"""
|
||||||
|
f.trap(InteractiveAuthIncompleteError)
|
||||||
|
return 401, f.value.result
|
||||||
|
|
|
@ -19,14 +19,14 @@ from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.auth import has_access_token
|
from synapse.api.auth import has_access_token
|
||||||
from synapse.api.constants import LoginType
|
from synapse.api.constants import LoginType
|
||||||
from synapse.api.errors import Codes, LoginError, SynapseError
|
from synapse.api.errors import Codes, SynapseError
|
||||||
from synapse.http.servlet import (
|
from synapse.http.servlet import (
|
||||||
RestServlet, assert_params_in_request,
|
RestServlet, assert_params_in_request,
|
||||||
parse_json_object_from_request,
|
parse_json_object_from_request,
|
||||||
)
|
)
|
||||||
from synapse.util.async import run_on_reactor
|
from synapse.util.async import run_on_reactor
|
||||||
from synapse.util.msisdn import phone_number_to_msisdn
|
from synapse.util.msisdn import phone_number_to_msisdn
|
||||||
from ._base import client_v2_patterns
|
from ._base import client_v2_patterns, interactive_auth_handler
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -98,56 +98,61 @@ class PasswordRestServlet(RestServlet):
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.auth_handler = hs.get_auth_handler()
|
self.auth_handler = hs.get_auth_handler()
|
||||||
self.datastore = self.hs.get_datastore()
|
self.datastore = self.hs.get_datastore()
|
||||||
|
self._set_password_handler = hs.get_set_password_handler()
|
||||||
|
|
||||||
|
@interactive_auth_handler
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
yield run_on_reactor()
|
|
||||||
|
|
||||||
body = parse_json_object_from_request(request)
|
body = parse_json_object_from_request(request)
|
||||||
|
|
||||||
authed, result, params, _ = yield self.auth_handler.check_auth([
|
# there are two possibilities here. Either the user does not have an
|
||||||
[LoginType.PASSWORD],
|
# access token, and needs to do a password reset; or they have one and
|
||||||
[LoginType.EMAIL_IDENTITY],
|
# need to validate their identity.
|
||||||
[LoginType.MSISDN],
|
#
|
||||||
], body, self.hs.get_ip_from_request(request))
|
# In the first case, we offer a couple of means of identifying
|
||||||
|
# themselves (email and msisdn, though it's unclear if msisdn actually
|
||||||
|
# works).
|
||||||
|
#
|
||||||
|
# In the second case, we require a password to confirm their identity.
|
||||||
|
|
||||||
if not authed:
|
if has_access_token(request):
|
||||||
defer.returnValue((401, result))
|
|
||||||
|
|
||||||
user_id = None
|
|
||||||
requester = None
|
|
||||||
|
|
||||||
if LoginType.PASSWORD in result:
|
|
||||||
# if using password, they should also be logged in
|
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
user_id = requester.user.to_string()
|
params = yield self.auth_handler.validate_user_via_ui_auth(
|
||||||
if user_id != result[LoginType.PASSWORD]:
|
requester, body, self.hs.get_ip_from_request(request),
|
||||||
raise LoginError(400, "", Codes.UNKNOWN)
|
|
||||||
elif LoginType.EMAIL_IDENTITY in result:
|
|
||||||
threepid = result[LoginType.EMAIL_IDENTITY]
|
|
||||||
if 'medium' not in threepid or 'address' not in threepid:
|
|
||||||
raise SynapseError(500, "Malformed threepid")
|
|
||||||
if threepid['medium'] == 'email':
|
|
||||||
# For emails, transform the address to lowercase.
|
|
||||||
# We store all email addreses as lowercase in the DB.
|
|
||||||
# (See add_threepid in synapse/handlers/auth.py)
|
|
||||||
threepid['address'] = threepid['address'].lower()
|
|
||||||
# if using email, we must know about the email they're authing with!
|
|
||||||
threepid_user_id = yield self.datastore.get_user_id_by_threepid(
|
|
||||||
threepid['medium'], threepid['address']
|
|
||||||
)
|
)
|
||||||
if not threepid_user_id:
|
user_id = requester.user.to_string()
|
||||||
raise SynapseError(404, "Email address not found", Codes.NOT_FOUND)
|
|
||||||
user_id = threepid_user_id
|
|
||||||
else:
|
else:
|
||||||
logger.error("Auth succeeded but no known type!", result.keys())
|
requester = None
|
||||||
raise SynapseError(500, "", Codes.UNKNOWN)
|
result, params, _ = yield self.auth_handler.check_auth(
|
||||||
|
[[LoginType.EMAIL_IDENTITY], [LoginType.MSISDN]],
|
||||||
|
body, self.hs.get_ip_from_request(request),
|
||||||
|
)
|
||||||
|
|
||||||
|
if LoginType.EMAIL_IDENTITY in result:
|
||||||
|
threepid = result[LoginType.EMAIL_IDENTITY]
|
||||||
|
if 'medium' not in threepid or 'address' not in threepid:
|
||||||
|
raise SynapseError(500, "Malformed threepid")
|
||||||
|
if threepid['medium'] == 'email':
|
||||||
|
# For emails, transform the address to lowercase.
|
||||||
|
# We store all email addreses as lowercase in the DB.
|
||||||
|
# (See add_threepid in synapse/handlers/auth.py)
|
||||||
|
threepid['address'] = threepid['address'].lower()
|
||||||
|
# if using email, we must know about the email they're authing with!
|
||||||
|
threepid_user_id = yield self.datastore.get_user_id_by_threepid(
|
||||||
|
threepid['medium'], threepid['address']
|
||||||
|
)
|
||||||
|
if not threepid_user_id:
|
||||||
|
raise SynapseError(404, "Email address not found", Codes.NOT_FOUND)
|
||||||
|
user_id = threepid_user_id
|
||||||
|
else:
|
||||||
|
logger.error("Auth succeeded but no known type!", result.keys())
|
||||||
|
raise SynapseError(500, "", Codes.UNKNOWN)
|
||||||
|
|
||||||
if 'new_password' not in params:
|
if 'new_password' not in params:
|
||||||
raise SynapseError(400, "", Codes.MISSING_PARAM)
|
raise SynapseError(400, "", Codes.MISSING_PARAM)
|
||||||
new_password = params['new_password']
|
new_password = params['new_password']
|
||||||
|
|
||||||
yield self.auth_handler.set_password(
|
yield self._set_password_handler.set_password(
|
||||||
user_id, new_password, requester
|
user_id, new_password, requester
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -161,52 +166,32 @@ class DeactivateAccountRestServlet(RestServlet):
|
||||||
PATTERNS = client_v2_patterns("/account/deactivate$")
|
PATTERNS = client_v2_patterns("/account/deactivate$")
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs):
|
||||||
|
super(DeactivateAccountRestServlet, self).__init__()
|
||||||
self.hs = hs
|
self.hs = hs
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.auth_handler = hs.get_auth_handler()
|
self.auth_handler = hs.get_auth_handler()
|
||||||
super(DeactivateAccountRestServlet, self).__init__()
|
self._deactivate_account_handler = hs.get_deactivate_account_handler()
|
||||||
|
|
||||||
|
@interactive_auth_handler
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
body = parse_json_object_from_request(request)
|
body = parse_json_object_from_request(request)
|
||||||
|
|
||||||
# if the caller provides an access token, it ought to be valid.
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
requester = None
|
|
||||||
if has_access_token(request):
|
|
||||||
requester = yield self.auth.get_user_by_req(
|
|
||||||
request,
|
|
||||||
) # type: synapse.types.Requester
|
|
||||||
|
|
||||||
# allow ASes to dectivate their own users
|
# allow ASes to dectivate their own users
|
||||||
if requester and requester.app_service:
|
if requester.app_service:
|
||||||
yield self.auth_handler.deactivate_account(
|
yield self._deactivate_account_handler.deactivate_account(
|
||||||
requester.user.to_string()
|
requester.user.to_string()
|
||||||
)
|
)
|
||||||
defer.returnValue((200, {}))
|
defer.returnValue((200, {}))
|
||||||
|
|
||||||
authed, result, params, _ = yield self.auth_handler.check_auth([
|
yield self.auth_handler.validate_user_via_ui_auth(
|
||||||
[LoginType.PASSWORD],
|
requester, body, self.hs.get_ip_from_request(request),
|
||||||
], body, self.hs.get_ip_from_request(request))
|
)
|
||||||
|
yield self._deactivate_account_handler.deactivate_account(
|
||||||
if not authed:
|
requester.user.to_string(),
|
||||||
defer.returnValue((401, result))
|
)
|
||||||
|
|
||||||
if LoginType.PASSWORD in result:
|
|
||||||
user_id = result[LoginType.PASSWORD]
|
|
||||||
# if using password, they should also be logged in
|
|
||||||
if requester is None:
|
|
||||||
raise SynapseError(
|
|
||||||
400,
|
|
||||||
"Deactivate account requires an access_token",
|
|
||||||
errcode=Codes.MISSING_TOKEN
|
|
||||||
)
|
|
||||||
if requester.user.to_string() != user_id:
|
|
||||||
raise LoginError(400, "", Codes.UNKNOWN)
|
|
||||||
else:
|
|
||||||
logger.error("Auth succeeded but no known type!", result.keys())
|
|
||||||
raise SynapseError(500, "", Codes.UNKNOWN)
|
|
||||||
|
|
||||||
yield self.auth_handler.deactivate_account(user_id)
|
|
||||||
defer.returnValue((200, {}))
|
defer.returnValue((200, {}))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -17,9 +17,9 @@ import logging
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api import constants, errors
|
from synapse.api import errors
|
||||||
from synapse.http import servlet
|
from synapse.http import servlet
|
||||||
from ._base import client_v2_patterns
|
from ._base import client_v2_patterns, interactive_auth_handler
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -60,8 +60,11 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
|
||||||
self.device_handler = hs.get_device_handler()
|
self.device_handler = hs.get_device_handler()
|
||||||
self.auth_handler = hs.get_auth_handler()
|
self.auth_handler = hs.get_auth_handler()
|
||||||
|
|
||||||
|
@interactive_auth_handler
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
body = servlet.parse_json_object_from_request(request)
|
body = servlet.parse_json_object_from_request(request)
|
||||||
except errors.SynapseError as e:
|
except errors.SynapseError as e:
|
||||||
|
@ -77,14 +80,10 @@ class DeleteDevicesRestServlet(servlet.RestServlet):
|
||||||
400, "No devices supplied", errcode=errors.Codes.MISSING_PARAM
|
400, "No devices supplied", errcode=errors.Codes.MISSING_PARAM
|
||||||
)
|
)
|
||||||
|
|
||||||
authed, result, params, _ = yield self.auth_handler.check_auth([
|
yield self.auth_handler.validate_user_via_ui_auth(
|
||||||
[constants.LoginType.PASSWORD],
|
requester, body, self.hs.get_ip_from_request(request),
|
||||||
], body, self.hs.get_ip_from_request(request))
|
)
|
||||||
|
|
||||||
if not authed:
|
|
||||||
defer.returnValue((401, result))
|
|
||||||
|
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
|
||||||
yield self.device_handler.delete_devices(
|
yield self.device_handler.delete_devices(
|
||||||
requester.user.to_string(),
|
requester.user.to_string(),
|
||||||
body['devices'],
|
body['devices'],
|
||||||
|
@ -115,6 +114,7 @@ class DeviceRestServlet(servlet.RestServlet):
|
||||||
)
|
)
|
||||||
defer.returnValue((200, device))
|
defer.returnValue((200, device))
|
||||||
|
|
||||||
|
@interactive_auth_handler
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_DELETE(self, request, device_id):
|
def on_DELETE(self, request, device_id):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request)
|
||||||
|
@ -130,19 +130,13 @@ class DeviceRestServlet(servlet.RestServlet):
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
authed, result, params, _ = yield self.auth_handler.check_auth([
|
yield self.auth_handler.validate_user_via_ui_auth(
|
||||||
[constants.LoginType.PASSWORD],
|
requester, body, self.hs.get_ip_from_request(request),
|
||||||
], body, self.hs.get_ip_from_request(request))
|
)
|
||||||
|
|
||||||
if not authed:
|
yield self.device_handler.delete_device(
|
||||||
defer.returnValue((401, result))
|
requester.user.to_string(), device_id,
|
||||||
|
)
|
||||||
# check that the UI auth matched the access token
|
|
||||||
user_id = result[constants.LoginType.PASSWORD]
|
|
||||||
if user_id != requester.user.to_string():
|
|
||||||
raise errors.AuthError(403, "Invalid auth")
|
|
||||||
|
|
||||||
yield self.device_handler.delete_device(user_id, device_id)
|
|
||||||
defer.returnValue((200, {}))
|
defer.returnValue((200, {}))
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
|
|
|
@ -38,7 +38,7 @@ class GroupServlet(RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, group_id):
|
def on_GET(self, request, group_id):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
group_description = yield self.groups_handler.get_group_profile(
|
group_description = yield self.groups_handler.get_group_profile(
|
||||||
|
@ -74,7 +74,7 @@ class GroupSummaryServlet(RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, group_id):
|
def on_GET(self, request, group_id):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
get_group_summary = yield self.groups_handler.get_group_summary(
|
get_group_summary = yield self.groups_handler.get_group_summary(
|
||||||
|
@ -148,7 +148,7 @@ class GroupCategoryServlet(RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, group_id, category_id):
|
def on_GET(self, request, group_id, category_id):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
category = yield self.groups_handler.get_group_category(
|
category = yield self.groups_handler.get_group_category(
|
||||||
|
@ -200,7 +200,7 @@ class GroupCategoriesServlet(RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, group_id):
|
def on_GET(self, request, group_id):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
category = yield self.groups_handler.get_group_categories(
|
category = yield self.groups_handler.get_group_categories(
|
||||||
|
@ -225,7 +225,7 @@ class GroupRoleServlet(RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, group_id, role_id):
|
def on_GET(self, request, group_id, role_id):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
category = yield self.groups_handler.get_group_role(
|
category = yield self.groups_handler.get_group_role(
|
||||||
|
@ -277,7 +277,7 @@ class GroupRolesServlet(RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, group_id):
|
def on_GET(self, request, group_id):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
category = yield self.groups_handler.get_group_roles(
|
category = yield self.groups_handler.get_group_roles(
|
||||||
|
@ -348,7 +348,7 @@ class GroupRoomServlet(RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, group_id):
|
def on_GET(self, request, group_id):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
result = yield self.groups_handler.get_rooms_in_group(group_id, requester_user_id)
|
result = yield self.groups_handler.get_rooms_in_group(group_id, requester_user_id)
|
||||||
|
@ -369,7 +369,7 @@ class GroupUsersServlet(RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, group_id):
|
def on_GET(self, request, group_id):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
result = yield self.groups_handler.get_users_in_group(group_id, requester_user_id)
|
result = yield self.groups_handler.get_users_in_group(group_id, requester_user_id)
|
||||||
|
@ -672,7 +672,7 @@ class PublicisedGroupsForUserServlet(RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request, user_id):
|
def on_GET(self, request, user_id):
|
||||||
yield self.auth.get_user_by_req(request)
|
yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
|
|
||||||
result = yield self.groups_handler.get_publicised_groups_for_user(
|
result = yield self.groups_handler.get_publicised_groups_for_user(
|
||||||
user_id
|
user_id
|
||||||
|
@ -697,7 +697,7 @@ class PublicisedGroupsForUsersServlet(RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
yield self.auth.get_user_by_req(request)
|
yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
|
|
||||||
content = parse_json_object_from_request(request)
|
content = parse_json_object_from_request(request)
|
||||||
user_ids = content["user_ids"]
|
user_ids = content["user_ids"]
|
||||||
|
@ -724,7 +724,7 @@ class GroupsForUserServlet(RestServlet):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_GET(self, request):
|
def on_GET(self, request):
|
||||||
requester = yield self.auth.get_user_by_req(request)
|
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||||
requester_user_id = requester.user.to_string()
|
requester_user_id = requester.user.to_string()
|
||||||
|
|
||||||
result = yield self.groups_handler.get_joined_groups(requester_user_id)
|
result = yield self.groups_handler.get_joined_groups(requester_user_id)
|
||||||
|
|
|
@ -27,7 +27,7 @@ from synapse.http.servlet import (
|
||||||
)
|
)
|
||||||
from synapse.util.msisdn import phone_number_to_msisdn
|
from synapse.util.msisdn import phone_number_to_msisdn
|
||||||
|
|
||||||
from ._base import client_v2_patterns
|
from ._base import client_v2_patterns, interactive_auth_handler
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import hmac
|
import hmac
|
||||||
|
@ -176,6 +176,7 @@ class RegisterRestServlet(RestServlet):
|
||||||
self.device_handler = hs.get_device_handler()
|
self.device_handler = hs.get_device_handler()
|
||||||
self.macaroon_gen = hs.get_macaroon_generator()
|
self.macaroon_gen = hs.get_macaroon_generator()
|
||||||
|
|
||||||
|
@interactive_auth_handler
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def on_POST(self, request):
|
def on_POST(self, request):
|
||||||
yield run_on_reactor()
|
yield run_on_reactor()
|
||||||
|
@ -325,14 +326,10 @@ class RegisterRestServlet(RestServlet):
|
||||||
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY],
|
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY],
|
||||||
])
|
])
|
||||||
|
|
||||||
authed, auth_result, params, session_id = yield self.auth_handler.check_auth(
|
auth_result, params, session_id = yield self.auth_handler.check_auth(
|
||||||
flows, body, self.hs.get_ip_from_request(request)
|
flows, body, self.hs.get_ip_from_request(request)
|
||||||
)
|
)
|
||||||
|
|
||||||
if not authed:
|
|
||||||
defer.returnValue((401, auth_result))
|
|
||||||
return
|
|
||||||
|
|
||||||
if registered_user_id is not None:
|
if registered_user_id is not None:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Already registered user ID %r for this session",
|
"Already registered user ID %r for this session",
|
||||||
|
|
|
@ -30,6 +30,7 @@ class VersionsRestServlet(RestServlet):
|
||||||
"r0.0.1",
|
"r0.0.1",
|
||||||
"r0.1.0",
|
"r0.1.0",
|
||||||
"r0.2.0",
|
"r0.2.0",
|
||||||
|
"r0.3.0",
|
||||||
]
|
]
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,8 @@ from synapse.util.stringutils import random_string
|
||||||
from synapse.util.caches.expiringcache import ExpiringCache
|
from synapse.util.caches.expiringcache import ExpiringCache
|
||||||
from synapse.http.client import SpiderHttpClient
|
from synapse.http.client import SpiderHttpClient
|
||||||
from synapse.http.server import (
|
from synapse.http.server import (
|
||||||
request_handler, respond_with_json_bytes
|
request_handler, respond_with_json_bytes,
|
||||||
|
respond_with_json,
|
||||||
)
|
)
|
||||||
from synapse.util.async import ObservableDeferred
|
from synapse.util.async import ObservableDeferred
|
||||||
from synapse.util.stringutils import is_ascii
|
from synapse.util.stringutils import is_ascii
|
||||||
|
@ -78,6 +79,9 @@ class PreviewUrlResource(Resource):
|
||||||
self._expire_url_cache_data, 10 * 1000
|
self._expire_url_cache_data, 10 * 1000
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def render_OPTIONS(self, request):
|
||||||
|
return respond_with_json(request, 200, {}, send_cors=True)
|
||||||
|
|
||||||
def render_GET(self, request):
|
def render_GET(self, request):
|
||||||
self._async_render_GET(request)
|
self._async_render_GET(request)
|
||||||
return NOT_DONE_YET
|
return NOT_DONE_YET
|
||||||
|
@ -348,11 +352,16 @@ class PreviewUrlResource(Resource):
|
||||||
def _expire_url_cache_data(self):
|
def _expire_url_cache_data(self):
|
||||||
"""Clean up expired url cache content, media and thumbnails.
|
"""Clean up expired url cache content, media and thumbnails.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# TODO: Delete from backup media store
|
# TODO: Delete from backup media store
|
||||||
|
|
||||||
now = self.clock.time_msec()
|
now = self.clock.time_msec()
|
||||||
|
|
||||||
|
logger.info("Running url preview cache expiry")
|
||||||
|
|
||||||
|
if not (yield self.store.has_completed_background_updates()):
|
||||||
|
logger.info("Still running DB updates; skipping expiry")
|
||||||
|
return
|
||||||
|
|
||||||
# First we delete expired url cache entries
|
# First we delete expired url cache entries
|
||||||
media_ids = yield self.store.get_expired_url_cache(now)
|
media_ids = yield self.store.get_expired_url_cache(now)
|
||||||
|
|
||||||
|
@ -426,8 +435,7 @@ class PreviewUrlResource(Resource):
|
||||||
|
|
||||||
yield self.store.delete_url_cache_media(removed_media)
|
yield self.store.delete_url_cache_media(removed_media)
|
||||||
|
|
||||||
if removed_media:
|
logger.info("Deleted %d media from url cache", len(removed_media))
|
||||||
logger.info("Deleted %d media from url cache", len(removed_media))
|
|
||||||
|
|
||||||
|
|
||||||
def decode_and_calc_og(body, media_uri, request_encoding=None):
|
def decode_and_calc_og(body, media_uri, request_encoding=None):
|
||||||
|
|
|
@ -39,18 +39,20 @@ from synapse.federation.transaction_queue import TransactionQueue
|
||||||
from synapse.handlers import Handlers
|
from synapse.handlers import Handlers
|
||||||
from synapse.handlers.appservice import ApplicationServicesHandler
|
from synapse.handlers.appservice import ApplicationServicesHandler
|
||||||
from synapse.handlers.auth import AuthHandler, MacaroonGeneartor
|
from synapse.handlers.auth import AuthHandler, MacaroonGeneartor
|
||||||
|
from synapse.handlers.deactivate_account import DeactivateAccountHandler
|
||||||
from synapse.handlers.devicemessage import DeviceMessageHandler
|
from synapse.handlers.devicemessage import DeviceMessageHandler
|
||||||
from synapse.handlers.device import DeviceHandler
|
from synapse.handlers.device import DeviceHandler
|
||||||
from synapse.handlers.e2e_keys import E2eKeysHandler
|
from synapse.handlers.e2e_keys import E2eKeysHandler
|
||||||
from synapse.handlers.presence import PresenceHandler
|
from synapse.handlers.presence import PresenceHandler
|
||||||
from synapse.handlers.room_list import RoomListHandler
|
from synapse.handlers.room_list import RoomListHandler
|
||||||
|
from synapse.handlers.set_password import SetPasswordHandler
|
||||||
from synapse.handlers.sync import SyncHandler
|
from synapse.handlers.sync import SyncHandler
|
||||||
from synapse.handlers.typing import TypingHandler
|
from synapse.handlers.typing import TypingHandler
|
||||||
from synapse.handlers.events import EventHandler, EventStreamHandler
|
from synapse.handlers.events import EventHandler, EventStreamHandler
|
||||||
from synapse.handlers.initial_sync import InitialSyncHandler
|
from synapse.handlers.initial_sync import InitialSyncHandler
|
||||||
from synapse.handlers.receipts import ReceiptsHandler
|
from synapse.handlers.receipts import ReceiptsHandler
|
||||||
from synapse.handlers.read_marker import ReadMarkerHandler
|
from synapse.handlers.read_marker import ReadMarkerHandler
|
||||||
from synapse.handlers.user_directory import UserDirectoyHandler
|
from synapse.handlers.user_directory import UserDirectoryHandler
|
||||||
from synapse.handlers.groups_local import GroupsLocalHandler
|
from synapse.handlers.groups_local import GroupsLocalHandler
|
||||||
from synapse.handlers.profile import ProfileHandler
|
from synapse.handlers.profile import ProfileHandler
|
||||||
from synapse.groups.groups_server import GroupsServerHandler
|
from synapse.groups.groups_server import GroupsServerHandler
|
||||||
|
@ -60,7 +62,10 @@ from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
||||||
from synapse.notifier import Notifier
|
from synapse.notifier import Notifier
|
||||||
from synapse.push.action_generator import ActionGenerator
|
from synapse.push.action_generator import ActionGenerator
|
||||||
from synapse.push.pusherpool import PusherPool
|
from synapse.push.pusherpool import PusherPool
|
||||||
from synapse.rest.media.v1.media_repository import MediaRepository
|
from synapse.rest.media.v1.media_repository import (
|
||||||
|
MediaRepository,
|
||||||
|
MediaRepositoryResource,
|
||||||
|
)
|
||||||
from synapse.state import StateHandler
|
from synapse.state import StateHandler
|
||||||
from synapse.storage import DataStore
|
from synapse.storage import DataStore
|
||||||
from synapse.streams.events import EventSources
|
from synapse.streams.events import EventSources
|
||||||
|
@ -90,17 +95,12 @@ class HomeServer(object):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
DEPENDENCIES = [
|
DEPENDENCIES = [
|
||||||
'config',
|
|
||||||
'clock',
|
|
||||||
'http_client',
|
'http_client',
|
||||||
'db_pool',
|
'db_pool',
|
||||||
'persistence_service',
|
|
||||||
'replication_layer',
|
'replication_layer',
|
||||||
'datastore',
|
|
||||||
'handlers',
|
'handlers',
|
||||||
'v1auth',
|
'v1auth',
|
||||||
'auth',
|
'auth',
|
||||||
'rest_servlet_factory',
|
|
||||||
'state_handler',
|
'state_handler',
|
||||||
'presence_handler',
|
'presence_handler',
|
||||||
'sync_handler',
|
'sync_handler',
|
||||||
|
@ -117,19 +117,10 @@ class HomeServer(object):
|
||||||
'application_service_handler',
|
'application_service_handler',
|
||||||
'device_message_handler',
|
'device_message_handler',
|
||||||
'profile_handler',
|
'profile_handler',
|
||||||
|
'deactivate_account_handler',
|
||||||
|
'set_password_handler',
|
||||||
'notifier',
|
'notifier',
|
||||||
'distributor',
|
|
||||||
'client_resource',
|
|
||||||
'resource_for_federation',
|
|
||||||
'resource_for_static_content',
|
|
||||||
'resource_for_web_client',
|
|
||||||
'resource_for_content_repo',
|
|
||||||
'resource_for_server_key',
|
|
||||||
'resource_for_server_key_v2',
|
|
||||||
'resource_for_media_repository',
|
|
||||||
'resource_for_metrics',
|
|
||||||
'event_sources',
|
'event_sources',
|
||||||
'ratelimiter',
|
|
||||||
'keyring',
|
'keyring',
|
||||||
'pusherpool',
|
'pusherpool',
|
||||||
'event_builder_factory',
|
'event_builder_factory',
|
||||||
|
@ -137,6 +128,7 @@ class HomeServer(object):
|
||||||
'http_client_context_factory',
|
'http_client_context_factory',
|
||||||
'simple_http_client',
|
'simple_http_client',
|
||||||
'media_repository',
|
'media_repository',
|
||||||
|
'media_repository_resource',
|
||||||
'federation_transport_client',
|
'federation_transport_client',
|
||||||
'federation_sender',
|
'federation_sender',
|
||||||
'receipts_handler',
|
'receipts_handler',
|
||||||
|
@ -183,6 +175,21 @@ class HomeServer(object):
|
||||||
def is_mine_id(self, string):
|
def is_mine_id(self, string):
|
||||||
return string.split(":", 1)[1] == self.hostname
|
return string.split(":", 1)[1] == self.hostname
|
||||||
|
|
||||||
|
def get_clock(self):
|
||||||
|
return self.clock
|
||||||
|
|
||||||
|
def get_datastore(self):
|
||||||
|
return self.datastore
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
return self.config
|
||||||
|
|
||||||
|
def get_distributor(self):
|
||||||
|
return self.distributor
|
||||||
|
|
||||||
|
def get_ratelimiter(self):
|
||||||
|
return self.ratelimiter
|
||||||
|
|
||||||
def build_replication_layer(self):
|
def build_replication_layer(self):
|
||||||
return initialize_http_replication(self)
|
return initialize_http_replication(self)
|
||||||
|
|
||||||
|
@ -265,6 +272,12 @@ class HomeServer(object):
|
||||||
def build_profile_handler(self):
|
def build_profile_handler(self):
|
||||||
return ProfileHandler(self)
|
return ProfileHandler(self)
|
||||||
|
|
||||||
|
def build_deactivate_account_handler(self):
|
||||||
|
return DeactivateAccountHandler(self)
|
||||||
|
|
||||||
|
def build_set_password_handler(self):
|
||||||
|
return SetPasswordHandler(self)
|
||||||
|
|
||||||
def build_event_sources(self):
|
def build_event_sources(self):
|
||||||
return EventSources(self)
|
return EventSources(self)
|
||||||
|
|
||||||
|
@ -294,6 +307,11 @@ class HomeServer(object):
|
||||||
**self.db_config.get("args", {})
|
**self.db_config.get("args", {})
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def build_media_repository_resource(self):
|
||||||
|
# build the media repo resource. This indirects through the HomeServer
|
||||||
|
# to ensure that we only have a single instance of
|
||||||
|
return MediaRepositoryResource(self)
|
||||||
|
|
||||||
def build_media_repository(self):
|
def build_media_repository(self):
|
||||||
return MediaRepository(self)
|
return MediaRepository(self)
|
||||||
|
|
||||||
|
@ -321,7 +339,7 @@ class HomeServer(object):
|
||||||
return ActionGenerator(self)
|
return ActionGenerator(self)
|
||||||
|
|
||||||
def build_user_directory_handler(self):
|
def build_user_directory_handler(self):
|
||||||
return UserDirectoyHandler(self)
|
return UserDirectoryHandler(self)
|
||||||
|
|
||||||
def build_groups_local_handler(self):
|
def build_groups_local_handler(self):
|
||||||
return GroupsLocalHandler(self)
|
return GroupsLocalHandler(self)
|
||||||
|
|
|
@ -3,10 +3,14 @@ import synapse.federation.transaction_queue
|
||||||
import synapse.federation.transport.client
|
import synapse.federation.transport.client
|
||||||
import synapse.handlers
|
import synapse.handlers
|
||||||
import synapse.handlers.auth
|
import synapse.handlers.auth
|
||||||
|
import synapse.handlers.deactivate_account
|
||||||
import synapse.handlers.device
|
import synapse.handlers.device
|
||||||
import synapse.handlers.e2e_keys
|
import synapse.handlers.e2e_keys
|
||||||
import synapse.storage
|
import synapse.handlers.set_password
|
||||||
|
import synapse.rest.media.v1.media_repository
|
||||||
import synapse.state
|
import synapse.state
|
||||||
|
import synapse.storage
|
||||||
|
|
||||||
|
|
||||||
class HomeServer(object):
|
class HomeServer(object):
|
||||||
def get_auth(self) -> synapse.api.auth.Auth:
|
def get_auth(self) -> synapse.api.auth.Auth:
|
||||||
|
@ -30,8 +34,20 @@ class HomeServer(object):
|
||||||
def get_state_handler(self) -> synapse.state.StateHandler:
|
def get_state_handler(self) -> synapse.state.StateHandler:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def get_deactivate_account_handler(self) -> synapse.handlers.deactivate_account.DeactivateAccountHandler:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_set_password_handler(self) -> synapse.handlers.set_password.SetPasswordHandler:
|
||||||
|
pass
|
||||||
|
|
||||||
def get_federation_sender(self) -> synapse.federation.transaction_queue.TransactionQueue:
|
def get_federation_sender(self) -> synapse.federation.transaction_queue.TransactionQueue:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get_federation_transport_client(self) -> synapse.federation.transport.client.TransportLayerClient:
|
def get_federation_transport_client(self) -> synapse.federation.transport.client.TransportLayerClient:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def get_media_repository_resource(self) -> synapse.rest.media.v1.media_repository.MediaRepositoryResource:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_media_repository(self) -> synapse.rest.media.v1.media_repository.MediaRepository:
|
||||||
|
pass
|
||||||
|
|
|
@ -16,8 +16,6 @@ import logging
|
||||||
|
|
||||||
from synapse.api.errors import StoreError
|
from synapse.api.errors import StoreError
|
||||||
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
|
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
|
||||||
from synapse.util.caches import CACHE_SIZE_FACTOR
|
|
||||||
from synapse.util.caches.dictionary_cache import DictionaryCache
|
|
||||||
from synapse.util.caches.descriptors import Cache
|
from synapse.util.caches.descriptors import Cache
|
||||||
from synapse.storage.engines import PostgresEngine
|
from synapse.storage.engines import PostgresEngine
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
|
@ -180,10 +178,6 @@ class SQLBaseStore(object):
|
||||||
self._get_event_cache = Cache("*getEvent*", keylen=3,
|
self._get_event_cache = Cache("*getEvent*", keylen=3,
|
||||||
max_entries=hs.config.event_cache_size)
|
max_entries=hs.config.event_cache_size)
|
||||||
|
|
||||||
self._state_group_cache = DictionaryCache(
|
|
||||||
"*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR
|
|
||||||
)
|
|
||||||
|
|
||||||
self._event_fetch_lock = threading.Condition()
|
self._event_fetch_lock = threading.Condition()
|
||||||
self._event_fetch_list = []
|
self._event_fetch_list = []
|
||||||
self._event_fetch_ongoing = 0
|
self._event_fetch_ongoing = 0
|
||||||
|
@ -475,23 +469,53 @@ class SQLBaseStore(object):
|
||||||
|
|
||||||
txn.executemany(sql, vals)
|
txn.executemany(sql, vals)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
def _simple_upsert(self, table, keyvalues, values,
|
def _simple_upsert(self, table, keyvalues, values,
|
||||||
insertion_values={}, desc="_simple_upsert", lock=True):
|
insertion_values={}, desc="_simple_upsert", lock=True):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
`lock` should generally be set to True (the default), but can be set
|
||||||
|
to False if either of the following are true:
|
||||||
|
|
||||||
|
* there is a UNIQUE INDEX on the key columns. In this case a conflict
|
||||||
|
will cause an IntegrityError in which case this function will retry
|
||||||
|
the update.
|
||||||
|
|
||||||
|
* we somehow know that we are the only thread which will be updating
|
||||||
|
this table.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
table (str): The table to upsert into
|
table (str): The table to upsert into
|
||||||
keyvalues (dict): The unique key tables and their new values
|
keyvalues (dict): The unique key tables and their new values
|
||||||
values (dict): The nonunique columns and their new values
|
values (dict): The nonunique columns and their new values
|
||||||
insertion_values (dict): key/values to use when inserting
|
insertion_values (dict): additional key/values to use only when
|
||||||
|
inserting
|
||||||
|
lock (bool): True to lock the table when doing the upsert.
|
||||||
Returns:
|
Returns:
|
||||||
Deferred(bool): True if a new entry was created, False if an
|
Deferred(bool): True if a new entry was created, False if an
|
||||||
existing one was updated.
|
existing one was updated.
|
||||||
"""
|
"""
|
||||||
return self.runInteraction(
|
attempts = 0
|
||||||
desc,
|
while True:
|
||||||
self._simple_upsert_txn, table, keyvalues, values, insertion_values,
|
try:
|
||||||
lock
|
result = yield self.runInteraction(
|
||||||
)
|
desc,
|
||||||
|
self._simple_upsert_txn, table, keyvalues, values, insertion_values,
|
||||||
|
lock=lock
|
||||||
|
)
|
||||||
|
defer.returnValue(result)
|
||||||
|
except self.database_engine.module.IntegrityError as e:
|
||||||
|
attempts += 1
|
||||||
|
if attempts >= 5:
|
||||||
|
# don't retry forever, because things other than races
|
||||||
|
# can cause IntegrityErrors
|
||||||
|
raise
|
||||||
|
|
||||||
|
# presumably we raced with another transaction: let's retry.
|
||||||
|
logger.warn(
|
||||||
|
"IntegrityError when upserting into %s; retrying: %s",
|
||||||
|
table, e
|
||||||
|
)
|
||||||
|
|
||||||
def _simple_upsert_txn(self, txn, table, keyvalues, values, insertion_values={},
|
def _simple_upsert_txn(self, txn, table, keyvalues, values, insertion_values={},
|
||||||
lock=True):
|
lock=True):
|
||||||
|
@ -499,7 +523,7 @@ class SQLBaseStore(object):
|
||||||
if lock:
|
if lock:
|
||||||
self.database_engine.lock_table(txn, table)
|
self.database_engine.lock_table(txn, table)
|
||||||
|
|
||||||
# Try to update
|
# First try to update.
|
||||||
sql = "UPDATE %s SET %s WHERE %s" % (
|
sql = "UPDATE %s SET %s WHERE %s" % (
|
||||||
table,
|
table,
|
||||||
", ".join("%s = ?" % (k,) for k in values),
|
", ".join("%s = ?" % (k,) for k in values),
|
||||||
|
@ -508,28 +532,29 @@ class SQLBaseStore(object):
|
||||||
sqlargs = values.values() + keyvalues.values()
|
sqlargs = values.values() + keyvalues.values()
|
||||||
|
|
||||||
txn.execute(sql, sqlargs)
|
txn.execute(sql, sqlargs)
|
||||||
if txn.rowcount == 0:
|
if txn.rowcount > 0:
|
||||||
# We didn't update and rows so insert a new one
|
# successfully updated at least one row.
|
||||||
allvalues = {}
|
|
||||||
allvalues.update(keyvalues)
|
|
||||||
allvalues.update(values)
|
|
||||||
allvalues.update(insertion_values)
|
|
||||||
|
|
||||||
sql = "INSERT INTO %s (%s) VALUES (%s)" % (
|
|
||||||
table,
|
|
||||||
", ".join(k for k in allvalues),
|
|
||||||
", ".join("?" for _ in allvalues)
|
|
||||||
)
|
|
||||||
txn.execute(sql, allvalues.values())
|
|
||||||
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# We didn't update any rows so insert a new one
|
||||||
|
allvalues = {}
|
||||||
|
allvalues.update(keyvalues)
|
||||||
|
allvalues.update(values)
|
||||||
|
allvalues.update(insertion_values)
|
||||||
|
|
||||||
|
sql = "INSERT INTO %s (%s) VALUES (%s)" % (
|
||||||
|
table,
|
||||||
|
", ".join(k for k in allvalues),
|
||||||
|
", ".join("?" for _ in allvalues)
|
||||||
|
)
|
||||||
|
txn.execute(sql, allvalues.values())
|
||||||
|
# successfully inserted
|
||||||
|
return True
|
||||||
|
|
||||||
def _simple_select_one(self, table, keyvalues, retcols,
|
def _simple_select_one(self, table, keyvalues, retcols,
|
||||||
allow_none=False, desc="_simple_select_one"):
|
allow_none=False, desc="_simple_select_one"):
|
||||||
"""Executes a SELECT query on the named table, which is expected to
|
"""Executes a SELECT query on the named table, which is expected to
|
||||||
return a single row, returning a single column from it.
|
return a single row, returning multiple columns from it.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
table : string giving the table name
|
table : string giving the table name
|
||||||
|
@ -582,20 +607,18 @@ class SQLBaseStore(object):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
|
def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
|
||||||
if keyvalues:
|
|
||||||
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
|
|
||||||
else:
|
|
||||||
where = ""
|
|
||||||
|
|
||||||
sql = (
|
sql = (
|
||||||
"SELECT %(retcol)s FROM %(table)s %(where)s"
|
"SELECT %(retcol)s FROM %(table)s"
|
||||||
) % {
|
) % {
|
||||||
"retcol": retcol,
|
"retcol": retcol,
|
||||||
"table": table,
|
"table": table,
|
||||||
"where": where,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
txn.execute(sql, keyvalues.values())
|
if keyvalues:
|
||||||
|
sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
|
||||||
|
txn.execute(sql, keyvalues.values())
|
||||||
|
else:
|
||||||
|
txn.execute(sql)
|
||||||
|
|
||||||
return [r[0] for r in txn]
|
return [r[0] for r in txn]
|
||||||
|
|
||||||
|
@ -606,7 +629,7 @@ class SQLBaseStore(object):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
table (str): table name
|
table (str): table name
|
||||||
keyvalues (dict): column names and values to select the rows with
|
keyvalues (dict|None): column names and values to select the rows with
|
||||||
retcol (str): column whos value we wish to retrieve.
|
retcol (str): column whos value we wish to retrieve.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
|
@ -222,9 +222,12 @@ class AccountDataStore(SQLBaseStore):
|
||||||
"""
|
"""
|
||||||
content_json = json.dumps(content)
|
content_json = json.dumps(content)
|
||||||
|
|
||||||
def add_account_data_txn(txn, next_id):
|
with self._account_data_id_gen.get_next() as next_id:
|
||||||
self._simple_upsert_txn(
|
# no need to lock here as room_account_data has a unique constraint
|
||||||
txn,
|
# on (user_id, room_id, account_data_type) so _simple_upsert will
|
||||||
|
# retry if there is a conflict.
|
||||||
|
yield self._simple_upsert(
|
||||||
|
desc="add_room_account_data",
|
||||||
table="room_account_data",
|
table="room_account_data",
|
||||||
keyvalues={
|
keyvalues={
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
|
@ -234,19 +237,20 @@ class AccountDataStore(SQLBaseStore):
|
||||||
values={
|
values={
|
||||||
"stream_id": next_id,
|
"stream_id": next_id,
|
||||||
"content": content_json,
|
"content": content_json,
|
||||||
}
|
},
|
||||||
|
lock=False,
|
||||||
)
|
)
|
||||||
txn.call_after(
|
|
||||||
self._account_data_stream_cache.entity_has_changed,
|
|
||||||
user_id, next_id,
|
|
||||||
)
|
|
||||||
txn.call_after(self.get_account_data_for_user.invalidate, (user_id,))
|
|
||||||
self._update_max_stream_id(txn, next_id)
|
|
||||||
|
|
||||||
with self._account_data_id_gen.get_next() as next_id:
|
# it's theoretically possible for the above to succeed and the
|
||||||
yield self.runInteraction(
|
# below to fail - in which case we might reuse a stream id on
|
||||||
"add_room_account_data", add_account_data_txn, next_id
|
# restart, and the above update might not get propagated. That
|
||||||
)
|
# doesn't sound any worse than the whole update getting lost,
|
||||||
|
# which is what would happen if we combined the two into one
|
||||||
|
# transaction.
|
||||||
|
yield self._update_max_stream_id(next_id)
|
||||||
|
|
||||||
|
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
|
||||||
|
self.get_account_data_for_user.invalidate((user_id,))
|
||||||
|
|
||||||
result = self._account_data_id_gen.get_current_token()
|
result = self._account_data_id_gen.get_current_token()
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
@ -263,9 +267,12 @@ class AccountDataStore(SQLBaseStore):
|
||||||
"""
|
"""
|
||||||
content_json = json.dumps(content)
|
content_json = json.dumps(content)
|
||||||
|
|
||||||
def add_account_data_txn(txn, next_id):
|
with self._account_data_id_gen.get_next() as next_id:
|
||||||
self._simple_upsert_txn(
|
# no need to lock here as account_data has a unique constraint on
|
||||||
txn,
|
# (user_id, account_data_type) so _simple_upsert will retry if
|
||||||
|
# there is a conflict.
|
||||||
|
yield self._simple_upsert(
|
||||||
|
desc="add_user_account_data",
|
||||||
table="account_data",
|
table="account_data",
|
||||||
keyvalues={
|
keyvalues={
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
|
@ -274,40 +281,46 @@ class AccountDataStore(SQLBaseStore):
|
||||||
values={
|
values={
|
||||||
"stream_id": next_id,
|
"stream_id": next_id,
|
||||||
"content": content_json,
|
"content": content_json,
|
||||||
}
|
},
|
||||||
|
lock=False,
|
||||||
)
|
)
|
||||||
txn.call_after(
|
|
||||||
self._account_data_stream_cache.entity_has_changed,
|
# it's theoretically possible for the above to succeed and the
|
||||||
|
# below to fail - in which case we might reuse a stream id on
|
||||||
|
# restart, and the above update might not get propagated. That
|
||||||
|
# doesn't sound any worse than the whole update getting lost,
|
||||||
|
# which is what would happen if we combined the two into one
|
||||||
|
# transaction.
|
||||||
|
yield self._update_max_stream_id(next_id)
|
||||||
|
|
||||||
|
self._account_data_stream_cache.entity_has_changed(
|
||||||
user_id, next_id,
|
user_id, next_id,
|
||||||
)
|
)
|
||||||
txn.call_after(self.get_account_data_for_user.invalidate, (user_id,))
|
self.get_account_data_for_user.invalidate((user_id,))
|
||||||
txn.call_after(
|
self.get_global_account_data_by_type_for_user.invalidate(
|
||||||
self.get_global_account_data_by_type_for_user.invalidate,
|
|
||||||
(account_data_type, user_id,)
|
(account_data_type, user_id,)
|
||||||
)
|
)
|
||||||
self._update_max_stream_id(txn, next_id)
|
|
||||||
|
|
||||||
with self._account_data_id_gen.get_next() as next_id:
|
|
||||||
yield self.runInteraction(
|
|
||||||
"add_user_account_data", add_account_data_txn, next_id
|
|
||||||
)
|
|
||||||
|
|
||||||
result = self._account_data_id_gen.get_current_token()
|
result = self._account_data_id_gen.get_current_token()
|
||||||
defer.returnValue(result)
|
defer.returnValue(result)
|
||||||
|
|
||||||
def _update_max_stream_id(self, txn, next_id):
|
def _update_max_stream_id(self, next_id):
|
||||||
"""Update the max stream_id
|
"""Update the max stream_id
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
txn: The database cursor
|
|
||||||
next_id(int): The the revision to advance to.
|
next_id(int): The the revision to advance to.
|
||||||
"""
|
"""
|
||||||
update_max_id_sql = (
|
def _update(txn):
|
||||||
"UPDATE account_data_max_stream_id"
|
update_max_id_sql = (
|
||||||
" SET stream_id = ?"
|
"UPDATE account_data_max_stream_id"
|
||||||
" WHERE stream_id < ?"
|
" SET stream_id = ?"
|
||||||
|
" WHERE stream_id < ?"
|
||||||
|
)
|
||||||
|
txn.execute(update_max_id_sql, (next_id, next_id))
|
||||||
|
return self.runInteraction(
|
||||||
|
"update_account_data_max_stream_id",
|
||||||
|
_update,
|
||||||
)
|
)
|
||||||
txn.execute(update_max_id_sql, (next_id, next_id))
|
|
||||||
|
|
||||||
@cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000)
|
@cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000)
|
||||||
def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context):
|
def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context):
|
||||||
|
|
|
@ -85,6 +85,7 @@ class BackgroundUpdateStore(SQLBaseStore):
|
||||||
self._background_update_performance = {}
|
self._background_update_performance = {}
|
||||||
self._background_update_queue = []
|
self._background_update_queue = []
|
||||||
self._background_update_handlers = {}
|
self._background_update_handlers = {}
|
||||||
|
self._all_done = False
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def start_doing_background_updates(self):
|
def start_doing_background_updates(self):
|
||||||
|
@ -106,8 +107,40 @@ class BackgroundUpdateStore(SQLBaseStore):
|
||||||
"No more background updates to do."
|
"No more background updates to do."
|
||||||
" Unscheduling background update task."
|
" Unscheduling background update task."
|
||||||
)
|
)
|
||||||
|
self._all_done = True
|
||||||
defer.returnValue(None)
|
defer.returnValue(None)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def has_completed_background_updates(self):
|
||||||
|
"""Check if all the background updates have completed
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Deferred[bool]: True if all background updates have completed
|
||||||
|
"""
|
||||||
|
# if we've previously determined that there is nothing left to do, that
|
||||||
|
# is easy
|
||||||
|
if self._all_done:
|
||||||
|
defer.returnValue(True)
|
||||||
|
|
||||||
|
# obviously, if we have things in our queue, we're not done.
|
||||||
|
if self._background_update_queue:
|
||||||
|
defer.returnValue(False)
|
||||||
|
|
||||||
|
# otherwise, check if there are updates to be run. This is important,
|
||||||
|
# as we may be running on a worker which doesn't perform the bg updates
|
||||||
|
# itself, but still wants to wait for them to happen.
|
||||||
|
updates = yield self._simple_select_onecol(
|
||||||
|
"background_updates",
|
||||||
|
keyvalues=None,
|
||||||
|
retcol="1",
|
||||||
|
desc="check_background_updates",
|
||||||
|
)
|
||||||
|
if not updates:
|
||||||
|
self._all_done = True
|
||||||
|
defer.returnValue(True)
|
||||||
|
|
||||||
|
defer.returnValue(False)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def do_next_background_update(self, desired_duration_ms):
|
def do_next_background_update(self, desired_duration_ms):
|
||||||
"""Does some amount of work on the next queued background update
|
"""Does some amount of work on the next queued background update
|
||||||
|
@ -269,7 +302,7 @@ class BackgroundUpdateStore(SQLBaseStore):
|
||||||
# Sqlite doesn't support concurrent creation of indexes.
|
# Sqlite doesn't support concurrent creation of indexes.
|
||||||
#
|
#
|
||||||
# We don't use partial indices on SQLite as it wasn't introduced
|
# We don't use partial indices on SQLite as it wasn't introduced
|
||||||
# until 3.8, and wheezy has 3.7
|
# until 3.8, and wheezy and CentOS 7 have 3.7
|
||||||
#
|
#
|
||||||
# We assume that sqlite doesn't give us invalid indices; however
|
# We assume that sqlite doesn't give us invalid indices; however
|
||||||
# we may still end up with the index existing but the
|
# we may still end up with the index existing but the
|
||||||
|
|
|
@ -12,13 +12,23 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
from synapse.storage.background_updates import BackgroundUpdateStore
|
||||||
from ._base import SQLBaseStore
|
|
||||||
|
|
||||||
|
|
||||||
class MediaRepositoryStore(SQLBaseStore):
|
class MediaRepositoryStore(BackgroundUpdateStore):
|
||||||
"""Persistence for attachments and avatars"""
|
"""Persistence for attachments and avatars"""
|
||||||
|
|
||||||
|
def __init__(self, db_conn, hs):
|
||||||
|
super(MediaRepositoryStore, self).__init__(db_conn, hs)
|
||||||
|
|
||||||
|
self.register_background_index_update(
|
||||||
|
update_name='local_media_repository_url_idx',
|
||||||
|
index_name='local_media_repository_url_idx',
|
||||||
|
table='local_media_repository',
|
||||||
|
columns=['created_ts'],
|
||||||
|
where_clause='url_cache IS NOT NULL',
|
||||||
|
)
|
||||||
|
|
||||||
def get_default_thumbnails(self, top_level_type, sub_type):
|
def get_default_thumbnails(self, top_level_type, sub_type):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,9 @@
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.storage.roommember import ProfileInfo
|
||||||
|
from synapse.api.errors import StoreError
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from ._base import SQLBaseStore
|
||||||
|
|
||||||
|
|
||||||
|
@ -26,6 +29,30 @@ class ProfileStore(SQLBaseStore):
|
||||||
desc="create_profile",
|
desc="create_profile",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_profileinfo(self, user_localpart):
|
||||||
|
try:
|
||||||
|
profile = yield self._simple_select_one(
|
||||||
|
table="profiles",
|
||||||
|
keyvalues={"user_id": user_localpart},
|
||||||
|
retcols=("displayname", "avatar_url"),
|
||||||
|
desc="get_profileinfo",
|
||||||
|
)
|
||||||
|
except StoreError as e:
|
||||||
|
if e.code == 404:
|
||||||
|
# no match
|
||||||
|
defer.returnValue(ProfileInfo(None, None))
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
defer.returnValue(
|
||||||
|
ProfileInfo(
|
||||||
|
avatar_url=profile['avatar_url'],
|
||||||
|
display_name=profile['displayname'],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def get_profile_displayname(self, user_localpart):
|
def get_profile_displayname(self, user_localpart):
|
||||||
return self._simple_select_one_onecol(
|
return self._simple_select_one_onecol(
|
||||||
table="profiles",
|
table="profiles",
|
||||||
|
|
|
@ -204,34 +204,35 @@ class PusherStore(SQLBaseStore):
|
||||||
pushkey, pushkey_ts, lang, data, last_stream_ordering,
|
pushkey, pushkey_ts, lang, data, last_stream_ordering,
|
||||||
profile_tag=""):
|
profile_tag=""):
|
||||||
with self._pushers_id_gen.get_next() as stream_id:
|
with self._pushers_id_gen.get_next() as stream_id:
|
||||||
def f(txn):
|
# no need to lock because `pushers` has a unique key on
|
||||||
newly_inserted = self._simple_upsert_txn(
|
# (app_id, pushkey, user_name) so _simple_upsert will retry
|
||||||
txn,
|
newly_inserted = yield self._simple_upsert(
|
||||||
"pushers",
|
table="pushers",
|
||||||
{
|
keyvalues={
|
||||||
"app_id": app_id,
|
"app_id": app_id,
|
||||||
"pushkey": pushkey,
|
"pushkey": pushkey,
|
||||||
"user_name": user_id,
|
"user_name": user_id,
|
||||||
},
|
},
|
||||||
{
|
values={
|
||||||
"access_token": access_token,
|
"access_token": access_token,
|
||||||
"kind": kind,
|
"kind": kind,
|
||||||
"app_display_name": app_display_name,
|
"app_display_name": app_display_name,
|
||||||
"device_display_name": device_display_name,
|
"device_display_name": device_display_name,
|
||||||
"ts": pushkey_ts,
|
"ts": pushkey_ts,
|
||||||
"lang": lang,
|
"lang": lang,
|
||||||
"data": encode_canonical_json(data),
|
"data": encode_canonical_json(data),
|
||||||
"last_stream_ordering": last_stream_ordering,
|
"last_stream_ordering": last_stream_ordering,
|
||||||
"profile_tag": profile_tag,
|
"profile_tag": profile_tag,
|
||||||
"id": stream_id,
|
"id": stream_id,
|
||||||
},
|
},
|
||||||
)
|
desc="add_pusher",
|
||||||
if newly_inserted:
|
lock=False,
|
||||||
# get_if_user_has_pusher only cares if the user has
|
)
|
||||||
# at least *one* pusher.
|
|
||||||
txn.call_after(self.get_if_user_has_pusher.invalidate, (user_id,))
|
|
||||||
|
|
||||||
yield self.runInteraction("add_pusher", f)
|
if newly_inserted:
|
||||||
|
# get_if_user_has_pusher only cares if the user has
|
||||||
|
# at least *one* pusher.
|
||||||
|
self.get_if_user_has_pusher.invalidate(user_id,)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id):
|
def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id):
|
||||||
|
@ -243,11 +244,19 @@ class PusherStore(SQLBaseStore):
|
||||||
"pushers",
|
"pushers",
|
||||||
{"app_id": app_id, "pushkey": pushkey, "user_name": user_id}
|
{"app_id": app_id, "pushkey": pushkey, "user_name": user_id}
|
||||||
)
|
)
|
||||||
self._simple_upsert_txn(
|
|
||||||
|
# it's possible for us to end up with duplicate rows for
|
||||||
|
# (app_id, pushkey, user_id) at different stream_ids, but that
|
||||||
|
# doesn't really matter.
|
||||||
|
self._simple_insert_txn(
|
||||||
txn,
|
txn,
|
||||||
"deleted_pushers",
|
table="deleted_pushers",
|
||||||
{"app_id": app_id, "pushkey": pushkey, "user_id": user_id},
|
values={
|
||||||
{"stream_id": stream_id},
|
"stream_id": stream_id,
|
||||||
|
"app_id": app_id,
|
||||||
|
"pushkey": pushkey,
|
||||||
|
"user_id": user_id,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
with self._pushers_id_gen.get_next() as stream_id:
|
with self._pushers_id_gen.get_next() as stream_id:
|
||||||
|
@ -310,9 +319,12 @@ class PusherStore(SQLBaseStore):
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def set_throttle_params(self, pusher_id, room_id, params):
|
def set_throttle_params(self, pusher_id, room_id, params):
|
||||||
|
# no need to lock because `pusher_throttle` has a primary key on
|
||||||
|
# (pusher, room_id) so _simple_upsert will retry
|
||||||
yield self._simple_upsert(
|
yield self._simple_upsert(
|
||||||
"pusher_throttle",
|
"pusher_throttle",
|
||||||
{"pusher": pusher_id, "room_id": room_id},
|
{"pusher": pusher_id, "room_id": room_id},
|
||||||
params,
|
params,
|
||||||
desc="set_throttle_params"
|
desc="set_throttle_params",
|
||||||
|
lock=False,
|
||||||
)
|
)
|
||||||
|
|
|
@ -254,8 +254,8 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
||||||
If None, tokens associated with any device (or no device) will
|
If None, tokens associated with any device (or no device) will
|
||||||
be deleted
|
be deleted
|
||||||
Returns:
|
Returns:
|
||||||
defer.Deferred[list[str, str|None]]: a list of the deleted tokens
|
defer.Deferred[list[str, int, str|None, int]]: a list of
|
||||||
and device IDs
|
(token, token id, device id) for each of the deleted tokens
|
||||||
"""
|
"""
|
||||||
def f(txn):
|
def f(txn):
|
||||||
keyvalues = {
|
keyvalues = {
|
||||||
|
@ -272,12 +272,12 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
|
||||||
values.append(except_token_id)
|
values.append(except_token_id)
|
||||||
|
|
||||||
txn.execute(
|
txn.execute(
|
||||||
"SELECT token, device_id FROM access_tokens WHERE %s" % where_clause,
|
"SELECT token, id, device_id FROM access_tokens WHERE %s" % where_clause,
|
||||||
values
|
values
|
||||||
)
|
)
|
||||||
tokens_and_devices = [(r[0], r[1]) for r in txn]
|
tokens_and_devices = [(r[0], r[1], r[2]) for r in txn]
|
||||||
|
|
||||||
for token, _ in tokens_and_devices:
|
for token, _, _ in tokens_and_devices:
|
||||||
self._invalidate_cache_and_stream(
|
self._invalidate_cache_and_stream(
|
||||||
txn, self.get_user_by_access_token, (token,)
|
txn, self.get_user_by_access_token, (token,)
|
||||||
)
|
)
|
||||||
|
|
|
@ -13,7 +13,10 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
CREATE INDEX local_media_repository_url_idx ON local_media_repository(created_ts) WHERE url_cache IS NOT NULL;
|
-- this didn't work on SQLite 3.7 (because of lack of partial indexes), so was
|
||||||
|
-- removed and replaced with 46/local_media_repository_url_idx.sql.
|
||||||
|
--
|
||||||
|
-- CREATE INDEX local_media_repository_url_idx ON local_media_repository(created_ts) WHERE url_cache IS NOT NULL;
|
||||||
|
|
||||||
-- we need to change `expires` to `expires_ts` so that we can index on it. SQLite doesn't support
|
-- we need to change `expires` to `expires_ts` so that we can index on it. SQLite doesn't support
|
||||||
-- indices on expressions until 3.9.
|
-- indices on expressions until 3.9.
|
||||||
|
|
|
@ -0,0 +1,35 @@
|
||||||
|
/* Copyright 2017 New Vector Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
-- drop the unique constraint on deleted_pushers so that we can just insert
|
||||||
|
-- into it rather than upserting.
|
||||||
|
|
||||||
|
CREATE TABLE deleted_pushers2 (
|
||||||
|
stream_id BIGINT NOT NULL,
|
||||||
|
app_id TEXT NOT NULL,
|
||||||
|
pushkey TEXT NOT NULL,
|
||||||
|
user_id TEXT NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
INSERT INTO deleted_pushers2 (stream_id, app_id, pushkey, user_id)
|
||||||
|
SELECT stream_id, app_id, pushkey, user_id from deleted_pushers;
|
||||||
|
|
||||||
|
DROP TABLE deleted_pushers;
|
||||||
|
ALTER TABLE deleted_pushers2 RENAME TO deleted_pushers;
|
||||||
|
|
||||||
|
-- create the index after doing the inserts because that's more efficient.
|
||||||
|
-- it also means we can give it the same name as the old one without renaming.
|
||||||
|
CREATE INDEX deleted_pushers_stream_id ON deleted_pushers (stream_id);
|
||||||
|
|
|
@ -0,0 +1,24 @@
|
||||||
|
/* Copyright 2017 New Vector Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
-- register a background update which will recreate the
|
||||||
|
-- local_media_repository_url_idx index.
|
||||||
|
--
|
||||||
|
-- We do this as a bg update not because it is a particularly onerous
|
||||||
|
-- operation, but because we'd like it to be a partial index if possible, and
|
||||||
|
-- the background_index_update code will understand whether we are on
|
||||||
|
-- postgres or sqlite and behave accordingly.
|
||||||
|
INSERT INTO background_updates (update_name, progress_json) VALUES
|
||||||
|
('local_media_repository_url_idx', '{}');
|
35
synapse/storage/schema/delta/46/user_dir_null_room_ids.sql
Normal file
35
synapse/storage/schema/delta/46/user_dir_null_room_ids.sql
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
/* Copyright 2017 New Vector Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
-- change the user_directory table to also cover global local user profiles
|
||||||
|
-- rather than just profiles within specific rooms.
|
||||||
|
|
||||||
|
CREATE TABLE user_directory2 (
|
||||||
|
user_id TEXT NOT NULL,
|
||||||
|
room_id TEXT,
|
||||||
|
display_name TEXT,
|
||||||
|
avatar_url TEXT
|
||||||
|
);
|
||||||
|
|
||||||
|
INSERT INTO user_directory2(user_id, room_id, display_name, avatar_url)
|
||||||
|
SELECT user_id, room_id, display_name, avatar_url from user_directory;
|
||||||
|
|
||||||
|
DROP TABLE user_directory;
|
||||||
|
ALTER TABLE user_directory2 RENAME TO user_directory;
|
||||||
|
|
||||||
|
-- create indexes after doing the inserts because that's more efficient.
|
||||||
|
-- it also means we can give it the same name as the old one without renaming.
|
||||||
|
CREATE INDEX user_directory_room_idx ON user_directory(room_id);
|
||||||
|
CREATE UNIQUE INDEX user_directory_user_idx ON user_directory(user_id);
|
|
@ -13,16 +13,18 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from ._base import SQLBaseStore
|
from collections import namedtuple
|
||||||
from synapse.util.caches.descriptors import cached, cachedList
|
import logging
|
||||||
from synapse.util.caches import intern_string
|
|
||||||
from synapse.util.stringutils import to_ascii
|
|
||||||
from synapse.storage.engines import PostgresEngine
|
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from collections import namedtuple
|
|
||||||
|
|
||||||
import logging
|
from synapse.storage.background_updates import BackgroundUpdateStore
|
||||||
|
from synapse.storage.engines import PostgresEngine
|
||||||
|
from synapse.util.caches import intern_string, CACHE_SIZE_FACTOR
|
||||||
|
from synapse.util.caches.descriptors import cached, cachedList
|
||||||
|
from synapse.util.caches.dictionary_cache import DictionaryCache
|
||||||
|
from synapse.util.stringutils import to_ascii
|
||||||
|
from ._base import SQLBaseStore
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -40,23 +42,11 @@ class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delt
|
||||||
return len(self.delta_ids) if self.delta_ids else 0
|
return len(self.delta_ids) if self.delta_ids else 0
|
||||||
|
|
||||||
|
|
||||||
class StateStore(SQLBaseStore):
|
class StateGroupReadStore(SQLBaseStore):
|
||||||
""" Keeps track of the state at a given event.
|
"""The read-only parts of StateGroupStore
|
||||||
|
|
||||||
This is done by the concept of `state groups`. Every event is a assigned
|
None of these functions write to the state tables, so are suitable for
|
||||||
a state group (identified by an arbitrary string), which references a
|
including in the SlavedStores.
|
||||||
collection of state events. The current state of an event is then the
|
|
||||||
collection of state events referenced by the event's state group.
|
|
||||||
|
|
||||||
Hence, every change in the current state causes a new state group to be
|
|
||||||
generated. However, if no change happens (e.g., if we get a message event
|
|
||||||
with only one parent it inherits the state group from its parent.)
|
|
||||||
|
|
||||||
There are three tables:
|
|
||||||
* `state_groups`: Stores group name, first event with in the group and
|
|
||||||
room id.
|
|
||||||
* `event_to_state_groups`: Maps events to state groups.
|
|
||||||
* `state_groups_state`: Maps state group to state events.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
|
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
|
||||||
|
@ -64,21 +54,10 @@ class StateStore(SQLBaseStore):
|
||||||
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
|
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
|
||||||
|
|
||||||
def __init__(self, db_conn, hs):
|
def __init__(self, db_conn, hs):
|
||||||
super(StateStore, self).__init__(db_conn, hs)
|
super(StateGroupReadStore, self).__init__(db_conn, hs)
|
||||||
self.register_background_update_handler(
|
|
||||||
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
|
self._state_group_cache = DictionaryCache(
|
||||||
self._background_deduplicate_state,
|
"*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR
|
||||||
)
|
|
||||||
self.register_background_update_handler(
|
|
||||||
self.STATE_GROUP_INDEX_UPDATE_NAME,
|
|
||||||
self._background_index_state,
|
|
||||||
)
|
|
||||||
self.register_background_index_update(
|
|
||||||
self.CURRENT_STATE_INDEX_UPDATE_NAME,
|
|
||||||
index_name="current_state_events_member_index",
|
|
||||||
table="current_state_events",
|
|
||||||
columns=["state_key"],
|
|
||||||
where_clause="type='m.room.member'",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@cached(max_entries=100000, iterable=True)
|
@cached(max_entries=100000, iterable=True)
|
||||||
|
@ -190,178 +169,6 @@ class StateStore(SQLBaseStore):
|
||||||
for group, event_id_map in group_to_ids.iteritems()
|
for group, event_id_map in group_to_ids.iteritems()
|
||||||
})
|
})
|
||||||
|
|
||||||
def _have_persisted_state_group_txn(self, txn, state_group):
|
|
||||||
txn.execute(
|
|
||||||
"SELECT count(*) FROM state_groups WHERE id = ?",
|
|
||||||
(state_group,)
|
|
||||||
)
|
|
||||||
row = txn.fetchone()
|
|
||||||
return row and row[0]
|
|
||||||
|
|
||||||
def _store_mult_state_groups_txn(self, txn, events_and_contexts):
|
|
||||||
state_groups = {}
|
|
||||||
for event, context in events_and_contexts:
|
|
||||||
if event.internal_metadata.is_outlier():
|
|
||||||
continue
|
|
||||||
|
|
||||||
if context.current_state_ids is None:
|
|
||||||
# AFAIK, this can never happen
|
|
||||||
logger.error(
|
|
||||||
"Non-outlier event %s had current_state_ids==None",
|
|
||||||
event.event_id)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# if the event was rejected, just give it the same state as its
|
|
||||||
# predecessor.
|
|
||||||
if context.rejected:
|
|
||||||
state_groups[event.event_id] = context.prev_group
|
|
||||||
continue
|
|
||||||
|
|
||||||
state_groups[event.event_id] = context.state_group
|
|
||||||
|
|
||||||
if self._have_persisted_state_group_txn(txn, context.state_group):
|
|
||||||
continue
|
|
||||||
|
|
||||||
self._simple_insert_txn(
|
|
||||||
txn,
|
|
||||||
table="state_groups",
|
|
||||||
values={
|
|
||||||
"id": context.state_group,
|
|
||||||
"room_id": event.room_id,
|
|
||||||
"event_id": event.event_id,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# We persist as a delta if we can, while also ensuring the chain
|
|
||||||
# of deltas isn't tooo long, as otherwise read performance degrades.
|
|
||||||
if context.prev_group:
|
|
||||||
is_in_db = self._simple_select_one_onecol_txn(
|
|
||||||
txn,
|
|
||||||
table="state_groups",
|
|
||||||
keyvalues={"id": context.prev_group},
|
|
||||||
retcol="id",
|
|
||||||
allow_none=True,
|
|
||||||
)
|
|
||||||
if not is_in_db:
|
|
||||||
raise Exception(
|
|
||||||
"Trying to persist state with unpersisted prev_group: %r"
|
|
||||||
% (context.prev_group,)
|
|
||||||
)
|
|
||||||
|
|
||||||
potential_hops = self._count_state_group_hops_txn(
|
|
||||||
txn, context.prev_group
|
|
||||||
)
|
|
||||||
if context.prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
|
|
||||||
self._simple_insert_txn(
|
|
||||||
txn,
|
|
||||||
table="state_group_edges",
|
|
||||||
values={
|
|
||||||
"state_group": context.state_group,
|
|
||||||
"prev_state_group": context.prev_group,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
self._simple_insert_many_txn(
|
|
||||||
txn,
|
|
||||||
table="state_groups_state",
|
|
||||||
values=[
|
|
||||||
{
|
|
||||||
"state_group": context.state_group,
|
|
||||||
"room_id": event.room_id,
|
|
||||||
"type": key[0],
|
|
||||||
"state_key": key[1],
|
|
||||||
"event_id": state_id,
|
|
||||||
}
|
|
||||||
for key, state_id in context.delta_ids.iteritems()
|
|
||||||
],
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self._simple_insert_many_txn(
|
|
||||||
txn,
|
|
||||||
table="state_groups_state",
|
|
||||||
values=[
|
|
||||||
{
|
|
||||||
"state_group": context.state_group,
|
|
||||||
"room_id": event.room_id,
|
|
||||||
"type": key[0],
|
|
||||||
"state_key": key[1],
|
|
||||||
"event_id": state_id,
|
|
||||||
}
|
|
||||||
for key, state_id in context.current_state_ids.iteritems()
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prefill the state group cache with this group.
|
|
||||||
# It's fine to use the sequence like this as the state group map
|
|
||||||
# is immutable. (If the map wasn't immutable then this prefill could
|
|
||||||
# race with another update)
|
|
||||||
txn.call_after(
|
|
||||||
self._state_group_cache.update,
|
|
||||||
self._state_group_cache.sequence,
|
|
||||||
key=context.state_group,
|
|
||||||
value=dict(context.current_state_ids),
|
|
||||||
full=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._simple_insert_many_txn(
|
|
||||||
txn,
|
|
||||||
table="event_to_state_groups",
|
|
||||||
values=[
|
|
||||||
{
|
|
||||||
"state_group": state_group_id,
|
|
||||||
"event_id": event_id,
|
|
||||||
}
|
|
||||||
for event_id, state_group_id in state_groups.iteritems()
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
for event_id, state_group_id in state_groups.iteritems():
|
|
||||||
txn.call_after(
|
|
||||||
self._get_state_group_for_event.prefill,
|
|
||||||
(event_id,), state_group_id
|
|
||||||
)
|
|
||||||
|
|
||||||
def _count_state_group_hops_txn(self, txn, state_group):
|
|
||||||
"""Given a state group, count how many hops there are in the tree.
|
|
||||||
|
|
||||||
This is used to ensure the delta chains don't get too long.
|
|
||||||
"""
|
|
||||||
if isinstance(self.database_engine, PostgresEngine):
|
|
||||||
sql = ("""
|
|
||||||
WITH RECURSIVE state(state_group) AS (
|
|
||||||
VALUES(?::bigint)
|
|
||||||
UNION ALL
|
|
||||||
SELECT prev_state_group FROM state_group_edges e, state s
|
|
||||||
WHERE s.state_group = e.state_group
|
|
||||||
)
|
|
||||||
SELECT count(*) FROM state;
|
|
||||||
""")
|
|
||||||
|
|
||||||
txn.execute(sql, (state_group,))
|
|
||||||
row = txn.fetchone()
|
|
||||||
if row and row[0]:
|
|
||||||
return row[0]
|
|
||||||
else:
|
|
||||||
return 0
|
|
||||||
else:
|
|
||||||
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
|
|
||||||
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
|
|
||||||
next_group = state_group
|
|
||||||
count = 0
|
|
||||||
|
|
||||||
while next_group:
|
|
||||||
next_group = self._simple_select_one_onecol_txn(
|
|
||||||
txn,
|
|
||||||
table="state_group_edges",
|
|
||||||
keyvalues={"state_group": next_group},
|
|
||||||
retcol="prev_state_group",
|
|
||||||
allow_none=True,
|
|
||||||
)
|
|
||||||
if next_group:
|
|
||||||
count += 1
|
|
||||||
|
|
||||||
return count
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def _get_state_groups_from_groups(self, groups, types):
|
def _get_state_groups_from_groups(self, groups, types):
|
||||||
"""Returns dictionary state_group -> (dict of (type, state_key) -> event id)
|
"""Returns dictionary state_group -> (dict of (type, state_key) -> event id)
|
||||||
|
@ -742,6 +549,220 @@ class StateStore(SQLBaseStore):
|
||||||
|
|
||||||
defer.returnValue(results)
|
defer.returnValue(results)
|
||||||
|
|
||||||
|
|
||||||
|
class StateStore(StateGroupReadStore, BackgroundUpdateStore):
|
||||||
|
""" Keeps track of the state at a given event.
|
||||||
|
|
||||||
|
This is done by the concept of `state groups`. Every event is a assigned
|
||||||
|
a state group (identified by an arbitrary string), which references a
|
||||||
|
collection of state events. The current state of an event is then the
|
||||||
|
collection of state events referenced by the event's state group.
|
||||||
|
|
||||||
|
Hence, every change in the current state causes a new state group to be
|
||||||
|
generated. However, if no change happens (e.g., if we get a message event
|
||||||
|
with only one parent it inherits the state group from its parent.)
|
||||||
|
|
||||||
|
There are three tables:
|
||||||
|
* `state_groups`: Stores group name, first event with in the group and
|
||||||
|
room id.
|
||||||
|
* `event_to_state_groups`: Maps events to state groups.
|
||||||
|
* `state_groups_state`: Maps state group to state events.
|
||||||
|
"""
|
||||||
|
|
||||||
|
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
|
||||||
|
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
|
||||||
|
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
|
||||||
|
|
||||||
|
def __init__(self, db_conn, hs):
|
||||||
|
super(StateStore, self).__init__(db_conn, hs)
|
||||||
|
self.register_background_update_handler(
|
||||||
|
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
|
||||||
|
self._background_deduplicate_state,
|
||||||
|
)
|
||||||
|
self.register_background_update_handler(
|
||||||
|
self.STATE_GROUP_INDEX_UPDATE_NAME,
|
||||||
|
self._background_index_state,
|
||||||
|
)
|
||||||
|
self.register_background_index_update(
|
||||||
|
self.CURRENT_STATE_INDEX_UPDATE_NAME,
|
||||||
|
index_name="current_state_events_member_index",
|
||||||
|
table="current_state_events",
|
||||||
|
columns=["state_key"],
|
||||||
|
where_clause="type='m.room.member'",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _have_persisted_state_group_txn(self, txn, state_group):
|
||||||
|
txn.execute(
|
||||||
|
"SELECT count(*) FROM state_groups WHERE id = ?",
|
||||||
|
(state_group,)
|
||||||
|
)
|
||||||
|
row = txn.fetchone()
|
||||||
|
return row and row[0]
|
||||||
|
|
||||||
|
def _store_mult_state_groups_txn(self, txn, events_and_contexts):
|
||||||
|
state_groups = {}
|
||||||
|
for event, context in events_and_contexts:
|
||||||
|
if event.internal_metadata.is_outlier():
|
||||||
|
continue
|
||||||
|
|
||||||
|
if context.current_state_ids is None:
|
||||||
|
# AFAIK, this can never happen
|
||||||
|
logger.error(
|
||||||
|
"Non-outlier event %s had current_state_ids==None",
|
||||||
|
event.event_id)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# if the event was rejected, just give it the same state as its
|
||||||
|
# predecessor.
|
||||||
|
if context.rejected:
|
||||||
|
state_groups[event.event_id] = context.prev_group
|
||||||
|
continue
|
||||||
|
|
||||||
|
state_groups[event.event_id] = context.state_group
|
||||||
|
|
||||||
|
if self._have_persisted_state_group_txn(txn, context.state_group):
|
||||||
|
continue
|
||||||
|
|
||||||
|
self._simple_insert_txn(
|
||||||
|
txn,
|
||||||
|
table="state_groups",
|
||||||
|
values={
|
||||||
|
"id": context.state_group,
|
||||||
|
"room_id": event.room_id,
|
||||||
|
"event_id": event.event_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# We persist as a delta if we can, while also ensuring the chain
|
||||||
|
# of deltas isn't tooo long, as otherwise read performance degrades.
|
||||||
|
if context.prev_group:
|
||||||
|
is_in_db = self._simple_select_one_onecol_txn(
|
||||||
|
txn,
|
||||||
|
table="state_groups",
|
||||||
|
keyvalues={"id": context.prev_group},
|
||||||
|
retcol="id",
|
||||||
|
allow_none=True,
|
||||||
|
)
|
||||||
|
if not is_in_db:
|
||||||
|
raise Exception(
|
||||||
|
"Trying to persist state with unpersisted prev_group: %r"
|
||||||
|
% (context.prev_group,)
|
||||||
|
)
|
||||||
|
|
||||||
|
potential_hops = self._count_state_group_hops_txn(
|
||||||
|
txn, context.prev_group
|
||||||
|
)
|
||||||
|
if context.prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
|
||||||
|
self._simple_insert_txn(
|
||||||
|
txn,
|
||||||
|
table="state_group_edges",
|
||||||
|
values={
|
||||||
|
"state_group": context.state_group,
|
||||||
|
"prev_state_group": context.prev_group,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
self._simple_insert_many_txn(
|
||||||
|
txn,
|
||||||
|
table="state_groups_state",
|
||||||
|
values=[
|
||||||
|
{
|
||||||
|
"state_group": context.state_group,
|
||||||
|
"room_id": event.room_id,
|
||||||
|
"type": key[0],
|
||||||
|
"state_key": key[1],
|
||||||
|
"event_id": state_id,
|
||||||
|
}
|
||||||
|
for key, state_id in context.delta_ids.iteritems()
|
||||||
|
],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._simple_insert_many_txn(
|
||||||
|
txn,
|
||||||
|
table="state_groups_state",
|
||||||
|
values=[
|
||||||
|
{
|
||||||
|
"state_group": context.state_group,
|
||||||
|
"room_id": event.room_id,
|
||||||
|
"type": key[0],
|
||||||
|
"state_key": key[1],
|
||||||
|
"event_id": state_id,
|
||||||
|
}
|
||||||
|
for key, state_id in context.current_state_ids.iteritems()
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prefill the state group cache with this group.
|
||||||
|
# It's fine to use the sequence like this as the state group map
|
||||||
|
# is immutable. (If the map wasn't immutable then this prefill could
|
||||||
|
# race with another update)
|
||||||
|
txn.call_after(
|
||||||
|
self._state_group_cache.update,
|
||||||
|
self._state_group_cache.sequence,
|
||||||
|
key=context.state_group,
|
||||||
|
value=dict(context.current_state_ids),
|
||||||
|
full=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._simple_insert_many_txn(
|
||||||
|
txn,
|
||||||
|
table="event_to_state_groups",
|
||||||
|
values=[
|
||||||
|
{
|
||||||
|
"state_group": state_group_id,
|
||||||
|
"event_id": event_id,
|
||||||
|
}
|
||||||
|
for event_id, state_group_id in state_groups.iteritems()
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
for event_id, state_group_id in state_groups.iteritems():
|
||||||
|
txn.call_after(
|
||||||
|
self._get_state_group_for_event.prefill,
|
||||||
|
(event_id,), state_group_id
|
||||||
|
)
|
||||||
|
|
||||||
|
def _count_state_group_hops_txn(self, txn, state_group):
|
||||||
|
"""Given a state group, count how many hops there are in the tree.
|
||||||
|
|
||||||
|
This is used to ensure the delta chains don't get too long.
|
||||||
|
"""
|
||||||
|
if isinstance(self.database_engine, PostgresEngine):
|
||||||
|
sql = ("""
|
||||||
|
WITH RECURSIVE state(state_group) AS (
|
||||||
|
VALUES(?::bigint)
|
||||||
|
UNION ALL
|
||||||
|
SELECT prev_state_group FROM state_group_edges e, state s
|
||||||
|
WHERE s.state_group = e.state_group
|
||||||
|
)
|
||||||
|
SELECT count(*) FROM state;
|
||||||
|
""")
|
||||||
|
|
||||||
|
txn.execute(sql, (state_group,))
|
||||||
|
row = txn.fetchone()
|
||||||
|
if row and row[0]:
|
||||||
|
return row[0]
|
||||||
|
else:
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
|
||||||
|
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
|
||||||
|
next_group = state_group
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
while next_group:
|
||||||
|
next_group = self._simple_select_one_onecol_txn(
|
||||||
|
txn,
|
||||||
|
table="state_group_edges",
|
||||||
|
keyvalues={"state_group": next_group},
|
||||||
|
retcol="prev_state_group",
|
||||||
|
allow_none=True,
|
||||||
|
)
|
||||||
|
if next_group:
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
return count
|
||||||
|
|
||||||
def get_next_state_group(self):
|
def get_next_state_group(self):
|
||||||
return self._state_groups_id_gen.get_next()
|
return self._state_groups_id_gen.get_next()
|
||||||
|
|
||||||
|
|
|
@ -39,7 +39,7 @@ from ._base import SQLBaseStore
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached
|
||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.types import RoomStreamToken
|
from synapse.types import RoomStreamToken
|
||||||
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
|
||||||
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
@ -234,7 +234,7 @@ class StreamStore(SQLBaseStore):
|
||||||
results = {}
|
results = {}
|
||||||
room_ids = list(room_ids)
|
room_ids = list(room_ids)
|
||||||
for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)):
|
for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)):
|
||||||
res = yield preserve_context_over_deferred(defer.gatherResults([
|
res = yield make_deferred_yieldable(defer.gatherResults([
|
||||||
preserve_fn(self.get_room_events_stream_for_room)(
|
preserve_fn(self.get_room_events_stream_for_room)(
|
||||||
room_id, from_key, to_key, limit, order=order,
|
room_id, from_key, to_key, limit, order=order,
|
||||||
)
|
)
|
||||||
|
|
|
@ -164,7 +164,7 @@ class UserDirectoryStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(self.database_engine, PostgresEngine):
|
if isinstance(self.database_engine, PostgresEngine):
|
||||||
# We weight the loclpart most highly, then display name and finally
|
# We weight the localpart most highly, then display name and finally
|
||||||
# server name
|
# server name
|
||||||
if new_entry:
|
if new_entry:
|
||||||
sql = """
|
sql = """
|
||||||
|
@ -317,6 +317,16 @@ class UserDirectoryStore(SQLBaseStore):
|
||||||
rows = yield self._execute("get_all_rooms", None, sql)
|
rows = yield self._execute("get_all_rooms", None, sql)
|
||||||
defer.returnValue([room_id for room_id, in rows])
|
defer.returnValue([room_id for room_id, in rows])
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def get_all_local_users(self):
|
||||||
|
"""Get all local users
|
||||||
|
"""
|
||||||
|
sql = """
|
||||||
|
SELECT name FROM users
|
||||||
|
"""
|
||||||
|
rows = yield self._execute("get_all_local_users", None, sql)
|
||||||
|
defer.returnValue([name for name, in rows])
|
||||||
|
|
||||||
def add_users_who_share_room(self, room_id, share_private, user_id_tuples):
|
def add_users_who_share_room(self, room_id, share_private, user_id_tuples):
|
||||||
"""Insert entries into the users_who_share_rooms table. The first
|
"""Insert entries into the users_who_share_rooms table. The first
|
||||||
user should be a local user.
|
user should be a local user.
|
||||||
|
@ -629,6 +639,20 @@ class UserDirectoryStore(SQLBaseStore):
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if self.hs.config.user_directory_search_all_users:
|
||||||
|
join_clause = ""
|
||||||
|
where_clause = "?<>''" # naughty hack to keep the same number of binds
|
||||||
|
else:
|
||||||
|
join_clause = """
|
||||||
|
LEFT JOIN users_in_public_rooms AS p USING (user_id)
|
||||||
|
LEFT JOIN (
|
||||||
|
SELECT other_user_id AS user_id FROM users_who_share_rooms
|
||||||
|
WHERE user_id = ? AND share_private
|
||||||
|
) AS s USING (user_id)
|
||||||
|
"""
|
||||||
|
where_clause = "(s.user_id IS NOT NULL OR p.user_id IS NOT NULL)"
|
||||||
|
|
||||||
if isinstance(self.database_engine, PostgresEngine):
|
if isinstance(self.database_engine, PostgresEngine):
|
||||||
full_query, exact_query, prefix_query = _parse_query_postgres(search_term)
|
full_query, exact_query, prefix_query = _parse_query_postgres(search_term)
|
||||||
|
|
||||||
|
@ -641,13 +665,9 @@ class UserDirectoryStore(SQLBaseStore):
|
||||||
SELECT d.user_id, display_name, avatar_url
|
SELECT d.user_id, display_name, avatar_url
|
||||||
FROM user_directory_search
|
FROM user_directory_search
|
||||||
INNER JOIN user_directory AS d USING (user_id)
|
INNER JOIN user_directory AS d USING (user_id)
|
||||||
LEFT JOIN users_in_public_rooms AS p USING (user_id)
|
%s
|
||||||
LEFT JOIN (
|
|
||||||
SELECT other_user_id AS user_id FROM users_who_share_rooms
|
|
||||||
WHERE user_id = ? AND share_private
|
|
||||||
) AS s USING (user_id)
|
|
||||||
WHERE
|
WHERE
|
||||||
(s.user_id IS NOT NULL OR p.user_id IS NOT NULL)
|
%s
|
||||||
AND vector @@ to_tsquery('english', ?)
|
AND vector @@ to_tsquery('english', ?)
|
||||||
ORDER BY
|
ORDER BY
|
||||||
(CASE WHEN s.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END)
|
(CASE WHEN s.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END)
|
||||||
|
@ -671,7 +691,7 @@ class UserDirectoryStore(SQLBaseStore):
|
||||||
display_name IS NULL,
|
display_name IS NULL,
|
||||||
avatar_url IS NULL
|
avatar_url IS NULL
|
||||||
LIMIT ?
|
LIMIT ?
|
||||||
"""
|
""" % (join_clause, where_clause)
|
||||||
args = (user_id, full_query, exact_query, prefix_query, limit + 1,)
|
args = (user_id, full_query, exact_query, prefix_query, limit + 1,)
|
||||||
elif isinstance(self.database_engine, Sqlite3Engine):
|
elif isinstance(self.database_engine, Sqlite3Engine):
|
||||||
search_query = _parse_query_sqlite(search_term)
|
search_query = _parse_query_sqlite(search_term)
|
||||||
|
@ -680,20 +700,16 @@ class UserDirectoryStore(SQLBaseStore):
|
||||||
SELECT d.user_id, display_name, avatar_url
|
SELECT d.user_id, display_name, avatar_url
|
||||||
FROM user_directory_search
|
FROM user_directory_search
|
||||||
INNER JOIN user_directory AS d USING (user_id)
|
INNER JOIN user_directory AS d USING (user_id)
|
||||||
LEFT JOIN users_in_public_rooms AS p USING (user_id)
|
%s
|
||||||
LEFT JOIN (
|
|
||||||
SELECT other_user_id AS user_id FROM users_who_share_rooms
|
|
||||||
WHERE user_id = ? AND share_private
|
|
||||||
) AS s USING (user_id)
|
|
||||||
WHERE
|
WHERE
|
||||||
(s.user_id IS NOT NULL OR p.user_id IS NOT NULL)
|
%s
|
||||||
AND value MATCH ?
|
AND value MATCH ?
|
||||||
ORDER BY
|
ORDER BY
|
||||||
rank(matchinfo(user_directory_search)) DESC,
|
rank(matchinfo(user_directory_search)) DESC,
|
||||||
display_name IS NULL,
|
display_name IS NULL,
|
||||||
avatar_url IS NULL
|
avatar_url IS NULL
|
||||||
LIMIT ?
|
LIMIT ?
|
||||||
"""
|
""" % (join_clause, where_clause)
|
||||||
args = (user_id, search_query, limit + 1)
|
args = (user_id, search_query, limit + 1)
|
||||||
else:
|
else:
|
||||||
# This should be unreachable.
|
# This should be unreachable.
|
||||||
|
@ -723,7 +739,7 @@ def _parse_query_sqlite(search_term):
|
||||||
|
|
||||||
# Pull out the individual words, discarding any non-word characters.
|
# Pull out the individual words, discarding any non-word characters.
|
||||||
results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
|
results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
|
||||||
return " & ".join("(%s* | %s)" % (result, result,) for result in results)
|
return " & ".join("(%s* OR %s)" % (result, result,) for result in results)
|
||||||
|
|
||||||
|
|
||||||
def _parse_query_postgres(search_term):
|
def _parse_query_postgres(search_term):
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
from twisted.internet import defer, reactor
|
from twisted.internet import defer, reactor
|
||||||
|
|
||||||
from .logcontext import (
|
from .logcontext import (
|
||||||
PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
|
PreserveLoggingContext, make_deferred_yieldable, preserve_fn
|
||||||
)
|
)
|
||||||
from synapse.util import logcontext, unwrapFirstError
|
from synapse.util import logcontext, unwrapFirstError
|
||||||
|
|
||||||
|
@ -351,7 +351,7 @@ class ReadWriteLock(object):
|
||||||
|
|
||||||
# We wait for the latest writer to finish writing. We can safely ignore
|
# We wait for the latest writer to finish writing. We can safely ignore
|
||||||
# any existing readers... as they're readers.
|
# any existing readers... as they're readers.
|
||||||
yield curr_writer
|
yield make_deferred_yieldable(curr_writer)
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def _ctx_manager():
|
def _ctx_manager():
|
||||||
|
@ -380,7 +380,7 @@ class ReadWriteLock(object):
|
||||||
curr_readers.clear()
|
curr_readers.clear()
|
||||||
self.key_to_current_writer[key] = new_defer
|
self.key_to_current_writer[key] = new_defer
|
||||||
|
|
||||||
yield preserve_context_over_deferred(defer.gatherResults(to_wait_on))
|
yield make_deferred_yieldable(defer.gatherResults(to_wait_on))
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def _ctx_manager():
|
def _ctx_manager():
|
||||||
|
|
|
@ -13,32 +13,24 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.util.logcontext import (
|
|
||||||
PreserveLoggingContext, preserve_context_over_fn
|
|
||||||
)
|
|
||||||
|
|
||||||
from synapse.util import unwrapFirstError
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from twisted.internet import defer
|
||||||
|
|
||||||
|
from synapse.util import unwrapFirstError
|
||||||
|
from synapse.util.logcontext import PreserveLoggingContext
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def user_left_room(distributor, user, room_id):
|
def user_left_room(distributor, user, room_id):
|
||||||
return preserve_context_over_fn(
|
with PreserveLoggingContext():
|
||||||
distributor.fire,
|
distributor.fire("user_left_room", user=user, room_id=room_id)
|
||||||
"user_left_room", user=user, room_id=room_id
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def user_joined_room(distributor, user, room_id):
|
def user_joined_room(distributor, user, room_id):
|
||||||
return preserve_context_over_fn(
|
with PreserveLoggingContext():
|
||||||
distributor.fire,
|
distributor.fire("user_joined_room", user=user, room_id=room_id)
|
||||||
"user_joined_room", user=user, room_id=room_id
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Distributor(object):
|
class Distributor(object):
|
||||||
|
|
|
@ -261,67 +261,6 @@ class PreserveLoggingContext(object):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class _PreservingContextDeferred(defer.Deferred):
|
|
||||||
"""A deferred that ensures that all callbacks and errbacks are called with
|
|
||||||
the given logging context.
|
|
||||||
"""
|
|
||||||
def __init__(self, context):
|
|
||||||
self._log_context = context
|
|
||||||
defer.Deferred.__init__(self)
|
|
||||||
|
|
||||||
def addCallbacks(self, callback, errback=None,
|
|
||||||
callbackArgs=None, callbackKeywords=None,
|
|
||||||
errbackArgs=None, errbackKeywords=None):
|
|
||||||
callback = self._wrap_callback(callback)
|
|
||||||
errback = self._wrap_callback(errback)
|
|
||||||
return defer.Deferred.addCallbacks(
|
|
||||||
self, callback,
|
|
||||||
errback=errback,
|
|
||||||
callbackArgs=callbackArgs,
|
|
||||||
callbackKeywords=callbackKeywords,
|
|
||||||
errbackArgs=errbackArgs,
|
|
||||||
errbackKeywords=errbackKeywords,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _wrap_callback(self, f):
|
|
||||||
def g(res, *args, **kwargs):
|
|
||||||
with PreserveLoggingContext(self._log_context):
|
|
||||||
res = f(res, *args, **kwargs)
|
|
||||||
return res
|
|
||||||
return g
|
|
||||||
|
|
||||||
|
|
||||||
def preserve_context_over_fn(fn, *args, **kwargs):
|
|
||||||
"""Takes a function and invokes it with the given arguments, but removes
|
|
||||||
and restores the current logging context while doing so.
|
|
||||||
|
|
||||||
If the result is a deferred, call preserve_context_over_deferred before
|
|
||||||
returning it.
|
|
||||||
"""
|
|
||||||
with PreserveLoggingContext():
|
|
||||||
res = fn(*args, **kwargs)
|
|
||||||
|
|
||||||
if isinstance(res, defer.Deferred):
|
|
||||||
return preserve_context_over_deferred(res)
|
|
||||||
else:
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def preserve_context_over_deferred(deferred, context=None):
|
|
||||||
"""Given a deferred wrap it such that any callbacks added later to it will
|
|
||||||
be invoked with the current context.
|
|
||||||
|
|
||||||
Deprecated: this almost certainly doesn't do want you want, ie make
|
|
||||||
the deferred follow the synapse logcontext rules: try
|
|
||||||
``make_deferred_yieldable`` instead.
|
|
||||||
"""
|
|
||||||
if context is None:
|
|
||||||
context = LoggingContext.current_context()
|
|
||||||
d = _PreservingContextDeferred(context)
|
|
||||||
deferred.chainDeferred(d)
|
|
||||||
return d
|
|
||||||
|
|
||||||
|
|
||||||
def preserve_fn(f):
|
def preserve_fn(f):
|
||||||
"""Wraps a function, to ensure that the current context is restored after
|
"""Wraps a function, to ensure that the current context is restored after
|
||||||
return from the function, and that the sentinel context is set once the
|
return from the function, and that the sentinel context is set once the
|
||||||
|
|
|
@ -17,7 +17,7 @@ from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.api.constants import Membership, EventTypes
|
from synapse.api.constants import Membership, EventTypes
|
||||||
|
|
||||||
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
|
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
@ -58,7 +58,7 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state,
|
||||||
always_include_ids (set(event_id)): set of event ids to specifically
|
always_include_ids (set(event_id)): set of event ids to specifically
|
||||||
include (unless sender is ignored)
|
include (unless sender is ignored)
|
||||||
"""
|
"""
|
||||||
forgotten = yield preserve_context_over_deferred(defer.gatherResults([
|
forgotten = yield make_deferred_yieldable(defer.gatherResults([
|
||||||
defer.maybeDeferred(
|
defer.maybeDeferred(
|
||||||
preserve_fn(store.who_forgot_in_room),
|
preserve_fn(store.who_forgot_in_room),
|
||||||
room_id,
|
room_id,
|
||||||
|
|
|
@ -36,6 +36,7 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
||||||
id="unique_identifier",
|
id="unique_identifier",
|
||||||
url="some_url",
|
url="some_url",
|
||||||
token="some_token",
|
token="some_token",
|
||||||
|
hostname="matrix.org", # only used by get_groups_for_user
|
||||||
namespaces={
|
namespaces={
|
||||||
ApplicationService.NS_USERS: [],
|
ApplicationService.NS_USERS: [],
|
||||||
ApplicationService.NS_ROOMS: [],
|
ApplicationService.NS_ROOMS: [],
|
||||||
|
|
|
@ -58,7 +58,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
||||||
|
|
||||||
self.mock_federation_resource = MockHttpResource()
|
self.mock_federation_resource = MockHttpResource()
|
||||||
|
|
||||||
mock_notifier = Mock(spec=["on_new_event"])
|
mock_notifier = Mock()
|
||||||
self.on_new_event = mock_notifier.on_new_event
|
self.on_new_event = mock_notifier.on_new_event
|
||||||
|
|
||||||
self.auth = Mock(spec=[])
|
self.auth = Mock(spec=[])
|
||||||
|
@ -76,6 +76,9 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
||||||
"set_received_txn_response",
|
"set_received_txn_response",
|
||||||
"get_destination_retry_timings",
|
"get_destination_retry_timings",
|
||||||
"get_devices_by_remote",
|
"get_devices_by_remote",
|
||||||
|
# Bits that user_directory needs
|
||||||
|
"get_user_directory_stream_pos",
|
||||||
|
"get_current_state_deltas",
|
||||||
]),
|
]),
|
||||||
state_handler=self.state_handler,
|
state_handler=self.state_handler,
|
||||||
handlers=None,
|
handlers=None,
|
||||||
|
@ -122,6 +125,15 @@ class TypingNotificationsTestCase(unittest.TestCase):
|
||||||
return set(str(u) for u in self.room_members)
|
return set(str(u) for u in self.room_members)
|
||||||
self.state_handler.get_current_user_in_room = get_current_user_in_room
|
self.state_handler.get_current_user_in_room = get_current_user_in_room
|
||||||
|
|
||||||
|
self.datastore.get_user_directory_stream_pos.return_value = (
|
||||||
|
# we deliberately return a non-None stream pos to avoid doing an initial_spam
|
||||||
|
defer.succeed(1)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.datastore.get_current_state_deltas.return_value = (
|
||||||
|
None
|
||||||
|
)
|
||||||
|
|
||||||
self.auth.check_joined_room = check_joined_room
|
self.auth.check_joined_room = check_joined_room
|
||||||
|
|
||||||
self.datastore.get_to_device_stream_token = lambda: 0
|
self.datastore.get_to_device_stream_token = lambda: 0
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
|
from twisted.python import failure
|
||||||
|
|
||||||
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
|
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
|
||||||
from synapse.api.errors import SynapseError
|
from synapse.api.errors import SynapseError, InteractiveAuthIncompleteError
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
from mock import Mock
|
from mock import Mock
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
@ -24,7 +26,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
||||||
side_effect=lambda x: self.appservice)
|
side_effect=lambda x: self.appservice)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.auth_result = (False, None, None, None)
|
self.auth_result = failure.Failure(InteractiveAuthIncompleteError(None))
|
||||||
self.auth_handler = Mock(
|
self.auth_handler = Mock(
|
||||||
check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
|
check_auth=Mock(side_effect=lambda x, y, z: self.auth_result),
|
||||||
get_session_data=Mock(return_value=None)
|
get_session_data=Mock(return_value=None)
|
||||||
|
@ -86,6 +88,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
||||||
self.request.args = {
|
self.request.args = {
|
||||||
"access_token": "i_am_an_app_service"
|
"access_token": "i_am_an_app_service"
|
||||||
}
|
}
|
||||||
|
|
||||||
self.request_data = json.dumps({
|
self.request_data = json.dumps({
|
||||||
"username": "kermit"
|
"username": "kermit"
|
||||||
})
|
})
|
||||||
|
@ -120,7 +123,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
||||||
"device_id": device_id,
|
"device_id": device_id,
|
||||||
})
|
})
|
||||||
self.registration_handler.check_username = Mock(return_value=True)
|
self.registration_handler.check_username = Mock(return_value=True)
|
||||||
self.auth_result = (True, None, {
|
self.auth_result = (None, {
|
||||||
"username": "kermit",
|
"username": "kermit",
|
||||||
"password": "monkey"
|
"password": "monkey"
|
||||||
}, None)
|
}, None)
|
||||||
|
@ -150,7 +153,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
||||||
"password": "monkey"
|
"password": "monkey"
|
||||||
})
|
})
|
||||||
self.registration_handler.check_username = Mock(return_value=True)
|
self.registration_handler.check_username = Mock(return_value=True)
|
||||||
self.auth_result = (True, None, {
|
self.auth_result = (None, {
|
||||||
"username": "kermit",
|
"username": "kermit",
|
||||||
"password": "monkey"
|
"password": "monkey"
|
||||||
}, None)
|
}, None)
|
||||||
|
|
Loading…
Reference in a new issue