forked from MirrorHub/synapse
		
	Merge branch 'develop' into matthew/gin_work_mem
This commit is contained in:
		
				commit
				
					
						bb9f0f3cdb
					
				
			
		
					 87 changed files with 2822 additions and 1122 deletions
				
			
		CHANGES.rstREADME.rst
docs
scripts
synapse
__init__.pyserver.pystate.py
api
app
appservice.pyclient_reader.pyfederation_reader.pyfederation_sender.pyfrontend_proxy.pyhomeserver.pymedia_repository.pypusher.pysynchrotron.pysynctl.pyuser_dir.py
config
event_auth.pyfederation
handlers
appservice.pyauth.pydevice.pydevicemessage.pye2e_keys.pyfederation.pygroups_local.pyregister.pyroom_list.pyset_password.py
http
metrics
push
replication/tcp
rest
client
key/v2
media/v1
storage
util
tests
							
								
								
									
										53
									
								
								CHANGES.rst
									
										
									
									
									
								
							
							
						
						
									
										53
									
								
								CHANGES.rst
									
										
									
									
									
								
							|  | @ -1,3 +1,56 @@ | |||
| Unreleased | ||||
| ========== | ||||
| 
 | ||||
| synctl no longer starts the main synapse when using ``-a`` option with workers. | ||||
| A new worker file should be added with ``worker_app: synapse.app.homeserver``. | ||||
| 
 | ||||
| This release also begins the process of renaming a number of the metrics | ||||
| reported to prometheus. See `docs/metrics-howto.rst <docs/metrics-howto.rst#block-and-response-metrics-renamed-for-0-27-0>`_. | ||||
| 
 | ||||
| 
 | ||||
| Changes in synapse v0.26.0 (2018-01-05) | ||||
| ======================================= | ||||
| 
 | ||||
| No changes since v0.26.0-rc1 | ||||
| 
 | ||||
| 
 | ||||
| Changes in synapse v0.26.0-rc1 (2017-12-13) | ||||
| =========================================== | ||||
| 
 | ||||
| Features: | ||||
| 
 | ||||
| * Add ability for ASes to publicise groups for their users (PR #2686) | ||||
| * Add all local users to the user_directory and optionally search them (PR | ||||
|   #2723) | ||||
| * Add support for custom login types for validating users (PR #2729) | ||||
| 
 | ||||
| 
 | ||||
| Changes: | ||||
| 
 | ||||
| * Update example Prometheus config to new format (PR #2648) Thanks to | ||||
|   @krombel! | ||||
| * Rename redact_content option to include_content in Push API (PR #2650) | ||||
| * Declare support for r0.3.0 (PR #2677) | ||||
| * Improve upserts (PR #2684, #2688, #2689, #2713) | ||||
| * Improve documentation of workers (PR #2700) | ||||
| * Improve tracebacks on exceptions (PR #2705) | ||||
| * Allow guest access to group APIs for reading (PR #2715) | ||||
| * Support for posting content in federation_client script (PR #2716) | ||||
| * Delete devices and pushers on logouts etc (PR #2722) | ||||
| 
 | ||||
| 
 | ||||
| Bug fixes: | ||||
| 
 | ||||
| * Fix database port script (PR #2673) | ||||
| * Fix internal server error on login with ldap_auth_provider (PR #2678) Thanks | ||||
|   to @jkolo! | ||||
| * Fix error on sqlite 3.7 (PR #2697) | ||||
| * Fix OPTIONS on preview_url (PR #2707) | ||||
| * Fix error handling on dns lookup (PR #2711) | ||||
| * Fix wrong avatars when inviting multiple users when creating room (PR #2717) | ||||
| * Fix 500 when joining matrix-dev (PR #2719) | ||||
| 
 | ||||
| 
 | ||||
| Changes in synapse v0.25.1 (2017-11-17) | ||||
| ======================================= | ||||
| 
 | ||||
|  |  | |||
|  | @ -632,6 +632,11 @@ largest boxes pause for thought.) | |||
| 
 | ||||
| Troubleshooting | ||||
| --------------- | ||||
| 
 | ||||
| You can use the federation tester to check if your homeserver is all set: | ||||
| ``https://matrix.org/federationtester/api/report?server_name=<your_server_name>`` | ||||
| If any of the attributes under "checks" is false, federation won't work. | ||||
| 
 | ||||
| The typical failure mode with federation is that when you try to join a room, | ||||
| it is rejected with "401: Unauthorized". Generally this means that other | ||||
| servers in the room couldn't access yours. (Joining a room over federation is a | ||||
|  |  | |||
|  | @ -16,7 +16,7 @@ How to monitor Synapse metrics using Prometheus | |||
|      metrics_port: 9092 | ||||
| 
 | ||||
|    Also ensure that ``enable_metrics`` is set to ``True``. | ||||
|    | ||||
| 
 | ||||
|    Restart synapse. | ||||
| 
 | ||||
| 3. Add a prometheus target for synapse. | ||||
|  | @ -28,11 +28,58 @@ How to monitor Synapse metrics using Prometheus | |||
|       static_configs: | ||||
|         - targets: ["my.server.here:9092"] | ||||
| 
 | ||||
|    If your prometheus is older than 1.5.2, you will need to replace  | ||||
|    If your prometheus is older than 1.5.2, you will need to replace | ||||
|    ``static_configs`` in the above with ``target_groups``. | ||||
|     | ||||
| 
 | ||||
|    Restart prometheus. | ||||
| 
 | ||||
| 
 | ||||
| Block and response metrics renamed for 0.27.0 | ||||
| --------------------------------------------- | ||||
| 
 | ||||
| Synapse 0.27.0 begins the process of rationalising the duplicate ``*:count`` | ||||
| metrics reported for the resource tracking for code blocks and HTTP requests. | ||||
| 
 | ||||
| At the same time, the corresponding ``*:total`` metrics are being renamed, as | ||||
| the ``:total`` suffix no longer makes sense in the absence of a corresponding | ||||
| ``:count`` metric. | ||||
| 
 | ||||
| To enable a graceful migration path, this release just adds new names for the | ||||
| metrics being renamed. A future release will remove the old ones. | ||||
| 
 | ||||
| The following table shows the new metrics, and the old metrics which they are | ||||
| replacing. | ||||
| 
 | ||||
| ==================================================== =================================================== | ||||
| New name                                             Old name | ||||
| ==================================================== =================================================== | ||||
| synapse_util_metrics_block_count                     synapse_util_metrics_block_timer:count | ||||
| synapse_util_metrics_block_count                     synapse_util_metrics_block_ru_utime:count | ||||
| synapse_util_metrics_block_count                     synapse_util_metrics_block_ru_stime:count | ||||
| synapse_util_metrics_block_count                     synapse_util_metrics_block_db_txn_count:count | ||||
| synapse_util_metrics_block_count                     synapse_util_metrics_block_db_txn_duration:count | ||||
| 
 | ||||
| synapse_util_metrics_block_time_seconds              synapse_util_metrics_block_timer:total | ||||
| synapse_util_metrics_block_ru_utime_seconds          synapse_util_metrics_block_ru_utime:total | ||||
| synapse_util_metrics_block_ru_stime_seconds          synapse_util_metrics_block_ru_stime:total | ||||
| synapse_util_metrics_block_db_txn_count              synapse_util_metrics_block_db_txn_count:total | ||||
| synapse_util_metrics_block_db_txn_duration_seconds   synapse_util_metrics_block_db_txn_duration:total | ||||
| 
 | ||||
| synapse_http_server_response_count                   synapse_http_server_requests | ||||
| synapse_http_server_response_count                   synapse_http_server_response_time:count | ||||
| synapse_http_server_response_count                   synapse_http_server_response_ru_utime:count | ||||
| synapse_http_server_response_count                   synapse_http_server_response_ru_stime:count | ||||
| synapse_http_server_response_count                   synapse_http_server_response_db_txn_count:count | ||||
| synapse_http_server_response_count                   synapse_http_server_response_db_txn_duration:count | ||||
| 
 | ||||
| synapse_http_server_response_time_seconds            synapse_http_server_response_time:total | ||||
| synapse_http_server_response_ru_utime_seconds        synapse_http_server_response_ru_utime:total | ||||
| synapse_http_server_response_ru_stime_seconds        synapse_http_server_response_ru_stime:total | ||||
| synapse_http_server_response_db_txn_count            synapse_http_server_response_db_txn_count:total | ||||
| synapse_http_server_response_db_txn_duration_seconds synapse_http_server_response_db_txn_duration:total | ||||
| ==================================================== =================================================== | ||||
| 
 | ||||
| 
 | ||||
| Standard Metric Names | ||||
| --------------------- | ||||
| 
 | ||||
|  | @ -42,7 +89,7 @@ have been changed to seconds, from miliseconds. | |||
| 
 | ||||
| ================================== ============================= | ||||
| New name                           Old name | ||||
| ---------------------------------- ----------------------------- | ||||
| ================================== ============================= | ||||
| process_cpu_user_seconds_total     process_resource_utime / 1000 | ||||
| process_cpu_system_seconds_total   process_resource_stime / 1000 | ||||
| process_open_fds (no 'type' label) process_fds | ||||
|  | @ -52,8 +99,8 @@ The python-specific counts of garbage collector performance have been renamed. | |||
| 
 | ||||
| =========================== ====================== | ||||
| New name                    Old name | ||||
| --------------------------- ---------------------- | ||||
| python_gc_time              reactor_gc_time       | ||||
| =========================== ====================== | ||||
| python_gc_time              reactor_gc_time | ||||
| python_gc_unreachable_total reactor_gc_unreachable | ||||
| python_gc_counts            reactor_gc_counts | ||||
| =========================== ====================== | ||||
|  | @ -62,7 +109,7 @@ The twisted-specific reactor metrics have been renamed. | |||
| 
 | ||||
| ==================================== ===================== | ||||
| New name                             Old name | ||||
| ------------------------------------ --------------------- | ||||
| ==================================== ===================== | ||||
| python_twisted_reactor_pending_calls reactor_pending_calls | ||||
| python_twisted_reactor_tick_time     reactor_tick_time | ||||
| ==================================== ===================== | ||||
|  |  | |||
							
								
								
									
										133
									
								
								scripts/move_remote_media_to_new_store.py
									
										
									
									
									
										Executable file
									
								
							
							
						
						
									
										133
									
								
								scripts/move_remote_media_to_new_store.py
									
										
									
									
									
										Executable file
									
								
							|  | @ -0,0 +1,133 @@ | |||
| #!/usr/bin/env python | ||||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2017 New Vector Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| """ | ||||
| Moves a list of remote media from one media store to another. | ||||
| 
 | ||||
| The input should be a list of media files to be moved, one per line. Each line | ||||
| should be formatted:: | ||||
| 
 | ||||
|     <origin server>|<file id> | ||||
| 
 | ||||
| This can be extracted from postgres with:: | ||||
| 
 | ||||
|     psql --tuples-only -A -c "select media_origin, filesystem_id from | ||||
|         matrix.remote_media_cache where ..." | ||||
| 
 | ||||
| To use, pipe the above into:: | ||||
| 
 | ||||
|     PYTHON_PATH=. ./scripts/move_remote_media_to_new_store.py <source repo> <dest repo> | ||||
| """ | ||||
| 
 | ||||
| from __future__ import print_function | ||||
| 
 | ||||
| import argparse | ||||
| import logging | ||||
| 
 | ||||
| import sys | ||||
| 
 | ||||
| import os | ||||
| 
 | ||||
| import shutil | ||||
| 
 | ||||
| from synapse.rest.media.v1.filepath import MediaFilePaths | ||||
| 
 | ||||
| logger = logging.getLogger() | ||||
| 
 | ||||
| 
 | ||||
| def main(src_repo, dest_repo): | ||||
|     src_paths = MediaFilePaths(src_repo) | ||||
|     dest_paths = MediaFilePaths(dest_repo) | ||||
|     for line in sys.stdin: | ||||
|         line = line.strip() | ||||
|         parts = line.split('|') | ||||
|         if len(parts) != 2: | ||||
|             print("Unable to parse input line %s" % line, file=sys.stderr) | ||||
|             exit(1) | ||||
| 
 | ||||
|         move_media(parts[0], parts[1], src_paths, dest_paths) | ||||
| 
 | ||||
| 
 | ||||
| def move_media(origin_server, file_id, src_paths, dest_paths): | ||||
|     """Move the given file, and any thumbnails, to the dest repo | ||||
| 
 | ||||
|     Args: | ||||
|         origin_server (str): | ||||
|         file_id (str): | ||||
|         src_paths (MediaFilePaths): | ||||
|         dest_paths (MediaFilePaths): | ||||
|     """ | ||||
|     logger.info("%s/%s", origin_server, file_id) | ||||
| 
 | ||||
|     # check that the original exists | ||||
|     original_file = src_paths.remote_media_filepath(origin_server, file_id) | ||||
|     if not os.path.exists(original_file): | ||||
|         logger.warn( | ||||
|             "Original for %s/%s (%s) does not exist", | ||||
|             origin_server, file_id, original_file, | ||||
|         ) | ||||
|     else: | ||||
|         mkdir_and_move( | ||||
|             original_file, | ||||
|             dest_paths.remote_media_filepath(origin_server, file_id), | ||||
|         ) | ||||
| 
 | ||||
|     # now look for thumbnails | ||||
|     original_thumb_dir = src_paths.remote_media_thumbnail_dir( | ||||
|         origin_server, file_id, | ||||
|     ) | ||||
|     if not os.path.exists(original_thumb_dir): | ||||
|         return | ||||
| 
 | ||||
|     mkdir_and_move( | ||||
|         original_thumb_dir, | ||||
|         dest_paths.remote_media_thumbnail_dir(origin_server, file_id) | ||||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| def mkdir_and_move(original_file, dest_file): | ||||
|     dirname = os.path.dirname(dest_file) | ||||
|     if not os.path.exists(dirname): | ||||
|         logger.debug("mkdir %s", dirname) | ||||
|         os.makedirs(dirname) | ||||
|     logger.debug("mv %s %s", original_file, dest_file) | ||||
|     shutil.move(original_file, dest_file) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description=__doc__, | ||||
|         formatter_class = argparse.RawDescriptionHelpFormatter, | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "-v", action='store_true', help='enable debug logging') | ||||
|     parser.add_argument( | ||||
|         "src_repo", | ||||
|         help="Path to source content repo", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "dest_repo", | ||||
|         help="Path to source content repo", | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
| 
 | ||||
|     logging_config = { | ||||
|         "level": logging.DEBUG if args.v else logging.INFO, | ||||
|         "format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s" | ||||
|     } | ||||
|     logging.basicConfig(**logging_config) | ||||
| 
 | ||||
|     main(args.src_repo, args.dest_repo) | ||||
|  | @ -16,4 +16,4 @@ | |||
| """ This is a reference implementation of a Matrix home server. | ||||
| """ | ||||
| 
 | ||||
| __version__ = "0.25.1" | ||||
| __version__ = "0.26.0" | ||||
|  |  | |||
|  | @ -46,6 +46,7 @@ class Codes(object): | |||
|     THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED" | ||||
|     THREEPID_IN_USE = "M_THREEPID_IN_USE" | ||||
|     THREEPID_NOT_FOUND = "M_THREEPID_NOT_FOUND" | ||||
|     THREEPID_DENIED = "M_THREEPID_DENIED" | ||||
|     INVALID_USERNAME = "M_INVALID_USERNAME" | ||||
|     SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED" | ||||
| 
 | ||||
|  | @ -140,6 +141,32 @@ class RegistrationError(SynapseError): | |||
|     pass | ||||
| 
 | ||||
| 
 | ||||
| class FederationDeniedError(SynapseError): | ||||
|     """An error raised when the server tries to federate with a server which | ||||
|     is not on its federation whitelist. | ||||
| 
 | ||||
|     Attributes: | ||||
|         destination (str): The destination which has been denied | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, destination): | ||||
|         """Raised by federation client or server to indicate that we are | ||||
|         are deliberately not attempting to contact a given server because it is | ||||
|         not on our federation whitelist. | ||||
| 
 | ||||
|         Args: | ||||
|             destination (str): the domain in question | ||||
|         """ | ||||
| 
 | ||||
|         self.destination = destination | ||||
| 
 | ||||
|         super(FederationDeniedError, self).__init__( | ||||
|             code=403, | ||||
|             msg="Federation denied with %s." % (self.destination,), | ||||
|             errcode=Codes.FORBIDDEN, | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| class InteractiveAuthIncompleteError(Exception): | ||||
|     """An error raised when UI auth is not yet complete | ||||
| 
 | ||||
|  |  | |||
|  | @ -49,19 +49,6 @@ class AppserviceSlaveStore( | |||
| 
 | ||||
| 
 | ||||
| class AppserviceServer(HomeServer): | ||||
|     def get_db_conn(self, run_new_connection=True): | ||||
|         # Any param beginning with cp_ is a parameter for adbapi, and should | ||||
|         # not be passed to the database engine. | ||||
|         db_params = { | ||||
|             k: v for k, v in self.db_config.get("args", {}).items() | ||||
|             if not k.startswith("cp_") | ||||
|         } | ||||
|         db_conn = self.database_engine.module.connect(**db_params) | ||||
| 
 | ||||
|         if run_new_connection: | ||||
|             self.database_engine.on_new_connection(db_conn) | ||||
|         return db_conn | ||||
| 
 | ||||
|     def setup(self): | ||||
|         logger.info("Setting up.") | ||||
|         self.datastore = AppserviceSlaveStore(self.get_db_conn(), self) | ||||
|  |  | |||
|  | @ -64,19 +64,6 @@ class ClientReaderSlavedStore( | |||
| 
 | ||||
| 
 | ||||
| class ClientReaderServer(HomeServer): | ||||
|     def get_db_conn(self, run_new_connection=True): | ||||
|         # Any param beginning with cp_ is a parameter for adbapi, and should | ||||
|         # not be passed to the database engine. | ||||
|         db_params = { | ||||
|             k: v for k, v in self.db_config.get("args", {}).items() | ||||
|             if not k.startswith("cp_") | ||||
|         } | ||||
|         db_conn = self.database_engine.module.connect(**db_params) | ||||
| 
 | ||||
|         if run_new_connection: | ||||
|             self.database_engine.on_new_connection(db_conn) | ||||
|         return db_conn | ||||
| 
 | ||||
|     def setup(self): | ||||
|         logger.info("Setting up.") | ||||
|         self.datastore = ClientReaderSlavedStore(self.get_db_conn(), self) | ||||
|  |  | |||
|  | @ -58,19 +58,6 @@ class FederationReaderSlavedStore( | |||
| 
 | ||||
| 
 | ||||
| class FederationReaderServer(HomeServer): | ||||
|     def get_db_conn(self, run_new_connection=True): | ||||
|         # Any param beginning with cp_ is a parameter for adbapi, and should | ||||
|         # not be passed to the database engine. | ||||
|         db_params = { | ||||
|             k: v for k, v in self.db_config.get("args", {}).items() | ||||
|             if not k.startswith("cp_") | ||||
|         } | ||||
|         db_conn = self.database_engine.module.connect(**db_params) | ||||
| 
 | ||||
|         if run_new_connection: | ||||
|             self.database_engine.on_new_connection(db_conn) | ||||
|         return db_conn | ||||
| 
 | ||||
|     def setup(self): | ||||
|         logger.info("Setting up.") | ||||
|         self.datastore = FederationReaderSlavedStore(self.get_db_conn(), self) | ||||
|  |  | |||
|  | @ -76,19 +76,6 @@ class FederationSenderSlaveStore( | |||
| 
 | ||||
| 
 | ||||
| class FederationSenderServer(HomeServer): | ||||
|     def get_db_conn(self, run_new_connection=True): | ||||
|         # Any param beginning with cp_ is a parameter for adbapi, and should | ||||
|         # not be passed to the database engine. | ||||
|         db_params = { | ||||
|             k: v for k, v in self.db_config.get("args", {}).items() | ||||
|             if not k.startswith("cp_") | ||||
|         } | ||||
|         db_conn = self.database_engine.module.connect(**db_params) | ||||
| 
 | ||||
|         if run_new_connection: | ||||
|             self.database_engine.on_new_connection(db_conn) | ||||
|         return db_conn | ||||
| 
 | ||||
|     def setup(self): | ||||
|         logger.info("Setting up.") | ||||
|         self.datastore = FederationSenderSlaveStore(self.get_db_conn(), self) | ||||
|  |  | |||
|  | @ -118,19 +118,6 @@ class FrontendProxySlavedStore( | |||
| 
 | ||||
| 
 | ||||
| class FrontendProxyServer(HomeServer): | ||||
|     def get_db_conn(self, run_new_connection=True): | ||||
|         # Any param beginning with cp_ is a parameter for adbapi, and should | ||||
|         # not be passed to the database engine. | ||||
|         db_params = { | ||||
|             k: v for k, v in self.db_config.get("args", {}).items() | ||||
|             if not k.startswith("cp_") | ||||
|         } | ||||
|         db_conn = self.database_engine.module.connect(**db_params) | ||||
| 
 | ||||
|         if run_new_connection: | ||||
|             self.database_engine.on_new_connection(db_conn) | ||||
|         return db_conn | ||||
| 
 | ||||
|     def setup(self): | ||||
|         logger.info("Setting up.") | ||||
|         self.datastore = FrontendProxySlavedStore(self.get_db_conn(), self) | ||||
|  |  | |||
|  | @ -266,19 +266,6 @@ class SynapseHomeServer(HomeServer): | |||
|         except IncorrectDatabaseSetup as e: | ||||
|             quit_with_error(e.message) | ||||
| 
 | ||||
|     def get_db_conn(self, run_new_connection=True): | ||||
|         # Any param beginning with cp_ is a parameter for adbapi, and should | ||||
|         # not be passed to the database engine. | ||||
|         db_params = { | ||||
|             k: v for k, v in self.db_config.get("args", {}).items() | ||||
|             if not k.startswith("cp_") | ||||
|         } | ||||
|         db_conn = self.database_engine.module.connect(**db_params) | ||||
| 
 | ||||
|         if run_new_connection: | ||||
|             self.database_engine.on_new_connection(db_conn) | ||||
|         return db_conn | ||||
| 
 | ||||
| 
 | ||||
| def setup(config_options): | ||||
|     """ | ||||
|  |  | |||
|  | @ -60,19 +60,6 @@ class MediaRepositorySlavedStore( | |||
| 
 | ||||
| 
 | ||||
| class MediaRepositoryServer(HomeServer): | ||||
|     def get_db_conn(self, run_new_connection=True): | ||||
|         # Any param beginning with cp_ is a parameter for adbapi, and should | ||||
|         # not be passed to the database engine. | ||||
|         db_params = { | ||||
|             k: v for k, v in self.db_config.get("args", {}).items() | ||||
|             if not k.startswith("cp_") | ||||
|         } | ||||
|         db_conn = self.database_engine.module.connect(**db_params) | ||||
| 
 | ||||
|         if run_new_connection: | ||||
|             self.database_engine.on_new_connection(db_conn) | ||||
|         return db_conn | ||||
| 
 | ||||
|     def setup(self): | ||||
|         logger.info("Setting up.") | ||||
|         self.datastore = MediaRepositorySlavedStore(self.get_db_conn(), self) | ||||
|  |  | |||
|  | @ -81,19 +81,6 @@ class PusherSlaveStore( | |||
| 
 | ||||
| 
 | ||||
| class PusherServer(HomeServer): | ||||
|     def get_db_conn(self, run_new_connection=True): | ||||
|         # Any param beginning with cp_ is a parameter for adbapi, and should | ||||
|         # not be passed to the database engine. | ||||
|         db_params = { | ||||
|             k: v for k, v in self.db_config.get("args", {}).items() | ||||
|             if not k.startswith("cp_") | ||||
|         } | ||||
|         db_conn = self.database_engine.module.connect(**db_params) | ||||
| 
 | ||||
|         if run_new_connection: | ||||
|             self.database_engine.on_new_connection(db_conn) | ||||
|         return db_conn | ||||
| 
 | ||||
|     def setup(self): | ||||
|         logger.info("Setting up.") | ||||
|         self.datastore = PusherSlaveStore(self.get_db_conn(), self) | ||||
|  |  | |||
|  | @ -246,19 +246,6 @@ class SynchrotronApplicationService(object): | |||
| 
 | ||||
| 
 | ||||
| class SynchrotronServer(HomeServer): | ||||
|     def get_db_conn(self, run_new_connection=True): | ||||
|         # Any param beginning with cp_ is a parameter for adbapi, and should | ||||
|         # not be passed to the database engine. | ||||
|         db_params = { | ||||
|             k: v for k, v in self.db_config.get("args", {}).items() | ||||
|             if not k.startswith("cp_") | ||||
|         } | ||||
|         db_conn = self.database_engine.module.connect(**db_params) | ||||
| 
 | ||||
|         if run_new_connection: | ||||
|             self.database_engine.on_new_connection(db_conn) | ||||
|         return db_conn | ||||
| 
 | ||||
|     def setup(self): | ||||
|         logger.info("Setting up.") | ||||
|         self.datastore = SynchrotronSlavedStore(self.get_db_conn(), self) | ||||
|  |  | |||
|  | @ -184,6 +184,9 @@ def main(): | |||
|         worker_configfiles.append(worker_configfile) | ||||
| 
 | ||||
|     if options.all_processes: | ||||
|         # To start the main synapse with -a you need to add a worker file | ||||
|         # with worker_app == "synapse.app.homeserver" | ||||
|         start_stop_synapse = False | ||||
|         worker_configdir = options.all_processes | ||||
|         if not os.path.isdir(worker_configdir): | ||||
|             write( | ||||
|  | @ -200,11 +203,29 @@ def main(): | |||
|         with open(worker_configfile) as stream: | ||||
|             worker_config = yaml.load(stream) | ||||
|         worker_app = worker_config["worker_app"] | ||||
|         worker_pidfile = worker_config["worker_pid_file"] | ||||
|         worker_daemonize = worker_config["worker_daemonize"] | ||||
|         assert worker_daemonize, "In config %r: expected '%s' to be True" % ( | ||||
|             worker_configfile, "worker_daemonize") | ||||
|         worker_cache_factor = worker_config.get("synctl_cache_factor") | ||||
|         if worker_app == "synapse.app.homeserver": | ||||
|             # We need to special case all of this to pick up options that may | ||||
|             # be set in the main config file or in this worker config file. | ||||
|             worker_pidfile = ( | ||||
|                 worker_config.get("pid_file") | ||||
|                 or pidfile | ||||
|             ) | ||||
|             worker_cache_factor = worker_config.get("synctl_cache_factor") or cache_factor | ||||
|             daemonize = worker_config.get("daemonize") or config.get("daemonize") | ||||
|             assert daemonize, "Main process must have daemonize set to true" | ||||
| 
 | ||||
|             # The master process doesn't support using worker_* config. | ||||
|             for key in worker_config: | ||||
|                 if key == "worker_app":  # But we allow worker_app | ||||
|                     continue | ||||
|                 assert not key.startswith("worker_"), \ | ||||
|                     "Main process cannot use worker_* config" | ||||
|         else: | ||||
|             worker_pidfile = worker_config["worker_pid_file"] | ||||
|             worker_daemonize = worker_config["worker_daemonize"] | ||||
|             assert worker_daemonize, "In config %r: expected '%s' to be True" % ( | ||||
|                 worker_configfile, "worker_daemonize") | ||||
|             worker_cache_factor = worker_config.get("synctl_cache_factor") | ||||
|         workers.append(Worker( | ||||
|             worker_app, worker_configfile, worker_pidfile, worker_cache_factor, | ||||
|         )) | ||||
|  |  | |||
|  | @ -92,19 +92,6 @@ class UserDirectorySlaveStore( | |||
| 
 | ||||
| 
 | ||||
| class UserDirectoryServer(HomeServer): | ||||
|     def get_db_conn(self, run_new_connection=True): | ||||
|         # Any param beginning with cp_ is a parameter for adbapi, and should | ||||
|         # not be passed to the database engine. | ||||
|         db_params = { | ||||
|             k: v for k, v in self.db_config.get("args", {}).items() | ||||
|             if not k.startswith("cp_") | ||||
|         } | ||||
|         db_conn = self.database_engine.module.connect(**db_params) | ||||
| 
 | ||||
|         if run_new_connection: | ||||
|             self.database_engine.on_new_connection(db_conn) | ||||
|         return db_conn | ||||
| 
 | ||||
|     def setup(self): | ||||
|         logger.info("Setting up.") | ||||
|         self.datastore = UserDirectorySlaveStore(self.get_db_conn(), self) | ||||
|  |  | |||
|  | @ -28,27 +28,27 @@ DEFAULT_LOG_CONFIG = Template(""" | |||
| version: 1 | ||||
| 
 | ||||
| formatters: | ||||
|   precise: | ||||
|    format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s\ | ||||
| - %(message)s' | ||||
|     precise: | ||||
|         format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - \ | ||||
| %(request)s - %(message)s' | ||||
| 
 | ||||
| filters: | ||||
|   context: | ||||
|     (): synapse.util.logcontext.LoggingContextFilter | ||||
|     request: "" | ||||
|     context: | ||||
|         (): synapse.util.logcontext.LoggingContextFilter | ||||
|         request: "" | ||||
| 
 | ||||
| handlers: | ||||
|   file: | ||||
|     class: logging.handlers.RotatingFileHandler | ||||
|     formatter: precise | ||||
|     filename: ${log_file} | ||||
|     maxBytes: 104857600 | ||||
|     backupCount: 10 | ||||
|     filters: [context] | ||||
|   console: | ||||
|     class: logging.StreamHandler | ||||
|     formatter: precise | ||||
|     filters: [context] | ||||
|     file: | ||||
|         class: logging.handlers.RotatingFileHandler | ||||
|         formatter: precise | ||||
|         filename: ${log_file} | ||||
|         maxBytes: 104857600 | ||||
|         backupCount: 10 | ||||
|         filters: [context] | ||||
|     console: | ||||
|         class: logging.StreamHandler | ||||
|         formatter: precise | ||||
|         filters: [context] | ||||
| 
 | ||||
| loggers: | ||||
|     synapse: | ||||
|  | @ -74,17 +74,10 @@ class LoggingConfig(Config): | |||
|         self.log_file = self.abspath(config.get("log_file")) | ||||
| 
 | ||||
|     def default_config(self, config_dir_path, server_name, **kwargs): | ||||
|         log_file = self.abspath("homeserver.log") | ||||
|         log_config = self.abspath( | ||||
|             os.path.join(config_dir_path, server_name + ".log.config") | ||||
|         ) | ||||
|         return """ | ||||
|         # Logging verbosity level. Ignored if log_config is specified. | ||||
|         verbose: 0 | ||||
| 
 | ||||
|         # File to write logging to. Ignored if log_config is specified. | ||||
|         log_file: "%(log_file)s" | ||||
| 
 | ||||
|         # A yaml python logging config file | ||||
|         log_config: "%(log_config)s" | ||||
|         """ % locals() | ||||
|  | @ -123,9 +116,10 @@ class LoggingConfig(Config): | |||
|     def generate_files(self, config): | ||||
|         log_config = config.get("log_config") | ||||
|         if log_config and not os.path.exists(log_config): | ||||
|             log_file = self.abspath("homeserver.log") | ||||
|             with open(log_config, "wb") as log_config_file: | ||||
|                 log_config_file.write( | ||||
|                     DEFAULT_LOG_CONFIG.substitute(log_file=config["log_file"]) | ||||
|                     DEFAULT_LOG_CONFIG.substitute(log_file=log_file) | ||||
|                 ) | ||||
| 
 | ||||
| 
 | ||||
|  | @ -150,6 +144,9 @@ def setup_logging(config, use_worker_options=False): | |||
|     ) | ||||
| 
 | ||||
|     if log_config is None: | ||||
|         # We don't have a logfile, so fall back to the 'verbosity' param from | ||||
|         # the config or cmdline. (Note that we generate a log config for new | ||||
|         # installs, so this will be an unusual case) | ||||
|         level = logging.INFO | ||||
|         level_for_storage = logging.INFO | ||||
|         if config.verbosity: | ||||
|  | @ -157,11 +154,10 @@ def setup_logging(config, use_worker_options=False): | |||
|             if config.verbosity > 1: | ||||
|                 level_for_storage = logging.DEBUG | ||||
| 
 | ||||
|         # FIXME: we need a logging.WARN for a -q quiet option | ||||
|         logger = logging.getLogger('') | ||||
|         logger.setLevel(level) | ||||
| 
 | ||||
|         logging.getLogger('synapse.storage').setLevel(level_for_storage) | ||||
|         logging.getLogger('synapse.storage.SQL').setLevel(level_for_storage) | ||||
| 
 | ||||
|         formatter = logging.Formatter(log_format) | ||||
|         if log_file: | ||||
|  |  | |||
|  | @ -31,6 +31,8 @@ class RegistrationConfig(Config): | |||
|                 strtobool(str(config["disable_registration"])) | ||||
|             ) | ||||
| 
 | ||||
|         self.registrations_require_3pid = config.get("registrations_require_3pid", []) | ||||
|         self.allowed_local_3pids = config.get("allowed_local_3pids", []) | ||||
|         self.registration_shared_secret = config.get("registration_shared_secret") | ||||
| 
 | ||||
|         self.bcrypt_rounds = config.get("bcrypt_rounds", 12) | ||||
|  | @ -52,6 +54,23 @@ class RegistrationConfig(Config): | |||
|         # Enable registration for new users. | ||||
|         enable_registration: False | ||||
| 
 | ||||
|         # The user must provide all of the below types of 3PID when registering. | ||||
|         # | ||||
|         # registrations_require_3pid: | ||||
|         #     - email | ||||
|         #     - msisdn | ||||
| 
 | ||||
|         # Mandate that users are only allowed to associate certain formats of | ||||
|         # 3PIDs with accounts on this server. | ||||
|         # | ||||
|         # allowed_local_3pids: | ||||
|         #     - medium: email | ||||
|         #       pattern: ".*@matrix\\.org" | ||||
|         #     - medium: email | ||||
|         #       pattern: ".*@vector\\.im" | ||||
|         #     - medium: msisdn | ||||
|         #       pattern: "\\+44" | ||||
| 
 | ||||
|         # If set, allows registration by anyone who also has the shared | ||||
|         # secret, even if registration is otherwise disabled. | ||||
|         registration_shared_secret: "%(registration_shared_secret)s" | ||||
|  |  | |||
|  | @ -16,6 +16,8 @@ | |||
| from ._base import Config, ConfigError | ||||
| from collections import namedtuple | ||||
| 
 | ||||
| from synapse.util.module_loader import load_module | ||||
| 
 | ||||
| 
 | ||||
| MISSING_NETADDR = ( | ||||
|     "Missing netaddr library. This is required for URL preview API." | ||||
|  | @ -36,6 +38,14 @@ ThumbnailRequirement = namedtuple( | |||
|     "ThumbnailRequirement", ["width", "height", "method", "media_type"] | ||||
| ) | ||||
| 
 | ||||
| MediaStorageProviderConfig = namedtuple( | ||||
|     "MediaStorageProviderConfig", ( | ||||
|         "store_local",  # Whether to store newly uploaded local files | ||||
|         "store_remote",  # Whether to store newly downloaded remote files | ||||
|         "store_synchronous",  # Whether to wait for successful storage for local uploads | ||||
|     ), | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| def parse_thumbnail_requirements(thumbnail_sizes): | ||||
|     """ Takes a list of dictionaries with "width", "height", and "method" keys | ||||
|  | @ -73,16 +83,61 @@ class ContentRepositoryConfig(Config): | |||
| 
 | ||||
|         self.media_store_path = self.ensure_directory(config["media_store_path"]) | ||||
| 
 | ||||
|         self.backup_media_store_path = config.get("backup_media_store_path") | ||||
|         if self.backup_media_store_path: | ||||
|             self.backup_media_store_path = self.ensure_directory( | ||||
|                 self.backup_media_store_path | ||||
|             ) | ||||
|         backup_media_store_path = config.get("backup_media_store_path") | ||||
| 
 | ||||
|         self.synchronous_backup_media_store = config.get( | ||||
|         synchronous_backup_media_store = config.get( | ||||
|             "synchronous_backup_media_store", False | ||||
|         ) | ||||
| 
 | ||||
|         storage_providers = config.get("media_storage_providers", []) | ||||
| 
 | ||||
|         if backup_media_store_path: | ||||
|             if storage_providers: | ||||
|                 raise ConfigError( | ||||
|                     "Cannot use both 'backup_media_store_path' and 'storage_providers'" | ||||
|                 ) | ||||
| 
 | ||||
|             storage_providers = [{ | ||||
|                 "module": "file_system", | ||||
|                 "store_local": True, | ||||
|                 "store_synchronous": synchronous_backup_media_store, | ||||
|                 "store_remote": True, | ||||
|                 "config": { | ||||
|                     "directory": backup_media_store_path, | ||||
|                 } | ||||
|             }] | ||||
| 
 | ||||
|         # This is a list of config that can be used to create the storage | ||||
|         # providers. The entries are tuples of (Class, class_config, | ||||
|         # MediaStorageProviderConfig), where Class is the class of the provider, | ||||
|         # the class_config the config to pass to it, and | ||||
|         # MediaStorageProviderConfig are options for StorageProviderWrapper. | ||||
|         # | ||||
|         # We don't create the storage providers here as not all workers need | ||||
|         # them to be started. | ||||
|         self.media_storage_providers = [] | ||||
| 
 | ||||
|         for provider_config in storage_providers: | ||||
|             # We special case the module "file_system" so as not to need to | ||||
|             # expose FileStorageProviderBackend | ||||
|             if provider_config["module"] == "file_system": | ||||
|                 provider_config["module"] = ( | ||||
|                     "synapse.rest.media.v1.storage_provider" | ||||
|                     ".FileStorageProviderBackend" | ||||
|                 ) | ||||
| 
 | ||||
|             provider_class, parsed_config = load_module(provider_config) | ||||
| 
 | ||||
|             wrapper_config = MediaStorageProviderConfig( | ||||
|                 provider_config.get("store_local", False), | ||||
|                 provider_config.get("store_remote", False), | ||||
|                 provider_config.get("store_synchronous", False), | ||||
|             ) | ||||
| 
 | ||||
|             self.media_storage_providers.append( | ||||
|                 (provider_class, parsed_config, wrapper_config,) | ||||
|             ) | ||||
| 
 | ||||
|         self.uploads_path = self.ensure_directory(config["uploads_path"]) | ||||
|         self.dynamic_thumbnails = config["dynamic_thumbnails"] | ||||
|         self.thumbnail_requirements = parse_thumbnail_requirements( | ||||
|  | @ -127,13 +182,19 @@ class ContentRepositoryConfig(Config): | |||
|         # Directory where uploaded images and attachments are stored. | ||||
|         media_store_path: "%(media_store)s" | ||||
| 
 | ||||
|         # A secondary directory where uploaded images and attachments are | ||||
|         # stored as a backup. | ||||
|         # backup_media_store_path: "%(media_store)s" | ||||
| 
 | ||||
|         # Whether to wait for successful write to backup media store before | ||||
|         # returning successfully. | ||||
|         # synchronous_backup_media_store: false | ||||
|         # Media storage providers allow media to be stored in different | ||||
|         # locations. | ||||
|         # media_storage_providers: | ||||
|         # - module: file_system | ||||
|         #   # Whether to write new local files. | ||||
|         #   store_local: false | ||||
|         #   # Whether to write new remote media | ||||
|         #   store_remote: false | ||||
|         #   # Whether to block upload requests waiting for write to this | ||||
|         #   # provider to complete | ||||
|         #   store_synchronous: false | ||||
|         #   config: | ||||
|         #     directory: /mnt/some/other/directory | ||||
| 
 | ||||
|         # Directory where in-progress uploads are stored. | ||||
|         uploads_path: "%(uploads_path)s" | ||||
|  |  | |||
|  | @ -55,6 +55,17 @@ class ServerConfig(Config): | |||
|             "block_non_admin_invites", False, | ||||
|         ) | ||||
| 
 | ||||
|         # FIXME: federation_domain_whitelist needs sytests | ||||
|         self.federation_domain_whitelist = None | ||||
|         federation_domain_whitelist = config.get( | ||||
|             "federation_domain_whitelist", None | ||||
|         ) | ||||
|         # turn the whitelist into a hash for speed of lookup | ||||
|         if federation_domain_whitelist is not None: | ||||
|             self.federation_domain_whitelist = {} | ||||
|             for domain in federation_domain_whitelist: | ||||
|                 self.federation_domain_whitelist[domain] = True | ||||
| 
 | ||||
|         if self.public_baseurl is not None: | ||||
|             if self.public_baseurl[-1] != '/': | ||||
|                 self.public_baseurl += '/' | ||||
|  | @ -210,6 +221,17 @@ class ServerConfig(Config): | |||
|         # (except those sent by local server admins). The default is False. | ||||
|         # block_non_admin_invites: True | ||||
| 
 | ||||
|         # Restrict federation to the following whitelist of domains. | ||||
|         # N.B. we recommend also firewalling your federation listener to limit | ||||
|         # inbound federation traffic as early as possible, rather than relying | ||||
|         # purely on this application-layer restriction.  If not specified, the | ||||
|         # default is to whitelist everything. | ||||
|         # | ||||
|         # federation_domain_whitelist: | ||||
|         #  - lon.example.com | ||||
|         #  - nyc.example.com | ||||
|         #  - syd.example.com | ||||
| 
 | ||||
|         # List of ports that Synapse should listen on, their purpose and their | ||||
|         # configuration. | ||||
|         listeners: | ||||
|  |  | |||
|  | @ -96,7 +96,7 @@ class TlsConfig(Config): | |||
|         # certificates returned by this server match one of the fingerprints. | ||||
|         # | ||||
|         # Synapse automatically adds the fingerprint of its own certificate | ||||
|         # to the list. So if federation traffic is handle directly by synapse | ||||
|         # to the list. So if federation traffic is handled directly by synapse | ||||
|         # then no modification to the list is required. | ||||
|         # | ||||
|         # If synapse is run behind a load balancer that handles the TLS then it | ||||
|  |  | |||
|  | @ -23,6 +23,11 @@ class WorkerConfig(Config): | |||
| 
 | ||||
|     def read_config(self, config): | ||||
|         self.worker_app = config.get("worker_app") | ||||
| 
 | ||||
|         # Canonicalise worker_app so that master always has None | ||||
|         if self.worker_app == "synapse.app.homeserver": | ||||
|             self.worker_app = None | ||||
| 
 | ||||
|         self.worker_listeners = config.get("worker_listeners") | ||||
|         self.worker_daemonize = config.get("worker_daemonize") | ||||
|         self.worker_pid_file = config.get("worker_pid_file") | ||||
|  |  | |||
|  | @ -319,7 +319,7 @@ def _is_membership_change_allowed(event, auth_events): | |||
|         # TODO (erikj): Implement kicks. | ||||
|         if target_banned and user_level < ban_level: | ||||
|             raise AuthError( | ||||
|                 403, "You cannot unban user &s." % (target_user_id,) | ||||
|                 403, "You cannot unban user %s." % (target_user_id,) | ||||
|             ) | ||||
|         elif target_user_id != event.user_id: | ||||
|             kick_level = _get_named_level(auth_events, "kick", 50) | ||||
|  |  | |||
|  | @ -16,7 +16,9 @@ import logging | |||
| 
 | ||||
| from synapse.api.errors import SynapseError | ||||
| from synapse.crypto.event_signing import check_event_content_hash | ||||
| from synapse.events import FrozenEvent | ||||
| from synapse.events.utils import prune_event | ||||
| from synapse.http.servlet import assert_params_in_request | ||||
| from synapse.util import unwrapFirstError, logcontext | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
|  | @ -169,3 +171,28 @@ class FederationBase(object): | |||
|             ) | ||||
| 
 | ||||
|         return deferreds | ||||
| 
 | ||||
| 
 | ||||
| def event_from_pdu_json(pdu_json, outlier=False): | ||||
|     """Construct a FrozenEvent from an event json received over federation | ||||
| 
 | ||||
|     Args: | ||||
|         pdu_json (object): pdu as received over federation | ||||
|         outlier (bool): True to mark this event as an outlier | ||||
| 
 | ||||
|     Returns: | ||||
|         FrozenEvent | ||||
| 
 | ||||
|     Raises: | ||||
|         SynapseError: if the pdu is missing required fields | ||||
|     """ | ||||
|     # we could probably enforce a bunch of other fields here (room_id, sender, | ||||
|     # origin, etc etc) | ||||
|     assert_params_in_request(pdu_json, ('event_id', 'type')) | ||||
|     event = FrozenEvent( | ||||
|         pdu_json | ||||
|     ) | ||||
| 
 | ||||
|     event.internal_metadata.outlier = outlier | ||||
| 
 | ||||
|     return event | ||||
|  |  | |||
|  | @ -14,28 +14,28 @@ | |||
| # limitations under the License. | ||||
| 
 | ||||
| 
 | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
| from .federation_base import FederationBase | ||||
| from synapse.api.constants import Membership | ||||
| 
 | ||||
| from synapse.api.errors import ( | ||||
|     CodeMessageException, HttpResponseException, SynapseError, | ||||
| ) | ||||
| from synapse.util import unwrapFirstError, logcontext | ||||
| from synapse.util.caches.expiringcache import ExpiringCache | ||||
| from synapse.util.logutils import log_function | ||||
| from synapse.util.logcontext import make_deferred_yieldable, preserve_fn | ||||
| from synapse.events import FrozenEvent, builder | ||||
| import synapse.metrics | ||||
| 
 | ||||
| from synapse.util.retryutils import NotRetryingDestination | ||||
| 
 | ||||
| import copy | ||||
| import itertools | ||||
| import logging | ||||
| import random | ||||
| 
 | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
| from synapse.api.constants import Membership | ||||
| from synapse.api.errors import ( | ||||
|     CodeMessageException, HttpResponseException, SynapseError, FederationDeniedError | ||||
| ) | ||||
| from synapse.events import builder | ||||
| from synapse.federation.federation_base import ( | ||||
|     FederationBase, | ||||
|     event_from_pdu_json, | ||||
| ) | ||||
| import synapse.metrics | ||||
| from synapse.util import logcontext, unwrapFirstError | ||||
| from synapse.util.caches.expiringcache import ExpiringCache | ||||
| from synapse.util.logcontext import make_deferred_yieldable, preserve_fn | ||||
| from synapse.util.logutils import log_function | ||||
| from synapse.util.retryutils import NotRetryingDestination | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
|  | @ -184,7 +184,7 @@ class FederationClient(FederationBase): | |||
|         logger.debug("backfill transaction_data=%s", repr(transaction_data)) | ||||
| 
 | ||||
|         pdus = [ | ||||
|             self.event_from_pdu_json(p, outlier=False) | ||||
|             event_from_pdu_json(p, outlier=False) | ||||
|             for p in transaction_data["pdus"] | ||||
|         ] | ||||
| 
 | ||||
|  | @ -244,7 +244,7 @@ class FederationClient(FederationBase): | |||
|                 logger.debug("transaction_data %r", transaction_data) | ||||
| 
 | ||||
|                 pdu_list = [ | ||||
|                     self.event_from_pdu_json(p, outlier=outlier) | ||||
|                     event_from_pdu_json(p, outlier=outlier) | ||||
|                     for p in transaction_data["pdus"] | ||||
|                 ] | ||||
| 
 | ||||
|  | @ -266,6 +266,9 @@ class FederationClient(FederationBase): | |||
|             except NotRetryingDestination as e: | ||||
|                 logger.info(e.message) | ||||
|                 continue | ||||
|             except FederationDeniedError as e: | ||||
|                 logger.info(e.message) | ||||
|                 continue | ||||
|             except Exception as e: | ||||
|                 pdu_attempts[destination] = now | ||||
| 
 | ||||
|  | @ -336,11 +339,11 @@ class FederationClient(FederationBase): | |||
|         ) | ||||
| 
 | ||||
|         pdus = [ | ||||
|             self.event_from_pdu_json(p, outlier=True) for p in result["pdus"] | ||||
|             event_from_pdu_json(p, outlier=True) for p in result["pdus"] | ||||
|         ] | ||||
| 
 | ||||
|         auth_chain = [ | ||||
|             self.event_from_pdu_json(p, outlier=True) | ||||
|             event_from_pdu_json(p, outlier=True) | ||||
|             for p in result.get("auth_chain", []) | ||||
|         ] | ||||
| 
 | ||||
|  | @ -441,7 +444,7 @@ class FederationClient(FederationBase): | |||
|         ) | ||||
| 
 | ||||
|         auth_chain = [ | ||||
|             self.event_from_pdu_json(p, outlier=True) | ||||
|             event_from_pdu_json(p, outlier=True) | ||||
|             for p in res["auth_chain"] | ||||
|         ] | ||||
| 
 | ||||
|  | @ -570,12 +573,12 @@ class FederationClient(FederationBase): | |||
|                 logger.debug("Got content: %s", content) | ||||
| 
 | ||||
|                 state = [ | ||||
|                     self.event_from_pdu_json(p, outlier=True) | ||||
|                     event_from_pdu_json(p, outlier=True) | ||||
|                     for p in content.get("state", []) | ||||
|                 ] | ||||
| 
 | ||||
|                 auth_chain = [ | ||||
|                     self.event_from_pdu_json(p, outlier=True) | ||||
|                     event_from_pdu_json(p, outlier=True) | ||||
|                     for p in content.get("auth_chain", []) | ||||
|                 ] | ||||
| 
 | ||||
|  | @ -650,7 +653,7 @@ class FederationClient(FederationBase): | |||
| 
 | ||||
|         logger.debug("Got response to send_invite: %s", pdu_dict) | ||||
| 
 | ||||
|         pdu = self.event_from_pdu_json(pdu_dict) | ||||
|         pdu = event_from_pdu_json(pdu_dict) | ||||
| 
 | ||||
|         # Check signatures are correct. | ||||
|         pdu = yield self._check_sigs_and_hash(pdu) | ||||
|  | @ -740,7 +743,7 @@ class FederationClient(FederationBase): | |||
|         ) | ||||
| 
 | ||||
|         auth_chain = [ | ||||
|             self.event_from_pdu_json(e) | ||||
|             event_from_pdu_json(e) | ||||
|             for e in content["auth_chain"] | ||||
|         ] | ||||
| 
 | ||||
|  | @ -788,7 +791,7 @@ class FederationClient(FederationBase): | |||
|             ) | ||||
| 
 | ||||
|             events = [ | ||||
|                 self.event_from_pdu_json(e) | ||||
|                 event_from_pdu_json(e) | ||||
|                 for e in content.get("events", []) | ||||
|             ] | ||||
| 
 | ||||
|  | @ -805,15 +808,6 @@ class FederationClient(FederationBase): | |||
| 
 | ||||
|         defer.returnValue(signed_events) | ||||
| 
 | ||||
|     def event_from_pdu_json(self, pdu_json, outlier=False): | ||||
|         event = FrozenEvent( | ||||
|             pdu_json | ||||
|         ) | ||||
| 
 | ||||
|         event.internal_metadata.outlier = outlier | ||||
| 
 | ||||
|         return event | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def forward_third_party_invite(self, destinations, room_id, event_dict): | ||||
|         for destination in destinations: | ||||
|  |  | |||
|  | @ -12,25 +12,24 @@ | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
| from .federation_base import FederationBase | ||||
| from .units import Transaction, Edu | ||||
| 
 | ||||
| from synapse.util import async | ||||
| from synapse.util.logcontext import make_deferred_yieldable, preserve_fn | ||||
| from synapse.util.logutils import log_function | ||||
| from synapse.util.caches.response_cache import ResponseCache | ||||
| from synapse.events import FrozenEvent | ||||
| from synapse.types import get_domain_from_id | ||||
| import synapse.metrics | ||||
| 
 | ||||
| from synapse.api.errors import AuthError, FederationError, SynapseError | ||||
| 
 | ||||
| from synapse.crypto.event_signing import compute_event_signature | ||||
| import logging | ||||
| 
 | ||||
| import simplejson as json | ||||
| import logging | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
| from synapse.api.errors import AuthError, FederationError, SynapseError | ||||
| from synapse.crypto.event_signing import compute_event_signature | ||||
| from synapse.federation.federation_base import ( | ||||
|     FederationBase, | ||||
|     event_from_pdu_json, | ||||
| ) | ||||
| from synapse.federation.units import Edu, Transaction | ||||
| import synapse.metrics | ||||
| from synapse.types import get_domain_from_id | ||||
| from synapse.util import async | ||||
| from synapse.util.caches.response_cache import ResponseCache | ||||
| from synapse.util.logcontext import make_deferred_yieldable, preserve_fn | ||||
| from synapse.util.logutils import log_function | ||||
| 
 | ||||
| # when processing incoming transactions, we try to handle multiple rooms in | ||||
| # parallel, up to this limit. | ||||
|  | @ -172,7 +171,7 @@ class FederationServer(FederationBase): | |||
|                 p["age_ts"] = request_time - int(p["age"]) | ||||
|                 del p["age"] | ||||
| 
 | ||||
|             event = self.event_from_pdu_json(p) | ||||
|             event = event_from_pdu_json(p) | ||||
|             room_id = event.room_id | ||||
|             pdus_by_room.setdefault(room_id, []).append(event) | ||||
| 
 | ||||
|  | @ -346,7 +345,7 @@ class FederationServer(FederationBase): | |||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_invite_request(self, origin, content): | ||||
|         pdu = self.event_from_pdu_json(content) | ||||
|         pdu = event_from_pdu_json(content) | ||||
|         ret_pdu = yield self.handler.on_invite_request(origin, pdu) | ||||
|         time_now = self._clock.time_msec() | ||||
|         defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)})) | ||||
|  | @ -354,7 +353,7 @@ class FederationServer(FederationBase): | |||
|     @defer.inlineCallbacks | ||||
|     def on_send_join_request(self, origin, content): | ||||
|         logger.debug("on_send_join_request: content: %s", content) | ||||
|         pdu = self.event_from_pdu_json(content) | ||||
|         pdu = event_from_pdu_json(content) | ||||
|         logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures) | ||||
|         res_pdus = yield self.handler.on_send_join_request(origin, pdu) | ||||
|         time_now = self._clock.time_msec() | ||||
|  | @ -374,7 +373,7 @@ class FederationServer(FederationBase): | |||
|     @defer.inlineCallbacks | ||||
|     def on_send_leave_request(self, origin, content): | ||||
|         logger.debug("on_send_leave_request: content: %s", content) | ||||
|         pdu = self.event_from_pdu_json(content) | ||||
|         pdu = event_from_pdu_json(content) | ||||
|         logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures) | ||||
|         yield self.handler.on_send_leave_request(origin, pdu) | ||||
|         defer.returnValue((200, {})) | ||||
|  | @ -411,7 +410,7 @@ class FederationServer(FederationBase): | |||
|         """ | ||||
|         with (yield self._server_linearizer.queue((origin, room_id))): | ||||
|             auth_chain = [ | ||||
|                 self.event_from_pdu_json(e) | ||||
|                 event_from_pdu_json(e) | ||||
|                 for e in content["auth_chain"] | ||||
|             ] | ||||
| 
 | ||||
|  | @ -586,15 +585,6 @@ class FederationServer(FederationBase): | |||
|     def __str__(self): | ||||
|         return "<ReplicationLayer(%s)>" % self.server_name | ||||
| 
 | ||||
|     def event_from_pdu_json(self, pdu_json, outlier=False): | ||||
|         event = FrozenEvent( | ||||
|             pdu_json | ||||
|         ) | ||||
| 
 | ||||
|         event.internal_metadata.outlier = outlier | ||||
| 
 | ||||
|         return event | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def exchange_third_party_invite( | ||||
|             self, | ||||
|  |  | |||
|  | @ -19,7 +19,7 @@ from twisted.internet import defer | |||
| from .persistence import TransactionActions | ||||
| from .units import Transaction, Edu | ||||
| 
 | ||||
| from synapse.api.errors import HttpResponseException | ||||
| from synapse.api.errors import HttpResponseException, FederationDeniedError | ||||
| from synapse.util import logcontext, PreserveLoggingContext | ||||
| from synapse.util.async import run_on_reactor | ||||
| from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter | ||||
|  | @ -42,6 +42,8 @@ sent_edus_counter = client_metrics.register_counter("sent_edus") | |||
| 
 | ||||
| sent_transactions_counter = client_metrics.register_counter("sent_transactions") | ||||
| 
 | ||||
| events_processed_counter = client_metrics.register_counter("events_processed") | ||||
| 
 | ||||
| 
 | ||||
| class TransactionQueue(object): | ||||
|     """This class makes sure we only have one transaction in flight at | ||||
|  | @ -205,6 +207,8 @@ class TransactionQueue(object): | |||
| 
 | ||||
|                     self._send_pdu(event, destinations) | ||||
| 
 | ||||
|                 events_processed_counter.inc_by(len(events)) | ||||
| 
 | ||||
|                 yield self.store.update_federation_out_pos( | ||||
|                     "events", next_token | ||||
|                 ) | ||||
|  | @ -486,6 +490,8 @@ class TransactionQueue(object): | |||
|                     (e.retry_last_ts + e.retry_interval) / 1000.0 | ||||
|                 ), | ||||
|             ) | ||||
|         except FederationDeniedError as e: | ||||
|             logger.info(e) | ||||
|         except Exception as e: | ||||
|             logger.warn( | ||||
|                 "TX [%s] Failed to send transaction: %s", | ||||
|  |  | |||
|  | @ -212,6 +212,9 @@ class TransportLayerClient(object): | |||
| 
 | ||||
|             Fails with ``NotRetryingDestination`` if we are not yet ready | ||||
|             to retry this server. | ||||
| 
 | ||||
|             Fails with ``FederationDeniedError`` if the remote destination | ||||
|             is not in our federation whitelist | ||||
|         """ | ||||
|         valid_memberships = {Membership.JOIN, Membership.LEAVE} | ||||
|         if membership not in valid_memberships: | ||||
|  |  | |||
|  | @ -16,7 +16,7 @@ | |||
| from twisted.internet import defer | ||||
| 
 | ||||
| from synapse.api.urls import FEDERATION_PREFIX as PREFIX | ||||
| from synapse.api.errors import Codes, SynapseError | ||||
| from synapse.api.errors import Codes, SynapseError, FederationDeniedError | ||||
| from synapse.http.server import JsonResource | ||||
| from synapse.http.servlet import ( | ||||
|     parse_json_object_from_request, parse_integer_from_args, parse_string_from_args, | ||||
|  | @ -81,6 +81,7 @@ class Authenticator(object): | |||
|         self.keyring = hs.get_keyring() | ||||
|         self.server_name = hs.hostname | ||||
|         self.store = hs.get_datastore() | ||||
|         self.federation_domain_whitelist = hs.config.federation_domain_whitelist | ||||
| 
 | ||||
|     # A method just so we can pass 'self' as the authenticator to the Servlets | ||||
|     @defer.inlineCallbacks | ||||
|  | @ -92,6 +93,12 @@ class Authenticator(object): | |||
|             "signatures": {}, | ||||
|         } | ||||
| 
 | ||||
|         if ( | ||||
|             self.federation_domain_whitelist is not None and | ||||
|             self.server_name not in self.federation_domain_whitelist | ||||
|         ): | ||||
|             raise FederationDeniedError(self.server_name) | ||||
| 
 | ||||
|         if content is not None: | ||||
|             json_request["content"] = content | ||||
| 
 | ||||
|  |  | |||
|  | @ -15,6 +15,7 @@ | |||
| 
 | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
| import synapse | ||||
| from synapse.api.constants import EventTypes | ||||
| from synapse.util.metrics import Measure | ||||
| from synapse.util.logcontext import make_deferred_yieldable, preserve_fn | ||||
|  | @ -23,6 +24,10 @@ import logging | |||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| metrics = synapse.metrics.get_metrics_for(__name__) | ||||
| 
 | ||||
| events_processed_counter = metrics.register_counter("events_processed") | ||||
| 
 | ||||
| 
 | ||||
| def log_failure(failure): | ||||
|     logger.error( | ||||
|  | @ -103,6 +108,8 @@ class ApplicationServicesHandler(object): | |||
|                                 service, event | ||||
|                             ) | ||||
| 
 | ||||
|                     events_processed_counter.inc_by(len(events)) | ||||
| 
 | ||||
|                     yield self.store.set_appservice_last_pos(upper_bound) | ||||
|             finally: | ||||
|                 self.is_processing = False | ||||
|  |  | |||
|  | @ -13,7 +13,7 @@ | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| from twisted.internet import defer | ||||
| from twisted.internet import defer, threads | ||||
| 
 | ||||
| from ._base import BaseHandler | ||||
| from synapse.api.constants import LoginType | ||||
|  | @ -25,6 +25,7 @@ from synapse.module_api import ModuleApi | |||
| from synapse.types import UserID | ||||
| from synapse.util.async import run_on_reactor | ||||
| from synapse.util.caches.expiringcache import ExpiringCache | ||||
| from synapse.util.logcontext import make_deferred_yieldable | ||||
| 
 | ||||
| from twisted.web.client import PartialDownloadError | ||||
| 
 | ||||
|  | @ -714,7 +715,7 @@ class AuthHandler(BaseHandler): | |||
|         if not lookupres: | ||||
|             defer.returnValue(None) | ||||
|         (user_id, password_hash) = lookupres | ||||
|         result = self.validate_hash(password, password_hash) | ||||
|         result = yield self.validate_hash(password, password_hash) | ||||
|         if not result: | ||||
|             logger.warn("Failed password login for user %s", user_id) | ||||
|             defer.returnValue(None) | ||||
|  | @ -842,10 +843,13 @@ class AuthHandler(BaseHandler): | |||
|             password (str): Password to hash. | ||||
| 
 | ||||
|         Returns: | ||||
|             Hashed password (str). | ||||
|             Deferred(str): Hashed password. | ||||
|         """ | ||||
|         return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper, | ||||
|                              bcrypt.gensalt(self.bcrypt_rounds)) | ||||
|         def _do_hash(): | ||||
|             return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper, | ||||
|                                  bcrypt.gensalt(self.bcrypt_rounds)) | ||||
| 
 | ||||
|         return make_deferred_yieldable(threads.deferToThread(_do_hash)) | ||||
| 
 | ||||
|     def validate_hash(self, password, stored_hash): | ||||
|         """Validates that self.hash(password) == stored_hash. | ||||
|  | @ -855,13 +859,17 @@ class AuthHandler(BaseHandler): | |||
|             stored_hash (str): Expected hash value. | ||||
| 
 | ||||
|         Returns: | ||||
|             Whether self.hash(password) == stored_hash (bool). | ||||
|             Deferred(bool): Whether self.hash(password) == stored_hash. | ||||
|         """ | ||||
|         if stored_hash: | ||||
| 
 | ||||
|         def _do_validate_hash(): | ||||
|             return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper, | ||||
|                                  stored_hash.encode('utf8')) == stored_hash | ||||
| 
 | ||||
|         if stored_hash: | ||||
|             return make_deferred_yieldable(threads.deferToThread(_do_validate_hash)) | ||||
|         else: | ||||
|             return False | ||||
|             return defer.succeed(False) | ||||
| 
 | ||||
| 
 | ||||
| class MacaroonGeneartor(object): | ||||
|  |  | |||
|  | @ -14,6 +14,7 @@ | |||
| # limitations under the License. | ||||
| from synapse.api import errors | ||||
| from synapse.api.constants import EventTypes | ||||
| from synapse.api.errors import FederationDeniedError | ||||
| from synapse.util import stringutils | ||||
| from synapse.util.async import Linearizer | ||||
| from synapse.util.caches.expiringcache import ExpiringCache | ||||
|  | @ -513,6 +514,9 @@ class DeviceListEduUpdater(object): | |||
|                     # This makes it more likely that the device lists will | ||||
|                     # eventually become consistent. | ||||
|                     return | ||||
|                 except FederationDeniedError as e: | ||||
|                     logger.info(e) | ||||
|                     return | ||||
|                 except Exception: | ||||
|                     # TODO: Remember that we are now out of sync and try again | ||||
|                     # later | ||||
|  |  | |||
|  | @ -17,7 +17,8 @@ import logging | |||
| 
 | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
| from synapse.types import get_domain_from_id | ||||
| from synapse.api.errors import SynapseError | ||||
| from synapse.types import get_domain_from_id, UserID | ||||
| from synapse.util.stringutils import random_string | ||||
| 
 | ||||
| 
 | ||||
|  | @ -33,7 +34,7 @@ class DeviceMessageHandler(object): | |||
|         """ | ||||
|         self.store = hs.get_datastore() | ||||
|         self.notifier = hs.get_notifier() | ||||
|         self.is_mine_id = hs.is_mine_id | ||||
|         self.is_mine = hs.is_mine | ||||
|         self.federation = hs.get_federation_sender() | ||||
| 
 | ||||
|         hs.get_replication_layer().register_edu_handler( | ||||
|  | @ -52,6 +53,12 @@ class DeviceMessageHandler(object): | |||
|         message_type = content["type"] | ||||
|         message_id = content["message_id"] | ||||
|         for user_id, by_device in content["messages"].items(): | ||||
|             # we use UserID.from_string to catch invalid user ids | ||||
|             if not self.is_mine(UserID.from_string(user_id)): | ||||
|                 logger.warning("Request for keys for non-local user %s", | ||||
|                                user_id) | ||||
|                 raise SynapseError(400, "Not a user here") | ||||
| 
 | ||||
|             messages_by_device = { | ||||
|                 device_id: { | ||||
|                     "content": message_content, | ||||
|  | @ -77,7 +84,8 @@ class DeviceMessageHandler(object): | |||
|         local_messages = {} | ||||
|         remote_messages = {} | ||||
|         for user_id, by_device in messages.items(): | ||||
|             if self.is_mine_id(user_id): | ||||
|             # we use UserID.from_string to catch invalid user ids | ||||
|             if self.is_mine(UserID.from_string(user_id)): | ||||
|                 messages_by_device = { | ||||
|                     device_id: { | ||||
|                         "content": message_content, | ||||
|  |  | |||
|  | @ -19,8 +19,10 @@ import logging | |||
| from canonicaljson import encode_canonical_json | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
| from synapse.api.errors import SynapseError, CodeMessageException | ||||
| from synapse.types import get_domain_from_id | ||||
| from synapse.api.errors import ( | ||||
|     SynapseError, CodeMessageException, FederationDeniedError, | ||||
| ) | ||||
| from synapse.types import get_domain_from_id, UserID | ||||
| from synapse.util.logcontext import preserve_fn, make_deferred_yieldable | ||||
| from synapse.util.retryutils import NotRetryingDestination | ||||
| 
 | ||||
|  | @ -32,7 +34,7 @@ class E2eKeysHandler(object): | |||
|         self.store = hs.get_datastore() | ||||
|         self.federation = hs.get_replication_layer() | ||||
|         self.device_handler = hs.get_device_handler() | ||||
|         self.is_mine_id = hs.is_mine_id | ||||
|         self.is_mine = hs.is_mine | ||||
|         self.clock = hs.get_clock() | ||||
| 
 | ||||
|         # doesn't really work as part of the generic query API, because the | ||||
|  | @ -70,7 +72,8 @@ class E2eKeysHandler(object): | |||
|         remote_queries = {} | ||||
| 
 | ||||
|         for user_id, device_ids in device_keys_query.items(): | ||||
|             if self.is_mine_id(user_id): | ||||
|             # we use UserID.from_string to catch invalid user ids | ||||
|             if self.is_mine(UserID.from_string(user_id)): | ||||
|                 local_query[user_id] = device_ids | ||||
|             else: | ||||
|                 remote_queries[user_id] = device_ids | ||||
|  | @ -139,6 +142,10 @@ class E2eKeysHandler(object): | |||
|                 failures[destination] = { | ||||
|                     "status": 503, "message": "Not ready for retry", | ||||
|                 } | ||||
|             except FederationDeniedError as e: | ||||
|                 failures[destination] = { | ||||
|                     "status": 403, "message": "Federation Denied", | ||||
|                 } | ||||
|             except Exception as e: | ||||
|                 # include ConnectionRefused and other errors | ||||
|                 failures[destination] = { | ||||
|  | @ -170,7 +177,8 @@ class E2eKeysHandler(object): | |||
| 
 | ||||
|         result_dict = {} | ||||
|         for user_id, device_ids in query.items(): | ||||
|             if not self.is_mine_id(user_id): | ||||
|             # we use UserID.from_string to catch invalid user ids | ||||
|             if not self.is_mine(UserID.from_string(user_id)): | ||||
|                 logger.warning("Request for keys for non-local user %s", | ||||
|                                user_id) | ||||
|                 raise SynapseError(400, "Not a user here") | ||||
|  | @ -213,7 +221,8 @@ class E2eKeysHandler(object): | |||
|         remote_queries = {} | ||||
| 
 | ||||
|         for user_id, device_keys in query.get("one_time_keys", {}).items(): | ||||
|             if self.is_mine_id(user_id): | ||||
|             # we use UserID.from_string to catch invalid user ids | ||||
|             if self.is_mine(UserID.from_string(user_id)): | ||||
|                 for device_id, algorithm in device_keys.items(): | ||||
|                     local_query.append((user_id, device_id, algorithm)) | ||||
|             else: | ||||
|  |  | |||
|  | @ -22,6 +22,7 @@ from ._base import BaseHandler | |||
| 
 | ||||
| from synapse.api.errors import ( | ||||
|     AuthError, FederationError, StoreError, CodeMessageException, SynapseError, | ||||
|     FederationDeniedError, | ||||
| ) | ||||
| from synapse.api.constants import EventTypes, Membership, RejectedReason | ||||
| from synapse.events.validator import EventValidator | ||||
|  | @ -782,6 +783,9 @@ class FederationHandler(BaseHandler): | |||
|                 except NotRetryingDestination as e: | ||||
|                     logger.info(e.message) | ||||
|                     continue | ||||
|                 except FederationDeniedError as e: | ||||
|                     logger.info(e) | ||||
|                     continue | ||||
|                 except Exception as e: | ||||
|                     logger.exception( | ||||
|                         "Failed to backfill from %s because %s", | ||||
|  |  | |||
|  | @ -383,11 +383,12 @@ class GroupsLocalHandler(object): | |||
| 
 | ||||
|             defer.returnValue({"groups": result}) | ||||
|         else: | ||||
|             result = yield self.transport_client.get_publicised_groups_for_user( | ||||
|                 get_domain_from_id(user_id), user_id | ||||
|             bulk_result = yield self.transport_client.bulk_get_publicised_groups( | ||||
|                 get_domain_from_id(user_id), [user_id], | ||||
|             ) | ||||
|             result = bulk_result.get("users", {}).get(user_id) | ||||
|             # TODO: Verify attestations | ||||
|             defer.returnValue(result) | ||||
|             defer.returnValue({"groups": result}) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def bulk_get_publicised_groups(self, user_ids, proxy=True): | ||||
|  |  | |||
|  | @ -25,6 +25,7 @@ from synapse.http.client import CaptchaServerHttpClient | |||
| from synapse import types | ||||
| from synapse.types import UserID | ||||
| from synapse.util.async import run_on_reactor | ||||
| from synapse.util.threepids import check_3pid_allowed | ||||
| from ._base import BaseHandler | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
|  | @ -131,7 +132,7 @@ class RegistrationHandler(BaseHandler): | |||
|         yield run_on_reactor() | ||||
|         password_hash = None | ||||
|         if password: | ||||
|             password_hash = self.auth_handler().hash(password) | ||||
|             password_hash = yield self.auth_handler().hash(password) | ||||
| 
 | ||||
|         if localpart: | ||||
|             yield self.check_username(localpart, guest_access_token=guest_access_token) | ||||
|  | @ -293,7 +294,7 @@ class RegistrationHandler(BaseHandler): | |||
|         """ | ||||
| 
 | ||||
|         for c in threepidCreds: | ||||
|             logger.info("validating theeepidcred sid %s on id server %s", | ||||
|             logger.info("validating threepidcred sid %s on id server %s", | ||||
|                         c['sid'], c['idServer']) | ||||
|             try: | ||||
|                 identity_handler = self.hs.get_handlers().identity_handler | ||||
|  | @ -307,6 +308,11 @@ class RegistrationHandler(BaseHandler): | |||
|             logger.info("got threepid with medium '%s' and address '%s'", | ||||
|                         threepid['medium'], threepid['address']) | ||||
| 
 | ||||
|             if not check_3pid_allowed(self.hs, threepid['medium'], threepid['address']): | ||||
|                 raise RegistrationError( | ||||
|                     403, "Third party identifier is not allowed" | ||||
|                 ) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def bind_emails(self, user_id, threepidCreds): | ||||
|         """Links emails with a user ID and informs an identity server. | ||||
|  |  | |||
|  | @ -203,7 +203,8 @@ class RoomListHandler(BaseHandler): | |||
|         if limit: | ||||
|             step = limit + 1 | ||||
|         else: | ||||
|             step = len(rooms_to_scan) | ||||
|             # step cannot be zero | ||||
|             step = len(rooms_to_scan) if len(rooms_to_scan) != 0 else 1 | ||||
| 
 | ||||
|         chunk = [] | ||||
|         for i in xrange(0, len(rooms_to_scan), step): | ||||
|  |  | |||
|  | @ -31,7 +31,7 @@ class SetPasswordHandler(BaseHandler): | |||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def set_password(self, user_id, newpassword, requester=None): | ||||
|         password_hash = self._auth_handler.hash(newpassword) | ||||
|         password_hash = yield self._auth_handler.hash(newpassword) | ||||
| 
 | ||||
|         except_device_id = requester.device_id if requester else None | ||||
|         except_access_token_id = requester.access_token_id if requester else None | ||||
|  |  | |||
|  | @ -18,6 +18,7 @@ from OpenSSL.SSL import VERIFY_NONE | |||
| from synapse.api.errors import ( | ||||
|     CodeMessageException, MatrixCodeMessageException, SynapseError, Codes, | ||||
| ) | ||||
| from synapse.util.caches import CACHE_SIZE_FACTOR | ||||
| from synapse.util.logcontext import make_deferred_yieldable | ||||
| from synapse.util import logcontext | ||||
| import synapse.metrics | ||||
|  | @ -30,6 +31,7 @@ from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS | |||
| from twisted.web.client import ( | ||||
|     BrowserLikeRedirectAgent, ContentDecoderAgent, GzipDecoder, Agent, | ||||
|     readBody, PartialDownloadError, | ||||
|     HTTPConnectionPool, | ||||
| ) | ||||
| from twisted.web.client import FileBodyProducer as TwistedFileBodyProducer | ||||
| from twisted.web.http import PotentialDataLoss | ||||
|  | @ -64,13 +66,23 @@ class SimpleHttpClient(object): | |||
|     """ | ||||
|     def __init__(self, hs): | ||||
|         self.hs = hs | ||||
| 
 | ||||
|         pool = HTTPConnectionPool(reactor) | ||||
| 
 | ||||
|         # the pusher makes lots of concurrent SSL connections to sygnal, and | ||||
|         # tends to do so in batches, so we need to allow the pool to keep lots | ||||
|         # of idle connections around. | ||||
|         pool.maxPersistentPerHost = max((100 * CACHE_SIZE_FACTOR, 5)) | ||||
|         pool.cachedConnectionTimeout = 2 * 60 | ||||
| 
 | ||||
|         # The default context factory in Twisted 14.0.0 (which we require) is | ||||
|         # BrowserLikePolicyForHTTPS which will do regular cert validation | ||||
|         # 'like a browser' | ||||
|         self.agent = Agent( | ||||
|             reactor, | ||||
|             connectTimeout=15, | ||||
|             contextFactory=hs.get_http_client_context_factory() | ||||
|             contextFactory=hs.get_http_client_context_factory(), | ||||
|             pool=pool, | ||||
|         ) | ||||
|         self.user_agent = hs.version_string | ||||
|         self.clock = hs.get_clock() | ||||
|  |  | |||
|  | @ -357,8 +357,7 @@ def _get_hosts_for_srv_record(dns_client, host): | |||
|     def eb(res, record_type): | ||||
|         if res.check(DNSNameError): | ||||
|             return [] | ||||
|         logger.warn("Error looking up %s for %s: %s", | ||||
|                     record_type, host, res, res.value) | ||||
|         logger.warn("Error looking up %s for %s: %s", record_type, host, res) | ||||
|         return res | ||||
| 
 | ||||
|     # no logcontexts here, so we can safely fire these off and gatherResults | ||||
|  |  | |||
|  | @ -27,7 +27,7 @@ import synapse.metrics | |||
| from canonicaljson import encode_canonical_json | ||||
| 
 | ||||
| from synapse.api.errors import ( | ||||
|     SynapseError, Codes, HttpResponseException, | ||||
|     SynapseError, Codes, HttpResponseException, FederationDeniedError, | ||||
| ) | ||||
| 
 | ||||
| from signedjson.sign import sign_json | ||||
|  | @ -123,11 +123,22 @@ class MatrixFederationHttpClient(object): | |||
| 
 | ||||
|             Fails with ``HTTPRequestException``: if we get an HTTP response | ||||
|                 code >= 300. | ||||
| 
 | ||||
|             Fails with ``NotRetryingDestination`` if we are not yet ready | ||||
|                 to retry this server. | ||||
| 
 | ||||
|             Fails with ``FederationDeniedError`` if this destination | ||||
|                 is not on our federation whitelist | ||||
| 
 | ||||
|             (May also fail with plenty of other Exceptions for things like DNS | ||||
|                 failures, connection failures, SSL failures.) | ||||
|         """ | ||||
|         if ( | ||||
|             self.hs.config.federation_domain_whitelist and | ||||
|             destination not in self.hs.config.federation_domain_whitelist | ||||
|         ): | ||||
|             raise FederationDeniedError(destination) | ||||
| 
 | ||||
|         limiter = yield synapse.util.retryutils.get_retry_limiter( | ||||
|             destination, | ||||
|             self.clock, | ||||
|  | @ -308,6 +319,9 @@ class MatrixFederationHttpClient(object): | |||
| 
 | ||||
|             Fails with ``NotRetryingDestination`` if we are not yet ready | ||||
|             to retry this server. | ||||
| 
 | ||||
|             Fails with ``FederationDeniedError`` if this destination | ||||
|             is not on our federation whitelist | ||||
|         """ | ||||
| 
 | ||||
|         if not json_data_callback: | ||||
|  | @ -368,6 +382,9 @@ class MatrixFederationHttpClient(object): | |||
| 
 | ||||
|             Fails with ``NotRetryingDestination`` if we are not yet ready | ||||
|             to retry this server. | ||||
| 
 | ||||
|             Fails with ``FederationDeniedError`` if this destination | ||||
|             is not on our federation whitelist | ||||
|         """ | ||||
| 
 | ||||
|         def body_callback(method, url_bytes, headers_dict): | ||||
|  | @ -422,6 +439,9 @@ class MatrixFederationHttpClient(object): | |||
| 
 | ||||
|             Fails with ``NotRetryingDestination`` if we are not yet ready | ||||
|             to retry this server. | ||||
| 
 | ||||
|             Fails with ``FederationDeniedError`` if this destination | ||||
|             is not on our federation whitelist | ||||
|         """ | ||||
|         logger.debug("get_json args: %s", args) | ||||
| 
 | ||||
|  | @ -475,6 +495,9 @@ class MatrixFederationHttpClient(object): | |||
| 
 | ||||
|             Fails with ``NotRetryingDestination`` if we are not yet ready | ||||
|             to retry this server. | ||||
| 
 | ||||
|             Fails with ``FederationDeniedError`` if this destination | ||||
|             is not on our federation whitelist | ||||
|         """ | ||||
| 
 | ||||
|         response = yield self._request( | ||||
|  | @ -518,6 +541,9 @@ class MatrixFederationHttpClient(object): | |||
| 
 | ||||
|             Fails with ``NotRetryingDestination`` if we are not yet ready | ||||
|             to retry this server. | ||||
| 
 | ||||
|             Fails with ``FederationDeniedError`` if this destination | ||||
|             is not on our federation whitelist | ||||
|         """ | ||||
| 
 | ||||
|         encoded_args = {} | ||||
|  |  | |||
|  | @ -42,36 +42,70 @@ logger = logging.getLogger(__name__) | |||
| 
 | ||||
| metrics = synapse.metrics.get_metrics_for(__name__) | ||||
| 
 | ||||
| incoming_requests_counter = metrics.register_counter( | ||||
|     "requests", | ||||
| # total number of responses served, split by method/servlet/tag | ||||
| response_count = metrics.register_counter( | ||||
|     "response_count", | ||||
|     labels=["method", "servlet", "tag"], | ||||
|     alternative_names=( | ||||
|         # the following are all deprecated aliases for the same metric | ||||
|         metrics.name_prefix + x for x in ( | ||||
|             "_requests", | ||||
|             "_response_time:count", | ||||
|             "_response_ru_utime:count", | ||||
|             "_response_ru_stime:count", | ||||
|             "_response_db_txn_count:count", | ||||
|             "_response_db_txn_duration:count", | ||||
|         ) | ||||
|     ) | ||||
| ) | ||||
| 
 | ||||
| outgoing_responses_counter = metrics.register_counter( | ||||
|     "responses", | ||||
|     labels=["method", "code"], | ||||
| ) | ||||
| 
 | ||||
| response_timer = metrics.register_distribution( | ||||
|     "response_time", | ||||
|     labels=["method", "servlet", "tag"] | ||||
| response_timer = metrics.register_counter( | ||||
|     "response_time_seconds", | ||||
|     labels=["method", "servlet", "tag"], | ||||
|     alternative_names=( | ||||
|         metrics.name_prefix + "_response_time:total", | ||||
|     ), | ||||
| ) | ||||
| 
 | ||||
| response_ru_utime = metrics.register_distribution( | ||||
|     "response_ru_utime", labels=["method", "servlet", "tag"] | ||||
| response_ru_utime = metrics.register_counter( | ||||
|     "response_ru_utime_seconds", labels=["method", "servlet", "tag"], | ||||
|     alternative_names=( | ||||
|         metrics.name_prefix + "_response_ru_utime:total", | ||||
|     ), | ||||
| ) | ||||
| 
 | ||||
| response_ru_stime = metrics.register_distribution( | ||||
|     "response_ru_stime", labels=["method", "servlet", "tag"] | ||||
| response_ru_stime = metrics.register_counter( | ||||
|     "response_ru_stime_seconds", labels=["method", "servlet", "tag"], | ||||
|     alternative_names=( | ||||
|         metrics.name_prefix + "_response_ru_stime:total", | ||||
|     ), | ||||
| ) | ||||
| 
 | ||||
| response_db_txn_count = metrics.register_distribution( | ||||
|     "response_db_txn_count", labels=["method", "servlet", "tag"] | ||||
| response_db_txn_count = metrics.register_counter( | ||||
|     "response_db_txn_count", labels=["method", "servlet", "tag"], | ||||
|     alternative_names=( | ||||
|         metrics.name_prefix + "_response_db_txn_count:total", | ||||
|     ), | ||||
| ) | ||||
| 
 | ||||
| response_db_txn_duration = metrics.register_distribution( | ||||
|     "response_db_txn_duration", labels=["method", "servlet", "tag"] | ||||
| # seconds spent waiting for db txns, excluding scheduling time, when processing | ||||
| # this request | ||||
| response_db_txn_duration = metrics.register_counter( | ||||
|     "response_db_txn_duration_seconds", labels=["method", "servlet", "tag"], | ||||
|     alternative_names=( | ||||
|         metrics.name_prefix + "_response_db_txn_duration:total", | ||||
|     ), | ||||
| ) | ||||
| 
 | ||||
| # seconds spent waiting for a db connection, when processing this request | ||||
| response_db_sched_duration = metrics.register_counter( | ||||
|     "response_db_sched_duration_seconds", labels=["method", "servlet", "tag"] | ||||
| ) | ||||
| 
 | ||||
| _next_request_id = 0 | ||||
| 
 | ||||
|  | @ -107,6 +141,10 @@ def wrap_request_handler(request_handler, include_metrics=False): | |||
|         with LoggingContext(request_id) as request_context: | ||||
|             with Measure(self.clock, "wrapped_request_handler"): | ||||
|                 request_metrics = RequestMetrics() | ||||
|                 # we start the request metrics timer here with an initial stab | ||||
|                 # at the servlet name. For most requests that name will be | ||||
|                 # JsonResource (or a subclass), and JsonResource._async_render | ||||
|                 # will update it once it picks a servlet. | ||||
|                 request_metrics.start(self.clock, name=self.__class__.__name__) | ||||
| 
 | ||||
|                 request_context.request = request_id | ||||
|  | @ -249,12 +287,23 @@ class JsonResource(HttpServer, resource.Resource): | |||
|             if not m: | ||||
|                 continue | ||||
| 
 | ||||
|             # We found a match! Trigger callback and then return the | ||||
|             # returned response. We pass both the request and any | ||||
|             # matched groups from the regex to the callback. | ||||
|             # We found a match! First update the metrics object to indicate | ||||
|             # which servlet is handling the request. | ||||
| 
 | ||||
|             callback = path_entry.callback | ||||
| 
 | ||||
|             servlet_instance = getattr(callback, "__self__", None) | ||||
|             if servlet_instance is not None: | ||||
|                 servlet_classname = servlet_instance.__class__.__name__ | ||||
|             else: | ||||
|                 servlet_classname = "%r" % callback | ||||
| 
 | ||||
|             request_metrics.name = servlet_classname | ||||
| 
 | ||||
|             # Now trigger the callback. If it returns a response, we send it | ||||
|             # here. If it throws an exception, that is handled by the wrapper | ||||
|             # installed by @request_handler. | ||||
| 
 | ||||
|             kwargs = intern_dict({ | ||||
|                 name: urllib.unquote(value).decode("UTF-8") if value else value | ||||
|                 for name, value in m.groupdict().items() | ||||
|  | @ -265,30 +314,14 @@ class JsonResource(HttpServer, resource.Resource): | |||
|                 code, response = callback_return | ||||
|                 self._send_response(request, code, response) | ||||
| 
 | ||||
|             servlet_instance = getattr(callback, "__self__", None) | ||||
|             if servlet_instance is not None: | ||||
|                 servlet_classname = servlet_instance.__class__.__name__ | ||||
|             else: | ||||
|                 servlet_classname = "%r" % callback | ||||
| 
 | ||||
|             request_metrics.name = servlet_classname | ||||
| 
 | ||||
|             return | ||||
| 
 | ||||
|         # Huh. No one wanted to handle that? Fiiiiiine. Send 400. | ||||
|         request_metrics.name = self.__class__.__name__ + ".UnrecognizedRequest" | ||||
|         raise UnrecognizedRequestError() | ||||
| 
 | ||||
|     def _send_response(self, request, code, response_json_object, | ||||
|                        response_code_message=None): | ||||
|         # could alternatively use request.notifyFinish() and flip a flag when | ||||
|         # the Deferred fires, but since the flag is RIGHT THERE it seems like | ||||
|         # a waste. | ||||
|         if request._disconnected: | ||||
|             logger.warn( | ||||
|                 "Not sending response to request %s, already disconnected.", | ||||
|                 request) | ||||
|             return | ||||
| 
 | ||||
|         outgoing_responses_counter.inc(request.method, str(code)) | ||||
| 
 | ||||
|         # TODO: Only enable CORS for the requests that need it. | ||||
|  | @ -322,7 +355,7 @@ class RequestMetrics(object): | |||
|                 ) | ||||
|                 return | ||||
| 
 | ||||
|         incoming_requests_counter.inc(request.method, self.name, tag) | ||||
|         response_count.inc(request.method, self.name, tag) | ||||
| 
 | ||||
|         response_timer.inc_by( | ||||
|             clock.time_msec() - self.start, request.method, | ||||
|  | @ -341,7 +374,10 @@ class RequestMetrics(object): | |||
|             context.db_txn_count, request.method, self.name, tag | ||||
|         ) | ||||
|         response_db_txn_duration.inc_by( | ||||
|             context.db_txn_duration, request.method, self.name, tag | ||||
|             context.db_txn_duration_ms / 1000., request.method, self.name, tag | ||||
|         ) | ||||
|         response_db_sched_duration.inc_by( | ||||
|             context.db_sched_duration_ms / 1000., request.method, self.name, tag | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
|  | @ -364,6 +400,15 @@ class RootRedirect(resource.Resource): | |||
| def respond_with_json(request, code, json_object, send_cors=False, | ||||
|                       response_code_message=None, pretty_print=False, | ||||
|                       version_string="", canonical_json=True): | ||||
|     # could alternatively use request.notifyFinish() and flip a flag when | ||||
|     # the Deferred fires, but since the flag is RIGHT THERE it seems like | ||||
|     # a waste. | ||||
|     if request._disconnected: | ||||
|         logger.warn( | ||||
|             "Not sending response to request %s, already disconnected.", | ||||
|             request) | ||||
|         return | ||||
| 
 | ||||
|     if pretty_print: | ||||
|         json_bytes = encode_pretty_printed_json(json_object) + "\n" | ||||
|     else: | ||||
|  |  | |||
|  | @ -66,14 +66,15 @@ class SynapseRequest(Request): | |||
|             context = LoggingContext.current_context() | ||||
|             ru_utime, ru_stime = context.get_resource_usage() | ||||
|             db_txn_count = context.db_txn_count | ||||
|             db_txn_duration = context.db_txn_duration | ||||
|             db_txn_duration_ms = context.db_txn_duration_ms | ||||
|             db_sched_duration_ms = context.db_sched_duration_ms | ||||
|         except Exception: | ||||
|             ru_utime, ru_stime = (0, 0) | ||||
|             db_txn_count, db_txn_duration = (0, 0) | ||||
|             db_txn_count, db_txn_duration_ms = (0, 0) | ||||
| 
 | ||||
|         self.site.access_logger.info( | ||||
|             "%s - %s - {%s}" | ||||
|             " Processed request: %dms (%dms, %dms) (%dms/%d)" | ||||
|             " Processed request: %dms (%dms, %dms) (%dms/%dms/%d)" | ||||
|             " %sB %s \"%s %s %s\" \"%s\"", | ||||
|             self.getClientIP(), | ||||
|             self.site.site_tag, | ||||
|  | @ -81,7 +82,8 @@ class SynapseRequest(Request): | |||
|             int(time.time() * 1000) - self.start_time, | ||||
|             int(ru_utime * 1000), | ||||
|             int(ru_stime * 1000), | ||||
|             int(db_txn_duration * 1000), | ||||
|             db_sched_duration_ms, | ||||
|             db_txn_duration_ms, | ||||
|             int(db_txn_count), | ||||
|             self.sentLength, | ||||
|             self.code, | ||||
|  |  | |||
|  | @ -146,10 +146,15 @@ def runUntilCurrentTimer(func): | |||
|             num_pending += 1 | ||||
| 
 | ||||
|         num_pending += len(reactor.threadCallQueue) | ||||
| 
 | ||||
|         start = time.time() * 1000 | ||||
|         ret = func(*args, **kwargs) | ||||
|         end = time.time() * 1000 | ||||
| 
 | ||||
|         # record the amount of wallclock time spent running pending calls. | ||||
|         # This is a proxy for the actual amount of time between reactor polls, | ||||
|         # since about 25% of time is actually spent running things triggered by | ||||
|         # I/O events, but that is harder to capture without rewriting half the | ||||
|         # reactor. | ||||
|         tick_time.inc_by(end - start) | ||||
|         pending_calls_metric.inc_by(num_pending) | ||||
| 
 | ||||
|  |  | |||
|  | @ -15,18 +15,38 @@ | |||
| 
 | ||||
| 
 | ||||
| from itertools import chain | ||||
| import logging | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| # TODO(paul): I can't believe Python doesn't have one of these | ||||
| def map_concat(func, items): | ||||
|     # flatten a list-of-lists | ||||
|     return list(chain.from_iterable(map(func, items))) | ||||
| def flatten(items): | ||||
|     """Flatten a list of lists | ||||
| 
 | ||||
|     Args: | ||||
|         items: iterable[iterable[X]] | ||||
| 
 | ||||
|     Returns: | ||||
|         list[X]: flattened list | ||||
|     """ | ||||
|     return list(chain.from_iterable(items)) | ||||
| 
 | ||||
| 
 | ||||
| class BaseMetric(object): | ||||
|     """Base class for metrics which report a single value per label set | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, name, labels=[]): | ||||
|         self.name = name | ||||
|     def __init__(self, name, labels=[], alternative_names=[]): | ||||
|         """ | ||||
|         Args: | ||||
|             name (str): principal name for this metric | ||||
|             labels (list(str)): names of the labels which will be reported | ||||
|                 for this metric | ||||
|             alternative_names (iterable(str)): list of alternative names for | ||||
|                  this metric. This can be useful to provide a migration path | ||||
|                 when renaming metrics. | ||||
|         """ | ||||
|         self._names = [name] + list(alternative_names) | ||||
|         self.labels = labels  # OK not to clone as we never write it | ||||
| 
 | ||||
|     def dimension(self): | ||||
|  | @ -36,7 +56,7 @@ class BaseMetric(object): | |||
|         return not len(self.labels) | ||||
| 
 | ||||
|     def _render_labelvalue(self, value): | ||||
|         # TODO: some kind of value escape | ||||
|         # TODO: escape backslashes, quotes and newlines | ||||
|         return '"%s"' % (value) | ||||
| 
 | ||||
|     def _render_key(self, values): | ||||
|  | @ -47,19 +67,60 @@ class BaseMetric(object): | |||
|                       for k, v in zip(self.labels, values)]) | ||||
|         ) | ||||
| 
 | ||||
|     def _render_for_labels(self, label_values, value): | ||||
|         """Render this metric for a single set of labels | ||||
| 
 | ||||
|         Args: | ||||
|             label_values (list[str]): values for each of the labels | ||||
|             value: value of the metric at with these labels | ||||
| 
 | ||||
|         Returns: | ||||
|             iterable[str]: rendered metric | ||||
|         """ | ||||
|         rendered_labels = self._render_key(label_values) | ||||
|         return ( | ||||
|             "%s%s %.12g" % (name, rendered_labels, value) | ||||
|             for name in self._names | ||||
|         ) | ||||
| 
 | ||||
|     def render(self): | ||||
|         """Render this metric | ||||
| 
 | ||||
|         Each metric is rendered as: | ||||
| 
 | ||||
|             name{label1="val1",label2="val2"} value | ||||
| 
 | ||||
|         https://prometheus.io/docs/instrumenting/exposition_formats/#text-format-details | ||||
| 
 | ||||
|         Returns: | ||||
|             iterable[str]: rendered metrics | ||||
|         """ | ||||
|         raise NotImplementedError() | ||||
| 
 | ||||
| 
 | ||||
| class CounterMetric(BaseMetric): | ||||
|     """The simplest kind of metric; one that stores a monotonically-increasing | ||||
|     integer that counts events.""" | ||||
|     value that counts events or running totals. | ||||
| 
 | ||||
|     Example use cases for Counters: | ||||
|     - Number of requests processed | ||||
|     - Number of items that were inserted into a queue | ||||
|     - Total amount of data that a system has processed | ||||
|     Counters can only go up (and be reset when the process restarts). | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         super(CounterMetric, self).__init__(*args, **kwargs) | ||||
| 
 | ||||
|         # dict[list[str]]: value for each set of label values. the keys are the | ||||
|         # label values, in the same order as the labels in self.labels. | ||||
|         # | ||||
|         # (if the metric is a scalar, the (single) key is the empty list). | ||||
|         self.counts = {} | ||||
| 
 | ||||
|         # Scalar metrics are never empty | ||||
|         if self.is_scalar(): | ||||
|             self.counts[()] = 0 | ||||
|             self.counts[()] = 0. | ||||
| 
 | ||||
|     def inc_by(self, incr, *values): | ||||
|         if len(values) != self.dimension(): | ||||
|  | @ -77,11 +138,11 @@ class CounterMetric(BaseMetric): | |||
|     def inc(self, *values): | ||||
|         self.inc_by(1, *values) | ||||
| 
 | ||||
|     def render_item(self, k): | ||||
|         return ["%s%s %d" % (self.name, self._render_key(k), self.counts[k])] | ||||
| 
 | ||||
|     def render(self): | ||||
|         return map_concat(self.render_item, sorted(self.counts.keys())) | ||||
|         return flatten( | ||||
|             self._render_for_labels(k, self.counts[k]) | ||||
|             for k in sorted(self.counts.keys()) | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| class CallbackMetric(BaseMetric): | ||||
|  | @ -95,13 +156,19 @@ class CallbackMetric(BaseMetric): | |||
|         self.callback = callback | ||||
| 
 | ||||
|     def render(self): | ||||
|         value = self.callback() | ||||
|         try: | ||||
|             value = self.callback() | ||||
|         except Exception: | ||||
|             logger.exception("Failed to render %s", self.name) | ||||
|             return ["# FAILED to render " + self.name] | ||||
| 
 | ||||
|         if self.is_scalar(): | ||||
|             return ["%s %.12g" % (self.name, value)] | ||||
|             return list(self._render_for_labels([], value)) | ||||
| 
 | ||||
|         return ["%s%s %.12g" % (self.name, self._render_key(k), value[k]) | ||||
|                 for k in sorted(value.keys())] | ||||
|         return flatten( | ||||
|             self._render_for_labels(k, value[k]) | ||||
|             for k in sorted(value.keys()) | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| class DistributionMetric(object): | ||||
|  |  | |||
|  | @ -13,21 +13,30 @@ | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| from synapse.push import PusherConfigException | ||||
| import logging | ||||
| 
 | ||||
| from twisted.internet import defer, reactor | ||||
| from twisted.internet.error import AlreadyCalled, AlreadyCancelled | ||||
| 
 | ||||
| import logging | ||||
| import push_rule_evaluator | ||||
| import push_tools | ||||
| 
 | ||||
| import synapse | ||||
| from synapse.push import PusherConfigException | ||||
| from synapse.util.logcontext import LoggingContext | ||||
| from synapse.util.metrics import Measure | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| metrics = synapse.metrics.get_metrics_for(__name__) | ||||
| 
 | ||||
| http_push_processed_counter = metrics.register_counter( | ||||
|     "http_pushes_processed", | ||||
| ) | ||||
| 
 | ||||
| http_push_failed_counter = metrics.register_counter( | ||||
|     "http_pushes_failed", | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| class HttpPusher(object): | ||||
|     INITIAL_BACKOFF_SEC = 1  # in seconds because that's what Twisted takes | ||||
|  | @ -152,9 +161,16 @@ class HttpPusher(object): | |||
|             self.user_id, self.last_stream_ordering, self.max_stream_ordering | ||||
|         ) | ||||
| 
 | ||||
|         logger.info( | ||||
|             "Processing %i unprocessed push actions for %s starting at " | ||||
|             "stream_ordering %s", | ||||
|             len(unprocessed), self.name, self.last_stream_ordering, | ||||
|         ) | ||||
| 
 | ||||
|         for push_action in unprocessed: | ||||
|             processed = yield self._process_one(push_action) | ||||
|             if processed: | ||||
|                 http_push_processed_counter.inc() | ||||
|                 self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC | ||||
|                 self.last_stream_ordering = push_action['stream_ordering'] | ||||
|                 yield self.store.update_pusher_last_stream_ordering_and_success( | ||||
|  | @ -169,6 +185,7 @@ class HttpPusher(object): | |||
|                         self.failing_since | ||||
|                     ) | ||||
|             else: | ||||
|                 http_push_failed_counter.inc() | ||||
|                 if not self.failing_since: | ||||
|                     self.failing_since = self.clock.time_msec() | ||||
|                     yield self.store.update_pusher_failing_since( | ||||
|  | @ -316,7 +333,10 @@ class HttpPusher(object): | |||
|         try: | ||||
|             resp = yield self.http_client.post_json_get_json(self.url, notification_dict) | ||||
|         except Exception: | ||||
|             logger.warn("Failed to push %s ", self.url) | ||||
|             logger.warn( | ||||
|                 "Failed to push event %s to %s", | ||||
|                 event.event_id, self.name, exc_info=True, | ||||
|             ) | ||||
|             defer.returnValue(False) | ||||
|         rejected = [] | ||||
|         if 'rejected' in resp: | ||||
|  | @ -325,7 +345,7 @@ class HttpPusher(object): | |||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _send_badge(self, badge): | ||||
|         logger.info("Sending updated badge count %d to %r", badge, self.user_id) | ||||
|         logger.info("Sending updated badge count %d to %s", badge, self.name) | ||||
|         d = { | ||||
|             'notification': { | ||||
|                 'id': '', | ||||
|  | @ -347,7 +367,10 @@ class HttpPusher(object): | |||
|         try: | ||||
|             resp = yield self.http_client.post_json_get_json(self.url, d) | ||||
|         except Exception: | ||||
|             logger.exception("Failed to push %s ", self.url) | ||||
|             logger.warn( | ||||
|                 "Failed to send badge count to %s", | ||||
|                 self.name, exc_info=True, | ||||
|             ) | ||||
|             defer.returnValue(False) | ||||
|         rejected = [] | ||||
|         if 'rejected' in resp: | ||||
|  |  | |||
|  | @ -517,25 +517,28 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol): | |||
|             self.send_error("Wrong remote") | ||||
| 
 | ||||
|     def on_RDATA(self, cmd): | ||||
|         stream_name = cmd.stream_name | ||||
|         inbound_rdata_count.inc(stream_name) | ||||
| 
 | ||||
|         try: | ||||
|             row = STREAMS_MAP[cmd.stream_name].ROW_TYPE(*cmd.row) | ||||
|             row = STREAMS_MAP[stream_name].ROW_TYPE(*cmd.row) | ||||
|         except Exception: | ||||
|             logger.exception( | ||||
|                 "[%s] Failed to parse RDATA: %r %r", | ||||
|                 self.id(), cmd.stream_name, cmd.row | ||||
|                 self.id(), stream_name, cmd.row | ||||
|             ) | ||||
|             raise | ||||
| 
 | ||||
|         if cmd.token is None: | ||||
|             # I.e. this is part of a batch of updates for this stream. Batch | ||||
|             # until we get an update for the stream with a non None token | ||||
|             self.pending_batches.setdefault(cmd.stream_name, []).append(row) | ||||
|             self.pending_batches.setdefault(stream_name, []).append(row) | ||||
|         else: | ||||
|             # Check if this is the last of a batch of updates | ||||
|             rows = self.pending_batches.pop(cmd.stream_name, []) | ||||
|             rows = self.pending_batches.pop(stream_name, []) | ||||
|             rows.append(row) | ||||
| 
 | ||||
|             self.handler.on_rdata(cmd.stream_name, cmd.token, rows) | ||||
|             self.handler.on_rdata(stream_name, cmd.token, rows) | ||||
| 
 | ||||
|     def on_POSITION(self, cmd): | ||||
|         self.handler.on_position(cmd.stream_name, cmd.token) | ||||
|  | @ -644,3 +647,9 @@ metrics.register_callback( | |||
|     }, | ||||
|     labels=["command", "name", "conn_id"], | ||||
| ) | ||||
| 
 | ||||
| # number of updates received for each RDATA stream | ||||
| inbound_rdata_count = metrics.register_counter( | ||||
|     "inbound_rdata_count", | ||||
|     labels=["stream_name"], | ||||
| ) | ||||
|  |  | |||
|  | @ -191,19 +191,25 @@ class LoginRestServlet(ClientV1RestServlet): | |||
| 
 | ||||
|         # convert threepid identifiers to user IDs | ||||
|         if identifier["type"] == "m.id.thirdparty": | ||||
|             if 'medium' not in identifier or 'address' not in identifier: | ||||
|             address = identifier.get('address') | ||||
|             medium = identifier.get('medium') | ||||
| 
 | ||||
|             if medium is None or address is None: | ||||
|                 raise SynapseError(400, "Invalid thirdparty identifier") | ||||
| 
 | ||||
|             address = identifier['address'] | ||||
|             if identifier['medium'] == 'email': | ||||
|             if medium == 'email': | ||||
|                 # For emails, transform the address to lowercase. | ||||
|                 # We store all email addreses as lowercase in the DB. | ||||
|                 # (See add_threepid in synapse/handlers/auth.py) | ||||
|                 address = address.lower() | ||||
|             user_id = yield self.hs.get_datastore().get_user_id_by_threepid( | ||||
|                 identifier['medium'], address | ||||
|                 medium, address, | ||||
|             ) | ||||
|             if not user_id: | ||||
|                 logger.warn( | ||||
|                     "unknown 3pid identifier medium %s, address %r", | ||||
|                     medium, address, | ||||
|                 ) | ||||
|                 raise LoginError(403, "", errcode=Codes.FORBIDDEN) | ||||
| 
 | ||||
|             identifier = { | ||||
|  |  | |||
|  | @ -70,10 +70,15 @@ class RegisterRestServlet(ClientV1RestServlet): | |||
|         self.handlers = hs.get_handlers() | ||||
| 
 | ||||
|     def on_GET(self, request): | ||||
| 
 | ||||
|         require_email = 'email' in self.hs.config.registrations_require_3pid | ||||
|         require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid | ||||
| 
 | ||||
|         flows = [] | ||||
|         if self.hs.config.enable_registration_captcha: | ||||
|             return ( | ||||
|                 200, | ||||
|                 {"flows": [ | ||||
|             # only support the email-only flow if we don't require MSISDN 3PIDs | ||||
|             if not require_msisdn: | ||||
|                 flows.extend([ | ||||
|                     { | ||||
|                         "type": LoginType.RECAPTCHA, | ||||
|                         "stages": [ | ||||
|  | @ -82,27 +87,34 @@ class RegisterRestServlet(ClientV1RestServlet): | |||
|                             LoginType.PASSWORD | ||||
|                         ] | ||||
|                     }, | ||||
|                 ]) | ||||
|             # only support 3PIDless registration if no 3PIDs are required | ||||
|             if not require_email and not require_msisdn: | ||||
|                 flows.extend([ | ||||
|                     { | ||||
|                         "type": LoginType.RECAPTCHA, | ||||
|                         "stages": [LoginType.RECAPTCHA, LoginType.PASSWORD] | ||||
|                     } | ||||
|                 ]} | ||||
|             ) | ||||
|                 ]) | ||||
|         else: | ||||
|             return ( | ||||
|                 200, | ||||
|                 {"flows": [ | ||||
|             # only support the email-only flow if we don't require MSISDN 3PIDs | ||||
|             if require_email or not require_msisdn: | ||||
|                 flows.extend([ | ||||
|                     { | ||||
|                         "type": LoginType.EMAIL_IDENTITY, | ||||
|                         "stages": [ | ||||
|                             LoginType.EMAIL_IDENTITY, LoginType.PASSWORD | ||||
|                         ] | ||||
|                     }, | ||||
|                     } | ||||
|                 ]) | ||||
|             # only support 3PIDless registration if no 3PIDs are required | ||||
|             if not require_email and not require_msisdn: | ||||
|                 flows.extend([ | ||||
|                     { | ||||
|                         "type": LoginType.PASSWORD | ||||
|                     } | ||||
|                 ]} | ||||
|             ) | ||||
|                 ]) | ||||
|         return (200, {"flows": flows}) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_POST(self, request): | ||||
|  |  | |||
|  | @ -195,15 +195,20 @@ class RoomSendEventRestServlet(ClientV1RestServlet): | |||
|         requester = yield self.auth.get_user_by_req(request, allow_guest=True) | ||||
|         content = parse_json_object_from_request(request) | ||||
| 
 | ||||
|         event_dict = { | ||||
|             "type": event_type, | ||||
|             "content": content, | ||||
|             "room_id": room_id, | ||||
|             "sender": requester.user.to_string(), | ||||
|         } | ||||
| 
 | ||||
|         if 'ts' in request.args and requester.app_service: | ||||
|             event_dict['origin_server_ts'] = parse_integer(request, "ts", 0) | ||||
| 
 | ||||
|         msg_handler = self.handlers.message_handler | ||||
|         event = yield msg_handler.create_and_send_nonmember_event( | ||||
|             requester, | ||||
|             { | ||||
|                 "type": event_type, | ||||
|                 "content": content, | ||||
|                 "room_id": room_id, | ||||
|                 "sender": requester.user.to_string(), | ||||
|             }, | ||||
|             event_dict, | ||||
|             txn_id=txn_id, | ||||
|         ) | ||||
| 
 | ||||
|  | @ -487,13 +492,35 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet): | |||
|         defer.returnValue((200, content)) | ||||
| 
 | ||||
| 
 | ||||
| class RoomEventContext(ClientV1RestServlet): | ||||
| class RoomEventServlet(ClientV1RestServlet): | ||||
|     PATTERNS = client_path_patterns( | ||||
|         "/rooms/(?P<room_id>[^/]*)/event/(?P<event_id>[^/]*)$" | ||||
|     ) | ||||
| 
 | ||||
|     def __init__(self, hs): | ||||
|         super(RoomEventServlet, self).__init__(hs) | ||||
|         self.clock = hs.get_clock() | ||||
|         self.event_handler = hs.get_event_handler() | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def on_GET(self, request, room_id, event_id): | ||||
|         requester = yield self.auth.get_user_by_req(request) | ||||
|         event = yield self.event_handler.get_event(requester.user, event_id) | ||||
| 
 | ||||
|         time_now = self.clock.time_msec() | ||||
|         if event: | ||||
|             defer.returnValue((200, serialize_event(event, time_now))) | ||||
|         else: | ||||
|             defer.returnValue((404, "Event not found.")) | ||||
| 
 | ||||
| 
 | ||||
| class RoomEventContextServlet(ClientV1RestServlet): | ||||
|     PATTERNS = client_path_patterns( | ||||
|         "/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$" | ||||
|     ) | ||||
| 
 | ||||
|     def __init__(self, hs): | ||||
|         super(RoomEventContext, self).__init__(hs) | ||||
|         super(RoomEventContextServlet, self).__init__(hs) | ||||
|         self.clock = hs.get_clock() | ||||
|         self.handlers = hs.get_handlers() | ||||
| 
 | ||||
|  | @ -803,4 +830,5 @@ def register_servlets(hs, http_server): | |||
|     RoomTypingRestServlet(hs).register(http_server) | ||||
|     SearchRestServlet(hs).register(http_server) | ||||
|     JoinedRoomsRestServlet(hs).register(http_server) | ||||
|     RoomEventContext(hs).register(http_server) | ||||
|     RoomEventServlet(hs).register(http_server) | ||||
|     RoomEventContextServlet(hs).register(http_server) | ||||
|  |  | |||
|  | @ -26,6 +26,7 @@ from synapse.http.servlet import ( | |||
| ) | ||||
| from synapse.util.async import run_on_reactor | ||||
| from synapse.util.msisdn import phone_number_to_msisdn | ||||
| from synapse.util.threepids import check_3pid_allowed | ||||
| from ._base import client_v2_patterns, interactive_auth_handler | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
|  | @ -47,6 +48,11 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): | |||
|             'id_server', 'client_secret', 'email', 'send_attempt' | ||||
|         ]) | ||||
| 
 | ||||
|         if not check_3pid_allowed(self.hs, "email", body['email']): | ||||
|             raise SynapseError( | ||||
|                 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED, | ||||
|             ) | ||||
| 
 | ||||
|         existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( | ||||
|             'email', body['email'] | ||||
|         ) | ||||
|  | @ -78,6 +84,11 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet): | |||
| 
 | ||||
|         msisdn = phone_number_to_msisdn(body['country'], body['phone_number']) | ||||
| 
 | ||||
|         if not check_3pid_allowed(self.hs, "msisdn", msisdn): | ||||
|             raise SynapseError( | ||||
|                 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED, | ||||
|             ) | ||||
| 
 | ||||
|         existingUid = yield self.datastore.get_user_id_by_threepid( | ||||
|             'msisdn', msisdn | ||||
|         ) | ||||
|  | @ -217,6 +228,11 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): | |||
|         if absent: | ||||
|             raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) | ||||
| 
 | ||||
|         if not check_3pid_allowed(self.hs, "email", body['email']): | ||||
|             raise SynapseError( | ||||
|                 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED, | ||||
|             ) | ||||
| 
 | ||||
|         existingUid = yield self.datastore.get_user_id_by_threepid( | ||||
|             'email', body['email'] | ||||
|         ) | ||||
|  | @ -255,6 +271,11 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): | |||
| 
 | ||||
|         msisdn = phone_number_to_msisdn(body['country'], body['phone_number']) | ||||
| 
 | ||||
|         if not check_3pid_allowed(self.hs, "msisdn", msisdn): | ||||
|             raise SynapseError( | ||||
|                 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED, | ||||
|             ) | ||||
| 
 | ||||
|         existingUid = yield self.datastore.get_user_id_by_threepid( | ||||
|             'msisdn', msisdn | ||||
|         ) | ||||
|  |  | |||
|  | @ -26,6 +26,7 @@ from synapse.http.servlet import ( | |||
|     RestServlet, parse_json_object_from_request, assert_params_in_request, parse_string | ||||
| ) | ||||
| from synapse.util.msisdn import phone_number_to_msisdn | ||||
| from synapse.util.threepids import check_3pid_allowed | ||||
| 
 | ||||
| from ._base import client_v2_patterns, interactive_auth_handler | ||||
| 
 | ||||
|  | @ -70,6 +71,11 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): | |||
|             'id_server', 'client_secret', 'email', 'send_attempt' | ||||
|         ]) | ||||
| 
 | ||||
|         if not check_3pid_allowed(self.hs, "email", body['email']): | ||||
|             raise SynapseError( | ||||
|                 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED, | ||||
|             ) | ||||
| 
 | ||||
|         existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( | ||||
|             'email', body['email'] | ||||
|         ) | ||||
|  | @ -105,6 +111,11 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): | |||
| 
 | ||||
|         msisdn = phone_number_to_msisdn(body['country'], body['phone_number']) | ||||
| 
 | ||||
|         if not check_3pid_allowed(self.hs, "msisdn", msisdn): | ||||
|             raise SynapseError( | ||||
|                 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED, | ||||
|             ) | ||||
| 
 | ||||
|         existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( | ||||
|             'msisdn', msisdn | ||||
|         ) | ||||
|  | @ -305,31 +316,67 @@ class RegisterRestServlet(RestServlet): | |||
|         if 'x_show_msisdn' in body and body['x_show_msisdn']: | ||||
|             show_msisdn = True | ||||
| 
 | ||||
|         # FIXME: need a better error than "no auth flow found" for scenarios | ||||
|         # where we required 3PID for registration but the user didn't give one | ||||
|         require_email = 'email' in self.hs.config.registrations_require_3pid | ||||
|         require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid | ||||
| 
 | ||||
|         flows = [] | ||||
|         if self.hs.config.enable_registration_captcha: | ||||
|             flows = [ | ||||
|                 [LoginType.RECAPTCHA], | ||||
|                 [LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA], | ||||
|             ] | ||||
|             # only support 3PIDless registration if no 3PIDs are required | ||||
|             if not require_email and not require_msisdn: | ||||
|                 flows.extend([[LoginType.RECAPTCHA]]) | ||||
|             # only support the email-only flow if we don't require MSISDN 3PIDs | ||||
|             if not require_msisdn: | ||||
|                 flows.extend([[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA]]) | ||||
| 
 | ||||
|             if show_msisdn: | ||||
|                 # only support the MSISDN-only flow if we don't require email 3PIDs | ||||
|                 if not require_email: | ||||
|                     flows.extend([[LoginType.MSISDN, LoginType.RECAPTCHA]]) | ||||
|                 # always let users provide both MSISDN & email | ||||
|                 flows.extend([ | ||||
|                     [LoginType.MSISDN, LoginType.RECAPTCHA], | ||||
|                     [LoginType.MSISDN, LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA], | ||||
|                 ]) | ||||
|         else: | ||||
|             flows = [ | ||||
|                 [LoginType.DUMMY], | ||||
|                 [LoginType.EMAIL_IDENTITY], | ||||
|             ] | ||||
|             # only support 3PIDless registration if no 3PIDs are required | ||||
|             if not require_email and not require_msisdn: | ||||
|                 flows.extend([[LoginType.DUMMY]]) | ||||
|             # only support the email-only flow if we don't require MSISDN 3PIDs | ||||
|             if not require_msisdn: | ||||
|                 flows.extend([[LoginType.EMAIL_IDENTITY]]) | ||||
| 
 | ||||
|             if show_msisdn: | ||||
|                 # only support the MSISDN-only flow if we don't require email 3PIDs | ||||
|                 if not require_email or require_msisdn: | ||||
|                     flows.extend([[LoginType.MSISDN]]) | ||||
|                 # always let users provide both MSISDN & email | ||||
|                 flows.extend([ | ||||
|                     [LoginType.MSISDN], | ||||
|                     [LoginType.MSISDN, LoginType.EMAIL_IDENTITY], | ||||
|                     [LoginType.MSISDN, LoginType.EMAIL_IDENTITY] | ||||
|                 ]) | ||||
| 
 | ||||
|         auth_result, params, session_id = yield self.auth_handler.check_auth( | ||||
|             flows, body, self.hs.get_ip_from_request(request) | ||||
|         ) | ||||
| 
 | ||||
|         # Check that we're not trying to register a denied 3pid. | ||||
|         # | ||||
|         # the user-facing checks will probably already have happened in | ||||
|         # /register/email/requestToken when we requested a 3pid, but that's not | ||||
|         # guaranteed. | ||||
| 
 | ||||
|         if auth_result: | ||||
|             for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]: | ||||
|                 if login_type in auth_result: | ||||
|                     medium = auth_result[login_type]['medium'] | ||||
|                     address = auth_result[login_type]['address'] | ||||
| 
 | ||||
|                     if not check_3pid_allowed(self.hs, medium, address): | ||||
|                         raise SynapseError( | ||||
|                             403, "Third party identifier is not allowed", | ||||
|                             Codes.THREEPID_DENIED, | ||||
|                         ) | ||||
| 
 | ||||
|         if registered_user_id is not None: | ||||
|             logger.info( | ||||
|                 "Already registered user ID %r for this session", | ||||
|  |  | |||
|  | @ -93,6 +93,7 @@ class RemoteKey(Resource): | |||
|         self.store = hs.get_datastore() | ||||
|         self.version_string = hs.version_string | ||||
|         self.clock = hs.get_clock() | ||||
|         self.federation_domain_whitelist = hs.config.federation_domain_whitelist | ||||
| 
 | ||||
|     def render_GET(self, request): | ||||
|         self.async_render_GET(request) | ||||
|  | @ -137,6 +138,13 @@ class RemoteKey(Resource): | |||
|         logger.info("Handling query for keys %r", query) | ||||
|         store_queries = [] | ||||
|         for server_name, key_ids in query.items(): | ||||
|             if ( | ||||
|                 self.federation_domain_whitelist is not None and | ||||
|                 server_name not in self.federation_domain_whitelist | ||||
|             ): | ||||
|                 logger.debug("Federation denied with %s", server_name) | ||||
|                 continue | ||||
| 
 | ||||
|             if not key_ids: | ||||
|                 key_ids = (None,) | ||||
|             for key_id in key_ids: | ||||
|  |  | |||
|  | @ -70,38 +70,11 @@ def respond_with_file(request, media_type, file_path, | |||
|     logger.debug("Responding with %r", file_path) | ||||
| 
 | ||||
|     if os.path.isfile(file_path): | ||||
|         request.setHeader(b"Content-Type", media_type.encode("UTF-8")) | ||||
|         if upload_name: | ||||
|             if is_ascii(upload_name): | ||||
|                 request.setHeader( | ||||
|                     b"Content-Disposition", | ||||
|                     b"inline; filename=%s" % ( | ||||
|                         urllib.quote(upload_name.encode("utf-8")), | ||||
|                     ), | ||||
|                 ) | ||||
|             else: | ||||
|                 request.setHeader( | ||||
|                     b"Content-Disposition", | ||||
|                     b"inline; filename*=utf-8''%s" % ( | ||||
|                         urllib.quote(upload_name.encode("utf-8")), | ||||
|                     ), | ||||
|                 ) | ||||
| 
 | ||||
|         # cache for at least a day. | ||||
|         # XXX: we might want to turn this off for data we don't want to | ||||
|         # recommend caching as it's sensitive or private - or at least | ||||
|         # select private. don't bother setting Expires as all our | ||||
|         # clients are smart enough to be happy with Cache-Control | ||||
|         request.setHeader( | ||||
|             b"Cache-Control", b"public,max-age=86400,s-maxage=86400" | ||||
|         ) | ||||
|         if file_size is None: | ||||
|             stat = os.stat(file_path) | ||||
|             file_size = stat.st_size | ||||
| 
 | ||||
|         request.setHeader( | ||||
|             b"Content-Length", b"%d" % (file_size,) | ||||
|         ) | ||||
|         add_file_headers(request, media_type, file_size, upload_name) | ||||
| 
 | ||||
|         with open(file_path, "rb") as f: | ||||
|             yield logcontext.make_deferred_yieldable( | ||||
|  | @ -111,3 +84,118 @@ def respond_with_file(request, media_type, file_path, | |||
|         finish_request(request) | ||||
|     else: | ||||
|         respond_404(request) | ||||
| 
 | ||||
| 
 | ||||
| def add_file_headers(request, media_type, file_size, upload_name): | ||||
|     """Adds the correct response headers in preparation for responding with the | ||||
|     media. | ||||
| 
 | ||||
|     Args: | ||||
|         request (twisted.web.http.Request) | ||||
|         media_type (str): The media/content type. | ||||
|         file_size (int): Size in bytes of the media, if known. | ||||
|         upload_name (str): The name of the requested file, if any. | ||||
|     """ | ||||
|     request.setHeader(b"Content-Type", media_type.encode("UTF-8")) | ||||
|     if upload_name: | ||||
|         if is_ascii(upload_name): | ||||
|             request.setHeader( | ||||
|                 b"Content-Disposition", | ||||
|                 b"inline; filename=%s" % ( | ||||
|                     urllib.quote(upload_name.encode("utf-8")), | ||||
|                 ), | ||||
|             ) | ||||
|         else: | ||||
|             request.setHeader( | ||||
|                 b"Content-Disposition", | ||||
|                 b"inline; filename*=utf-8''%s" % ( | ||||
|                     urllib.quote(upload_name.encode("utf-8")), | ||||
|                 ), | ||||
|             ) | ||||
| 
 | ||||
|     # cache for at least a day. | ||||
|     # XXX: we might want to turn this off for data we don't want to | ||||
|     # recommend caching as it's sensitive or private - or at least | ||||
|     # select private. don't bother setting Expires as all our | ||||
|     # clients are smart enough to be happy with Cache-Control | ||||
|     request.setHeader( | ||||
|         b"Cache-Control", b"public,max-age=86400,s-maxage=86400" | ||||
|     ) | ||||
| 
 | ||||
|     request.setHeader( | ||||
|         b"Content-Length", b"%d" % (file_size,) | ||||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| @defer.inlineCallbacks | ||||
| def respond_with_responder(request, responder, media_type, file_size, upload_name=None): | ||||
|     """Responds to the request with given responder. If responder is None then | ||||
|     returns 404. | ||||
| 
 | ||||
|     Args: | ||||
|         request (twisted.web.http.Request) | ||||
|         responder (Responder|None) | ||||
|         media_type (str): The media/content type. | ||||
|         file_size (int|None): Size in bytes of the media. If not known it should be None | ||||
|         upload_name (str|None): The name of the requested file, if any. | ||||
|     """ | ||||
|     if not responder: | ||||
|         respond_404(request) | ||||
|         return | ||||
| 
 | ||||
|     add_file_headers(request, media_type, file_size, upload_name) | ||||
|     with responder: | ||||
|         yield responder.write_to_consumer(request) | ||||
|     finish_request(request) | ||||
| 
 | ||||
| 
 | ||||
| class Responder(object): | ||||
|     """Represents a response that can be streamed to the requester. | ||||
| 
 | ||||
|     Responder is a context manager which *must* be used, so that any resources | ||||
|     held can be cleaned up. | ||||
|     """ | ||||
|     def write_to_consumer(self, consumer): | ||||
|         """Stream response into consumer | ||||
| 
 | ||||
|         Args: | ||||
|             consumer (IConsumer) | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred: Resolves once the response has finished being written | ||||
|         """ | ||||
|         pass | ||||
| 
 | ||||
|     def __enter__(self): | ||||
|         pass | ||||
| 
 | ||||
|     def __exit__(self, exc_type, exc_val, exc_tb): | ||||
|         pass | ||||
| 
 | ||||
| 
 | ||||
| class FileInfo(object): | ||||
|     """Details about a requested/uploaded file. | ||||
| 
 | ||||
|     Attributes: | ||||
|         server_name (str): The server name where the media originated from, | ||||
|             or None if local. | ||||
|         file_id (str): The local ID of the file. For local files this is the | ||||
|             same as the media_id | ||||
|         url_cache (bool): If the file is for the url preview cache | ||||
|         thumbnail (bool): Whether the file is a thumbnail or not. | ||||
|         thumbnail_width (int) | ||||
|         thumbnail_height (int) | ||||
|         thumbnail_method (str) | ||||
|         thumbnail_type (str): Content type of thumbnail, e.g. image/png | ||||
|     """ | ||||
|     def __init__(self, server_name, file_id, url_cache=False, | ||||
|                  thumbnail=False, thumbnail_width=None, thumbnail_height=None, | ||||
|                  thumbnail_method=None, thumbnail_type=None): | ||||
|         self.server_name = server_name | ||||
|         self.file_id = file_id | ||||
|         self.url_cache = url_cache | ||||
|         self.thumbnail = thumbnail | ||||
|         self.thumbnail_width = thumbnail_width | ||||
|         self.thumbnail_height = thumbnail_height | ||||
|         self.thumbnail_method = thumbnail_method | ||||
|         self.thumbnail_type = thumbnail_type | ||||
|  |  | |||
|  | @ -14,7 +14,7 @@ | |||
| # limitations under the License. | ||||
| import synapse.http.servlet | ||||
| 
 | ||||
| from ._base import parse_media_id, respond_with_file, respond_404 | ||||
| from ._base import parse_media_id, respond_404 | ||||
| from twisted.web.resource import Resource | ||||
| from synapse.http.server import request_handler, set_cors_headers | ||||
| 
 | ||||
|  | @ -32,12 +32,12 @@ class DownloadResource(Resource): | |||
|     def __init__(self, hs, media_repo): | ||||
|         Resource.__init__(self) | ||||
| 
 | ||||
|         self.filepaths = media_repo.filepaths | ||||
|         self.media_repo = media_repo | ||||
|         self.server_name = hs.hostname | ||||
|         self.store = hs.get_datastore() | ||||
|         self.version_string = hs.version_string | ||||
| 
 | ||||
|         # Both of these are expected by @request_handler() | ||||
|         self.clock = hs.get_clock() | ||||
|         self.version_string = hs.version_string | ||||
| 
 | ||||
|     def render_GET(self, request): | ||||
|         self._async_render_GET(request) | ||||
|  | @ -57,59 +57,16 @@ class DownloadResource(Resource): | |||
|         ) | ||||
|         server_name, media_id, name = parse_media_id(request) | ||||
|         if server_name == self.server_name: | ||||
|             yield self._respond_local_file(request, media_id, name) | ||||
|             yield self.media_repo.get_local_media(request, media_id, name) | ||||
|         else: | ||||
|             yield self._respond_remote_file( | ||||
|                 request, server_name, media_id, name | ||||
|             ) | ||||
|             allow_remote = synapse.http.servlet.parse_boolean( | ||||
|                 request, "allow_remote", default=True) | ||||
|             if not allow_remote: | ||||
|                 logger.info( | ||||
|                     "Rejecting request for remote media %s/%s due to allow_remote", | ||||
|                     server_name, media_id, | ||||
|                 ) | ||||
|                 respond_404(request) | ||||
|                 return | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _respond_local_file(self, request, media_id, name): | ||||
|         media_info = yield self.store.get_local_media(media_id) | ||||
|         if not media_info or media_info["quarantined_by"]: | ||||
|             respond_404(request) | ||||
|             return | ||||
| 
 | ||||
|         media_type = media_info["media_type"] | ||||
|         media_length = media_info["media_length"] | ||||
|         upload_name = name if name else media_info["upload_name"] | ||||
|         if media_info["url_cache"]: | ||||
|             # TODO: Check the file still exists, if it doesn't we can redownload | ||||
|             # it from the url `media_info["url_cache"]` | ||||
|             file_path = self.filepaths.url_cache_filepath(media_id) | ||||
|         else: | ||||
|             file_path = self.filepaths.local_media_filepath(media_id) | ||||
| 
 | ||||
|         yield respond_with_file( | ||||
|             request, media_type, file_path, media_length, | ||||
|             upload_name=upload_name, | ||||
|         ) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _respond_remote_file(self, request, server_name, media_id, name): | ||||
|         # don't forward requests for remote media if allow_remote is false | ||||
|         allow_remote = synapse.http.servlet.parse_boolean( | ||||
|             request, "allow_remote", default=True) | ||||
|         if not allow_remote: | ||||
|             logger.info( | ||||
|                 "Rejecting request for remote media %s/%s due to allow_remote", | ||||
|                 server_name, media_id, | ||||
|             ) | ||||
|             respond_404(request) | ||||
|             return | ||||
| 
 | ||||
|         media_info = yield self.media_repo.get_remote_media(server_name, media_id) | ||||
| 
 | ||||
|         media_type = media_info["media_type"] | ||||
|         media_length = media_info["media_length"] | ||||
|         filesystem_id = media_info["filesystem_id"] | ||||
|         upload_name = name if name else media_info["upload_name"] | ||||
| 
 | ||||
|         file_path = self.filepaths.remote_media_filepath( | ||||
|             server_name, filesystem_id | ||||
|         ) | ||||
| 
 | ||||
|         yield respond_with_file( | ||||
|             request, media_type, file_path, media_length, | ||||
|             upload_name=upload_name, | ||||
|         ) | ||||
|             yield self.media_repo.get_remote_media(request, server_name, media_id, name) | ||||
|  |  | |||
|  | @ -1,5 +1,6 @@ | |||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2014-2016 OpenMarket Ltd | ||||
| # Copyright 2018 New Vector Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
|  | @ -18,6 +19,7 @@ import twisted.internet.error | |||
| import twisted.web.http | ||||
| from twisted.web.resource import Resource | ||||
| 
 | ||||
| from ._base import respond_404, FileInfo, respond_with_responder | ||||
| from .upload_resource import UploadResource | ||||
| from .download_resource import DownloadResource | ||||
| from .thumbnail_resource import ThumbnailResource | ||||
|  | @ -25,15 +27,18 @@ from .identicon_resource import IdenticonResource | |||
| from .preview_url_resource import PreviewUrlResource | ||||
| from .filepath import MediaFilePaths | ||||
| from .thumbnailer import Thumbnailer | ||||
| from .storage_provider import StorageProviderWrapper | ||||
| from .media_storage import MediaStorage | ||||
| 
 | ||||
| from synapse.http.matrixfederationclient import MatrixFederationHttpClient | ||||
| from synapse.util.stringutils import random_string | ||||
| from synapse.api.errors import SynapseError, HttpResponseException, \ | ||||
|     NotFoundError | ||||
| from synapse.api.errors import ( | ||||
|     SynapseError, HttpResponseException, NotFoundError, FederationDeniedError, | ||||
| ) | ||||
| 
 | ||||
| from synapse.util.async import Linearizer | ||||
| from synapse.util.stringutils import is_ascii | ||||
| from synapse.util.logcontext import make_deferred_yieldable, preserve_fn | ||||
| from synapse.util.logcontext import make_deferred_yieldable | ||||
| from synapse.util.retryutils import NotRetryingDestination | ||||
| 
 | ||||
| import os | ||||
|  | @ -47,7 +52,7 @@ import urlparse | |||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| UPDATE_RECENTLY_ACCESSED_REMOTES_TS = 60 * 1000 | ||||
| UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000 | ||||
| 
 | ||||
| 
 | ||||
| class MediaRepository(object): | ||||
|  | @ -63,96 +68,62 @@ class MediaRepository(object): | |||
|         self.primary_base_path = hs.config.media_store_path | ||||
|         self.filepaths = MediaFilePaths(self.primary_base_path) | ||||
| 
 | ||||
|         self.backup_base_path = hs.config.backup_media_store_path | ||||
| 
 | ||||
|         self.synchronous_backup_media_store = hs.config.synchronous_backup_media_store | ||||
| 
 | ||||
|         self.dynamic_thumbnails = hs.config.dynamic_thumbnails | ||||
|         self.thumbnail_requirements = hs.config.thumbnail_requirements | ||||
| 
 | ||||
|         self.remote_media_linearizer = Linearizer(name="media_remote") | ||||
| 
 | ||||
|         self.recently_accessed_remotes = set() | ||||
|         self.recently_accessed_locals = set() | ||||
| 
 | ||||
|         self.federation_domain_whitelist = hs.config.federation_domain_whitelist | ||||
| 
 | ||||
|         # List of StorageProviders where we should search for media and | ||||
|         # potentially upload to. | ||||
|         storage_providers = [] | ||||
| 
 | ||||
|         for clz, provider_config, wrapper_config in hs.config.media_storage_providers: | ||||
|             backend = clz(hs, provider_config) | ||||
|             provider = StorageProviderWrapper( | ||||
|                 backend, | ||||
|                 store_local=wrapper_config.store_local, | ||||
|                 store_remote=wrapper_config.store_remote, | ||||
|                 store_synchronous=wrapper_config.store_synchronous, | ||||
|             ) | ||||
|             storage_providers.append(provider) | ||||
| 
 | ||||
|         self.media_storage = MediaStorage( | ||||
|             self.primary_base_path, self.filepaths, storage_providers, | ||||
|         ) | ||||
| 
 | ||||
|         self.clock.looping_call( | ||||
|             self._update_recently_accessed_remotes, | ||||
|             UPDATE_RECENTLY_ACCESSED_REMOTES_TS | ||||
|             self._update_recently_accessed, | ||||
|             UPDATE_RECENTLY_ACCESSED_TS, | ||||
|         ) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _update_recently_accessed_remotes(self): | ||||
|         media = self.recently_accessed_remotes | ||||
|     def _update_recently_accessed(self): | ||||
|         remote_media = self.recently_accessed_remotes | ||||
|         self.recently_accessed_remotes = set() | ||||
| 
 | ||||
|         local_media = self.recently_accessed_locals | ||||
|         self.recently_accessed_locals = set() | ||||
| 
 | ||||
|         yield self.store.update_cached_last_access_time( | ||||
|             media, self.clock.time_msec() | ||||
|             local_media, remote_media, self.clock.time_msec() | ||||
|         ) | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def _makedirs(filepath): | ||||
|         dirname = os.path.dirname(filepath) | ||||
|         if not os.path.exists(dirname): | ||||
|             os.makedirs(dirname) | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def _write_file_synchronously(source, fname): | ||||
|         """Write `source` to the path `fname` synchronously. Should be called | ||||
|         from a thread. | ||||
|     def mark_recently_accessed(self, server_name, media_id): | ||||
|         """Mark the given media as recently accessed. | ||||
| 
 | ||||
|         Args: | ||||
|             source: A file like object to be written | ||||
|             fname (str): Path to write to | ||||
|             server_name (str|None): Origin server of media, or None if local | ||||
|             media_id (str): The media ID of the content | ||||
|         """ | ||||
|         MediaRepository._makedirs(fname) | ||||
|         source.seek(0)  # Ensure we read from the start of the file | ||||
|         with open(fname, "wb") as f: | ||||
|             shutil.copyfileobj(source, f) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def write_to_file_and_backup(self, source, path): | ||||
|         """Write `source` to the on disk media store, and also the backup store | ||||
|         if configured. | ||||
| 
 | ||||
|         Args: | ||||
|             source: A file like object that should be written | ||||
|             path (str): Relative path to write file to | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred[str]: the file path written to in the primary media store | ||||
|         """ | ||||
|         fname = os.path.join(self.primary_base_path, path) | ||||
| 
 | ||||
|         # Write to the main repository | ||||
|         yield make_deferred_yieldable(threads.deferToThread( | ||||
|             self._write_file_synchronously, source, fname, | ||||
|         )) | ||||
| 
 | ||||
|         # Write to backup repository | ||||
|         yield self.copy_to_backup(path) | ||||
| 
 | ||||
|         defer.returnValue(fname) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def copy_to_backup(self, path): | ||||
|         """Copy a file from the primary to backup media store, if configured. | ||||
| 
 | ||||
|         Args: | ||||
|             path(str): Relative path to write file to | ||||
|         """ | ||||
|         if self.backup_base_path: | ||||
|             primary_fname = os.path.join(self.primary_base_path, path) | ||||
|             backup_fname = os.path.join(self.backup_base_path, path) | ||||
| 
 | ||||
|             # We can either wait for successful writing to the backup repository | ||||
|             # or write in the background and immediately return | ||||
|             if self.synchronous_backup_media_store: | ||||
|                 yield make_deferred_yieldable(threads.deferToThread( | ||||
|                     shutil.copyfile, primary_fname, backup_fname, | ||||
|                 )) | ||||
|             else: | ||||
|                 preserve_fn(threads.deferToThread)( | ||||
|                     shutil.copyfile, primary_fname, backup_fname, | ||||
|                 ) | ||||
|         if server_name: | ||||
|             self.recently_accessed_remotes.add((server_name, media_id)) | ||||
|         else: | ||||
|             self.recently_accessed_locals.add(media_id) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def create_content(self, media_type, upload_name, content, content_length, | ||||
|  | @ -171,10 +142,13 @@ class MediaRepository(object): | |||
|         """ | ||||
|         media_id = random_string(24) | ||||
| 
 | ||||
|         fname = yield self.write_to_file_and_backup( | ||||
|             content, self.filepaths.local_media_filepath_rel(media_id) | ||||
|         file_info = FileInfo( | ||||
|             server_name=None, | ||||
|             file_id=media_id, | ||||
|         ) | ||||
| 
 | ||||
|         fname = yield self.media_storage.store_file(content, file_info) | ||||
| 
 | ||||
|         logger.info("Stored local media in file %r", fname) | ||||
| 
 | ||||
|         yield self.store.store_local_media( | ||||
|  | @ -185,134 +159,275 @@ class MediaRepository(object): | |||
|             media_length=content_length, | ||||
|             user_id=auth_user, | ||||
|         ) | ||||
|         media_info = { | ||||
|             "media_type": media_type, | ||||
|             "media_length": content_length, | ||||
|         } | ||||
| 
 | ||||
|         yield self._generate_thumbnails(None, media_id, media_info) | ||||
|         yield self._generate_thumbnails( | ||||
|             None, media_id, media_id, media_type, | ||||
|         ) | ||||
| 
 | ||||
|         defer.returnValue("mxc://%s/%s" % (self.server_name, media_id)) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def get_remote_media(self, server_name, media_id): | ||||
|     def get_local_media(self, request, media_id, name): | ||||
|         """Responds to reqests for local media, if exists, or returns 404. | ||||
| 
 | ||||
|         Args: | ||||
|             request(twisted.web.http.Request) | ||||
|             media_id (str): The media ID of the content. (This is the same as | ||||
|                 the file_id for local content.) | ||||
|             name (str|None): Optional name that, if specified, will be used as | ||||
|                 the filename in the Content-Disposition header of the response. | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred: Resolves once a response has successfully been written | ||||
|                 to request | ||||
|         """ | ||||
|         media_info = yield self.store.get_local_media(media_id) | ||||
|         if not media_info or media_info["quarantined_by"]: | ||||
|             respond_404(request) | ||||
|             return | ||||
| 
 | ||||
|         self.mark_recently_accessed(None, media_id) | ||||
| 
 | ||||
|         media_type = media_info["media_type"] | ||||
|         media_length = media_info["media_length"] | ||||
|         upload_name = name if name else media_info["upload_name"] | ||||
|         url_cache = media_info["url_cache"] | ||||
| 
 | ||||
|         file_info = FileInfo( | ||||
|             None, media_id, | ||||
|             url_cache=url_cache, | ||||
|         ) | ||||
| 
 | ||||
|         responder = yield self.media_storage.fetch_media(file_info) | ||||
|         yield respond_with_responder( | ||||
|             request, responder, media_type, media_length, upload_name, | ||||
|         ) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def get_remote_media(self, request, server_name, media_id, name): | ||||
|         """Respond to requests for remote media. | ||||
| 
 | ||||
|         Args: | ||||
|             request(twisted.web.http.Request) | ||||
|             server_name (str): Remote server_name where the media originated. | ||||
|             media_id (str): The media ID of the content (as defined by the | ||||
|                 remote server). | ||||
|             name (str|None): Optional name that, if specified, will be used as | ||||
|                 the filename in the Content-Disposition header of the response. | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred: Resolves once a response has successfully been written | ||||
|                 to request | ||||
|         """ | ||||
|         if ( | ||||
|             self.federation_domain_whitelist is not None and | ||||
|             server_name not in self.federation_domain_whitelist | ||||
|         ): | ||||
|             raise FederationDeniedError(server_name) | ||||
| 
 | ||||
|         self.mark_recently_accessed(server_name, media_id) | ||||
| 
 | ||||
|         # We linearize here to ensure that we don't try and download remote | ||||
|         # media multiple times concurrently | ||||
|         key = (server_name, media_id) | ||||
|         with (yield self.remote_media_linearizer.queue(key)): | ||||
|             media_info = yield self._get_remote_media_impl(server_name, media_id) | ||||
|             responder, media_info = yield self._get_remote_media_impl( | ||||
|                 server_name, media_id, | ||||
|             ) | ||||
| 
 | ||||
|         # We deliberately stream the file outside the lock | ||||
|         if responder: | ||||
|             media_type = media_info["media_type"] | ||||
|             media_length = media_info["media_length"] | ||||
|             upload_name = name if name else media_info["upload_name"] | ||||
|             yield respond_with_responder( | ||||
|                 request, responder, media_type, media_length, upload_name, | ||||
|             ) | ||||
|         else: | ||||
|             respond_404(request) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def get_remote_media_info(self, server_name, media_id): | ||||
|         """Gets the media info associated with the remote file, downloading | ||||
|         if necessary. | ||||
| 
 | ||||
|         Args: | ||||
|             server_name (str): Remote server_name where the media originated. | ||||
|             media_id (str): The media ID of the content (as defined by the | ||||
|                 remote server). | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred[dict]: The media_info of the file | ||||
|         """ | ||||
|         if ( | ||||
|             self.federation_domain_whitelist is not None and | ||||
|             server_name not in self.federation_domain_whitelist | ||||
|         ): | ||||
|             raise FederationDeniedError(server_name) | ||||
| 
 | ||||
|         # We linearize here to ensure that we don't try and download remote | ||||
|         # media multiple times concurrently | ||||
|         key = (server_name, media_id) | ||||
|         with (yield self.remote_media_linearizer.queue(key)): | ||||
|             responder, media_info = yield self._get_remote_media_impl( | ||||
|                 server_name, media_id, | ||||
|             ) | ||||
| 
 | ||||
|         # Ensure we actually use the responder so that it releases resources | ||||
|         if responder: | ||||
|             with responder: | ||||
|                 pass | ||||
| 
 | ||||
|         defer.returnValue(media_info) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _get_remote_media_impl(self, server_name, media_id): | ||||
|         """Looks for media in local cache, if not there then attempt to | ||||
|         download from remote server. | ||||
| 
 | ||||
|         Args: | ||||
|             server_name (str): Remote server_name where the media originated. | ||||
|             media_id (str): The media ID of the content (as defined by the | ||||
|                 remote server). | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred[(Responder, media_info)] | ||||
|         """ | ||||
|         media_info = yield self.store.get_cached_remote_media( | ||||
|             server_name, media_id | ||||
|         ) | ||||
|         if not media_info: | ||||
|             media_info = yield self._download_remote_file( | ||||
|                 server_name, media_id | ||||
|             ) | ||||
|         elif media_info["quarantined_by"]: | ||||
|             raise NotFoundError() | ||||
| 
 | ||||
|         # file_id is the ID we use to track the file locally. If we've already | ||||
|         # seen the file then reuse the existing ID, otherwise genereate a new | ||||
|         # one. | ||||
|         if media_info: | ||||
|             file_id = media_info["filesystem_id"] | ||||
|         else: | ||||
|             self.recently_accessed_remotes.add((server_name, media_id)) | ||||
|             yield self.store.update_cached_last_access_time( | ||||
|                 [(server_name, media_id)], self.clock.time_msec() | ||||
|             ) | ||||
|         defer.returnValue(media_info) | ||||
|             file_id = random_string(24) | ||||
| 
 | ||||
|         file_info = FileInfo(server_name, file_id) | ||||
| 
 | ||||
|         # If we have an entry in the DB, try and look for it | ||||
|         if media_info: | ||||
|             if media_info["quarantined_by"]: | ||||
|                 logger.info("Media is quarantined") | ||||
|                 raise NotFoundError() | ||||
| 
 | ||||
|             responder = yield self.media_storage.fetch_media(file_info) | ||||
|             if responder: | ||||
|                 defer.returnValue((responder, media_info)) | ||||
| 
 | ||||
|         # Failed to find the file anywhere, lets download it. | ||||
| 
 | ||||
|         media_info = yield self._download_remote_file( | ||||
|             server_name, media_id, file_id | ||||
|         ) | ||||
| 
 | ||||
|         responder = yield self.media_storage.fetch_media(file_info) | ||||
|         defer.returnValue((responder, media_info)) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _download_remote_file(self, server_name, media_id): | ||||
|         file_id = random_string(24) | ||||
|     def _download_remote_file(self, server_name, media_id, file_id): | ||||
|         """Attempt to download the remote file from the given server name, | ||||
|         using the given file_id as the local id. | ||||
| 
 | ||||
|         fpath = self.filepaths.remote_media_filepath_rel( | ||||
|             server_name, file_id | ||||
|         Args: | ||||
|             server_name (str): Originating server | ||||
|             media_id (str): The media ID of the content (as defined by the | ||||
|                 remote server). This is different than the file_id, which is | ||||
|                 locally generated. | ||||
|             file_id (str): Local file ID | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred[MediaInfo] | ||||
|         """ | ||||
| 
 | ||||
|         file_info = FileInfo( | ||||
|             server_name=server_name, | ||||
|             file_id=file_id, | ||||
|         ) | ||||
|         fname = os.path.join(self.primary_base_path, fpath) | ||||
|         self._makedirs(fname) | ||||
| 
 | ||||
|         try: | ||||
|             with open(fname, "wb") as f: | ||||
|                 request_path = "/".join(( | ||||
|                     "/_matrix/media/v1/download", server_name, media_id, | ||||
|                 )) | ||||
|         with self.media_storage.store_into_file(file_info) as (f, fname, finish): | ||||
|             request_path = "/".join(( | ||||
|                 "/_matrix/media/v1/download", server_name, media_id, | ||||
|             )) | ||||
|             try: | ||||
|                 length, headers = yield self.client.get_file( | ||||
|                     server_name, request_path, output_stream=f, | ||||
|                     max_size=self.max_upload_size, args={ | ||||
|                         # tell the remote server to 404 if it doesn't | ||||
|                         # recognise the server_name, to make sure we don't | ||||
|                         # end up with a routing loop. | ||||
|                         "allow_remote": "false", | ||||
|                     } | ||||
|                 ) | ||||
|             except twisted.internet.error.DNSLookupError as e: | ||||
|                 logger.warn("HTTP error fetching remote media %s/%s: %r", | ||||
|                             server_name, media_id, e) | ||||
|                 raise NotFoundError() | ||||
| 
 | ||||
|             except HttpResponseException as e: | ||||
|                 logger.warn("HTTP error fetching remote media %s/%s: %s", | ||||
|                             server_name, media_id, e.response) | ||||
|                 if e.code == twisted.web.http.NOT_FOUND: | ||||
|                     raise SynapseError.from_http_response_exception(e) | ||||
|                 raise SynapseError(502, "Failed to fetch remote media") | ||||
| 
 | ||||
|             except SynapseError: | ||||
|                 logger.exception("Failed to fetch remote media %s/%s", | ||||
|                                  server_name, media_id) | ||||
|                 raise | ||||
|             except NotRetryingDestination: | ||||
|                 logger.warn("Not retrying destination %r", server_name) | ||||
|                 raise SynapseError(502, "Failed to fetch remote media") | ||||
|             except Exception: | ||||
|                 logger.exception("Failed to fetch remote media %s/%s", | ||||
|                                  server_name, media_id) | ||||
|                 raise SynapseError(502, "Failed to fetch remote media") | ||||
| 
 | ||||
|             yield finish() | ||||
| 
 | ||||
|         media_type = headers["Content-Type"][0] | ||||
| 
 | ||||
|         time_now_ms = self.clock.time_msec() | ||||
| 
 | ||||
|         content_disposition = headers.get("Content-Disposition", None) | ||||
|         if content_disposition: | ||||
|             _, params = cgi.parse_header(content_disposition[0],) | ||||
|             upload_name = None | ||||
| 
 | ||||
|             # First check if there is a valid UTF-8 filename | ||||
|             upload_name_utf8 = params.get("filename*", None) | ||||
|             if upload_name_utf8: | ||||
|                 if upload_name_utf8.lower().startswith("utf-8''"): | ||||
|                     upload_name = upload_name_utf8[7:] | ||||
| 
 | ||||
|             # If there isn't check for an ascii name. | ||||
|             if not upload_name: | ||||
|                 upload_name_ascii = params.get("filename", None) | ||||
|                 if upload_name_ascii and is_ascii(upload_name_ascii): | ||||
|                     upload_name = upload_name_ascii | ||||
| 
 | ||||
|             if upload_name: | ||||
|                 upload_name = urlparse.unquote(upload_name) | ||||
|                 try: | ||||
|                     length, headers = yield self.client.get_file( | ||||
|                         server_name, request_path, output_stream=f, | ||||
|                         max_size=self.max_upload_size, args={ | ||||
|                             # tell the remote server to 404 if it doesn't | ||||
|                             # recognise the server_name, to make sure we don't | ||||
|                             # end up with a routing loop. | ||||
|                             "allow_remote": "false", | ||||
|                         } | ||||
|                     ) | ||||
|                 except twisted.internet.error.DNSLookupError as e: | ||||
|                     logger.warn("HTTP error fetching remote media %s/%s: %r", | ||||
|                                 server_name, media_id, e) | ||||
|                     raise NotFoundError() | ||||
|                     upload_name = upload_name.decode("utf-8") | ||||
|                 except UnicodeDecodeError: | ||||
|                     upload_name = None | ||||
|         else: | ||||
|             upload_name = None | ||||
| 
 | ||||
|                 except HttpResponseException as e: | ||||
|                     logger.warn("HTTP error fetching remote media %s/%s: %s", | ||||
|                                 server_name, media_id, e.response) | ||||
|                     if e.code == twisted.web.http.NOT_FOUND: | ||||
|                         raise SynapseError.from_http_response_exception(e) | ||||
|                     raise SynapseError(502, "Failed to fetch remote media") | ||||
|         logger.info("Stored remote media in file %r", fname) | ||||
| 
 | ||||
|                 except SynapseError: | ||||
|                     logger.exception("Failed to fetch remote media %s/%s", | ||||
|                                      server_name, media_id) | ||||
|                     raise | ||||
|                 except NotRetryingDestination: | ||||
|                     logger.warn("Not retrying destination %r", server_name) | ||||
|                     raise SynapseError(502, "Failed to fetch remote media") | ||||
|                 except Exception: | ||||
|                     logger.exception("Failed to fetch remote media %s/%s", | ||||
|                                      server_name, media_id) | ||||
|                     raise SynapseError(502, "Failed to fetch remote media") | ||||
| 
 | ||||
|             yield self.copy_to_backup(fpath) | ||||
| 
 | ||||
|             media_type = headers["Content-Type"][0] | ||||
|             time_now_ms = self.clock.time_msec() | ||||
| 
 | ||||
|             content_disposition = headers.get("Content-Disposition", None) | ||||
|             if content_disposition: | ||||
|                 _, params = cgi.parse_header(content_disposition[0],) | ||||
|                 upload_name = None | ||||
| 
 | ||||
|                 # First check if there is a valid UTF-8 filename | ||||
|                 upload_name_utf8 = params.get("filename*", None) | ||||
|                 if upload_name_utf8: | ||||
|                     if upload_name_utf8.lower().startswith("utf-8''"): | ||||
|                         upload_name = upload_name_utf8[7:] | ||||
| 
 | ||||
|                 # If there isn't check for an ascii name. | ||||
|                 if not upload_name: | ||||
|                     upload_name_ascii = params.get("filename", None) | ||||
|                     if upload_name_ascii and is_ascii(upload_name_ascii): | ||||
|                         upload_name = upload_name_ascii | ||||
| 
 | ||||
|                 if upload_name: | ||||
|                     upload_name = urlparse.unquote(upload_name) | ||||
|                     try: | ||||
|                         upload_name = upload_name.decode("utf-8") | ||||
|                     except UnicodeDecodeError: | ||||
|                         upload_name = None | ||||
|             else: | ||||
|                 upload_name = None | ||||
| 
 | ||||
|             logger.info("Stored remote media in file %r", fname) | ||||
| 
 | ||||
|             yield self.store.store_cached_remote_media( | ||||
|                 origin=server_name, | ||||
|                 media_id=media_id, | ||||
|                 media_type=media_type, | ||||
|                 time_now_ms=self.clock.time_msec(), | ||||
|                 upload_name=upload_name, | ||||
|                 media_length=length, | ||||
|                 filesystem_id=file_id, | ||||
|             ) | ||||
|         except Exception: | ||||
|             os.remove(fname) | ||||
|             raise | ||||
|         yield self.store.store_cached_remote_media( | ||||
|             origin=server_name, | ||||
|             media_id=media_id, | ||||
|             media_type=media_type, | ||||
|             time_now_ms=self.clock.time_msec(), | ||||
|             upload_name=upload_name, | ||||
|             media_length=length, | ||||
|             filesystem_id=file_id, | ||||
|         ) | ||||
| 
 | ||||
|         media_info = { | ||||
|             "media_type": media_type, | ||||
|  | @ -323,7 +438,7 @@ class MediaRepository(object): | |||
|         } | ||||
| 
 | ||||
|         yield self._generate_thumbnails( | ||||
|             server_name, media_id, media_info | ||||
|             server_name, media_id, file_id, media_type, | ||||
|         ) | ||||
| 
 | ||||
|         defer.returnValue(media_info) | ||||
|  | @ -368,11 +483,18 @@ class MediaRepository(object): | |||
| 
 | ||||
|         if t_byte_source: | ||||
|             try: | ||||
|                 output_path = yield self.write_to_file_and_backup( | ||||
|                     t_byte_source, | ||||
|                     self.filepaths.local_media_thumbnail_rel( | ||||
|                         media_id, t_width, t_height, t_type, t_method | ||||
|                     ) | ||||
|                 file_info = FileInfo( | ||||
|                     server_name=None, | ||||
|                     file_id=media_id, | ||||
|                     thumbnail=True, | ||||
|                     thumbnail_width=t_width, | ||||
|                     thumbnail_height=t_height, | ||||
|                     thumbnail_method=t_method, | ||||
|                     thumbnail_type=t_type, | ||||
|                 ) | ||||
| 
 | ||||
|                 output_path = yield self.media_storage.store_file( | ||||
|                     t_byte_source, file_info, | ||||
|                 ) | ||||
|             finally: | ||||
|                 t_byte_source.close() | ||||
|  | @ -400,11 +522,18 @@ class MediaRepository(object): | |||
| 
 | ||||
|         if t_byte_source: | ||||
|             try: | ||||
|                 output_path = yield self.write_to_file_and_backup( | ||||
|                     t_byte_source, | ||||
|                     self.filepaths.remote_media_thumbnail_rel( | ||||
|                         server_name, file_id, t_width, t_height, t_type, t_method | ||||
|                     ) | ||||
|                 file_info = FileInfo( | ||||
|                     server_name=server_name, | ||||
|                     file_id=media_id, | ||||
|                     thumbnail=True, | ||||
|                     thumbnail_width=t_width, | ||||
|                     thumbnail_height=t_height, | ||||
|                     thumbnail_method=t_method, | ||||
|                     thumbnail_type=t_type, | ||||
|                 ) | ||||
| 
 | ||||
|                 output_path = yield self.media_storage.store_file( | ||||
|                     t_byte_source, file_info, | ||||
|                 ) | ||||
|             finally: | ||||
|                 t_byte_source.close() | ||||
|  | @ -421,21 +550,22 @@ class MediaRepository(object): | |||
|             defer.returnValue(output_path) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _generate_thumbnails(self, server_name, media_id, media_info, url_cache=False): | ||||
|     def _generate_thumbnails(self, server_name, media_id, file_id, media_type, | ||||
|                              url_cache=False): | ||||
|         """Generate and store thumbnails for an image. | ||||
| 
 | ||||
|         Args: | ||||
|             server_name(str|None): The server name if remote media, else None if local | ||||
|             media_id(str) | ||||
|             media_info(dict) | ||||
|             url_cache(bool): If we are thumbnailing images downloaded for the URL cache, | ||||
|             server_name (str|None): The server name if remote media, else None if local | ||||
|             media_id (str): The media ID of the content. (This is the same as | ||||
|                 the file_id for local content) | ||||
|             file_id (str): Local file ID | ||||
|             media_type (str): The content type of the file | ||||
|             url_cache (bool): If we are thumbnailing images downloaded for the URL cache, | ||||
|                 used exclusively by the url previewer | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred[dict]: Dict with "width" and "height" keys of original image | ||||
|         """ | ||||
|         media_type = media_info["media_type"] | ||||
|         file_id = media_info.get("filesystem_id") | ||||
|         requirements = self._get_thumbnail_requirements(media_type) | ||||
|         if not requirements: | ||||
|             return | ||||
|  | @ -472,20 +602,6 @@ class MediaRepository(object): | |||
| 
 | ||||
|         # Now we generate the thumbnails for each dimension, store it | ||||
|         for (t_width, t_height, t_type), t_method in thumbnails.iteritems(): | ||||
|             # Work out the correct file name for thumbnail | ||||
|             if server_name: | ||||
|                 file_path = self.filepaths.remote_media_thumbnail_rel( | ||||
|                     server_name, file_id, t_width, t_height, t_type, t_method | ||||
|                 ) | ||||
|             elif url_cache: | ||||
|                 file_path = self.filepaths.url_cache_thumbnail_rel( | ||||
|                     media_id, t_width, t_height, t_type, t_method | ||||
|                 ) | ||||
|             else: | ||||
|                 file_path = self.filepaths.local_media_thumbnail_rel( | ||||
|                     media_id, t_width, t_height, t_type, t_method | ||||
|                 ) | ||||
| 
 | ||||
|             # Generate the thumbnail | ||||
|             if t_method == "crop": | ||||
|                 t_byte_source = yield make_deferred_yieldable(threads.deferToThread( | ||||
|  | @ -505,9 +621,19 @@ class MediaRepository(object): | |||
|                 continue | ||||
| 
 | ||||
|             try: | ||||
|                 # Write to disk | ||||
|                 output_path = yield self.write_to_file_and_backup( | ||||
|                     t_byte_source, file_path, | ||||
|                 file_info = FileInfo( | ||||
|                     server_name=server_name, | ||||
|                     file_id=file_id, | ||||
|                     thumbnail=True, | ||||
|                     thumbnail_width=t_width, | ||||
|                     thumbnail_height=t_height, | ||||
|                     thumbnail_method=t_method, | ||||
|                     thumbnail_type=t_type, | ||||
|                     url_cache=url_cache, | ||||
|                 ) | ||||
| 
 | ||||
|                 output_path = yield self.media_storage.store_file( | ||||
|                     t_byte_source, file_info, | ||||
|                 ) | ||||
|             finally: | ||||
|                 t_byte_source.close() | ||||
|  | @ -620,7 +746,11 @@ class MediaRepositoryResource(Resource): | |||
| 
 | ||||
|         self.putChild("upload", UploadResource(hs, media_repo)) | ||||
|         self.putChild("download", DownloadResource(hs, media_repo)) | ||||
|         self.putChild("thumbnail", ThumbnailResource(hs, media_repo)) | ||||
|         self.putChild("thumbnail", ThumbnailResource( | ||||
|             hs, media_repo, media_repo.media_storage, | ||||
|         )) | ||||
|         self.putChild("identicon", IdenticonResource()) | ||||
|         if hs.config.url_preview_enabled: | ||||
|             self.putChild("preview_url", PreviewUrlResource(hs, media_repo)) | ||||
|             self.putChild("preview_url", PreviewUrlResource( | ||||
|                 hs, media_repo, media_repo.media_storage, | ||||
|             )) | ||||
|  |  | |||
							
								
								
									
										236
									
								
								synapse/rest/media/v1/media_storage.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										236
									
								
								synapse/rest/media/v1/media_storage.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,236 @@ | |||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2018 New Vecotr Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| from twisted.internet import defer, threads | ||||
| from twisted.protocols.basic import FileSender | ||||
| 
 | ||||
| from ._base import Responder | ||||
| 
 | ||||
| from synapse.util.logcontext import make_deferred_yieldable | ||||
| 
 | ||||
| import contextlib | ||||
| import os | ||||
| import logging | ||||
| import shutil | ||||
| import sys | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| class MediaStorage(object): | ||||
|     """Responsible for storing/fetching files from local sources. | ||||
| 
 | ||||
|     Args: | ||||
|         local_media_directory (str): Base path where we store media on disk | ||||
|         filepaths (MediaFilePaths) | ||||
|         storage_providers ([StorageProvider]): List of StorageProvider that are | ||||
|             used to fetch and store files. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, local_media_directory, filepaths, storage_providers): | ||||
|         self.local_media_directory = local_media_directory | ||||
|         self.filepaths = filepaths | ||||
|         self.storage_providers = storage_providers | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def store_file(self, source, file_info): | ||||
|         """Write `source` to the on disk media store, and also any other | ||||
|         configured storage providers | ||||
| 
 | ||||
|         Args: | ||||
|             source: A file like object that should be written | ||||
|             file_info (FileInfo): Info about the file to store | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred[str]: the file path written to in the primary media store | ||||
|         """ | ||||
|         path = self._file_info_to_path(file_info) | ||||
|         fname = os.path.join(self.local_media_directory, path) | ||||
| 
 | ||||
|         dirname = os.path.dirname(fname) | ||||
|         if not os.path.exists(dirname): | ||||
|             os.makedirs(dirname) | ||||
| 
 | ||||
|         # Write to the main repository | ||||
|         yield make_deferred_yieldable(threads.deferToThread( | ||||
|             _write_file_synchronously, source, fname, | ||||
|         )) | ||||
| 
 | ||||
|         defer.returnValue(fname) | ||||
| 
 | ||||
|     @contextlib.contextmanager | ||||
|     def store_into_file(self, file_info): | ||||
|         """Context manager used to get a file like object to write into, as | ||||
|         described by file_info. | ||||
| 
 | ||||
|         Actually yields a 3-tuple (file, fname, finish_cb), where file is a file | ||||
|         like object that can be written to, fname is the absolute path of file | ||||
|         on disk, and finish_cb is a function that returns a Deferred. | ||||
| 
 | ||||
|         fname can be used to read the contents from after upload, e.g. to | ||||
|         generate thumbnails. | ||||
| 
 | ||||
|         finish_cb must be called and waited on after the file has been | ||||
|         successfully been written to. Should not be called if there was an | ||||
|         error. | ||||
| 
 | ||||
|         Args: | ||||
|             file_info (FileInfo): Info about the file to store | ||||
| 
 | ||||
|         Example: | ||||
| 
 | ||||
|             with media_storage.store_into_file(info) as (f, fname, finish_cb): | ||||
|                 # .. write into f ... | ||||
|                 yield finish_cb() | ||||
|         """ | ||||
| 
 | ||||
|         path = self._file_info_to_path(file_info) | ||||
|         fname = os.path.join(self.local_media_directory, path) | ||||
| 
 | ||||
|         dirname = os.path.dirname(fname) | ||||
|         if not os.path.exists(dirname): | ||||
|             os.makedirs(dirname) | ||||
| 
 | ||||
|         finished_called = [False] | ||||
| 
 | ||||
|         @defer.inlineCallbacks | ||||
|         def finish(): | ||||
|             for provider in self.storage_providers: | ||||
|                 yield provider.store_file(path, file_info) | ||||
| 
 | ||||
|             finished_called[0] = True | ||||
| 
 | ||||
|         try: | ||||
|             with open(fname, "wb") as f: | ||||
|                 yield f, fname, finish | ||||
|         except Exception: | ||||
|             t, v, tb = sys.exc_info() | ||||
|             try: | ||||
|                 os.remove(fname) | ||||
|             except Exception: | ||||
|                 pass | ||||
|             raise t, v, tb | ||||
| 
 | ||||
|         if not finished_called: | ||||
|             raise Exception("Finished callback not called") | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def fetch_media(self, file_info): | ||||
|         """Attempts to fetch media described by file_info from the local cache | ||||
|         and configured storage providers. | ||||
| 
 | ||||
|         Args: | ||||
|             file_info (FileInfo) | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred[Responder|None]: Returns a Responder if the file was found, | ||||
|                 otherwise None. | ||||
|         """ | ||||
| 
 | ||||
|         path = self._file_info_to_path(file_info) | ||||
|         local_path = os.path.join(self.local_media_directory, path) | ||||
|         if os.path.exists(local_path): | ||||
|             defer.returnValue(FileResponder(open(local_path, "rb"))) | ||||
| 
 | ||||
|         for provider in self.storage_providers: | ||||
|             res = yield provider.fetch(path, file_info) | ||||
|             if res: | ||||
|                 defer.returnValue(res) | ||||
| 
 | ||||
|         defer.returnValue(None) | ||||
| 
 | ||||
|     def _file_info_to_path(self, file_info): | ||||
|         """Converts file_info into a relative path. | ||||
| 
 | ||||
|         The path is suitable for storing files under a directory, e.g. used to | ||||
|         store files on local FS under the base media repository directory. | ||||
| 
 | ||||
|         Args: | ||||
|             file_info (FileInfo) | ||||
| 
 | ||||
|         Returns: | ||||
|             str | ||||
|         """ | ||||
|         if file_info.url_cache: | ||||
|             if file_info.thumbnail: | ||||
|                 return self.filepaths.url_cache_thumbnail_rel( | ||||
|                     media_id=file_info.file_id, | ||||
|                     width=file_info.thumbnail_width, | ||||
|                     height=file_info.thumbnail_height, | ||||
|                     content_type=file_info.thumbnail_type, | ||||
|                     method=file_info.thumbnail_method, | ||||
|                 ) | ||||
|             return self.filepaths.url_cache_filepath_rel(file_info.file_id) | ||||
| 
 | ||||
|         if file_info.server_name: | ||||
|             if file_info.thumbnail: | ||||
|                 return self.filepaths.remote_media_thumbnail_rel( | ||||
|                     server_name=file_info.server_name, | ||||
|                     file_id=file_info.file_id, | ||||
|                     width=file_info.thumbnail_width, | ||||
|                     height=file_info.thumbnail_height, | ||||
|                     content_type=file_info.thumbnail_type, | ||||
|                     method=file_info.thumbnail_method | ||||
|                 ) | ||||
|             return self.filepaths.remote_media_filepath_rel( | ||||
|                 file_info.server_name, file_info.file_id, | ||||
|             ) | ||||
| 
 | ||||
|         if file_info.thumbnail: | ||||
|             return self.filepaths.local_media_thumbnail_rel( | ||||
|                 media_id=file_info.file_id, | ||||
|                 width=file_info.thumbnail_width, | ||||
|                 height=file_info.thumbnail_height, | ||||
|                 content_type=file_info.thumbnail_type, | ||||
|                 method=file_info.thumbnail_method | ||||
|             ) | ||||
|         return self.filepaths.local_media_filepath_rel( | ||||
|             file_info.file_id, | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| def _write_file_synchronously(source, fname): | ||||
|     """Write `source` to the path `fname` synchronously. Should be called | ||||
|     from a thread. | ||||
| 
 | ||||
|     Args: | ||||
|         source: A file like object to be written | ||||
|         fname (str): Path to write to | ||||
|     """ | ||||
|     dirname = os.path.dirname(fname) | ||||
|     if not os.path.exists(dirname): | ||||
|         os.makedirs(dirname) | ||||
| 
 | ||||
|     source.seek(0)  # Ensure we read from the start of the file | ||||
|     with open(fname, "wb") as f: | ||||
|         shutil.copyfileobj(source, f) | ||||
| 
 | ||||
| 
 | ||||
| class FileResponder(Responder): | ||||
|     """Wraps an open file that can be sent to a request. | ||||
| 
 | ||||
|     Args: | ||||
|         open_file (file): A file like object to be streamed ot the client, | ||||
|             is closed when finished streaming. | ||||
|     """ | ||||
|     def __init__(self, open_file): | ||||
|         self.open_file = open_file | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def write_to_consumer(self, consumer): | ||||
|         yield FileSender().beginFileTransfer(self.open_file, consumer) | ||||
| 
 | ||||
|     def __exit__(self, exc_type, exc_val, exc_tb): | ||||
|         self.open_file.close() | ||||
|  | @ -17,6 +17,8 @@ from twisted.web.server import NOT_DONE_YET | |||
| from twisted.internet import defer | ||||
| from twisted.web.resource import Resource | ||||
| 
 | ||||
| from ._base import FileInfo | ||||
| 
 | ||||
| from synapse.api.errors import ( | ||||
|     SynapseError, Codes, | ||||
| ) | ||||
|  | @ -49,7 +51,7 @@ logger = logging.getLogger(__name__) | |||
| class PreviewUrlResource(Resource): | ||||
|     isLeaf = True | ||||
| 
 | ||||
|     def __init__(self, hs, media_repo): | ||||
|     def __init__(self, hs, media_repo, media_storage): | ||||
|         Resource.__init__(self) | ||||
| 
 | ||||
|         self.auth = hs.get_auth() | ||||
|  | @ -62,6 +64,7 @@ class PreviewUrlResource(Resource): | |||
|         self.client = SpiderHttpClient(hs) | ||||
|         self.media_repo = media_repo | ||||
|         self.primary_base_path = media_repo.primary_base_path | ||||
|         self.media_storage = media_storage | ||||
| 
 | ||||
|         self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist | ||||
| 
 | ||||
|  | @ -182,8 +185,10 @@ class PreviewUrlResource(Resource): | |||
|         logger.debug("got media_info of '%s'" % media_info) | ||||
| 
 | ||||
|         if _is_media(media_info['media_type']): | ||||
|             file_id = media_info['filesystem_id'] | ||||
|             dims = yield self.media_repo._generate_thumbnails( | ||||
|                 None, media_info['filesystem_id'], media_info, url_cache=True, | ||||
|                 None, file_id, file_id, media_info["media_type"], | ||||
|                 url_cache=True, | ||||
|             ) | ||||
| 
 | ||||
|             og = { | ||||
|  | @ -228,8 +233,10 @@ class PreviewUrlResource(Resource): | |||
| 
 | ||||
|                 if _is_media(image_info['media_type']): | ||||
|                     # TODO: make sure we don't choke on white-on-transparent images | ||||
|                     file_id = image_info['filesystem_id'] | ||||
|                     dims = yield self.media_repo._generate_thumbnails( | ||||
|                         None, image_info['filesystem_id'], image_info, url_cache=True, | ||||
|                         None, file_id, file_id, image_info["media_type"], | ||||
|                         url_cache=True, | ||||
|                     ) | ||||
|                     if dims: | ||||
|                         og["og:image:width"] = dims['width'] | ||||
|  | @ -273,19 +280,21 @@ class PreviewUrlResource(Resource): | |||
| 
 | ||||
|         file_id = datetime.date.today().isoformat() + '_' + random_string(16) | ||||
| 
 | ||||
|         fpath = self.filepaths.url_cache_filepath_rel(file_id) | ||||
|         fname = os.path.join(self.primary_base_path, fpath) | ||||
|         self.media_repo._makedirs(fname) | ||||
|         file_info = FileInfo( | ||||
|             server_name=None, | ||||
|             file_id=file_id, | ||||
|             url_cache=True, | ||||
|         ) | ||||
| 
 | ||||
|         try: | ||||
|             with open(fname, "wb") as f: | ||||
|             with self.media_storage.store_into_file(file_info) as (f, fname, finish): | ||||
|                 logger.debug("Trying to get url '%s'" % url) | ||||
|                 length, headers, uri, code = yield self.client.get_file( | ||||
|                     url, output_stream=f, max_size=self.max_spider_size, | ||||
|                 ) | ||||
|                 # FIXME: pass through 404s and other error messages nicely | ||||
| 
 | ||||
|             yield self.media_repo.copy_to_backup(fpath) | ||||
|                 yield finish() | ||||
| 
 | ||||
|             media_type = headers["Content-Type"][0] | ||||
|             time_now_ms = self.clock.time_msec() | ||||
|  | @ -327,7 +336,6 @@ class PreviewUrlResource(Resource): | |||
|             ) | ||||
| 
 | ||||
|         except Exception as e: | ||||
|             os.remove(fname) | ||||
|             raise SynapseError( | ||||
|                 500, ("Failed to download content: %s" % e), | ||||
|                 Codes.UNKNOWN | ||||
|  |  | |||
							
								
								
									
										140
									
								
								synapse/rest/media/v1/storage_provider.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										140
									
								
								synapse/rest/media/v1/storage_provider.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,140 @@ | |||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2018 New Vector Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| from twisted.internet import defer, threads | ||||
| 
 | ||||
| from .media_storage import FileResponder | ||||
| 
 | ||||
| from synapse.config._base import Config | ||||
| from synapse.util.logcontext import preserve_fn | ||||
| 
 | ||||
| import logging | ||||
| import os | ||||
| import shutil | ||||
| 
 | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| class StorageProvider(object): | ||||
|     """A storage provider is a service that can store uploaded media and | ||||
|     retrieve them. | ||||
|     """ | ||||
|     def store_file(self, path, file_info): | ||||
|         """Store the file described by file_info. The actual contents can be | ||||
|         retrieved by reading the file in file_info.upload_path. | ||||
| 
 | ||||
|         Args: | ||||
|             path (str): Relative path of file in local cache | ||||
|             file_info (FileInfo) | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred | ||||
|         """ | ||||
|         pass | ||||
| 
 | ||||
|     def fetch(self, path, file_info): | ||||
|         """Attempt to fetch the file described by file_info and stream it | ||||
|         into writer. | ||||
| 
 | ||||
|         Args: | ||||
|             path (str): Relative path of file in local cache | ||||
|             file_info (FileInfo) | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred(Responder): Returns a Responder if the provider has the file, | ||||
|                 otherwise returns None. | ||||
|         """ | ||||
|         pass | ||||
| 
 | ||||
| 
 | ||||
| class StorageProviderWrapper(StorageProvider): | ||||
|     """Wraps a storage provider and provides various config options | ||||
| 
 | ||||
|     Args: | ||||
|         backend (StorageProvider) | ||||
|         store_local (bool): Whether to store new local files or not. | ||||
|         store_synchronous (bool): Whether to wait for file to be successfully | ||||
|             uploaded, or todo the upload in the backgroud. | ||||
|         store_remote (bool): Whether remote media should be uploaded | ||||
|     """ | ||||
|     def __init__(self, backend, store_local, store_synchronous, store_remote): | ||||
|         self.backend = backend | ||||
|         self.store_local = store_local | ||||
|         self.store_synchronous = store_synchronous | ||||
|         self.store_remote = store_remote | ||||
| 
 | ||||
|     def store_file(self, path, file_info): | ||||
|         if not file_info.server_name and not self.store_local: | ||||
|             return defer.succeed(None) | ||||
| 
 | ||||
|         if file_info.server_name and not self.store_remote: | ||||
|             return defer.succeed(None) | ||||
| 
 | ||||
|         if self.store_synchronous: | ||||
|             return self.backend.store_file(path, file_info) | ||||
|         else: | ||||
|             # TODO: Handle errors. | ||||
|             preserve_fn(self.backend.store_file)(path, file_info) | ||||
|             return defer.succeed(None) | ||||
| 
 | ||||
|     def fetch(self, path, file_info): | ||||
|         return self.backend.fetch(path, file_info) | ||||
| 
 | ||||
| 
 | ||||
| class FileStorageProviderBackend(StorageProvider): | ||||
|     """A storage provider that stores files in a directory on a filesystem. | ||||
| 
 | ||||
|     Args: | ||||
|         hs (HomeServer) | ||||
|         config: The config returned by `parse_config`. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, hs, config): | ||||
|         self.cache_directory = hs.config.media_store_path | ||||
|         self.base_directory = config | ||||
| 
 | ||||
|     def store_file(self, path, file_info): | ||||
|         """See StorageProvider.store_file""" | ||||
| 
 | ||||
|         primary_fname = os.path.join(self.cache_directory, path) | ||||
|         backup_fname = os.path.join(self.base_directory, path) | ||||
| 
 | ||||
|         dirname = os.path.dirname(backup_fname) | ||||
|         if not os.path.exists(dirname): | ||||
|             os.makedirs(dirname) | ||||
| 
 | ||||
|         return threads.deferToThread( | ||||
|             shutil.copyfile, primary_fname, backup_fname, | ||||
|         ) | ||||
| 
 | ||||
|     def fetch(self, path, file_info): | ||||
|         """See StorageProvider.fetch""" | ||||
| 
 | ||||
|         backup_fname = os.path.join(self.base_directory, path) | ||||
|         if os.path.isfile(backup_fname): | ||||
|             return FileResponder(open(backup_fname, "rb")) | ||||
| 
 | ||||
|     @staticmethod | ||||
|     def parse_config(config): | ||||
|         """Called on startup to parse config supplied. This should parse | ||||
|         the config and raise if there is a problem. | ||||
| 
 | ||||
|         The returned value is passed into the constructor. | ||||
| 
 | ||||
|         In this case we only care about a single param, the directory, so let's | ||||
|         just pull that out. | ||||
|         """ | ||||
|         return Config.ensure_directory(config["directory"]) | ||||
|  | @ -14,7 +14,10 @@ | |||
| # limitations under the License. | ||||
| 
 | ||||
| 
 | ||||
| from ._base import parse_media_id, respond_404, respond_with_file | ||||
| from ._base import ( | ||||
|     parse_media_id, respond_404, respond_with_file, FileInfo, | ||||
|     respond_with_responder, | ||||
| ) | ||||
| from twisted.web.resource import Resource | ||||
| from synapse.http.servlet import parse_string, parse_integer | ||||
| from synapse.http.server import request_handler, set_cors_headers | ||||
|  | @ -30,12 +33,12 @@ logger = logging.getLogger(__name__) | |||
| class ThumbnailResource(Resource): | ||||
|     isLeaf = True | ||||
| 
 | ||||
|     def __init__(self, hs, media_repo): | ||||
|     def __init__(self, hs, media_repo, media_storage): | ||||
|         Resource.__init__(self) | ||||
| 
 | ||||
|         self.store = hs.get_datastore() | ||||
|         self.filepaths = media_repo.filepaths | ||||
|         self.media_repo = media_repo | ||||
|         self.media_storage = media_storage | ||||
|         self.dynamic_thumbnails = hs.config.dynamic_thumbnails | ||||
|         self.server_name = hs.hostname | ||||
|         self.version_string = hs.version_string | ||||
|  | @ -64,6 +67,7 @@ class ThumbnailResource(Resource): | |||
|                 yield self._respond_local_thumbnail( | ||||
|                     request, media_id, width, height, method, m_type | ||||
|                 ) | ||||
|             self.media_repo.mark_recently_accessed(None, media_id) | ||||
|         else: | ||||
|             if self.dynamic_thumbnails: | ||||
|                 yield self._select_or_generate_remote_thumbnail( | ||||
|  | @ -75,20 +79,20 @@ class ThumbnailResource(Resource): | |||
|                     request, server_name, media_id, | ||||
|                     width, height, method, m_type | ||||
|                 ) | ||||
|             self.media_repo.mark_recently_accessed(server_name, media_id) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _respond_local_thumbnail(self, request, media_id, width, height, | ||||
|                                  method, m_type): | ||||
|         media_info = yield self.store.get_local_media(media_id) | ||||
| 
 | ||||
|         if not media_info or media_info["quarantined_by"]: | ||||
|         if not media_info: | ||||
|             respond_404(request) | ||||
|             return | ||||
|         if media_info["quarantined_by"]: | ||||
|             logger.info("Media is quarantined") | ||||
|             respond_404(request) | ||||
|             return | ||||
| 
 | ||||
|         # if media_info["media_type"] == "image/svg+xml": | ||||
|         #     file_path = self.filepaths.local_media_filepath(media_id) | ||||
|         #     yield respond_with_file(request, media_info["media_type"], file_path) | ||||
|         #     return | ||||
| 
 | ||||
|         thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id) | ||||
| 
 | ||||
|  | @ -96,42 +100,39 @@ class ThumbnailResource(Resource): | |||
|             thumbnail_info = self._select_thumbnail( | ||||
|                 width, height, method, m_type, thumbnail_infos | ||||
|             ) | ||||
|             t_width = thumbnail_info["thumbnail_width"] | ||||
|             t_height = thumbnail_info["thumbnail_height"] | ||||
|             t_type = thumbnail_info["thumbnail_type"] | ||||
|             t_method = thumbnail_info["thumbnail_method"] | ||||
| 
 | ||||
|             if media_info["url_cache"]: | ||||
|                 # TODO: Check the file still exists, if it doesn't we can redownload | ||||
|                 # it from the url `media_info["url_cache"]` | ||||
|                 file_path = self.filepaths.url_cache_thumbnail( | ||||
|                     media_id, t_width, t_height, t_type, t_method, | ||||
|                 ) | ||||
|             else: | ||||
|                 file_path = self.filepaths.local_media_thumbnail( | ||||
|                     media_id, t_width, t_height, t_type, t_method, | ||||
|                 ) | ||||
|             yield respond_with_file(request, t_type, file_path) | ||||
| 
 | ||||
|         else: | ||||
|             yield self._respond_default_thumbnail( | ||||
|                 request, media_info, width, height, method, m_type, | ||||
|             file_info = FileInfo( | ||||
|                 server_name=None, file_id=media_id, | ||||
|                 url_cache=media_info["url_cache"], | ||||
|                 thumbnail=True, | ||||
|                 thumbnail_width=thumbnail_info["thumbnail_width"], | ||||
|                 thumbnail_height=thumbnail_info["thumbnail_height"], | ||||
|                 thumbnail_type=thumbnail_info["thumbnail_type"], | ||||
|                 thumbnail_method=thumbnail_info["thumbnail_method"], | ||||
|             ) | ||||
| 
 | ||||
|             t_type = file_info.thumbnail_type | ||||
|             t_length = thumbnail_info["thumbnail_length"] | ||||
| 
 | ||||
|             responder = yield self.media_storage.fetch_media(file_info) | ||||
|             yield respond_with_responder(request, responder, t_type, t_length) | ||||
|         else: | ||||
|             logger.info("Couldn't find any generated thumbnails") | ||||
|             respond_404(request) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _select_or_generate_local_thumbnail(self, request, media_id, desired_width, | ||||
|                                             desired_height, desired_method, | ||||
|                                             desired_type): | ||||
|         media_info = yield self.store.get_local_media(media_id) | ||||
| 
 | ||||
|         if not media_info or media_info["quarantined_by"]: | ||||
|         if not media_info: | ||||
|             respond_404(request) | ||||
|             return | ||||
|         if media_info["quarantined_by"]: | ||||
|             logger.info("Media is quarantined") | ||||
|             respond_404(request) | ||||
|             return | ||||
| 
 | ||||
|         # if media_info["media_type"] == "image/svg+xml": | ||||
|         #     file_path = self.filepaths.local_media_filepath(media_id) | ||||
|         #     yield respond_with_file(request, media_info["media_type"], file_path) | ||||
|         #     return | ||||
| 
 | ||||
|         thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id) | ||||
|         for info in thumbnail_infos: | ||||
|  | @ -141,22 +142,25 @@ class ThumbnailResource(Resource): | |||
|             t_type = info["thumbnail_type"] == desired_type | ||||
| 
 | ||||
|             if t_w and t_h and t_method and t_type: | ||||
|                 if media_info["url_cache"]: | ||||
|                     # TODO: Check the file still exists, if it doesn't we can redownload | ||||
|                     # it from the url `media_info["url_cache"]` | ||||
|                     file_path = self.filepaths.url_cache_thumbnail( | ||||
|                         media_id, desired_width, desired_height, desired_type, | ||||
|                         desired_method, | ||||
|                     ) | ||||
|                 else: | ||||
|                     file_path = self.filepaths.local_media_thumbnail( | ||||
|                         media_id, desired_width, desired_height, desired_type, | ||||
|                         desired_method, | ||||
|                     ) | ||||
|                 yield respond_with_file(request, desired_type, file_path) | ||||
|                 return | ||||
|                 file_info = FileInfo( | ||||
|                     server_name=None, file_id=media_id, | ||||
|                     url_cache=media_info["url_cache"], | ||||
|                     thumbnail=True, | ||||
|                     thumbnail_width=info["thumbnail_width"], | ||||
|                     thumbnail_height=info["thumbnail_height"], | ||||
|                     thumbnail_type=info["thumbnail_type"], | ||||
|                     thumbnail_method=info["thumbnail_method"], | ||||
|                 ) | ||||
| 
 | ||||
|         logger.debug("We don't have a local thumbnail of that size. Generating") | ||||
|                 t_type = file_info.thumbnail_type | ||||
|                 t_length = info["thumbnail_length"] | ||||
| 
 | ||||
|                 responder = yield self.media_storage.fetch_media(file_info) | ||||
|                 if responder: | ||||
|                     yield respond_with_responder(request, responder, t_type, t_length) | ||||
|                     return | ||||
| 
 | ||||
|         logger.debug("We don't have a thumbnail of that size. Generating") | ||||
| 
 | ||||
|         # Okay, so we generate one. | ||||
|         file_path = yield self.media_repo.generate_local_exact_thumbnail( | ||||
|  | @ -166,21 +170,14 @@ class ThumbnailResource(Resource): | |||
|         if file_path: | ||||
|             yield respond_with_file(request, desired_type, file_path) | ||||
|         else: | ||||
|             yield self._respond_default_thumbnail( | ||||
|                 request, media_info, desired_width, desired_height, | ||||
|                 desired_method, desired_type, | ||||
|             ) | ||||
|             logger.warn("Failed to generate thumbnail") | ||||
|             respond_404(request) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _select_or_generate_remote_thumbnail(self, request, server_name, media_id, | ||||
|                                              desired_width, desired_height, | ||||
|                                              desired_method, desired_type): | ||||
|         media_info = yield self.media_repo.get_remote_media(server_name, media_id) | ||||
| 
 | ||||
|         # if media_info["media_type"] == "image/svg+xml": | ||||
|         #     file_path = self.filepaths.remote_media_filepath(server_name, media_id) | ||||
|         #     yield respond_with_file(request, media_info["media_type"], file_path) | ||||
|         #     return | ||||
|         media_info = yield self.media_repo.get_remote_media_info(server_name, media_id) | ||||
| 
 | ||||
|         thumbnail_infos = yield self.store.get_remote_media_thumbnails( | ||||
|             server_name, media_id, | ||||
|  | @ -195,14 +192,24 @@ class ThumbnailResource(Resource): | |||
|             t_type = info["thumbnail_type"] == desired_type | ||||
| 
 | ||||
|             if t_w and t_h and t_method and t_type: | ||||
|                 file_path = self.filepaths.remote_media_thumbnail( | ||||
|                     server_name, file_id, desired_width, desired_height, | ||||
|                     desired_type, desired_method, | ||||
|                 file_info = FileInfo( | ||||
|                     server_name=server_name, file_id=media_info["filesystem_id"], | ||||
|                     thumbnail=True, | ||||
|                     thumbnail_width=info["thumbnail_width"], | ||||
|                     thumbnail_height=info["thumbnail_height"], | ||||
|                     thumbnail_type=info["thumbnail_type"], | ||||
|                     thumbnail_method=info["thumbnail_method"], | ||||
|                 ) | ||||
|                 yield respond_with_file(request, desired_type, file_path) | ||||
|                 return | ||||
| 
 | ||||
|         logger.debug("We don't have a local thumbnail of that size. Generating") | ||||
|                 t_type = file_info.thumbnail_type | ||||
|                 t_length = info["thumbnail_length"] | ||||
| 
 | ||||
|                 responder = yield self.media_storage.fetch_media(file_info) | ||||
|                 if responder: | ||||
|                     yield respond_with_responder(request, responder, t_type, t_length) | ||||
|                     return | ||||
| 
 | ||||
|         logger.debug("We don't have a thumbnail of that size. Generating") | ||||
| 
 | ||||
|         # Okay, so we generate one. | ||||
|         file_path = yield self.media_repo.generate_remote_exact_thumbnail( | ||||
|  | @ -213,22 +220,16 @@ class ThumbnailResource(Resource): | |||
|         if file_path: | ||||
|             yield respond_with_file(request, desired_type, file_path) | ||||
|         else: | ||||
|             yield self._respond_default_thumbnail( | ||||
|                 request, media_info, desired_width, desired_height, | ||||
|                 desired_method, desired_type, | ||||
|             ) | ||||
|             logger.warn("Failed to generate thumbnail") | ||||
|             respond_404(request) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _respond_remote_thumbnail(self, request, server_name, media_id, width, | ||||
|                                   height, method, m_type): | ||||
|         # TODO: Don't download the whole remote file | ||||
|         # We should proxy the thumbnail from the remote server instead. | ||||
|         media_info = yield self.media_repo.get_remote_media(server_name, media_id) | ||||
| 
 | ||||
|         # if media_info["media_type"] == "image/svg+xml": | ||||
|         #     file_path = self.filepaths.remote_media_filepath(server_name, media_id) | ||||
|         #     yield respond_with_file(request, media_info["media_type"], file_path) | ||||
|         #     return | ||||
|         # We should proxy the thumbnail from the remote server instead of | ||||
|         # downloading the remote file and generating our own thumbnails. | ||||
|         media_info = yield self.media_repo.get_remote_media_info(server_name, media_id) | ||||
| 
 | ||||
|         thumbnail_infos = yield self.store.get_remote_media_thumbnails( | ||||
|             server_name, media_id, | ||||
|  | @ -238,59 +239,23 @@ class ThumbnailResource(Resource): | |||
|             thumbnail_info = self._select_thumbnail( | ||||
|                 width, height, method, m_type, thumbnail_infos | ||||
|             ) | ||||
|             t_width = thumbnail_info["thumbnail_width"] | ||||
|             t_height = thumbnail_info["thumbnail_height"] | ||||
|             t_type = thumbnail_info["thumbnail_type"] | ||||
|             t_method = thumbnail_info["thumbnail_method"] | ||||
|             file_id = thumbnail_info["filesystem_id"] | ||||
|             file_info = FileInfo( | ||||
|                 server_name=server_name, file_id=media_info["filesystem_id"], | ||||
|                 thumbnail=True, | ||||
|                 thumbnail_width=thumbnail_info["thumbnail_width"], | ||||
|                 thumbnail_height=thumbnail_info["thumbnail_height"], | ||||
|                 thumbnail_type=thumbnail_info["thumbnail_type"], | ||||
|                 thumbnail_method=thumbnail_info["thumbnail_method"], | ||||
|             ) | ||||
| 
 | ||||
|             t_type = file_info.thumbnail_type | ||||
|             t_length = thumbnail_info["thumbnail_length"] | ||||
| 
 | ||||
|             file_path = self.filepaths.remote_media_thumbnail( | ||||
|                 server_name, file_id, t_width, t_height, t_type, t_method, | ||||
|             ) | ||||
|             yield respond_with_file(request, t_type, file_path, t_length) | ||||
|             responder = yield self.media_storage.fetch_media(file_info) | ||||
|             yield respond_with_responder(request, responder, t_type, t_length) | ||||
|         else: | ||||
|             yield self._respond_default_thumbnail( | ||||
|                 request, media_info, width, height, method, m_type, | ||||
|             ) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def _respond_default_thumbnail(self, request, media_info, width, height, | ||||
|                                    method, m_type): | ||||
|         # XXX: how is this meant to work? store.get_default_thumbnails | ||||
|         # appears to always return [] so won't this always 404? | ||||
|         media_type = media_info["media_type"] | ||||
|         top_level_type = media_type.split("/")[0] | ||||
|         sub_type = media_type.split("/")[-1].split(";")[0] | ||||
|         thumbnail_infos = yield self.store.get_default_thumbnails( | ||||
|             top_level_type, sub_type, | ||||
|         ) | ||||
|         if not thumbnail_infos: | ||||
|             thumbnail_infos = yield self.store.get_default_thumbnails( | ||||
|                 top_level_type, "_default", | ||||
|             ) | ||||
|         if not thumbnail_infos: | ||||
|             thumbnail_infos = yield self.store.get_default_thumbnails( | ||||
|                 "_default", "_default", | ||||
|             ) | ||||
|         if not thumbnail_infos: | ||||
|             logger.info("Failed to find any generated thumbnails") | ||||
|             respond_404(request) | ||||
|             return | ||||
| 
 | ||||
|         thumbnail_info = self._select_thumbnail( | ||||
|             width, height, "crop", m_type, thumbnail_infos | ||||
|         ) | ||||
| 
 | ||||
|         t_width = thumbnail_info["thumbnail_width"] | ||||
|         t_height = thumbnail_info["thumbnail_height"] | ||||
|         t_type = thumbnail_info["thumbnail_type"] | ||||
|         t_method = thumbnail_info["thumbnail_method"] | ||||
|         t_length = thumbnail_info["thumbnail_length"] | ||||
| 
 | ||||
|         file_path = self.filepaths.default_thumbnail( | ||||
|             top_level_type, sub_type, t_width, t_height, t_type, t_method, | ||||
|         ) | ||||
|         yield respond_with_file(request, t_type, file_path, t_length) | ||||
| 
 | ||||
|     def _select_thumbnail(self, desired_width, desired_height, desired_method, | ||||
|                           desired_type, thumbnail_infos): | ||||
|  |  | |||
|  | @ -307,6 +307,23 @@ class HomeServer(object): | |||
|             **self.db_config.get("args", {}) | ||||
|         ) | ||||
| 
 | ||||
|     def get_db_conn(self, run_new_connection=True): | ||||
|         """Makes a new connection to the database, skipping the db pool | ||||
| 
 | ||||
|         Returns: | ||||
|             Connection: a connection object implementing the PEP-249 spec | ||||
|         """ | ||||
|         # Any param beginning with cp_ is a parameter for adbapi, and should | ||||
|         # not be passed to the database engine. | ||||
|         db_params = { | ||||
|             k: v for k, v in self.db_config.get("args", {}).items() | ||||
|             if not k.startswith("cp_") | ||||
|         } | ||||
|         db_conn = self.database_engine.module.connect(**db_params) | ||||
|         if run_new_connection: | ||||
|             self.database_engine.on_new_connection(db_conn) | ||||
|         return db_conn | ||||
| 
 | ||||
|     def build_media_repository_resource(self): | ||||
|         # build the media repo resource. This indirects through the HomeServer | ||||
|         # to ensure that we only have a single instance of | ||||
|  |  | |||
|  | @ -146,8 +146,20 @@ class StateHandler(object): | |||
|         defer.returnValue(state) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def get_current_state_ids(self, room_id, event_type=None, state_key="", | ||||
|                               latest_event_ids=None): | ||||
|     def get_current_state_ids(self, room_id, latest_event_ids=None): | ||||
|         """Get the current state, or the state at a set of events, for a room | ||||
| 
 | ||||
|         Args: | ||||
|             room_id (str): | ||||
| 
 | ||||
|             latest_event_ids (iterable[str]|None): if given, the forward | ||||
|                 extremities to resolve. If None, we look them up from the | ||||
|                 database (via a cache) | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred[dict[(str, str), str)]]: the state dict, mapping from | ||||
|                 (event_type, state_key) -> event_id | ||||
|         """ | ||||
|         if not latest_event_ids: | ||||
|             latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) | ||||
| 
 | ||||
|  | @ -155,10 +167,6 @@ class StateHandler(object): | |||
|         ret = yield self.resolve_state_groups(room_id, latest_event_ids) | ||||
|         state = ret.state | ||||
| 
 | ||||
|         if event_type: | ||||
|             defer.returnValue(state.get((event_type, state_key))) | ||||
|             return | ||||
| 
 | ||||
|         defer.returnValue(state) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|  | @ -341,7 +349,7 @@ class StateHandler(object): | |||
|             if conflicted_state: | ||||
|                 logger.info("Resolving conflicted state for %r", room_id) | ||||
|                 with Measure(self.clock, "state._resolve_events"): | ||||
|                     new_state = yield resolve_events( | ||||
|                     new_state = yield resolve_events_with_factory( | ||||
|                         state_groups_ids.values(), | ||||
|                         state_map_factory=lambda ev_ids: self.store.get_events( | ||||
|                             ev_ids, get_prev_content=False, check_redacted=False, | ||||
|  | @ -404,7 +412,7 @@ class StateHandler(object): | |||
|         } | ||||
| 
 | ||||
|         with Measure(self.clock, "state._resolve_events"): | ||||
|             new_state = resolve_events(state_set_ids, state_map) | ||||
|             new_state = resolve_events_with_state_map(state_set_ids, state_map) | ||||
| 
 | ||||
|         new_state = { | ||||
|             key: state_map[ev_id] for key, ev_id in new_state.items() | ||||
|  | @ -420,19 +428,17 @@ def _ordered_events(events): | |||
|     return sorted(events, key=key_func) | ||||
| 
 | ||||
| 
 | ||||
| def resolve_events(state_sets, state_map_factory): | ||||
| def resolve_events_with_state_map(state_sets, state_map): | ||||
|     """ | ||||
|     Args: | ||||
|         state_sets(list): List of dicts of (type, state_key) -> event_id, | ||||
|             which are the different state groups to resolve. | ||||
|         state_map_factory(dict|callable): If callable, then will be called | ||||
|             with a list of event_ids that are needed, and should return with | ||||
|             a Deferred of dict of event_id to event. Otherwise, should be | ||||
|             a dict from event_id to event of all events in state_sets. | ||||
|         state_map(dict): a dict from event_id to event, for all events in | ||||
|             state_sets. | ||||
| 
 | ||||
|     Returns | ||||
|         dict[(str, str), synapse.events.FrozenEvent] is a map from | ||||
|         (type, state_key) to event. | ||||
|         dict[(str, str), synapse.events.FrozenEvent]: | ||||
|             a map from (type, state_key) to event. | ||||
|     """ | ||||
|     if len(state_sets) == 1: | ||||
|         return state_sets[0] | ||||
|  | @ -441,13 +447,6 @@ def resolve_events(state_sets, state_map_factory): | |||
|         state_sets, | ||||
|     ) | ||||
| 
 | ||||
|     if callable(state_map_factory): | ||||
|         return _resolve_with_state_fac( | ||||
|             unconflicted_state, conflicted_state, state_map_factory | ||||
|         ) | ||||
| 
 | ||||
|     state_map = state_map_factory | ||||
| 
 | ||||
|     auth_events = _create_auth_events_from_maps( | ||||
|         unconflicted_state, conflicted_state, state_map | ||||
|     ) | ||||
|  | @ -491,8 +490,26 @@ def _seperate(state_sets): | |||
| 
 | ||||
| 
 | ||||
| @defer.inlineCallbacks | ||||
| def _resolve_with_state_fac(unconflicted_state, conflicted_state, | ||||
|                             state_map_factory): | ||||
| def resolve_events_with_factory(state_sets, state_map_factory): | ||||
|     """ | ||||
|     Args: | ||||
|         state_sets(list): List of dicts of (type, state_key) -> event_id, | ||||
|             which are the different state groups to resolve. | ||||
|         state_map_factory(func): will be called | ||||
|             with a list of event_ids that are needed, and should return with | ||||
|             a Deferred of dict of event_id to event. | ||||
| 
 | ||||
|     Returns | ||||
|         Deferred[dict[(str, str), synapse.events.FrozenEvent]]: | ||||
|             a map from (type, state_key) to event. | ||||
|     """ | ||||
|     if len(state_sets) == 1: | ||||
|         defer.returnValue(state_sets[0]) | ||||
| 
 | ||||
|     unconflicted_state, conflicted_state = _seperate( | ||||
|         state_sets, | ||||
|     ) | ||||
| 
 | ||||
|     needed_events = set( | ||||
|         event_id | ||||
|         for event_ids in conflicted_state.itervalues() | ||||
|  |  | |||
|  | @ -291,33 +291,33 @@ class SQLBaseStore(object): | |||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def runInteraction(self, desc, func, *args, **kwargs): | ||||
|         """Wraps the .runInteraction() method on the underlying db_pool.""" | ||||
|         current_context = LoggingContext.current_context() | ||||
|         """Starts a transaction on the database and runs a given function | ||||
| 
 | ||||
|         start_time = time.time() * 1000 | ||||
|         Arguments: | ||||
|             desc (str): description of the transaction, for logging and metrics | ||||
|             func (func): callback function, which will be called with a | ||||
|                 database transaction (twisted.enterprise.adbapi.Transaction) as | ||||
|                 its first argument, followed by `args` and `kwargs`. | ||||
| 
 | ||||
|             args (list): positional args to pass to `func` | ||||
|             kwargs (dict): named args to pass to `func` | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred: The result of func | ||||
|         """ | ||||
|         current_context = LoggingContext.current_context() | ||||
| 
 | ||||
|         after_callbacks = [] | ||||
|         final_callbacks = [] | ||||
| 
 | ||||
|         def inner_func(conn, *args, **kwargs): | ||||
|             with LoggingContext("runInteraction") as context: | ||||
|                 sql_scheduling_timer.inc_by(time.time() * 1000 - start_time) | ||||
| 
 | ||||
|                 if self.database_engine.is_connection_closed(conn): | ||||
|                     logger.debug("Reconnecting closed database connection") | ||||
|                     conn.reconnect() | ||||
| 
 | ||||
|                 current_context.copy_to(context) | ||||
|                 return self._new_transaction( | ||||
|                     conn, desc, after_callbacks, final_callbacks, current_context, | ||||
|                     func, *args, **kwargs | ||||
|                 ) | ||||
|             return self._new_transaction( | ||||
|                 conn, desc, after_callbacks, final_callbacks, current_context, | ||||
|                 func, *args, **kwargs | ||||
|             ) | ||||
| 
 | ||||
|         try: | ||||
|             with PreserveLoggingContext(): | ||||
|                 result = yield self._db_pool.runWithConnection( | ||||
|                     inner_func, *args, **kwargs | ||||
|                 ) | ||||
|             result = yield self.runWithConnection(inner_func, *args, **kwargs) | ||||
| 
 | ||||
|             for after_callback, after_args, after_kwargs in after_callbacks: | ||||
|                 after_callback(*after_args, **after_kwargs) | ||||
|  | @ -329,14 +329,27 @@ class SQLBaseStore(object): | |||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def runWithConnection(self, func, *args, **kwargs): | ||||
|         """Wraps the .runInteraction() method on the underlying db_pool.""" | ||||
|         """Wraps the .runWithConnection() method on the underlying db_pool. | ||||
| 
 | ||||
|         Arguments: | ||||
|             func (func): callback function, which will be called with a | ||||
|                 database connection (twisted.enterprise.adbapi.Connection) as | ||||
|                 its first argument, followed by `args` and `kwargs`. | ||||
|             args (list): positional args to pass to `func` | ||||
|             kwargs (dict): named args to pass to `func` | ||||
| 
 | ||||
|         Returns: | ||||
|             Deferred: The result of func | ||||
|         """ | ||||
|         current_context = LoggingContext.current_context() | ||||
| 
 | ||||
|         start_time = time.time() * 1000 | ||||
| 
 | ||||
|         def inner_func(conn, *args, **kwargs): | ||||
|             with LoggingContext("runWithConnection") as context: | ||||
|                 sql_scheduling_timer.inc_by(time.time() * 1000 - start_time) | ||||
|                 sched_duration_ms = time.time() * 1000 - start_time | ||||
|                 sql_scheduling_timer.inc_by(sched_duration_ms) | ||||
|                 current_context.add_database_scheduled(sched_duration_ms) | ||||
| 
 | ||||
|                 if self.database_engine.is_connection_closed(conn): | ||||
|                     logger.debug("Reconnecting closed database connection") | ||||
|  |  | |||
|  | @ -27,7 +27,7 @@ from synapse.util.logutils import log_function | |||
| from synapse.util.metrics import Measure | ||||
| from synapse.api.constants import EventTypes | ||||
| from synapse.api.errors import SynapseError | ||||
| from synapse.state import resolve_events | ||||
| from synapse.state import resolve_events_with_factory | ||||
| from synapse.util.caches.descriptors import cached | ||||
| from synapse.types import get_domain_from_id | ||||
| 
 | ||||
|  | @ -110,7 +110,7 @@ class _EventPeristenceQueue(object): | |||
|                 end_item.events_and_contexts.extend(events_and_contexts) | ||||
|                 return end_item.deferred.observe() | ||||
| 
 | ||||
|         deferred = ObservableDeferred(defer.Deferred()) | ||||
|         deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True) | ||||
| 
 | ||||
|         queue.append(self._EventPersistQueueItem( | ||||
|             events_and_contexts=events_and_contexts, | ||||
|  | @ -146,18 +146,25 @@ class _EventPeristenceQueue(object): | |||
|             try: | ||||
|                 queue = self._get_drainining_queue(room_id) | ||||
|                 for item in queue: | ||||
|                     # handle_queue_loop runs in the sentinel logcontext, so | ||||
|                     # there is no need to preserve_fn when running the | ||||
|                     # callbacks on the deferred. | ||||
|                     try: | ||||
|                         ret = yield per_item_callback(item) | ||||
|                         item.deferred.callback(ret) | ||||
|                     except Exception as e: | ||||
|                         item.deferred.errback(e) | ||||
|                     except Exception: | ||||
|                         item.deferred.errback() | ||||
|             finally: | ||||
|                 queue = self._event_persist_queues.pop(room_id, None) | ||||
|                 if queue: | ||||
|                     self._event_persist_queues[room_id] = queue | ||||
|                 self._currently_persisting_rooms.discard(room_id) | ||||
| 
 | ||||
|         preserve_fn(handle_queue_loop)() | ||||
|         # set handle_queue_loop off on the background. We don't want to | ||||
|         # attribute work done in it to the current request, so we drop the | ||||
|         # logcontext altogether. | ||||
|         with PreserveLoggingContext(): | ||||
|             handle_queue_loop() | ||||
| 
 | ||||
|     def _get_drainining_queue(self, room_id): | ||||
|         queue = self._event_persist_queues.setdefault(room_id, deque()) | ||||
|  | @ -528,6 +535,12 @@ class EventsStore(SQLBaseStore): | |||
|                 # the events we have yet to persist, so we need a slightly more | ||||
|                 # complicated event lookup function than simply looking the events | ||||
|                 # up in the db. | ||||
| 
 | ||||
|                 logger.info( | ||||
|                     "Resolving state for %s with %i state sets", | ||||
|                     room_id, len(state_sets), | ||||
|                 ) | ||||
| 
 | ||||
|                 events_map = {ev.event_id: ev for ev, _ in events_context} | ||||
| 
 | ||||
|                 @defer.inlineCallbacks | ||||
|  | @ -550,7 +563,7 @@ class EventsStore(SQLBaseStore): | |||
|                         to_return.update(evs) | ||||
|                     defer.returnValue(to_return) | ||||
| 
 | ||||
|                 current_state = yield resolve_events( | ||||
|                 current_state = yield resolve_events_with_factory( | ||||
|                     state_sets, | ||||
|                     state_map_factory=get_events, | ||||
|                 ) | ||||
|  |  | |||
|  | @ -29,9 +29,6 @@ class MediaRepositoryStore(BackgroundUpdateStore): | |||
|             where_clause='url_cache IS NOT NULL', | ||||
|         ) | ||||
| 
 | ||||
|     def get_default_thumbnails(self, top_level_type, sub_type): | ||||
|         return [] | ||||
| 
 | ||||
|     def get_local_media(self, media_id): | ||||
|         """Get the metadata for a local piece of media | ||||
|         Returns: | ||||
|  | @ -176,7 +173,14 @@ class MediaRepositoryStore(BackgroundUpdateStore): | |||
|             desc="store_cached_remote_media", | ||||
|         ) | ||||
| 
 | ||||
|     def update_cached_last_access_time(self, origin_id_tuples, time_ts): | ||||
|     def update_cached_last_access_time(self, local_media, remote_media, time_ms): | ||||
|         """Updates the last access time of the given media | ||||
| 
 | ||||
|         Args: | ||||
|             local_media (iterable[str]): Set of media_ids | ||||
|             remote_media (iterable[(str, str)]): Set of (server_name, media_id) | ||||
|             time_ms: Current time in milliseconds | ||||
|         """ | ||||
|         def update_cache_txn(txn): | ||||
|             sql = ( | ||||
|                 "UPDATE remote_media_cache SET last_access_ts = ?" | ||||
|  | @ -184,8 +188,18 @@ class MediaRepositoryStore(BackgroundUpdateStore): | |||
|             ) | ||||
| 
 | ||||
|             txn.executemany(sql, ( | ||||
|                 (time_ts, media_origin, media_id) | ||||
|                 for media_origin, media_id in origin_id_tuples | ||||
|                 (time_ms, media_origin, media_id) | ||||
|                 for media_origin, media_id in remote_media | ||||
|             )) | ||||
| 
 | ||||
|             sql = ( | ||||
|                 "UPDATE local_media_repository SET last_access_ts = ?" | ||||
|                 " WHERE media_id = ?" | ||||
|             ) | ||||
| 
 | ||||
|             txn.executemany(sql, ( | ||||
|                 (time_ms, media_id) | ||||
|                 for media_id in local_media | ||||
|             )) | ||||
| 
 | ||||
|         return self.runInteraction("update_cached_last_access_time", update_cache_txn) | ||||
|  |  | |||
|  | @ -25,7 +25,7 @@ logger = logging.getLogger(__name__) | |||
| 
 | ||||
| # Remember to update this number every time a change is made to database | ||||
| # schema files, so the users will be informed on server restarts. | ||||
| SCHEMA_VERSION = 46 | ||||
| SCHEMA_VERSION = 47 | ||||
| 
 | ||||
| dir_path = os.path.abspath(os.path.dirname(__file__)) | ||||
| 
 | ||||
|  |  | |||
|  | @ -591,7 +591,7 @@ class RoomStore(SQLBaseStore): | |||
|                     """ | ||||
|                         UPDATE remote_media_cache | ||||
|                         SET quarantined_by = ? | ||||
|                         WHERE media_origin AND media_id = ? | ||||
|                         WHERE media_origin = ? AND media_id = ? | ||||
|                     """, | ||||
|                     ( | ||||
|                         (quarantined_by, origin, media_id) | ||||
|  |  | |||
							
								
								
									
										16
									
								
								synapse/storage/schema/delta/47/last_access_media.sql
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								synapse/storage/schema/delta/47/last_access_media.sql
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,16 @@ | |||
| /* Copyright 2018 New Vector Ltd | ||||
|  * | ||||
|  * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  * you may not use this file except in compliance with the License. | ||||
|  * You may obtain a copy of the License at | ||||
|  * | ||||
|  *    http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  * See the License for the specific language governing permissions and | ||||
|  * limitations under the License. | ||||
|  */ | ||||
| 
 | ||||
| ALTER TABLE local_media_repository ADD COLUMN last_access_ts BIGINT; | ||||
|  | @ -641,8 +641,12 @@ class UserDirectoryStore(SQLBaseStore): | |||
|         """ | ||||
| 
 | ||||
|         if self.hs.config.user_directory_search_all_users: | ||||
|             join_clause = "" | ||||
|             where_clause = "?<>''"  # naughty hack to keep the same number of binds | ||||
|             # make s.user_id null to keep the ordering algorithm happy | ||||
|             join_clause = """ | ||||
|                 CROSS JOIN (SELECT NULL as user_id) AS s | ||||
|             """ | ||||
|             join_args = () | ||||
|             where_clause = "1=1" | ||||
|         else: | ||||
|             join_clause = """ | ||||
|                 LEFT JOIN users_in_public_rooms AS p USING (user_id) | ||||
|  | @ -651,6 +655,7 @@ class UserDirectoryStore(SQLBaseStore): | |||
|                     WHERE user_id = ? AND share_private | ||||
|                 ) AS s USING (user_id) | ||||
|             """ | ||||
|             join_args = (user_id,) | ||||
|             where_clause = "(s.user_id IS NOT NULL OR p.user_id IS NOT NULL)" | ||||
| 
 | ||||
|         if isinstance(self.database_engine, PostgresEngine): | ||||
|  | @ -692,7 +697,7 @@ class UserDirectoryStore(SQLBaseStore): | |||
|                     avatar_url IS NULL | ||||
|                 LIMIT ? | ||||
|             """ % (join_clause, where_clause) | ||||
|             args = (user_id, full_query, exact_query, prefix_query, limit + 1,) | ||||
|             args = join_args + (full_query, exact_query, prefix_query, limit + 1,) | ||||
|         elif isinstance(self.database_engine, Sqlite3Engine): | ||||
|             search_query = _parse_query_sqlite(search_term) | ||||
| 
 | ||||
|  | @ -710,7 +715,7 @@ class UserDirectoryStore(SQLBaseStore): | |||
|                     avatar_url IS NULL | ||||
|                 LIMIT ? | ||||
|             """ % (join_clause, where_clause) | ||||
|             args = (user_id, search_query, limit + 1) | ||||
|             args = join_args + (search_query, limit + 1) | ||||
|         else: | ||||
|             # This should be unreachable. | ||||
|             raise Exception("Unrecognized database engine") | ||||
|  |  | |||
							
								
								
									
										139
									
								
								synapse/util/file_consumer.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										139
									
								
								synapse/util/file_consumer.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,139 @@ | |||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2018 New Vector Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| from twisted.internet import threads, reactor | ||||
| 
 | ||||
| from synapse.util.logcontext import make_deferred_yieldable, preserve_fn | ||||
| 
 | ||||
| import Queue | ||||
| 
 | ||||
| 
 | ||||
| class BackgroundFileConsumer(object): | ||||
|     """A consumer that writes to a file like object. Supports both push | ||||
|     and pull producers | ||||
| 
 | ||||
|     Args: | ||||
|         file_obj (file): The file like object to write to. Closed when | ||||
|             finished. | ||||
|     """ | ||||
| 
 | ||||
|     # For PushProducers pause if we have this many unwritten slices | ||||
|     _PAUSE_ON_QUEUE_SIZE = 5 | ||||
|     # And resume once the size of the queue is less than this | ||||
|     _RESUME_ON_QUEUE_SIZE = 2 | ||||
| 
 | ||||
|     def __init__(self, file_obj): | ||||
|         self._file_obj = file_obj | ||||
| 
 | ||||
|         # Producer we're registered with | ||||
|         self._producer = None | ||||
| 
 | ||||
|         # True if PushProducer, false if PullProducer | ||||
|         self.streaming = False | ||||
| 
 | ||||
|         # For PushProducers, indicates whether we've paused the producer and | ||||
|         # need to call resumeProducing before we get more data. | ||||
|         self._paused_producer = False | ||||
| 
 | ||||
|         # Queue of slices of bytes to be written. When producer calls | ||||
|         # unregister a final None is sent. | ||||
|         self._bytes_queue = Queue.Queue() | ||||
| 
 | ||||
|         # Deferred that is resolved when finished writing | ||||
|         self._finished_deferred = None | ||||
| 
 | ||||
|         # If the _writer thread throws an exception it gets stored here. | ||||
|         self._write_exception = None | ||||
| 
 | ||||
|     def registerProducer(self, producer, streaming): | ||||
|         """Part of IConsumer interface | ||||
| 
 | ||||
|         Args: | ||||
|             producer (IProducer) | ||||
|             streaming (bool): True if push based producer, False if pull | ||||
|                 based. | ||||
|         """ | ||||
|         if self._producer: | ||||
|             raise Exception("registerProducer called twice") | ||||
| 
 | ||||
|         self._producer = producer | ||||
|         self.streaming = streaming | ||||
|         self._finished_deferred = preserve_fn(threads.deferToThread)(self._writer) | ||||
|         if not streaming: | ||||
|             self._producer.resumeProducing() | ||||
| 
 | ||||
|     def unregisterProducer(self): | ||||
|         """Part of IProducer interface | ||||
|         """ | ||||
|         self._producer = None | ||||
|         if not self._finished_deferred.called: | ||||
|             self._bytes_queue.put_nowait(None) | ||||
| 
 | ||||
|     def write(self, bytes): | ||||
|         """Part of IProducer interface | ||||
|         """ | ||||
|         if self._write_exception: | ||||
|             raise self._write_exception | ||||
| 
 | ||||
|         if self._finished_deferred.called: | ||||
|             raise Exception("consumer has closed") | ||||
| 
 | ||||
|         self._bytes_queue.put_nowait(bytes) | ||||
| 
 | ||||
|         # If this is a PushProducer and the queue is getting behind | ||||
|         # then we pause the producer. | ||||
|         if self.streaming and self._bytes_queue.qsize() >= self._PAUSE_ON_QUEUE_SIZE: | ||||
|             self._paused_producer = True | ||||
|             self._producer.pauseProducing() | ||||
| 
 | ||||
|     def _writer(self): | ||||
|         """This is run in a background thread to write to the file. | ||||
|         """ | ||||
|         try: | ||||
|             while self._producer or not self._bytes_queue.empty(): | ||||
|                 # If we've paused the producer check if we should resume the | ||||
|                 # producer. | ||||
|                 if self._producer and self._paused_producer: | ||||
|                     if self._bytes_queue.qsize() <= self._RESUME_ON_QUEUE_SIZE: | ||||
|                         reactor.callFromThread(self._resume_paused_producer) | ||||
| 
 | ||||
|                 bytes = self._bytes_queue.get() | ||||
| 
 | ||||
|                 # If we get a None (or empty list) then that's a signal used | ||||
|                 # to indicate we should check if we should stop. | ||||
|                 if bytes: | ||||
|                     self._file_obj.write(bytes) | ||||
| 
 | ||||
|                 # If its a pull producer then we need to explicitly ask for | ||||
|                 # more stuff. | ||||
|                 if not self.streaming and self._producer: | ||||
|                     reactor.callFromThread(self._producer.resumeProducing) | ||||
|         except Exception as e: | ||||
|             self._write_exception = e | ||||
|             raise | ||||
|         finally: | ||||
|             self._file_obj.close() | ||||
| 
 | ||||
|     def wait(self): | ||||
|         """Returns a deferred that resolves when finished writing to file | ||||
|         """ | ||||
|         return make_deferred_yieldable(self._finished_deferred) | ||||
| 
 | ||||
|     def _resume_paused_producer(self): | ||||
|         """Gets called if we should resume producing after being paused | ||||
|         """ | ||||
|         if self._paused_producer and self._producer: | ||||
|             self._paused_producer = False | ||||
|             self._producer.resumeProducing() | ||||
|  | @ -52,13 +52,17 @@ except Exception: | |||
| class LoggingContext(object): | ||||
|     """Additional context for log formatting. Contexts are scoped within a | ||||
|     "with" block. | ||||
| 
 | ||||
|     Args: | ||||
|         name (str): Name for the context for debugging. | ||||
|     """ | ||||
| 
 | ||||
|     __slots__ = [ | ||||
|         "previous_context", "name", "usage_start", "usage_end", "main_thread", | ||||
|         "__dict__", "tag", "alive", | ||||
|         "previous_context", "name", "ru_stime", "ru_utime", | ||||
|         "db_txn_count", "db_txn_duration_ms", "db_sched_duration_ms", | ||||
|         "usage_start", "usage_end", | ||||
|         "main_thread", "alive", | ||||
|         "request", "tag", | ||||
|     ] | ||||
| 
 | ||||
|     thread_local = threading.local() | ||||
|  | @ -83,6 +87,9 @@ class LoggingContext(object): | |||
|         def add_database_transaction(self, duration_ms): | ||||
|             pass | ||||
| 
 | ||||
|         def add_database_scheduled(self, sched_ms): | ||||
|             pass | ||||
| 
 | ||||
|         def __nonzero__(self): | ||||
|             return False | ||||
| 
 | ||||
|  | @ -94,9 +101,17 @@ class LoggingContext(object): | |||
|         self.ru_stime = 0. | ||||
|         self.ru_utime = 0. | ||||
|         self.db_txn_count = 0 | ||||
|         self.db_txn_duration = 0. | ||||
| 
 | ||||
|         # ms spent waiting for db txns, excluding scheduling time | ||||
|         self.db_txn_duration_ms = 0 | ||||
| 
 | ||||
|         # ms spent waiting for db txns to be scheduled | ||||
|         self.db_sched_duration_ms = 0 | ||||
| 
 | ||||
|         self.usage_start = None | ||||
|         self.usage_end = None | ||||
|         self.main_thread = threading.current_thread() | ||||
|         self.request = None | ||||
|         self.tag = "" | ||||
|         self.alive = True | ||||
| 
 | ||||
|  | @ -105,7 +120,11 @@ class LoggingContext(object): | |||
| 
 | ||||
|     @classmethod | ||||
|     def current_context(cls): | ||||
|         """Get the current logging context from thread local storage""" | ||||
|         """Get the current logging context from thread local storage | ||||
| 
 | ||||
|         Returns: | ||||
|             LoggingContext: the current logging context | ||||
|         """ | ||||
|         return getattr(cls.thread_local, "current_context", cls.sentinel) | ||||
| 
 | ||||
|     @classmethod | ||||
|  | @ -155,11 +174,13 @@ class LoggingContext(object): | |||
|         self.alive = False | ||||
| 
 | ||||
|     def copy_to(self, record): | ||||
|         """Copy fields from this context to the record""" | ||||
|         for key, value in self.__dict__.items(): | ||||
|             setattr(record, key, value) | ||||
|         """Copy logging fields from this context to a log record or | ||||
|         another LoggingContext | ||||
|         """ | ||||
| 
 | ||||
|         record.ru_utime, record.ru_stime = self.get_resource_usage() | ||||
|         # 'request' is the only field we currently use in the logger, so that's | ||||
|         # all we need to copy | ||||
|         record.request = self.request | ||||
| 
 | ||||
|     def start(self): | ||||
|         if threading.current_thread() is not self.main_thread: | ||||
|  | @ -194,7 +215,16 @@ class LoggingContext(object): | |||
| 
 | ||||
|     def add_database_transaction(self, duration_ms): | ||||
|         self.db_txn_count += 1 | ||||
|         self.db_txn_duration += duration_ms / 1000. | ||||
|         self.db_txn_duration_ms += duration_ms | ||||
| 
 | ||||
|     def add_database_scheduled(self, sched_ms): | ||||
|         """Record a use of the database pool | ||||
| 
 | ||||
|         Args: | ||||
|             sched_ms (int): number of milliseconds it took us to get a | ||||
|                 connection | ||||
|         """ | ||||
|         self.db_sched_duration_ms += sched_ms | ||||
| 
 | ||||
| 
 | ||||
| class LoggingContextFilter(logging.Filter): | ||||
|  |  | |||
|  | @ -27,25 +27,62 @@ logger = logging.getLogger(__name__) | |||
| 
 | ||||
| metrics = synapse.metrics.get_metrics_for(__name__) | ||||
| 
 | ||||
| block_timer = metrics.register_distribution( | ||||
|     "block_timer", | ||||
|     labels=["block_name"] | ||||
| # total number of times we have hit this block | ||||
| block_counter = metrics.register_counter( | ||||
|     "block_count", | ||||
|     labels=["block_name"], | ||||
|     alternative_names=( | ||||
|         # the following are all deprecated aliases for the same metric | ||||
|         metrics.name_prefix + x for x in ( | ||||
|             "_block_timer:count", | ||||
|             "_block_ru_utime:count", | ||||
|             "_block_ru_stime:count", | ||||
|             "_block_db_txn_count:count", | ||||
|             "_block_db_txn_duration:count", | ||||
|         ) | ||||
|     ) | ||||
| ) | ||||
| 
 | ||||
| block_ru_utime = metrics.register_distribution( | ||||
|     "block_ru_utime", labels=["block_name"] | ||||
| block_timer = metrics.register_counter( | ||||
|     "block_time_seconds", | ||||
|     labels=["block_name"], | ||||
|     alternative_names=( | ||||
|         metrics.name_prefix + "_block_timer:total", | ||||
|     ), | ||||
| ) | ||||
| 
 | ||||
| block_ru_stime = metrics.register_distribution( | ||||
|     "block_ru_stime", labels=["block_name"] | ||||
| block_ru_utime = metrics.register_counter( | ||||
|     "block_ru_utime_seconds", labels=["block_name"], | ||||
|     alternative_names=( | ||||
|         metrics.name_prefix + "_block_ru_utime:total", | ||||
|     ), | ||||
| ) | ||||
| 
 | ||||
| block_db_txn_count = metrics.register_distribution( | ||||
|     "block_db_txn_count", labels=["block_name"] | ||||
| block_ru_stime = metrics.register_counter( | ||||
|     "block_ru_stime_seconds", labels=["block_name"], | ||||
|     alternative_names=( | ||||
|         metrics.name_prefix + "_block_ru_stime:total", | ||||
|     ), | ||||
| ) | ||||
| 
 | ||||
| block_db_txn_duration = metrics.register_distribution( | ||||
|     "block_db_txn_duration", labels=["block_name"] | ||||
| block_db_txn_count = metrics.register_counter( | ||||
|     "block_db_txn_count", labels=["block_name"], | ||||
|     alternative_names=( | ||||
|         metrics.name_prefix + "_block_db_txn_count:total", | ||||
|     ), | ||||
| ) | ||||
| 
 | ||||
| # seconds spent waiting for db txns, excluding scheduling time, in this block | ||||
| block_db_txn_duration = metrics.register_counter( | ||||
|     "block_db_txn_duration_seconds", labels=["block_name"], | ||||
|     alternative_names=( | ||||
|         metrics.name_prefix + "_block_db_txn_duration:total", | ||||
|     ), | ||||
| ) | ||||
| 
 | ||||
| # seconds spent waiting for a db connection, in this block | ||||
| block_db_sched_duration = metrics.register_counter( | ||||
|     "block_db_sched_duration_seconds", labels=["block_name"], | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
|  | @ -64,7 +101,9 @@ def measure_func(name): | |||
| class Measure(object): | ||||
|     __slots__ = [ | ||||
|         "clock", "name", "start_context", "start", "new_context", "ru_utime", | ||||
|         "ru_stime", "db_txn_count", "db_txn_duration", "created_context" | ||||
|         "ru_stime", | ||||
|         "db_txn_count", "db_txn_duration_ms", "db_sched_duration_ms", | ||||
|         "created_context", | ||||
|     ] | ||||
| 
 | ||||
|     def __init__(self, clock, name): | ||||
|  | @ -84,13 +123,16 @@ class Measure(object): | |||
| 
 | ||||
|         self.ru_utime, self.ru_stime = self.start_context.get_resource_usage() | ||||
|         self.db_txn_count = self.start_context.db_txn_count | ||||
|         self.db_txn_duration = self.start_context.db_txn_duration | ||||
|         self.db_txn_duration_ms = self.start_context.db_txn_duration_ms | ||||
|         self.db_sched_duration_ms = self.start_context.db_sched_duration_ms | ||||
| 
 | ||||
|     def __exit__(self, exc_type, exc_val, exc_tb): | ||||
|         if isinstance(exc_type, Exception) or not self.start_context: | ||||
|             return | ||||
| 
 | ||||
|         duration = self.clock.time_msec() - self.start | ||||
| 
 | ||||
|         block_counter.inc(self.name) | ||||
|         block_timer.inc_by(duration, self.name) | ||||
| 
 | ||||
|         context = LoggingContext.current_context() | ||||
|  | @ -114,7 +156,12 @@ class Measure(object): | |||
|             context.db_txn_count - self.db_txn_count, self.name | ||||
|         ) | ||||
|         block_db_txn_duration.inc_by( | ||||
|             context.db_txn_duration - self.db_txn_duration, self.name | ||||
|             (context.db_txn_duration_ms - self.db_txn_duration_ms) / 1000., | ||||
|             self.name | ||||
|         ) | ||||
|         block_db_sched_duration.inc_by( | ||||
|             (context.db_sched_duration_ms - self.db_sched_duration_ms) / 1000., | ||||
|             self.name | ||||
|         ) | ||||
| 
 | ||||
|         if self.created_context: | ||||
|  |  | |||
|  | @ -26,6 +26,18 @@ logger = logging.getLogger(__name__) | |||
| 
 | ||||
| class NotRetryingDestination(Exception): | ||||
|     def __init__(self, retry_last_ts, retry_interval, destination): | ||||
|         """Raised by the limiter (and federation client) to indicate that we are | ||||
|         are deliberately not attempting to contact a given server. | ||||
| 
 | ||||
|         Args: | ||||
|             retry_last_ts (int): the unix ts in milliseconds of our last attempt | ||||
|                 to contact the server.  0 indicates that the last attempt was | ||||
|                 successful or that we've never actually attempted to connect. | ||||
|             retry_interval (int): the time in milliseconds to wait until the next | ||||
|                 attempt. | ||||
|             destination (str): the domain in question | ||||
|         """ | ||||
| 
 | ||||
|         msg = "Not retrying server %s." % (destination,) | ||||
|         super(NotRetryingDestination, self).__init__(msg) | ||||
| 
 | ||||
|  |  | |||
							
								
								
									
										48
									
								
								synapse/util/threepids.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										48
									
								
								synapse/util/threepids.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,48 @@ | |||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2018 New Vector Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| import logging | ||||
| import re | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| def check_3pid_allowed(hs, medium, address): | ||||
|     """Checks whether a given format of 3PID is allowed to be used on this HS | ||||
| 
 | ||||
|     Args: | ||||
|         hs (synapse.server.HomeServer): server | ||||
|         medium (str): 3pid medium - e.g. email, msisdn | ||||
|         address (str): address within that medium (e.g. "wotan@matrix.org") | ||||
|             msisdns need to first have been canonicalised | ||||
|     Returns: | ||||
|         bool: whether the 3PID medium/address is allowed to be added to this HS | ||||
|     """ | ||||
| 
 | ||||
|     if hs.config.allowed_local_3pids: | ||||
|         for constraint in hs.config.allowed_local_3pids: | ||||
|             logger.debug( | ||||
|                 "Checking 3PID %s (%s) against %s (%s)", | ||||
|                 address, medium, constraint['pattern'], constraint['medium'], | ||||
|             ) | ||||
|             if ( | ||||
|                 medium == constraint['medium'] and | ||||
|                 re.match(constraint['pattern'], address) | ||||
|             ): | ||||
|                 return True | ||||
|     else: | ||||
|         return True | ||||
| 
 | ||||
|     return False | ||||
|  | @ -12,9 +12,12 @@ | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| import os.path | ||||
| import re | ||||
| import shutil | ||||
| import tempfile | ||||
| 
 | ||||
| from synapse.config.homeserver import HomeServerConfig | ||||
| from tests import unittest | ||||
| 
 | ||||
|  | @ -23,7 +26,6 @@ class ConfigGenerationTestCase(unittest.TestCase): | |||
| 
 | ||||
|     def setUp(self): | ||||
|         self.dir = tempfile.mkdtemp() | ||||
|         print self.dir | ||||
|         self.file = os.path.join(self.dir, "homeserver.yaml") | ||||
| 
 | ||||
|     def tearDown(self): | ||||
|  | @ -48,3 +50,16 @@ class ConfigGenerationTestCase(unittest.TestCase): | |||
|             ]), | ||||
|             set(os.listdir(self.dir)) | ||||
|         ) | ||||
| 
 | ||||
|         self.assert_log_filename_is( | ||||
|             os.path.join(self.dir, "lemurs.win.log.config"), | ||||
|             os.path.join(os.getcwd(), "homeserver.log"), | ||||
|         ) | ||||
| 
 | ||||
|     def assert_log_filename_is(self, log_config_file, expected): | ||||
|         with open(log_config_file) as f: | ||||
|             config = f.read() | ||||
|             # find the 'filename' line | ||||
|             matches = re.findall("^\s*filename:\s*(.*)$", config, re.M) | ||||
|             self.assertEqual(1, len(matches)) | ||||
|             self.assertEqual(matches[0], expected) | ||||
|  |  | |||
|  | @ -68,7 +68,7 @@ class KeyringTestCase(unittest.TestCase): | |||
| 
 | ||||
|     def check_context(self, _, expected): | ||||
|         self.assertEquals( | ||||
|             getattr(LoggingContext.current_context(), "test_key", None), | ||||
|             getattr(LoggingContext.current_context(), "request", None), | ||||
|             expected | ||||
|         ) | ||||
| 
 | ||||
|  | @ -82,7 +82,7 @@ class KeyringTestCase(unittest.TestCase): | |||
|         lookup_2_deferred = defer.Deferred() | ||||
| 
 | ||||
|         with LoggingContext("one") as context_one: | ||||
|             context_one.test_key = "one" | ||||
|             context_one.request = "one" | ||||
| 
 | ||||
|             wait_1_deferred = kr.wait_for_previous_lookups( | ||||
|                 ["server1"], | ||||
|  | @ -96,7 +96,7 @@ class KeyringTestCase(unittest.TestCase): | |||
|             wait_1_deferred.addBoth(self.check_context, "one") | ||||
| 
 | ||||
|         with LoggingContext("two") as context_two: | ||||
|             context_two.test_key = "two" | ||||
|             context_two.request = "two" | ||||
| 
 | ||||
|             # set off another wait. It should block because the first lookup | ||||
|             # hasn't yet completed. | ||||
|  | @ -137,7 +137,7 @@ class KeyringTestCase(unittest.TestCase): | |||
|         @defer.inlineCallbacks | ||||
|         def get_perspectives(**kwargs): | ||||
|             self.assertEquals( | ||||
|                 LoggingContext.current_context().test_key, "11", | ||||
|                 LoggingContext.current_context().request, "11", | ||||
|             ) | ||||
|             with logcontext.PreserveLoggingContext(): | ||||
|                 yield persp_deferred | ||||
|  | @ -145,7 +145,7 @@ class KeyringTestCase(unittest.TestCase): | |||
|         self.http_client.post_json.side_effect = get_perspectives | ||||
| 
 | ||||
|         with LoggingContext("11") as context_11: | ||||
|             context_11.test_key = "11" | ||||
|             context_11.request = "11" | ||||
| 
 | ||||
|             # start off a first set of lookups | ||||
|             res_deferreds = kr.verify_json_objects_for_server( | ||||
|  | @ -167,13 +167,13 @@ class KeyringTestCase(unittest.TestCase): | |||
| 
 | ||||
|             # wait a tick for it to send the request to the perspectives server | ||||
|             # (it first tries the datastore) | ||||
|             yield async.sleep(0.005) | ||||
|             yield async.sleep(1)   # XXX find out why this takes so long! | ||||
|             self.http_client.post_json.assert_called_once() | ||||
| 
 | ||||
|             self.assertIs(LoggingContext.current_context(), context_11) | ||||
| 
 | ||||
|             context_12 = LoggingContext("12") | ||||
|             context_12.test_key = "12" | ||||
|             context_12.request = "12" | ||||
|             with logcontext.PreserveLoggingContext(context_12): | ||||
|                 # a second request for a server with outstanding requests | ||||
|                 # should block rather than start a second call | ||||
|  | @ -183,7 +183,7 @@ class KeyringTestCase(unittest.TestCase): | |||
|                 res_deferreds_2 = kr.verify_json_objects_for_server( | ||||
|                     [("server10", json1)], | ||||
|                 ) | ||||
|                 yield async.sleep(0.005) | ||||
|                 yield async.sleep(01) | ||||
|                 self.http_client.post_json.assert_not_called() | ||||
|                 res_deferreds_2[0].addBoth(self.check_context, None) | ||||
| 
 | ||||
|  | @ -211,7 +211,7 @@ class KeyringTestCase(unittest.TestCase): | |||
|         sentinel_context = LoggingContext.current_context() | ||||
| 
 | ||||
|         with LoggingContext("one") as context_one: | ||||
|             context_one.test_key = "one" | ||||
|             context_one.request = "one" | ||||
| 
 | ||||
|             defer = kr.verify_json_for_server("server9", {}) | ||||
|             try: | ||||
|  |  | |||
|  | @ -143,7 +143,6 @@ class E2eKeysHandlerTestCase(unittest.TestCase): | |||
|         except errors.SynapseError: | ||||
|             pass | ||||
| 
 | ||||
|     @unittest.DEBUG | ||||
|     @defer.inlineCallbacks | ||||
|     def test_claim_one_time_key(self): | ||||
|         local_user = "@boris:" + self.hs.hostname | ||||
|  |  | |||
|  | @ -15,6 +15,8 @@ | |||
| from twisted.internet import defer, reactor | ||||
| from tests import unittest | ||||
| 
 | ||||
| import tempfile | ||||
| 
 | ||||
| from mock import Mock, NonCallableMock | ||||
| from tests.utils import setup_test_homeserver | ||||
| from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory | ||||
|  | @ -41,7 +43,9 @@ class BaseSlavedStoreTestCase(unittest.TestCase): | |||
|         self.event_id = 0 | ||||
| 
 | ||||
|         server_factory = ReplicationStreamProtocolFactory(self.hs) | ||||
|         listener = reactor.listenUNIX("\0xxx", server_factory) | ||||
|         # XXX: mktemp is unsafe and should never be used. but we're just a test. | ||||
|         path = tempfile.mktemp(prefix="base_slaved_store_test_case_socket") | ||||
|         listener = reactor.listenUNIX(path, server_factory) | ||||
|         self.addCleanup(listener.stopListening) | ||||
|         self.streamer = server_factory.streamer | ||||
| 
 | ||||
|  | @ -49,7 +53,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase): | |||
|         client_factory = ReplicationClientFactory( | ||||
|             self.hs, "client_name", self.replication_handler | ||||
|         ) | ||||
|         client_connector = reactor.connectUNIX("\0xxx", client_factory) | ||||
|         client_connector = reactor.connectUNIX(path, client_factory) | ||||
|         self.addCleanup(client_factory.stopTrying) | ||||
|         self.addCleanup(client_connector.disconnect) | ||||
| 
 | ||||
|  |  | |||
|  | @ -515,9 +515,6 @@ class RoomsCreateTestCase(RestTestCase): | |||
| 
 | ||||
|         synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource) | ||||
| 
 | ||||
|     def tearDown(self): | ||||
|         pass | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def test_post_room_no_keys(self): | ||||
|         # POST with no config keys, expect new room id | ||||
|  |  | |||
|  | @ -49,6 +49,7 @@ class RegisterRestServletTestCase(unittest.TestCase): | |||
|         self.hs.get_auth_handler = Mock(return_value=self.auth_handler) | ||||
|         self.hs.get_device_handler = Mock(return_value=self.device_handler) | ||||
|         self.hs.config.enable_registration = True | ||||
|         self.hs.config.registrations_require_3pid = [] | ||||
|         self.hs.config.auto_join_rooms = [] | ||||
| 
 | ||||
|         # init the thing we're testing | ||||
|  |  | |||
							
								
								
									
										88
									
								
								tests/storage/test_user_directory.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										88
									
								
								tests/storage/test_user_directory.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,88 @@ | |||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2018 New Vector Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| from twisted.internet import defer | ||||
| 
 | ||||
| from synapse.storage import UserDirectoryStore | ||||
| from synapse.storage.roommember import ProfileInfo | ||||
| from tests import unittest | ||||
| from tests.utils import setup_test_homeserver | ||||
| 
 | ||||
| ALICE = "@alice:a" | ||||
| BOB = "@bob:b" | ||||
| BOBBY = "@bobby:a" | ||||
| 
 | ||||
| 
 | ||||
| class UserDirectoryStoreTestCase(unittest.TestCase): | ||||
|     @defer.inlineCallbacks | ||||
|     def setUp(self): | ||||
|         self.hs = yield setup_test_homeserver() | ||||
|         self.store = UserDirectoryStore(None, self.hs) | ||||
| 
 | ||||
|         # alice and bob are both in !room_id. bobby is not but shares | ||||
|         # a homeserver with alice. | ||||
|         yield self.store.add_profiles_to_user_dir( | ||||
|             "!room:id", | ||||
|             { | ||||
|                 ALICE: ProfileInfo(None, "alice"), | ||||
|                 BOB: ProfileInfo(None, "bob"), | ||||
|                 BOBBY: ProfileInfo(None, "bobby") | ||||
|             }, | ||||
|         ) | ||||
|         yield self.store.add_users_to_public_room( | ||||
|             "!room:id", | ||||
|             [ALICE, BOB], | ||||
|         ) | ||||
|         yield self.store.add_users_who_share_room( | ||||
|             "!room:id", | ||||
|             False, | ||||
|             ( | ||||
|                 (ALICE, BOB), | ||||
|                 (BOB, ALICE), | ||||
|             ), | ||||
|         ) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def test_search_user_dir(self): | ||||
|         # normally when alice searches the directory she should just find | ||||
|         # bob because bobby doesn't share a room with her. | ||||
|         r = yield self.store.search_user_dir(ALICE, "bob", 10) | ||||
|         self.assertFalse(r["limited"]) | ||||
|         self.assertEqual(1, len(r["results"])) | ||||
|         self.assertDictEqual(r["results"][0], { | ||||
|             "user_id": BOB, | ||||
|             "display_name": "bob", | ||||
|             "avatar_url": None, | ||||
|         }) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def test_search_user_dir_all_users(self): | ||||
|         self.hs.config.user_directory_search_all_users = True | ||||
|         try: | ||||
|             r = yield self.store.search_user_dir(ALICE, "bob", 10) | ||||
|             self.assertFalse(r["limited"]) | ||||
|             self.assertEqual(2, len(r["results"])) | ||||
|             self.assertDictEqual(r["results"][0], { | ||||
|                 "user_id": BOB, | ||||
|                 "display_name": "bob", | ||||
|                 "avatar_url": None, | ||||
|             }) | ||||
|             self.assertDictEqual(r["results"][1], { | ||||
|                 "user_id": BOBBY, | ||||
|                 "display_name": "bobby", | ||||
|                 "avatar_url": None, | ||||
|             }) | ||||
|         finally: | ||||
|             self.hs.config.user_directory_search_all_users = False | ||||
|  | @ -12,7 +12,7 @@ | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| import twisted | ||||
| from twisted.trial import unittest | ||||
| 
 | ||||
| import logging | ||||
|  | @ -65,6 +65,10 @@ class TestCase(unittest.TestCase): | |||
| 
 | ||||
|         @around(self) | ||||
|         def setUp(orig): | ||||
|             # enable debugging of delayed calls - this means that we get a | ||||
|             # traceback when a unit test exits leaving things on the reactor. | ||||
|             twisted.internet.base.DelayedCall.debug = True | ||||
| 
 | ||||
|             old_level = logging.getLogger().level | ||||
| 
 | ||||
|             if old_level != level: | ||||
|  |  | |||
							
								
								
									
										176
									
								
								tests/util/test_file_consumer.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										176
									
								
								tests/util/test_file_consumer.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,176 @@ | |||
| # -*- coding: utf-8 -*- | ||||
| # Copyright 2018 New Vector Ltd | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| 
 | ||||
| from twisted.internet import defer, reactor | ||||
| from mock import NonCallableMock | ||||
| 
 | ||||
| from synapse.util.file_consumer import BackgroundFileConsumer | ||||
| 
 | ||||
| from tests import unittest | ||||
| from StringIO import StringIO | ||||
| 
 | ||||
| import threading | ||||
| 
 | ||||
| 
 | ||||
| class FileConsumerTests(unittest.TestCase): | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def test_pull_consumer(self): | ||||
|         string_file = StringIO() | ||||
|         consumer = BackgroundFileConsumer(string_file) | ||||
| 
 | ||||
|         try: | ||||
|             producer = DummyPullProducer() | ||||
| 
 | ||||
|             yield producer.register_with_consumer(consumer) | ||||
| 
 | ||||
|             yield producer.write_and_wait("Foo") | ||||
| 
 | ||||
|             self.assertEqual(string_file.getvalue(), "Foo") | ||||
| 
 | ||||
|             yield producer.write_and_wait("Bar") | ||||
| 
 | ||||
|             self.assertEqual(string_file.getvalue(), "FooBar") | ||||
|         finally: | ||||
|             consumer.unregisterProducer() | ||||
| 
 | ||||
|         yield consumer.wait() | ||||
| 
 | ||||
|         self.assertTrue(string_file.closed) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def test_push_consumer(self): | ||||
|         string_file = BlockingStringWrite() | ||||
|         consumer = BackgroundFileConsumer(string_file) | ||||
| 
 | ||||
|         try: | ||||
|             producer = NonCallableMock(spec_set=[]) | ||||
| 
 | ||||
|             consumer.registerProducer(producer, True) | ||||
| 
 | ||||
|             consumer.write("Foo") | ||||
|             yield string_file.wait_for_n_writes(1) | ||||
| 
 | ||||
|             self.assertEqual(string_file.buffer, "Foo") | ||||
| 
 | ||||
|             consumer.write("Bar") | ||||
|             yield string_file.wait_for_n_writes(2) | ||||
| 
 | ||||
|             self.assertEqual(string_file.buffer, "FooBar") | ||||
|         finally: | ||||
|             consumer.unregisterProducer() | ||||
| 
 | ||||
|         yield consumer.wait() | ||||
| 
 | ||||
|         self.assertTrue(string_file.closed) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def test_push_producer_feedback(self): | ||||
|         string_file = BlockingStringWrite() | ||||
|         consumer = BackgroundFileConsumer(string_file) | ||||
| 
 | ||||
|         try: | ||||
|             producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"]) | ||||
| 
 | ||||
|             resume_deferred = defer.Deferred() | ||||
|             producer.resumeProducing.side_effect = lambda: resume_deferred.callback(None) | ||||
| 
 | ||||
|             consumer.registerProducer(producer, True) | ||||
| 
 | ||||
|             number_writes = 0 | ||||
|             with string_file.write_lock: | ||||
|                 for _ in range(consumer._PAUSE_ON_QUEUE_SIZE): | ||||
|                     consumer.write("Foo") | ||||
|                     number_writes += 1 | ||||
| 
 | ||||
|                 producer.pauseProducing.assert_called_once() | ||||
| 
 | ||||
|             yield string_file.wait_for_n_writes(number_writes) | ||||
| 
 | ||||
|             yield resume_deferred | ||||
|             producer.resumeProducing.assert_called_once() | ||||
|         finally: | ||||
|             consumer.unregisterProducer() | ||||
| 
 | ||||
|         yield consumer.wait() | ||||
| 
 | ||||
|         self.assertTrue(string_file.closed) | ||||
| 
 | ||||
| 
 | ||||
| class DummyPullProducer(object): | ||||
|     def __init__(self): | ||||
|         self.consumer = None | ||||
|         self.deferred = defer.Deferred() | ||||
| 
 | ||||
|     def resumeProducing(self): | ||||
|         d = self.deferred | ||||
|         self.deferred = defer.Deferred() | ||||
|         d.callback(None) | ||||
| 
 | ||||
|     def write_and_wait(self, bytes): | ||||
|         d = self.deferred | ||||
|         self.consumer.write(bytes) | ||||
|         return d | ||||
| 
 | ||||
|     def register_with_consumer(self, consumer): | ||||
|         d = self.deferred | ||||
|         self.consumer = consumer | ||||
|         self.consumer.registerProducer(self, False) | ||||
|         return d | ||||
| 
 | ||||
| 
 | ||||
| class BlockingStringWrite(object): | ||||
|     def __init__(self): | ||||
|         self.buffer = "" | ||||
|         self.closed = False | ||||
|         self.write_lock = threading.Lock() | ||||
| 
 | ||||
|         self._notify_write_deferred = None | ||||
|         self._number_of_writes = 0 | ||||
| 
 | ||||
|     def write(self, bytes): | ||||
|         with self.write_lock: | ||||
|             self.buffer += bytes | ||||
|             self._number_of_writes += 1 | ||||
| 
 | ||||
|         reactor.callFromThread(self._notify_write) | ||||
| 
 | ||||
|     def close(self): | ||||
|         self.closed = True | ||||
| 
 | ||||
|     def _notify_write(self): | ||||
|         "Called by write to indicate a write happened" | ||||
|         with self.write_lock: | ||||
|             if not self._notify_write_deferred: | ||||
|                 return | ||||
|             d = self._notify_write_deferred | ||||
|             self._notify_write_deferred = None | ||||
|         d.callback(None) | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|     def wait_for_n_writes(self, n): | ||||
|         "Wait for n writes to have happened" | ||||
|         while True: | ||||
|             with self.write_lock: | ||||
|                 if n <= self._number_of_writes: | ||||
|                     return | ||||
| 
 | ||||
|                 if not self._notify_write_deferred: | ||||
|                     self._notify_write_deferred = defer.Deferred() | ||||
| 
 | ||||
|                 d = self._notify_write_deferred | ||||
| 
 | ||||
|             yield d | ||||
|  | @ -12,12 +12,12 @@ class LoggingContextTestCase(unittest.TestCase): | |||
| 
 | ||||
|     def _check_test_key(self, value): | ||||
|         self.assertEquals( | ||||
|             LoggingContext.current_context().test_key, value | ||||
|             LoggingContext.current_context().request, value | ||||
|         ) | ||||
| 
 | ||||
|     def test_with_context(self): | ||||
|         with LoggingContext() as context_one: | ||||
|             context_one.test_key = "test" | ||||
|             context_one.request = "test" | ||||
|             self._check_test_key("test") | ||||
| 
 | ||||
|     @defer.inlineCallbacks | ||||
|  | @ -25,14 +25,14 @@ class LoggingContextTestCase(unittest.TestCase): | |||
|         @defer.inlineCallbacks | ||||
|         def competing_callback(): | ||||
|             with LoggingContext() as competing_context: | ||||
|                 competing_context.test_key = "competing" | ||||
|                 competing_context.request = "competing" | ||||
|                 yield sleep(0) | ||||
|                 self._check_test_key("competing") | ||||
| 
 | ||||
|         reactor.callLater(0, competing_callback) | ||||
| 
 | ||||
|         with LoggingContext() as context_one: | ||||
|             context_one.test_key = "one" | ||||
|             context_one.request = "one" | ||||
|             yield sleep(0) | ||||
|             self._check_test_key("one") | ||||
| 
 | ||||
|  | @ -43,14 +43,14 @@ class LoggingContextTestCase(unittest.TestCase): | |||
| 
 | ||||
|         @defer.inlineCallbacks | ||||
|         def cb(): | ||||
|             context_one.test_key = "one" | ||||
|             context_one.request = "one" | ||||
|             yield function() | ||||
|             self._check_test_key("one") | ||||
| 
 | ||||
|             callback_completed[0] = True | ||||
| 
 | ||||
|         with LoggingContext() as context_one: | ||||
|             context_one.test_key = "one" | ||||
|             context_one.request = "one" | ||||
| 
 | ||||
|             # fire off function, but don't wait on it. | ||||
|             logcontext.preserve_fn(cb)() | ||||
|  | @ -107,7 +107,7 @@ class LoggingContextTestCase(unittest.TestCase): | |||
|         sentinel_context = LoggingContext.current_context() | ||||
| 
 | ||||
|         with LoggingContext() as context_one: | ||||
|             context_one.test_key = "one" | ||||
|             context_one.request = "one" | ||||
| 
 | ||||
|             d1 = logcontext.make_deferred_yieldable(blocking_function()) | ||||
|             # make sure that the context was reset by make_deferred_yieldable | ||||
|  | @ -124,7 +124,7 @@ class LoggingContextTestCase(unittest.TestCase): | |||
|         argument isn't actually a deferred""" | ||||
| 
 | ||||
|         with LoggingContext() as context_one: | ||||
|             context_one.test_key = "one" | ||||
|             context_one.request = "one" | ||||
| 
 | ||||
|             d1 = logcontext.make_deferred_yieldable("bum") | ||||
|             self._check_test_key("one") | ||||
|  |  | |||
							
								
								
									
										249
									
								
								tests/utils.py
									
										
									
									
									
								
							
							
						
						
									
										249
									
								
								tests/utils.py
									
										
									
									
									
								
							|  | @ -13,27 +13,28 @@ | |||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| 
 | ||||
| from synapse.http.server import HttpServer | ||||
| from synapse.api.errors import cs_error, CodeMessageException, StoreError | ||||
| from synapse.api.constants import EventTypes | ||||
| from synapse.storage.prepare_database import prepare_database | ||||
| from synapse.storage.engines import create_engine | ||||
| from synapse.server import HomeServer | ||||
| from synapse.federation.transport import server | ||||
| from synapse.util.ratelimitutils import FederationRateLimiter | ||||
| 
 | ||||
| from synapse.util.logcontext import LoggingContext | ||||
| 
 | ||||
| from twisted.internet import defer, reactor | ||||
| from twisted.enterprise.adbapi import ConnectionPool | ||||
| 
 | ||||
| from collections import namedtuple | ||||
| from mock import patch, Mock | ||||
| import hashlib | ||||
| from inspect import getcallargs | ||||
| import urllib | ||||
| import urlparse | ||||
| 
 | ||||
| from inspect import getcallargs | ||||
| from mock import Mock, patch | ||||
| from twisted.internet import defer, reactor | ||||
| 
 | ||||
| from synapse.api.errors import CodeMessageException, cs_error | ||||
| from synapse.federation.transport import server | ||||
| from synapse.http.server import HttpServer | ||||
| from synapse.server import HomeServer | ||||
| from synapse.storage import PostgresEngine | ||||
| from synapse.storage.engines import create_engine | ||||
| from synapse.storage.prepare_database import prepare_database | ||||
| from synapse.util.logcontext import LoggingContext | ||||
| from synapse.util.ratelimitutils import FederationRateLimiter | ||||
| 
 | ||||
| # set this to True to run the tests against postgres instead of sqlite. | ||||
| # It requires you to have a local postgres database called synapse_test, within | ||||
| # which ALL TABLES WILL BE DROPPED | ||||
| USE_POSTGRES_FOR_TESTS = False | ||||
| 
 | ||||
| 
 | ||||
| @defer.inlineCallbacks | ||||
|  | @ -57,32 +58,70 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): | |||
|         config.worker_app = None | ||||
|         config.email_enable_notifs = False | ||||
|         config.block_non_admin_invites = False | ||||
|         config.federation_domain_whitelist = None | ||||
|         config.user_directory_search_all_users = False | ||||
| 
 | ||||
|         # disable user directory updates, because they get done in the | ||||
|         # background, which upsets the test runner. | ||||
|         config.update_user_directory = False | ||||
| 
 | ||||
|     config.use_frozen_dicts = True | ||||
|     config.database_config = {"name": "sqlite3"} | ||||
|     config.ldap_enabled = False | ||||
| 
 | ||||
|     if "clock" not in kargs: | ||||
|         kargs["clock"] = MockClock() | ||||
| 
 | ||||
|     if USE_POSTGRES_FOR_TESTS: | ||||
|         config.database_config = { | ||||
|             "name": "psycopg2", | ||||
|             "args": { | ||||
|                 "database": "synapse_test", | ||||
|                 "cp_min": 1, | ||||
|                 "cp_max": 5, | ||||
|             }, | ||||
|         } | ||||
|     else: | ||||
|         config.database_config = { | ||||
|             "name": "sqlite3", | ||||
|             "args": { | ||||
|                 "database": ":memory:", | ||||
|                 "cp_min": 1, | ||||
|                 "cp_max": 1, | ||||
|             }, | ||||
|         } | ||||
| 
 | ||||
|     db_engine = create_engine(config.database_config) | ||||
| 
 | ||||
|     # we need to configure the connection pool to run the on_new_connection | ||||
|     # function, so that we can test code that uses custom sqlite functions | ||||
|     # (like rank). | ||||
|     config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection | ||||
| 
 | ||||
|     if datastore is None: | ||||
|         db_pool = SQLiteMemoryDbPool() | ||||
|         yield db_pool.prepare() | ||||
|         hs = HomeServer( | ||||
|             name, db_pool=db_pool, config=config, | ||||
|             name, config=config, | ||||
|             db_config=config.database_config, | ||||
|             version_string="Synapse/tests", | ||||
|             database_engine=create_engine(config.database_config), | ||||
|             get_db_conn=db_pool.get_db_conn, | ||||
|             database_engine=db_engine, | ||||
|             room_list_handler=object(), | ||||
|             tls_server_context_factory=Mock(), | ||||
|             **kargs | ||||
|         ) | ||||
|         db_conn = hs.get_db_conn() | ||||
|         # make sure that the database is empty | ||||
|         if isinstance(db_engine, PostgresEngine): | ||||
|             cur = db_conn.cursor() | ||||
|             cur.execute("SELECT tablename FROM pg_tables where schemaname='public'") | ||||
|             rows = cur.fetchall() | ||||
|             for r in rows: | ||||
|                 cur.execute("DROP TABLE %s CASCADE" % r[0]) | ||||
|         yield prepare_database(db_conn, db_engine, config) | ||||
|         hs.setup() | ||||
|     else: | ||||
|         hs = HomeServer( | ||||
|             name, db_pool=None, datastore=datastore, config=config, | ||||
|             version_string="Synapse/tests", | ||||
|             database_engine=create_engine(config.database_config), | ||||
|             database_engine=db_engine, | ||||
|             room_list_handler=object(), | ||||
|             tls_server_context_factory=Mock(), | ||||
|             **kargs | ||||
|  | @ -301,168 +340,6 @@ class MockClock(object): | |||
|         return d | ||||
| 
 | ||||
| 
 | ||||
| class SQLiteMemoryDbPool(ConnectionPool, object): | ||||
|     def __init__(self): | ||||
|         super(SQLiteMemoryDbPool, self).__init__( | ||||
|             "sqlite3", ":memory:", | ||||
|             cp_min=1, | ||||
|             cp_max=1, | ||||
|         ) | ||||
| 
 | ||||
|         self.config = Mock() | ||||
|         self.config.password_providers = [] | ||||
|         self.config.database_config = {"name": "sqlite3"} | ||||
| 
 | ||||
|     def prepare(self): | ||||
|         engine = self.create_engine() | ||||
|         return self.runWithConnection( | ||||
|             lambda conn: prepare_database(conn, engine, self.config) | ||||
|         ) | ||||
| 
 | ||||
|     def get_db_conn(self): | ||||
|         conn = self.connect() | ||||
|         engine = self.create_engine() | ||||
|         prepare_database(conn, engine, self.config) | ||||
|         return conn | ||||
| 
 | ||||
|     def create_engine(self): | ||||
|         return create_engine(self.config.database_config) | ||||
| 
 | ||||
| 
 | ||||
| class MemoryDataStore(object): | ||||
| 
 | ||||
|     Room = namedtuple( | ||||
|         "Room", | ||||
|         ["room_id", "is_public", "creator"] | ||||
|     ) | ||||
| 
 | ||||
|     def __init__(self): | ||||
|         self.tokens_to_users = {} | ||||
|         self.paths_to_content = {} | ||||
| 
 | ||||
|         self.members = {} | ||||
|         self.rooms = {} | ||||
| 
 | ||||
|         self.current_state = {} | ||||
|         self.events = [] | ||||
| 
 | ||||
|     class Snapshot(namedtuple("Snapshot", "room_id user_id membership_state")): | ||||
|         def fill_out_prev_events(self, event): | ||||
|             pass | ||||
| 
 | ||||
|     def snapshot_room(self, room_id, user_id, state_type=None, state_key=None): | ||||
|         return self.Snapshot( | ||||
|             room_id, user_id, self.get_room_member(user_id, room_id) | ||||
|         ) | ||||
| 
 | ||||
|     def register(self, user_id, token, password_hash): | ||||
|         if user_id in self.tokens_to_users.values(): | ||||
|             raise StoreError(400, "User in use.") | ||||
|         self.tokens_to_users[token] = user_id | ||||
| 
 | ||||
|     def get_user_by_access_token(self, token): | ||||
|         try: | ||||
|             return { | ||||
|                 "name": self.tokens_to_users[token], | ||||
|             } | ||||
|         except Exception: | ||||
|             raise StoreError(400, "User does not exist.") | ||||
| 
 | ||||
|     def get_room(self, room_id): | ||||
|         try: | ||||
|             return self.rooms[room_id] | ||||
|         except Exception: | ||||
|             return None | ||||
| 
 | ||||
|     def store_room(self, room_id, room_creator_user_id, is_public): | ||||
|         if room_id in self.rooms: | ||||
|             raise StoreError(409, "Conflicting room!") | ||||
| 
 | ||||
|         room = MemoryDataStore.Room( | ||||
|             room_id=room_id, | ||||
|             is_public=is_public, | ||||
|             creator=room_creator_user_id | ||||
|         ) | ||||
|         self.rooms[room_id] = room | ||||
| 
 | ||||
|     def get_room_member(self, user_id, room_id): | ||||
|         return self.members.get(room_id, {}).get(user_id) | ||||
| 
 | ||||
|     def get_room_members(self, room_id, membership=None): | ||||
|         if membership: | ||||
|             return [ | ||||
|                 v for k, v in self.members.get(room_id, {}).items() | ||||
|                 if v.membership == membership | ||||
|             ] | ||||
|         else: | ||||
|             return self.members.get(room_id, {}).values() | ||||
| 
 | ||||
|     def get_rooms_for_user_where_membership_is(self, user_id, membership_list): | ||||
|         return [ | ||||
|             m[user_id] for m in self.members.values() | ||||
|             if user_id in m and m[user_id].membership in membership_list | ||||
|         ] | ||||
| 
 | ||||
|     def get_room_events_stream(self, user_id=None, from_key=None, to_key=None, | ||||
|                                limit=0, with_feedback=False): | ||||
|         return ([], from_key)  # TODO | ||||
| 
 | ||||
|     def get_joined_hosts_for_room(self, room_id): | ||||
|         return defer.succeed([]) | ||||
| 
 | ||||
|     def persist_event(self, event): | ||||
|         if event.type == EventTypes.Member: | ||||
|             room_id = event.room_id | ||||
|             user = event.state_key | ||||
|             self.members.setdefault(room_id, {})[user] = event | ||||
| 
 | ||||
|         if hasattr(event, "state_key"): | ||||
|             key = (event.room_id, event.type, event.state_key) | ||||
|             self.current_state[key] = event | ||||
| 
 | ||||
|         self.events.append(event) | ||||
| 
 | ||||
|     def get_current_state(self, room_id, event_type=None, state_key=""): | ||||
|         if event_type: | ||||
|             key = (room_id, event_type, state_key) | ||||
|             if self.current_state.get(key): | ||||
|                 return [self.current_state.get(key)] | ||||
|             return None | ||||
|         else: | ||||
|             return [ | ||||
|                 e for e in self.current_state | ||||
|                 if e[0] == room_id | ||||
|             ] | ||||
| 
 | ||||
|     def set_presence_state(self, user_localpart, state): | ||||
|         return defer.succeed({"state": 0}) | ||||
| 
 | ||||
|     def get_presence_list(self, user_localpart, accepted): | ||||
|         return [] | ||||
| 
 | ||||
|     def get_room_events_max_id(self): | ||||
|         return "s0"  # TODO (erikj) | ||||
| 
 | ||||
|     def get_send_event_level(self, room_id): | ||||
|         return defer.succeed(0) | ||||
| 
 | ||||
|     def get_power_level(self, room_id, user_id): | ||||
|         return defer.succeed(0) | ||||
| 
 | ||||
|     def get_add_state_level(self, room_id): | ||||
|         return defer.succeed(0) | ||||
| 
 | ||||
|     def get_room_join_rule(self, room_id): | ||||
|         # TODO (erikj): This should be configurable | ||||
|         return defer.succeed("invite") | ||||
| 
 | ||||
|     def get_ops_levels(self, room_id): | ||||
|         return defer.succeed((5, 5, 5)) | ||||
| 
 | ||||
|     def insert_client_ip(self, user, access_token, ip, user_agent): | ||||
|         return defer.succeed(None) | ||||
| 
 | ||||
| 
 | ||||
| def _format_call(args, kwargs): | ||||
|     return ", ".join( | ||||
|         ["%r" % (a) for a in args] + | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue